Files
TERES_fastapi_backend/ocr/service.py
2025-11-03 10:22:28 +08:00

291 lines
9.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
OCR 服务统一接口
支持本地OCR模型和HTTP接口两种方式可通过配置选择
"""
import logging
import os
from abc import ABC, abstractmethod
from typing import Optional, Callable, List, Tuple, Any
logger = logging.getLogger(__name__)
class OCRService(ABC):
"""OCR服务抽象接口"""
@abstractmethod
async def remove_tag(self, text: str) -> str:
"""
移除文本中的位置标签
Args:
text: 包含位置标签的文本
Returns:
清理后的文本
"""
pass
@abstractmethod
def remove_tag_sync(self, text: str) -> str:
"""
同步版本的 remove_tag用于同步代码
Args:
text: 包含位置标签的文本
Returns:
清理后的文本
"""
pass
@abstractmethod
async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""
从文本中提取位置信息
Args:
text: 包含位置标签的文本
Returns:
位置信息列表,格式为 [(页码列表, left, right, top, bottom), ...]
"""
pass
@abstractmethod
def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""
同步版本的 extract_positions用于同步代码
Args:
text: 包含位置标签的文本
Returns:
位置信息列表
"""
pass
@abstractmethod
async def parse_into_bboxes(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""
解析 PDF 并返回边界框
Args:
pdf_bytes: PDF 文件的二进制数据
callback: 进度回调函数 (progress: float, message: str) -> None
zoomin: 图像放大倍数1-5默认为3
filename: 文件名(仅用于日志)
Returns:
边界框列表
"""
pass
@abstractmethod
def parse_into_bboxes_sync(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""
同步版本的 parse_into_bboxes用于同步代码
Args:
pdf_bytes: PDF 文件的二进制数据
callback: 进度回调函数注意HTTP 调用中无法实时传递回调,此参数将被忽略)
zoomin: 图像放大倍数1-5默认为3
filename: 文件名(仅用于日志)
Returns:
边界框列表
"""
pass
class LocalOCRService(OCRService):
"""本地OCR服务实现直接调用本地OCR模型"""
def __init__(self, parser_instance=None):
"""
初始化本地OCR服务
Args:
parser_instance: SimplePdfParser 实例,如果不提供则自动创建
"""
if parser_instance is None:
from ocr import SimplePdfParser
from ocr.config import MODEL_DIR
logger.info(f"Initializing local OCR parser with model_dir={MODEL_DIR}")
self.parser = SimplePdfParser(model_dir=MODEL_DIR)
else:
self.parser = parser_instance
async def remove_tag(self, text: str) -> str:
"""使用本地解析器的静态方法移除标签"""
# SimplePdfParser.remove_tag 是静态方法,可以直接调用
return self.parser.remove_tag(text)
def remove_tag_sync(self, text: str) -> str:
"""同步版本的 remove_tag"""
return self.parser.remove_tag(text)
async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""使用本地解析器的静态方法提取位置"""
# SimplePdfParser.extract_positions 是静态方法,可以直接调用
return self.parser.extract_positions(text)
def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""同步版本的 extract_positions"""
return self.parser.extract_positions(text)
async def parse_into_bboxes(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""使用本地解析器解析PDF"""
# 本地解析器可以直接接受BytesIO
import asyncio
from io import BytesIO
# 在后台线程中运行同步方法
loop = asyncio.get_event_loop()
bboxes = await loop.run_in_executor(
None,
lambda: self.parser.parse_into_bboxes(BytesIO(pdf_bytes), callback=callback, zoomin=zoomin)
)
return bboxes
def parse_into_bboxes_sync(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""同步版本的 parse_into_bboxes"""
from io import BytesIO
# 本地解析器可以直接接受BytesIO
return self.parser.parse_into_bboxes(BytesIO(pdf_bytes), callback=callback, zoomin=zoomin)
class HTTPOCRService(OCRService):
"""HTTP OCR服务实现通过HTTP接口调用OCR服务"""
def __init__(self, base_url: Optional[str] = None, timeout: float = 300.0):
"""
初始化HTTP OCR服务
Args:
base_url: OCR 服务的基础 URL如果不提供则从环境变量 OCR_SERVICE_URL 获取
timeout: 请求超时时间(秒),默认 300 秒
"""
from ocr.client import OCRClient
self.client = OCRClient(base_url=base_url, timeout=timeout)
async def remove_tag(self, text: str) -> str:
"""通过HTTP接口移除标签"""
return await self.client.remove_tag(text)
def remove_tag_sync(self, text: str) -> str:
"""同步版本的 remove_tag"""
return self.client.remove_tag_sync(text)
async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""通过HTTP接口提取位置"""
return await self.client.extract_positions(text)
def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""同步版本的 extract_positions"""
return self.client.extract_positions_sync(text)
async def parse_into_bboxes(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""通过HTTP接口解析PDF"""
return await self.client.parse_into_bboxes(pdf_bytes, callback, zoomin, filename)
def parse_into_bboxes_sync(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""同步版本的 parse_into_bboxes"""
return self.client.parse_into_bboxes_sync(pdf_bytes, callback, zoomin, filename)
# 全局服务实例(懒加载)
_global_service: Optional[OCRService] = None
def get_ocr_service() -> OCRService:
"""
获取全局 OCR 服务实例(单例模式)
根据环境变量 OCR_MODE 选择使用本地或HTTP方式
- OCR_MODE=local 或未设置使用本地OCR模型
- OCR_MODE=http使用HTTP接口
也可以通过环境变量 OCR_SERVICE_URL 配置HTTP服务的地址仅在OCR_MODE=http时生效
Returns:
OCRService 实例
"""
global _global_service
if _global_service is None:
ocr_mode = os.getenv("OCR_MODE", "local").lower()
if ocr_mode == "http":
base_url = os.getenv("OCR_SERVICE_URL", "http://localhost:8000/api/v1/ocr")
logger.info(f"Initializing HTTP OCR service with URL: {base_url}")
_global_service = HTTPOCRService(base_url=base_url)
else:
logger.info("Initializing local OCR service")
_global_service = LocalOCRService()
return _global_service
# 为了向后兼容,保留 get_ocr_client 函数(但重定向到 get_ocr_service
def get_ocr_client() -> OCRService:
"""
获取OCR服务实例向后兼容函数
建议使用 get_ocr_service() 替代
Returns:
OCRService 实例
"""
return get_ocr_service()