优化OCR解析

This commit is contained in:
2025-11-03 10:22:28 +08:00
parent 4603a86df4
commit 3e58c3d0e9
9 changed files with 581 additions and 30 deletions

View File

@@ -20,12 +20,15 @@
可以直接作为独立模块使用。
使用方法:
from ocr import OCR
from ocr import OCR, SimplePdfParser
import cv2
ocr = OCR()
img = cv2.imread("image.jpg")
results = ocr(img)
parser = SimplePdfParser()
result = parser.parse_pdf("document.pdf")
"""
# 处理导入问题:支持直接运行和模块导入
@@ -35,3 +38,4 @@ from pathlib import Path
__all__ = ['OCR', 'TextDetector', 'TextRecognizer', 'SimplePdfParser']

View File

@@ -57,7 +57,7 @@ class ParseResponse(BaseModel):
data: Optional[dict] = None
@router.get(
@ocr_router.get(
"/health",
summary="健康检查",
description="检查OCR服务的健康状态和配置信息",
@@ -79,7 +79,7 @@ async def health_check():
}
@router.post(
@ocr_router.post(
"/parse",
response_model=ParseResponse,
summary="上传并解析PDF文件",
@@ -165,7 +165,7 @@ async def parse_pdf_endpoint(
logger.warning(f"Failed to delete temp file {temp_file}: {e}")
@router.post(
@ocr_router.post(
"/parse/bytes",
response_model=ParseResponse,
summary="通过二进制数据解析PDF",
@@ -244,7 +244,7 @@ async def parse_pdf_bytes(
logger.warning(f"Failed to delete temp file {temp_file}: {e}")
@router.post(
@ocr_router.post(
"/parse/path",
response_model=ParseResponse,
summary="通过文件路径解析PDF",
@@ -315,7 +315,7 @@ async def parse_pdf_path(
)
@router.post(
@ocr_router.post(
"/parse_into_bboxes",
summary="解析PDF并返回边界框",
description="解析PDF文件并返回文本边界框信息用于文档结构化处理",
@@ -414,7 +414,7 @@ class RemoveTagResponse(BaseModel):
text: Optional[str] = None
@router.post(
@ocr_router.post(
"/remove_tag",
response_model=RemoveTagResponse,
summary="移除文本中的位置标签",
@@ -464,7 +464,7 @@ class ExtractPositionsResponse(BaseModel):
positions: Optional[list] = None
@router.post(
@ocr_router.post(
"/extract_positions",
response_model=ExtractPositionsResponse,
summary="从文本中提取位置信息",

239
ocr/client.py Normal file
View File

@@ -0,0 +1,239 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
OCR HTTP 客户端工具类
用于通过 HTTP 接口调用 OCR 服务
"""
import logging
import os
from typing import Optional, Callable, List, Tuple, Any
try:
import httpx
HAS_HTTPX = True
except ImportError:
HAS_HTTPX = False
import aiohttp
logger = logging.getLogger(__name__)
class OCRClient:
"""OCR HTTP 客户端,用于调用 OCR API"""
def __init__(self, base_url: Optional[str] = None, timeout: float = 300.0):
"""
初始化 OCR 客户端
Args:
base_url: OCR 服务的基础 URL如果不提供则从环境变量 OCR_SERVICE_URL 获取,
如果仍未设置则默认为 http://localhost:8000/api/v1/ocr
timeout: 请求超时时间(秒),默认 300 秒
"""
self.base_url = base_url or os.getenv("OCR_SERVICE_URL", "http://localhost:8000/api/v1/ocr")
self.timeout = timeout
# 移除末尾的斜杠
if self.base_url.endswith('/'):
self.base_url = self.base_url.rstrip('/')
async def _make_request(self, method: str, endpoint: str, **kwargs) -> dict:
"""内部方法:发送 HTTP 请求"""
url = f"{self.base_url}{endpoint}"
if HAS_HTTPX:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.request(method, url, **kwargs)
response.raise_for_status()
return response.json()
else:
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session:
async with session.request(method, url, **kwargs) as response:
response.raise_for_status()
return await response.json()
async def remove_tag(self, text: str) -> str:
"""
移除文本中的位置标签
Args:
text: 包含位置标签的文本
Returns:
移除标签后的文本
"""
response = await self._make_request(
"POST",
"/remove_tag",
json={"text": text}
)
if response.get("success") and response.get("text") is not None:
return response["text"]
raise Exception(f"移除标签失败: {response.get('message', '未知错误')}")
def remove_tag_sync(self, text: str) -> str:
"""
同步版本的 remove_tag用于同步代码
Args:
text: 包含位置标签的文本
Returns:
移除标签后的文本
"""
import asyncio
try:
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.remove_tag(text))
except RuntimeError:
# 如果没有事件循环,创建一个新的
return asyncio.run(self.remove_tag(text))
async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""
从文本中提取位置信息
Args:
text: 包含位置标签的文本
Returns:
位置信息列表,格式为 [(页码列表, left, right, top, bottom), ...]
"""
response = await self._make_request(
"POST",
"/extract_positions",
json={"text": text}
)
if response.get("success") and response.get("positions") is not None:
# 将响应格式转换为原始格式
positions = []
for pos in response["positions"]:
positions.append((
pos["page_numbers"],
pos["left"],
pos["right"],
pos["top"],
pos["bottom"]
))
return positions
raise Exception(f"提取位置信息失败: {response.get('message', '未知错误')}")
def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""
同步版本的 extract_positions用于同步代码
Args:
text: 包含位置标签的文本
Returns:
位置信息列表
"""
import asyncio
try:
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.extract_positions(text))
except RuntimeError:
return asyncio.run(self.extract_positions(text))
async def parse_into_bboxes(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""
解析 PDF 并返回边界框
Args:
pdf_bytes: PDF 文件的二进制数据
callback: 进度回调函数 (progress: float, message: str) -> None
zoomin: 图像放大倍数1-5默认为3
filename: 文件名(仅用于日志)
Returns:
边界框列表
"""
if HAS_HTTPX:
async with httpx.AsyncClient(timeout=self.timeout) as client:
# 注意httpx 需要将文件和数据合并到 files 参数中
form_data = {"filename": filename, "zoomin": str(zoomin)}
form_files = {"pdf_bytes": (filename, pdf_bytes, "application/pdf")}
response = await client.post(
f"{self.base_url}/parse_into_bboxes",
files=form_files,
data=form_data
)
response.raise_for_status()
result = response.json()
else:
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session:
form_data = aiohttp.FormData()
form_data.add_field('pdf_bytes', pdf_bytes, filename=filename, content_type='application/pdf')
form_data.add_field('filename', filename)
form_data.add_field('zoomin', str(zoomin))
async with session.post(
f"{self.base_url}/parse_into_bboxes",
data=form_data
) as response:
response.raise_for_status()
result = await response.json()
if result.get("success") and result.get("data") and result["data"].get("bboxes"):
return result["data"]["bboxes"]
raise Exception(f"解析 PDF 失败: {result.get('message', '未知错误')}")
def parse_into_bboxes_sync(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""
同步版本的 parse_into_bboxes用于同步代码
Args:
pdf_bytes: PDF 文件的二进制数据
callback: 进度回调函数注意HTTP 调用中无法实时传递回调,此参数将被忽略)
zoomin: 图像放大倍数1-5默认为3
filename: 文件名(仅用于日志)
Returns:
边界框列表
"""
if callback:
logger.warning("HTTP 调用中无法使用 callback将忽略回调函数")
import asyncio
try:
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.parse_into_bboxes(pdf_bytes, None, zoomin, filename))
except RuntimeError:
return asyncio.run(self.parse_into_bboxes(pdf_bytes, None, zoomin, filename))
# 全局客户端实例(懒加载)
_global_client: Optional[OCRClient] = None
def get_ocr_client() -> OCRClient:
"""获取全局 OCR 客户端实例(单例模式)"""
global _global_client
if _global_client is None:
_global_client = OCRClient()
return _global_client

