# # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ OCR 服务统一接口 支持本地OCR模型和HTTP接口两种方式,可通过配置选择 """ import logging import os from abc import ABC, abstractmethod from typing import Optional, Callable, List, Tuple, Any logger = logging.getLogger(__name__) class OCRService(ABC): """OCR服务抽象接口""" @abstractmethod async def remove_tag(self, text: str) -> str: """ 移除文本中的位置标签 Args: text: 包含位置标签的文本 Returns: 清理后的文本 """ pass @abstractmethod def remove_tag_sync(self, text: str) -> str: """ 同步版本的 remove_tag(用于同步代码) Args: text: 包含位置标签的文本 Returns: 清理后的文本 """ pass @abstractmethod async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]: """ 从文本中提取位置信息 Args: text: 包含位置标签的文本 Returns: 位置信息列表,格式为 [(页码列表, left, right, top, bottom), ...] """ pass @abstractmethod def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]: """ 同步版本的 extract_positions(用于同步代码) Args: text: 包含位置标签的文本 Returns: 位置信息列表 """ pass @abstractmethod async def parse_into_bboxes( self, pdf_bytes: bytes, callback: Optional[Callable[[float, str], None]] = None, zoomin: int = 3, filename: str = "document.pdf" ) -> List[dict]: """ 解析 PDF 并返回边界框 Args: pdf_bytes: PDF 文件的二进制数据 callback: 进度回调函数 (progress: float, message: str) -> None zoomin: 图像放大倍数(1-5,默认为3) filename: 文件名(仅用于日志) Returns: 边界框列表 """ pass @abstractmethod def parse_into_bboxes_sync( self, pdf_bytes: bytes, callback: Optional[Callable[[float, str], None]] = None, zoomin: int = 3, filename: str = "document.pdf" ) -> List[dict]: """ 同步版本的 parse_into_bboxes(用于同步代码) Args: pdf_bytes: PDF 文件的二进制数据 callback: 进度回调函数(注意:HTTP 调用中无法实时传递回调,此参数将被忽略) zoomin: 图像放大倍数(1-5,默认为3) filename: 文件名(仅用于日志) Returns: 边界框列表 """ pass class LocalOCRService(OCRService): """本地OCR服务实现(直接调用本地OCR模型)""" def __init__(self, parser_instance=None): """ 初始化本地OCR服务 Args: parser_instance: SimplePdfParser 实例,如果不提供则自动创建 """ if parser_instance is None: from ocr import SimplePdfParser from ocr.config import MODEL_DIR logger.info(f"Initializing local OCR parser with model_dir={MODEL_DIR}") self.parser = SimplePdfParser(model_dir=MODEL_DIR) else: self.parser = parser_instance async def remove_tag(self, text: str) -> str: """使用本地解析器的静态方法移除标签""" # SimplePdfParser.remove_tag 是静态方法,可以直接调用 return self.parser.remove_tag(text) def remove_tag_sync(self, text: str) -> str: """同步版本的 remove_tag""" return self.parser.remove_tag(text) async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]: """使用本地解析器的静态方法提取位置""" # SimplePdfParser.extract_positions 是静态方法,可以直接调用 return self.parser.extract_positions(text) def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]: """同步版本的 extract_positions""" return self.parser.extract_positions(text) async def parse_into_bboxes( self, pdf_bytes: bytes, callback: Optional[Callable[[float, str], None]] = None, zoomin: int = 3, filename: str = "document.pdf" ) -> List[dict]: """使用本地解析器解析PDF""" # 本地解析器可以直接接受BytesIO import asyncio from io import BytesIO # 在后台线程中运行同步方法 loop = asyncio.get_event_loop() bboxes = await loop.run_in_executor( None, lambda: self.parser.parse_into_bboxes(BytesIO(pdf_bytes), callback=callback, zoomin=zoomin) ) return bboxes def parse_into_bboxes_sync( self, pdf_bytes: bytes, callback: Optional[Callable[[float, str], None]] = None, zoomin: int = 3, filename: str = "document.pdf" ) -> List[dict]: """同步版本的 parse_into_bboxes""" from io import BytesIO # 本地解析器可以直接接受BytesIO return self.parser.parse_into_bboxes(BytesIO(pdf_bytes), callback=callback, zoomin=zoomin) class HTTPOCRService(OCRService): """HTTP OCR服务实现(通过HTTP接口调用OCR服务)""" def __init__(self, base_url: Optional[str] = None, timeout: float = 300.0): """ 初始化HTTP OCR服务 Args: base_url: OCR 服务的基础 URL,如果不提供则从环境变量 OCR_SERVICE_URL 获取 timeout: 请求超时时间(秒),默认 300 秒 """ from ocr.client import OCRClient self.client = OCRClient(base_url=base_url, timeout=timeout) async def remove_tag(self, text: str) -> str: """通过HTTP接口移除标签""" return await self.client.remove_tag(text) def remove_tag_sync(self, text: str) -> str: """同步版本的 remove_tag""" return self.client.remove_tag_sync(text) async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]: """通过HTTP接口提取位置""" return await self.client.extract_positions(text) def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]: """同步版本的 extract_positions""" return self.client.extract_positions_sync(text) async def parse_into_bboxes( self, pdf_bytes: bytes, callback: Optional[Callable[[float, str], None]] = None, zoomin: int = 3, filename: str = "document.pdf" ) -> List[dict]: """通过HTTP接口解析PDF""" return await self.client.parse_into_bboxes(pdf_bytes, callback, zoomin, filename) def parse_into_bboxes_sync( self, pdf_bytes: bytes, callback: Optional[Callable[[float, str], None]] = None, zoomin: int = 3, filename: str = "document.pdf" ) -> List[dict]: """同步版本的 parse_into_bboxes""" return self.client.parse_into_bboxes_sync(pdf_bytes, callback, zoomin, filename) # 全局服务实例(懒加载) _global_service: Optional[OCRService] = None def get_ocr_service() -> OCRService: """ 获取全局 OCR 服务实例(单例模式) 根据环境变量 OCR_MODE 选择使用本地或HTTP方式: - OCR_MODE=local 或未设置:使用本地OCR模型 - OCR_MODE=http:使用HTTP接口 也可以通过环境变量 OCR_SERVICE_URL 配置HTTP服务的地址(仅在OCR_MODE=http时生效) Returns: OCRService 实例 """ global _global_service if _global_service is None: ocr_mode = os.getenv("OCR_MODE", "local").lower() if ocr_mode == "http": base_url = os.getenv("OCR_SERVICE_URL", "http://localhost:8000/api/v1/ocr") logger.info(f"Initializing HTTP OCR service with URL: {base_url}") _global_service = HTTPOCRService(base_url=base_url) else: logger.info("Initializing local OCR service") _global_service = LocalOCRService() return _global_service # 为了向后兼容,保留 get_ocr_client 函数(但重定向到 get_ocr_service) def get_ocr_client() -> OCRService: """ 获取OCR服务实例(向后兼容函数) 建议使用 get_ocr_service() 替代 Returns: OCRService 实例 """ return get_ocr_service()