将ocr解析模块独立出来

This commit is contained in:
2025-10-31 14:38:37 +08:00
parent d78f1fe91d
commit 4318179904
13 changed files with 3262 additions and 23 deletions

View File

@@ -0,0 +1,175 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
OCR HTTP 客户端
用于调用独立的 OCR 服务的 HTTP API
"""
import os
import logging
import requests
from typing import Optional, Union, Dict, Any
logger = logging.getLogger(__name__)
class OCRHttpClient:
"""OCR HTTP 客户端,用于调用独立的 OCR 服务"""
def __init__(self, base_url: Optional[str] = None, timeout: int = 300):
"""
初始化 OCR HTTP 客户端
Args:
base_url: OCR 服务的基础 URL如果不提供则从环境变量 OCR_SERVICE_URL 读取
默认值为 http://localhost:8000
timeout: 请求超时时间(秒),默认 300 秒
"""
if base_url is None:
base_url = os.getenv("OCR_SERVICE_URL", "http://localhost:8000")
# 确保 URL 不包含尾随斜杠
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self.api_prefix = "/api/v1/ocr"
logger.info(f"Initialized OCR HTTP client with base_url: {self.base_url}")
def parse_pdf_by_path(self, file_path: str, page_from: int = 1, page_to: int = 0, zoomin: int = 3) -> Dict[str, Any]:
"""
通过文件路径解析 PDF
Args:
file_path: PDF 文件的本地路径
page_from: 起始页码从1开始
page_to: 结束页码0表示最后一页
zoomin: 图像放大倍数1-5
Returns:
dict: 解析结果,格式:
{
"success": bool,
"message": str,
"data": {
"pages": [
{
"page_number": int,
"boxes": [
{
"text": str,
"bbox": [[x0, y0], [x1, y0], [x1, y1], [x0, y1]],
"confidence": float
},
...
]
},
...
]
}
}
Raises:
requests.RequestException: HTTP 请求失败
ValueError: 响应格式不正确
"""
url = f"{self.base_url}{self.api_prefix}/parse/path"
data = {
"file_path": file_path,
"page_from": page_from,
"page_to": page_to,
"zoomin": zoomin
}
try:
logger.info(f"Calling OCR service: {url} for file: {file_path}")
response = requests.post(url, data=data, timeout=self.timeout)
response.raise_for_status()
result = response.json()
if not result.get("success", False):
raise ValueError(f"OCR service returned error: {result.get('message', 'Unknown error')}")
return result
except requests.RequestException as e:
logger.error(f"Failed to call OCR service: {e}")
raise
def parse_pdf_by_bytes(self, pdf_bytes: bytes, filename: str = "document.pdf",
page_from: int = 1, page_to: int = 0, zoomin: int = 3) -> Dict[str, Any]:
"""
通过二进制数据解析 PDF
Args:
pdf_bytes: PDF 文件的二进制数据
filename: 文件名(仅用于日志)
page_from: 起始页码从1开始
page_to: 结束页码0表示最后一页
zoomin: 图像放大倍数1-5
Returns:
dict: 解析结果,格式同 parse_pdf_by_path
Raises:
requests.RequestException: HTTP 请求失败
ValueError: 响应格式不正确
"""
url = f"{self.base_url}{self.api_prefix}/parse/bytes"
files = {
"pdf_bytes": (filename, pdf_bytes, "application/pdf")
}
data = {
"filename": filename,
"page_from": page_from,
"page_to": page_to,
"zoomin": zoomin
}
try:
logger.info(f"Calling OCR service: {url} with {len(pdf_bytes)} bytes")
response = requests.post(url, files=files, data=data, timeout=self.timeout)
response.raise_for_status()
result = response.json()
if not result.get("success", False):
raise ValueError(f"OCR service returned error: {result.get('message', 'Unknown error')}")
return result
except requests.RequestException as e:
logger.error(f"Failed to call OCR service: {e}")
raise
def health_check(self) -> Dict[str, Any]:
"""
检查 OCR 服务健康状态
Returns:
dict: 健康状态信息
"""
url = f"{self.base_url}{self.api_prefix}/health"
try:
response = requests.get(url, timeout=10)
response.raise_for_status()
return response.json()
except requests.RequestException as e:
logger.error(f"Failed to check OCR service health: {e}")
raise

View File

@@ -35,6 +35,7 @@ from pypdf import PdfReader as pdf2_read
from api import settings
from api.utils.file_utils import get_project_base_directory
from deepdoc.vision import OCR, AscendLayoutRecognizer, LayoutRecognizer, Recognizer, TableStructureRecognizer
from deepdoc.parser.ocr_http_client import OCRHttpClient
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
from rag.nlp import rag_tokenizer
from rag.prompts.generator import vision_llm_describe_prompt
@@ -59,9 +60,23 @@ class RAGFlowPdfParser:
"""
# 检查是否使用 HTTP OCR 服务
use_http_ocr = os.getenv("USE_OCR_HTTP", "false").lower() in ("true", "1", "yes")
ocr_service_url = os.getenv("OCR_SERVICE_URL", "http://localhost:8000")
if use_http_ocr:
logging.info(f"Using HTTP OCR service: {ocr_service_url}")
self.ocr = None # 不使用本地 OCR
self.ocr_http_client = OCRHttpClient(base_url=ocr_service_url)
self.use_http_ocr = True
else:
logging.info("Using local OCR")
self.ocr = OCR()
self.ocr_http_client = None
self.use_http_ocr = False
self.parallel_limiter = None
if PARALLEL_DEVICES > 1:
if not self.use_http_ocr and PARALLEL_DEVICES > 1:
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)]
layout_recognizer_type = os.getenv("LAYOUT_RECOGNIZER_TYPE", "onnx").lower()
@@ -276,7 +291,97 @@ class RAGFlowPdfParser:
b["H_right"] = spans[ii]["x1"]
b["SP"] = ii
def _convert_http_ocr_result(self, ocr_result: dict, zoomin: int = 3):
"""
将 HTTP OCR API 返回的结果转换为 RAGFlow 内部格式
Args:
ocr_result: HTTP API 返回的结果,格式:
{
"success": bool,
"data": {
"pages": [
{
"page_number": int,
"boxes": [
{
"text": str,
"bbox": [[x0, y0], [x1, y0], [x1, y1], [x0, y1]],
"confidence": float
}
]
}
]
}
}
zoomin: 放大倍数
"""
if not ocr_result.get("success", False) or "data" not in ocr_result:
logging.warning("Invalid OCR HTTP result")
return
pages_data = ocr_result["data"].get("pages", [])
self.boxes = []
for page_data in pages_data:
page_num = page_data.get("page_number", 0) # HTTP API 返回的页码从1开始
boxes = page_data.get("boxes", [])
# 转换为 RAGFlow 格式的 boxes
ragflow_boxes = []
# 计算在 page_chars 中的索引HTTP API 返回的页码是从1开始的需要转换为相对于 page_from 的索引
page_index = page_num - (self.page_from + 1) # page_from 是从0开始所以需要 +1
chars_for_page = self.page_chars[page_index] if hasattr(self, 'page_chars') and 0 <= page_index < len(self.page_chars) else []
for box in boxes:
bbox = box.get("bbox", [])
if len(bbox) != 4:
continue
# 从 bbox 提取坐标bbox 格式: [[x0, y0], [x1, y0], [x1, y1], [x0, y1]]
x0 = min(bbox[0][0], bbox[3][0]) / zoomin
x1 = max(bbox[1][0], bbox[2][0]) / zoomin
top = min(bbox[0][1], bbox[1][1]) / zoomin
bottom = max(bbox[2][1], bbox[3][1]) / zoomin
# 创建 RAGFlow 格式的 box
ragflow_box = {
"x0": x0,
"x1": x1,
"top": top,
"bottom": bottom,
"text": box.get("text", ""),
"page_number": page_num,
"layoutno": "",
"layout_type": ""
}
ragflow_boxes.append(ragflow_box)
# 计算 mean_height
if ragflow_boxes:
heights = [b["bottom"] - b["top"] for b in ragflow_boxes]
self.mean_height.append(np.median(heights) if heights else 0)
else:
self.mean_height.append(0)
# 计算 mean_width
if chars_for_page:
widths = [c.get("width", 8) for c in chars_for_page]
self.mean_width.append(np.median(widths) if widths else 8)
else:
self.mean_width.append(8)
self.boxes.append(ragflow_boxes)
logging.info(f"Converted {len(pages_data)} pages from HTTP OCR result")
def __ocr(self, pagenum, img, chars, ZM=3, device_id: int | None = None):
# 如果使用 HTTP OCR这个方法不会被调用
if self.use_http_ocr:
logging.warning("__ocr called when using HTTP OCR, this should not happen")
return
start = timer()
bxs = self.ocr.detect(np.array(img), device_id)
logging.info(f"__ocr detecting boxes of a image cost ({timer() - start}s)")
@@ -927,6 +1032,7 @@ class RAGFlowPdfParser:
self.page_cum_height = [0]
self.page_layout = []
self.page_from = page_from
start = timer()
try:
with sys.modules[LOCK_KEY_pdfplumber]:
@@ -946,6 +1052,42 @@ class RAGFlowPdfParser:
logging.exception("RAGFlowPdfParser __images__")
logging.info(f"__images__ dedupe_chars cost {timer() - start}s")
# 如果使用 HTTP OCR在获取图片和字符信息后调用 HTTP API 获取 OCR 结果
if self.use_http_ocr:
try:
if callback:
callback(0.1, "Calling OCR HTTP service...")
# 调用 HTTP OCR 服务
if isinstance(fnm, str):
# 文件路径
ocr_result = self.ocr_http_client.parse_pdf_by_path(
fnm,
page_from=page_from + 1, # HTTP API 使用从1开始的页码
page_to=(page_to + 1) if page_to < 299 else 0, # 转换为从1开始0 表示最后一页
zoomin=zoomin
)
else:
# 二进制数据
ocr_result = self.ocr_http_client.parse_pdf_by_bytes(
fnm,
filename="document.pdf",
page_from=page_from + 1,
page_to=(page_to + 1) if page_to < 299 else 0,
zoomin=zoomin
)
# 将 HTTP API 返回的结果转换为 RAGFlow 格式
self._convert_http_ocr_result(ocr_result, zoomin)
if callback:
callback(0.4, "OCR HTTP service completed")
except Exception as e:
logging.error(f"Failed to call OCR HTTP service: {e}", exc_info=True)
# 如果 HTTP OCR 失败,回退到空结果或抛出异常
raise
self.outlines = []
try:
with pdf2_read(fnm if isinstance(fnm, str) else BytesIO(fnm)) as pdf:
@@ -999,6 +1141,8 @@ class RAGFlowPdfParser:
if callback and i % 6 == 5:
callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="")
# 如果使用 HTTP OCR已经在上面的代码中获取了结果跳过本地 OCR
if not self.use_http_ocr:
async def __img_ocr_launcher():
def __ocr_preprocess():
chars = self.page_chars[i] if not self.is_english else []
@@ -1020,8 +1164,11 @@ class RAGFlowPdfParser:
await __img_ocr(i, 0, img, chars, None)
start = timer()
trio.run(__img_ocr_launcher)
else:
# HTTP OCR 模式:初始化 page_cum_height
for i, img in enumerate(self.page_images):
self.page_cum_height.append(img.size[1] / zoomin)
logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s")

