# # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ OCR HTTP 客户端工具类 用于通过 HTTP 接口调用 OCR 服务 """ import logging import os from typing import Optional, Callable, List, Tuple, Any try: import httpx HAS_HTTPX = True except ImportError: HAS_HTTPX = False import aiohttp logger = logging.getLogger(__name__) class OCRClient: """OCR HTTP 客户端,用于调用 OCR API""" def __init__(self, base_url: Optional[str] = None, timeout: float = 300.0): """ 初始化 OCR 客户端 Args: base_url: OCR 服务的基础 URL,如果不提供则从环境变量 OCR_SERVICE_URL 获取, 如果仍未设置则默认为 http://localhost:8000/api/v1/ocr timeout: 请求超时时间(秒),默认 300 秒 """ self.base_url = base_url or os.getenv("OCR_SERVICE_URL", "http://localhost:8000/api/v1/ocr") self.timeout = timeout # 移除末尾的斜杠 if self.base_url.endswith('/'): self.base_url = self.base_url.rstrip('/') async def _make_request(self, method: str, endpoint: str, **kwargs) -> dict: """内部方法:发送 HTTP 请求""" url = f"{self.base_url}{endpoint}" if HAS_HTTPX: async with httpx.AsyncClient(timeout=self.timeout) as client: response = await client.request(method, url, **kwargs) response.raise_for_status() return response.json() else: async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session: async with session.request(method, url, **kwargs) as response: response.raise_for_status() return await response.json() async def remove_tag(self, text: str) -> str: """ 移除文本中的位置标签 Args: text: 包含位置标签的文本 Returns: 移除标签后的文本 """ response = await self._make_request( "POST", "/remove_tag", json={"text": text} ) if response.get("success") and response.get("text") is not None: return response["text"] raise Exception(f"移除标签失败: {response.get('message', '未知错误')}") def remove_tag_sync(self, text: str) -> str: """ 同步版本的 remove_tag(用于同步代码) Args: text: 包含位置标签的文本 Returns: 移除标签后的文本 """ import asyncio try: loop = asyncio.get_event_loop() return loop.run_until_complete(self.remove_tag(text)) except RuntimeError: # 如果没有事件循环,创建一个新的 return asyncio.run(self.remove_tag(text)) async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]: """ 从文本中提取位置信息 Args: text: 包含位置标签的文本 Returns: 位置信息列表,格式为 [(页码列表, left, right, top, bottom), ...] """ response = await self._make_request( "POST", "/extract_positions", json={"text": text} ) if response.get("success") and response.get("positions") is not None: # 将响应格式转换为原始格式 positions = [] for pos in response["positions"]: positions.append(( pos["page_numbers"], pos["left"], pos["right"], pos["top"], pos["bottom"] )) return positions raise Exception(f"提取位置信息失败: {response.get('message', '未知错误')}") def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]: """ 同步版本的 extract_positions(用于同步代码) Args: text: 包含位置标签的文本 Returns: 位置信息列表 """ import asyncio try: loop = asyncio.get_event_loop() return loop.run_until_complete(self.extract_positions(text)) except RuntimeError: return asyncio.run(self.extract_positions(text)) async def parse_into_bboxes( self, pdf_bytes: bytes, callback: Optional[Callable[[float, str], None]] = None, zoomin: int = 3, filename: str = "document.pdf" ) -> List[dict]: """ 解析 PDF 并返回边界框 Args: pdf_bytes: PDF 文件的二进制数据 callback: 进度回调函数 (progress: float, message: str) -> None zoomin: 图像放大倍数(1-5,默认为3) filename: 文件名(仅用于日志) Returns: 边界框列表 """ if HAS_HTTPX: async with httpx.AsyncClient(timeout=self.timeout) as client: # 注意:httpx 需要将文件和数据合并到 files 参数中 form_data = {"filename": filename, "zoomin": str(zoomin)} form_files = {"pdf_bytes": (filename, pdf_bytes, "application/pdf")} response = await client.post( f"{self.base_url}/parse_into_bboxes", files=form_files, data=form_data ) response.raise_for_status() result = response.json() else: async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session: form_data = aiohttp.FormData() form_data.add_field('pdf_bytes', pdf_bytes, filename=filename, content_type='application/pdf') form_data.add_field('filename', filename) form_data.add_field('zoomin', str(zoomin)) async with session.post( f"{self.base_url}/parse_into_bboxes", data=form_data ) as response: response.raise_for_status() result = await response.json() if result.get("success") and result.get("data") and result["data"].get("bboxes"): return result["data"]["bboxes"] raise Exception(f"解析 PDF 失败: {result.get('message', '未知错误')}") def parse_into_bboxes_sync( self, pdf_bytes: bytes, callback: Optional[Callable[[float, str], None]] = None, zoomin: int = 3, filename: str = "document.pdf" ) -> List[dict]: """ 同步版本的 parse_into_bboxes(用于同步代码) Args: pdf_bytes: PDF 文件的二进制数据 callback: 进度回调函数(注意:HTTP 调用中无法实时传递回调,此参数将被忽略) zoomin: 图像放大倍数(1-5,默认为3) filename: 文件名(仅用于日志) Returns: 边界框列表 """ if callback: logger.warning("HTTP 调用中无法使用 callback,将忽略回调函数") import asyncio try: loop = asyncio.get_event_loop() return loop.run_until_complete(self.parse_into_bboxes(pdf_bytes, None, zoomin, filename)) except RuntimeError: return asyncio.run(self.parse_into_bboxes(pdf_bytes, None, zoomin, filename)) # 全局客户端实例(懒加载) _global_client: Optional[OCRClient] = None def get_ocr_client() -> OCRClient: """获取全局 OCR 客户端实例(单例模式)""" global _global_client if _global_client is None: _global_client = OCRClient() return _global_client