240 lines
8.5 KiB
Python
240 lines
8.5 KiB
Python
|
|
#
|
|||
|
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|||
|
|
#
|
|||
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|||
|
|
# you may not use this file except in compliance with the License.
|
|||
|
|
# You may obtain a copy of the License at
|
|||
|
|
#
|
|||
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|||
|
|
#
|
|||
|
|
# Unless required by applicable law or agreed to in writing, software
|
|||
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|||
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
|
|
# See the License for the specific language governing permissions and
|
|||
|
|
# limitations under the License.
|
|||
|
|
#
|
|||
|
|
"""
|
|||
|
|
OCR HTTP 客户端工具类
|
|||
|
|
用于通过 HTTP 接口调用 OCR 服务
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import logging
|
|||
|
|
import os
|
|||
|
|
from typing import Optional, Callable, List, Tuple, Any
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
import httpx
|
|||
|
|
HAS_HTTPX = True
|
|||
|
|
except ImportError:
|
|||
|
|
HAS_HTTPX = False
|
|||
|
|
import aiohttp
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class OCRClient:
|
|||
|
|
"""OCR HTTP 客户端,用于调用 OCR API"""
|
|||
|
|
|
|||
|
|
def __init__(self, base_url: Optional[str] = None, timeout: float = 300.0):
|
|||
|
|
"""
|
|||
|
|
初始化 OCR 客户端
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
base_url: OCR 服务的基础 URL,如果不提供则从环境变量 OCR_SERVICE_URL 获取,
|
|||
|
|
如果仍未设置则默认为 http://localhost:8000/api/v1/ocr
|
|||
|
|
timeout: 请求超时时间(秒),默认 300 秒
|
|||
|
|
"""
|
|||
|
|
self.base_url = base_url or os.getenv("OCR_SERVICE_URL", "http://localhost:8000/api/v1/ocr")
|
|||
|
|
self.timeout = timeout
|
|||
|
|
# 移除末尾的斜杠
|
|||
|
|
if self.base_url.endswith('/'):
|
|||
|
|
self.base_url = self.base_url.rstrip('/')
|
|||
|
|
|
|||
|
|
async def _make_request(self, method: str, endpoint: str, **kwargs) -> dict:
|
|||
|
|
"""内部方法:发送 HTTP 请求"""
|
|||
|
|
url = f"{self.base_url}{endpoint}"
|
|||
|
|
|
|||
|
|
if HAS_HTTPX:
|
|||
|
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|||
|
|
response = await client.request(method, url, **kwargs)
|
|||
|
|
response.raise_for_status()
|
|||
|
|
return response.json()
|
|||
|
|
else:
|
|||
|
|
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session:
|
|||
|
|
async with session.request(method, url, **kwargs) as response:
|
|||
|
|
response.raise_for_status()
|
|||
|
|
return await response.json()
|
|||
|
|
|
|||
|
|
async def remove_tag(self, text: str) -> str:
|
|||
|
|
"""
|
|||
|
|
移除文本中的位置标签
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
text: 包含位置标签的文本
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
移除标签后的文本
|
|||
|
|
"""
|
|||
|
|
response = await self._make_request(
|
|||
|
|
"POST",
|
|||
|
|
"/remove_tag",
|
|||
|
|
json={"text": text}
|
|||
|
|
)
|
|||
|
|
if response.get("success") and response.get("text") is not None:
|
|||
|
|
return response["text"]
|
|||
|
|
raise Exception(f"移除标签失败: {response.get('message', '未知错误')}")
|
|||
|
|
|
|||
|
|
def remove_tag_sync(self, text: str) -> str:
|
|||
|
|
"""
|
|||
|
|
同步版本的 remove_tag(用于同步代码)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
text: 包含位置标签的文本
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
移除标签后的文本
|
|||
|
|
"""
|
|||
|
|
import asyncio
|
|||
|
|
try:
|
|||
|
|
loop = asyncio.get_event_loop()
|
|||
|
|
return loop.run_until_complete(self.remove_tag(text))
|
|||
|
|
except RuntimeError:
|
|||
|
|
# 如果没有事件循环,创建一个新的
|
|||
|
|
return asyncio.run(self.remove_tag(text))
|
|||
|
|
|
|||
|
|
async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
|
|||
|
|
"""
|
|||
|
|
从文本中提取位置信息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
text: 包含位置标签的文本
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
位置信息列表,格式为 [(页码列表, left, right, top, bottom), ...]
|
|||
|
|
"""
|
|||
|
|
response = await self._make_request(
|
|||
|
|
"POST",
|
|||
|
|
"/extract_positions",
|
|||
|
|
json={"text": text}
|
|||
|
|
)
|
|||
|
|
if response.get("success") and response.get("positions") is not None:
|
|||
|
|
# 将响应格式转换为原始格式
|
|||
|
|
positions = []
|
|||
|
|
for pos in response["positions"]:
|
|||
|
|
positions.append((
|
|||
|
|
pos["page_numbers"],
|
|||
|
|
pos["left"],
|
|||
|
|
pos["right"],
|
|||
|
|
pos["top"],
|
|||
|
|
pos["bottom"]
|
|||
|
|
))
|
|||
|
|
return positions
|
|||
|
|
raise Exception(f"提取位置信息失败: {response.get('message', '未知错误')}")
|
|||
|
|
|
|||
|
|
def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
|
|||
|
|
"""
|
|||
|
|
同步版本的 extract_positions(用于同步代码)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
text: 包含位置标签的文本
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
位置信息列表
|
|||
|
|
"""
|
|||
|
|
import asyncio
|
|||
|
|
try:
|
|||
|
|
loop = asyncio.get_event_loop()
|
|||
|
|
return loop.run_until_complete(self.extract_positions(text))
|
|||
|
|
except RuntimeError:
|
|||
|
|
return asyncio.run(self.extract_positions(text))
|
|||
|
|
|
|||
|
|
async def parse_into_bboxes(
|
|||
|
|
self,
|
|||
|
|
pdf_bytes: bytes,
|
|||
|
|
callback: Optional[Callable[[float, str], None]] = None,
|
|||
|
|
zoomin: int = 3,
|
|||
|
|
filename: str = "document.pdf"
|
|||
|
|
) -> List[dict]:
|
|||
|
|
"""
|
|||
|
|
解析 PDF 并返回边界框
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
pdf_bytes: PDF 文件的二进制数据
|
|||
|
|
callback: 进度回调函数 (progress: float, message: str) -> None
|
|||
|
|
zoomin: 图像放大倍数(1-5,默认为3)
|
|||
|
|
filename: 文件名(仅用于日志)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
边界框列表
|
|||
|
|
"""
|
|||
|
|
if HAS_HTTPX:
|
|||
|
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|||
|
|
# 注意:httpx 需要将文件和数据合并到 files 参数中
|
|||
|
|
form_data = {"filename": filename, "zoomin": str(zoomin)}
|
|||
|
|
form_files = {"pdf_bytes": (filename, pdf_bytes, "application/pdf")}
|
|||
|
|
response = await client.post(
|
|||
|
|
f"{self.base_url}/parse_into_bboxes",
|
|||
|
|
files=form_files,
|
|||
|
|
data=form_data
|
|||
|
|
)
|
|||
|
|
response.raise_for_status()
|
|||
|
|
result = response.json()
|
|||
|
|
else:
|
|||
|
|
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session:
|
|||
|
|
form_data = aiohttp.FormData()
|
|||
|
|
form_data.add_field('pdf_bytes', pdf_bytes, filename=filename, content_type='application/pdf')
|
|||
|
|
form_data.add_field('filename', filename)
|
|||
|
|
form_data.add_field('zoomin', str(zoomin))
|
|||
|
|
|
|||
|
|
async with session.post(
|
|||
|
|
f"{self.base_url}/parse_into_bboxes",
|
|||
|
|
data=form_data
|
|||
|
|
) as response:
|
|||
|
|
response.raise_for_status()
|
|||
|
|
result = await response.json()
|
|||
|
|
|
|||
|
|
if result.get("success") and result.get("data") and result["data"].get("bboxes"):
|
|||
|
|
return result["data"]["bboxes"]
|
|||
|
|
raise Exception(f"解析 PDF 失败: {result.get('message', '未知错误')}")
|
|||
|
|
|
|||
|
|
def parse_into_bboxes_sync(
|
|||
|
|
self,
|
|||
|
|
pdf_bytes: bytes,
|
|||
|
|
callback: Optional[Callable[[float, str], None]] = None,
|
|||
|
|
zoomin: int = 3,
|
|||
|
|
filename: str = "document.pdf"
|
|||
|
|
) -> List[dict]:
|
|||
|
|
"""
|
|||
|
|
同步版本的 parse_into_bboxes(用于同步代码)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
pdf_bytes: PDF 文件的二进制数据
|
|||
|
|
callback: 进度回调函数(注意:HTTP 调用中无法实时传递回调,此参数将被忽略)
|
|||
|
|
zoomin: 图像放大倍数(1-5,默认为3)
|
|||
|
|
filename: 文件名(仅用于日志)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
边界框列表
|
|||
|
|
"""
|
|||
|
|
if callback:
|
|||
|
|
logger.warning("HTTP 调用中无法使用 callback,将忽略回调函数")
|
|||
|
|
import asyncio
|
|||
|
|
try:
|
|||
|
|
loop = asyncio.get_event_loop()
|
|||
|
|
return loop.run_until_complete(self.parse_into_bboxes(pdf_bytes, None, zoomin, filename))
|
|||
|
|
except RuntimeError:
|
|||
|
|
return asyncio.run(self.parse_into_bboxes(pdf_bytes, None, zoomin, filename))
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全局客户端实例(懒加载)
|
|||
|
|
_global_client: Optional[OCRClient] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_ocr_client() -> OCRClient:
|
|||
|
|
"""获取全局 OCR 客户端实例(单例模式)"""
|
|||
|
|
global _global_client
|
|||
|
|
if _global_client is None:
|
|||
|
|
_global_client = OCRClient()
|
|||
|
|
return _global_client
|
|||
|
|
|