View File

@@ -199,3 +199,5 @@ POSTGRES_USER=rag_flow
POSTGRES_PASSWORD=infini_rag_flow
POSTGRES_PORT=5432
DB_TYPE=postgres
USE_OCR_HTTP=true

53
ocr/__init__.py Normal file
View File

@@ -0,0 +1,53 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
独立的 OCR 模块
此模块从 RAGFlow 项目中提取,已经移除了对 RAGFlow 特定模块的依赖。
可以直接作为独立模块使用。
使用方法:
from ocr import OCR
import cv2
ocr = OCR()
img = cv2.imread("image.jpg")
results = ocr(img)
"""
# 处理导入问题:支持直接运行和模块导入
import sys
from pathlib import Path
try:
_package = __package__
except NameError:
_package = None
if _package is None:
# 直接运行时,添加父目录到路径并使用绝对导入
parent_dir = Path(__file__).parent.parent
if str(parent_dir) not in sys.path:
sys.path.insert(0, str(parent_dir))
from ocr.ocr import OCR, TextDetector, TextRecognizer
from ocr.pdf_parser import SimplePdfParser
else:
# 作为模块导入时使用相对导入
from .ocr import OCR, TextDetector, TextRecognizer
from .pdf_parser import SimplePdfParser
__all__ = ['OCR', 'TextDetector', 'TextRecognizer', 'SimplePdfParser']

332
ocr/api.py Normal file
View File

@@ -0,0 +1,332 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
OCR PDF处理的FastAPI路由
提供HTTP接口用于PDF的OCR识别
"""
import asyncio
import logging
import os
import sys
import tempfile
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from fastapi.responses import JSONResponse
from pydantic import BaseModel
# 处理导入问题:支持直接运行和模块导入
try:
_package = __package__
except NameError:
_package = None
if _package is None:
# 直接运行时,添加父目录到路径并使用绝对导入
parent_dir = Path(__file__).parent.parent
if str(parent_dir) not in sys.path:
sys.path.insert(0, str(parent_dir))
from ocr.pdf_parser import SimplePdfParser
from ocr.config import MODEL_DIR
else:
# 作为模块导入时使用相对导入
from pdf_parser import SimplePdfParser
from config import MODEL_DIR
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr", tags=["OCR"])
# 全局解析器实例(懒加载)
_parser_instance: Optional[SimplePdfParser] = None
def get_parser() -> SimplePdfParser:
"""获取全局解析器实例(单例模式)"""
global _parser_instance
if _parser_instance is None:
logger.info(f"Initializing OCR parser with model_dir={MODEL_DIR}")
_parser_instance = SimplePdfParser(model_dir=MODEL_DIR)
return _parser_instance
class ParseResponse(BaseModel):
"""解析响应模型"""
success: bool
message: str
data: Optional[dict] = None
@router.get(
"/health",
summary="健康检查",
description="检查OCR服务的健康状态和配置信息",
response_description="返回服务状态和模型目录信息"
)
async def health_check():
"""
健康检查端点
用于检查OCR服务的运行状态和配置信息。
Returns:
dict: 包含服务状态和模型目录的信息
"""
return {
"status": "healthy",
"service": "OCR PDF Parser",
"model_dir": MODEL_DIR
}
@router.post(
"/parse",
response_model=ParseResponse,
summary="上传并解析PDF文件",
description="上传PDF文件并通过OCR识别提取文本内容",
response_description="返回OCR识别结果"
)
async def parse_pdf_endpoint(
file: UploadFile = File(..., description="PDF文件支持上传任意PDF文档"),
page_from: int = Form(1, ge=1, description="起始页码从1开始默认为1"),
page_to: int = Form(0, ge=0, description="结束页码0表示解析到最后一页默认为0"),
zoomin: int = Form(3, ge=1, le=5, description="图像放大倍数1-5数值越大识别精度越高但速度越慢默认为3")
):
"""
上传并解析PDF文件
通过上传PDF文件使用OCR技术识别并提取其中的文本内容。
支持指定解析的页码范围,以及调整图像放大倍数以平衡识别精度和速度。
Args:
file: 上传的PDF文件multipart/form-data格式
page_from: 起始页码从1开始最小值为1
page_to: 结束页码0表示解析到最后一页最小值为0
zoomin: 图像放大倍数1-5之间数值越大识别精度越高但处理速度越慢
Returns:
ParseResponse: 包含解析结果的响应对象,包括:
- success: 是否成功
- message: 操作结果消息
- data: OCR识别的文本内容和元数据
Raises:
HTTPException: 400 - 如果文件不是PDF格式或文件为空
HTTPException: 500 - 如果解析过程中发生错误
"""
if not file.filename.lower().endswith('.pdf'):
raise HTTPException(status_code=400, detail="只支持PDF文件")
# 保存上传的文件到临时目录
temp_file = None
try:
# 读取文件内容
content = await file.read()
if not content:
raise HTTPException(status_code=400, detail="文件为空")
# 创建临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
tmp.write(content)
temp_file = tmp.name
logger.info(f"Parsing PDF file: {file.filename}, pages {page_from}-{page_to or 'end'}, zoomin={zoomin}")
# 解析PDFparse_pdf是同步方法使用to_thread在线程池中执行
parser = get_parser()
result = await asyncio.to_thread(
parser.parse_pdf,
temp_file,
zoomin,
page_from - 1, # 转换为从0开始的索引
(page_to - 1) if page_to > 0 else 299, # 转换为从0开始的索引
None # callback
)
return ParseResponse(
success=True,
message=f"成功解析PDF: {file.filename}",
data=result
)
except Exception as e:
logger.error(f"Error parsing PDF: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"解析PDF时发生错误: {str(e)}"
)
finally:
# 清理临时文件
if temp_file and os.path.exists(temp_file):
try:
os.unlink(temp_file)
except Exception as e:
logger.warning(f"Failed to delete temp file {temp_file}: {e}")
@router.post(
"/parse/bytes",
response_model=ParseResponse,
summary="通过二进制数据解析PDF",
description="直接通过二进制数据解析PDF文件无需上传文件",
response_description="返回OCR识别结果"
)
async def parse_pdf_bytes(
pdf_bytes: bytes = File(..., description="PDF文件的二进制数据multipart/form-data格式"),
filename: str = Form("document.pdf", description="文件名(仅用于日志记录,不影响解析)"),
page_from: int = Form(1, ge=1, description="起始页码从1开始默认为1"),
page_to: int = Form(0, ge=0, description="结束页码0表示解析到最后一页默认为0"),
zoomin: int = Form(3, ge=1, le=5, description="图像放大倍数1-5数值越大识别精度越高但速度越慢默认为3")
):
"""
直接通过二进制数据解析PDF
适用于已获取PDF二进制数据的场景无需文件上传步骤。
直接将PDF的二进制数据提交即可进行OCR识别。
Args:
pdf_bytes: PDF文件的二进制数据以文件形式提交
filename: 文件名(仅用于日志记录,不影响实际解析过程)
page_from: 起始页码从1开始最小值为1
page_to: 结束页码0表示解析到最后一页最小值为0
zoomin: 图像放大倍数1-5之间数值越大识别精度越高但处理速度越慢
Returns:
ParseResponse: 包含解析结果的响应对象
Raises:
HTTPException: 400 - 如果PDF数据为空
HTTPException: 500 - 如果解析过程中发生错误
"""
if not pdf_bytes:
raise HTTPException(status_code=400, detail="PDF数据为空")
# 保存到临时文件
temp_file = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
tmp.write(pdf_bytes)
temp_file = tmp.name
logger.info(f"Parsing PDF bytes (filename: {filename}), pages {page_from}-{page_to or 'end'}, zoomin={zoomin}")
# 解析PDFparse_pdf是同步方法使用to_thread在线程池中执行
parser = get_parser()
result = await asyncio.to_thread(
parser.parse_pdf,
temp_file,
zoomin,
page_from - 1, # 转换为从0开始的索引
(page_to - 1) if page_to > 0 else 299, # 转换为从0开始的索引
None # callback
)
return ParseResponse(
success=True,
message=f"成功解析PDF: {filename}",
data=result
)
except Exception as e:
logger.error(f"Error parsing PDF bytes: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"解析PDF时发生错误: {str(e)}"
)
finally:
# 清理临时文件
if temp_file and os.path.exists(temp_file):
try:
os.unlink(temp_file)
except Exception as e:
logger.warning(f"Failed to delete temp file {temp_file}: {e}")
@router.post(
"/parse/path",
response_model=ParseResponse,
summary="通过文件路径解析PDF",
description="通过服务器本地文件路径解析PDF文件",
response_description="返回OCR识别结果"
)
async def parse_pdf_path(
file_path: str = Form(..., description="PDF文件在服务器上的本地路径必须是可访问的绝对路径"),
page_from: int = Form(1, ge=1, description="起始页码从1开始默认为1"),
page_to: int = Form(0, ge=0, description="结束页码0表示解析到最后一页默认为0"),
zoomin: int = Form(3, ge=1, le=5, description="图像放大倍数1-5数值越大识别精度越高但速度越慢默认为3")
):
"""
通过文件路径解析PDF
适用于PDF文件已经存在于服务器上的场景。
通过提供文件路径直接进行OCR识别无需上传文件。
Args:
file_path: PDF文件在服务器上的本地路径必须是服务器可访问的绝对路径
page_from: 起始页码从1开始最小值为1
page_to: 结束页码0表示解析到最后一页最小值为0
zoomin: 图像放大倍数1-5之间数值越大识别精度越高但处理速度越慢
Returns:
ParseResponse: 包含解析结果的响应对象
Raises:
HTTPException: 400 - 如果文件不是PDF格式
HTTPException: 404 - 如果文件不存在
HTTPException: 500 - 如果解析过程中发生错误
Note:
此端点需要确保提供的文件路径在服务器上可访问。
建议仅在内网环境或受信任的环境中使用,避免路径遍历安全风险。
"""
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail=f"文件不存在: {file_path}")
if not file_path.lower().endswith('.pdf'):
raise HTTPException(status_code=400, detail="只支持PDF文件")
try:
logger.info(f"Parsing PDF from path: {file_path}, pages {page_from}-{page_to or 'end'}, zoomin={zoomin}")
# 解析PDFparse_pdf是同步方法使用to_thread在线程池中执行
parser = get_parser()
result = await asyncio.to_thread(
parser.parse_pdf,
file_path,
zoomin,
page_from - 1, # 转换为从0开始的索引
(page_to - 1) if page_to > 0 else 299, # 转换为从0开始的索引
None # callback
)
return ParseResponse(
success=True,
message=f"成功解析PDF: {file_path}",
data=result
)
except Exception as e:
logger.error(f"Error parsing PDF from path: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"解析PDF时发生错误: {str(e)}"
)

