将ocr解析模块独立出来

This commit is contained in:
2025-10-31 14:38:37 +08:00
parent d78f1fe91d
commit 4318179904
13 changed files with 3262 additions and 23 deletions

View File

@@ -35,6 +35,7 @@ from pypdf import PdfReader as pdf2_read
from api import settings
from api.utils.file_utils import get_project_base_directory
from deepdoc.vision import OCR, AscendLayoutRecognizer, LayoutRecognizer, Recognizer, TableStructureRecognizer
from deepdoc.parser.ocr_http_client import OCRHttpClient
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
from rag.nlp import rag_tokenizer
from rag.prompts.generator import vision_llm_describe_prompt
@@ -58,10 +59,24 @@ class RAGFlowPdfParser:
^_-
"""
self.ocr = OCR()
# 检查是否使用 HTTP OCR 服务
use_http_ocr = os.getenv("USE_OCR_HTTP", "false").lower() in ("true", "1", "yes")
ocr_service_url = os.getenv("OCR_SERVICE_URL", "http://localhost:8000")
if use_http_ocr:
logging.info(f"Using HTTP OCR service: {ocr_service_url}")
self.ocr = None # 不使用本地 OCR
self.ocr_http_client = OCRHttpClient(base_url=ocr_service_url)
self.use_http_ocr = True
else:
logging.info("Using local OCR")
self.ocr = OCR()
self.ocr_http_client = None
self.use_http_ocr = False
self.parallel_limiter = None
if PARALLEL_DEVICES > 1:
if not self.use_http_ocr and PARALLEL_DEVICES > 1:
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)]
layout_recognizer_type = os.getenv("LAYOUT_RECOGNIZER_TYPE", "onnx").lower()
@@ -276,7 +291,97 @@ class RAGFlowPdfParser:
b["H_right"] = spans[ii]["x1"]
b["SP"] = ii
def _convert_http_ocr_result(self, ocr_result: dict, zoomin: int = 3):
"""
将 HTTP OCR API 返回的结果转换为 RAGFlow 内部格式
Args:
ocr_result: HTTP API 返回的结果,格式:
{
"success": bool,
"data": {
"pages": [
{
"page_number": int,
"boxes": [
{
"text": str,
"bbox": [[x0, y0], [x1, y0], [x1, y1], [x0, y1]],
"confidence": float
}
]
}
]
}
}
zoomin: 放大倍数
"""
if not ocr_result.get("success", False) or "data" not in ocr_result:
logging.warning("Invalid OCR HTTP result")
return
pages_data = ocr_result["data"].get("pages", [])
self.boxes = []
for page_data in pages_data:
page_num = page_data.get("page_number", 0) # HTTP API 返回的页码从1开始
boxes = page_data.get("boxes", [])
# 转换为 RAGFlow 格式的 boxes
ragflow_boxes = []
# 计算在 page_chars 中的索引HTTP API 返回的页码是从1开始的需要转换为相对于 page_from 的索引
page_index = page_num - (self.page_from + 1) # page_from 是从0开始所以需要 +1
chars_for_page = self.page_chars[page_index] if hasattr(self, 'page_chars') and 0 <= page_index < len(self.page_chars) else []
for box in boxes:
bbox = box.get("bbox", [])
if len(bbox) != 4:
continue
# 从 bbox 提取坐标bbox 格式: [[x0, y0], [x1, y0], [x1, y1], [x0, y1]]
x0 = min(bbox[0][0], bbox[3][0]) / zoomin
x1 = max(bbox[1][0], bbox[2][0]) / zoomin
top = min(bbox[0][1], bbox[1][1]) / zoomin
bottom = max(bbox[2][1], bbox[3][1]) / zoomin
# 创建 RAGFlow 格式的 box
ragflow_box = {
"x0": x0,
"x1": x1,
"top": top,
"bottom": bottom,
"text": box.get("text", ""),
"page_number": page_num,
"layoutno": "",
"layout_type": ""
}
ragflow_boxes.append(ragflow_box)
# 计算 mean_height
if ragflow_boxes:
heights = [b["bottom"] - b["top"] for b in ragflow_boxes]
self.mean_height.append(np.median(heights) if heights else 0)
else:
self.mean_height.append(0)
# 计算 mean_width
if chars_for_page:
widths = [c.get("width", 8) for c in chars_for_page]
self.mean_width.append(np.median(widths) if widths else 8)
else:
self.mean_width.append(8)
self.boxes.append(ragflow_boxes)
logging.info(f"Converted {len(pages_data)} pages from HTTP OCR result")
def __ocr(self, pagenum, img, chars, ZM=3, device_id: int | None = None):
# 如果使用 HTTP OCR这个方法不会被调用
if self.use_http_ocr:
logging.warning("__ocr called when using HTTP OCR, this should not happen")
return
start = timer()
bxs = self.ocr.detect(np.array(img), device_id)
logging.info(f"__ocr detecting boxes of a image cost ({timer() - start}s)")
@@ -927,6 +1032,7 @@ class RAGFlowPdfParser:
self.page_cum_height = [0]
self.page_layout = []
self.page_from = page_from
start = timer()
try:
with sys.modules[LOCK_KEY_pdfplumber]:
@@ -945,6 +1051,42 @@ class RAGFlowPdfParser:
except Exception:
logging.exception("RAGFlowPdfParser __images__")
logging.info(f"__images__ dedupe_chars cost {timer() - start}s")
# 如果使用 HTTP OCR在获取图片和字符信息后调用 HTTP API 获取 OCR 结果
if self.use_http_ocr:
try:
if callback:
callback(0.1, "Calling OCR HTTP service...")
# 调用 HTTP OCR 服务
if isinstance(fnm, str):
# 文件路径
ocr_result = self.ocr_http_client.parse_pdf_by_path(
fnm,
page_from=page_from + 1, # HTTP API 使用从1开始的页码
page_to=(page_to + 1) if page_to < 299 else 0, # 转换为从1开始0 表示最后一页
zoomin=zoomin
)
else:
# 二进制数据
ocr_result = self.ocr_http_client.parse_pdf_by_bytes(
fnm,
filename="document.pdf",
page_from=page_from + 1,
page_to=(page_to + 1) if page_to < 299 else 0,
zoomin=zoomin
)
# 将 HTTP API 返回的结果转换为 RAGFlow 格式
self._convert_http_ocr_result(ocr_result, zoomin)
if callback:
callback(0.4, "OCR HTTP service completed")
except Exception as e:
logging.error(f"Failed to call OCR HTTP service: {e}", exc_info=True)
# 如果 HTTP OCR 失败,回退到空结果或抛出异常
raise
self.outlines = []
try:
@@ -999,29 +1141,34 @@ class RAGFlowPdfParser:
if callback and i % 6 == 5:
callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="")
async def __img_ocr_launcher():
def __ocr_preprocess():
chars = self.page_chars[i] if not self.is_english else []
self.mean_height.append(np.median(sorted([c["height"] for c in chars])) if chars else 0)
self.mean_width.append(np.median(sorted([c["width"] for c in chars])) if chars else 8)
self.page_cum_height.append(img.size[1] / zoomin)
return chars
# 如果使用 HTTP OCR已经在上面的代码中获取了结果跳过本地 OCR
if not self.use_http_ocr:
async def __img_ocr_launcher():
def __ocr_preprocess():
chars = self.page_chars[i] if not self.is_english else []
self.mean_height.append(np.median(sorted([c["height"] for c in chars])) if chars else 0)
self.mean_width.append(np.median(sorted([c["width"] for c in chars])) if chars else 8)
self.page_cum_height.append(img.size[1] / zoomin)
return chars
if self.parallel_limiter:
async with trio.open_nursery() as nursery:
if self.parallel_limiter:
async with trio.open_nursery() as nursery:
for i, img in enumerate(self.page_images):
chars = __ocr_preprocess()
nursery.start_soon(__img_ocr, i, i % PARALLEL_DEVICES, img, chars, self.parallel_limiter[i % PARALLEL_DEVICES])
await trio.sleep(0.1)
else:
for i, img in enumerate(self.page_images):
chars = __ocr_preprocess()
await __img_ocr(i, 0, img, chars, None)
nursery.start_soon(__img_ocr, i, i % PARALLEL_DEVICES, img, chars, self.parallel_limiter[i % PARALLEL_DEVICES])
await trio.sleep(0.1)
else:
for i, img in enumerate(self.page_images):
chars = __ocr_preprocess()
await __img_ocr(i, 0, img, chars, None)
start = timer()
trio.run(__img_ocr_launcher)
start = timer()
trio.run(__img_ocr_launcher)
else:
# HTTP OCR 模式:初始化 page_cum_height
for i, img in enumerate(self.page_images):
self.page_cum_height.append(img.size[1] / zoomin)
logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s")