Files
TERES_fastapi_backend/ocr/api.py
2025-11-03 10:22:28 +08:00

525 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
OCR PDF处理的FastAPI路由
提供HTTP接口用于PDF的OCR识别
"""
import asyncio
import logging
import os
import sys
import tempfile
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from ocr import SimplePdfParser
from ocr.config import MODEL_DIR
logger = logging.getLogger(__name__)
ocr_router = APIRouter(prefix="/api/v1/ocr", tags=["OCR"])
# 全局解析器实例(懒加载)
_parser_instance: Optional[SimplePdfParser] = None
def get_parser() -> SimplePdfParser:
"""获取全局解析器实例(单例模式)"""
global _parser_instance
if _parser_instance is None:
logger.info(f"Initializing OCR parser with model_dir={MODEL_DIR}")
_parser_instance = SimplePdfParser(model_dir=MODEL_DIR)
return _parser_instance
class ParseResponse(BaseModel):
"""解析响应模型"""
success: bool
message: str
data: Optional[dict] = None
@ocr_router.get(
"/health",
summary="健康检查",
description="检查OCR服务的健康状态和配置信息",
response_description="返回服务状态和模型目录信息"
)
async def health_check():
"""
健康检查端点
用于检查OCR服务的运行状态和配置信息。
Returns:
dict: 包含服务状态和模型目录的信息
"""
return {
"status": "healthy",
"service": "OCR PDF Parser",
"model_dir": MODEL_DIR
}
@ocr_router.post(
"/parse",
response_model=ParseResponse,
summary="上传并解析PDF文件",
description="上传PDF文件并通过OCR识别提取文本内容",
response_description="返回OCR识别结果"
)
async def parse_pdf_endpoint(
file: UploadFile = File(..., description="PDF文件支持上传任意PDF文档"),
page_from: int = Form(1, ge=1, description="起始页码从1开始默认为1"),
page_to: int = Form(0, ge=0, description="结束页码0表示解析到最后一页默认为0"),
zoomin: int = Form(3, ge=1, le=5, description="图像放大倍数1-5数值越大识别精度越高但速度越慢默认为3")
):
"""
上传并解析PDF文件
通过上传PDF文件使用OCR技术识别并提取其中的文本内容。
支持指定解析的页码范围,以及调整图像放大倍数以平衡识别精度和速度。
Args:
file: 上传的PDF文件multipart/form-data格式
page_from: 起始页码从1开始最小值为1
page_to: 结束页码0表示解析到最后一页最小值为0
zoomin: 图像放大倍数1-5之间数值越大识别精度越高但处理速度越慢
Returns:
ParseResponse: 包含解析结果的响应对象,包括:
- success: 是否成功
- message: 操作结果消息
- data: OCR识别的文本内容和元数据
Raises:
HTTPException: 400 - 如果文件不是PDF格式或文件为空
HTTPException: 500 - 如果解析过程中发生错误
"""
if not file.filename.lower().endswith('.pdf'):
raise HTTPException(status_code=400, detail="只支持PDF文件")
# 保存上传的文件到临时目录
temp_file = None
try:
# 读取文件内容
content = await file.read()
if not content:
raise HTTPException(status_code=400, detail="文件为空")
# 创建临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
tmp.write(content)
temp_file = tmp.name
logger.info(f"Parsing PDF file: {file.filename}, pages {page_from}-{page_to or 'end'}, zoomin={zoomin}")
# 解析PDFparse_pdf是同步方法使用to_thread在线程池中执行
parser = get_parser()
result = await asyncio.to_thread(
parser.parse_pdf,
temp_file,
zoomin,
page_from - 1, # 转换为从0开始的索引
(page_to - 1) if page_to > 0 else 299, # 转换为从0开始的索引
None # callback
)
return ParseResponse(
success=True,
message=f"成功解析PDF: {file.filename}",
data=result
)
except Exception as e:
logger.error(f"Error parsing PDF: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"解析PDF时发生错误: {str(e)}"
)
finally:
# 清理临时文件
if temp_file and os.path.exists(temp_file):
try:
os.unlink(temp_file)
except Exception as e:
logger.warning(f"Failed to delete temp file {temp_file}: {e}")
@ocr_router.post(
"/parse/bytes",
response_model=ParseResponse,
summary="通过二进制数据解析PDF",
description="直接通过二进制数据解析PDF文件无需上传文件",
response_description="返回OCR识别结果"
)
async def parse_pdf_bytes(
pdf_bytes: bytes = File(..., description="PDF文件的二进制数据multipart/form-data格式"),
filename: str = Form("document.pdf", description="文件名(仅用于日志记录,不影响解析)"),
page_from: int = Form(1, ge=1, description="起始页码从1开始默认为1"),
page_to: int = Form(0, ge=0, description="结束页码0表示解析到最后一页默认为0"),
zoomin: int = Form(3, ge=1, le=5, description="图像放大倍数1-5数值越大识别精度越高但速度越慢默认为3")
):
"""
直接通过二进制数据解析PDF
适用于已获取PDF二进制数据的场景无需文件上传步骤。
直接将PDF的二进制数据提交即可进行OCR识别。
Args:
pdf_bytes: PDF文件的二进制数据以文件形式提交
filename: 文件名(仅用于日志记录,不影响实际解析过程)
page_from: 起始页码从1开始最小值为1
page_to: 结束页码0表示解析到最后一页最小值为0
zoomin: 图像放大倍数1-5之间数值越大识别精度越高但处理速度越慢
Returns:
ParseResponse: 包含解析结果的响应对象
Raises:
HTTPException: 400 - 如果PDF数据为空
HTTPException: 500 - 如果解析过程中发生错误
"""
if not pdf_bytes:
raise HTTPException(status_code=400, detail="PDF数据为空")
# 保存到临时文件
temp_file = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
tmp.write(pdf_bytes)
temp_file = tmp.name
logger.info(f"Parsing PDF bytes (filename: {filename}), pages {page_from}-{page_to or 'end'}, zoomin={zoomin}")
# 解析PDFparse_pdf是同步方法使用to_thread在线程池中执行
parser = get_parser()
result = await asyncio.to_thread(
parser.parse_pdf,
temp_file,
zoomin,
page_from - 1, # 转换为从0开始的索引
(page_to - 1) if page_to > 0 else 299, # 转换为从0开始的索引
None # callback
)
return ParseResponse(
success=True,
message=f"成功解析PDF: {filename}",
data=result
)
except Exception as e:
logger.error(f"Error parsing PDF bytes: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"解析PDF时发生错误: {str(e)}"
)
finally:
# 清理临时文件
if temp_file and os.path.exists(temp_file):
try:
os.unlink(temp_file)
except Exception as e:
logger.warning(f"Failed to delete temp file {temp_file}: {e}")
@ocr_router.post(
"/parse/path",
response_model=ParseResponse,
summary="通过文件路径解析PDF",
description="通过服务器本地文件路径解析PDF文件",
response_description="返回OCR识别结果"
)
async def parse_pdf_path(
file_path: str = Form(..., description="PDF文件在服务器上的本地路径必须是可访问的绝对路径"),
page_from: int = Form(1, ge=1, description="起始页码从1开始默认为1"),
page_to: int = Form(0, ge=0, description="结束页码0表示解析到最后一页默认为0"),
zoomin: int = Form(3, ge=1, le=5, description="图像放大倍数1-5数值越大识别精度越高但速度越慢默认为3")
):
"""
通过文件路径解析PDF
适用于PDF文件已经存在于服务器上的场景。
通过提供文件路径直接进行OCR识别无需上传文件。
Args:
file_path: PDF文件在服务器上的本地路径必须是服务器可访问的绝对路径
page_from: 起始页码从1开始最小值为1
page_to: 结束页码0表示解析到最后一页最小值为0
zoomin: 图像放大倍数1-5之间数值越大识别精度越高但处理速度越慢
Returns:
ParseResponse: 包含解析结果的响应对象
Raises:
HTTPException: 400 - 如果文件不是PDF格式
HTTPException: 404 - 如果文件不存在
HTTPException: 500 - 如果解析过程中发生错误
Note:
此端点需要确保提供的文件路径在服务器上可访问。
建议仅在内网环境或受信任的环境中使用,避免路径遍历安全风险。
"""
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail=f"文件不存在: {file_path}")
if not file_path.lower().endswith('.pdf'):
raise HTTPException(status_code=400, detail="只支持PDF文件")
try:
logger.info(f"Parsing PDF from path: {file_path}, pages {page_from}-{page_to or 'end'}, zoomin={zoomin}")
# 解析PDFparse_pdf是同步方法使用to_thread在线程池中执行
parser = get_parser()
result = await asyncio.to_thread(
parser.parse_pdf,
file_path,
zoomin,
page_from - 1, # 转换为从0开始的索引
(page_to - 1) if page_to > 0 else 299, # 转换为从0开始的索引
None # callback
)
return ParseResponse(
success=True,
message=f"成功解析PDF: {file_path}",
data=result
)
except Exception as e:
logger.error(f"Error parsing PDF from path: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"解析PDF时发生错误: {str(e)}"
)
@ocr_router.post(
"/parse_into_bboxes",
summary="解析PDF并返回边界框",
description="解析PDF文件并返回文本边界框信息用于文档结构化处理",
response_description="返回包含文本边界框的列表"
)
async def parse_into_bboxes_endpoint(
pdf_bytes: bytes = File(..., description="PDF文件的二进制数据"),
filename: str = Form("document.pdf", description="文件名(仅用于日志)"),
zoomin: int = Form(3, ge=1, le=5, description="图像放大倍数1-5默认为3")
):
"""
解析PDF并返回边界框
此接口用于将PDF文档解析为结构化文本边界框每个边界框包含
- 文本内容
- 页面编号
- 坐标信息x0, x1, top, bottom
- 布局类型(如 text, table, figure 等)
- 图像数据(如果有)
Args:
pdf_bytes: PDF文件的二进制数据
filename: 文件名(仅用于日志记录)
zoomin: 图像放大倍数1-5之间
Returns:
dict: 包含解析结果的对象data字段为边界框列表
Raises:
HTTPException: 400 - 如果PDF数据为空
HTTPException: 500 - 如果解析过程中发生错误
"""
if not pdf_bytes:
raise HTTPException(status_code=400, detail="PDF数据为空")
temp_file = None
try:
# 保存到临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
tmp.write(pdf_bytes)
temp_file = tmp.name
logger.info(f"Parsing PDF into bboxes: {filename}, zoomin={zoomin}")
# 定义一个简单的callback包装器用于处理进度回调记录日志
def progress_callback(prog, msg):
logger.info(f"Progress: {prog:.2%} - {msg}")
parser = get_parser()
result = await asyncio.to_thread(
parser.parse_into_bboxes,
temp_file,
progress_callback,
zoomin
)
# 将图像数据转换为base64或None
processed_result = []
for bbox in result:
processed_bbox = dict(bbox)
# 如果有图像转换为base64如果需要的话可以在这里处理
# 但为了保持兼容性,我们保留原始格式
processed_result.append(processed_bbox)
return ParseResponse(
success=True,
message=f"成功解析PDF为边界框: {filename}",
data={"bboxes": processed_result}
)
except Exception as e:
logger.error(f"Error parsing PDF into bboxes: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"解析PDF为边界框时发生错误: {str(e)}"
)
finally:
# 清理临时文件
if temp_file and os.path.exists(temp_file):
try:
os.unlink(temp_file)
except Exception as e:
logger.warning(f"Failed to delete temp file {temp_file}: {e}")
class TextRequest(BaseModel):
"""文本处理请求模型"""
text: str = Field(..., description="需要处理的文本内容")
class RemoveTagResponse(BaseModel):
"""移除标签响应模型"""
success: bool
message: str
text: Optional[str] = None
@ocr_router.post(
"/remove_tag",
response_model=RemoveTagResponse,
summary="移除文本中的位置标签",
description="从文本中移除PDF解析生成的位置标签格式@@页码\t坐标##",
response_description="返回移除标签后的文本"
)
async def remove_tag_endpoint(request: TextRequest):
"""
移除文本中的位置标签
此接口用于从包含位置标签的文本中移除标签信息。
位置标签格式为:@@页码\t坐标##,例如:@@1\t100.0\t200.0\t50.0\t60.0##
Args:
request: 包含待处理文本的请求对象
Returns:
RemoveTagResponse: 包含处理结果的响应对象
Raises:
HTTPException: 400 - 如果文本为空
"""
if not request.text:
raise HTTPException(status_code=400, detail="文本内容不能为空")
try:
cleaned_text = SimplePdfParser.remove_tag(request.text)
return RemoveTagResponse(
success=True,
message="成功移除文本标签",
text=cleaned_text
)
except Exception as e:
logger.error(f"Error removing tag: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"移除标签时发生错误: {str(e)}"
)
class ExtractPositionsResponse(BaseModel):
"""提取位置信息响应模型"""
success: bool
message: str
positions: Optional[list] = None
@ocr_router.post(
"/extract_positions",
response_model=ExtractPositionsResponse,
summary="从文本中提取位置信息",
description="从包含位置标签的文本中提取所有位置坐标信息",
response_description="返回提取到的位置信息列表"
)
async def extract_positions_endpoint(request: TextRequest):
"""
从文本中提取位置信息
此接口用于从包含位置标签的文本中提取所有位置坐标信息。
位置标签格式为:@@页码\t坐标##
返回的位置信息格式为:
[
([页码列表], left, right, top, bottom),
...
]
Args:
request: 包含待处理文本的请求对象
Returns:
ExtractPositionsResponse: 包含提取结果的响应对象
Raises:
HTTPException: 400 - 如果文本为空
"""
if not request.text:
raise HTTPException(status_code=400, detail="文本内容不能为空")
try:
positions = SimplePdfParser.extract_positions(request.text)
# 将位置信息转换为可序列化的格式
serializable_positions = [
{
"page_numbers": pos[0],
"left": pos[1],
"right": pos[2],
"top": pos[3],
"bottom": pos[4]
}
for pos in positions
]
return ExtractPositionsResponse(
success=True,
message=f"成功提取 {len(positions)} 个位置信息",
positions=serializable_positions
)
except Exception as e:
logger.error(f"Error extracting positions: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"提取位置信息时发生错误: {str(e)}"
)