42
ocr/config.py Normal file
View File

@@ -0,0 +1,42 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
OCR 模块配置文件
"""
import os
import logging
# 并行设备数量GPU数量0表示使用CPU
PARALLEL_DEVICES = 0
try:
import torch.cuda
PARALLEL_DEVICES = torch.cuda.device_count()
logging.info(f"found {PARALLEL_DEVICES} gpus")
except Exception:
logging.info("can't import package 'torch', using CPU mode")
# 模型目录
# 可以从环境变量获取,或使用默认路径
MODEL_DIR = os.getenv("OCR_MODEL_DIR", None)
if MODEL_DIR is None:
# 默认模型目录:当前项目根目录下的 models/deepdoc 目录
# 如果不存在,将在 OCR 类初始化时尝试从 HuggingFace 下载
_base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
MODEL_DIR = os.path.join(_base_dir, "models", "deepdoc")
# 如果目录不存在,设置为 None让 OCR 类处理下载逻辑
if not os.path.exists(MODEL_DIR):
MODEL_DIR = None

202
ocr/main.py Normal file
View File

@@ -0,0 +1,202 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
OCR PDF处理服务的主程序入口
独立运行不依赖RAGFlow的其他部分
"""
import argparse
import logging
import os
import sys
import signal
from pathlib import Path
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
# 处理直接运行时的导入问题
# 当直接运行 python ocr/main.py 时__package__ 为 None
# 当作为模块运行时python -m ocr.main__package__ 为 'ocr'
try:
_package = __package__
except NameError:
_package = None
if _package is None:
# 直接运行脚本时,添加父目录到路径
parent_dir = Path(__file__).parent.parent
if str(parent_dir) not in sys.path:
sys.path.insert(0, str(parent_dir))
from api import router as ocr_router
from config import MODEL_DIR
else:
# 作为模块导入时使用相对导入
from api import router as ocr_router
from config import MODEL_DIR
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
def create_app() -> FastAPI:
"""创建FastAPI应用实例"""
app = FastAPI(
title="OCR PDF Parser API",
description="独立的OCR PDF处理服务提供PDF文档的OCR识别功能",
version="1.0.0",
docs_url="/apidocs", # Swagger UI 文档地址
redoc_url="/redoc", # ReDoc 文档地址(备用)
openapi_url="/openapi.json" # OpenAPI JSON schema 地址
)
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 生产环境中应该设置具体的域名
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 注册OCR路由
app.include_router(ocr_router)
# 根路径
@app.get("/")
async def root():
return {
"service": "OCR PDF Parser",
"version": "1.0.0",
"docs": "/apidocs",
"health": "/api/v1/ocr/health"
}
return app
def signal_handler(sig, frame):
"""信号处理器,用于优雅关闭"""
logger.info("Received shutdown signal, exiting...")
sys.exit(0)
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="OCR PDF处理服务")
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="服务器监听地址 (default: 0.0.0.0)"
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="服务器端口 (default: 8000)"
)
parser.add_argument(
"--reload",
action="store_true",
help="开发模式:自动重载代码"
)
parser.add_argument(
"--workers",
type=int,
default=1,
help="工作进程数 (default: 1)"
)
parser.add_argument(
"--log-level",
type=str,
default="info",
choices=["critical", "error", "warning", "info", "debug", "trace"],
help="日志级别 (default: info)"
)
parser.add_argument(
"--model-dir",
type=str,
default=None,
help=f"OCR模型目录路径 (default: {MODEL_DIR})"
)
args = parser.parse_args()
# 设置模型目录(如果提供)
if args.model_dir:
os.environ["OCR_MODEL_DIR"] = args.model_dir
logger.info(f"Using custom model directory: {args.model_dir}")
# 检查模型目录
model_dir = os.environ.get("OCR_MODEL_DIR", MODEL_DIR)
if model_dir and not os.path.exists(model_dir):
logger.warning(f"Model directory does not exist: {model_dir}")
logger.info("Models will be downloaded on first use")
# 注册信号处理器
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# 显示启动信息
logger.info("=" * 60)
logger.info("OCR PDF Parser Service")
logger.info("=" * 60)
logger.info(f"Host: {args.host}")
logger.info(f"Port: {args.port}")
logger.info(f"Model Directory: {model_dir}")
logger.info(f"Workers: {args.workers}")
logger.info(f"Reload: {args.reload}")
logger.info(f"Log Level: {args.log_level}")
logger.info("=" * 60)
logger.info(f"API Documentation (Swagger): http://{args.host}:{args.port}/apidocs")
logger.info(f"API Documentation (ReDoc): http://{args.host}:{args.port}/redoc")
logger.info(f"Health Check: http://{args.host}:{args.port}/api/v1/ocr/health")
logger.info("=" * 60)
# 创建应用
app = create_app()
# 启动服务器
try:
uvicorn.run(
app,
host=args.host,
port=args.port,
log_level=args.log_level,
reload=args.reload,
workers=args.workers if not args.reload else 1, # reload模式不支持多进程
access_log=True
)
except KeyboardInterrupt:
logger.info("Server stopped by user")
except Exception as e:
logger.error(f"Server error: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()

785
ocr/ocr.py Normal file
View File

@@ -0,0 +1,785 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import gc
import logging
import copy
import time
import os
import sys
from pathlib import Path
from huggingface_hub import snapshot_download
# 处理导入问题:支持直接运行和模块导入
try:
_package = __package__
except NameError:
_package = None
if _package is None:
# 直接运行时,添加父目录到路径并使用绝对导入
parent_dir = Path(__file__).parent.parent
if str(parent_dir) not in sys.path:
sys.path.insert(0, str(parent_dir))
from ocr.utils import get_project_base_directory
from ocr.config import PARALLEL_DEVICES, MODEL_DIR
from ocr.operators import * # noqa: F403
import ocr.operators as operators
from ocr.postprocess import build_post_process
else:
# 作为模块导入时使用相对导入
from utils import get_project_base_directory
from config import PARALLEL_DEVICES, MODEL_DIR
from operators import * # noqa: F403
import operators
from postprocess import build_post_process
import math
import numpy as np
import cv2
import onnxruntime as ort
loaded_models = {}
def transform(data, ops=None):
""" transform """
if ops is None:
ops = []
for op in ops:
data = op(data)
if data is None:
return None
return data
def create_operators(op_param_list, global_config=None):
"""
create operators based on the config
Args:
params(list): a dict list, used to create some operators
"""
assert isinstance(
op_param_list, list), ('operator config should be a list')
ops = []
for operator in op_param_list:
assert isinstance(operator,
dict) and len(operator) == 1, "yaml format error"
op_name = list(operator)[0]
param = {} if operator[op_name] is None else operator[op_name]
if global_config is not None:
param.update(global_config)
op = getattr(operators, op_name)(**param)
ops.append(op)
return ops
def load_model(model_dir, nm, device_id: int | None = None):
model_file_path = os.path.join(model_dir, nm + ".onnx")
model_cached_tag = model_file_path + str(device_id) if device_id is not None else model_file_path
global loaded_models
loaded_model = loaded_models.get(model_cached_tag)
if loaded_model:
logging.info(f"load_model {model_file_path} reuses cached model")
return loaded_model
if not os.path.exists(model_file_path):
raise ValueError("not find model file path {}".format(
model_file_path))
def cuda_is_available():
try:
import torch
if torch.cuda.is_available() and torch.cuda.device_count() > device_id:
return True
except Exception:
return False
return False
options = ort.SessionOptions()
options.enable_cpu_mem_arena = False
options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
options.intra_op_num_threads = 2
options.inter_op_num_threads = 2
# https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580
# Shrink GPU memory after execution
run_options = ort.RunOptions()
if cuda_is_available():
cuda_provider_options = {
"device_id": device_id, # Use specific GPU
"gpu_mem_limit": 512 * 1024 * 1024, # Limit gpu memory
"arena_extend_strategy": "kNextPowerOfTwo", # gpu memory allocation strategy
}
sess = ort.InferenceSession(
model_file_path,
options=options,
providers=['CUDAExecutionProvider'],
provider_options=[cuda_provider_options]
)
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "gpu:" + str(device_id))
logging.info(f"load_model {model_file_path} uses GPU")
else:
sess = ort.InferenceSession(
model_file_path,
options=options,
providers=['CPUExecutionProvider'])
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "cpu")
logging.info(f"load_model {model_file_path} uses CPU")
loaded_model = (sess, run_options)
loaded_models[model_cached_tag] = loaded_model
return loaded_model
class TextRecognizer:
def __init__(self, model_dir, device_id: int | None = None):
self.rec_image_shape = [int(v) for v in "3, 48, 320".split(",")]
self.rec_batch_num = 16
postprocess_params = {
'name': 'CTCLabelDecode',
"character_dict_path": os.path.join(model_dir, "ocr.res"),
"use_space_char": True
}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.run_options = load_model(model_dir, 'rec', device_id)
self.input_tensor = self.predictor.get_inputs()[0]
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
assert imgC == img.shape[2]
imgW = int((imgH * max_wh_ratio))
w = self.input_tensor.shape[3:][0]
if isinstance(w, str):
pass
elif w is not None and w > 0:
imgW = w
h, w = img.shape[:2]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def resize_norm_img_vl(self, img, image_shape):
imgC, imgH, imgW = image_shape
img = img[:, :, ::-1] # bgr2rgb
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
return resized_image
def resize_norm_img_srn(self, img, image_shape):
imgC, imgH, imgW = image_shape
img_black = np.zeros((imgH, imgW))
im_hei = img.shape[0]
im_wid = img.shape[1]
if im_wid <= im_hei * 1:
img_new = cv2.resize(img, (imgH * 1, imgH))
elif im_wid <= im_hei * 2:
img_new = cv2.resize(img, (imgH * 2, imgH))
elif im_wid <= im_hei * 3:
img_new = cv2.resize(img, (imgH * 3, imgH))
else:
img_new = cv2.resize(img, (imgW, imgH))
img_np = np.asarray(img_new)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
img_black[:, 0:img_np.shape[1]] = img_np
img_black = img_black[:, :, np.newaxis]
row, col, c = img_black.shape
c = 1
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def srn_other_inputs(self, image_shape, num_heads, max_text_length):
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
(feature_dim, 1)).astype('int64')
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
(max_text_length, 1)).astype('int64')
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias1 = np.tile(
gsrm_slf_attn_bias1,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias2 = np.tile(
gsrm_slf_attn_bias2,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
encoder_word_pos = encoder_word_pos[np.newaxis, :]
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
return [
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
def process_image_srn(self, img, image_shape, num_heads, max_text_length):
norm_img = self.resize_norm_img_srn(img, image_shape)
norm_img = norm_img[np.newaxis, :]
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
self.srn_other_inputs(image_shape, num_heads, max_text_length)
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
encoder_word_pos = encoder_word_pos.astype(np.int64)
gsrm_word_pos = gsrm_word_pos.astype(np.int64)
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2)
def resize_norm_img_sar(self, img, image_shape,
width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
w = img.shape[1]
valid_ratio = 1.0
# make sure new_width is an integral multiple of width_divisor.
width_divisor = int(1 / width_downsample_ratio)
# resize
ratio = w / float(h)
resize_w = math.ceil(imgH * ratio)
if resize_w % width_divisor != 0:
resize_w = round(resize_w / width_divisor) * width_divisor
if imgW_min is not None:
resize_w = max(imgW_min, resize_w)
if imgW_max is not None:
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
resize_w = min(imgW_max, resize_w)
resized_image = cv2.resize(img, (resize_w, imgH))
resized_image = resized_image.astype('float32')
# norm
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
resize_shape = resized_image.shape
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
padding_im[:, :, 0:resize_w] = resized_image
pad_shape = padding_im.shape
return padding_im, resize_shape, pad_shape, valid_ratio
def resize_norm_img_spin(self, img):
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# return padding_im
img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC)
img = np.array(img, np.float32)
img = np.expand_dims(img, -1)
img = img.transpose((2, 0, 1))
mean = [127.5]
std = [127.5]
mean = np.array(mean, dtype=np.float32)
std = np.array(std, dtype=np.float32)
mean = np.float32(mean.reshape(1, -1))
stdinv = 1 / np.float32(std.reshape(1, -1))
img -= mean
img *= stdinv
return img
def resize_norm_img_svtr(self, img, image_shape):
imgC, imgH, imgW = image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
return resized_image
def resize_norm_img_abinet(self, img, image_shape):
imgC, imgH, imgW = image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image / 255.
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
resized_image = (
resized_image - mean[None, None, ...]) / std[None, None, ...]
resized_image = resized_image.transpose((2, 0, 1))
resized_image = resized_image.astype('float32')
return resized_image
def norm_img_can(self, img, image_shape):
img = cv2.cvtColor(
img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
if self.rec_image_shape[0] == 1:
h, w = img.shape
_, imgH, imgW = self.rec_image_shape
if h < imgH or w < imgW:
padding_h = max(imgH - h, 0)
padding_w = max(imgW - w, 0)
img_padded = np.pad(img, ((0, padding_h), (0, padding_w)),
'constant',
constant_values=(255))
img = img_padded
img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w
img = img.astype('float32')
return img
def close(self):
# close session and release manually
logging.info('Close text recognizer.')
if hasattr(self, "predictor"):
del self.predictor
gc.collect()
def __call__(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
width_list = []
for img in img_list:
width_list.append(img.shape[1] / float(img.shape[0]))
# Sorting can speed up the recognition process
indices = np.argsort(np.array(width_list))
rec_res = [['', 0.0]] * img_num
batch_num = self.rec_batch_num
st = time.time()
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
imgC, imgH, imgW = self.rec_image_shape[:3]
max_wh_ratio = imgW / imgH
# max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
input_dict = {}
input_dict[self.input_tensor.name] = norm_img_batch
for i in range(100000):
try:
outputs = self.predictor.run(None, input_dict, self.run_options)
break
except Exception as e:
if i >= 3:
raise e
time.sleep(5)
preds = outputs[0]
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
return rec_res, time.time() - st
def __del__(self):
self.close()
class TextDetector:
def __init__(self, model_dir, device_id: int | None = None):
pre_process_list = [{
'DetResizeForTest': {
'limit_side_len': 960,
'limit_type': "max",
}
}, {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225],
'mean': [0.485, 0.456, 0.406],
'scale': '1./255.',
'order': 'hwc'
}
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': ['image', 'shape']
}
}]
postprocess_params = {"name": "DBPostProcess", "thresh": 0.3, "box_thresh": 0.5, "max_candidates": 1000,
"unclip_ratio": 1.5, "use_dilation": False, "score_mode": "fast", "box_type": "quad"}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.run_options = load_model(model_dir, 'det', device_id)
self.input_tensor = self.predictor.get_inputs()[0]
img_h, img_w = self.input_tensor.shape[2:]
if isinstance(img_h, str) or isinstance(img_w, str):
pass
elif img_h is not None and img_w is not None and img_h > 0 and img_w > 0:
pre_process_list[0] = {
'DetResizeForTest': {
'image_shape': [img_h, img_w]
}
}
self.preprocess_op = create_operators(pre_process_list)
def order_points_clockwise(self, pts):
rect = np.zeros((4, 2), dtype="float32")
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)]
rect[2] = pts[np.argmax(s)]
tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
diff = np.diff(np.array(tmp), axis=1)
rect[1] = tmp[np.argmin(diff)]
rect[3] = tmp[np.argmax(diff)]
return rect
def clip_det_res(self, points, img_height, img_width):
for pno in range(points.shape[0]):
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
return points
def filter_tag_det_res(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
if isinstance(box, list):
box = np.array(box)
box = self.order_points_clockwise(box)
box = self.clip_det_res(box, img_height, img_width)
rect_width = int(np.linalg.norm(box[0] - box[1]))
rect_height = int(np.linalg.norm(box[0] - box[3]))
if rect_width <= 3 or rect_height <= 3:
continue
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
if isinstance(box, list):
box = np.array(box)
box = self.clip_det_res(box, img_height, img_width)
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def close(self):
logging.info("Close text detector.")
if hasattr(self, "predictor"):
del self.predictor
gc.collect()
def __call__(self, img):
ori_im = img.copy()
data = {'image': img}
st = time.time()
data = transform(data, self.preprocess_op)
img, shape_list = data
if img is None:
return None, 0
img = np.expand_dims(img, axis=0)
shape_list = np.expand_dims(shape_list, axis=0)
img = img.copy()
input_dict = {}
input_dict[self.input_tensor.name] = img
for i in range(100000):
try:
outputs = self.predictor.run(None, input_dict, self.run_options)
break
except Exception as e:
if i >= 3:
raise e
time.sleep(5)
post_result = self.postprocess_op({"maps": outputs[0]}, shape_list)
dt_boxes = post_result[0]['points']
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
return dt_boxes, time.time() - st
def __del__(self):
self.close()
class OCR:
def __init__(self, model_dir=None):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
For Linux:
export HF_ENDPOINT=https://hf-mirror.com
For Windows:
Good luck
^_-
"""
if not model_dir:
try:
# 使用配置中的 MODEL_DIR如果不存在则尝试默认路径
if MODEL_DIR and os.path.exists(MODEL_DIR):
model_dir = MODEL_DIR
else:
model_dir = os.path.join(
get_project_base_directory(),
"models", "deepdoc")
# Append muti-gpus task to the list
if PARALLEL_DEVICES > 0:
self.text_detector = []
self.text_recognizer = []
for device_id in range(PARALLEL_DEVICES):
self.text_detector.append(TextDetector(model_dir, device_id))
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
else:
self.text_detector = [TextDetector(model_dir)]
self.text_recognizer = [TextRecognizer(model_dir)]
except Exception:
# 如果模型目录不存在,尝试从 HuggingFace 下载
default_model_dir = os.path.join(
get_project_base_directory(), "models", "deepdoc")
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
local_dir=default_model_dir,
local_dir_use_symlinks=False)
if PARALLEL_DEVICES > 0:
self.text_detector = []
self.text_recognizer = []
for device_id in range(PARALLEL_DEVICES):
self.text_detector.append(TextDetector(model_dir, device_id))
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
else:
self.text_detector = [TextDetector(model_dir)]
self.text_recognizer = [TextRecognizer(model_dir)]
else:
# 如果指定了 model_dir直接使用
if PARALLEL_DEVICES > 0:
self.text_detector = []
self.text_recognizer = []
for device_id in range(PARALLEL_DEVICES):
self.text_detector.append(TextDetector(model_dir, device_id))
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
else:
self.text_detector = [TextDetector(model_dir)]
self.text_recognizer = [TextRecognizer(model_dir)]
self.drop_score = 0.5
self.crop_image_res_index = 0
def get_rotate_crop_image(self, img, points):
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
'''
assert len(points) == 4, "shape of points must be 4*2"
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])))
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
# Try original orientation
rec_result = self.text_recognizer[0]([dst_img])
text, score = rec_result[0][0]
best_score = score
best_img = dst_img
# Try clockwise 90° rotation
rotated_cw = np.rot90(dst_img, k=3)
rec_result = self.text_recognizer[0]([rotated_cw])
rotated_cw_text, rotated_cw_score = rec_result[0][0]
if rotated_cw_score > best_score:
best_score = rotated_cw_score
best_img = rotated_cw
# Try counter-clockwise 90° rotation
rotated_ccw = np.rot90(dst_img, k=1)
rec_result = self.text_recognizer[0]([rotated_ccw])
rotated_ccw_text, rotated_ccw_score = rec_result[0][0]
if rotated_ccw_score > best_score:
best_img = rotated_ccw
# Use the best image
dst_img = best_img
return dst_img
def sorted_boxes(self, dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
for j in range(i, -1, -1):
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
tmp = _boxes[j]
_boxes[j] = _boxes[j + 1]
_boxes[j + 1] = tmp
else:
break
return _boxes
def detect(self, img, device_id: int | None = None):
if device_id is None:
device_id = 0
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
if img is None:
return None, None, time_dict
start = time.time()
dt_boxes, elapse = self.text_detector[device_id](img)
time_dict['det'] = elapse
if dt_boxes is None:
end = time.time()
time_dict['all'] = end - start
return None, None, time_dict
return zip(self.sorted_boxes(dt_boxes), [
("", 0) for _ in range(len(dt_boxes))])
def recognize(self, ori_im, box, device_id: int | None = None):
if device_id is None:
device_id = 0
img_crop = self.get_rotate_crop_image(ori_im, box)
rec_res, elapse = self.text_recognizer[device_id]([img_crop])
text, score = rec_res[0]
if score < self.drop_score:
return ""
return text
def recognize_batch(self, img_list, device_id: int | None = None):
if device_id is None:
device_id = 0
rec_res, elapse = self.text_recognizer[device_id](img_list)
texts = []
for i in range(len(rec_res)):
text, score = rec_res[i]
if score < self.drop_score:
text = ""
texts.append(text)
return texts
def __call__(self, img, device_id = 0, cls=True):
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
if device_id is None:
device_id = 0
if img is None:
return None, None, time_dict
start = time.time()
ori_im = img.copy()
dt_boxes, elapse = self.text_detector[device_id](img)
time_dict['det'] = elapse
if dt_boxes is None:
end = time.time()
time_dict['all'] = end - start
return None, None, time_dict
img_crop_list = []
dt_boxes = self.sorted_boxes(dt_boxes)
for bno in range(len(dt_boxes)):
tmp_box = copy.deepcopy(dt_boxes[bno])
img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
img_crop_list.append(img_crop)
rec_res, elapse = self.text_recognizer[device_id](img_crop_list)
time_dict['rec'] = elapse
filter_boxes, filter_rec_res = [], []
for box, rec_result in zip(dt_boxes, rec_res):
text, score = rec_result
if score >= self.drop_score:
filter_boxes.append(box)
filter_rec_res.append(rec_result)
end = time.time()
time_dict['all'] = end - start
# for bno in range(len(img_crop_list)):
# print(f"{bno}, {rec_res[bno]}")
return list(zip([a.tolist() for a in filter_boxes], filter_rec_res))

726
ocr/operators.py Normal file
View File

@@ -0,0 +1,726 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import sys
import six
import cv2
import numpy as np
import math
from PIL import Image
class DecodeImage:
""" decode image """
def __init__(self,
img_mode='RGB',
channel_first=False,
ignore_orientation=False,
**kwargs):
self.img_mode = img_mode
self.channel_first = channel_first
self.ignore_orientation = ignore_orientation
def __call__(self, data):
img = data['image']
if six.PY2:
assert isinstance(img, str) and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert isinstance(img, bytes) and len(
img) > 0, "invalid input 'img' in DecodeImage"
img = np.frombuffer(img, dtype='uint8')
if self.ignore_orientation:
img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION |
cv2.IMREAD_COLOR)
else:
img = cv2.imdecode(img, 1)
if img is None:
return None
if self.img_mode == 'GRAY':
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif self.img_mode == 'RGB':
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
img.shape)
img = img[:, :, ::-1]
if self.channel_first:
img = img.transpose((2, 0, 1))
data['image'] = img
return data
class StandardizeImag:
"""normalize image
Args:
mean (list): im - mean
std (list): im / std
is_scale (bool): whether need im / 255
norm_type (str): type in ['mean_std', 'none']
"""
def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
self.mean = mean
self.std = std
self.is_scale = is_scale
self.norm_type = norm_type
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.astype(np.float32, copy=False)
if self.is_scale:
scale = 1.0 / 255.0
im *= scale
if self.norm_type == 'mean_std':
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im -= mean
im /= std
return im, im_info
class NormalizeImage:
""" normalize image such as subtract mean, divide std
"""
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
if isinstance(scale, str):
scale = eval(scale)
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32')
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = (
img.astype('float32') * self.scale - self.mean) / self.std
return data
class ToCHWImage:
""" convert hwc image to chw image
"""
def __init__(self, **kwargs):
pass
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
data['image'] = img.transpose((2, 0, 1))
return data
class KeepKeys:
def __init__(self, keep_keys, **kwargs):
self.keep_keys = keep_keys
def __call__(self, data):
data_list = []
for key in self.keep_keys:
data_list.append(data[key])
return data_list
class Pad:
def __init__(self, size=None, size_div=32, **kwargs):
if size is not None and not isinstance(size, (int, list, tuple)):
raise TypeError("Type of target_size is invalid. Now is {}".format(
type(size)))
if isinstance(size, int):
size = [size, size]
self.size = size
self.size_div = size_div
def __call__(self, data):
img = data['image']
img_h, img_w = img.shape[0], img.shape[1]
if self.size:
resize_h2, resize_w2 = self.size
assert (
img_h < resize_h2 and img_w < resize_w2
), '(h, w) of target size should be greater than (img_h, img_w)'
else:
resize_h2 = max(
int(math.ceil(img.shape[0] / self.size_div) * self.size_div),
self.size_div)
resize_w2 = max(
int(math.ceil(img.shape[1] / self.size_div) * self.size_div),
self.size_div)
img = cv2.copyMakeBorder(
img,
0,
resize_h2 - img_h,
0,
resize_w2 - img_w,
cv2.BORDER_CONSTANT,
value=0)
data['image'] = img
return data
class LinearResize:
"""resize image by target_size and max_size
Args:
target_size (int): the target size of image
keep_ratio (bool): whether keep_ratio or not, default true
interp (int): method of resize
"""
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
self.keep_ratio = keep_ratio
self.interp = interp
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
_im_channel = im.shape[2]
im_scale_y, im_scale_x = self.generate_scale(im)
im = cv2.resize(
im,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
im_info['scale_factor'] = np.array(
[im_scale_y, im_scale_x]).astype('float32')
return im, im_info
def generate_scale(self, im):
"""
Args:
im (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape = im.shape[:2]
_im_c = im.shape[2]
if self.keep_ratio:
im_size_min = np.min(origin_shape)
im_size_max = np.max(origin_shape)
target_size_min = np.min(self.target_size)
target_size_max = np.max(self.target_size)
im_scale = float(target_size_min) / float(im_size_min)
if np.round(im_scale * im_size_max) > target_size_max:
im_scale = float(target_size_max) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
else:
resize_h, resize_w = self.target_size
im_scale_y = resize_h / float(origin_shape[0])
im_scale_x = resize_w / float(origin_shape[1])
return im_scale_y, im_scale_x
class Resize:
def __init__(self, size=(640, 640), **kwargs):
self.size = size
def resize_image(self, img):
resize_h, resize_w = self.size
ori_h, ori_w = img.shape[:2] # (h, w, c)
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h)))
return img, [ratio_h, ratio_w]
def __call__(self, data):
img = data['image']
if 'polys' in data:
text_polys = data['polys']
img_resize, [ratio_h, ratio_w] = self.resize_image(img)
if 'polys' in data:
new_boxes = []
for box in text_polys:
new_box = []
for cord in box:
new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
new_boxes.append(new_box)
data['polys'] = np.array(new_boxes, dtype=np.float32)
data['image'] = img_resize
return data
class DetResizeForTest:
def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__()
self.resize_type = 0
self.keep_ratio = False
if 'image_shape' in kwargs:
self.image_shape = kwargs['image_shape']
self.resize_type = 1
if 'keep_ratio' in kwargs:
self.keep_ratio = kwargs['keep_ratio']
elif 'limit_side_len' in kwargs:
self.limit_side_len = kwargs['limit_side_len']
self.limit_type = kwargs.get('limit_type', 'min')
elif 'resize_long' in kwargs:
self.resize_type = 2
self.resize_long = kwargs.get('resize_long', 960)
else:
self.limit_side_len = 736
self.limit_type = 'min'
def __call__(self, data):
img = data['image']
src_h, src_w, _ = img.shape
if sum([src_h, src_w]) < 64:
img = self.image_padding(img)
if self.resize_type == 0:
# img, shape = self.resize_image_type0(img)
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
elif self.resize_type == 2:
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
else:
# img, shape = self.resize_image_type1(img)
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
data['image'] = img
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def image_padding(self, im, value=0):
h, w, c = im.shape
im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
im_pad[:h, :w, :] = im
return im_pad
def resize_image_type1(self, img):
resize_h, resize_w = self.image_shape
ori_h, ori_w = img.shape[:2] # (h, w, c)
if self.keep_ratio is True:
resize_w = ori_w * resize_h / ori_h
N = math.ceil(resize_w / 32)
resize_w = N * 32
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h)))
# return img, np.array([ori_h, ori_w])
return img, [ratio_h, ratio_w]
def resize_image_type0(self, img):
"""
resize image to a size multiple of 32 which is required by the network
args:
img(array): array with shape [h, w, c]
return(tuple):
img, (ratio_h, ratio_w)
"""
limit_side_len = self.limit_side_len
h, w, c = img.shape
# limit the max side
if self.limit_type == 'max':
if max(h, w) > limit_side_len:
if h > w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
elif self.limit_type == 'min':
if min(h, w) < limit_side_len:
if h < w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
elif self.limit_type == 'resize_long':
ratio = float(limit_side_len) / max(h, w)
else:
raise Exception('not support limit type, image ')
resize_h = int(h * ratio)
resize_w = int(w * ratio)
resize_h = max(int(round(resize_h / 32) * 32), 32)
resize_w = max(int(round(resize_w / 32) * 32), 32)
try:
if int(resize_w) <= 0 or int(resize_h) <= 0:
return None, (None, None)
img = cv2.resize(img, (int(resize_w), int(resize_h)))
except BaseException:
logging.exception("{} {} {}".format(img.shape, resize_w, resize_h))
sys.exit(0)
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
def resize_image_type2(self, img):
h, w, _ = img.shape
resize_w = w
resize_h = h
if resize_h > resize_w:
ratio = float(self.resize_long) / resize_h
else:
ratio = float(self.resize_long) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
img = cv2.resize(img, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
class E2EResizeForTest:
def __init__(self, **kwargs):
super(E2EResizeForTest, self).__init__()
self.max_side_len = kwargs['max_side_len']
self.valid_set = kwargs['valid_set']
def __call__(self, data):
img = data['image']
src_h, src_w, _ = img.shape
if self.valid_set == 'totaltext':
im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
img, max_side_len=self.max_side_len)
else:
im_resized, (ratio_h, ratio_w) = self.resize_image(
img, max_side_len=self.max_side_len)
data['image'] = im_resized
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def resize_image_for_totaltext(self, im, max_side_len=512):
h, w, _ = im.shape
resize_w = w
resize_h = h
ratio = 1.25
if h * ratio > max_side_len:
ratio = float(max_side_len) / resize_h
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image(self, im, max_side_len=512):
"""
resize image to a size multiple of max_stride which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
# Fix the longer side
if resize_h > resize_w:
ratio = float(max_side_len) / resize_h
else:
ratio = float(max_side_len) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
class KieResize:
def __init__(self, **kwargs):
super(KieResize, self).__init__()
self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[
'img_scale'][1]
def __call__(self, data):
img = data['image']
points = data['points']
src_h, src_w, _ = img.shape
im_resized, scale_factor, [ratio_h, ratio_w
], [new_h, new_w] = self.resize_image(img)
resize_points = self.resize_boxes(img, points, scale_factor)
data['ori_image'] = img
data['ori_boxes'] = points
data['points'] = resize_points
data['image'] = im_resized
data['shape'] = np.array([new_h, new_w])
return data
def resize_image(self, img):
norm_img = np.zeros([1024, 1024, 3], dtype='float32')
scale = [512, 1024]
h, w = img.shape[:2]
max_long_edge = max(scale)
max_short_edge = min(scale)
scale_factor = min(max_long_edge / max(h, w),
max_short_edge / min(h, w))
resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float(
scale_factor) + 0.5)
max_stride = 32
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(img, (resize_w, resize_h))
new_h, new_w = im.shape[:2]
w_scale = new_w / w
h_scale = new_h / h
scale_factor = np.array(
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
norm_img[:new_h, :new_w, :] = im
return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
def resize_boxes(self, im, points, scale_factor):
points = points * scale_factor
img_shape = im.shape[:2]
points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
return points
class SRResize:
def __init__(self,
imgH=32,
imgW=128,
down_sample_scale=4,
keep_ratio=False,
min_ratio=1,
mask=False,
infer_mode=False,
**kwargs):
self.imgH = imgH
self.imgW = imgW
self.keep_ratio = keep_ratio
self.min_ratio = min_ratio
self.down_sample_scale = down_sample_scale
self.mask = mask
self.infer_mode = infer_mode
def __call__(self, data):
imgH = self.imgH
imgW = self.imgW
images_lr = data["image_lr"]
transform2 = ResizeNormalize(
(imgW // self.down_sample_scale, imgH // self.down_sample_scale))
images_lr = transform2(images_lr)
data["img_lr"] = images_lr
if self.infer_mode:
return data
images_HR = data["image_hr"]
_label_strs = data["label"]
transform = ResizeNormalize((imgW, imgH))
images_HR = transform(images_HR)
data["img_hr"] = images_HR
return data
class ResizeNormalize:
def __init__(self, size, interpolation=Image.BICUBIC):
self.size = size
self.interpolation = interpolation
def __call__(self, img):
img = img.resize(self.size, self.interpolation)
img_numpy = np.array(img).astype("float32")
img_numpy = img_numpy.transpose((2, 0, 1)) / 255
return img_numpy
class GrayImageChannelFormat:
"""
format gray scale image's channel: (3,h,w) -> (1,h,w)
Args:
inverse: inverse gray image
"""
def __init__(self, inverse=False, **kwargs):
self.inverse = inverse
def __call__(self, data):
img = data['image']
img_single_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img_expanded = np.expand_dims(img_single_channel, 0)
if self.inverse:
data['image'] = np.abs(img_expanded - 1)
else:
data['image'] = img_expanded
data['src_image'] = img
return data
class Permute:
"""permute image
Args:
to_bgr (bool): whether convert RGB to BGR
channel_first (bool): whether convert HWC to CHW
"""
def __init__(self, ):
super(Permute, self).__init__()
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.transpose((2, 0, 1)).copy()
return im, im_info
class PadStride:
""" padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
Args:
stride (bool): model with FPN need image shape % stride == 0
"""
def __init__(self, stride=0):
self.coarsest_stride = stride
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
coarsest_stride = self.coarsest_stride
if coarsest_stride <= 0:
return im, im_info
im_c, im_h, im_w = im.shape
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = im
return padding_im, im_info
def decode_image(im_file, im_info):
"""read rgb image
Args:
im_file (str|np.ndarray): input can be image path or np.ndarray
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
if isinstance(im_file, str):
with open(im_file, 'rb') as f:
im_read = f.read()
data = np.frombuffer(im_read, dtype='uint8')
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
else:
im = im_file
im_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32)
im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32)
return im, im_info
def preprocess(im, preprocess_ops):
# process image by preprocess_ops
im_info = {
'scale_factor': np.array(
[1., 1.], dtype=np.float32),
'im_shape': None,
}
im, im_info = decode_image(im, im_info)
for operator in preprocess_ops:
im, im_info = operator(im, im_info)
return im, im_info
def nms(bboxes, scores, iou_thresh):
import numpy as np
x1 = bboxes[:, 0]
y1 = bboxes[:, 1]
x2 = bboxes[:, 2]
y2 = bboxes[:, 3]
areas = (y2 - y1) * (x2 - x1)
indices = []
index = scores.argsort()[::-1]
while index.size > 0:
i = index[0]
indices.append(i)
x11 = np.maximum(x1[i], x1[index[1:]])
y11 = np.maximum(y1[i], y1[index[1:]])
x22 = np.minimum(x2[i], x2[index[1:]])
y22 = np.minimum(y2[i], y2[index[1:]])
w = np.maximum(0, x22 - x11 + 1)
h = np.maximum(0, y22 - y11 + 1)
overlaps = w * h
ious = overlaps / (areas[i] + areas[index[1:]] - overlaps)
idx = np.where(ious <= iou_thresh)[0]
index = index[idx + 1]
return indices

