340 lines
12 KiB
Python
340 lines
12 KiB
Python
#
|
||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
#
|
||
|
||
"""
|
||
简化的PDF解析器,只使用OCR处理PDF文档
|
||
|
||
从 RAGFlow 的 RAGFlowPdfParser 中提取OCR相关功能,移除了:
|
||
- 布局识别(Layout Recognition)
|
||
- 表格结构识别(Table Structure Recognition)
|
||
- 文本合并和语义分析
|
||
- RAG相关功能
|
||
|
||
只保留:
|
||
- PDF转图片
|
||
- OCR文本检测和识别
|
||
- 基本的文本和位置信息返回
|
||
"""
|
||
|
||
import logging
|
||
import sys
|
||
import threading
|
||
from io import BytesIO
|
||
from pathlib import Path
|
||
from timeit import default_timer as timer
|
||
|
||
import numpy as np
|
||
import pdfplumber
|
||
import trio
|
||
|
||
# 处理导入问题:支持直接运行和模块导入
|
||
try:
|
||
_package = __package__
|
||
except NameError:
|
||
_package = None
|
||
|
||
if _package is None:
|
||
# 直接运行时,添加父目录到路径并使用绝对导入
|
||
parent_dir = Path(__file__).parent.parent
|
||
if str(parent_dir) not in sys.path:
|
||
sys.path.insert(0, str(parent_dir))
|
||
from ocr.config import PARALLEL_DEVICES
|
||
from ocr.ocr import OCR
|
||
else:
|
||
# 作为模块导入时使用相对导入
|
||
from config import PARALLEL_DEVICES
|
||
from ocr import OCR
|
||
|
||
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
||
if LOCK_KEY_pdfplumber not in sys.modules:
|
||
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
|
||
|
||
|
||
class SimplePdfParser:
|
||
"""
|
||
简化的PDF解析器,只使用OCR处理PDF
|
||
|
||
使用方法:
|
||
parser = SimplePdfParser()
|
||
result = parser.parse_pdf("file.pdf") # 或传入二进制数据
|
||
# result 格式:
|
||
# {
|
||
# "pages": [
|
||
# {
|
||
# "page_number": 1,
|
||
# "boxes": [
|
||
# {
|
||
# "text": "识别到的文本",
|
||
# "bbox": [[x0, y0], [x1, y0], [x1, y1], [x0, y1]],
|
||
# "confidence": 0.95
|
||
# },
|
||
# ...
|
||
# ]
|
||
# },
|
||
# ...
|
||
# ]
|
||
# }
|
||
"""
|
||
|
||
def __init__(self, model_dir=None):
|
||
"""
|
||
初始化PDF解析器
|
||
|
||
Args:
|
||
model_dir: OCR模型目录,如果为None则使用默认路径
|
||
"""
|
||
self.ocr = OCR(model_dir=model_dir)
|
||
self.parallel_limiter = None
|
||
if PARALLEL_DEVICES > 1:
|
||
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)]
|
||
|
||
def __ocr_page(self, page_num, img, zoomin=3, device_id=None):
|
||
"""
|
||
对单页进行OCR处理
|
||
|
||
Args:
|
||
page_num: 页码
|
||
img: PIL Image对象
|
||
zoomin: 放大倍数(用于坐标缩放)
|
||
device_id: GPU设备ID
|
||
|
||
Returns:
|
||
list: OCR结果列表,每个元素为 {"text": str, "bbox": list, "confidence": float}
|
||
"""
|
||
start = timer()
|
||
img_np = np.array(img)
|
||
|
||
# 文本检测
|
||
# detect方法返回: zip对象,格式为 (box_coords, (text, score))
|
||
# 但检测阶段text和score都是默认值,需要后续识别
|
||
detection_result = self.ocr.detect(img_np, device_id)
|
||
|
||
if detection_result is None:
|
||
return []
|
||
|
||
# 转换为列表并提取box坐标
|
||
# detect返回的格式是zip,每个元素是 (box_coords, (text, score))
|
||
# 在检测阶段,text是空字符串,score是0
|
||
bxs = list(detection_result)
|
||
|
||
logging.info(f"Page {page_num}: OCR detection found {len(bxs)} boxes in {timer() - start:.2f}s")
|
||
|
||
if not bxs:
|
||
return []
|
||
|
||
# 解析检测结果并准备识别
|
||
boxes_to_reg = []
|
||
|
||
start = timer()
|
||
for box_coords, _, _ in bxs:
|
||
# box_coords 是四边形坐标: [[x0, y0], [x1, y0], [x1, y1], [x0, y1]]
|
||
# 转换为原始坐标(考虑zoomin)
|
||
box_coords_np = np.array(box_coords, dtype=np.float32)
|
||
original_coords = box_coords_np / zoomin # 缩放回原始坐标
|
||
|
||
# 裁剪图像用于识别
|
||
# 使用放大后的坐标裁剪(因为img_np是放大后的图像)
|
||
crop_box = box_coords_np
|
||
crop_img = self.ocr.get_rotate_crop_image(img_np, crop_box)
|
||
boxes_to_reg.append({
|
||
"bbox": original_coords.tolist(),
|
||
"crop_img": crop_img
|
||
})
|
||
|
||
# 批量识别文本
|
||
ocr_results = []
|
||
if boxes_to_reg:
|
||
crop_imgs = [b["crop_img"] for b in boxes_to_reg]
|
||
texts = self.ocr.recognize_batch(crop_imgs, device_id)
|
||
|
||
# 组装结果
|
||
for i, b in enumerate(boxes_to_reg):
|
||
if i < len(texts) and texts[i]: # 过滤空文本
|
||
ocr_results.append({
|
||
"text": texts[i],
|
||
"bbox": b["bbox"],
|
||
"confidence": 0.9 # 简化版本,不计算具体置信度
|
||
})
|
||
|
||
logging.info(f"Page {page_num}: OCR recognition {len(ocr_results)} boxes cost {timer() - start:.2f}s")
|
||
return ocr_results
|
||
|
||
async def __ocr_page_async(self, page_num, img, zoomin, device_id, limiter, callback):
|
||
"""
|
||
异步OCR处理单页
|
||
|
||
Args:
|
||
page_num: 页码
|
||
img: PIL Image对象
|
||
zoomin: 放大倍数
|
||
device_id: GPU设备ID
|
||
limiter: 并发限制器
|
||
callback: 进度回调函数
|
||
"""
|
||
if limiter:
|
||
async with limiter:
|
||
result = await trio.to_thread.run_sync(
|
||
lambda: self.__ocr_page(page_num, img, zoomin, device_id)
|
||
)
|
||
else:
|
||
result = await trio.to_thread.run_sync(
|
||
lambda: self.__ocr_page(page_num, img, zoomin, device_id)
|
||
)
|
||
|
||
if callback and page_num % 5 == 0:
|
||
callback(prog=page_num * 0.9 / 100, msg=f"Processing page {page_num}...")
|
||
|
||
return result
|
||
|
||
def __convert_pdf_to_images(self, pdf_source, zoomin=3, page_from=0, page_to=299):
|
||
"""
|
||
将PDF转换为图片
|
||
|
||
Args:
|
||
pdf_source: PDF文件路径(str)或二进制数据(bytes)
|
||
zoomin: 放大倍数,默认3(72*3=216 DPI)
|
||
page_from: 起始页码(从0开始)
|
||
page_to: 结束页码
|
||
|
||
Returns:
|
||
list: PIL Image对象列表
|
||
"""
|
||
start = timer()
|
||
page_images = []
|
||
|
||
try:
|
||
with sys.modules[LOCK_KEY_pdfplumber]:
|
||
pdf = pdfplumber.open(pdf_source) if isinstance(pdf_source, str) else pdfplumber.open(BytesIO(pdf_source))
|
||
try:
|
||
# 转换为图片,resolution = 72 * zoomin
|
||
page_images = [
|
||
p.to_image(resolution=72 * zoomin, antialias=True).annotated
|
||
for i, p in enumerate(pdf.pages[page_from:page_to])
|
||
]
|
||
pdf.close()
|
||
except Exception as e:
|
||
logging.warning(f"Failed to convert PDF pages {page_from}-{page_to}: {str(e)}")
|
||
if hasattr(pdf, 'close'):
|
||
pdf.close()
|
||
except Exception as e:
|
||
logging.exception(f"Error converting PDF to images: {str(e)}")
|
||
|
||
logging.info(f"Converted {len(page_images)} pages to images in {timer() - start:.2f}s")
|
||
return page_images
|
||
|
||
def parse_pdf(self, pdf_source, zoomin=3, page_from=0, page_to=299, callback=None):
|
||
"""
|
||
解析PDF文档,使用OCR识别文本
|
||
|
||
Args:
|
||
pdf_source: PDF文件路径(str)或二进制数据(bytes)
|
||
zoomin: 放大倍数,默认3
|
||
page_from: 起始页码(从0开始)
|
||
page_to: 结束页码
|
||
callback: 进度回调函数,格式: callback(prog: float, msg: str)
|
||
|
||
Returns:
|
||
dict: 解析结果
|
||
{
|
||
"pages": [
|
||
{
|
||
"page_number": int,
|
||
"boxes": [
|
||
{
|
||
"text": str,
|
||
"bbox": [[x0, y0], [x1, y0], [x1, y1], [x0, y1]],
|
||
"confidence": float
|
||
},
|
||
...
|
||
]
|
||
},
|
||
...
|
||
]
|
||
}
|
||
"""
|
||
if callback:
|
||
callback(0.0, "Starting PDF parsing...")
|
||
|
||
# 1. 转换为图片
|
||
if callback:
|
||
callback(0.1, "Converting PDF to images...")
|
||
page_images = self.__convert_pdf_to_images(pdf_source, zoomin, page_from, page_to)
|
||
|
||
if not page_images:
|
||
logging.warning("No pages converted from PDF")
|
||
return {"pages": []}
|
||
|
||
# 2. OCR处理
|
||
async def process_all_pages():
|
||
pages_result = []
|
||
|
||
if self.parallel_limiter:
|
||
# 并行处理(多GPU)
|
||
async with trio.open_nursery() as nursery:
|
||
tasks = []
|
||
for i, img in enumerate(page_images):
|
||
page_num = page_from + i + 1
|
||
device_id = i % PARALLEL_DEVICES
|
||
task = nursery.start_soon(
|
||
self.__ocr_page_async,
|
||
page_num, img, zoomin, device_id,
|
||
self.parallel_limiter[device_id], callback
|
||
)
|
||
tasks.append(task)
|
||
|
||
# 等待所有任务完成并收集结果
|
||
for i, task in enumerate(tasks):
|
||
result = await task
|
||
pages_result.append({
|
||
"page_number": page_from + i + 1,
|
||
"boxes": result
|
||
})
|
||
else:
|
||
# 串行处理(单GPU或CPU)
|
||
for i, img in enumerate(page_images):
|
||
page_num = page_from + i + 1
|
||
result = await trio.to_thread.run_sync(
|
||
lambda img=img, pn=page_num: self.__ocr_page(pn, img, zoomin, 0)
|
||
)
|
||
pages_result.append({
|
||
"page_number": page_num,
|
||
"boxes": result
|
||
})
|
||
if callback:
|
||
callback(0.1 + (i + 1) * 0.9 / len(page_images), f"Processing page {page_num}...")
|
||
|
||
return pages_result
|
||
|
||
# 运行异步处理
|
||
if callback:
|
||
callback(0.2, "Starting OCR processing...")
|
||
|
||
start = timer()
|
||
pages_result = trio.run(process_all_pages)
|
||
logging.info(f"OCR processing completed in {timer() - start:.2f}s")
|
||
|
||
if callback:
|
||
callback(1.0, "OCR processing completed")
|
||
|
||
return {
|
||
"pages": pages_result
|
||
}
|
||
|
||
|
||
# 向后兼容的别名
|
||
PdfParser = SimplePdfParser
|
||
|