Files
TERES_fastapi_backend/ocr/client.py
2025-11-03 10:22:28 +08:00

240 lines
8.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
OCR HTTP 客户端工具类
用于通过 HTTP 接口调用 OCR 服务
"""
import logging
import os
from typing import Optional, Callable, List, Tuple, Any
try:
import httpx
HAS_HTTPX = True
except ImportError:
HAS_HTTPX = False
import aiohttp
logger = logging.getLogger(__name__)
class OCRClient:
"""OCR HTTP 客户端,用于调用 OCR API"""
def __init__(self, base_url: Optional[str] = None, timeout: float = 300.0):
"""
初始化 OCR 客户端
Args:
base_url: OCR 服务的基础 URL如果不提供则从环境变量 OCR_SERVICE_URL 获取,
如果仍未设置则默认为 http://localhost:8000/api/v1/ocr
timeout: 请求超时时间(秒),默认 300 秒
"""
self.base_url = base_url or os.getenv("OCR_SERVICE_URL", "http://localhost:8000/api/v1/ocr")
self.timeout = timeout
# 移除末尾的斜杠
if self.base_url.endswith('/'):
self.base_url = self.base_url.rstrip('/')
async def _make_request(self, method: str, endpoint: str, **kwargs) -> dict:
"""内部方法:发送 HTTP 请求"""
url = f"{self.base_url}{endpoint}"
if HAS_HTTPX:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.request(method, url, **kwargs)
response.raise_for_status()
return response.json()
else:
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session:
async with session.request(method, url, **kwargs) as response:
response.raise_for_status()
return await response.json()
async def remove_tag(self, text: str) -> str:
"""
移除文本中的位置标签
Args:
text: 包含位置标签的文本
Returns:
移除标签后的文本
"""
response = await self._make_request(
"POST",
"/remove_tag",
json={"text": text}
)
if response.get("success") and response.get("text") is not None:
return response["text"]
raise Exception(f"移除标签失败: {response.get('message', '未知错误')}")
def remove_tag_sync(self, text: str) -> str:
"""
同步版本的 remove_tag用于同步代码
Args:
text: 包含位置标签的文本
Returns:
移除标签后的文本
"""
import asyncio
try:
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.remove_tag(text))
except RuntimeError:
# 如果没有事件循环,创建一个新的
return asyncio.run(self.remove_tag(text))
async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""
从文本中提取位置信息
Args:
text: 包含位置标签的文本
Returns:
位置信息列表,格式为 [(页码列表, left, right, top, bottom), ...]
"""
response = await self._make_request(
"POST",
"/extract_positions",
json={"text": text}
)
if response.get("success") and response.get("positions") is not None:
# 将响应格式转换为原始格式
positions = []
for pos in response["positions"]:
positions.append((
pos["page_numbers"],
pos["left"],
pos["right"],
pos["top"],
pos["bottom"]
))
return positions
raise Exception(f"提取位置信息失败: {response.get('message', '未知错误')}")
def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""
同步版本的 extract_positions用于同步代码
Args:
text: 包含位置标签的文本
Returns:
位置信息列表
"""
import asyncio
try:
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.extract_positions(text))
except RuntimeError:
return asyncio.run(self.extract_positions(text))
async def parse_into_bboxes(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""
解析 PDF 并返回边界框
Args:
pdf_bytes: PDF 文件的二进制数据
callback: 进度回调函数 (progress: float, message: str) -> None
zoomin: 图像放大倍数1-5默认为3
filename: 文件名(仅用于日志)
Returns:
边界框列表
"""
if HAS_HTTPX:
async with httpx.AsyncClient(timeout=self.timeout) as client:
# 注意httpx 需要将文件和数据合并到 files 参数中
form_data = {"filename": filename, "zoomin": str(zoomin)}
form_files = {"pdf_bytes": (filename, pdf_bytes, "application/pdf")}
response = await client.post(
f"{self.base_url}/parse_into_bboxes",
files=form_files,
data=form_data
)
response.raise_for_status()
result = response.json()
else:
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session:
form_data = aiohttp.FormData()
form_data.add_field('pdf_bytes', pdf_bytes, filename=filename, content_type='application/pdf')
form_data.add_field('filename', filename)
form_data.add_field('zoomin', str(zoomin))
async with session.post(
f"{self.base_url}/parse_into_bboxes",
data=form_data
) as response:
response.raise_for_status()
result = await response.json()
if result.get("success") and result.get("data") and result["data"].get("bboxes"):
return result["data"]["bboxes"]
raise Exception(f"解析 PDF 失败: {result.get('message', '未知错误')}")
def parse_into_bboxes_sync(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""
同步版本的 parse_into_bboxes用于同步代码
Args:
pdf_bytes: PDF 文件的二进制数据
callback: 进度回调函数注意HTTP 调用中无法实时传递回调,此参数将被忽略)
zoomin: 图像放大倍数1-5默认为3
filename: 文件名(仅用于日志)
Returns:
边界框列表
"""
if callback:
logger.warning("HTTP 调用中无法使用 callback将忽略回调函数")
import asyncio
try:
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.parse_into_bboxes(pdf_bytes, None, zoomin, filename))
except RuntimeError:
return asyncio.run(self.parse_into_bboxes(pdf_bytes, None, zoomin, filename))
# 全局客户端实例(懒加载)
_global_client: Optional[OCRClient] = None
def get_ocr_client() -> OCRClient:
"""获取全局 OCR 客户端实例(单例模式)"""
global _global_client
if _global_client is None:
_global_client = OCRClient()
return _global_client