339
ocr/pdf_parser.py Normal file
View File

@@ -0,0 +1,339 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
简化的PDF解析器只使用OCR处理PDF文档
从 RAGFlow 的 RAGFlowPdfParser 中提取OCR相关功能移除了
- 布局识别Layout Recognition
- 表格结构识别Table Structure Recognition
- 文本合并和语义分析
- RAG相关功能
只保留:
- PDF转图片
- OCR文本检测和识别
- 基本的文本和位置信息返回
"""
import logging
import sys
import threading
from io import BytesIO
from pathlib import Path
from timeit import default_timer as timer
import numpy as np
import pdfplumber
import trio
# 处理导入问题:支持直接运行和模块导入
try:
_package = __package__
except NameError:
_package = None
if _package is None:
# 直接运行时,添加父目录到路径并使用绝对导入
parent_dir = Path(__file__).parent.parent
if str(parent_dir) not in sys.path:
sys.path.insert(0, str(parent_dir))
from ocr.config import PARALLEL_DEVICES
from ocr.ocr import OCR
else:
# 作为模块导入时使用相对导入
from config import PARALLEL_DEVICES
from ocr import OCR
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
if LOCK_KEY_pdfplumber not in sys.modules:
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
class SimplePdfParser:
"""
简化的PDF解析器只使用OCR处理PDF
使用方法:
parser = SimplePdfParser()
result = parser.parse_pdf("file.pdf") # 或传入二进制数据
# result 格式:
# {
# "pages": [
# {
# "page_number": 1,
# "boxes": [
# {
# "text": "识别到的文本",
# "bbox": [[x0, y0], [x1, y0], [x1, y1], [x0, y1]],
# "confidence": 0.95
# },
# ...
# ]
# },
# ...
# ]
# }
"""
def __init__(self, model_dir=None):
"""
初始化PDF解析器
Args:
model_dir: OCR模型目录如果为None则使用默认路径
"""
self.ocr = OCR(model_dir=model_dir)
self.parallel_limiter = None
if PARALLEL_DEVICES > 1:
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)]
def __ocr_page(self, page_num, img, zoomin=3, device_id=None):
"""
对单页进行OCR处理
Args:
page_num: 页码
img: PIL Image对象
zoomin: 放大倍数(用于坐标缩放)
device_id: GPU设备ID
Returns:
list: OCR结果列表每个元素为 {"text": str, "bbox": list, "confidence": float}
"""
start = timer()
img_np = np.array(img)
# 文本检测
# detect方法返回: zip对象格式为 (box_coords, (text, score))
# 但检测阶段text和score都是默认值需要后续识别
detection_result = self.ocr.detect(img_np, device_id)
if detection_result is None:
return []
# 转换为列表并提取box坐标
# detect返回的格式是zip每个元素是 (box_coords, (text, score))
# 在检测阶段text是空字符串score是0
bxs = list(detection_result)
logging.info(f"Page {page_num}: OCR detection found {len(bxs)} boxes in {timer() - start:.2f}s")
if not bxs:
return []
# 解析检测结果并准备识别
boxes_to_reg = []
start = timer()
for box_coords, _, _ in bxs:
# box_coords 是四边形坐标: [[x0, y0], [x1, y0], [x1, y1], [x0, y1]]
# 转换为原始坐标考虑zoomin
box_coords_np = np.array(box_coords, dtype=np.float32)
original_coords = box_coords_np / zoomin # 缩放回原始坐标
# 裁剪图像用于识别
# 使用放大后的坐标裁剪因为img_np是放大后的图像
crop_box = box_coords_np
crop_img = self.ocr.get_rotate_crop_image(img_np, crop_box)
boxes_to_reg.append({
"bbox": original_coords.tolist(),
"crop_img": crop_img
})
# 批量识别文本
ocr_results = []
if boxes_to_reg:
crop_imgs = [b["crop_img"] for b in boxes_to_reg]
texts = self.ocr.recognize_batch(crop_imgs, device_id)
# 组装结果
for i, b in enumerate(boxes_to_reg):
if i < len(texts) and texts[i]: # 过滤空文本
ocr_results.append({
"text": texts[i],
"bbox": b["bbox"],
"confidence": 0.9 # 简化版本,不计算具体置信度
})
logging.info(f"Page {page_num}: OCR recognition {len(ocr_results)} boxes cost {timer() - start:.2f}s")
return ocr_results
async def __ocr_page_async(self, page_num, img, zoomin, device_id, limiter, callback):
"""
异步OCR处理单页
Args:
page_num: 页码
img: PIL Image对象
zoomin: 放大倍数
device_id: GPU设备ID
limiter: 并发限制器
callback: 进度回调函数
"""
if limiter:
async with limiter:
result = await trio.to_thread.run_sync(
lambda: self.__ocr_page(page_num, img, zoomin, device_id)
)
else:
result = await trio.to_thread.run_sync(
lambda: self.__ocr_page(page_num, img, zoomin, device_id)
)
if callback and page_num % 5 == 0:
callback(prog=page_num * 0.9 / 100, msg=f"Processing page {page_num}...")
return result
def __convert_pdf_to_images(self, pdf_source, zoomin=3, page_from=0, page_to=299):
"""
将PDF转换为图片
Args:
pdf_source: PDF文件路径str或二进制数据bytes
zoomin: 放大倍数默认372*3=216 DPI
page_from: 起始页码从0开始
page_to: 结束页码
Returns:
list: PIL Image对象列表
"""
start = timer()
page_images = []
try:
with sys.modules[LOCK_KEY_pdfplumber]:
pdf = pdfplumber.open(pdf_source) if isinstance(pdf_source, str) else pdfplumber.open(BytesIO(pdf_source))
try:
# 转换为图片resolution = 72 * zoomin
page_images = [
p.to_image(resolution=72 * zoomin, antialias=True).annotated
for i, p in enumerate(pdf.pages[page_from:page_to])
]
pdf.close()
except Exception as e:
logging.warning(f"Failed to convert PDF pages {page_from}-{page_to}: {str(e)}")
if hasattr(pdf, 'close'):
pdf.close()
except Exception as e:
logging.exception(f"Error converting PDF to images: {str(e)}")
logging.info(f"Converted {len(page_images)} pages to images in {timer() - start:.2f}s")
return page_images
def parse_pdf(self, pdf_source, zoomin=3, page_from=0, page_to=299, callback=None):
"""
解析PDF文档使用OCR识别文本
Args:
pdf_source: PDF文件路径str或二进制数据bytes
zoomin: 放大倍数默认3
page_from: 起始页码从0开始
page_to: 结束页码
callback: 进度回调函数,格式: callback(prog: float, msg: str)
Returns:
dict: 解析结果
{
"pages": [
{
"page_number": int,
"boxes": [
{
"text": str,
"bbox": [[x0, y0], [x1, y0], [x1, y1], [x0, y1]],
"confidence": float
},
...
]
},
...
]
}
"""
if callback:
callback(0.0, "Starting PDF parsing...")
# 1. 转换为图片
if callback:
callback(0.1, "Converting PDF to images...")
page_images = self.__convert_pdf_to_images(pdf_source, zoomin, page_from, page_to)
if not page_images:
logging.warning("No pages converted from PDF")
return {"pages": []}
# 2. OCR处理
async def process_all_pages():
pages_result = []
if self.parallel_limiter:
# 并行处理多GPU
async with trio.open_nursery() as nursery:
tasks = []
for i, img in enumerate(page_images):
page_num = page_from + i + 1
device_id = i % PARALLEL_DEVICES
task = nursery.start_soon(
self.__ocr_page_async,
page_num, img, zoomin, device_id,
self.parallel_limiter[device_id], callback
)
tasks.append(task)
# 等待所有任务完成并收集结果
for i, task in enumerate(tasks):
result = await task
pages_result.append({
"page_number": page_from + i + 1,
"boxes": result
})
else:
# 串行处理单GPU或CPU
for i, img in enumerate(page_images):
page_num = page_from + i + 1
result = await trio.to_thread.run_sync(
lambda img=img, pn=page_num: self.__ocr_page(pn, img, zoomin, 0)
)
pages_result.append({
"page_number": page_num,
"boxes": result
})
if callback:
callback(0.1 + (i + 1) * 0.9 / len(page_images), f"Processing page {page_num}...")
return pages_result
# 运行异步处理
if callback:
callback(0.2, "Starting OCR processing...")
start = timer()
pages_result = trio.run(process_all_pages)
logging.info(f"OCR processing completed in {timer() - start:.2f}s")
if callback:
callback(1.0, "OCR processing completed")
return {
"pages": pages_result
}
# 向后兼容的别名
PdfParser = SimplePdfParser

371
ocr/postprocess.py Normal file
View File

@@ -0,0 +1,371 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
import re
import numpy as np
import cv2
from shapely.geometry import Polygon
import pyclipper
def build_post_process(config, global_config=None):
support_dict = {'DBPostProcess': DBPostProcess, 'CTCLabelDecode': CTCLabelDecode}
config = copy.deepcopy(config)
module_name = config.pop('name')
if module_name == "None":
return
if global_config is not None:
config.update(global_config)
module_class = support_dict.get(module_name)
if module_class is None:
raise ValueError(
'post process only support {}'.format(list(support_dict)))
return module_class(**config)
class DBPostProcess:
"""
The post process for Differentiable Binarization (DB).
"""
def __init__(self,
thresh=0.3,
box_thresh=0.7,
max_candidates=1000,
unclip_ratio=2.0,
use_dilation=False,
score_mode="fast",
box_type='quad',
**kwargs):
self.thresh = thresh
self.box_thresh = box_thresh
self.max_candidates = max_candidates
self.unclip_ratio = unclip_ratio
self.min_size = 3
self.score_mode = score_mode
self.box_type = box_type
assert score_mode in [
"slow", "fast"
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
self.dilation_kernel = None if not use_dilation else np.array(
[[1, 1], [1, 1]])
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap = _bitmap
height, width = bitmap.shape
boxes = []
scores = []
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours[:self.max_candidates]:
epsilon = 0.002 * cv2.arcLength(contour, True)
approx = cv2.approxPolyDP(contour, epsilon, True)
points = approx.reshape((-1, 2))
if points.shape[0] < 4:
continue
score = self.box_score_fast(pred, points.reshape(-1, 2))
if self.box_thresh > score:
continue
if points.shape[0] > 2:
box = self.unclip(points, self.unclip_ratio)
if len(box) > 1:
continue
else:
continue
box = box.reshape(-1, 2)
_, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
if sside < self.min_size + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(
np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box.tolist())
scores.append(score)
return boxes, scores
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap = _bitmap
height, width = bitmap.shape
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
cv2.CHAIN_APPROX_SIMPLE)
if len(outs) == 3:
_img, contours, _ = outs[0], outs[1], outs[2]
elif len(outs) == 2:
contours, _ = outs[0], outs[1]
num_contours = min(len(contours), self.max_candidates)
boxes = []
scores = []
for index in range(num_contours):
contour = contours[index]
points, sside = self.get_mini_boxes(contour)
if sside < self.min_size:
continue
points = np.array(points)
if self.score_mode == "fast":
score = self.box_score_fast(pred, points.reshape(-1, 2))
else:
score = self.box_score_slow(pred, contour)
if self.box_thresh > score:
continue
box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
box, sside = self.get_mini_boxes(box)
if sside < self.min_size + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(
np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box.astype("int32"))
scores.append(score)
return np.array(boxes, dtype="int32"), scores
def unclip(self, box, unclip_ratio):
poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset()
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
expanded = np.array(offset.Execute(distance))
return expanded
def get_mini_boxes(self, contour):
bounding_box = cv2.minAreaRect(contour)
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
if points[1][1] > points[0][1]:
index_1 = 0
index_4 = 1
else:
index_1 = 1
index_4 = 0
if points[3][1] > points[2][1]:
index_2 = 2
index_3 = 3
else:
index_2 = 3
index_3 = 2
box = [
points[index_1], points[index_2], points[index_3], points[index_4]
]
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
'''
box_score_fast: use bbox mean score as the mean score
'''
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def box_score_slow(self, bitmap, contour):
'''
box_score_slow: use polyon mean score as the mean score
'''
h, w = bitmap.shape[:2]
contour = contour.copy()
contour = np.reshape(contour, (-1, 2))
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
contour[:, 0] = contour[:, 0] - xmin
contour[:, 1] = contour[:, 1] - ymin
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def __call__(self, outs_dict, shape_list):
pred = outs_dict['maps']
if not isinstance(pred, np.ndarray):
pred = pred.numpy()
pred = pred[:, 0, :, :]
segmentation = pred > self.thresh
boxes_batch = []
for batch_index in range(pred.shape[0]):
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
if self.dilation_kernel is not None:
mask = cv2.dilate(
np.array(segmentation[batch_index]).astype(np.uint8),
self.dilation_kernel)
else:
mask = segmentation[batch_index]
if self.box_type == 'poly':
boxes, scores = self.polygons_from_bitmap(pred[batch_index],
mask, src_w, src_h)
elif self.box_type == 'quad':
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
src_w, src_h)
else:
raise ValueError(
"box_type can only be one of ['quad', 'poly']")
boxes_batch.append({'points': boxes})
return boxes_batch
class BaseRecLabelDecode:
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False):
self.beg_str = "sos"
self.end_str = "eos"
self.reverse = False
self.character_str = []
if character_dict_path is None:
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
else:
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
self.character_str.append(line)
if use_space_char:
self.character_str.append(" ")
dict_character = list(self.character_str)
if 'arabic' in character_dict_path:
self.reverse = True
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
def pred_reverse(self, pred):
pred_re = []
c_current = ''
for c in pred:
if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)):
if c_current != '':
pred_re.append(c_current)
pred_re.append(c)
c_current = ''
else:
c_current += c
if c_current != '':
pred_re.append(c_current)
return ''.join(pred_re[::-1])
def add_special_char(self, dict_character):
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
if is_remove_duplicate:
selection[1:] = text_index[batch_idx][1:] != text_index[
batch_idx][:-1]
for ignored_token in ignored_tokens:
selection &= text_index[batch_idx] != ignored_token
char_list = [
self.character[text_id]
for text_id in text_index[batch_idx][selection]
]
if text_prob is not None:
conf_list = text_prob[batch_idx][selection]
else:
conf_list = [1] * len(selection)
if len(conf_list) == 0:
conf_list = [0]
text = ''.join(char_list)
if self.reverse: # for arabic rec
text = self.pred_reverse(text)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def get_ignored_tokens(self):
return [0] # for ctc blank
class CTCLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(CTCLabelDecode, self).__init__(character_dict_path,
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, tuple) or isinstance(preds, list):
preds = preds[-1]
if not isinstance(preds, np.ndarray):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
if label is None:
return text
label = self.decode(label)
return text, label
def add_special_char(self, dict_character):
dict_character = ['blank'] + dict_character
return dict_character

25
ocr/requirements.txt Normal file
View File

@@ -0,0 +1,25 @@
# OCR PDF处理模块依赖
# 核心依赖
numpy>=1.21.0
opencv-python>=4.5.0
pillow>=8.0.0
pdfplumber>=0.9.0
onnxruntime>=1.12.0
trio>=0.22.0
# 几何计算依赖
shapely>=1.8.0
pyclipper>=1.2.0
# Web框架依赖
fastapi>=0.100.0
uvicorn[standard]>=0.23.0
pydantic>=2.0.0
# 模型下载依赖
huggingface_hub>=0.16.0
# 可选依赖用于GPU检测和加速
# torch>=1.12.0 # 如果需要GPU支持取消注释并安装
# onnxruntime-gpu>=1.12.0 # 如果需要GPU支持取消注释并安装

40
ocr/utils.py Normal file
View File

@@ -0,0 +1,40 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
OCR 模块工具函数
"""
import os
def get_project_base_directory(*args):
"""
获取项目根目录
Args:
*args: 可选的子路径
Returns:
str: 项目根目录路径
"""
# 获取当前文件的目录
current_dir = os.path.dirname(os.path.realpath(__file__))
# 返回 ocr 模块的父目录(项目根目录)
base_dir = os.path.dirname(current_dir)
if args:
return os.path.join(base_dir, *args)
return base_dir