优化OCR解析
This commit is contained in:
@@ -25,6 +25,12 @@ import sys
|
|||||||
import signal
|
import signal
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
# 确保项目根目录在 sys.path 中
|
||||||
|
_current_file = Path(__file__).resolve()
|
||||||
|
_project_root = _current_file.parent.parent
|
||||||
|
if str(_project_root) not in sys.path:
|
||||||
|
sys.path.insert(0, str(_project_root))
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
@@ -20,12 +20,15 @@
|
|||||||
可以直接作为独立模块使用。
|
可以直接作为独立模块使用。
|
||||||
|
|
||||||
使用方法:
|
使用方法:
|
||||||
from ocr import OCR
|
from ocr import OCR, SimplePdfParser
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
ocr = OCR()
|
ocr = OCR()
|
||||||
img = cv2.imread("image.jpg")
|
img = cv2.imread("image.jpg")
|
||||||
results = ocr(img)
|
results = ocr(img)
|
||||||
|
|
||||||
|
parser = SimplePdfParser()
|
||||||
|
result = parser.parse_pdf("document.pdf")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 处理导入问题:支持直接运行和模块导入
|
# 处理导入问题:支持直接运行和模块导入
|
||||||
@@ -35,3 +38,4 @@ from pathlib import Path
|
|||||||
|
|
||||||
__all__ = ['OCR', 'TextDetector', 'TextRecognizer', 'SimplePdfParser']
|
__all__ = ['OCR', 'TextDetector', 'TextRecognizer', 'SimplePdfParser']
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
14
ocr/api.py
14
ocr/api.py
@@ -57,7 +57,7 @@ class ParseResponse(BaseModel):
|
|||||||
data: Optional[dict] = None
|
data: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@ocr_router.get(
|
||||||
"/health",
|
"/health",
|
||||||
summary="健康检查",
|
summary="健康检查",
|
||||||
description="检查OCR服务的健康状态和配置信息",
|
description="检查OCR服务的健康状态和配置信息",
|
||||||
@@ -79,7 +79,7 @@ async def health_check():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@ocr_router.post(
|
||||||
"/parse",
|
"/parse",
|
||||||
response_model=ParseResponse,
|
response_model=ParseResponse,
|
||||||
summary="上传并解析PDF文件",
|
summary="上传并解析PDF文件",
|
||||||
@@ -165,7 +165,7 @@ async def parse_pdf_endpoint(
|
|||||||
logger.warning(f"Failed to delete temp file {temp_file}: {e}")
|
logger.warning(f"Failed to delete temp file {temp_file}: {e}")
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@ocr_router.post(
|
||||||
"/parse/bytes",
|
"/parse/bytes",
|
||||||
response_model=ParseResponse,
|
response_model=ParseResponse,
|
||||||
summary="通过二进制数据解析PDF",
|
summary="通过二进制数据解析PDF",
|
||||||
@@ -244,7 +244,7 @@ async def parse_pdf_bytes(
|
|||||||
logger.warning(f"Failed to delete temp file {temp_file}: {e}")
|
logger.warning(f"Failed to delete temp file {temp_file}: {e}")
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@ocr_router.post(
|
||||||
"/parse/path",
|
"/parse/path",
|
||||||
response_model=ParseResponse,
|
response_model=ParseResponse,
|
||||||
summary="通过文件路径解析PDF",
|
summary="通过文件路径解析PDF",
|
||||||
@@ -315,7 +315,7 @@ async def parse_pdf_path(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@ocr_router.post(
|
||||||
"/parse_into_bboxes",
|
"/parse_into_bboxes",
|
||||||
summary="解析PDF并返回边界框",
|
summary="解析PDF并返回边界框",
|
||||||
description="解析PDF文件并返回文本边界框信息,用于文档结构化处理",
|
description="解析PDF文件并返回文本边界框信息,用于文档结构化处理",
|
||||||
@@ -414,7 +414,7 @@ class RemoveTagResponse(BaseModel):
|
|||||||
text: Optional[str] = None
|
text: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@ocr_router.post(
|
||||||
"/remove_tag",
|
"/remove_tag",
|
||||||
response_model=RemoveTagResponse,
|
response_model=RemoveTagResponse,
|
||||||
summary="移除文本中的位置标签",
|
summary="移除文本中的位置标签",
|
||||||
@@ -464,7 +464,7 @@ class ExtractPositionsResponse(BaseModel):
|
|||||||
positions: Optional[list] = None
|
positions: Optional[list] = None
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@ocr_router.post(
|
||||||
"/extract_positions",
|
"/extract_positions",
|
||||||
response_model=ExtractPositionsResponse,
|
response_model=ExtractPositionsResponse,
|
||||||
summary="从文本中提取位置信息",
|
summary="从文本中提取位置信息",
|
||||||
|
|||||||
239
ocr/client.py
Normal file
239
ocr/client.py
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
"""
|
||||||
|
OCR HTTP 客户端工具类
|
||||||
|
用于通过 HTTP 接口调用 OCR 服务
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Optional, Callable, List, Tuple, Any
|
||||||
|
|
||||||
|
try:
|
||||||
|
import httpx
|
||||||
|
HAS_HTTPX = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_HTTPX = False
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OCRClient:
|
||||||
|
"""OCR HTTP 客户端,用于调用 OCR API"""
|
||||||
|
|
||||||
|
def __init__(self, base_url: Optional[str] = None, timeout: float = 300.0):
|
||||||
|
"""
|
||||||
|
初始化 OCR 客户端
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_url: OCR 服务的基础 URL,如果不提供则从环境变量 OCR_SERVICE_URL 获取,
|
||||||
|
如果仍未设置则默认为 http://localhost:8000/api/v1/ocr
|
||||||
|
timeout: 请求超时时间(秒),默认 300 秒
|
||||||
|
"""
|
||||||
|
self.base_url = base_url or os.getenv("OCR_SERVICE_URL", "http://localhost:8000/api/v1/ocr")
|
||||||
|
self.timeout = timeout
|
||||||
|
# 移除末尾的斜杠
|
||||||
|
if self.base_url.endswith('/'):
|
||||||
|
self.base_url = self.base_url.rstrip('/')
|
||||||
|
|
||||||
|
async def _make_request(self, method: str, endpoint: str, **kwargs) -> dict:
|
||||||
|
"""内部方法:发送 HTTP 请求"""
|
||||||
|
url = f"{self.base_url}{endpoint}"
|
||||||
|
|
||||||
|
if HAS_HTTPX:
|
||||||
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||||
|
response = await client.request(method, url, **kwargs)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
else:
|
||||||
|
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session:
|
||||||
|
async with session.request(method, url, **kwargs) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
return await response.json()
|
||||||
|
|
||||||
|
async def remove_tag(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
移除文本中的位置标签
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 包含位置标签的文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
移除标签后的文本
|
||||||
|
"""
|
||||||
|
response = await self._make_request(
|
||||||
|
"POST",
|
||||||
|
"/remove_tag",
|
||||||
|
json={"text": text}
|
||||||
|
)
|
||||||
|
if response.get("success") and response.get("text") is not None:
|
||||||
|
return response["text"]
|
||||||
|
raise Exception(f"移除标签失败: {response.get('message', '未知错误')}")
|
||||||
|
|
||||||
|
def remove_tag_sync(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
同步版本的 remove_tag(用于同步代码)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 包含位置标签的文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
移除标签后的文本
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.run_until_complete(self.remove_tag(text))
|
||||||
|
except RuntimeError:
|
||||||
|
# 如果没有事件循环,创建一个新的
|
||||||
|
return asyncio.run(self.remove_tag(text))
|
||||||
|
|
||||||
|
async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
|
||||||
|
"""
|
||||||
|
从文本中提取位置信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 包含位置标签的文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
位置信息列表,格式为 [(页码列表, left, right, top, bottom), ...]
|
||||||
|
"""
|
||||||
|
response = await self._make_request(
|
||||||
|
"POST",
|
||||||
|
"/extract_positions",
|
||||||
|
json={"text": text}
|
||||||
|
)
|
||||||
|
if response.get("success") and response.get("positions") is not None:
|
||||||
|
# 将响应格式转换为原始格式
|
||||||
|
positions = []
|
||||||
|
for pos in response["positions"]:
|
||||||
|
positions.append((
|
||||||
|
pos["page_numbers"],
|
||||||
|
pos["left"],
|
||||||
|
pos["right"],
|
||||||
|
pos["top"],
|
||||||
|
pos["bottom"]
|
||||||
|
))
|
||||||
|
return positions
|
||||||
|
raise Exception(f"提取位置信息失败: {response.get('message', '未知错误')}")
|
||||||
|
|
||||||
|
def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
|
||||||
|
"""
|
||||||
|
同步版本的 extract_positions(用于同步代码)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 包含位置标签的文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
位置信息列表
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.run_until_complete(self.extract_positions(text))
|
||||||
|
except RuntimeError:
|
||||||
|
return asyncio.run(self.extract_positions(text))
|
||||||
|
|
||||||
|
async def parse_into_bboxes(
|
||||||
|
self,
|
||||||
|
pdf_bytes: bytes,
|
||||||
|
callback: Optional[Callable[[float, str], None]] = None,
|
||||||
|
zoomin: int = 3,
|
||||||
|
filename: str = "document.pdf"
|
||||||
|
) -> List[dict]:
|
||||||
|
"""
|
||||||
|
解析 PDF 并返回边界框
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pdf_bytes: PDF 文件的二进制数据
|
||||||
|
callback: 进度回调函数 (progress: float, message: str) -> None
|
||||||
|
zoomin: 图像放大倍数(1-5,默认为3)
|
||||||
|
filename: 文件名(仅用于日志)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
边界框列表
|
||||||
|
"""
|
||||||
|
if HAS_HTTPX:
|
||||||
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||||
|
# 注意:httpx 需要将文件和数据合并到 files 参数中
|
||||||
|
form_data = {"filename": filename, "zoomin": str(zoomin)}
|
||||||
|
form_files = {"pdf_bytes": (filename, pdf_bytes, "application/pdf")}
|
||||||
|
response = await client.post(
|
||||||
|
f"{self.base_url}/parse_into_bboxes",
|
||||||
|
files=form_files,
|
||||||
|
data=form_data
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
else:
|
||||||
|
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session:
|
||||||
|
form_data = aiohttp.FormData()
|
||||||
|
form_data.add_field('pdf_bytes', pdf_bytes, filename=filename, content_type='application/pdf')
|
||||||
|
form_data.add_field('filename', filename)
|
||||||
|
form_data.add_field('zoomin', str(zoomin))
|
||||||
|
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/parse_into_bboxes",
|
||||||
|
data=form_data
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
result = await response.json()
|
||||||
|
|
||||||
|
if result.get("success") and result.get("data") and result["data"].get("bboxes"):
|
||||||
|
return result["data"]["bboxes"]
|
||||||
|
raise Exception(f"解析 PDF 失败: {result.get('message', '未知错误')}")
|
||||||
|
|
||||||
|
def parse_into_bboxes_sync(
|
||||||
|
self,
|
||||||
|
pdf_bytes: bytes,
|
||||||
|
callback: Optional[Callable[[float, str], None]] = None,
|
||||||
|
zoomin: int = 3,
|
||||||
|
filename: str = "document.pdf"
|
||||||
|
) -> List[dict]:
|
||||||
|
"""
|
||||||
|
同步版本的 parse_into_bboxes(用于同步代码)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pdf_bytes: PDF 文件的二进制数据
|
||||||
|
callback: 进度回调函数(注意:HTTP 调用中无法实时传递回调,此参数将被忽略)
|
||||||
|
zoomin: 图像放大倍数(1-5,默认为3)
|
||||||
|
filename: 文件名(仅用于日志)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
边界框列表
|
||||||
|
"""
|
||||||
|
if callback:
|
||||||
|
logger.warning("HTTP 调用中无法使用 callback,将忽略回调函数")
|
||||||
|
import asyncio
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.run_until_complete(self.parse_into_bboxes(pdf_bytes, None, zoomin, filename))
|
||||||
|
except RuntimeError:
|
||||||
|
return asyncio.run(self.parse_into_bboxes(pdf_bytes, None, zoomin, filename))
|
||||||
|
|
||||||
|
|
||||||
|
# 全局客户端实例(懒加载)
|
||||||
|
_global_client: Optional[OCRClient] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_ocr_client() -> OCRClient:
|
||||||
|
"""获取全局 OCR 客户端实例(单例模式)"""
|
||||||
|
global _global_client
|
||||||
|
if _global_client is None:
|
||||||
|
_global_client = OCRClient()
|
||||||
|
return _global_client
|
||||||
|
|
||||||
290
ocr/service.py
Normal file
290
ocr/service.py
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
"""
|
||||||
|
OCR 服务统一接口
|
||||||
|
支持本地OCR模型和HTTP接口两种方式,可通过配置选择
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional, Callable, List, Tuple, Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OCRService(ABC):
|
||||||
|
"""OCR服务抽象接口"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def remove_tag(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
移除文本中的位置标签
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 包含位置标签的文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
清理后的文本
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def remove_tag_sync(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
同步版本的 remove_tag(用于同步代码)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 包含位置标签的文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
清理后的文本
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
|
||||||
|
"""
|
||||||
|
从文本中提取位置信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 包含位置标签的文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
位置信息列表,格式为 [(页码列表, left, right, top, bottom), ...]
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
|
||||||
|
"""
|
||||||
|
同步版本的 extract_positions(用于同步代码)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 包含位置标签的文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
位置信息列表
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def parse_into_bboxes(
|
||||||
|
self,
|
||||||
|
pdf_bytes: bytes,
|
||||||
|
callback: Optional[Callable[[float, str], None]] = None,
|
||||||
|
zoomin: int = 3,
|
||||||
|
filename: str = "document.pdf"
|
||||||
|
) -> List[dict]:
|
||||||
|
"""
|
||||||
|
解析 PDF 并返回边界框
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pdf_bytes: PDF 文件的二进制数据
|
||||||
|
callback: 进度回调函数 (progress: float, message: str) -> None
|
||||||
|
zoomin: 图像放大倍数(1-5,默认为3)
|
||||||
|
filename: 文件名(仅用于日志)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
边界框列表
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def parse_into_bboxes_sync(
|
||||||
|
self,
|
||||||
|
pdf_bytes: bytes,
|
||||||
|
callback: Optional[Callable[[float, str], None]] = None,
|
||||||
|
zoomin: int = 3,
|
||||||
|
filename: str = "document.pdf"
|
||||||
|
) -> List[dict]:
|
||||||
|
"""
|
||||||
|
同步版本的 parse_into_bboxes(用于同步代码)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pdf_bytes: PDF 文件的二进制数据
|
||||||
|
callback: 进度回调函数(注意:HTTP 调用中无法实时传递回调,此参数将被忽略)
|
||||||
|
zoomin: 图像放大倍数(1-5,默认为3)
|
||||||
|
filename: 文件名(仅用于日志)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
边界框列表
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LocalOCRService(OCRService):
|
||||||
|
"""本地OCR服务实现(直接调用本地OCR模型)"""
|
||||||
|
|
||||||
|
def __init__(self, parser_instance=None):
|
||||||
|
"""
|
||||||
|
初始化本地OCR服务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parser_instance: SimplePdfParser 实例,如果不提供则自动创建
|
||||||
|
"""
|
||||||
|
if parser_instance is None:
|
||||||
|
from ocr import SimplePdfParser
|
||||||
|
from ocr.config import MODEL_DIR
|
||||||
|
logger.info(f"Initializing local OCR parser with model_dir={MODEL_DIR}")
|
||||||
|
self.parser = SimplePdfParser(model_dir=MODEL_DIR)
|
||||||
|
else:
|
||||||
|
self.parser = parser_instance
|
||||||
|
|
||||||
|
async def remove_tag(self, text: str) -> str:
|
||||||
|
"""使用本地解析器的静态方法移除标签"""
|
||||||
|
# SimplePdfParser.remove_tag 是静态方法,可以直接调用
|
||||||
|
return self.parser.remove_tag(text)
|
||||||
|
|
||||||
|
def remove_tag_sync(self, text: str) -> str:
|
||||||
|
"""同步版本的 remove_tag"""
|
||||||
|
return self.parser.remove_tag(text)
|
||||||
|
|
||||||
|
async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
|
||||||
|
"""使用本地解析器的静态方法提取位置"""
|
||||||
|
# SimplePdfParser.extract_positions 是静态方法,可以直接调用
|
||||||
|
return self.parser.extract_positions(text)
|
||||||
|
|
||||||
|
def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
|
||||||
|
"""同步版本的 extract_positions"""
|
||||||
|
return self.parser.extract_positions(text)
|
||||||
|
|
||||||
|
async def parse_into_bboxes(
|
||||||
|
self,
|
||||||
|
pdf_bytes: bytes,
|
||||||
|
callback: Optional[Callable[[float, str], None]] = None,
|
||||||
|
zoomin: int = 3,
|
||||||
|
filename: str = "document.pdf"
|
||||||
|
) -> List[dict]:
|
||||||
|
"""使用本地解析器解析PDF"""
|
||||||
|
# 本地解析器可以直接接受BytesIO
|
||||||
|
import asyncio
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
# 在后台线程中运行同步方法
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
bboxes = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: self.parser.parse_into_bboxes(BytesIO(pdf_bytes), callback=callback, zoomin=zoomin)
|
||||||
|
)
|
||||||
|
return bboxes
|
||||||
|
|
||||||
|
def parse_into_bboxes_sync(
|
||||||
|
self,
|
||||||
|
pdf_bytes: bytes,
|
||||||
|
callback: Optional[Callable[[float, str], None]] = None,
|
||||||
|
zoomin: int = 3,
|
||||||
|
filename: str = "document.pdf"
|
||||||
|
) -> List[dict]:
|
||||||
|
"""同步版本的 parse_into_bboxes"""
|
||||||
|
from io import BytesIO
|
||||||
|
# 本地解析器可以直接接受BytesIO
|
||||||
|
return self.parser.parse_into_bboxes(BytesIO(pdf_bytes), callback=callback, zoomin=zoomin)
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPOCRService(OCRService):
|
||||||
|
"""HTTP OCR服务实现(通过HTTP接口调用OCR服务)"""
|
||||||
|
|
||||||
|
def __init__(self, base_url: Optional[str] = None, timeout: float = 300.0):
|
||||||
|
"""
|
||||||
|
初始化HTTP OCR服务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_url: OCR 服务的基础 URL,如果不提供则从环境变量 OCR_SERVICE_URL 获取
|
||||||
|
timeout: 请求超时时间(秒),默认 300 秒
|
||||||
|
"""
|
||||||
|
from ocr.client import OCRClient
|
||||||
|
self.client = OCRClient(base_url=base_url, timeout=timeout)
|
||||||
|
|
||||||
|
async def remove_tag(self, text: str) -> str:
|
||||||
|
"""通过HTTP接口移除标签"""
|
||||||
|
return await self.client.remove_tag(text)
|
||||||
|
|
||||||
|
def remove_tag_sync(self, text: str) -> str:
|
||||||
|
"""同步版本的 remove_tag"""
|
||||||
|
return self.client.remove_tag_sync(text)
|
||||||
|
|
||||||
|
async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
|
||||||
|
"""通过HTTP接口提取位置"""
|
||||||
|
return await self.client.extract_positions(text)
|
||||||
|
|
||||||
|
def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
|
||||||
|
"""同步版本的 extract_positions"""
|
||||||
|
return self.client.extract_positions_sync(text)
|
||||||
|
|
||||||
|
async def parse_into_bboxes(
|
||||||
|
self,
|
||||||
|
pdf_bytes: bytes,
|
||||||
|
callback: Optional[Callable[[float, str], None]] = None,
|
||||||
|
zoomin: int = 3,
|
||||||
|
filename: str = "document.pdf"
|
||||||
|
) -> List[dict]:
|
||||||
|
"""通过HTTP接口解析PDF"""
|
||||||
|
return await self.client.parse_into_bboxes(pdf_bytes, callback, zoomin, filename)
|
||||||
|
|
||||||
|
def parse_into_bboxes_sync(
|
||||||
|
self,
|
||||||
|
pdf_bytes: bytes,
|
||||||
|
callback: Optional[Callable[[float, str], None]] = None,
|
||||||
|
zoomin: int = 3,
|
||||||
|
filename: str = "document.pdf"
|
||||||
|
) -> List[dict]:
|
||||||
|
"""同步版本的 parse_into_bboxes"""
|
||||||
|
return self.client.parse_into_bboxes_sync(pdf_bytes, callback, zoomin, filename)
|
||||||
|
|
||||||
|
|
||||||
|
# 全局服务实例(懒加载)
|
||||||
|
_global_service: Optional[OCRService] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_ocr_service() -> OCRService:
|
||||||
|
"""
|
||||||
|
获取全局 OCR 服务实例(单例模式)
|
||||||
|
根据环境变量 OCR_MODE 选择使用本地或HTTP方式:
|
||||||
|
- OCR_MODE=local 或未设置:使用本地OCR模型
|
||||||
|
- OCR_MODE=http:使用HTTP接口
|
||||||
|
|
||||||
|
也可以通过环境变量 OCR_SERVICE_URL 配置HTTP服务的地址(仅在OCR_MODE=http时生效)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OCRService 实例
|
||||||
|
"""
|
||||||
|
global _global_service
|
||||||
|
if _global_service is None:
|
||||||
|
ocr_mode = os.getenv("OCR_MODE", "local").lower()
|
||||||
|
|
||||||
|
if ocr_mode == "http":
|
||||||
|
base_url = os.getenv("OCR_SERVICE_URL", "http://localhost:8000/api/v1/ocr")
|
||||||
|
logger.info(f"Initializing HTTP OCR service with URL: {base_url}")
|
||||||
|
_global_service = HTTPOCRService(base_url=base_url)
|
||||||
|
else:
|
||||||
|
logger.info("Initializing local OCR service")
|
||||||
|
_global_service = LocalOCRService()
|
||||||
|
|
||||||
|
return _global_service
|
||||||
|
|
||||||
|
|
||||||
|
# 为了向后兼容,保留 get_ocr_client 函数(但重定向到 get_ocr_service)
|
||||||
|
def get_ocr_client() -> OCRService:
|
||||||
|
"""
|
||||||
|
获取OCR服务实例(向后兼容函数)
|
||||||
|
建议使用 get_ocr_service() 替代
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OCRService 实例
|
||||||
|
"""
|
||||||
|
return get_ocr_service()
|
||||||
|
|
||||||
@@ -22,7 +22,7 @@ import trio
|
|||||||
|
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.utils.base64_image import id2image, image2id
|
from api.utils.base64_image import id2image, image2id
|
||||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
from ocr.service import get_ocr_service
|
||||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||||
from rag.flow.hierarchical_merger.schema import HierarchicalMergerFromUpstream
|
from rag.flow.hierarchical_merger.schema import HierarchicalMergerFromUpstream
|
||||||
from rag.nlp import concat_img
|
from rag.nlp import concat_img
|
||||||
@@ -170,14 +170,17 @@ class HierarchicalMerger(ProcessBase):
|
|||||||
cks.append(txt)
|
cks.append(txt)
|
||||||
images.append(img)
|
images.append(img)
|
||||||
|
|
||||||
cks = [
|
ocr_service = get_ocr_service()
|
||||||
{
|
processed_cks = []
|
||||||
"text": RAGFlowPdfParser.remove_tag(c),
|
for c, img in zip(cks, images):
|
||||||
|
cleaned_text = await ocr_service.remove_tag(c)
|
||||||
|
positions = await ocr_service.extract_positions(c)
|
||||||
|
processed_cks.append({
|
||||||
|
"text": cleaned_text,
|
||||||
"image": img,
|
"image": img,
|
||||||
"positions": RAGFlowPdfParser.extract_positions(c),
|
"positions": positions,
|
||||||
}
|
})
|
||||||
for c, img in zip(cks, images)
|
cks = processed_cks
|
||||||
]
|
|
||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
for d in cks:
|
for d in cks:
|
||||||
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid())
|
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid())
|
||||||
|
|||||||
@@ -29,7 +29,8 @@ from api.db.services.llm_service import LLMBundle
|
|||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.utils.base64_image import image2id
|
from api.utils.base64_image import image2id
|
||||||
from deepdoc.parser import ExcelParser
|
from deepdoc.parser import ExcelParser
|
||||||
from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser
|
from deepdoc.parser.pdf_parser import PlainParser, VisionParser
|
||||||
|
from ocr.service import get_ocr_service
|
||||||
from rag.app.naive import Docx
|
from rag.app.naive import Docx
|
||||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||||
from rag.flow.parser.schema import ParserFromUpstream
|
from rag.flow.parser.schema import ParserFromUpstream
|
||||||
@@ -204,7 +205,9 @@ class Parser(ProcessBase):
|
|||||||
self.set_output("output_format", conf["output_format"])
|
self.set_output("output_format", conf["output_format"])
|
||||||
|
|
||||||
if conf.get("parse_method").lower() == "deepdoc":
|
if conf.get("parse_method").lower() == "deepdoc":
|
||||||
bboxes = RAGFlowPdfParser().parse_into_bboxes(blob, callback=self.callback)
|
# 注意:HTTP 调用中无法传递 callback,callback 将被忽略
|
||||||
|
ocr_service = get_ocr_service()
|
||||||
|
bboxes = ocr_service.parse_into_bboxes_sync(blob, callback=self.callback, filename=name)
|
||||||
elif conf.get("parse_method").lower() == "plain_text":
|
elif conf.get("parse_method").lower() == "plain_text":
|
||||||
lines, _ = PlainParser()(blob)
|
lines, _ = PlainParser()(blob)
|
||||||
bboxes = [{"text": t} for t, _ in lines]
|
bboxes = [{"text": t} for t, _ in lines]
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import trio
|
|||||||
|
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.utils.base64_image import id2image, image2id
|
from api.utils.base64_image import id2image, image2id
|
||||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
from ocr.service import get_ocr_service
|
||||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||||
from rag.flow.splitter.schema import SplitterFromUpstream
|
from rag.flow.splitter.schema import SplitterFromUpstream
|
||||||
from rag.nlp import naive_merge, naive_merge_with_images
|
from rag.nlp import naive_merge, naive_merge_with_images
|
||||||
@@ -96,14 +96,18 @@ class Splitter(ProcessBase):
|
|||||||
deli,
|
deli,
|
||||||
self._param.overlapped_percent,
|
self._param.overlapped_percent,
|
||||||
)
|
)
|
||||||
cks = [
|
ocr_service = get_ocr_service()
|
||||||
{
|
cks = []
|
||||||
"text": RAGFlowPdfParser.remove_tag(c),
|
for c, img in zip(chunks, images):
|
||||||
|
if not c.strip():
|
||||||
|
continue
|
||||||
|
cleaned_text = await ocr_service.remove_tag(c)
|
||||||
|
positions = await ocr_service.extract_positions(c)
|
||||||
|
cks.append({
|
||||||
|
"text": cleaned_text,
|
||||||
"image": img,
|
"image": img,
|
||||||
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)],
|
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in positions],
|
||||||
}
|
})
|
||||||
for c, img in zip(chunks, images) if c.strip()
|
|
||||||
]
|
|
||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
for d in cks:
|
for d in cks:
|
||||||
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid())
|
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid())
|
||||||
|
|||||||
@@ -578,7 +578,8 @@ def hierarchical_merge(bull, sections, depth):
|
|||||||
|
|
||||||
|
|
||||||
def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
|
def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
|
||||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
from ocr.service import get_ocr_service
|
||||||
|
ocr_service = get_ocr_service()
|
||||||
if not sections:
|
if not sections:
|
||||||
return []
|
return []
|
||||||
if isinstance(sections, str):
|
if isinstance(sections, str):
|
||||||
@@ -598,7 +599,7 @@ def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。;
|
|||||||
# Ensure that the length of the merged chunk does not exceed chunk_token_num
|
# Ensure that the length of the merged chunk does not exceed chunk_token_num
|
||||||
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
|
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
|
||||||
if cks:
|
if cks:
|
||||||
overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
|
overlapped = ocr_service.remove_tag_sync(cks[-1])
|
||||||
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
|
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
|
||||||
if t.find(pos) < 0:
|
if t.find(pos) < 0:
|
||||||
t += pos
|
t += pos
|
||||||
@@ -625,7 +626,8 @@ def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。;
|
|||||||
|
|
||||||
|
|
||||||
def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
|
def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
|
||||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
from ocr.service import get_ocr_service
|
||||||
|
ocr_service = get_ocr_service()
|
||||||
if not texts or len(texts) != len(images):
|
if not texts or len(texts) != len(images):
|
||||||
return [], []
|
return [], []
|
||||||
cks = [""]
|
cks = [""]
|
||||||
@@ -642,7 +644,7 @@ def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。
|
|||||||
# Ensure that the length of the merged chunk does not exceed chunk_token_num
|
# Ensure that the length of the merged chunk does not exceed chunk_token_num
|
||||||
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
|
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
|
||||||
if cks:
|
if cks:
|
||||||
overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
|
overlapped = ocr_service.remove_tag_sync(cks[-1])
|
||||||
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
|
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
|
||||||
if t.find(pos) < 0:
|
if t.find(pos) < 0:
|
||||||
t += pos
|
t += pos
|
||||||
|
|||||||
Reference in New Issue
Block a user