View File

@@ -1,185 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
OCR PDF处理服务的主程序入口
独立运行不依赖RAGFlow的其他部分
"""
import argparse
import logging
import os
import sys
import signal
from pathlib import Path
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from ocr.api import ocr_router
from ocr.config import MODEL_DIR
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
def create_app() -> FastAPI:
"""创建FastAPI应用实例"""
app = FastAPI(
title="OCR PDF Parser API",
description="独立的OCR PDF处理服务提供PDF文档的OCR识别功能",
version="1.0.0",
docs_url="/apidocs", # Swagger UI 文档地址
redoc_url="/redoc", # ReDoc 文档地址(备用)
openapi_url="/openapi.json" # OpenAPI JSON schema 地址
)
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 生产环境中应该设置具体的域名
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 注册OCR路由
app.include_router(ocr_router)
# 根路径
@app.get("/")
async def root():
return {
"service": "OCR PDF Parser",
"version": "1.0.0",
"docs": "/apidocs",
"health": "/api/v1/ocr/health"
}
return app
def signal_handler(sig, frame):
"""信号处理器,用于优雅关闭"""
logger.info("Received shutdown signal, exiting...")
sys.exit(0)
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="OCR PDF处理服务")
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="服务器监听地址 (default: 0.0.0.0)"
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="服务器端口 (default: 8000)"
)
parser.add_argument(
"--reload",
action="store_true",
help="开发模式:自动重载代码"
)
parser.add_argument(
"--workers",
type=int,
default=1,
help="工作进程数 (default: 1)"
)
parser.add_argument(
"--log-level",
type=str,
default="info",
choices=["critical", "error", "warning", "info", "debug", "trace"],
help="日志级别 (default: info)"
)
parser.add_argument(
"--model-dir",
type=str,
default=None,
help=f"OCR模型目录路径 (default: {MODEL_DIR})"
)
args = parser.parse_args()
# 设置模型目录(如果提供)
if args.model_dir:
os.environ["OCR_MODEL_DIR"] = args.model_dir
logger.info(f"Using custom model directory: {args.model_dir}")
# 检查模型目录
model_dir = os.environ.get("OCR_MODEL_DIR", MODEL_DIR)
if model_dir and not os.path.exists(model_dir):
logger.warning(f"Model directory does not exist: {model_dir}")
logger.info("Models will be downloaded on first use")
# 注册信号处理器
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# 显示启动信息
logger.info("=" * 60)
logger.info("OCR PDF Parser Service")
logger.info("=" * 60)
logger.info(f"Host: {args.host}")
logger.info(f"Port: {args.port}")
logger.info(f"Model Directory: {model_dir}")
logger.info(f"Workers: {args.workers}")
logger.info(f"Reload: {args.reload}")
logger.info(f"Log Level: {args.log_level}")
logger.info("=" * 60)
logger.info(f"API Documentation (Swagger): http://{args.host}:{args.port}/apidocs")
logger.info(f"API Documentation (ReDoc): http://{args.host}:{args.port}/redoc")
logger.info(f"Health Check: http://{args.host}:{args.port}/api/v1/ocr/health")
logger.info("=" * 60)
# 创建应用
app = create_app()
# 启动服务器
try:
uvicorn.run(
app,
host=args.host,
port=args.port,
log_level=args.log_level,
reload=args.reload,
workers=args.workers if not args.reload else 1, # reload模式不支持多进程
access_log=True
)
except KeyboardInterrupt:
logger.info("Server stopped by user")
except Exception as e:
logger.error(f"Server error: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()

290
ocr/service.py Normal file
View File

@@ -0,0 +1,290 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
OCR 服务统一接口
支持本地OCR模型和HTTP接口两种方式可通过配置选择
"""
import logging
import os
from abc import ABC, abstractmethod
from typing import Optional, Callable, List, Tuple, Any
logger = logging.getLogger(__name__)
class OCRService(ABC):
"""OCR服务抽象接口"""
@abstractmethod
async def remove_tag(self, text: str) -> str:
"""
移除文本中的位置标签
Args:
text: 包含位置标签的文本
Returns:
清理后的文本
"""
pass
@abstractmethod
def remove_tag_sync(self, text: str) -> str:
"""
同步版本的 remove_tag用于同步代码
Args:
text: 包含位置标签的文本
Returns:
清理后的文本
"""
pass
@abstractmethod
async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""
从文本中提取位置信息
Args:
text: 包含位置标签的文本
Returns:
位置信息列表,格式为 [(页码列表, left, right, top, bottom), ...]
"""
pass
@abstractmethod
def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""
同步版本的 extract_positions用于同步代码
Args:
text: 包含位置标签的文本
Returns:
位置信息列表
"""
pass
@abstractmethod
async def parse_into_bboxes(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""
解析 PDF 并返回边界框
Args:
pdf_bytes: PDF 文件的二进制数据
callback: 进度回调函数 (progress: float, message: str) -> None
zoomin: 图像放大倍数1-5默认为3
filename: 文件名(仅用于日志)
Returns:
边界框列表
"""
pass
@abstractmethod
def parse_into_bboxes_sync(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""
同步版本的 parse_into_bboxes用于同步代码
Args:
pdf_bytes: PDF 文件的二进制数据
callback: 进度回调函数注意HTTP 调用中无法实时传递回调,此参数将被忽略)
zoomin: 图像放大倍数1-5默认为3
filename: 文件名(仅用于日志)
Returns:
边界框列表
"""
pass
class LocalOCRService(OCRService):
"""本地OCR服务实现直接调用本地OCR模型"""
def __init__(self, parser_instance=None):
"""
初始化本地OCR服务
Args:
parser_instance: SimplePdfParser 实例,如果不提供则自动创建
"""
if parser_instance is None:
from ocr import SimplePdfParser
from ocr.config import MODEL_DIR
logger.info(f"Initializing local OCR parser with model_dir={MODEL_DIR}")
self.parser = SimplePdfParser(model_dir=MODEL_DIR)
else:
self.parser = parser_instance
async def remove_tag(self, text: str) -> str:
"""使用本地解析器的静态方法移除标签"""
# SimplePdfParser.remove_tag 是静态方法,可以直接调用
return self.parser.remove_tag(text)
def remove_tag_sync(self, text: str) -> str:
"""同步版本的 remove_tag"""
return self.parser.remove_tag(text)
async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""使用本地解析器的静态方法提取位置"""
# SimplePdfParser.extract_positions 是静态方法,可以直接调用
return self.parser.extract_positions(text)
def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""同步版本的 extract_positions"""
return self.parser.extract_positions(text)
async def parse_into_bboxes(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""使用本地解析器解析PDF"""
# 本地解析器可以直接接受BytesIO
import asyncio
from io import BytesIO
# 在后台线程中运行同步方法
loop = asyncio.get_event_loop()
bboxes = await loop.run_in_executor(
None,
lambda: self.parser.parse_into_bboxes(BytesIO(pdf_bytes), callback=callback, zoomin=zoomin)
)
return bboxes
def parse_into_bboxes_sync(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""同步版本的 parse_into_bboxes"""
from io import BytesIO
# 本地解析器可以直接接受BytesIO
return self.parser.parse_into_bboxes(BytesIO(pdf_bytes), callback=callback, zoomin=zoomin)
class HTTPOCRService(OCRService):
"""HTTP OCR服务实现通过HTTP接口调用OCR服务"""
def __init__(self, base_url: Optional[str] = None, timeout: float = 300.0):
"""
初始化HTTP OCR服务
Args:
base_url: OCR 服务的基础 URL如果不提供则从环境变量 OCR_SERVICE_URL 获取
timeout: 请求超时时间(秒),默认 300 秒
"""
from ocr.client import OCRClient
self.client = OCRClient(base_url=base_url, timeout=timeout)
async def remove_tag(self, text: str) -> str:
"""通过HTTP接口移除标签"""
return await self.client.remove_tag(text)
def remove_tag_sync(self, text: str) -> str:
"""同步版本的 remove_tag"""
return self.client.remove_tag_sync(text)
async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""通过HTTP接口提取位置"""
return await self.client.extract_positions(text)
def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""同步版本的 extract_positions"""
return self.client.extract_positions_sync(text)
async def parse_into_bboxes(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""通过HTTP接口解析PDF"""
return await self.client.parse_into_bboxes(pdf_bytes, callback, zoomin, filename)
def parse_into_bboxes_sync(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""同步版本的 parse_into_bboxes"""
return self.client.parse_into_bboxes_sync(pdf_bytes, callback, zoomin, filename)
# 全局服务实例(懒加载)
_global_service: Optional[OCRService] = None
def get_ocr_service() -> OCRService:
"""
获取全局 OCR 服务实例(单例模式)
根据环境变量 OCR_MODE 选择使用本地或HTTP方式
- OCR_MODE=local 或未设置使用本地OCR模型
- OCR_MODE=http使用HTTP接口
也可以通过环境变量 OCR_SERVICE_URL 配置HTTP服务的地址仅在OCR_MODE=http时生效
Returns:
OCRService 实例
"""
global _global_service
if _global_service is None:
ocr_mode = os.getenv("OCR_MODE", "local").lower()
if ocr_mode == "http":
base_url = os.getenv("OCR_SERVICE_URL", "http://localhost:8000/api/v1/ocr")
logger.info(f"Initializing HTTP OCR service with URL: {base_url}")
_global_service = HTTPOCRService(base_url=base_url)
else:
logger.info("Initializing local OCR service")
_global_service = LocalOCRService()
return _global_service
# 为了向后兼容,保留 get_ocr_client 函数(但重定向到 get_ocr_service
def get_ocr_client() -> OCRService:
"""
获取OCR服务实例向后兼容函数
建议使用 get_ocr_service() 替代
Returns:
OCRService 实例
"""
return get_ocr_service()