Files
TERES_fastapi_backend/ocr/client.py

240 lines
8.5 KiB
Python
Raw Permalink Normal View History

2025-11-03 10:22:28 +08:00
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
OCR HTTP 客户端工具类
用于通过 HTTP 接口调用 OCR 服务
"""
import logging
import os
from typing import Optional, Callable, List, Tuple, Any
try:
import httpx
HAS_HTTPX = True
except ImportError:
HAS_HTTPX = False
import aiohttp
logger = logging.getLogger(__name__)
class OCRClient:
"""OCR HTTP 客户端,用于调用 OCR API"""
def __init__(self, base_url: Optional[str] = None, timeout: float = 300.0):
"""
初始化 OCR 客户端
Args:
base_url: OCR 服务的基础 URL如果不提供则从环境变量 OCR_SERVICE_URL 获取
如果仍未设置则默认为 http://localhost:8000/api/v1/ocr
timeout: 请求超时时间默认 300
"""
self.base_url = base_url or os.getenv("OCR_SERVICE_URL", "http://localhost:8000/api/v1/ocr")
self.timeout = timeout
# 移除末尾的斜杠
if self.base_url.endswith('/'):
self.base_url = self.base_url.rstrip('/')
async def _make_request(self, method: str, endpoint: str, **kwargs) -> dict:
"""内部方法:发送 HTTP 请求"""
url = f"{self.base_url}{endpoint}"
if HAS_HTTPX:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.request(method, url, **kwargs)
response.raise_for_status()
return response.json()
else:
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session:
async with session.request(method, url, **kwargs) as response:
response.raise_for_status()
return await response.json()
async def remove_tag(self, text: str) -> str:
"""
移除文本中的位置标签
Args:
text: 包含位置标签的文本
Returns:
移除标签后的文本
"""
response = await self._make_request(
"POST",
"/remove_tag",
json={"text": text}
)
if response.get("success") and response.get("text") is not None:
return response["text"]
raise Exception(f"移除标签失败: {response.get('message', '未知错误')}")
def remove_tag_sync(self, text: str) -> str:
"""
同步版本的 remove_tag用于同步代码
Args:
text: 包含位置标签的文本
Returns:
移除标签后的文本
"""
import asyncio
try:
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.remove_tag(text))
except RuntimeError:
# 如果没有事件循环,创建一个新的
return asyncio.run(self.remove_tag(text))
async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""
从文本中提取位置信息
Args:
text: 包含位置标签的文本
Returns:
位置信息列表格式为 [(页码列表, left, right, top, bottom), ...]
"""
response = await self._make_request(
"POST",
"/extract_positions",
json={"text": text}
)
if response.get("success") and response.get("positions") is not None:
# 将响应格式转换为原始格式
positions = []
for pos in response["positions"]:
positions.append((
pos["page_numbers"],
pos["left"],
pos["right"],
pos["top"],
pos["bottom"]
))
return positions
raise Exception(f"提取位置信息失败: {response.get('message', '未知错误')}")
def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""
同步版本的 extract_positions用于同步代码
Args:
text: 包含位置标签的文本
Returns:
位置信息列表
"""
import asyncio
try:
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.extract_positions(text))
except RuntimeError:
return asyncio.run(self.extract_positions(text))
async def parse_into_bboxes(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""
解析 PDF 并返回边界框
Args:
pdf_bytes: PDF 文件的二进制数据
callback: 进度回调函数 (progress: float, message: str) -> None
zoomin: 图像放大倍数1-5默认为3
filename: 文件名仅用于日志
Returns:
边界框列表
"""
if HAS_HTTPX:
async with httpx.AsyncClient(timeout=self.timeout) as client:
# 注意httpx 需要将文件和数据合并到 files 参数中
form_data = {"filename": filename, "zoomin": str(zoomin)}
form_files = {"pdf_bytes": (filename, pdf_bytes, "application/pdf")}
response = await client.post(
f"{self.base_url}/parse_into_bboxes",
files=form_files,
data=form_data
)
response.raise_for_status()
result = response.json()
else:
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session:
form_data = aiohttp.FormData()
form_data.add_field('pdf_bytes', pdf_bytes, filename=filename, content_type='application/pdf')
form_data.add_field('filename', filename)
form_data.add_field('zoomin', str(zoomin))
async with session.post(
f"{self.base_url}/parse_into_bboxes",
data=form_data
) as response:
response.raise_for_status()
result = await response.json()
if result.get("success") and result.get("data") and result["data"].get("bboxes"):
return result["data"]["bboxes"]
raise Exception(f"解析 PDF 失败: {result.get('message', '未知错误')}")
def parse_into_bboxes_sync(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""
同步版本的 parse_into_bboxes用于同步代码
Args:
pdf_bytes: PDF 文件的二进制数据
callback: 进度回调函数注意HTTP 调用中无法实时传递回调此参数将被忽略
zoomin: 图像放大倍数1-5默认为3
filename: 文件名仅用于日志
Returns:
边界框列表
"""
if callback:
logger.warning("HTTP 调用中无法使用 callback将忽略回调函数")
import asyncio
try:
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.parse_into_bboxes(pdf_bytes, None, zoomin, filename))
except RuntimeError:
return asyncio.run(self.parse_into_bboxes(pdf_bytes, None, zoomin, filename))
# 全局客户端实例(懒加载)
_global_client: Optional[OCRClient] = None
def get_ocr_client() -> OCRClient:
"""获取全局 OCR 客户端实例(单例模式)"""
global _global_client
if _global_client is None:
_global_client = OCRClient()
return _global_client