将ocr解析模块独立出来
This commit is contained in:
175
deepdoc/parser/ocr_http_client.py
Normal file
175
deepdoc/parser/ocr_http_client.py
Normal file
@@ -0,0 +1,175 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
OCR HTTP 客户端
|
||||
用于调用独立的 OCR 服务的 HTTP API
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import requests
|
||||
from typing import Optional, Union, Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OCRHttpClient:
|
||||
"""OCR HTTP 客户端,用于调用独立的 OCR 服务"""
|
||||
|
||||
def __init__(self, base_url: Optional[str] = None, timeout: int = 300):
|
||||
"""
|
||||
初始化 OCR HTTP 客户端
|
||||
|
||||
Args:
|
||||
base_url: OCR 服务的基础 URL,如果不提供则从环境变量 OCR_SERVICE_URL 读取
|
||||
默认值为 http://localhost:8000
|
||||
timeout: 请求超时时间(秒),默认 300 秒
|
||||
"""
|
||||
if base_url is None:
|
||||
base_url = os.getenv("OCR_SERVICE_URL", "http://localhost:8000")
|
||||
|
||||
# 确保 URL 不包含尾随斜杠
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.timeout = timeout
|
||||
self.api_prefix = "/api/v1/ocr"
|
||||
|
||||
logger.info(f"Initialized OCR HTTP client with base_url: {self.base_url}")
|
||||
|
||||
def parse_pdf_by_path(self, file_path: str, page_from: int = 1, page_to: int = 0, zoomin: int = 3) -> Dict[str, Any]:
|
||||
"""
|
||||
通过文件路径解析 PDF
|
||||
|
||||
Args:
|
||||
file_path: PDF 文件的本地路径
|
||||
page_from: 起始页码(从1开始)
|
||||
page_to: 结束页码(0表示最后一页)
|
||||
zoomin: 图像放大倍数(1-5)
|
||||
|
||||
Returns:
|
||||
dict: 解析结果,格式:
|
||||
{
|
||||
"success": bool,
|
||||
"message": str,
|
||||
"data": {
|
||||
"pages": [
|
||||
{
|
||||
"page_number": int,
|
||||
"boxes": [
|
||||
{
|
||||
"text": str,
|
||||
"bbox": [[x0, y0], [x1, y0], [x1, y1], [x0, y1]],
|
||||
"confidence": float
|
||||
},
|
||||
...
|
||||
]
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
Raises:
|
||||
requests.RequestException: HTTP 请求失败
|
||||
ValueError: 响应格式不正确
|
||||
"""
|
||||
url = f"{self.base_url}{self.api_prefix}/parse/path"
|
||||
|
||||
data = {
|
||||
"file_path": file_path,
|
||||
"page_from": page_from,
|
||||
"page_to": page_to,
|
||||
"zoomin": zoomin
|
||||
}
|
||||
|
||||
try:
|
||||
logger.info(f"Calling OCR service: {url} for file: {file_path}")
|
||||
response = requests.post(url, data=data, timeout=self.timeout)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
if not result.get("success", False):
|
||||
raise ValueError(f"OCR service returned error: {result.get('message', 'Unknown error')}")
|
||||
|
||||
return result
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to call OCR service: {e}")
|
||||
raise
|
||||
|
||||
def parse_pdf_by_bytes(self, pdf_bytes: bytes, filename: str = "document.pdf",
|
||||
page_from: int = 1, page_to: int = 0, zoomin: int = 3) -> Dict[str, Any]:
|
||||
"""
|
||||
通过二进制数据解析 PDF
|
||||
|
||||
Args:
|
||||
pdf_bytes: PDF 文件的二进制数据
|
||||
filename: 文件名(仅用于日志)
|
||||
page_from: 起始页码(从1开始)
|
||||
page_to: 结束页码(0表示最后一页)
|
||||
zoomin: 图像放大倍数(1-5)
|
||||
|
||||
Returns:
|
||||
dict: 解析结果,格式同 parse_pdf_by_path
|
||||
|
||||
Raises:
|
||||
requests.RequestException: HTTP 请求失败
|
||||
ValueError: 响应格式不正确
|
||||
"""
|
||||
url = f"{self.base_url}{self.api_prefix}/parse/bytes"
|
||||
|
||||
files = {
|
||||
"pdf_bytes": (filename, pdf_bytes, "application/pdf")
|
||||
}
|
||||
|
||||
data = {
|
||||
"filename": filename,
|
||||
"page_from": page_from,
|
||||
"page_to": page_to,
|
||||
"zoomin": zoomin
|
||||
}
|
||||
|
||||
try:
|
||||
logger.info(f"Calling OCR service: {url} with {len(pdf_bytes)} bytes")
|
||||
response = requests.post(url, files=files, data=data, timeout=self.timeout)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
if not result.get("success", False):
|
||||
raise ValueError(f"OCR service returned error: {result.get('message', 'Unknown error')}")
|
||||
|
||||
return result
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to call OCR service: {e}")
|
||||
raise
|
||||
|
||||
def health_check(self) -> Dict[str, Any]:
|
||||
"""
|
||||
检查 OCR 服务健康状态
|
||||
|
||||
Returns:
|
||||
dict: 健康状态信息
|
||||
"""
|
||||
url = f"{self.base_url}{self.api_prefix}/health"
|
||||
|
||||
try:
|
||||
response = requests.get(url, timeout=10)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to check OCR service health: {e}")
|
||||
raise
|
||||
|
||||
@@ -35,6 +35,7 @@ from pypdf import PdfReader as pdf2_read
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from deepdoc.vision import OCR, AscendLayoutRecognizer, LayoutRecognizer, Recognizer, TableStructureRecognizer
|
||||
from deepdoc.parser.ocr_http_client import OCRHttpClient
|
||||
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
|
||||
from rag.nlp import rag_tokenizer
|
||||
from rag.prompts.generator import vision_llm_describe_prompt
|
||||
@@ -58,10 +59,24 @@ class RAGFlowPdfParser:
|
||||
^_-
|
||||
|
||||
"""
|
||||
|
||||
self.ocr = OCR()
|
||||
|
||||
# 检查是否使用 HTTP OCR 服务
|
||||
use_http_ocr = os.getenv("USE_OCR_HTTP", "false").lower() in ("true", "1", "yes")
|
||||
ocr_service_url = os.getenv("OCR_SERVICE_URL", "http://localhost:8000")
|
||||
|
||||
if use_http_ocr:
|
||||
logging.info(f"Using HTTP OCR service: {ocr_service_url}")
|
||||
self.ocr = None # 不使用本地 OCR
|
||||
self.ocr_http_client = OCRHttpClient(base_url=ocr_service_url)
|
||||
self.use_http_ocr = True
|
||||
else:
|
||||
logging.info("Using local OCR")
|
||||
self.ocr = OCR()
|
||||
self.ocr_http_client = None
|
||||
self.use_http_ocr = False
|
||||
|
||||
self.parallel_limiter = None
|
||||
if PARALLEL_DEVICES > 1:
|
||||
if not self.use_http_ocr and PARALLEL_DEVICES > 1:
|
||||
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)]
|
||||
|
||||
layout_recognizer_type = os.getenv("LAYOUT_RECOGNIZER_TYPE", "onnx").lower()
|
||||
@@ -276,7 +291,97 @@ class RAGFlowPdfParser:
|
||||
b["H_right"] = spans[ii]["x1"]
|
||||
b["SP"] = ii
|
||||
|
||||
def _convert_http_ocr_result(self, ocr_result: dict, zoomin: int = 3):
|
||||
"""
|
||||
将 HTTP OCR API 返回的结果转换为 RAGFlow 内部格式
|
||||
|
||||
Args:
|
||||
ocr_result: HTTP API 返回的结果,格式:
|
||||
{
|
||||
"success": bool,
|
||||
"data": {
|
||||
"pages": [
|
||||
{
|
||||
"page_number": int,
|
||||
"boxes": [
|
||||
{
|
||||
"text": str,
|
||||
"bbox": [[x0, y0], [x1, y0], [x1, y1], [x0, y1]],
|
||||
"confidence": float
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
zoomin: 放大倍数
|
||||
"""
|
||||
if not ocr_result.get("success", False) or "data" not in ocr_result:
|
||||
logging.warning("Invalid OCR HTTP result")
|
||||
return
|
||||
|
||||
pages_data = ocr_result["data"].get("pages", [])
|
||||
self.boxes = []
|
||||
|
||||
for page_data in pages_data:
|
||||
page_num = page_data.get("page_number", 0) # HTTP API 返回的页码(从1开始)
|
||||
boxes = page_data.get("boxes", [])
|
||||
|
||||
# 转换为 RAGFlow 格式的 boxes
|
||||
ragflow_boxes = []
|
||||
# 计算在 page_chars 中的索引:HTTP API 返回的页码是从1开始的,需要转换为相对于 page_from 的索引
|
||||
page_index = page_num - (self.page_from + 1) # page_from 是从0开始,所以需要 +1
|
||||
chars_for_page = self.page_chars[page_index] if hasattr(self, 'page_chars') and 0 <= page_index < len(self.page_chars) else []
|
||||
|
||||
for box in boxes:
|
||||
bbox = box.get("bbox", [])
|
||||
if len(bbox) != 4:
|
||||
continue
|
||||
|
||||
# 从 bbox 提取坐标(bbox 格式: [[x0, y0], [x1, y0], [x1, y1], [x0, y1]])
|
||||
x0 = min(bbox[0][0], bbox[3][0]) / zoomin
|
||||
x1 = max(bbox[1][0], bbox[2][0]) / zoomin
|
||||
top = min(bbox[0][1], bbox[1][1]) / zoomin
|
||||
bottom = max(bbox[2][1], bbox[3][1]) / zoomin
|
||||
|
||||
# 创建 RAGFlow 格式的 box
|
||||
ragflow_box = {
|
||||
"x0": x0,
|
||||
"x1": x1,
|
||||
"top": top,
|
||||
"bottom": bottom,
|
||||
"text": box.get("text", ""),
|
||||
"page_number": page_num,
|
||||
"layoutno": "",
|
||||
"layout_type": ""
|
||||
}
|
||||
|
||||
ragflow_boxes.append(ragflow_box)
|
||||
|
||||
# 计算 mean_height
|
||||
if ragflow_boxes:
|
||||
heights = [b["bottom"] - b["top"] for b in ragflow_boxes]
|
||||
self.mean_height.append(np.median(heights) if heights else 0)
|
||||
else:
|
||||
self.mean_height.append(0)
|
||||
|
||||
# 计算 mean_width
|
||||
if chars_for_page:
|
||||
widths = [c.get("width", 8) for c in chars_for_page]
|
||||
self.mean_width.append(np.median(widths) if widths else 8)
|
||||
else:
|
||||
self.mean_width.append(8)
|
||||
|
||||
self.boxes.append(ragflow_boxes)
|
||||
|
||||
logging.info(f"Converted {len(pages_data)} pages from HTTP OCR result")
|
||||
|
||||
def __ocr(self, pagenum, img, chars, ZM=3, device_id: int | None = None):
|
||||
# 如果使用 HTTP OCR,这个方法不会被调用
|
||||
if self.use_http_ocr:
|
||||
logging.warning("__ocr called when using HTTP OCR, this should not happen")
|
||||
return
|
||||
|
||||
start = timer()
|
||||
bxs = self.ocr.detect(np.array(img), device_id)
|
||||
logging.info(f"__ocr detecting boxes of a image cost ({timer() - start}s)")
|
||||
@@ -927,6 +1032,7 @@ class RAGFlowPdfParser:
|
||||
self.page_cum_height = [0]
|
||||
self.page_layout = []
|
||||
self.page_from = page_from
|
||||
|
||||
start = timer()
|
||||
try:
|
||||
with sys.modules[LOCK_KEY_pdfplumber]:
|
||||
@@ -945,6 +1051,42 @@ class RAGFlowPdfParser:
|
||||
except Exception:
|
||||
logging.exception("RAGFlowPdfParser __images__")
|
||||
logging.info(f"__images__ dedupe_chars cost {timer() - start}s")
|
||||
|
||||
# 如果使用 HTTP OCR,在获取图片和字符信息后调用 HTTP API 获取 OCR 结果
|
||||
if self.use_http_ocr:
|
||||
try:
|
||||
if callback:
|
||||
callback(0.1, "Calling OCR HTTP service...")
|
||||
|
||||
# 调用 HTTP OCR 服务
|
||||
if isinstance(fnm, str):
|
||||
# 文件路径
|
||||
ocr_result = self.ocr_http_client.parse_pdf_by_path(
|
||||
fnm,
|
||||
page_from=page_from + 1, # HTTP API 使用从1开始的页码
|
||||
page_to=(page_to + 1) if page_to < 299 else 0, # 转换为从1开始,0 表示最后一页
|
||||
zoomin=zoomin
|
||||
)
|
||||
else:
|
||||
# 二进制数据
|
||||
ocr_result = self.ocr_http_client.parse_pdf_by_bytes(
|
||||
fnm,
|
||||
filename="document.pdf",
|
||||
page_from=page_from + 1,
|
||||
page_to=(page_to + 1) if page_to < 299 else 0,
|
||||
zoomin=zoomin
|
||||
)
|
||||
|
||||
# 将 HTTP API 返回的结果转换为 RAGFlow 格式
|
||||
self._convert_http_ocr_result(ocr_result, zoomin)
|
||||
|
||||
if callback:
|
||||
callback(0.4, "OCR HTTP service completed")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to call OCR HTTP service: {e}", exc_info=True)
|
||||
# 如果 HTTP OCR 失败,回退到空结果或抛出异常
|
||||
raise
|
||||
|
||||
self.outlines = []
|
||||
try:
|
||||
@@ -999,29 +1141,34 @@ class RAGFlowPdfParser:
|
||||
if callback and i % 6 == 5:
|
||||
callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="")
|
||||
|
||||
async def __img_ocr_launcher():
|
||||
def __ocr_preprocess():
|
||||
chars = self.page_chars[i] if not self.is_english else []
|
||||
self.mean_height.append(np.median(sorted([c["height"] for c in chars])) if chars else 0)
|
||||
self.mean_width.append(np.median(sorted([c["width"] for c in chars])) if chars else 8)
|
||||
self.page_cum_height.append(img.size[1] / zoomin)
|
||||
return chars
|
||||
# 如果使用 HTTP OCR,已经在上面的代码中获取了结果,跳过本地 OCR
|
||||
if not self.use_http_ocr:
|
||||
async def __img_ocr_launcher():
|
||||
def __ocr_preprocess():
|
||||
chars = self.page_chars[i] if not self.is_english else []
|
||||
self.mean_height.append(np.median(sorted([c["height"] for c in chars])) if chars else 0)
|
||||
self.mean_width.append(np.median(sorted([c["width"] for c in chars])) if chars else 8)
|
||||
self.page_cum_height.append(img.size[1] / zoomin)
|
||||
return chars
|
||||
|
||||
if self.parallel_limiter:
|
||||
async with trio.open_nursery() as nursery:
|
||||
if self.parallel_limiter:
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i, img in enumerate(self.page_images):
|
||||
chars = __ocr_preprocess()
|
||||
|
||||
nursery.start_soon(__img_ocr, i, i % PARALLEL_DEVICES, img, chars, self.parallel_limiter[i % PARALLEL_DEVICES])
|
||||
await trio.sleep(0.1)
|
||||
else:
|
||||
for i, img in enumerate(self.page_images):
|
||||
chars = __ocr_preprocess()
|
||||
await __img_ocr(i, 0, img, chars, None)
|
||||
|
||||
nursery.start_soon(__img_ocr, i, i % PARALLEL_DEVICES, img, chars, self.parallel_limiter[i % PARALLEL_DEVICES])
|
||||
await trio.sleep(0.1)
|
||||
else:
|
||||
for i, img in enumerate(self.page_images):
|
||||
chars = __ocr_preprocess()
|
||||
await __img_ocr(i, 0, img, chars, None)
|
||||
|
||||
start = timer()
|
||||
|
||||
trio.run(__img_ocr_launcher)
|
||||
start = timer()
|
||||
trio.run(__img_ocr_launcher)
|
||||
else:
|
||||
# HTTP OCR 模式:初始化 page_cum_height
|
||||
for i, img in enumerate(self.page_images):
|
||||
self.page_cum_height.append(img.size[1] / zoomin)
|
||||
|
||||
logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user