diff --git a/deepdoc/parser/ocr_http_client.py b/deepdoc/parser/ocr_http_client.py new file mode 100644 index 0000000..babc6ba --- /dev/null +++ b/deepdoc/parser/ocr_http_client.py @@ -0,0 +1,175 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +OCR HTTP 客户端 +用于调用独立的 OCR 服务的 HTTP API +""" + +import os +import logging +import requests +from typing import Optional, Union, Dict, Any + +logger = logging.getLogger(__name__) + + +class OCRHttpClient: + """OCR HTTP 客户端,用于调用独立的 OCR 服务""" + + def __init__(self, base_url: Optional[str] = None, timeout: int = 300): + """ + 初始化 OCR HTTP 客户端 + + Args: + base_url: OCR 服务的基础 URL,如果不提供则从环境变量 OCR_SERVICE_URL 读取 + 默认值为 http://localhost:8000 + timeout: 请求超时时间(秒),默认 300 秒 + """ + if base_url is None: + base_url = os.getenv("OCR_SERVICE_URL", "http://localhost:8000") + + # 确保 URL 不包含尾随斜杠 + self.base_url = base_url.rstrip("/") + self.timeout = timeout + self.api_prefix = "/api/v1/ocr" + + logger.info(f"Initialized OCR HTTP client with base_url: {self.base_url}") + + def parse_pdf_by_path(self, file_path: str, page_from: int = 1, page_to: int = 0, zoomin: int = 3) -> Dict[str, Any]: + """ + 通过文件路径解析 PDF + + Args: + file_path: PDF 文件的本地路径 + page_from: 起始页码(从1开始) + page_to: 结束页码(0表示最后一页) + zoomin: 图像放大倍数(1-5) + + Returns: + dict: 解析结果,格式: + { + "success": bool, + "message": str, + "data": { + "pages": [ + { + "page_number": int, + "boxes": [ + { + "text": str, + "bbox": [[x0, y0], [x1, y0], [x1, y1], [x0, y1]], + "confidence": float + }, + ... + ] + }, + ... + ] + } + } + + Raises: + requests.RequestException: HTTP 请求失败 + ValueError: 响应格式不正确 + """ + url = f"{self.base_url}{self.api_prefix}/parse/path" + + data = { + "file_path": file_path, + "page_from": page_from, + "page_to": page_to, + "zoomin": zoomin + } + + try: + logger.info(f"Calling OCR service: {url} for file: {file_path}") + response = requests.post(url, data=data, timeout=self.timeout) + response.raise_for_status() + + result = response.json() + if not result.get("success", False): + raise ValueError(f"OCR service returned error: {result.get('message', 'Unknown error')}") + + return result + + except requests.RequestException as e: + logger.error(f"Failed to call OCR service: {e}") + raise + + def parse_pdf_by_bytes(self, pdf_bytes: bytes, filename: str = "document.pdf", + page_from: int = 1, page_to: int = 0, zoomin: int = 3) -> Dict[str, Any]: + """ + 通过二进制数据解析 PDF + + Args: + pdf_bytes: PDF 文件的二进制数据 + filename: 文件名(仅用于日志) + page_from: 起始页码(从1开始) + page_to: 结束页码(0表示最后一页) + zoomin: 图像放大倍数(1-5) + + Returns: + dict: 解析结果,格式同 parse_pdf_by_path + + Raises: + requests.RequestException: HTTP 请求失败 + ValueError: 响应格式不正确 + """ + url = f"{self.base_url}{self.api_prefix}/parse/bytes" + + files = { + "pdf_bytes": (filename, pdf_bytes, "application/pdf") + } + + data = { + "filename": filename, + "page_from": page_from, + "page_to": page_to, + "zoomin": zoomin + } + + try: + logger.info(f"Calling OCR service: {url} with {len(pdf_bytes)} bytes") + response = requests.post(url, files=files, data=data, timeout=self.timeout) + response.raise_for_status() + + result = response.json() + if not result.get("success", False): + raise ValueError(f"OCR service returned error: {result.get('message', 'Unknown error')}") + + return result + + except requests.RequestException as e: + logger.error(f"Failed to call OCR service: {e}") + raise + + def health_check(self) -> Dict[str, Any]: + """ + 检查 OCR 服务健康状态 + + Returns: + dict: 健康状态信息 + """ + url = f"{self.base_url}{self.api_prefix}/health" + + try: + response = requests.get(url, timeout=10) + response.raise_for_status() + return response.json() + except requests.RequestException as e: + logger.error(f"Failed to check OCR service health: {e}") + raise + diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index ea3a87b..4853b91 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -35,6 +35,7 @@ from pypdf import PdfReader as pdf2_read from api import settings from api.utils.file_utils import get_project_base_directory from deepdoc.vision import OCR, AscendLayoutRecognizer, LayoutRecognizer, Recognizer, TableStructureRecognizer +from deepdoc.parser.ocr_http_client import OCRHttpClient from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk from rag.nlp import rag_tokenizer from rag.prompts.generator import vision_llm_describe_prompt @@ -58,10 +59,24 @@ class RAGFlowPdfParser: ^_- """ - - self.ocr = OCR() + + # 检查是否使用 HTTP OCR 服务 + use_http_ocr = os.getenv("USE_OCR_HTTP", "false").lower() in ("true", "1", "yes") + ocr_service_url = os.getenv("OCR_SERVICE_URL", "http://localhost:8000") + + if use_http_ocr: + logging.info(f"Using HTTP OCR service: {ocr_service_url}") + self.ocr = None # 不使用本地 OCR + self.ocr_http_client = OCRHttpClient(base_url=ocr_service_url) + self.use_http_ocr = True + else: + logging.info("Using local OCR") + self.ocr = OCR() + self.ocr_http_client = None + self.use_http_ocr = False + self.parallel_limiter = None - if PARALLEL_DEVICES > 1: + if not self.use_http_ocr and PARALLEL_DEVICES > 1: self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)] layout_recognizer_type = os.getenv("LAYOUT_RECOGNIZER_TYPE", "onnx").lower() @@ -276,7 +291,97 @@ class RAGFlowPdfParser: b["H_right"] = spans[ii]["x1"] b["SP"] = ii + def _convert_http_ocr_result(self, ocr_result: dict, zoomin: int = 3): + """ + 将 HTTP OCR API 返回的结果转换为 RAGFlow 内部格式 + + Args: + ocr_result: HTTP API 返回的结果,格式: + { + "success": bool, + "data": { + "pages": [ + { + "page_number": int, + "boxes": [ + { + "text": str, + "bbox": [[x0, y0], [x1, y0], [x1, y1], [x0, y1]], + "confidence": float + } + ] + } + ] + } + } + zoomin: 放大倍数 + """ + if not ocr_result.get("success", False) or "data" not in ocr_result: + logging.warning("Invalid OCR HTTP result") + return + + pages_data = ocr_result["data"].get("pages", []) + self.boxes = [] + + for page_data in pages_data: + page_num = page_data.get("page_number", 0) # HTTP API 返回的页码(从1开始) + boxes = page_data.get("boxes", []) + + # 转换为 RAGFlow 格式的 boxes + ragflow_boxes = [] + # 计算在 page_chars 中的索引:HTTP API 返回的页码是从1开始的,需要转换为相对于 page_from 的索引 + page_index = page_num - (self.page_from + 1) # page_from 是从0开始,所以需要 +1 + chars_for_page = self.page_chars[page_index] if hasattr(self, 'page_chars') and 0 <= page_index < len(self.page_chars) else [] + + for box in boxes: + bbox = box.get("bbox", []) + if len(bbox) != 4: + continue + + # 从 bbox 提取坐标(bbox 格式: [[x0, y0], [x1, y0], [x1, y1], [x0, y1]]) + x0 = min(bbox[0][0], bbox[3][0]) / zoomin + x1 = max(bbox[1][0], bbox[2][0]) / zoomin + top = min(bbox[0][1], bbox[1][1]) / zoomin + bottom = max(bbox[2][1], bbox[3][1]) / zoomin + + # 创建 RAGFlow 格式的 box + ragflow_box = { + "x0": x0, + "x1": x1, + "top": top, + "bottom": bottom, + "text": box.get("text", ""), + "page_number": page_num, + "layoutno": "", + "layout_type": "" + } + + ragflow_boxes.append(ragflow_box) + + # 计算 mean_height + if ragflow_boxes: + heights = [b["bottom"] - b["top"] for b in ragflow_boxes] + self.mean_height.append(np.median(heights) if heights else 0) + else: + self.mean_height.append(0) + + # 计算 mean_width + if chars_for_page: + widths = [c.get("width", 8) for c in chars_for_page] + self.mean_width.append(np.median(widths) if widths else 8) + else: + self.mean_width.append(8) + + self.boxes.append(ragflow_boxes) + + logging.info(f"Converted {len(pages_data)} pages from HTTP OCR result") + def __ocr(self, pagenum, img, chars, ZM=3, device_id: int | None = None): + # 如果使用 HTTP OCR,这个方法不会被调用 + if self.use_http_ocr: + logging.warning("__ocr called when using HTTP OCR, this should not happen") + return + start = timer() bxs = self.ocr.detect(np.array(img), device_id) logging.info(f"__ocr detecting boxes of a image cost ({timer() - start}s)") @@ -927,6 +1032,7 @@ class RAGFlowPdfParser: self.page_cum_height = [0] self.page_layout = [] self.page_from = page_from + start = timer() try: with sys.modules[LOCK_KEY_pdfplumber]: @@ -945,6 +1051,42 @@ class RAGFlowPdfParser: except Exception: logging.exception("RAGFlowPdfParser __images__") logging.info(f"__images__ dedupe_chars cost {timer() - start}s") + + # 如果使用 HTTP OCR,在获取图片和字符信息后调用 HTTP API 获取 OCR 结果 + if self.use_http_ocr: + try: + if callback: + callback(0.1, "Calling OCR HTTP service...") + + # 调用 HTTP OCR 服务 + if isinstance(fnm, str): + # 文件路径 + ocr_result = self.ocr_http_client.parse_pdf_by_path( + fnm, + page_from=page_from + 1, # HTTP API 使用从1开始的页码 + page_to=(page_to + 1) if page_to < 299 else 0, # 转换为从1开始,0 表示最后一页 + zoomin=zoomin + ) + else: + # 二进制数据 + ocr_result = self.ocr_http_client.parse_pdf_by_bytes( + fnm, + filename="document.pdf", + page_from=page_from + 1, + page_to=(page_to + 1) if page_to < 299 else 0, + zoomin=zoomin + ) + + # 将 HTTP API 返回的结果转换为 RAGFlow 格式 + self._convert_http_ocr_result(ocr_result, zoomin) + + if callback: + callback(0.4, "OCR HTTP service completed") + + except Exception as e: + logging.error(f"Failed to call OCR HTTP service: {e}", exc_info=True) + # 如果 HTTP OCR 失败,回退到空结果或抛出异常 + raise self.outlines = [] try: @@ -999,29 +1141,34 @@ class RAGFlowPdfParser: if callback and i % 6 == 5: callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="") - async def __img_ocr_launcher(): - def __ocr_preprocess(): - chars = self.page_chars[i] if not self.is_english else [] - self.mean_height.append(np.median(sorted([c["height"] for c in chars])) if chars else 0) - self.mean_width.append(np.median(sorted([c["width"] for c in chars])) if chars else 8) - self.page_cum_height.append(img.size[1] / zoomin) - return chars + # 如果使用 HTTP OCR,已经在上面的代码中获取了结果,跳过本地 OCR + if not self.use_http_ocr: + async def __img_ocr_launcher(): + def __ocr_preprocess(): + chars = self.page_chars[i] if not self.is_english else [] + self.mean_height.append(np.median(sorted([c["height"] for c in chars])) if chars else 0) + self.mean_width.append(np.median(sorted([c["width"] for c in chars])) if chars else 8) + self.page_cum_height.append(img.size[1] / zoomin) + return chars - if self.parallel_limiter: - async with trio.open_nursery() as nursery: + if self.parallel_limiter: + async with trio.open_nursery() as nursery: + for i, img in enumerate(self.page_images): + chars = __ocr_preprocess() + + nursery.start_soon(__img_ocr, i, i % PARALLEL_DEVICES, img, chars, self.parallel_limiter[i % PARALLEL_DEVICES]) + await trio.sleep(0.1) + else: for i, img in enumerate(self.page_images): chars = __ocr_preprocess() + await __img_ocr(i, 0, img, chars, None) - nursery.start_soon(__img_ocr, i, i % PARALLEL_DEVICES, img, chars, self.parallel_limiter[i % PARALLEL_DEVICES]) - await trio.sleep(0.1) - else: - for i, img in enumerate(self.page_images): - chars = __ocr_preprocess() - await __img_ocr(i, 0, img, chars, None) - - start = timer() - - trio.run(__img_ocr_launcher) + start = timer() + trio.run(__img_ocr_launcher) + else: + # HTTP OCR 模式:初始化 page_cum_height + for i, img in enumerate(self.page_images): + self.page_cum_height.append(img.size[1] / zoomin) logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s") diff --git a/docker/.env b/docker/.env index 335b15d..4fa4816 100644 --- a/docker/.env +++ b/docker/.env @@ -198,4 +198,6 @@ POSTGRES_DBNAME=rag_flow POSTGRES_USER=rag_flow POSTGRES_PASSWORD=infini_rag_flow POSTGRES_PORT=5432 -DB_TYPE=postgres \ No newline at end of file +DB_TYPE=postgres + +USE_OCR_HTTP=true \ No newline at end of file diff --git a/ocr/__init__.py b/ocr/__init__.py new file mode 100644 index 0000000..60ab8da --- /dev/null +++ b/ocr/__init__.py @@ -0,0 +1,53 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +独立的 OCR 模块 + +此模块从 RAGFlow 项目中提取,已经移除了对 RAGFlow 特定模块的依赖。 +可以直接作为独立模块使用。 + +使用方法: + from ocr import OCR + import cv2 + + ocr = OCR() + img = cv2.imread("image.jpg") + results = ocr(img) +""" + +# 处理导入问题:支持直接运行和模块导入 +import sys +from pathlib import Path + +try: + _package = __package__ +except NameError: + _package = None + +if _package is None: + # 直接运行时,添加父目录到路径并使用绝对导入 + parent_dir = Path(__file__).parent.parent + if str(parent_dir) not in sys.path: + sys.path.insert(0, str(parent_dir)) + from ocr.ocr import OCR, TextDetector, TextRecognizer + from ocr.pdf_parser import SimplePdfParser +else: + # 作为模块导入时使用相对导入 + from .ocr import OCR, TextDetector, TextRecognizer + from .pdf_parser import SimplePdfParser + +__all__ = ['OCR', 'TextDetector', 'TextRecognizer', 'SimplePdfParser'] + diff --git a/ocr/api.py b/ocr/api.py new file mode 100644 index 0000000..771d38f --- /dev/null +++ b/ocr/api.py @@ -0,0 +1,332 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +OCR PDF处理的FastAPI路由 +提供HTTP接口用于PDF的OCR识别 +""" + +import asyncio +import logging +import os +import sys +import tempfile +from pathlib import Path +from typing import Optional + +from fastapi import APIRouter, File, Form, HTTPException, UploadFile +from fastapi.responses import JSONResponse +from pydantic import BaseModel + +# 处理导入问题:支持直接运行和模块导入 + +try: + _package = __package__ +except NameError: + _package = None + +if _package is None: + # 直接运行时,添加父目录到路径并使用绝对导入 + parent_dir = Path(__file__).parent.parent + if str(parent_dir) not in sys.path: + sys.path.insert(0, str(parent_dir)) + from ocr.pdf_parser import SimplePdfParser + from ocr.config import MODEL_DIR +else: + # 作为模块导入时使用相对导入 + from pdf_parser import SimplePdfParser + from config import MODEL_DIR + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr", tags=["OCR"]) + +# 全局解析器实例(懒加载) +_parser_instance: Optional[SimplePdfParser] = None + + +def get_parser() -> SimplePdfParser: + """获取全局解析器实例(单例模式)""" + global _parser_instance + if _parser_instance is None: + logger.info(f"Initializing OCR parser with model_dir={MODEL_DIR}") + _parser_instance = SimplePdfParser(model_dir=MODEL_DIR) + return _parser_instance + + +class ParseResponse(BaseModel): + """解析响应模型""" + success: bool + message: str + data: Optional[dict] = None + + +@router.get( + "/health", + summary="健康检查", + description="检查OCR服务的健康状态和配置信息", + response_description="返回服务状态和模型目录信息" +) +async def health_check(): + """ + 健康检查端点 + + 用于检查OCR服务的运行状态和配置信息。 + + Returns: + dict: 包含服务状态和模型目录的信息 + """ + return { + "status": "healthy", + "service": "OCR PDF Parser", + "model_dir": MODEL_DIR + } + + +@router.post( + "/parse", + response_model=ParseResponse, + summary="上传并解析PDF文件", + description="上传PDF文件并通过OCR识别提取文本内容", + response_description="返回OCR识别结果" +) +async def parse_pdf_endpoint( + file: UploadFile = File(..., description="PDF文件,支持上传任意PDF文档"), + page_from: int = Form(1, ge=1, description="起始页码(从1开始,默认为1)"), + page_to: int = Form(0, ge=0, description="结束页码(0表示解析到最后一页,默认为0)"), + zoomin: int = Form(3, ge=1, le=5, description="图像放大倍数(1-5,数值越大识别精度越高但速度越慢,默认为3)") +): + """ + 上传并解析PDF文件 + + 通过上传PDF文件,使用OCR技术识别并提取其中的文本内容。 + 支持指定解析的页码范围,以及调整图像放大倍数以平衡识别精度和速度。 + + Args: + file: 上传的PDF文件(multipart/form-data格式) + page_from: 起始页码(从1开始,最小值为1) + page_to: 结束页码(0表示解析到最后一页,最小值为0) + zoomin: 图像放大倍数(1-5之间,数值越大识别精度越高但处理速度越慢) + + Returns: + ParseResponse: 包含解析结果的响应对象,包括: + - success: 是否成功 + - message: 操作结果消息 + - data: OCR识别的文本内容和元数据 + + Raises: + HTTPException: 400 - 如果文件不是PDF格式或文件为空 + HTTPException: 500 - 如果解析过程中发生错误 + """ + if not file.filename.lower().endswith('.pdf'): + raise HTTPException(status_code=400, detail="只支持PDF文件") + + # 保存上传的文件到临时目录 + temp_file = None + try: + # 读取文件内容 + content = await file.read() + if not content: + raise HTTPException(status_code=400, detail="文件为空") + + # 创建临时文件 + with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp: + tmp.write(content) + temp_file = tmp.name + + logger.info(f"Parsing PDF file: {file.filename}, pages {page_from}-{page_to or 'end'}, zoomin={zoomin}") + + # 解析PDF(parse_pdf是同步方法,使用to_thread在线程池中执行) + parser = get_parser() + result = await asyncio.to_thread( + parser.parse_pdf, + temp_file, + zoomin, + page_from - 1, # 转换为从0开始的索引 + (page_to - 1) if page_to > 0 else 299, # 转换为从0开始的索引 + None # callback + ) + + return ParseResponse( + success=True, + message=f"成功解析PDF: {file.filename}", + data=result + ) + + except Exception as e: + logger.error(f"Error parsing PDF: {str(e)}", exc_info=True) + raise HTTPException( + status_code=500, + detail=f"解析PDF时发生错误: {str(e)}" + ) + + finally: + # 清理临时文件 + if temp_file and os.path.exists(temp_file): + try: + os.unlink(temp_file) + except Exception as e: + logger.warning(f"Failed to delete temp file {temp_file}: {e}") + + +@router.post( + "/parse/bytes", + response_model=ParseResponse, + summary="通过二进制数据解析PDF", + description="直接通过二进制数据解析PDF文件,无需上传文件", + response_description="返回OCR识别结果" +) +async def parse_pdf_bytes( + pdf_bytes: bytes = File(..., description="PDF文件的二进制数据(multipart/form-data格式)"), + filename: str = Form("document.pdf", description="文件名(仅用于日志记录,不影响解析)"), + page_from: int = Form(1, ge=1, description="起始页码(从1开始,默认为1)"), + page_to: int = Form(0, ge=0, description="结束页码(0表示解析到最后一页,默认为0)"), + zoomin: int = Form(3, ge=1, le=5, description="图像放大倍数(1-5,数值越大识别精度越高但速度越慢,默认为3)") +): + """ + 直接通过二进制数据解析PDF + + 适用于已获取PDF二进制数据的场景,无需文件上传步骤。 + 直接将PDF的二进制数据提交即可进行OCR识别。 + + Args: + pdf_bytes: PDF文件的二进制数据(以文件形式提交) + filename: 文件名(仅用于日志记录,不影响实际解析过程) + page_from: 起始页码(从1开始,最小值为1) + page_to: 结束页码(0表示解析到最后一页,最小值为0) + zoomin: 图像放大倍数(1-5之间,数值越大识别精度越高但处理速度越慢) + + Returns: + ParseResponse: 包含解析结果的响应对象 + + Raises: + HTTPException: 400 - 如果PDF数据为空 + HTTPException: 500 - 如果解析过程中发生错误 + """ + if not pdf_bytes: + raise HTTPException(status_code=400, detail="PDF数据为空") + + # 保存到临时文件 + temp_file = None + try: + with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp: + tmp.write(pdf_bytes) + temp_file = tmp.name + + logger.info(f"Parsing PDF bytes (filename: {filename}), pages {page_from}-{page_to or 'end'}, zoomin={zoomin}") + + # 解析PDF(parse_pdf是同步方法,使用to_thread在线程池中执行) + parser = get_parser() + result = await asyncio.to_thread( + parser.parse_pdf, + temp_file, + zoomin, + page_from - 1, # 转换为从0开始的索引 + (page_to - 1) if page_to > 0 else 299, # 转换为从0开始的索引 + None # callback + ) + + return ParseResponse( + success=True, + message=f"成功解析PDF: {filename}", + data=result + ) + + except Exception as e: + logger.error(f"Error parsing PDF bytes: {str(e)}", exc_info=True) + raise HTTPException( + status_code=500, + detail=f"解析PDF时发生错误: {str(e)}" + ) + + finally: + # 清理临时文件 + if temp_file and os.path.exists(temp_file): + try: + os.unlink(temp_file) + except Exception as e: + logger.warning(f"Failed to delete temp file {temp_file}: {e}") + + +@router.post( + "/parse/path", + response_model=ParseResponse, + summary="通过文件路径解析PDF", + description="通过服务器本地文件路径解析PDF文件", + response_description="返回OCR识别结果" +) +async def parse_pdf_path( + file_path: str = Form(..., description="PDF文件在服务器上的本地路径(必须是可访问的绝对路径)"), + page_from: int = Form(1, ge=1, description="起始页码(从1开始,默认为1)"), + page_to: int = Form(0, ge=0, description="结束页码(0表示解析到最后一页,默认为0)"), + zoomin: int = Form(3, ge=1, le=5, description="图像放大倍数(1-5,数值越大识别精度越高但速度越慢,默认为3)") +): + """ + 通过文件路径解析PDF + + 适用于PDF文件已经存在于服务器上的场景。 + 通过提供文件路径直接进行OCR识别,无需上传文件。 + + Args: + file_path: PDF文件在服务器上的本地路径(必须是服务器可访问的绝对路径) + page_from: 起始页码(从1开始,最小值为1) + page_to: 结束页码(0表示解析到最后一页,最小值为0) + zoomin: 图像放大倍数(1-5之间,数值越大识别精度越高但处理速度越慢) + + Returns: + ParseResponse: 包含解析结果的响应对象 + + Raises: + HTTPException: 400 - 如果文件不是PDF格式 + HTTPException: 404 - 如果文件不存在 + HTTPException: 500 - 如果解析过程中发生错误 + + Note: + 此端点需要确保提供的文件路径在服务器上可访问。 + 建议仅在内网环境或受信任的环境中使用,避免路径遍历安全风险。 + """ + if not os.path.exists(file_path): + raise HTTPException(status_code=404, detail=f"文件不存在: {file_path}") + + if not file_path.lower().endswith('.pdf'): + raise HTTPException(status_code=400, detail="只支持PDF文件") + + try: + logger.info(f"Parsing PDF from path: {file_path}, pages {page_from}-{page_to or 'end'}, zoomin={zoomin}") + + # 解析PDF(parse_pdf是同步方法,使用to_thread在线程池中执行) + parser = get_parser() + result = await asyncio.to_thread( + parser.parse_pdf, + file_path, + zoomin, + page_from - 1, # 转换为从0开始的索引 + (page_to - 1) if page_to > 0 else 299, # 转换为从0开始的索引 + None # callback + ) + + return ParseResponse( + success=True, + message=f"成功解析PDF: {file_path}", + data=result + ) + + except Exception as e: + logger.error(f"Error parsing PDF from path: {str(e)}", exc_info=True) + raise HTTPException( + status_code=500, + detail=f"解析PDF时发生错误: {str(e)}" + ) + diff --git a/ocr/config.py b/ocr/config.py new file mode 100644 index 0000000..d8d4d1a --- /dev/null +++ b/ocr/config.py @@ -0,0 +1,42 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +OCR 模块配置文件 +""" +import os +import logging + +# 并行设备数量(GPU数量,0表示使用CPU) +PARALLEL_DEVICES = 0 +try: + import torch.cuda + PARALLEL_DEVICES = torch.cuda.device_count() + logging.info(f"found {PARALLEL_DEVICES} gpus") +except Exception: + logging.info("can't import package 'torch', using CPU mode") + +# 模型目录 +# 可以从环境变量获取,或使用默认路径 +MODEL_DIR = os.getenv("OCR_MODEL_DIR", None) +if MODEL_DIR is None: + # 默认模型目录:当前项目根目录下的 models/deepdoc 目录 + # 如果不存在,将在 OCR 类初始化时尝试从 HuggingFace 下载 + _base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + MODEL_DIR = os.path.join(_base_dir, "models", "deepdoc") + # 如果目录不存在,设置为 None,让 OCR 类处理下载逻辑 + if not os.path.exists(MODEL_DIR): + MODEL_DIR = None + diff --git a/ocr/main.py b/ocr/main.py new file mode 100644 index 0000000..52f18e7 --- /dev/null +++ b/ocr/main.py @@ -0,0 +1,202 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +OCR PDF处理服务的主程序入口 +独立运行,不依赖RAGFlow的其他部分 +""" + +import argparse +import logging +import os +import sys +import signal +from pathlib import Path + +import uvicorn +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +# 处理直接运行时的导入问题 +# 当直接运行 python ocr/main.py 时,__package__ 为 None +# 当作为模块运行时(python -m ocr.main),__package__ 为 'ocr' +try: + _package = __package__ +except NameError: + _package = None + +if _package is None: + # 直接运行脚本时,添加父目录到路径 + parent_dir = Path(__file__).parent.parent + if str(parent_dir) not in sys.path: + sys.path.insert(0, str(parent_dir)) + from api import router as ocr_router + from config import MODEL_DIR +else: + # 作为模块导入时使用相对导入 + from api import router as ocr_router + from config import MODEL_DIR + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(sys.stdout) + ] +) + +logger = logging.getLogger(__name__) + + +def create_app() -> FastAPI: + """创建FastAPI应用实例""" + app = FastAPI( + title="OCR PDF Parser API", + description="独立的OCR PDF处理服务,提供PDF文档的OCR识别功能", + version="1.0.0", + docs_url="/apidocs", # Swagger UI 文档地址 + redoc_url="/redoc", # ReDoc 文档地址(备用) + openapi_url="/openapi.json" # OpenAPI JSON schema 地址 + ) + + # 添加CORS中间件 + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # 生产环境中应该设置具体的域名 + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # 注册OCR路由 + app.include_router(ocr_router) + + # 根路径 + @app.get("/") + async def root(): + return { + "service": "OCR PDF Parser", + "version": "1.0.0", + "docs": "/apidocs", + "health": "/api/v1/ocr/health" + } + + return app + + +def signal_handler(sig, frame): + """信号处理器,用于优雅关闭""" + logger.info("Received shutdown signal, exiting...") + sys.exit(0) + + +def main(): + """主函数""" + parser = argparse.ArgumentParser(description="OCR PDF处理服务") + parser.add_argument( + "--host", + type=str, + default="0.0.0.0", + help="服务器监听地址 (default: 0.0.0.0)" + ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="服务器端口 (default: 8000)" + ) + parser.add_argument( + "--reload", + action="store_true", + help="开发模式:自动重载代码" + ) + parser.add_argument( + "--workers", + type=int, + default=1, + help="工作进程数 (default: 1)" + ) + parser.add_argument( + "--log-level", + type=str, + default="info", + choices=["critical", "error", "warning", "info", "debug", "trace"], + help="日志级别 (default: info)" + ) + parser.add_argument( + "--model-dir", + type=str, + default=None, + help=f"OCR模型目录路径 (default: {MODEL_DIR})" + ) + + args = parser.parse_args() + + # 设置模型目录(如果提供) + if args.model_dir: + os.environ["OCR_MODEL_DIR"] = args.model_dir + logger.info(f"Using custom model directory: {args.model_dir}") + + # 检查模型目录 + model_dir = os.environ.get("OCR_MODEL_DIR", MODEL_DIR) + if model_dir and not os.path.exists(model_dir): + logger.warning(f"Model directory does not exist: {model_dir}") + logger.info("Models will be downloaded on first use") + + # 注册信号处理器 + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + # 显示启动信息 + logger.info("=" * 60) + logger.info("OCR PDF Parser Service") + logger.info("=" * 60) + logger.info(f"Host: {args.host}") + logger.info(f"Port: {args.port}") + logger.info(f"Model Directory: {model_dir}") + logger.info(f"Workers: {args.workers}") + logger.info(f"Reload: {args.reload}") + logger.info(f"Log Level: {args.log_level}") + logger.info("=" * 60) + logger.info(f"API Documentation (Swagger): http://{args.host}:{args.port}/apidocs") + logger.info(f"API Documentation (ReDoc): http://{args.host}:{args.port}/redoc") + logger.info(f"Health Check: http://{args.host}:{args.port}/api/v1/ocr/health") + logger.info("=" * 60) + + # 创建应用 + app = create_app() + + # 启动服务器 + try: + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level=args.log_level, + reload=args.reload, + workers=args.workers if not args.reload else 1, # reload模式不支持多进程 + access_log=True + ) + except KeyboardInterrupt: + logger.info("Server stopped by user") + except Exception as e: + logger.error(f"Server error: {e}", exc_info=True) + sys.exit(1) + + +if __name__ == "__main__": + main() + diff --git a/ocr/ocr.py b/ocr/ocr.py new file mode 100644 index 0000000..268e4c9 --- /dev/null +++ b/ocr/ocr.py @@ -0,0 +1,785 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import gc +import logging +import copy +import time +import os +import sys +from pathlib import Path + +from huggingface_hub import snapshot_download + +# 处理导入问题:支持直接运行和模块导入 +try: + _package = __package__ +except NameError: + _package = None + +if _package is None: + # 直接运行时,添加父目录到路径并使用绝对导入 + parent_dir = Path(__file__).parent.parent + if str(parent_dir) not in sys.path: + sys.path.insert(0, str(parent_dir)) + from ocr.utils import get_project_base_directory + from ocr.config import PARALLEL_DEVICES, MODEL_DIR + from ocr.operators import * # noqa: F403 + import ocr.operators as operators + from ocr.postprocess import build_post_process +else: + # 作为模块导入时使用相对导入 + from utils import get_project_base_directory + from config import PARALLEL_DEVICES, MODEL_DIR + from operators import * # noqa: F403 + import operators + from postprocess import build_post_process + +import math +import numpy as np +import cv2 +import onnxruntime as ort + +loaded_models = {} + +def transform(data, ops=None): + """ transform """ + if ops is None: + ops = [] + for op in ops: + data = op(data) + if data is None: + return None + return data + + +def create_operators(op_param_list, global_config=None): + """ + create operators based on the config + + Args: + params(list): a dict list, used to create some operators + """ + assert isinstance( + op_param_list, list), ('operator config should be a list') + ops = [] + for operator in op_param_list: + assert isinstance(operator, + dict) and len(operator) == 1, "yaml format error" + op_name = list(operator)[0] + param = {} if operator[op_name] is None else operator[op_name] + if global_config is not None: + param.update(global_config) + op = getattr(operators, op_name)(**param) + ops.append(op) + return ops + + +def load_model(model_dir, nm, device_id: int | None = None): + model_file_path = os.path.join(model_dir, nm + ".onnx") + model_cached_tag = model_file_path + str(device_id) if device_id is not None else model_file_path + + global loaded_models + loaded_model = loaded_models.get(model_cached_tag) + if loaded_model: + logging.info(f"load_model {model_file_path} reuses cached model") + return loaded_model + + if not os.path.exists(model_file_path): + raise ValueError("not find model file path {}".format( + model_file_path)) + + def cuda_is_available(): + try: + import torch + if torch.cuda.is_available() and torch.cuda.device_count() > device_id: + return True + except Exception: + return False + return False + + options = ort.SessionOptions() + options.enable_cpu_mem_arena = False + options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL + options.intra_op_num_threads = 2 + options.inter_op_num_threads = 2 + + # https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580 + # Shrink GPU memory after execution + run_options = ort.RunOptions() + if cuda_is_available(): + cuda_provider_options = { + "device_id": device_id, # Use specific GPU + "gpu_mem_limit": 512 * 1024 * 1024, # Limit gpu memory + "arena_extend_strategy": "kNextPowerOfTwo", # gpu memory allocation strategy + } + sess = ort.InferenceSession( + model_file_path, + options=options, + providers=['CUDAExecutionProvider'], + provider_options=[cuda_provider_options] + ) + run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "gpu:" + str(device_id)) + logging.info(f"load_model {model_file_path} uses GPU") + else: + sess = ort.InferenceSession( + model_file_path, + options=options, + providers=['CPUExecutionProvider']) + run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "cpu") + logging.info(f"load_model {model_file_path} uses CPU") + loaded_model = (sess, run_options) + loaded_models[model_cached_tag] = loaded_model + return loaded_model + + +class TextRecognizer: + def __init__(self, model_dir, device_id: int | None = None): + self.rec_image_shape = [int(v) for v in "3, 48, 320".split(",")] + self.rec_batch_num = 16 + postprocess_params = { + 'name': 'CTCLabelDecode', + "character_dict_path": os.path.join(model_dir, "ocr.res"), + "use_space_char": True + } + self.postprocess_op = build_post_process(postprocess_params) + self.predictor, self.run_options = load_model(model_dir, 'rec', device_id) + self.input_tensor = self.predictor.get_inputs()[0] + + def resize_norm_img(self, img, max_wh_ratio): + imgC, imgH, imgW = self.rec_image_shape + + assert imgC == img.shape[2] + imgW = int((imgH * max_wh_ratio)) + w = self.input_tensor.shape[3:][0] + if isinstance(w, str): + pass + elif w is not None and w > 0: + imgW = w + h, w = img.shape[:2] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype('float32') + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + return padding_im + + def resize_norm_img_vl(self, img, image_shape): + + imgC, imgH, imgW = image_shape + img = img[:, :, ::-1] # bgr2rgb + resized_image = cv2.resize( + img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) + resized_image = resized_image.astype('float32') + resized_image = resized_image.transpose((2, 0, 1)) / 255 + return resized_image + + def resize_norm_img_srn(self, img, image_shape): + imgC, imgH, imgW = image_shape + + img_black = np.zeros((imgH, imgW)) + im_hei = img.shape[0] + im_wid = img.shape[1] + + if im_wid <= im_hei * 1: + img_new = cv2.resize(img, (imgH * 1, imgH)) + elif im_wid <= im_hei * 2: + img_new = cv2.resize(img, (imgH * 2, imgH)) + elif im_wid <= im_hei * 3: + img_new = cv2.resize(img, (imgH * 3, imgH)) + else: + img_new = cv2.resize(img, (imgW, imgH)) + + img_np = np.asarray(img_new) + img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) + img_black[:, 0:img_np.shape[1]] = img_np + img_black = img_black[:, :, np.newaxis] + + row, col, c = img_black.shape + c = 1 + + return np.reshape(img_black, (c, row, col)).astype(np.float32) + + def srn_other_inputs(self, image_shape, num_heads, max_text_length): + + imgC, imgH, imgW = image_shape + feature_dim = int((imgH / 8) * (imgW / 8)) + + encoder_word_pos = np.array(range(0, feature_dim)).reshape( + (feature_dim, 1)).astype('int64') + gsrm_word_pos = np.array(range(0, max_text_length)).reshape( + (max_text_length, 1)).astype('int64') + + gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length)) + gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape( + [-1, 1, max_text_length, max_text_length]) + gsrm_slf_attn_bias1 = np.tile( + gsrm_slf_attn_bias1, + [1, num_heads, 1, 1]).astype('float32') * [-1e9] + + gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape( + [-1, 1, max_text_length, max_text_length]) + gsrm_slf_attn_bias2 = np.tile( + gsrm_slf_attn_bias2, + [1, num_heads, 1, 1]).astype('float32') * [-1e9] + + encoder_word_pos = encoder_word_pos[np.newaxis, :] + gsrm_word_pos = gsrm_word_pos[np.newaxis, :] + + return [ + encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2 + ] + + def process_image_srn(self, img, image_shape, num_heads, max_text_length): + norm_img = self.resize_norm_img_srn(img, image_shape) + norm_img = norm_img[np.newaxis, :] + + [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \ + self.srn_other_inputs(image_shape, num_heads, max_text_length) + + gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32) + gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32) + encoder_word_pos = encoder_word_pos.astype(np.int64) + gsrm_word_pos = gsrm_word_pos.astype(np.int64) + + return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2) + + def resize_norm_img_sar(self, img, image_shape, + width_downsample_ratio=0.25): + imgC, imgH, imgW_min, imgW_max = image_shape + h = img.shape[0] + w = img.shape[1] + valid_ratio = 1.0 + # make sure new_width is an integral multiple of width_divisor. + width_divisor = int(1 / width_downsample_ratio) + # resize + ratio = w / float(h) + resize_w = math.ceil(imgH * ratio) + if resize_w % width_divisor != 0: + resize_w = round(resize_w / width_divisor) * width_divisor + if imgW_min is not None: + resize_w = max(imgW_min, resize_w) + if imgW_max is not None: + valid_ratio = min(1.0, 1.0 * resize_w / imgW_max) + resize_w = min(imgW_max, resize_w) + resized_image = cv2.resize(img, (resize_w, imgH)) + resized_image = resized_image.astype('float32') + # norm + if image_shape[0] == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + resize_shape = resized_image.shape + padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32) + padding_im[:, :, 0:resize_w] = resized_image + pad_shape = padding_im.shape + + return padding_im, resize_shape, pad_shape, valid_ratio + + def resize_norm_img_spin(self, img): + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + # return padding_im + img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC) + img = np.array(img, np.float32) + img = np.expand_dims(img, -1) + img = img.transpose((2, 0, 1)) + mean = [127.5] + std = [127.5] + mean = np.array(mean, dtype=np.float32) + std = np.array(std, dtype=np.float32) + mean = np.float32(mean.reshape(1, -1)) + stdinv = 1 / np.float32(std.reshape(1, -1)) + img -= mean + img *= stdinv + return img + + def resize_norm_img_svtr(self, img, image_shape): + + imgC, imgH, imgW = image_shape + resized_image = cv2.resize( + img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) + resized_image = resized_image.astype('float32') + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + return resized_image + + def resize_norm_img_abinet(self, img, image_shape): + + imgC, imgH, imgW = image_shape + + resized_image = cv2.resize( + img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) + resized_image = resized_image.astype('float32') + resized_image = resized_image / 255. + + mean = np.array([0.485, 0.456, 0.406]) + std = np.array([0.229, 0.224, 0.225]) + resized_image = ( + resized_image - mean[None, None, ...]) / std[None, None, ...] + resized_image = resized_image.transpose((2, 0, 1)) + resized_image = resized_image.astype('float32') + + return resized_image + + def norm_img_can(self, img, image_shape): + + img = cv2.cvtColor( + img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image + + if self.rec_image_shape[0] == 1: + h, w = img.shape + _, imgH, imgW = self.rec_image_shape + if h < imgH or w < imgW: + padding_h = max(imgH - h, 0) + padding_w = max(imgW - w, 0) + img_padded = np.pad(img, ((0, padding_h), (0, padding_w)), + 'constant', + constant_values=(255)) + img = img_padded + + img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w + img = img.astype('float32') + + return img + + def close(self): + # close session and release manually + logging.info('Close text recognizer.') + if hasattr(self, "predictor"): + del self.predictor + gc.collect() + + def __call__(self, img_list): + img_num = len(img_list) + # Calculate the aspect ratio of all text bars + width_list = [] + for img in img_list: + width_list.append(img.shape[1] / float(img.shape[0])) + # Sorting can speed up the recognition process + indices = np.argsort(np.array(width_list)) + rec_res = [['', 0.0]] * img_num + batch_num = self.rec_batch_num + st = time.time() + + for beg_img_no in range(0, img_num, batch_num): + end_img_no = min(img_num, beg_img_no + batch_num) + norm_img_batch = [] + imgC, imgH, imgW = self.rec_image_shape[:3] + max_wh_ratio = imgW / imgH + # max_wh_ratio = 0 + for ino in range(beg_img_no, end_img_no): + h, w = img_list[indices[ino]].shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + for ino in range(beg_img_no, end_img_no): + norm_img = self.resize_norm_img(img_list[indices[ino]], + max_wh_ratio) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + norm_img_batch = np.concatenate(norm_img_batch) + norm_img_batch = norm_img_batch.copy() + + input_dict = {} + input_dict[self.input_tensor.name] = norm_img_batch + for i in range(100000): + try: + outputs = self.predictor.run(None, input_dict, self.run_options) + break + except Exception as e: + if i >= 3: + raise e + time.sleep(5) + preds = outputs[0] + rec_result = self.postprocess_op(preds) + for rno in range(len(rec_result)): + rec_res[indices[beg_img_no + rno]] = rec_result[rno] + + return rec_res, time.time() - st + + def __del__(self): + self.close() + + +class TextDetector: + def __init__(self, model_dir, device_id: int | None = None): + pre_process_list = [{ + 'DetResizeForTest': { + 'limit_side_len': 960, + 'limit_type': "max", + } + }, { + 'NormalizeImage': { + 'std': [0.229, 0.224, 0.225], + 'mean': [0.485, 0.456, 0.406], + 'scale': '1./255.', + 'order': 'hwc' + } + }, { + 'ToCHWImage': None + }, { + 'KeepKeys': { + 'keep_keys': ['image', 'shape'] + } + }] + postprocess_params = {"name": "DBPostProcess", "thresh": 0.3, "box_thresh": 0.5, "max_candidates": 1000, + "unclip_ratio": 1.5, "use_dilation": False, "score_mode": "fast", "box_type": "quad"} + + self.postprocess_op = build_post_process(postprocess_params) + self.predictor, self.run_options = load_model(model_dir, 'det', device_id) + self.input_tensor = self.predictor.get_inputs()[0] + + img_h, img_w = self.input_tensor.shape[2:] + if isinstance(img_h, str) or isinstance(img_w, str): + pass + elif img_h is not None and img_w is not None and img_h > 0 and img_w > 0: + pre_process_list[0] = { + 'DetResizeForTest': { + 'image_shape': [img_h, img_w] + } + } + self.preprocess_op = create_operators(pre_process_list) + + def order_points_clockwise(self, pts): + rect = np.zeros((4, 2), dtype="float32") + s = pts.sum(axis=1) + rect[0] = pts[np.argmin(s)] + rect[2] = pts[np.argmax(s)] + tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0) + diff = np.diff(np.array(tmp), axis=1) + rect[1] = tmp[np.argmin(diff)] + rect[3] = tmp[np.argmax(diff)] + return rect + + def clip_det_res(self, points, img_height, img_width): + for pno in range(points.shape[0]): + points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1)) + points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1)) + return points + + def filter_tag_det_res(self, dt_boxes, image_shape): + img_height, img_width = image_shape[0:2] + dt_boxes_new = [] + for box in dt_boxes: + if isinstance(box, list): + box = np.array(box) + box = self.order_points_clockwise(box) + box = self.clip_det_res(box, img_height, img_width) + rect_width = int(np.linalg.norm(box[0] - box[1])) + rect_height = int(np.linalg.norm(box[0] - box[3])) + if rect_width <= 3 or rect_height <= 3: + continue + dt_boxes_new.append(box) + dt_boxes = np.array(dt_boxes_new) + return dt_boxes + + def filter_tag_det_res_only_clip(self, dt_boxes, image_shape): + img_height, img_width = image_shape[0:2] + dt_boxes_new = [] + for box in dt_boxes: + if isinstance(box, list): + box = np.array(box) + box = self.clip_det_res(box, img_height, img_width) + dt_boxes_new.append(box) + dt_boxes = np.array(dt_boxes_new) + return dt_boxes + + def close(self): + logging.info("Close text detector.") + if hasattr(self, "predictor"): + del self.predictor + gc.collect() + + def __call__(self, img): + ori_im = img.copy() + data = {'image': img} + + st = time.time() + data = transform(data, self.preprocess_op) + img, shape_list = data + if img is None: + return None, 0 + img = np.expand_dims(img, axis=0) + shape_list = np.expand_dims(shape_list, axis=0) + img = img.copy() + input_dict = {} + input_dict[self.input_tensor.name] = img + for i in range(100000): + try: + outputs = self.predictor.run(None, input_dict, self.run_options) + break + except Exception as e: + if i >= 3: + raise e + time.sleep(5) + + post_result = self.postprocess_op({"maps": outputs[0]}, shape_list) + dt_boxes = post_result[0]['points'] + dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) + + return dt_boxes, time.time() - st + + def __del__(self): + self.close() + + +class OCR: + def __init__(self, model_dir=None): + """ + If you have trouble downloading HuggingFace models, -_^ this might help!! + + For Linux: + export HF_ENDPOINT=https://hf-mirror.com + + For Windows: + Good luck + ^_- + + """ + if not model_dir: + try: + # 使用配置中的 MODEL_DIR,如果不存在则尝试默认路径 + if MODEL_DIR and os.path.exists(MODEL_DIR): + model_dir = MODEL_DIR + else: + model_dir = os.path.join( + get_project_base_directory(), + "models", "deepdoc") + + # Append muti-gpus task to the list + if PARALLEL_DEVICES > 0: + self.text_detector = [] + self.text_recognizer = [] + for device_id in range(PARALLEL_DEVICES): + self.text_detector.append(TextDetector(model_dir, device_id)) + self.text_recognizer.append(TextRecognizer(model_dir, device_id)) + else: + self.text_detector = [TextDetector(model_dir)] + self.text_recognizer = [TextRecognizer(model_dir)] + + except Exception: + # 如果模型目录不存在,尝试从 HuggingFace 下载 + default_model_dir = os.path.join( + get_project_base_directory(), "models", "deepdoc") + model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", + local_dir=default_model_dir, + local_dir_use_symlinks=False) + + if PARALLEL_DEVICES > 0: + self.text_detector = [] + self.text_recognizer = [] + for device_id in range(PARALLEL_DEVICES): + self.text_detector.append(TextDetector(model_dir, device_id)) + self.text_recognizer.append(TextRecognizer(model_dir, device_id)) + else: + self.text_detector = [TextDetector(model_dir)] + self.text_recognizer = [TextRecognizer(model_dir)] + else: + # 如果指定了 model_dir,直接使用 + if PARALLEL_DEVICES > 0: + self.text_detector = [] + self.text_recognizer = [] + for device_id in range(PARALLEL_DEVICES): + self.text_detector.append(TextDetector(model_dir, device_id)) + self.text_recognizer.append(TextRecognizer(model_dir, device_id)) + else: + self.text_detector = [TextDetector(model_dir)] + self.text_recognizer = [TextRecognizer(model_dir)] + + self.drop_score = 0.5 + self.crop_image_res_index = 0 + + def get_rotate_crop_image(self, img, points): + ''' + img_height, img_width = img.shape[0:2] + left = int(np.min(points[:, 0])) + right = int(np.max(points[:, 0])) + top = int(np.min(points[:, 1])) + bottom = int(np.max(points[:, 1])) + img_crop = img[top:bottom, left:right, :].copy() + points[:, 0] = points[:, 0] - left + points[:, 1] = points[:, 1] - top + ''' + assert len(points) == 4, "shape of points must be 4*2" + img_crop_width = int( + max( + np.linalg.norm(points[0] - points[1]), + np.linalg.norm(points[2] - points[3]))) + img_crop_height = int( + max( + np.linalg.norm(points[0] - points[3]), + np.linalg.norm(points[1] - points[2]))) + pts_std = np.float32([[0, 0], [img_crop_width, 0], + [img_crop_width, img_crop_height], + [0, img_crop_height]]) + M = cv2.getPerspectiveTransform(points, pts_std) + dst_img = cv2.warpPerspective( + img, + M, (img_crop_width, img_crop_height), + borderMode=cv2.BORDER_REPLICATE, + flags=cv2.INTER_CUBIC) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_height * 1.0 / dst_img_width >= 1.5: + # Try original orientation + rec_result = self.text_recognizer[0]([dst_img]) + text, score = rec_result[0][0] + best_score = score + best_img = dst_img + + # Try clockwise 90° rotation + rotated_cw = np.rot90(dst_img, k=3) + rec_result = self.text_recognizer[0]([rotated_cw]) + rotated_cw_text, rotated_cw_score = rec_result[0][0] + if rotated_cw_score > best_score: + best_score = rotated_cw_score + best_img = rotated_cw + + # Try counter-clockwise 90° rotation + rotated_ccw = np.rot90(dst_img, k=1) + rec_result = self.text_recognizer[0]([rotated_ccw]) + rotated_ccw_text, rotated_ccw_score = rec_result[0][0] + if rotated_ccw_score > best_score: + best_img = rotated_ccw + + # Use the best image + dst_img = best_img + return dst_img + + def sorted_boxes(self, dt_boxes): + """ + Sort text boxes in order from top to bottom, left to right + args: + dt_boxes(array):detected text boxes with shape [4, 2] + return: + sorted boxes(array) with shape [4, 2] + """ + num_boxes = dt_boxes.shape[0] + sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) + _boxes = list(sorted_boxes) + + for i in range(num_boxes - 1): + for j in range(i, -1, -1): + if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \ + (_boxes[j + 1][0][0] < _boxes[j][0][0]): + tmp = _boxes[j] + _boxes[j] = _boxes[j + 1] + _boxes[j + 1] = tmp + else: + break + return _boxes + + def detect(self, img, device_id: int | None = None): + if device_id is None: + device_id = 0 + + time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0} + + if img is None: + return None, None, time_dict + + start = time.time() + dt_boxes, elapse = self.text_detector[device_id](img) + time_dict['det'] = elapse + + if dt_boxes is None: + end = time.time() + time_dict['all'] = end - start + return None, None, time_dict + + return zip(self.sorted_boxes(dt_boxes), [ + ("", 0) for _ in range(len(dt_boxes))]) + + def recognize(self, ori_im, box, device_id: int | None = None): + if device_id is None: + device_id = 0 + + img_crop = self.get_rotate_crop_image(ori_im, box) + + rec_res, elapse = self.text_recognizer[device_id]([img_crop]) + text, score = rec_res[0] + if score < self.drop_score: + return "" + return text + + def recognize_batch(self, img_list, device_id: int | None = None): + if device_id is None: + device_id = 0 + rec_res, elapse = self.text_recognizer[device_id](img_list) + texts = [] + for i in range(len(rec_res)): + text, score = rec_res[i] + if score < self.drop_score: + text = "" + texts.append(text) + return texts + + def __call__(self, img, device_id = 0, cls=True): + time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0} + if device_id is None: + device_id = 0 + + if img is None: + return None, None, time_dict + + start = time.time() + ori_im = img.copy() + dt_boxes, elapse = self.text_detector[device_id](img) + time_dict['det'] = elapse + + if dt_boxes is None: + end = time.time() + time_dict['all'] = end - start + return None, None, time_dict + + img_crop_list = [] + + dt_boxes = self.sorted_boxes(dt_boxes) + + for bno in range(len(dt_boxes)): + tmp_box = copy.deepcopy(dt_boxes[bno]) + img_crop = self.get_rotate_crop_image(ori_im, tmp_box) + img_crop_list.append(img_crop) + + rec_res, elapse = self.text_recognizer[device_id](img_crop_list) + + time_dict['rec'] = elapse + + filter_boxes, filter_rec_res = [], [] + for box, rec_result in zip(dt_boxes, rec_res): + text, score = rec_result + if score >= self.drop_score: + filter_boxes.append(box) + filter_rec_res.append(rec_result) + end = time.time() + time_dict['all'] = end - start + + # for bno in range(len(img_crop_list)): + # print(f"{bno}, {rec_res[bno]}") + + return list(zip([a.tolist() for a in filter_boxes], filter_rec_res)) + diff --git a/ocr/operators.py b/ocr/operators.py new file mode 100644 index 0000000..d7ff403 --- /dev/null +++ b/ocr/operators.py @@ -0,0 +1,726 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import sys +import six +import cv2 +import numpy as np +import math +from PIL import Image + + +class DecodeImage: + """ decode image """ + + def __init__(self, + img_mode='RGB', + channel_first=False, + ignore_orientation=False, + **kwargs): + self.img_mode = img_mode + self.channel_first = channel_first + self.ignore_orientation = ignore_orientation + + def __call__(self, data): + img = data['image'] + if six.PY2: + assert isinstance(img, str) and len( + img) > 0, "invalid input 'img' in DecodeImage" + else: + assert isinstance(img, bytes) and len( + img) > 0, "invalid input 'img' in DecodeImage" + img = np.frombuffer(img, dtype='uint8') + if self.ignore_orientation: + img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION | + cv2.IMREAD_COLOR) + else: + img = cv2.imdecode(img, 1) + if img is None: + return None + if self.img_mode == 'GRAY': + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif self.img_mode == 'RGB': + assert img.shape[2] == 3, 'invalid shape of image[%s]' % ( + img.shape) + img = img[:, :, ::-1] + + if self.channel_first: + img = img.transpose((2, 0, 1)) + + data['image'] = img + return data + + +class StandardizeImag: + """normalize image + Args: + mean (list): im - mean + std (list): im / std + is_scale (bool): whether need im / 255 + norm_type (str): type in ['mean_std', 'none'] + """ + + def __init__(self, mean, std, is_scale=True, norm_type='mean_std'): + self.mean = mean + self.std = std + self.is_scale = is_scale + self.norm_type = norm_type + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + im = im.astype(np.float32, copy=False) + if self.is_scale: + scale = 1.0 / 255.0 + im *= scale + + if self.norm_type == 'mean_std': + mean = np.array(self.mean)[np.newaxis, np.newaxis, :] + std = np.array(self.std)[np.newaxis, np.newaxis, :] + im -= mean + im /= std + return im, im_info + + +class NormalizeImage: + """ normalize image such as subtract mean, divide std + """ + + def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs): + if isinstance(scale, str): + scale = eval(scale) + self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) + mean = mean if mean is not None else [0.485, 0.456, 0.406] + std = std if std is not None else [0.229, 0.224, 0.225] + + shape = (3, 1, 1) if order == 'chw' else (1, 1, 3) + self.mean = np.array(mean).reshape(shape).astype('float32') + self.std = np.array(std).reshape(shape).astype('float32') + + def __call__(self, data): + img = data['image'] + from PIL import Image + if isinstance(img, Image.Image): + img = np.array(img) + assert isinstance(img, + np.ndarray), "invalid input 'img' in NormalizeImage" + data['image'] = ( + img.astype('float32') * self.scale - self.mean) / self.std + return data + + +class ToCHWImage: + """ convert hwc image to chw image + """ + + def __init__(self, **kwargs): + pass + + def __call__(self, data): + img = data['image'] + from PIL import Image + if isinstance(img, Image.Image): + img = np.array(img) + data['image'] = img.transpose((2, 0, 1)) + return data + + +class KeepKeys: + def __init__(self, keep_keys, **kwargs): + self.keep_keys = keep_keys + + def __call__(self, data): + data_list = [] + for key in self.keep_keys: + data_list.append(data[key]) + return data_list + + +class Pad: + def __init__(self, size=None, size_div=32, **kwargs): + if size is not None and not isinstance(size, (int, list, tuple)): + raise TypeError("Type of target_size is invalid. Now is {}".format( + type(size))) + if isinstance(size, int): + size = [size, size] + self.size = size + self.size_div = size_div + + def __call__(self, data): + + img = data['image'] + img_h, img_w = img.shape[0], img.shape[1] + if self.size: + resize_h2, resize_w2 = self.size + assert ( + img_h < resize_h2 and img_w < resize_w2 + ), '(h, w) of target size should be greater than (img_h, img_w)' + else: + resize_h2 = max( + int(math.ceil(img.shape[0] / self.size_div) * self.size_div), + self.size_div) + resize_w2 = max( + int(math.ceil(img.shape[1] / self.size_div) * self.size_div), + self.size_div) + img = cv2.copyMakeBorder( + img, + 0, + resize_h2 - img_h, + 0, + resize_w2 - img_w, + cv2.BORDER_CONSTANT, + value=0) + data['image'] = img + return data + + +class LinearResize: + """resize image by target_size and max_size + Args: + target_size (int): the target size of image + keep_ratio (bool): whether keep_ratio or not, default true + interp (int): method of resize + """ + + def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR): + if isinstance(target_size, int): + target_size = [target_size, target_size] + self.target_size = target_size + self.keep_ratio = keep_ratio + self.interp = interp + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + assert len(self.target_size) == 2 + assert self.target_size[0] > 0 and self.target_size[1] > 0 + _im_channel = im.shape[2] + im_scale_y, im_scale_x = self.generate_scale(im) + im = cv2.resize( + im, + None, + None, + fx=im_scale_x, + fy=im_scale_y, + interpolation=self.interp) + im_info['im_shape'] = np.array(im.shape[:2]).astype('float32') + im_info['scale_factor'] = np.array( + [im_scale_y, im_scale_x]).astype('float32') + return im, im_info + + def generate_scale(self, im): + """ + Args: + im (np.ndarray): image (np.ndarray) + Returns: + im_scale_x: the resize ratio of X + im_scale_y: the resize ratio of Y + """ + origin_shape = im.shape[:2] + _im_c = im.shape[2] + if self.keep_ratio: + im_size_min = np.min(origin_shape) + im_size_max = np.max(origin_shape) + target_size_min = np.min(self.target_size) + target_size_max = np.max(self.target_size) + im_scale = float(target_size_min) / float(im_size_min) + if np.round(im_scale * im_size_max) > target_size_max: + im_scale = float(target_size_max) / float(im_size_max) + im_scale_x = im_scale + im_scale_y = im_scale + else: + resize_h, resize_w = self.target_size + im_scale_y = resize_h / float(origin_shape[0]) + im_scale_x = resize_w / float(origin_shape[1]) + return im_scale_y, im_scale_x + + +class Resize: + def __init__(self, size=(640, 640), **kwargs): + self.size = size + + def resize_image(self, img): + resize_h, resize_w = self.size + ori_h, ori_w = img.shape[:2] # (h, w, c) + ratio_h = float(resize_h) / ori_h + ratio_w = float(resize_w) / ori_w + img = cv2.resize(img, (int(resize_w), int(resize_h))) + return img, [ratio_h, ratio_w] + + def __call__(self, data): + img = data['image'] + if 'polys' in data: + text_polys = data['polys'] + + img_resize, [ratio_h, ratio_w] = self.resize_image(img) + if 'polys' in data: + new_boxes = [] + for box in text_polys: + new_box = [] + for cord in box: + new_box.append([cord[0] * ratio_w, cord[1] * ratio_h]) + new_boxes.append(new_box) + data['polys'] = np.array(new_boxes, dtype=np.float32) + data['image'] = img_resize + return data + + +class DetResizeForTest: + def __init__(self, **kwargs): + super(DetResizeForTest, self).__init__() + self.resize_type = 0 + self.keep_ratio = False + if 'image_shape' in kwargs: + self.image_shape = kwargs['image_shape'] + self.resize_type = 1 + if 'keep_ratio' in kwargs: + self.keep_ratio = kwargs['keep_ratio'] + elif 'limit_side_len' in kwargs: + self.limit_side_len = kwargs['limit_side_len'] + self.limit_type = kwargs.get('limit_type', 'min') + elif 'resize_long' in kwargs: + self.resize_type = 2 + self.resize_long = kwargs.get('resize_long', 960) + else: + self.limit_side_len = 736 + self.limit_type = 'min' + + def __call__(self, data): + img = data['image'] + src_h, src_w, _ = img.shape + if sum([src_h, src_w]) < 64: + img = self.image_padding(img) + + if self.resize_type == 0: + # img, shape = self.resize_image_type0(img) + img, [ratio_h, ratio_w] = self.resize_image_type0(img) + elif self.resize_type == 2: + img, [ratio_h, ratio_w] = self.resize_image_type2(img) + else: + # img, shape = self.resize_image_type1(img) + img, [ratio_h, ratio_w] = self.resize_image_type1(img) + data['image'] = img + data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) + return data + + def image_padding(self, im, value=0): + h, w, c = im.shape + im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value + im_pad[:h, :w, :] = im + return im_pad + + def resize_image_type1(self, img): + resize_h, resize_w = self.image_shape + ori_h, ori_w = img.shape[:2] # (h, w, c) + if self.keep_ratio is True: + resize_w = ori_w * resize_h / ori_h + N = math.ceil(resize_w / 32) + resize_w = N * 32 + ratio_h = float(resize_h) / ori_h + ratio_w = float(resize_w) / ori_w + img = cv2.resize(img, (int(resize_w), int(resize_h))) + # return img, np.array([ori_h, ori_w]) + return img, [ratio_h, ratio_w] + + def resize_image_type0(self, img): + """ + resize image to a size multiple of 32 which is required by the network + args: + img(array): array with shape [h, w, c] + return(tuple): + img, (ratio_h, ratio_w) + """ + limit_side_len = self.limit_side_len + h, w, c = img.shape + + # limit the max side + if self.limit_type == 'max': + if max(h, w) > limit_side_len: + if h > w: + ratio = float(limit_side_len) / h + else: + ratio = float(limit_side_len) / w + else: + ratio = 1. + elif self.limit_type == 'min': + if min(h, w) < limit_side_len: + if h < w: + ratio = float(limit_side_len) / h + else: + ratio = float(limit_side_len) / w + else: + ratio = 1. + elif self.limit_type == 'resize_long': + ratio = float(limit_side_len) / max(h, w) + else: + raise Exception('not support limit type, image ') + resize_h = int(h * ratio) + resize_w = int(w * ratio) + + resize_h = max(int(round(resize_h / 32) * 32), 32) + resize_w = max(int(round(resize_w / 32) * 32), 32) + + try: + if int(resize_w) <= 0 or int(resize_h) <= 0: + return None, (None, None) + img = cv2.resize(img, (int(resize_w), int(resize_h))) + except BaseException: + logging.exception("{} {} {}".format(img.shape, resize_w, resize_h)) + sys.exit(0) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + return img, [ratio_h, ratio_w] + + def resize_image_type2(self, img): + h, w, _ = img.shape + + resize_w = w + resize_h = h + + if resize_h > resize_w: + ratio = float(self.resize_long) / resize_h + else: + ratio = float(self.resize_long) / resize_w + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + img = cv2.resize(img, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + + return img, [ratio_h, ratio_w] + + +class E2EResizeForTest: + def __init__(self, **kwargs): + super(E2EResizeForTest, self).__init__() + self.max_side_len = kwargs['max_side_len'] + self.valid_set = kwargs['valid_set'] + + def __call__(self, data): + img = data['image'] + src_h, src_w, _ = img.shape + if self.valid_set == 'totaltext': + im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext( + img, max_side_len=self.max_side_len) + else: + im_resized, (ratio_h, ratio_w) = self.resize_image( + img, max_side_len=self.max_side_len) + data['image'] = im_resized + data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) + return data + + def resize_image_for_totaltext(self, im, max_side_len=512): + h, w, _ = im.shape + resize_w = w + resize_h = h + ratio = 1.25 + if h * ratio > max_side_len: + ratio = float(max_side_len) / resize_h + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(im, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + return im, (ratio_h, ratio_w) + + def resize_image(self, im, max_side_len=512): + """ + resize image to a size multiple of max_stride which is required by the network + :param im: the resized image + :param max_side_len: limit of max image size to avoid out of memory in gpu + :return: the resized image and the resize ratio + """ + h, w, _ = im.shape + + resize_w = w + resize_h = h + + # Fix the longer side + if resize_h > resize_w: + ratio = float(max_side_len) / resize_h + else: + ratio = float(max_side_len) / resize_w + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(im, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + + return im, (ratio_h, ratio_w) + + +class KieResize: + def __init__(self, **kwargs): + super(KieResize, self).__init__() + self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[ + 'img_scale'][1] + + def __call__(self, data): + img = data['image'] + points = data['points'] + src_h, src_w, _ = img.shape + im_resized, scale_factor, [ratio_h, ratio_w + ], [new_h, new_w] = self.resize_image(img) + resize_points = self.resize_boxes(img, points, scale_factor) + data['ori_image'] = img + data['ori_boxes'] = points + data['points'] = resize_points + data['image'] = im_resized + data['shape'] = np.array([new_h, new_w]) + return data + + def resize_image(self, img): + norm_img = np.zeros([1024, 1024, 3], dtype='float32') + scale = [512, 1024] + h, w = img.shape[:2] + max_long_edge = max(scale) + max_short_edge = min(scale) + scale_factor = min(max_long_edge / max(h, w), + max_short_edge / min(h, w)) + resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float( + scale_factor) + 0.5) + max_stride = 32 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(img, (resize_w, resize_h)) + new_h, new_w = im.shape[:2] + w_scale = new_w / w + h_scale = new_h / h + scale_factor = np.array( + [w_scale, h_scale, w_scale, h_scale], dtype=np.float32) + norm_img[:new_h, :new_w, :] = im + return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w] + + def resize_boxes(self, im, points, scale_factor): + points = points * scale_factor + img_shape = im.shape[:2] + points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1]) + points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0]) + return points + + +class SRResize: + def __init__(self, + imgH=32, + imgW=128, + down_sample_scale=4, + keep_ratio=False, + min_ratio=1, + mask=False, + infer_mode=False, + **kwargs): + self.imgH = imgH + self.imgW = imgW + self.keep_ratio = keep_ratio + self.min_ratio = min_ratio + self.down_sample_scale = down_sample_scale + self.mask = mask + self.infer_mode = infer_mode + + def __call__(self, data): + imgH = self.imgH + imgW = self.imgW + images_lr = data["image_lr"] + transform2 = ResizeNormalize( + (imgW // self.down_sample_scale, imgH // self.down_sample_scale)) + images_lr = transform2(images_lr) + data["img_lr"] = images_lr + if self.infer_mode: + return data + + images_HR = data["image_hr"] + _label_strs = data["label"] + transform = ResizeNormalize((imgW, imgH)) + images_HR = transform(images_HR) + data["img_hr"] = images_HR + return data + + +class ResizeNormalize: + def __init__(self, size, interpolation=Image.BICUBIC): + self.size = size + self.interpolation = interpolation + + def __call__(self, img): + img = img.resize(self.size, self.interpolation) + img_numpy = np.array(img).astype("float32") + img_numpy = img_numpy.transpose((2, 0, 1)) / 255 + return img_numpy + + +class GrayImageChannelFormat: + """ + format gray scale image's channel: (3,h,w) -> (1,h,w) + Args: + inverse: inverse gray image + """ + + def __init__(self, inverse=False, **kwargs): + self.inverse = inverse + + def __call__(self, data): + img = data['image'] + img_single_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + img_expanded = np.expand_dims(img_single_channel, 0) + + if self.inverse: + data['image'] = np.abs(img_expanded - 1) + else: + data['image'] = img_expanded + + data['src_image'] = img + return data + + +class Permute: + """permute image + Args: + to_bgr (bool): whether convert RGB to BGR + channel_first (bool): whether convert HWC to CHW + """ + + def __init__(self, ): + super(Permute, self).__init__() + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + im = im.transpose((2, 0, 1)).copy() + return im, im_info + + +class PadStride: + """ padding image for model with FPN, instead PadBatch(pad_to_stride) in original config + Args: + stride (bool): model with FPN need image shape % stride == 0 + """ + + def __init__(self, stride=0): + self.coarsest_stride = stride + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + coarsest_stride = self.coarsest_stride + if coarsest_stride <= 0: + return im, im_info + im_c, im_h, im_w = im.shape + pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride) + pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride) + padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32) + padding_im[:, :im_h, :im_w] = im + return padding_im, im_info + + +def decode_image(im_file, im_info): + """read rgb image + Args: + im_file (str|np.ndarray): input can be image path or np.ndarray + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + if isinstance(im_file, str): + with open(im_file, 'rb') as f: + im_read = f.read() + data = np.frombuffer(im_read, dtype='uint8') + im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + else: + im = im_file + im_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32) + im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32) + return im, im_info + + +def preprocess(im, preprocess_ops): + # process image by preprocess_ops + im_info = { + 'scale_factor': np.array( + [1., 1.], dtype=np.float32), + 'im_shape': None, + } + im, im_info = decode_image(im, im_info) + for operator in preprocess_ops: + im, im_info = operator(im, im_info) + return im, im_info + + +def nms(bboxes, scores, iou_thresh): + import numpy as np + x1 = bboxes[:, 0] + y1 = bboxes[:, 1] + x2 = bboxes[:, 2] + y2 = bboxes[:, 3] + areas = (y2 - y1) * (x2 - x1) + + indices = [] + index = scores.argsort()[::-1] + while index.size > 0: + i = index[0] + indices.append(i) + x11 = np.maximum(x1[i], x1[index[1:]]) + y11 = np.maximum(y1[i], y1[index[1:]]) + x22 = np.minimum(x2[i], x2[index[1:]]) + y22 = np.minimum(y2[i], y2[index[1:]]) + w = np.maximum(0, x22 - x11 + 1) + h = np.maximum(0, y22 - y11 + 1) + overlaps = w * h + ious = overlaps / (areas[i] + areas[index[1:]] - overlaps) + idx = np.where(ious <= iou_thresh)[0] + index = index[idx + 1] + return indices + diff --git a/ocr/pdf_parser.py b/ocr/pdf_parser.py new file mode 100644 index 0000000..800c6e4 --- /dev/null +++ b/ocr/pdf_parser.py @@ -0,0 +1,339 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +简化的PDF解析器,只使用OCR处理PDF文档 + +从 RAGFlow 的 RAGFlowPdfParser 中提取OCR相关功能,移除了: +- 布局识别(Layout Recognition) +- 表格结构识别(Table Structure Recognition) +- 文本合并和语义分析 +- RAG相关功能 + +只保留: +- PDF转图片 +- OCR文本检测和识别 +- 基本的文本和位置信息返回 +""" + +import logging +import sys +import threading +from io import BytesIO +from pathlib import Path +from timeit import default_timer as timer + +import numpy as np +import pdfplumber +import trio + +# 处理导入问题:支持直接运行和模块导入 +try: + _package = __package__ +except NameError: + _package = None + +if _package is None: + # 直接运行时,添加父目录到路径并使用绝对导入 + parent_dir = Path(__file__).parent.parent + if str(parent_dir) not in sys.path: + sys.path.insert(0, str(parent_dir)) + from ocr.config import PARALLEL_DEVICES + from ocr.ocr import OCR +else: + # 作为模块导入时使用相对导入 + from config import PARALLEL_DEVICES + from ocr import OCR + +LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber" +if LOCK_KEY_pdfplumber not in sys.modules: + sys.modules[LOCK_KEY_pdfplumber] = threading.Lock() + + +class SimplePdfParser: + """ + 简化的PDF解析器,只使用OCR处理PDF + + 使用方法: + parser = SimplePdfParser() + result = parser.parse_pdf("file.pdf") # 或传入二进制数据 + # result 格式: + # { + # "pages": [ + # { + # "page_number": 1, + # "boxes": [ + # { + # "text": "识别到的文本", + # "bbox": [[x0, y0], [x1, y0], [x1, y1], [x0, y1]], + # "confidence": 0.95 + # }, + # ... + # ] + # }, + # ... + # ] + # } + """ + + def __init__(self, model_dir=None): + """ + 初始化PDF解析器 + + Args: + model_dir: OCR模型目录,如果为None则使用默认路径 + """ + self.ocr = OCR(model_dir=model_dir) + self.parallel_limiter = None + if PARALLEL_DEVICES > 1: + self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)] + + def __ocr_page(self, page_num, img, zoomin=3, device_id=None): + """ + 对单页进行OCR处理 + + Args: + page_num: 页码 + img: PIL Image对象 + zoomin: 放大倍数(用于坐标缩放) + device_id: GPU设备ID + + Returns: + list: OCR结果列表,每个元素为 {"text": str, "bbox": list, "confidence": float} + """ + start = timer() + img_np = np.array(img) + + # 文本检测 + # detect方法返回: zip对象,格式为 (box_coords, (text, score)) + # 但检测阶段text和score都是默认值,需要后续识别 + detection_result = self.ocr.detect(img_np, device_id) + + if detection_result is None: + return [] + + # 转换为列表并提取box坐标 + # detect返回的格式是zip,每个元素是 (box_coords, (text, score)) + # 在检测阶段,text是空字符串,score是0 + bxs = list(detection_result) + + logging.info(f"Page {page_num}: OCR detection found {len(bxs)} boxes in {timer() - start:.2f}s") + + if not bxs: + return [] + + # 解析检测结果并准备识别 + boxes_to_reg = [] + + start = timer() + for box_coords, _, _ in bxs: + # box_coords 是四边形坐标: [[x0, y0], [x1, y0], [x1, y1], [x0, y1]] + # 转换为原始坐标(考虑zoomin) + box_coords_np = np.array(box_coords, dtype=np.float32) + original_coords = box_coords_np / zoomin # 缩放回原始坐标 + + # 裁剪图像用于识别 + # 使用放大后的坐标裁剪(因为img_np是放大后的图像) + crop_box = box_coords_np + crop_img = self.ocr.get_rotate_crop_image(img_np, crop_box) + boxes_to_reg.append({ + "bbox": original_coords.tolist(), + "crop_img": crop_img + }) + + # 批量识别文本 + ocr_results = [] + if boxes_to_reg: + crop_imgs = [b["crop_img"] for b in boxes_to_reg] + texts = self.ocr.recognize_batch(crop_imgs, device_id) + + # 组装结果 + for i, b in enumerate(boxes_to_reg): + if i < len(texts) and texts[i]: # 过滤空文本 + ocr_results.append({ + "text": texts[i], + "bbox": b["bbox"], + "confidence": 0.9 # 简化版本,不计算具体置信度 + }) + + logging.info(f"Page {page_num}: OCR recognition {len(ocr_results)} boxes cost {timer() - start:.2f}s") + return ocr_results + + async def __ocr_page_async(self, page_num, img, zoomin, device_id, limiter, callback): + """ + 异步OCR处理单页 + + Args: + page_num: 页码 + img: PIL Image对象 + zoomin: 放大倍数 + device_id: GPU设备ID + limiter: 并发限制器 + callback: 进度回调函数 + """ + if limiter: + async with limiter: + result = await trio.to_thread.run_sync( + lambda: self.__ocr_page(page_num, img, zoomin, device_id) + ) + else: + result = await trio.to_thread.run_sync( + lambda: self.__ocr_page(page_num, img, zoomin, device_id) + ) + + if callback and page_num % 5 == 0: + callback(prog=page_num * 0.9 / 100, msg=f"Processing page {page_num}...") + + return result + + def __convert_pdf_to_images(self, pdf_source, zoomin=3, page_from=0, page_to=299): + """ + 将PDF转换为图片 + + Args: + pdf_source: PDF文件路径(str)或二进制数据(bytes) + zoomin: 放大倍数,默认3(72*3=216 DPI) + page_from: 起始页码(从0开始) + page_to: 结束页码 + + Returns: + list: PIL Image对象列表 + """ + start = timer() + page_images = [] + + try: + with sys.modules[LOCK_KEY_pdfplumber]: + pdf = pdfplumber.open(pdf_source) if isinstance(pdf_source, str) else pdfplumber.open(BytesIO(pdf_source)) + try: + # 转换为图片,resolution = 72 * zoomin + page_images = [ + p.to_image(resolution=72 * zoomin, antialias=True).annotated + for i, p in enumerate(pdf.pages[page_from:page_to]) + ] + pdf.close() + except Exception as e: + logging.warning(f"Failed to convert PDF pages {page_from}-{page_to}: {str(e)}") + if hasattr(pdf, 'close'): + pdf.close() + except Exception as e: + logging.exception(f"Error converting PDF to images: {str(e)}") + + logging.info(f"Converted {len(page_images)} pages to images in {timer() - start:.2f}s") + return page_images + + def parse_pdf(self, pdf_source, zoomin=3, page_from=0, page_to=299, callback=None): + """ + 解析PDF文档,使用OCR识别文本 + + Args: + pdf_source: PDF文件路径(str)或二进制数据(bytes) + zoomin: 放大倍数,默认3 + page_from: 起始页码(从0开始) + page_to: 结束页码 + callback: 进度回调函数,格式: callback(prog: float, msg: str) + + Returns: + dict: 解析结果 + { + "pages": [ + { + "page_number": int, + "boxes": [ + { + "text": str, + "bbox": [[x0, y0], [x1, y0], [x1, y1], [x0, y1]], + "confidence": float + }, + ... + ] + }, + ... + ] + } + """ + if callback: + callback(0.0, "Starting PDF parsing...") + + # 1. 转换为图片 + if callback: + callback(0.1, "Converting PDF to images...") + page_images = self.__convert_pdf_to_images(pdf_source, zoomin, page_from, page_to) + + if not page_images: + logging.warning("No pages converted from PDF") + return {"pages": []} + + # 2. OCR处理 + async def process_all_pages(): + pages_result = [] + + if self.parallel_limiter: + # 并行处理(多GPU) + async with trio.open_nursery() as nursery: + tasks = [] + for i, img in enumerate(page_images): + page_num = page_from + i + 1 + device_id = i % PARALLEL_DEVICES + task = nursery.start_soon( + self.__ocr_page_async, + page_num, img, zoomin, device_id, + self.parallel_limiter[device_id], callback + ) + tasks.append(task) + + # 等待所有任务完成并收集结果 + for i, task in enumerate(tasks): + result = await task + pages_result.append({ + "page_number": page_from + i + 1, + "boxes": result + }) + else: + # 串行处理(单GPU或CPU) + for i, img in enumerate(page_images): + page_num = page_from + i + 1 + result = await trio.to_thread.run_sync( + lambda img=img, pn=page_num: self.__ocr_page(pn, img, zoomin, 0) + ) + pages_result.append({ + "page_number": page_num, + "boxes": result + }) + if callback: + callback(0.1 + (i + 1) * 0.9 / len(page_images), f"Processing page {page_num}...") + + return pages_result + + # 运行异步处理 + if callback: + callback(0.2, "Starting OCR processing...") + + start = timer() + pages_result = trio.run(process_all_pages) + logging.info(f"OCR processing completed in {timer() - start:.2f}s") + + if callback: + callback(1.0, "OCR processing completed") + + return { + "pages": pages_result + } + + +# 向后兼容的别名 +PdfParser = SimplePdfParser + diff --git a/ocr/postprocess.py b/ocr/postprocess.py new file mode 100644 index 0000000..f4577f6 --- /dev/null +++ b/ocr/postprocess.py @@ -0,0 +1,371 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import copy +import re +import numpy as np +import cv2 +from shapely.geometry import Polygon +import pyclipper + + +def build_post_process(config, global_config=None): + support_dict = {'DBPostProcess': DBPostProcess, 'CTCLabelDecode': CTCLabelDecode} + + config = copy.deepcopy(config) + module_name = config.pop('name') + if module_name == "None": + return + if global_config is not None: + config.update(global_config) + module_class = support_dict.get(module_name) + if module_class is None: + raise ValueError( + 'post process only support {}'.format(list(support_dict))) + return module_class(**config) + + +class DBPostProcess: + """ + The post process for Differentiable Binarization (DB). + """ + + def __init__(self, + thresh=0.3, + box_thresh=0.7, + max_candidates=1000, + unclip_ratio=2.0, + use_dilation=False, + score_mode="fast", + box_type='quad', + **kwargs): + self.thresh = thresh + self.box_thresh = box_thresh + self.max_candidates = max_candidates + self.unclip_ratio = unclip_ratio + self.min_size = 3 + self.score_mode = score_mode + self.box_type = box_type + assert score_mode in [ + "slow", "fast" + ], "Score mode must be in [slow, fast] but got: {}".format(score_mode) + + self.dilation_kernel = None if not use_dilation else np.array( + [[1, 1], [1, 1]]) + + def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height): + ''' + _bitmap: single map with shape (1, H, W), + whose values are binarized as {0, 1} + ''' + + bitmap = _bitmap + height, width = bitmap.shape + + boxes = [] + scores = [] + + contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), + cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + + for contour in contours[:self.max_candidates]: + epsilon = 0.002 * cv2.arcLength(contour, True) + approx = cv2.approxPolyDP(contour, epsilon, True) + points = approx.reshape((-1, 2)) + if points.shape[0] < 4: + continue + + score = self.box_score_fast(pred, points.reshape(-1, 2)) + if self.box_thresh > score: + continue + + if points.shape[0] > 2: + box = self.unclip(points, self.unclip_ratio) + if len(box) > 1: + continue + else: + continue + box = box.reshape(-1, 2) + + _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2))) + if sside < self.min_size + 2: + continue + + box = np.array(box) + box[:, 0] = np.clip( + np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip( + np.round(box[:, 1] / height * dest_height), 0, dest_height) + boxes.append(box.tolist()) + scores.append(score) + return boxes, scores + + def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): + ''' + _bitmap: single map with shape (1, H, W), + whose values are binarized as {0, 1} + ''' + + bitmap = _bitmap + height, width = bitmap.shape + + outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, + cv2.CHAIN_APPROX_SIMPLE) + if len(outs) == 3: + _img, contours, _ = outs[0], outs[1], outs[2] + elif len(outs) == 2: + contours, _ = outs[0], outs[1] + + num_contours = min(len(contours), self.max_candidates) + + boxes = [] + scores = [] + for index in range(num_contours): + contour = contours[index] + points, sside = self.get_mini_boxes(contour) + if sside < self.min_size: + continue + points = np.array(points) + if self.score_mode == "fast": + score = self.box_score_fast(pred, points.reshape(-1, 2)) + else: + score = self.box_score_slow(pred, contour) + if self.box_thresh > score: + continue + + box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2) + box, sside = self.get_mini_boxes(box) + if sside < self.min_size + 2: + continue + box = np.array(box) + + box[:, 0] = np.clip( + np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip( + np.round(box[:, 1] / height * dest_height), 0, dest_height) + boxes.append(box.astype("int32")) + scores.append(score) + return np.array(boxes, dtype="int32"), scores + + def unclip(self, box, unclip_ratio): + poly = Polygon(box) + distance = poly.area * unclip_ratio / poly.length + offset = pyclipper.PyclipperOffset() + offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + expanded = np.array(offset.Execute(distance)) + return expanded + + def get_mini_boxes(self, contour): + bounding_box = cv2.minAreaRect(contour) + points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) + + index_1, index_2, index_3, index_4 = 0, 1, 2, 3 + if points[1][1] > points[0][1]: + index_1 = 0 + index_4 = 1 + else: + index_1 = 1 + index_4 = 0 + if points[3][1] > points[2][1]: + index_2 = 2 + index_3 = 3 + else: + index_2 = 3 + index_3 = 2 + + box = [ + points[index_1], points[index_2], points[index_3], points[index_4] + ] + return box, min(bounding_box[1]) + + def box_score_fast(self, bitmap, _box): + ''' + box_score_fast: use bbox mean score as the mean score + ''' + h, w = bitmap.shape[:2] + box = _box.copy() + xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1) + xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1) + ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1) + ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + box[:, 0] = box[:, 0] - xmin + box[:, 1] = box[:, 1] - ymin + cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1) + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] + + def box_score_slow(self, bitmap, contour): + ''' + box_score_slow: use polyon mean score as the mean score + ''' + h, w = bitmap.shape[:2] + contour = contour.copy() + contour = np.reshape(contour, (-1, 2)) + + xmin = np.clip(np.min(contour[:, 0]), 0, w - 1) + xmax = np.clip(np.max(contour[:, 0]), 0, w - 1) + ymin = np.clip(np.min(contour[:, 1]), 0, h - 1) + ymax = np.clip(np.max(contour[:, 1]), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + + contour[:, 0] = contour[:, 0] - xmin + contour[:, 1] = contour[:, 1] - ymin + + cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1) + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] + + def __call__(self, outs_dict, shape_list): + pred = outs_dict['maps'] + if not isinstance(pred, np.ndarray): + pred = pred.numpy() + pred = pred[:, 0, :, :] + segmentation = pred > self.thresh + + boxes_batch = [] + for batch_index in range(pred.shape[0]): + src_h, src_w, ratio_h, ratio_w = shape_list[batch_index] + if self.dilation_kernel is not None: + mask = cv2.dilate( + np.array(segmentation[batch_index]).astype(np.uint8), + self.dilation_kernel) + else: + mask = segmentation[batch_index] + if self.box_type == 'poly': + boxes, scores = self.polygons_from_bitmap(pred[batch_index], + mask, src_w, src_h) + elif self.box_type == 'quad': + boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, + src_w, src_h) + else: + raise ValueError( + "box_type can only be one of ['quad', 'poly']") + + boxes_batch.append({'points': boxes}) + return boxes_batch + + +class BaseRecLabelDecode: + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=False): + self.beg_str = "sos" + self.end_str = "eos" + self.reverse = False + self.character_str = [] + + if character_dict_path is None: + self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" + dict_character = list(self.character_str) + else: + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode('utf-8').strip("\n").strip("\r\n") + self.character_str.append(line) + if use_space_char: + self.character_str.append(" ") + dict_character = list(self.character_str) + if 'arabic' in character_dict_path: + self.reverse = True + + dict_character = self.add_special_char(dict_character) + self.dict = {} + for i, char in enumerate(dict_character): + self.dict[char] = i + self.character = dict_character + + def pred_reverse(self, pred): + pred_re = [] + c_current = '' + for c in pred: + if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)): + if c_current != '': + pred_re.append(c_current) + pred_re.append(c) + c_current = '' + else: + c_current += c + if c_current != '': + pred_re.append(c_current) + + return ''.join(pred_re[::-1]) + + def add_special_char(self, dict_character): + return dict_character + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + ignored_tokens = self.get_ignored_tokens() + batch_size = len(text_index) + for batch_idx in range(batch_size): + selection = np.ones(len(text_index[batch_idx]), dtype=bool) + if is_remove_duplicate: + selection[1:] = text_index[batch_idx][1:] != text_index[ + batch_idx][:-1] + for ignored_token in ignored_tokens: + selection &= text_index[batch_idx] != ignored_token + + char_list = [ + self.character[text_id] + for text_id in text_index[batch_idx][selection] + ] + if text_prob is not None: + conf_list = text_prob[batch_idx][selection] + else: + conf_list = [1] * len(selection) + if len(conf_list) == 0: + conf_list = [0] + + text = ''.join(char_list) + + if self.reverse: # for arabic rec + text = self.pred_reverse(text) + + result_list.append((text, np.mean(conf_list).tolist())) + return result_list + + def get_ignored_tokens(self): + return [0] # for ctc blank + + +class CTCLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=False, + **kwargs): + super(CTCLabelDecode, self).__init__(character_dict_path, + use_space_char) + + def __call__(self, preds, label=None, *args, **kwargs): + if isinstance(preds, tuple) or isinstance(preds, list): + preds = preds[-1] + if not isinstance(preds, np.ndarray): + preds = preds.numpy() + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) + if label is None: + return text + label = self.decode(label) + return text, label + + def add_special_char(self, dict_character): + dict_character = ['blank'] + dict_character + return dict_character + diff --git a/ocr/requirements.txt b/ocr/requirements.txt new file mode 100644 index 0000000..fedeaed --- /dev/null +++ b/ocr/requirements.txt @@ -0,0 +1,25 @@ +# OCR PDF处理模块依赖 +# 核心依赖 +numpy>=1.21.0 +opencv-python>=4.5.0 +pillow>=8.0.0 +pdfplumber>=0.9.0 +onnxruntime>=1.12.0 +trio>=0.22.0 + +# 几何计算依赖 +shapely>=1.8.0 +pyclipper>=1.2.0 + +# Web框架依赖 +fastapi>=0.100.0 +uvicorn[standard]>=0.23.0 +pydantic>=2.0.0 + +# 模型下载依赖 +huggingface_hub>=0.16.0 + +# 可选依赖(用于GPU检测和加速) +# torch>=1.12.0 # 如果需要GPU支持,取消注释并安装 +# onnxruntime-gpu>=1.12.0 # 如果需要GPU支持,取消注释并安装 + diff --git a/ocr/utils.py b/ocr/utils.py new file mode 100644 index 0000000..90bd40b --- /dev/null +++ b/ocr/utils.py @@ -0,0 +1,40 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +OCR 模块工具函数 +""" +import os + + +def get_project_base_directory(*args): + """ + 获取项目根目录 + + Args: + *args: 可选的子路径 + + Returns: + str: 项目根目录路径 + """ + # 获取当前文件的目录 + current_dir = os.path.dirname(os.path.realpath(__file__)) + # 返回 ocr 模块的父目录(项目根目录) + base_dir = os.path.dirname(current_dir) + + if args: + return os.path.join(base_dir, *args) + return base_dir +