291 lines
9.4 KiB
Python
291 lines
9.4 KiB
Python
|
|
#
|
|||
|
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|||
|
|
#
|
|||
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|||
|
|
# you may not use this file except in compliance with the License.
|
|||
|
|
# You may obtain a copy of the License at
|
|||
|
|
#
|
|||
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|||
|
|
#
|
|||
|
|
# Unless required by applicable law or agreed to in writing, software
|
|||
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|||
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
|
|
# See the License for the specific language governing permissions and
|
|||
|
|
# limitations under the License.
|
|||
|
|
#
|
|||
|
|
"""
|
|||
|
|
OCR 服务统一接口
|
|||
|
|
支持本地OCR模型和HTTP接口两种方式,可通过配置选择
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import logging
|
|||
|
|
import os
|
|||
|
|
from abc import ABC, abstractmethod
|
|||
|
|
from typing import Optional, Callable, List, Tuple, Any
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class OCRService(ABC):
|
|||
|
|
"""OCR服务抽象接口"""
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
async def remove_tag(self, text: str) -> str:
|
|||
|
|
"""
|
|||
|
|
移除文本中的位置标签
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
text: 包含位置标签的文本
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
清理后的文本
|
|||
|
|
"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
def remove_tag_sync(self, text: str) -> str:
|
|||
|
|
"""
|
|||
|
|
同步版本的 remove_tag(用于同步代码)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
text: 包含位置标签的文本
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
清理后的文本
|
|||
|
|
"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
|
|||
|
|
"""
|
|||
|
|
从文本中提取位置信息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
text: 包含位置标签的文本
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
位置信息列表,格式为 [(页码列表, left, right, top, bottom), ...]
|
|||
|
|
"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
|
|||
|
|
"""
|
|||
|
|
同步版本的 extract_positions(用于同步代码)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
text: 包含位置标签的文本
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
位置信息列表
|
|||
|
|
"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
async def parse_into_bboxes(
|
|||
|
|
self,
|
|||
|
|
pdf_bytes: bytes,
|
|||
|
|
callback: Optional[Callable[[float, str], None]] = None,
|
|||
|
|
zoomin: int = 3,
|
|||
|
|
filename: str = "document.pdf"
|
|||
|
|
) -> List[dict]:
|
|||
|
|
"""
|
|||
|
|
解析 PDF 并返回边界框
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
pdf_bytes: PDF 文件的二进制数据
|
|||
|
|
callback: 进度回调函数 (progress: float, message: str) -> None
|
|||
|
|
zoomin: 图像放大倍数(1-5,默认为3)
|
|||
|
|
filename: 文件名(仅用于日志)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
边界框列表
|
|||
|
|
"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
def parse_into_bboxes_sync(
|
|||
|
|
self,
|
|||
|
|
pdf_bytes: bytes,
|
|||
|
|
callback: Optional[Callable[[float, str], None]] = None,
|
|||
|
|
zoomin: int = 3,
|
|||
|
|
filename: str = "document.pdf"
|
|||
|
|
) -> List[dict]:
|
|||
|
|
"""
|
|||
|
|
同步版本的 parse_into_bboxes(用于同步代码)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
pdf_bytes: PDF 文件的二进制数据
|
|||
|
|
callback: 进度回调函数(注意:HTTP 调用中无法实时传递回调,此参数将被忽略)
|
|||
|
|
zoomin: 图像放大倍数(1-5,默认为3)
|
|||
|
|
filename: 文件名(仅用于日志)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
边界框列表
|
|||
|
|
"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
|
|||
|
|
class LocalOCRService(OCRService):
|
|||
|
|
"""本地OCR服务实现(直接调用本地OCR模型)"""
|
|||
|
|
|
|||
|
|
def __init__(self, parser_instance=None):
|
|||
|
|
"""
|
|||
|
|
初始化本地OCR服务
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
parser_instance: SimplePdfParser 实例,如果不提供则自动创建
|
|||
|
|
"""
|
|||
|
|
if parser_instance is None:
|
|||
|
|
from ocr import SimplePdfParser
|
|||
|
|
from ocr.config import MODEL_DIR
|
|||
|
|
logger.info(f"Initializing local OCR parser with model_dir={MODEL_DIR}")
|
|||
|
|
self.parser = SimplePdfParser(model_dir=MODEL_DIR)
|
|||
|
|
else:
|
|||
|
|
self.parser = parser_instance
|
|||
|
|
|
|||
|
|
async def remove_tag(self, text: str) -> str:
|
|||
|
|
"""使用本地解析器的静态方法移除标签"""
|
|||
|
|
# SimplePdfParser.remove_tag 是静态方法,可以直接调用
|
|||
|
|
return self.parser.remove_tag(text)
|
|||
|
|
|
|||
|
|
def remove_tag_sync(self, text: str) -> str:
|
|||
|
|
"""同步版本的 remove_tag"""
|
|||
|
|
return self.parser.remove_tag(text)
|
|||
|
|
|
|||
|
|
async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
|
|||
|
|
"""使用本地解析器的静态方法提取位置"""
|
|||
|
|
# SimplePdfParser.extract_positions 是静态方法,可以直接调用
|
|||
|
|
return self.parser.extract_positions(text)
|
|||
|
|
|
|||
|
|
def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
|
|||
|
|
"""同步版本的 extract_positions"""
|
|||
|
|
return self.parser.extract_positions(text)
|
|||
|
|
|
|||
|
|
async def parse_into_bboxes(
|
|||
|
|
self,
|
|||
|
|
pdf_bytes: bytes,
|
|||
|
|
callback: Optional[Callable[[float, str], None]] = None,
|
|||
|
|
zoomin: int = 3,
|
|||
|
|
filename: str = "document.pdf"
|
|||
|
|
) -> List[dict]:
|
|||
|
|
"""使用本地解析器解析PDF"""
|
|||
|
|
# 本地解析器可以直接接受BytesIO
|
|||
|
|
import asyncio
|
|||
|
|
from io import BytesIO
|
|||
|
|
|
|||
|
|
# 在后台线程中运行同步方法
|
|||
|
|
loop = asyncio.get_event_loop()
|
|||
|
|
bboxes = await loop.run_in_executor(
|
|||
|
|
None,
|
|||
|
|
lambda: self.parser.parse_into_bboxes(BytesIO(pdf_bytes), callback=callback, zoomin=zoomin)
|
|||
|
|
)
|
|||
|
|
return bboxes
|
|||
|
|
|
|||
|
|
def parse_into_bboxes_sync(
|
|||
|
|
self,
|
|||
|
|
pdf_bytes: bytes,
|
|||
|
|
callback: Optional[Callable[[float, str], None]] = None,
|
|||
|
|
zoomin: int = 3,
|
|||
|
|
filename: str = "document.pdf"
|
|||
|
|
) -> List[dict]:
|
|||
|
|
"""同步版本的 parse_into_bboxes"""
|
|||
|
|
from io import BytesIO
|
|||
|
|
# 本地解析器可以直接接受BytesIO
|
|||
|
|
return self.parser.parse_into_bboxes(BytesIO(pdf_bytes), callback=callback, zoomin=zoomin)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class HTTPOCRService(OCRService):
|
|||
|
|
"""HTTP OCR服务实现(通过HTTP接口调用OCR服务)"""
|
|||
|
|
|
|||
|
|
def __init__(self, base_url: Optional[str] = None, timeout: float = 300.0):
|
|||
|
|
"""
|
|||
|
|
初始化HTTP OCR服务
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
base_url: OCR 服务的基础 URL,如果不提供则从环境变量 OCR_SERVICE_URL 获取
|
|||
|
|
timeout: 请求超时时间(秒),默认 300 秒
|
|||
|
|
"""
|
|||
|
|
from ocr.client import OCRClient
|
|||
|
|
self.client = OCRClient(base_url=base_url, timeout=timeout)
|
|||
|
|
|
|||
|
|
async def remove_tag(self, text: str) -> str:
|
|||
|
|
"""通过HTTP接口移除标签"""
|
|||
|
|
return await self.client.remove_tag(text)
|
|||
|
|
|
|||
|
|
def remove_tag_sync(self, text: str) -> str:
|
|||
|
|
"""同步版本的 remove_tag"""
|
|||
|
|
return self.client.remove_tag_sync(text)
|
|||
|
|
|
|||
|
|
async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
|
|||
|
|
"""通过HTTP接口提取位置"""
|
|||
|
|
return await self.client.extract_positions(text)
|
|||
|
|
|
|||
|
|
def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
|
|||
|
|
"""同步版本的 extract_positions"""
|
|||
|
|
return self.client.extract_positions_sync(text)
|
|||
|
|
|
|||
|
|
async def parse_into_bboxes(
|
|||
|
|
self,
|
|||
|
|
pdf_bytes: bytes,
|
|||
|
|
callback: Optional[Callable[[float, str], None]] = None,
|
|||
|
|
zoomin: int = 3,
|
|||
|
|
filename: str = "document.pdf"
|
|||
|
|
) -> List[dict]:
|
|||
|
|
"""通过HTTP接口解析PDF"""
|
|||
|
|
return await self.client.parse_into_bboxes(pdf_bytes, callback, zoomin, filename)
|
|||
|
|
|
|||
|
|
def parse_into_bboxes_sync(
|
|||
|
|
self,
|
|||
|
|
pdf_bytes: bytes,
|
|||
|
|
callback: Optional[Callable[[float, str], None]] = None,
|
|||
|
|
zoomin: int = 3,
|
|||
|
|
filename: str = "document.pdf"
|
|||
|
|
) -> List[dict]:
|
|||
|
|
"""同步版本的 parse_into_bboxes"""
|
|||
|
|
return self.client.parse_into_bboxes_sync(pdf_bytes, callback, zoomin, filename)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全局服务实例(懒加载)
|
|||
|
|
_global_service: Optional[OCRService] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_ocr_service() -> OCRService:
|
|||
|
|
"""
|
|||
|
|
获取全局 OCR 服务实例(单例模式)
|
|||
|
|
根据环境变量 OCR_MODE 选择使用本地或HTTP方式:
|
|||
|
|
- OCR_MODE=local 或未设置:使用本地OCR模型
|
|||
|
|
- OCR_MODE=http:使用HTTP接口
|
|||
|
|
|
|||
|
|
也可以通过环境变量 OCR_SERVICE_URL 配置HTTP服务的地址(仅在OCR_MODE=http时生效)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
OCRService 实例
|
|||
|
|
"""
|
|||
|
|
global _global_service
|
|||
|
|
if _global_service is None:
|
|||
|
|
ocr_mode = os.getenv("OCR_MODE", "local").lower()
|
|||
|
|
|
|||
|
|
if ocr_mode == "http":
|
|||
|
|
base_url = os.getenv("OCR_SERVICE_URL", "http://localhost:8000/api/v1/ocr")
|
|||
|
|
logger.info(f"Initializing HTTP OCR service with URL: {base_url}")
|
|||
|
|
_global_service = HTTPOCRService(base_url=base_url)
|
|||
|
|
else:
|
|||
|
|
logger.info("Initializing local OCR service")
|
|||
|
|
_global_service = LocalOCRService()
|
|||
|
|
|
|||
|
|
return _global_service
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 为了向后兼容,保留 get_ocr_client 函数(但重定向到 get_ocr_service)
|
|||
|
|
def get_ocr_client() -> OCRService:
|
|||
|
|
"""
|
|||
|
|
获取OCR服务实例(向后兼容函数)
|
|||
|
|
建议使用 get_ocr_service() 替代
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
OCRService 实例
|
|||
|
|
"""
|
|||
|
|
return get_ocr_service()
|
|||
|
|
|