将ocr解析模块独立出来
This commit is contained in:
175
deepdoc/parser/ocr_http_client.py
Normal file
175
deepdoc/parser/ocr_http_client.py
Normal file
@@ -0,0 +1,175 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
OCR HTTP 客户端
|
||||
用于调用独立的 OCR 服务的 HTTP API
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import requests
|
||||
from typing import Optional, Union, Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OCRHttpClient:
|
||||
"""OCR HTTP 客户端,用于调用独立的 OCR 服务"""
|
||||
|
||||
def __init__(self, base_url: Optional[str] = None, timeout: int = 300):
|
||||
"""
|
||||
初始化 OCR HTTP 客户端
|
||||
|
||||
Args:
|
||||
base_url: OCR 服务的基础 URL,如果不提供则从环境变量 OCR_SERVICE_URL 读取
|
||||
默认值为 http://localhost:8000
|
||||
timeout: 请求超时时间(秒),默认 300 秒
|
||||
"""
|
||||
if base_url is None:
|
||||
base_url = os.getenv("OCR_SERVICE_URL", "http://localhost:8000")
|
||||
|
||||
# 确保 URL 不包含尾随斜杠
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.timeout = timeout
|
||||
self.api_prefix = "/api/v1/ocr"
|
||||
|
||||
logger.info(f"Initialized OCR HTTP client with base_url: {self.base_url}")
|
||||
|
||||
def parse_pdf_by_path(self, file_path: str, page_from: int = 1, page_to: int = 0, zoomin: int = 3) -> Dict[str, Any]:
|
||||
"""
|
||||
通过文件路径解析 PDF
|
||||
|
||||
Args:
|
||||
file_path: PDF 文件的本地路径
|
||||
page_from: 起始页码(从1开始)
|
||||
page_to: 结束页码(0表示最后一页)
|
||||
zoomin: 图像放大倍数(1-5)
|
||||
|
||||
Returns:
|
||||
dict: 解析结果,格式:
|
||||
{
|
||||
"success": bool,
|
||||
"message": str,
|
||||
"data": {
|
||||
"pages": [
|
||||
{
|
||||
"page_number": int,
|
||||
"boxes": [
|
||||
{
|
||||
"text": str,
|
||||
"bbox": [[x0, y0], [x1, y0], [x1, y1], [x0, y1]],
|
||||
"confidence": float
|
||||
},
|
||||
...
|
||||
]
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
Raises:
|
||||
requests.RequestException: HTTP 请求失败
|
||||
ValueError: 响应格式不正确
|
||||
"""
|
||||
url = f"{self.base_url}{self.api_prefix}/parse/path"
|
||||
|
||||
data = {
|
||||
"file_path": file_path,
|
||||
"page_from": page_from,
|
||||
"page_to": page_to,
|
||||
"zoomin": zoomin
|
||||
}
|
||||
|
||||
try:
|
||||
logger.info(f"Calling OCR service: {url} for file: {file_path}")
|
||||
response = requests.post(url, data=data, timeout=self.timeout)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
if not result.get("success", False):
|
||||
raise ValueError(f"OCR service returned error: {result.get('message', 'Unknown error')}")
|
||||
|
||||
return result
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to call OCR service: {e}")
|
||||
raise
|
||||
|
||||
def parse_pdf_by_bytes(self, pdf_bytes: bytes, filename: str = "document.pdf",
|
||||
page_from: int = 1, page_to: int = 0, zoomin: int = 3) -> Dict[str, Any]:
|
||||
"""
|
||||
通过二进制数据解析 PDF
|
||||
|
||||
Args:
|
||||
pdf_bytes: PDF 文件的二进制数据
|
||||
filename: 文件名(仅用于日志)
|
||||
page_from: 起始页码(从1开始)
|
||||
page_to: 结束页码(0表示最后一页)
|
||||
zoomin: 图像放大倍数(1-5)
|
||||
|
||||
Returns:
|
||||
dict: 解析结果,格式同 parse_pdf_by_path
|
||||
|
||||
Raises:
|
||||
requests.RequestException: HTTP 请求失败
|
||||
ValueError: 响应格式不正确
|
||||
"""
|
||||
url = f"{self.base_url}{self.api_prefix}/parse/bytes"
|
||||
|
||||
files = {
|
||||
"pdf_bytes": (filename, pdf_bytes, "application/pdf")
|
||||
}
|
||||
|
||||
data = {
|
||||
"filename": filename,
|
||||
"page_from": page_from,
|
||||
"page_to": page_to,
|
||||
"zoomin": zoomin
|
||||
}
|
||||
|
||||
try:
|
||||
logger.info(f"Calling OCR service: {url} with {len(pdf_bytes)} bytes")
|
||||
response = requests.post(url, files=files, data=data, timeout=self.timeout)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
if not result.get("success", False):
|
||||
raise ValueError(f"OCR service returned error: {result.get('message', 'Unknown error')}")
|
||||
|
||||
return result
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to call OCR service: {e}")
|
||||
raise
|
||||
|
||||
def health_check(self) -> Dict[str, Any]:
|
||||
"""
|
||||
检查 OCR 服务健康状态
|
||||
|
||||
Returns:
|
||||
dict: 健康状态信息
|
||||
"""
|
||||
url = f"{self.base_url}{self.api_prefix}/health"
|
||||
|
||||
try:
|
||||
response = requests.get(url, timeout=10)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to check OCR service health: {e}")
|
||||
raise
|
||||
|
||||
@@ -35,6 +35,7 @@ from pypdf import PdfReader as pdf2_read
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from deepdoc.vision import OCR, AscendLayoutRecognizer, LayoutRecognizer, Recognizer, TableStructureRecognizer
|
||||
from deepdoc.parser.ocr_http_client import OCRHttpClient
|
||||
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
|
||||
from rag.nlp import rag_tokenizer
|
||||
from rag.prompts.generator import vision_llm_describe_prompt
|
||||
@@ -58,10 +59,24 @@ class RAGFlowPdfParser:
|
||||
^_-
|
||||
|
||||
"""
|
||||
|
||||
self.ocr = OCR()
|
||||
|
||||
# 检查是否使用 HTTP OCR 服务
|
||||
use_http_ocr = os.getenv("USE_OCR_HTTP", "false").lower() in ("true", "1", "yes")
|
||||
ocr_service_url = os.getenv("OCR_SERVICE_URL", "http://localhost:8000")
|
||||
|
||||
if use_http_ocr:
|
||||
logging.info(f"Using HTTP OCR service: {ocr_service_url}")
|
||||
self.ocr = None # 不使用本地 OCR
|
||||
self.ocr_http_client = OCRHttpClient(base_url=ocr_service_url)
|
||||
self.use_http_ocr = True
|
||||
else:
|
||||
logging.info("Using local OCR")
|
||||
self.ocr = OCR()
|
||||
self.ocr_http_client = None
|
||||
self.use_http_ocr = False
|
||||
|
||||
self.parallel_limiter = None
|
||||
if PARALLEL_DEVICES > 1:
|
||||
if not self.use_http_ocr and PARALLEL_DEVICES > 1:
|
||||
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)]
|
||||
|
||||
layout_recognizer_type = os.getenv("LAYOUT_RECOGNIZER_TYPE", "onnx").lower()
|
||||
@@ -276,7 +291,97 @@ class RAGFlowPdfParser:
|
||||
b["H_right"] = spans[ii]["x1"]
|
||||
b["SP"] = ii
|
||||
|
||||
def _convert_http_ocr_result(self, ocr_result: dict, zoomin: int = 3):
|
||||
"""
|
||||
将 HTTP OCR API 返回的结果转换为 RAGFlow 内部格式
|
||||
|
||||
Args:
|
||||
ocr_result: HTTP API 返回的结果,格式:
|
||||
{
|
||||
"success": bool,
|
||||
"data": {
|
||||
"pages": [
|
||||
{
|
||||
"page_number": int,
|
||||
"boxes": [
|
||||
{
|
||||
"text": str,
|
||||
"bbox": [[x0, y0], [x1, y0], [x1, y1], [x0, y1]],
|
||||
"confidence": float
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
zoomin: 放大倍数
|
||||
"""
|
||||
if not ocr_result.get("success", False) or "data" not in ocr_result:
|
||||
logging.warning("Invalid OCR HTTP result")
|
||||
return
|
||||
|
||||
pages_data = ocr_result["data"].get("pages", [])
|
||||
self.boxes = []
|
||||
|
||||
for page_data in pages_data:
|
||||
page_num = page_data.get("page_number", 0) # HTTP API 返回的页码(从1开始)
|
||||
boxes = page_data.get("boxes", [])
|
||||
|
||||
# 转换为 RAGFlow 格式的 boxes
|
||||
ragflow_boxes = []
|
||||
# 计算在 page_chars 中的索引:HTTP API 返回的页码是从1开始的,需要转换为相对于 page_from 的索引
|
||||
page_index = page_num - (self.page_from + 1) # page_from 是从0开始,所以需要 +1
|
||||
chars_for_page = self.page_chars[page_index] if hasattr(self, 'page_chars') and 0 <= page_index < len(self.page_chars) else []
|
||||
|
||||
for box in boxes:
|
||||
bbox = box.get("bbox", [])
|
||||
if len(bbox) != 4:
|
||||
continue
|
||||
|
||||
# 从 bbox 提取坐标(bbox 格式: [[x0, y0], [x1, y0], [x1, y1], [x0, y1]])
|
||||
x0 = min(bbox[0][0], bbox[3][0]) / zoomin
|
||||
x1 = max(bbox[1][0], bbox[2][0]) / zoomin
|
||||
top = min(bbox[0][1], bbox[1][1]) / zoomin
|
||||
bottom = max(bbox[2][1], bbox[3][1]) / zoomin
|
||||
|
||||
# 创建 RAGFlow 格式的 box
|
||||
ragflow_box = {
|
||||
"x0": x0,
|
||||
"x1": x1,
|
||||
"top": top,
|
||||
"bottom": bottom,
|
||||
"text": box.get("text", ""),
|
||||
"page_number": page_num,
|
||||
"layoutno": "",
|
||||
"layout_type": ""
|
||||
}
|
||||
|
||||
ragflow_boxes.append(ragflow_box)
|
||||
|
||||
# 计算 mean_height
|
||||
if ragflow_boxes:
|
||||
heights = [b["bottom"] - b["top"] for b in ragflow_boxes]
|
||||
self.mean_height.append(np.median(heights) if heights else 0)
|
||||
else:
|
||||
self.mean_height.append(0)
|
||||
|
||||
# 计算 mean_width
|
||||
if chars_for_page:
|
||||
widths = [c.get("width", 8) for c in chars_for_page]
|
||||
self.mean_width.append(np.median(widths) if widths else 8)
|
||||
else:
|
||||
self.mean_width.append(8)
|
||||
|
||||
self.boxes.append(ragflow_boxes)
|
||||
|
||||
logging.info(f"Converted {len(pages_data)} pages from HTTP OCR result")
|
||||
|
||||
def __ocr(self, pagenum, img, chars, ZM=3, device_id: int | None = None):
|
||||
# 如果使用 HTTP OCR,这个方法不会被调用
|
||||
if self.use_http_ocr:
|
||||
logging.warning("__ocr called when using HTTP OCR, this should not happen")
|
||||
return
|
||||
|
||||
start = timer()
|
||||
bxs = self.ocr.detect(np.array(img), device_id)
|
||||
logging.info(f"__ocr detecting boxes of a image cost ({timer() - start}s)")
|
||||
@@ -927,6 +1032,7 @@ class RAGFlowPdfParser:
|
||||
self.page_cum_height = [0]
|
||||
self.page_layout = []
|
||||
self.page_from = page_from
|
||||
|
||||
start = timer()
|
||||
try:
|
||||
with sys.modules[LOCK_KEY_pdfplumber]:
|
||||
@@ -945,6 +1051,42 @@ class RAGFlowPdfParser:
|
||||
except Exception:
|
||||
logging.exception("RAGFlowPdfParser __images__")
|
||||
logging.info(f"__images__ dedupe_chars cost {timer() - start}s")
|
||||
|
||||
# 如果使用 HTTP OCR,在获取图片和字符信息后调用 HTTP API 获取 OCR 结果
|
||||
if self.use_http_ocr:
|
||||
try:
|
||||
if callback:
|
||||
callback(0.1, "Calling OCR HTTP service...")
|
||||
|
||||
# 调用 HTTP OCR 服务
|
||||
if isinstance(fnm, str):
|
||||
# 文件路径
|
||||
ocr_result = self.ocr_http_client.parse_pdf_by_path(
|
||||
fnm,
|
||||
page_from=page_from + 1, # HTTP API 使用从1开始的页码
|
||||
page_to=(page_to + 1) if page_to < 299 else 0, # 转换为从1开始,0 表示最后一页
|
||||
zoomin=zoomin
|
||||
)
|
||||
else:
|
||||
# 二进制数据
|
||||
ocr_result = self.ocr_http_client.parse_pdf_by_bytes(
|
||||
fnm,
|
||||
filename="document.pdf",
|
||||
page_from=page_from + 1,
|
||||
page_to=(page_to + 1) if page_to < 299 else 0,
|
||||
zoomin=zoomin
|
||||
)
|
||||
|
||||
# 将 HTTP API 返回的结果转换为 RAGFlow 格式
|
||||
self._convert_http_ocr_result(ocr_result, zoomin)
|
||||
|
||||
if callback:
|
||||
callback(0.4, "OCR HTTP service completed")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to call OCR HTTP service: {e}", exc_info=True)
|
||||
# 如果 HTTP OCR 失败,回退到空结果或抛出异常
|
||||
raise
|
||||
|
||||
self.outlines = []
|
||||
try:
|
||||
@@ -999,29 +1141,34 @@ class RAGFlowPdfParser:
|
||||
if callback and i % 6 == 5:
|
||||
callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="")
|
||||
|
||||
async def __img_ocr_launcher():
|
||||
def __ocr_preprocess():
|
||||
chars = self.page_chars[i] if not self.is_english else []
|
||||
self.mean_height.append(np.median(sorted([c["height"] for c in chars])) if chars else 0)
|
||||
self.mean_width.append(np.median(sorted([c["width"] for c in chars])) if chars else 8)
|
||||
self.page_cum_height.append(img.size[1] / zoomin)
|
||||
return chars
|
||||
# 如果使用 HTTP OCR,已经在上面的代码中获取了结果,跳过本地 OCR
|
||||
if not self.use_http_ocr:
|
||||
async def __img_ocr_launcher():
|
||||
def __ocr_preprocess():
|
||||
chars = self.page_chars[i] if not self.is_english else []
|
||||
self.mean_height.append(np.median(sorted([c["height"] for c in chars])) if chars else 0)
|
||||
self.mean_width.append(np.median(sorted([c["width"] for c in chars])) if chars else 8)
|
||||
self.page_cum_height.append(img.size[1] / zoomin)
|
||||
return chars
|
||||
|
||||
if self.parallel_limiter:
|
||||
async with trio.open_nursery() as nursery:
|
||||
if self.parallel_limiter:
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i, img in enumerate(self.page_images):
|
||||
chars = __ocr_preprocess()
|
||||
|
||||
nursery.start_soon(__img_ocr, i, i % PARALLEL_DEVICES, img, chars, self.parallel_limiter[i % PARALLEL_DEVICES])
|
||||
await trio.sleep(0.1)
|
||||
else:
|
||||
for i, img in enumerate(self.page_images):
|
||||
chars = __ocr_preprocess()
|
||||
await __img_ocr(i, 0, img, chars, None)
|
||||
|
||||
nursery.start_soon(__img_ocr, i, i % PARALLEL_DEVICES, img, chars, self.parallel_limiter[i % PARALLEL_DEVICES])
|
||||
await trio.sleep(0.1)
|
||||
else:
|
||||
for i, img in enumerate(self.page_images):
|
||||
chars = __ocr_preprocess()
|
||||
await __img_ocr(i, 0, img, chars, None)
|
||||
|
||||
start = timer()
|
||||
|
||||
trio.run(__img_ocr_launcher)
|
||||
start = timer()
|
||||
trio.run(__img_ocr_launcher)
|
||||
else:
|
||||
# HTTP OCR 模式:初始化 page_cum_height
|
||||
for i, img in enumerate(self.page_images):
|
||||
self.page_cum_height.append(img.size[1] / zoomin)
|
||||
|
||||
logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s")
|
||||
|
||||
|
||||
@@ -198,4 +198,6 @@ POSTGRES_DBNAME=rag_flow
|
||||
POSTGRES_USER=rag_flow
|
||||
POSTGRES_PASSWORD=infini_rag_flow
|
||||
POSTGRES_PORT=5432
|
||||
DB_TYPE=postgres
|
||||
DB_TYPE=postgres
|
||||
|
||||
USE_OCR_HTTP=true
|
||||
53
ocr/__init__.py
Normal file
53
ocr/__init__.py
Normal file
@@ -0,0 +1,53 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
独立的 OCR 模块
|
||||
|
||||
此模块从 RAGFlow 项目中提取,已经移除了对 RAGFlow 特定模块的依赖。
|
||||
可以直接作为独立模块使用。
|
||||
|
||||
使用方法:
|
||||
from ocr import OCR
|
||||
import cv2
|
||||
|
||||
ocr = OCR()
|
||||
img = cv2.imread("image.jpg")
|
||||
results = ocr(img)
|
||||
"""
|
||||
|
||||
# 处理导入问题:支持直接运行和模块导入
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
_package = __package__
|
||||
except NameError:
|
||||
_package = None
|
||||
|
||||
if _package is None:
|
||||
# 直接运行时,添加父目录到路径并使用绝对导入
|
||||
parent_dir = Path(__file__).parent.parent
|
||||
if str(parent_dir) not in sys.path:
|
||||
sys.path.insert(0, str(parent_dir))
|
||||
from ocr.ocr import OCR, TextDetector, TextRecognizer
|
||||
from ocr.pdf_parser import SimplePdfParser
|
||||
else:
|
||||
# 作为模块导入时使用相对导入
|
||||
from .ocr import OCR, TextDetector, TextRecognizer
|
||||
from .pdf_parser import SimplePdfParser
|
||||
|
||||
__all__ = ['OCR', 'TextDetector', 'TextRecognizer', 'SimplePdfParser']
|
||||
|
||||
332
ocr/api.py
Normal file
332
ocr/api.py
Normal file
@@ -0,0 +1,332 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
OCR PDF处理的FastAPI路由
|
||||
提供HTTP接口用于PDF的OCR识别
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
# 处理导入问题:支持直接运行和模块导入
|
||||
|
||||
try:
|
||||
_package = __package__
|
||||
except NameError:
|
||||
_package = None
|
||||
|
||||
if _package is None:
|
||||
# 直接运行时,添加父目录到路径并使用绝对导入
|
||||
parent_dir = Path(__file__).parent.parent
|
||||
if str(parent_dir) not in sys.path:
|
||||
sys.path.insert(0, str(parent_dir))
|
||||
from ocr.pdf_parser import SimplePdfParser
|
||||
from ocr.config import MODEL_DIR
|
||||
else:
|
||||
# 作为模块导入时使用相对导入
|
||||
from pdf_parser import SimplePdfParser
|
||||
from config import MODEL_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr", tags=["OCR"])
|
||||
|
||||
# 全局解析器实例(懒加载)
|
||||
_parser_instance: Optional[SimplePdfParser] = None
|
||||
|
||||
|
||||
def get_parser() -> SimplePdfParser:
|
||||
"""获取全局解析器实例(单例模式)"""
|
||||
global _parser_instance
|
||||
if _parser_instance is None:
|
||||
logger.info(f"Initializing OCR parser with model_dir={MODEL_DIR}")
|
||||
_parser_instance = SimplePdfParser(model_dir=MODEL_DIR)
|
||||
return _parser_instance
|
||||
|
||||
|
||||
class ParseResponse(BaseModel):
|
||||
"""解析响应模型"""
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[dict] = None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/health",
|
||||
summary="健康检查",
|
||||
description="检查OCR服务的健康状态和配置信息",
|
||||
response_description="返回服务状态和模型目录信息"
|
||||
)
|
||||
async def health_check():
|
||||
"""
|
||||
健康检查端点
|
||||
|
||||
用于检查OCR服务的运行状态和配置信息。
|
||||
|
||||
Returns:
|
||||
dict: 包含服务状态和模型目录的信息
|
||||
"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "OCR PDF Parser",
|
||||
"model_dir": MODEL_DIR
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/parse",
|
||||
response_model=ParseResponse,
|
||||
summary="上传并解析PDF文件",
|
||||
description="上传PDF文件并通过OCR识别提取文本内容",
|
||||
response_description="返回OCR识别结果"
|
||||
)
|
||||
async def parse_pdf_endpoint(
|
||||
file: UploadFile = File(..., description="PDF文件,支持上传任意PDF文档"),
|
||||
page_from: int = Form(1, ge=1, description="起始页码(从1开始,默认为1)"),
|
||||
page_to: int = Form(0, ge=0, description="结束页码(0表示解析到最后一页,默认为0)"),
|
||||
zoomin: int = Form(3, ge=1, le=5, description="图像放大倍数(1-5,数值越大识别精度越高但速度越慢,默认为3)")
|
||||
):
|
||||
"""
|
||||
上传并解析PDF文件
|
||||
|
||||
通过上传PDF文件,使用OCR技术识别并提取其中的文本内容。
|
||||
支持指定解析的页码范围,以及调整图像放大倍数以平衡识别精度和速度。
|
||||
|
||||
Args:
|
||||
file: 上传的PDF文件(multipart/form-data格式)
|
||||
page_from: 起始页码(从1开始,最小值为1)
|
||||
page_to: 结束页码(0表示解析到最后一页,最小值为0)
|
||||
zoomin: 图像放大倍数(1-5之间,数值越大识别精度越高但处理速度越慢)
|
||||
|
||||
Returns:
|
||||
ParseResponse: 包含解析结果的响应对象,包括:
|
||||
- success: 是否成功
|
||||
- message: 操作结果消息
|
||||
- data: OCR识别的文本内容和元数据
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 - 如果文件不是PDF格式或文件为空
|
||||
HTTPException: 500 - 如果解析过程中发生错误
|
||||
"""
|
||||
if not file.filename.lower().endswith('.pdf'):
|
||||
raise HTTPException(status_code=400, detail="只支持PDF文件")
|
||||
|
||||
# 保存上传的文件到临时目录
|
||||
temp_file = None
|
||||
try:
|
||||
# 读取文件内容
|
||||
content = await file.read()
|
||||
if not content:
|
||||
raise HTTPException(status_code=400, detail="文件为空")
|
||||
|
||||
# 创建临时文件
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
|
||||
tmp.write(content)
|
||||
temp_file = tmp.name
|
||||
|
||||
logger.info(f"Parsing PDF file: {file.filename}, pages {page_from}-{page_to or 'end'}, zoomin={zoomin}")
|
||||
|
||||
# 解析PDF(parse_pdf是同步方法,使用to_thread在线程池中执行)
|
||||
parser = get_parser()
|
||||
result = await asyncio.to_thread(
|
||||
parser.parse_pdf,
|
||||
temp_file,
|
||||
zoomin,
|
||||
page_from - 1, # 转换为从0开始的索引
|
||||
(page_to - 1) if page_to > 0 else 299, # 转换为从0开始的索引
|
||||
None # callback
|
||||
)
|
||||
|
||||
return ParseResponse(
|
||||
success=True,
|
||||
message=f"成功解析PDF: {file.filename}",
|
||||
data=result
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing PDF: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"解析PDF时发生错误: {str(e)}"
|
||||
)
|
||||
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if temp_file and os.path.exists(temp_file):
|
||||
try:
|
||||
os.unlink(temp_file)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete temp file {temp_file}: {e}")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/parse/bytes",
|
||||
response_model=ParseResponse,
|
||||
summary="通过二进制数据解析PDF",
|
||||
description="直接通过二进制数据解析PDF文件,无需上传文件",
|
||||
response_description="返回OCR识别结果"
|
||||
)
|
||||
async def parse_pdf_bytes(
|
||||
pdf_bytes: bytes = File(..., description="PDF文件的二进制数据(multipart/form-data格式)"),
|
||||
filename: str = Form("document.pdf", description="文件名(仅用于日志记录,不影响解析)"),
|
||||
page_from: int = Form(1, ge=1, description="起始页码(从1开始,默认为1)"),
|
||||
page_to: int = Form(0, ge=0, description="结束页码(0表示解析到最后一页,默认为0)"),
|
||||
zoomin: int = Form(3, ge=1, le=5, description="图像放大倍数(1-5,数值越大识别精度越高但速度越慢,默认为3)")
|
||||
):
|
||||
"""
|
||||
直接通过二进制数据解析PDF
|
||||
|
||||
适用于已获取PDF二进制数据的场景,无需文件上传步骤。
|
||||
直接将PDF的二进制数据提交即可进行OCR识别。
|
||||
|
||||
Args:
|
||||
pdf_bytes: PDF文件的二进制数据(以文件形式提交)
|
||||
filename: 文件名(仅用于日志记录,不影响实际解析过程)
|
||||
page_from: 起始页码(从1开始,最小值为1)
|
||||
page_to: 结束页码(0表示解析到最后一页,最小值为0)
|
||||
zoomin: 图像放大倍数(1-5之间,数值越大识别精度越高但处理速度越慢)
|
||||
|
||||
Returns:
|
||||
ParseResponse: 包含解析结果的响应对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 - 如果PDF数据为空
|
||||
HTTPException: 500 - 如果解析过程中发生错误
|
||||
"""
|
||||
if not pdf_bytes:
|
||||
raise HTTPException(status_code=400, detail="PDF数据为空")
|
||||
|
||||
# 保存到临时文件
|
||||
temp_file = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
|
||||
tmp.write(pdf_bytes)
|
||||
temp_file = tmp.name
|
||||
|
||||
logger.info(f"Parsing PDF bytes (filename: {filename}), pages {page_from}-{page_to or 'end'}, zoomin={zoomin}")
|
||||
|
||||
# 解析PDF(parse_pdf是同步方法,使用to_thread在线程池中执行)
|
||||
parser = get_parser()
|
||||
result = await asyncio.to_thread(
|
||||
parser.parse_pdf,
|
||||
temp_file,
|
||||
zoomin,
|
||||
page_from - 1, # 转换为从0开始的索引
|
||||
(page_to - 1) if page_to > 0 else 299, # 转换为从0开始的索引
|
||||
None # callback
|
||||
)
|
||||
|
||||
return ParseResponse(
|
||||
success=True,
|
||||
message=f"成功解析PDF: {filename}",
|
||||
data=result
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing PDF bytes: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"解析PDF时发生错误: {str(e)}"
|
||||
)
|
||||
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if temp_file and os.path.exists(temp_file):
|
||||
try:
|
||||
os.unlink(temp_file)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete temp file {temp_file}: {e}")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/parse/path",
|
||||
response_model=ParseResponse,
|
||||
summary="通过文件路径解析PDF",
|
||||
description="通过服务器本地文件路径解析PDF文件",
|
||||
response_description="返回OCR识别结果"
|
||||
)
|
||||
async def parse_pdf_path(
|
||||
file_path: str = Form(..., description="PDF文件在服务器上的本地路径(必须是可访问的绝对路径)"),
|
||||
page_from: int = Form(1, ge=1, description="起始页码(从1开始,默认为1)"),
|
||||
page_to: int = Form(0, ge=0, description="结束页码(0表示解析到最后一页,默认为0)"),
|
||||
zoomin: int = Form(3, ge=1, le=5, description="图像放大倍数(1-5,数值越大识别精度越高但速度越慢,默认为3)")
|
||||
):
|
||||
"""
|
||||
通过文件路径解析PDF
|
||||
|
||||
适用于PDF文件已经存在于服务器上的场景。
|
||||
通过提供文件路径直接进行OCR识别,无需上传文件。
|
||||
|
||||
Args:
|
||||
file_path: PDF文件在服务器上的本地路径(必须是服务器可访问的绝对路径)
|
||||
page_from: 起始页码(从1开始,最小值为1)
|
||||
page_to: 结束页码(0表示解析到最后一页,最小值为0)
|
||||
zoomin: 图像放大倍数(1-5之间,数值越大识别精度越高但处理速度越慢)
|
||||
|
||||
Returns:
|
||||
ParseResponse: 包含解析结果的响应对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 - 如果文件不是PDF格式
|
||||
HTTPException: 404 - 如果文件不存在
|
||||
HTTPException: 500 - 如果解析过程中发生错误
|
||||
|
||||
Note:
|
||||
此端点需要确保提供的文件路径在服务器上可访问。
|
||||
建议仅在内网环境或受信任的环境中使用,避免路径遍历安全风险。
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise HTTPException(status_code=404, detail=f"文件不存在: {file_path}")
|
||||
|
||||
if not file_path.lower().endswith('.pdf'):
|
||||
raise HTTPException(status_code=400, detail="只支持PDF文件")
|
||||
|
||||
try:
|
||||
logger.info(f"Parsing PDF from path: {file_path}, pages {page_from}-{page_to or 'end'}, zoomin={zoomin}")
|
||||
|
||||
# 解析PDF(parse_pdf是同步方法,使用to_thread在线程池中执行)
|
||||
parser = get_parser()
|
||||
result = await asyncio.to_thread(
|
||||
parser.parse_pdf,
|
||||
file_path,
|
||||
zoomin,
|
||||
page_from - 1, # 转换为从0开始的索引
|
||||
(page_to - 1) if page_to > 0 else 299, # 转换为从0开始的索引
|
||||
None # callback
|
||||
)
|
||||
|
||||
return ParseResponse(
|
||||
success=True,
|
||||
message=f"成功解析PDF: {file_path}",
|
||||
data=result
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing PDF from path: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"解析PDF时发生错误: {str(e)}"
|
||||
)
|
||||
|
||||
42
ocr/config.py
Normal file
42
ocr/config.py
Normal file
@@ -0,0 +1,42 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
OCR 模块配置文件
|
||||
"""
|
||||
import os
|
||||
import logging
|
||||
|
||||
# 并行设备数量(GPU数量,0表示使用CPU)
|
||||
PARALLEL_DEVICES = 0
|
||||
try:
|
||||
import torch.cuda
|
||||
PARALLEL_DEVICES = torch.cuda.device_count()
|
||||
logging.info(f"found {PARALLEL_DEVICES} gpus")
|
||||
except Exception:
|
||||
logging.info("can't import package 'torch', using CPU mode")
|
||||
|
||||
# 模型目录
|
||||
# 可以从环境变量获取,或使用默认路径
|
||||
MODEL_DIR = os.getenv("OCR_MODEL_DIR", None)
|
||||
if MODEL_DIR is None:
|
||||
# 默认模型目录:当前项目根目录下的 models/deepdoc 目录
|
||||
# 如果不存在,将在 OCR 类初始化时尝试从 HuggingFace 下载
|
||||
_base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
MODEL_DIR = os.path.join(_base_dir, "models", "deepdoc")
|
||||
# 如果目录不存在,设置为 None,让 OCR 类处理下载逻辑
|
||||
if not os.path.exists(MODEL_DIR):
|
||||
MODEL_DIR = None
|
||||
|
||||
202
ocr/main.py
Normal file
202
ocr/main.py
Normal file
@@ -0,0 +1,202 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
OCR PDF处理服务的主程序入口
|
||||
独立运行,不依赖RAGFlow的其他部分
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
from pathlib import Path
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
# 处理直接运行时的导入问题
|
||||
# 当直接运行 python ocr/main.py 时,__package__ 为 None
|
||||
# 当作为模块运行时(python -m ocr.main),__package__ 为 'ocr'
|
||||
try:
|
||||
_package = __package__
|
||||
except NameError:
|
||||
_package = None
|
||||
|
||||
if _package is None:
|
||||
# 直接运行脚本时,添加父目录到路径
|
||||
parent_dir = Path(__file__).parent.parent
|
||||
if str(parent_dir) not in sys.path:
|
||||
sys.path.insert(0, str(parent_dir))
|
||||
from api import router as ocr_router
|
||||
from config import MODEL_DIR
|
||||
else:
|
||||
# 作为模块导入时使用相对导入
|
||||
from api import router as ocr_router
|
||||
from config import MODEL_DIR
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""创建FastAPI应用实例"""
|
||||
app = FastAPI(
|
||||
title="OCR PDF Parser API",
|
||||
description="独立的OCR PDF处理服务,提供PDF文档的OCR识别功能",
|
||||
version="1.0.0",
|
||||
docs_url="/apidocs", # Swagger UI 文档地址
|
||||
redoc_url="/redoc", # ReDoc 文档地址(备用)
|
||||
openapi_url="/openapi.json" # OpenAPI JSON schema 地址
|
||||
)
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # 生产环境中应该设置具体的域名
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 注册OCR路由
|
||||
app.include_router(ocr_router)
|
||||
|
||||
# 根路径
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"service": "OCR PDF Parser",
|
||||
"version": "1.0.0",
|
||||
"docs": "/apidocs",
|
||||
"health": "/api/v1/ocr/health"
|
||||
}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
"""信号处理器,用于优雅关闭"""
|
||||
logger.info("Received shutdown signal, exiting...")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(description="OCR PDF处理服务")
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help="服务器监听地址 (default: 0.0.0.0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="服务器端口 (default: 8000)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reload",
|
||||
action="store_true",
|
||||
help="开发模式:自动重载代码"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="工作进程数 (default: 1)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
type=str,
|
||||
default="info",
|
||||
choices=["critical", "error", "warning", "info", "debug", "trace"],
|
||||
help="日志级别 (default: info)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help=f"OCR模型目录路径 (default: {MODEL_DIR})"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 设置模型目录(如果提供)
|
||||
if args.model_dir:
|
||||
os.environ["OCR_MODEL_DIR"] = args.model_dir
|
||||
logger.info(f"Using custom model directory: {args.model_dir}")
|
||||
|
||||
# 检查模型目录
|
||||
model_dir = os.environ.get("OCR_MODEL_DIR", MODEL_DIR)
|
||||
if model_dir and not os.path.exists(model_dir):
|
||||
logger.warning(f"Model directory does not exist: {model_dir}")
|
||||
logger.info("Models will be downloaded on first use")
|
||||
|
||||
# 注册信号处理器
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# 显示启动信息
|
||||
logger.info("=" * 60)
|
||||
logger.info("OCR PDF Parser Service")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Host: {args.host}")
|
||||
logger.info(f"Port: {args.port}")
|
||||
logger.info(f"Model Directory: {model_dir}")
|
||||
logger.info(f"Workers: {args.workers}")
|
||||
logger.info(f"Reload: {args.reload}")
|
||||
logger.info(f"Log Level: {args.log_level}")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"API Documentation (Swagger): http://{args.host}:{args.port}/apidocs")
|
||||
logger.info(f"API Documentation (ReDoc): http://{args.host}:{args.port}/redoc")
|
||||
logger.info(f"Health Check: http://{args.host}:{args.port}/api/v1/ocr/health")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 创建应用
|
||||
app = create_app()
|
||||
|
||||
# 启动服务器
|
||||
try:
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level=args.log_level,
|
||||
reload=args.reload,
|
||||
workers=args.workers if not args.reload else 1, # reload模式不支持多进程
|
||||
access_log=True
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Server stopped by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Server error: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
785
ocr/ocr.py
Normal file
785
ocr/ocr.py
Normal file
@@ -0,0 +1,785 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import gc
|
||||
import logging
|
||||
import copy
|
||||
import time
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# 处理导入问题:支持直接运行和模块导入
|
||||
try:
|
||||
_package = __package__
|
||||
except NameError:
|
||||
_package = None
|
||||
|
||||
if _package is None:
|
||||
# 直接运行时,添加父目录到路径并使用绝对导入
|
||||
parent_dir = Path(__file__).parent.parent
|
||||
if str(parent_dir) not in sys.path:
|
||||
sys.path.insert(0, str(parent_dir))
|
||||
from ocr.utils import get_project_base_directory
|
||||
from ocr.config import PARALLEL_DEVICES, MODEL_DIR
|
||||
from ocr.operators import * # noqa: F403
|
||||
import ocr.operators as operators
|
||||
from ocr.postprocess import build_post_process
|
||||
else:
|
||||
# 作为模块导入时使用相对导入
|
||||
from utils import get_project_base_directory
|
||||
from config import PARALLEL_DEVICES, MODEL_DIR
|
||||
from operators import * # noqa: F403
|
||||
import operators
|
||||
from postprocess import build_post_process
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import cv2
|
||||
import onnxruntime as ort
|
||||
|
||||
loaded_models = {}
|
||||
|
||||
def transform(data, ops=None):
|
||||
""" transform """
|
||||
if ops is None:
|
||||
ops = []
|
||||
for op in ops:
|
||||
data = op(data)
|
||||
if data is None:
|
||||
return None
|
||||
return data
|
||||
|
||||
|
||||
def create_operators(op_param_list, global_config=None):
|
||||
"""
|
||||
create operators based on the config
|
||||
|
||||
Args:
|
||||
params(list): a dict list, used to create some operators
|
||||
"""
|
||||
assert isinstance(
|
||||
op_param_list, list), ('operator config should be a list')
|
||||
ops = []
|
||||
for operator in op_param_list:
|
||||
assert isinstance(operator,
|
||||
dict) and len(operator) == 1, "yaml format error"
|
||||
op_name = list(operator)[0]
|
||||
param = {} if operator[op_name] is None else operator[op_name]
|
||||
if global_config is not None:
|
||||
param.update(global_config)
|
||||
op = getattr(operators, op_name)(**param)
|
||||
ops.append(op)
|
||||
return ops
|
||||
|
||||
|
||||
def load_model(model_dir, nm, device_id: int | None = None):
|
||||
model_file_path = os.path.join(model_dir, nm + ".onnx")
|
||||
model_cached_tag = model_file_path + str(device_id) if device_id is not None else model_file_path
|
||||
|
||||
global loaded_models
|
||||
loaded_model = loaded_models.get(model_cached_tag)
|
||||
if loaded_model:
|
||||
logging.info(f"load_model {model_file_path} reuses cached model")
|
||||
return loaded_model
|
||||
|
||||
if not os.path.exists(model_file_path):
|
||||
raise ValueError("not find model file path {}".format(
|
||||
model_file_path))
|
||||
|
||||
def cuda_is_available():
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > device_id:
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
return False
|
||||
|
||||
options = ort.SessionOptions()
|
||||
options.enable_cpu_mem_arena = False
|
||||
options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
||||
options.intra_op_num_threads = 2
|
||||
options.inter_op_num_threads = 2
|
||||
|
||||
# https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580
|
||||
# Shrink GPU memory after execution
|
||||
run_options = ort.RunOptions()
|
||||
if cuda_is_available():
|
||||
cuda_provider_options = {
|
||||
"device_id": device_id, # Use specific GPU
|
||||
"gpu_mem_limit": 512 * 1024 * 1024, # Limit gpu memory
|
||||
"arena_extend_strategy": "kNextPowerOfTwo", # gpu memory allocation strategy
|
||||
}
|
||||
sess = ort.InferenceSession(
|
||||
model_file_path,
|
||||
options=options,
|
||||
providers=['CUDAExecutionProvider'],
|
||||
provider_options=[cuda_provider_options]
|
||||
)
|
||||
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "gpu:" + str(device_id))
|
||||
logging.info(f"load_model {model_file_path} uses GPU")
|
||||
else:
|
||||
sess = ort.InferenceSession(
|
||||
model_file_path,
|
||||
options=options,
|
||||
providers=['CPUExecutionProvider'])
|
||||
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "cpu")
|
||||
logging.info(f"load_model {model_file_path} uses CPU")
|
||||
loaded_model = (sess, run_options)
|
||||
loaded_models[model_cached_tag] = loaded_model
|
||||
return loaded_model
|
||||
|
||||
|
||||
class TextRecognizer:
|
||||
def __init__(self, model_dir, device_id: int | None = None):
|
||||
self.rec_image_shape = [int(v) for v in "3, 48, 320".split(",")]
|
||||
self.rec_batch_num = 16
|
||||
postprocess_params = {
|
||||
'name': 'CTCLabelDecode',
|
||||
"character_dict_path": os.path.join(model_dir, "ocr.res"),
|
||||
"use_space_char": True
|
||||
}
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.run_options = load_model(model_dir, 'rec', device_id)
|
||||
self.input_tensor = self.predictor.get_inputs()[0]
|
||||
|
||||
def resize_norm_img(self, img, max_wh_ratio):
|
||||
imgC, imgH, imgW = self.rec_image_shape
|
||||
|
||||
assert imgC == img.shape[2]
|
||||
imgW = int((imgH * max_wh_ratio))
|
||||
w = self.input_tensor.shape[3:][0]
|
||||
if isinstance(w, str):
|
||||
pass
|
||||
elif w is not None and w > 0:
|
||||
imgW = w
|
||||
h, w = img.shape[:2]
|
||||
ratio = w / float(h)
|
||||
if math.ceil(imgH * ratio) > imgW:
|
||||
resized_w = imgW
|
||||
else:
|
||||
resized_w = int(math.ceil(imgH * ratio))
|
||||
|
||||
resized_image = cv2.resize(img, (resized_w, imgH))
|
||||
resized_image = resized_image.astype('float32')
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
||||
padding_im[:, :, 0:resized_w] = resized_image
|
||||
return padding_im
|
||||
|
||||
def resize_norm_img_vl(self, img, image_shape):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
img = img[:, :, ::-1] # bgr2rgb
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_image = resized_image.astype('float32')
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
return resized_image
|
||||
|
||||
def resize_norm_img_srn(self, img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
|
||||
img_black = np.zeros((imgH, imgW))
|
||||
im_hei = img.shape[0]
|
||||
im_wid = img.shape[1]
|
||||
|
||||
if im_wid <= im_hei * 1:
|
||||
img_new = cv2.resize(img, (imgH * 1, imgH))
|
||||
elif im_wid <= im_hei * 2:
|
||||
img_new = cv2.resize(img, (imgH * 2, imgH))
|
||||
elif im_wid <= im_hei * 3:
|
||||
img_new = cv2.resize(img, (imgH * 3, imgH))
|
||||
else:
|
||||
img_new = cv2.resize(img, (imgW, imgH))
|
||||
|
||||
img_np = np.asarray(img_new)
|
||||
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
|
||||
img_black[:, 0:img_np.shape[1]] = img_np
|
||||
img_black = img_black[:, :, np.newaxis]
|
||||
|
||||
row, col, c = img_black.shape
|
||||
c = 1
|
||||
|
||||
return np.reshape(img_black, (c, row, col)).astype(np.float32)
|
||||
|
||||
def srn_other_inputs(self, image_shape, num_heads, max_text_length):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
feature_dim = int((imgH / 8) * (imgW / 8))
|
||||
|
||||
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
|
||||
(feature_dim, 1)).astype('int64')
|
||||
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
|
||||
(max_text_length, 1)).astype('int64')
|
||||
|
||||
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
||||
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
|
||||
[-1, 1, max_text_length, max_text_length])
|
||||
gsrm_slf_attn_bias1 = np.tile(
|
||||
gsrm_slf_attn_bias1,
|
||||
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
|
||||
|
||||
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
|
||||
[-1, 1, max_text_length, max_text_length])
|
||||
gsrm_slf_attn_bias2 = np.tile(
|
||||
gsrm_slf_attn_bias2,
|
||||
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
|
||||
|
||||
encoder_word_pos = encoder_word_pos[np.newaxis, :]
|
||||
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
|
||||
|
||||
return [
|
||||
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2
|
||||
]
|
||||
|
||||
def process_image_srn(self, img, image_shape, num_heads, max_text_length):
|
||||
norm_img = self.resize_norm_img_srn(img, image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
|
||||
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
|
||||
self.srn_other_inputs(image_shape, num_heads, max_text_length)
|
||||
|
||||
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
|
||||
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
|
||||
encoder_word_pos = encoder_word_pos.astype(np.int64)
|
||||
gsrm_word_pos = gsrm_word_pos.astype(np.int64)
|
||||
|
||||
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2)
|
||||
|
||||
def resize_norm_img_sar(self, img, image_shape,
|
||||
width_downsample_ratio=0.25):
|
||||
imgC, imgH, imgW_min, imgW_max = image_shape
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
valid_ratio = 1.0
|
||||
# make sure new_width is an integral multiple of width_divisor.
|
||||
width_divisor = int(1 / width_downsample_ratio)
|
||||
# resize
|
||||
ratio = w / float(h)
|
||||
resize_w = math.ceil(imgH * ratio)
|
||||
if resize_w % width_divisor != 0:
|
||||
resize_w = round(resize_w / width_divisor) * width_divisor
|
||||
if imgW_min is not None:
|
||||
resize_w = max(imgW_min, resize_w)
|
||||
if imgW_max is not None:
|
||||
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
|
||||
resize_w = min(imgW_max, resize_w)
|
||||
resized_image = cv2.resize(img, (resize_w, imgH))
|
||||
resized_image = resized_image.astype('float32')
|
||||
# norm
|
||||
if image_shape[0] == 1:
|
||||
resized_image = resized_image / 255
|
||||
resized_image = resized_image[np.newaxis, :]
|
||||
else:
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
resize_shape = resized_image.shape
|
||||
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
|
||||
padding_im[:, :, 0:resize_w] = resized_image
|
||||
pad_shape = padding_im.shape
|
||||
|
||||
return padding_im, resize_shape, pad_shape, valid_ratio
|
||||
|
||||
def resize_norm_img_spin(self, img):
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
# return padding_im
|
||||
img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC)
|
||||
img = np.array(img, np.float32)
|
||||
img = np.expand_dims(img, -1)
|
||||
img = img.transpose((2, 0, 1))
|
||||
mean = [127.5]
|
||||
std = [127.5]
|
||||
mean = np.array(mean, dtype=np.float32)
|
||||
std = np.array(std, dtype=np.float32)
|
||||
mean = np.float32(mean.reshape(1, -1))
|
||||
stdinv = 1 / np.float32(std.reshape(1, -1))
|
||||
img -= mean
|
||||
img *= stdinv
|
||||
return img
|
||||
|
||||
def resize_norm_img_svtr(self, img, image_shape):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_image = resized_image.astype('float32')
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
return resized_image
|
||||
|
||||
def resize_norm_img_abinet(self, img, image_shape):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_image = resized_image.astype('float32')
|
||||
resized_image = resized_image / 255.
|
||||
|
||||
mean = np.array([0.485, 0.456, 0.406])
|
||||
std = np.array([0.229, 0.224, 0.225])
|
||||
resized_image = (
|
||||
resized_image - mean[None, None, ...]) / std[None, None, ...]
|
||||
resized_image = resized_image.transpose((2, 0, 1))
|
||||
resized_image = resized_image.astype('float32')
|
||||
|
||||
return resized_image
|
||||
|
||||
def norm_img_can(self, img, image_shape):
|
||||
|
||||
img = cv2.cvtColor(
|
||||
img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
|
||||
|
||||
if self.rec_image_shape[0] == 1:
|
||||
h, w = img.shape
|
||||
_, imgH, imgW = self.rec_image_shape
|
||||
if h < imgH or w < imgW:
|
||||
padding_h = max(imgH - h, 0)
|
||||
padding_w = max(imgW - w, 0)
|
||||
img_padded = np.pad(img, ((0, padding_h), (0, padding_w)),
|
||||
'constant',
|
||||
constant_values=(255))
|
||||
img = img_padded
|
||||
|
||||
img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w
|
||||
img = img.astype('float32')
|
||||
|
||||
return img
|
||||
|
||||
def close(self):
|
||||
# close session and release manually
|
||||
logging.info('Close text recognizer.')
|
||||
if hasattr(self, "predictor"):
|
||||
del self.predictor
|
||||
gc.collect()
|
||||
|
||||
def __call__(self, img_list):
|
||||
img_num = len(img_list)
|
||||
# Calculate the aspect ratio of all text bars
|
||||
width_list = []
|
||||
for img in img_list:
|
||||
width_list.append(img.shape[1] / float(img.shape[0]))
|
||||
# Sorting can speed up the recognition process
|
||||
indices = np.argsort(np.array(width_list))
|
||||
rec_res = [['', 0.0]] * img_num
|
||||
batch_num = self.rec_batch_num
|
||||
st = time.time()
|
||||
|
||||
for beg_img_no in range(0, img_num, batch_num):
|
||||
end_img_no = min(img_num, beg_img_no + batch_num)
|
||||
norm_img_batch = []
|
||||
imgC, imgH, imgW = self.rec_image_shape[:3]
|
||||
max_wh_ratio = imgW / imgH
|
||||
# max_wh_ratio = 0
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
h, w = img_list[indices[ino]].shape[0:2]
|
||||
wh_ratio = w * 1.0 / h
|
||||
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
norm_img = self.resize_norm_img(img_list[indices[ino]],
|
||||
max_wh_ratio)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
norm_img_batch = np.concatenate(norm_img_batch)
|
||||
norm_img_batch = norm_img_batch.copy()
|
||||
|
||||
input_dict = {}
|
||||
input_dict[self.input_tensor.name] = norm_img_batch
|
||||
for i in range(100000):
|
||||
try:
|
||||
outputs = self.predictor.run(None, input_dict, self.run_options)
|
||||
break
|
||||
except Exception as e:
|
||||
if i >= 3:
|
||||
raise e
|
||||
time.sleep(5)
|
||||
preds = outputs[0]
|
||||
rec_result = self.postprocess_op(preds)
|
||||
for rno in range(len(rec_result)):
|
||||
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
|
||||
|
||||
return rec_res, time.time() - st
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
|
||||
class TextDetector:
|
||||
def __init__(self, model_dir, device_id: int | None = None):
|
||||
pre_process_list = [{
|
||||
'DetResizeForTest': {
|
||||
'limit_side_len': 960,
|
||||
'limit_type': "max",
|
||||
}
|
||||
}, {
|
||||
'NormalizeImage': {
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'scale': '1./255.',
|
||||
'order': 'hwc'
|
||||
}
|
||||
}, {
|
||||
'ToCHWImage': None
|
||||
}, {
|
||||
'KeepKeys': {
|
||||
'keep_keys': ['image', 'shape']
|
||||
}
|
||||
}]
|
||||
postprocess_params = {"name": "DBPostProcess", "thresh": 0.3, "box_thresh": 0.5, "max_candidates": 1000,
|
||||
"unclip_ratio": 1.5, "use_dilation": False, "score_mode": "fast", "box_type": "quad"}
|
||||
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.run_options = load_model(model_dir, 'det', device_id)
|
||||
self.input_tensor = self.predictor.get_inputs()[0]
|
||||
|
||||
img_h, img_w = self.input_tensor.shape[2:]
|
||||
if isinstance(img_h, str) or isinstance(img_w, str):
|
||||
pass
|
||||
elif img_h is not None and img_w is not None and img_h > 0 and img_w > 0:
|
||||
pre_process_list[0] = {
|
||||
'DetResizeForTest': {
|
||||
'image_shape': [img_h, img_w]
|
||||
}
|
||||
}
|
||||
self.preprocess_op = create_operators(pre_process_list)
|
||||
|
||||
def order_points_clockwise(self, pts):
|
||||
rect = np.zeros((4, 2), dtype="float32")
|
||||
s = pts.sum(axis=1)
|
||||
rect[0] = pts[np.argmin(s)]
|
||||
rect[2] = pts[np.argmax(s)]
|
||||
tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
|
||||
diff = np.diff(np.array(tmp), axis=1)
|
||||
rect[1] = tmp[np.argmin(diff)]
|
||||
rect[3] = tmp[np.argmax(diff)]
|
||||
return rect
|
||||
|
||||
def clip_det_res(self, points, img_height, img_width):
|
||||
for pno in range(points.shape[0]):
|
||||
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
|
||||
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
|
||||
return points
|
||||
|
||||
def filter_tag_det_res(self, dt_boxes, image_shape):
|
||||
img_height, img_width = image_shape[0:2]
|
||||
dt_boxes_new = []
|
||||
for box in dt_boxes:
|
||||
if isinstance(box, list):
|
||||
box = np.array(box)
|
||||
box = self.order_points_clockwise(box)
|
||||
box = self.clip_det_res(box, img_height, img_width)
|
||||
rect_width = int(np.linalg.norm(box[0] - box[1]))
|
||||
rect_height = int(np.linalg.norm(box[0] - box[3]))
|
||||
if rect_width <= 3 or rect_height <= 3:
|
||||
continue
|
||||
dt_boxes_new.append(box)
|
||||
dt_boxes = np.array(dt_boxes_new)
|
||||
return dt_boxes
|
||||
|
||||
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
|
||||
img_height, img_width = image_shape[0:2]
|
||||
dt_boxes_new = []
|
||||
for box in dt_boxes:
|
||||
if isinstance(box, list):
|
||||
box = np.array(box)
|
||||
box = self.clip_det_res(box, img_height, img_width)
|
||||
dt_boxes_new.append(box)
|
||||
dt_boxes = np.array(dt_boxes_new)
|
||||
return dt_boxes
|
||||
|
||||
def close(self):
|
||||
logging.info("Close text detector.")
|
||||
if hasattr(self, "predictor"):
|
||||
del self.predictor
|
||||
gc.collect()
|
||||
|
||||
def __call__(self, img):
|
||||
ori_im = img.copy()
|
||||
data = {'image': img}
|
||||
|
||||
st = time.time()
|
||||
data = transform(data, self.preprocess_op)
|
||||
img, shape_list = data
|
||||
if img is None:
|
||||
return None, 0
|
||||
img = np.expand_dims(img, axis=0)
|
||||
shape_list = np.expand_dims(shape_list, axis=0)
|
||||
img = img.copy()
|
||||
input_dict = {}
|
||||
input_dict[self.input_tensor.name] = img
|
||||
for i in range(100000):
|
||||
try:
|
||||
outputs = self.predictor.run(None, input_dict, self.run_options)
|
||||
break
|
||||
except Exception as e:
|
||||
if i >= 3:
|
||||
raise e
|
||||
time.sleep(5)
|
||||
|
||||
post_result = self.postprocess_op({"maps": outputs[0]}, shape_list)
|
||||
dt_boxes = post_result[0]['points']
|
||||
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
||||
|
||||
return dt_boxes, time.time() - st
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
|
||||
class OCR:
|
||||
def __init__(self, model_dir=None):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
For Linux:
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
For Windows:
|
||||
Good luck
|
||||
^_-
|
||||
|
||||
"""
|
||||
if not model_dir:
|
||||
try:
|
||||
# 使用配置中的 MODEL_DIR,如果不存在则尝试默认路径
|
||||
if MODEL_DIR and os.path.exists(MODEL_DIR):
|
||||
model_dir = MODEL_DIR
|
||||
else:
|
||||
model_dir = os.path.join(
|
||||
get_project_base_directory(),
|
||||
"models", "deepdoc")
|
||||
|
||||
# Append muti-gpus task to the list
|
||||
if PARALLEL_DEVICES > 0:
|
||||
self.text_detector = []
|
||||
self.text_recognizer = []
|
||||
for device_id in range(PARALLEL_DEVICES):
|
||||
self.text_detector.append(TextDetector(model_dir, device_id))
|
||||
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
|
||||
else:
|
||||
self.text_detector = [TextDetector(model_dir)]
|
||||
self.text_recognizer = [TextRecognizer(model_dir)]
|
||||
|
||||
except Exception:
|
||||
# 如果模型目录不存在,尝试从 HuggingFace 下载
|
||||
default_model_dir = os.path.join(
|
||||
get_project_base_directory(), "models", "deepdoc")
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
|
||||
local_dir=default_model_dir,
|
||||
local_dir_use_symlinks=False)
|
||||
|
||||
if PARALLEL_DEVICES > 0:
|
||||
self.text_detector = []
|
||||
self.text_recognizer = []
|
||||
for device_id in range(PARALLEL_DEVICES):
|
||||
self.text_detector.append(TextDetector(model_dir, device_id))
|
||||
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
|
||||
else:
|
||||
self.text_detector = [TextDetector(model_dir)]
|
||||
self.text_recognizer = [TextRecognizer(model_dir)]
|
||||
else:
|
||||
# 如果指定了 model_dir,直接使用
|
||||
if PARALLEL_DEVICES > 0:
|
||||
self.text_detector = []
|
||||
self.text_recognizer = []
|
||||
for device_id in range(PARALLEL_DEVICES):
|
||||
self.text_detector.append(TextDetector(model_dir, device_id))
|
||||
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
|
||||
else:
|
||||
self.text_detector = [TextDetector(model_dir)]
|
||||
self.text_recognizer = [TextRecognizer(model_dir)]
|
||||
|
||||
self.drop_score = 0.5
|
||||
self.crop_image_res_index = 0
|
||||
|
||||
def get_rotate_crop_image(self, img, points):
|
||||
'''
|
||||
img_height, img_width = img.shape[0:2]
|
||||
left = int(np.min(points[:, 0]))
|
||||
right = int(np.max(points[:, 0]))
|
||||
top = int(np.min(points[:, 1]))
|
||||
bottom = int(np.max(points[:, 1]))
|
||||
img_crop = img[top:bottom, left:right, :].copy()
|
||||
points[:, 0] = points[:, 0] - left
|
||||
points[:, 1] = points[:, 1] - top
|
||||
'''
|
||||
assert len(points) == 4, "shape of points must be 4*2"
|
||||
img_crop_width = int(
|
||||
max(
|
||||
np.linalg.norm(points[0] - points[1]),
|
||||
np.linalg.norm(points[2] - points[3])))
|
||||
img_crop_height = int(
|
||||
max(
|
||||
np.linalg.norm(points[0] - points[3]),
|
||||
np.linalg.norm(points[1] - points[2])))
|
||||
pts_std = np.float32([[0, 0], [img_crop_width, 0],
|
||||
[img_crop_width, img_crop_height],
|
||||
[0, img_crop_height]])
|
||||
M = cv2.getPerspectiveTransform(points, pts_std)
|
||||
dst_img = cv2.warpPerspective(
|
||||
img,
|
||||
M, (img_crop_width, img_crop_height),
|
||||
borderMode=cv2.BORDER_REPLICATE,
|
||||
flags=cv2.INTER_CUBIC)
|
||||
dst_img_height, dst_img_width = dst_img.shape[0:2]
|
||||
if dst_img_height * 1.0 / dst_img_width >= 1.5:
|
||||
# Try original orientation
|
||||
rec_result = self.text_recognizer[0]([dst_img])
|
||||
text, score = rec_result[0][0]
|
||||
best_score = score
|
||||
best_img = dst_img
|
||||
|
||||
# Try clockwise 90° rotation
|
||||
rotated_cw = np.rot90(dst_img, k=3)
|
||||
rec_result = self.text_recognizer[0]([rotated_cw])
|
||||
rotated_cw_text, rotated_cw_score = rec_result[0][0]
|
||||
if rotated_cw_score > best_score:
|
||||
best_score = rotated_cw_score
|
||||
best_img = rotated_cw
|
||||
|
||||
# Try counter-clockwise 90° rotation
|
||||
rotated_ccw = np.rot90(dst_img, k=1)
|
||||
rec_result = self.text_recognizer[0]([rotated_ccw])
|
||||
rotated_ccw_text, rotated_ccw_score = rec_result[0][0]
|
||||
if rotated_ccw_score > best_score:
|
||||
best_img = rotated_ccw
|
||||
|
||||
# Use the best image
|
||||
dst_img = best_img
|
||||
return dst_img
|
||||
|
||||
def sorted_boxes(self, dt_boxes):
|
||||
"""
|
||||
Sort text boxes in order from top to bottom, left to right
|
||||
args:
|
||||
dt_boxes(array):detected text boxes with shape [4, 2]
|
||||
return:
|
||||
sorted boxes(array) with shape [4, 2]
|
||||
"""
|
||||
num_boxes = dt_boxes.shape[0]
|
||||
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
|
||||
_boxes = list(sorted_boxes)
|
||||
|
||||
for i in range(num_boxes - 1):
|
||||
for j in range(i, -1, -1):
|
||||
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
|
||||
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
|
||||
tmp = _boxes[j]
|
||||
_boxes[j] = _boxes[j + 1]
|
||||
_boxes[j + 1] = tmp
|
||||
else:
|
||||
break
|
||||
return _boxes
|
||||
|
||||
def detect(self, img, device_id: int | None = None):
|
||||
if device_id is None:
|
||||
device_id = 0
|
||||
|
||||
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
|
||||
|
||||
if img is None:
|
||||
return None, None, time_dict
|
||||
|
||||
start = time.time()
|
||||
dt_boxes, elapse = self.text_detector[device_id](img)
|
||||
time_dict['det'] = elapse
|
||||
|
||||
if dt_boxes is None:
|
||||
end = time.time()
|
||||
time_dict['all'] = end - start
|
||||
return None, None, time_dict
|
||||
|
||||
return zip(self.sorted_boxes(dt_boxes), [
|
||||
("", 0) for _ in range(len(dt_boxes))])
|
||||
|
||||
def recognize(self, ori_im, box, device_id: int | None = None):
|
||||
if device_id is None:
|
||||
device_id = 0
|
||||
|
||||
img_crop = self.get_rotate_crop_image(ori_im, box)
|
||||
|
||||
rec_res, elapse = self.text_recognizer[device_id]([img_crop])
|
||||
text, score = rec_res[0]
|
||||
if score < self.drop_score:
|
||||
return ""
|
||||
return text
|
||||
|
||||
def recognize_batch(self, img_list, device_id: int | None = None):
|
||||
if device_id is None:
|
||||
device_id = 0
|
||||
rec_res, elapse = self.text_recognizer[device_id](img_list)
|
||||
texts = []
|
||||
for i in range(len(rec_res)):
|
||||
text, score = rec_res[i]
|
||||
if score < self.drop_score:
|
||||
text = ""
|
||||
texts.append(text)
|
||||
return texts
|
||||
|
||||
def __call__(self, img, device_id = 0, cls=True):
|
||||
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
|
||||
if device_id is None:
|
||||
device_id = 0
|
||||
|
||||
if img is None:
|
||||
return None, None, time_dict
|
||||
|
||||
start = time.time()
|
||||
ori_im = img.copy()
|
||||
dt_boxes, elapse = self.text_detector[device_id](img)
|
||||
time_dict['det'] = elapse
|
||||
|
||||
if dt_boxes is None:
|
||||
end = time.time()
|
||||
time_dict['all'] = end - start
|
||||
return None, None, time_dict
|
||||
|
||||
img_crop_list = []
|
||||
|
||||
dt_boxes = self.sorted_boxes(dt_boxes)
|
||||
|
||||
for bno in range(len(dt_boxes)):
|
||||
tmp_box = copy.deepcopy(dt_boxes[bno])
|
||||
img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
|
||||
img_crop_list.append(img_crop)
|
||||
|
||||
rec_res, elapse = self.text_recognizer[device_id](img_crop_list)
|
||||
|
||||
time_dict['rec'] = elapse
|
||||
|
||||
filter_boxes, filter_rec_res = [], []
|
||||
for box, rec_result in zip(dt_boxes, rec_res):
|
||||
text, score = rec_result
|
||||
if score >= self.drop_score:
|
||||
filter_boxes.append(box)
|
||||
filter_rec_res.append(rec_result)
|
||||
end = time.time()
|
||||
time_dict['all'] = end - start
|
||||
|
||||
# for bno in range(len(img_crop_list)):
|
||||
# print(f"{bno}, {rec_res[bno]}")
|
||||
|
||||
return list(zip([a.tolist() for a in filter_boxes], filter_rec_res))
|
||||
|
||||
726
ocr/operators.py
Normal file
726
ocr/operators.py
Normal file
@@ -0,0 +1,726 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import six
|
||||
import cv2
|
||||
import numpy as np
|
||||
import math
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class DecodeImage:
|
||||
""" decode image """
|
||||
|
||||
def __init__(self,
|
||||
img_mode='RGB',
|
||||
channel_first=False,
|
||||
ignore_orientation=False,
|
||||
**kwargs):
|
||||
self.img_mode = img_mode
|
||||
self.channel_first = channel_first
|
||||
self.ignore_orientation = ignore_orientation
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
if six.PY2:
|
||||
assert isinstance(img, str) and len(
|
||||
img) > 0, "invalid input 'img' in DecodeImage"
|
||||
else:
|
||||
assert isinstance(img, bytes) and len(
|
||||
img) > 0, "invalid input 'img' in DecodeImage"
|
||||
img = np.frombuffer(img, dtype='uint8')
|
||||
if self.ignore_orientation:
|
||||
img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION |
|
||||
cv2.IMREAD_COLOR)
|
||||
else:
|
||||
img = cv2.imdecode(img, 1)
|
||||
if img is None:
|
||||
return None
|
||||
if self.img_mode == 'GRAY':
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
elif self.img_mode == 'RGB':
|
||||
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
|
||||
img.shape)
|
||||
img = img[:, :, ::-1]
|
||||
|
||||
if self.channel_first:
|
||||
img = img.transpose((2, 0, 1))
|
||||
|
||||
data['image'] = img
|
||||
return data
|
||||
|
||||
|
||||
class StandardizeImag:
|
||||
"""normalize image
|
||||
Args:
|
||||
mean (list): im - mean
|
||||
std (list): im / std
|
||||
is_scale (bool): whether need im / 255
|
||||
norm_type (str): type in ['mean_std', 'none']
|
||||
"""
|
||||
|
||||
def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.is_scale = is_scale
|
||||
self.norm_type = norm_type
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
im = im.astype(np.float32, copy=False)
|
||||
if self.is_scale:
|
||||
scale = 1.0 / 255.0
|
||||
im *= scale
|
||||
|
||||
if self.norm_type == 'mean_std':
|
||||
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
|
||||
std = np.array(self.std)[np.newaxis, np.newaxis, :]
|
||||
im -= mean
|
||||
im /= std
|
||||
return im, im_info
|
||||
|
||||
|
||||
class NormalizeImage:
|
||||
""" normalize image such as subtract mean, divide std
|
||||
"""
|
||||
|
||||
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
|
||||
if isinstance(scale, str):
|
||||
scale = eval(scale)
|
||||
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
|
||||
mean = mean if mean is not None else [0.485, 0.456, 0.406]
|
||||
std = std if std is not None else [0.229, 0.224, 0.225]
|
||||
|
||||
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
|
||||
self.mean = np.array(mean).reshape(shape).astype('float32')
|
||||
self.std = np.array(std).reshape(shape).astype('float32')
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
from PIL import Image
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.array(img)
|
||||
assert isinstance(img,
|
||||
np.ndarray), "invalid input 'img' in NormalizeImage"
|
||||
data['image'] = (
|
||||
img.astype('float32') * self.scale - self.mean) / self.std
|
||||
return data
|
||||
|
||||
|
||||
class ToCHWImage:
|
||||
""" convert hwc image to chw image
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
from PIL import Image
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.array(img)
|
||||
data['image'] = img.transpose((2, 0, 1))
|
||||
return data
|
||||
|
||||
|
||||
class KeepKeys:
|
||||
def __init__(self, keep_keys, **kwargs):
|
||||
self.keep_keys = keep_keys
|
||||
|
||||
def __call__(self, data):
|
||||
data_list = []
|
||||
for key in self.keep_keys:
|
||||
data_list.append(data[key])
|
||||
return data_list
|
||||
|
||||
|
||||
class Pad:
|
||||
def __init__(self, size=None, size_div=32, **kwargs):
|
||||
if size is not None and not isinstance(size, (int, list, tuple)):
|
||||
raise TypeError("Type of target_size is invalid. Now is {}".format(
|
||||
type(size)))
|
||||
if isinstance(size, int):
|
||||
size = [size, size]
|
||||
self.size = size
|
||||
self.size_div = size_div
|
||||
|
||||
def __call__(self, data):
|
||||
|
||||
img = data['image']
|
||||
img_h, img_w = img.shape[0], img.shape[1]
|
||||
if self.size:
|
||||
resize_h2, resize_w2 = self.size
|
||||
assert (
|
||||
img_h < resize_h2 and img_w < resize_w2
|
||||
), '(h, w) of target size should be greater than (img_h, img_w)'
|
||||
else:
|
||||
resize_h2 = max(
|
||||
int(math.ceil(img.shape[0] / self.size_div) * self.size_div),
|
||||
self.size_div)
|
||||
resize_w2 = max(
|
||||
int(math.ceil(img.shape[1] / self.size_div) * self.size_div),
|
||||
self.size_div)
|
||||
img = cv2.copyMakeBorder(
|
||||
img,
|
||||
0,
|
||||
resize_h2 - img_h,
|
||||
0,
|
||||
resize_w2 - img_w,
|
||||
cv2.BORDER_CONSTANT,
|
||||
value=0)
|
||||
data['image'] = img
|
||||
return data
|
||||
|
||||
|
||||
class LinearResize:
|
||||
"""resize image by target_size and max_size
|
||||
Args:
|
||||
target_size (int): the target size of image
|
||||
keep_ratio (bool): whether keep_ratio or not, default true
|
||||
interp (int): method of resize
|
||||
"""
|
||||
|
||||
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
|
||||
if isinstance(target_size, int):
|
||||
target_size = [target_size, target_size]
|
||||
self.target_size = target_size
|
||||
self.keep_ratio = keep_ratio
|
||||
self.interp = interp
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
assert len(self.target_size) == 2
|
||||
assert self.target_size[0] > 0 and self.target_size[1] > 0
|
||||
_im_channel = im.shape[2]
|
||||
im_scale_y, im_scale_x = self.generate_scale(im)
|
||||
im = cv2.resize(
|
||||
im,
|
||||
None,
|
||||
None,
|
||||
fx=im_scale_x,
|
||||
fy=im_scale_y,
|
||||
interpolation=self.interp)
|
||||
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
|
||||
im_info['scale_factor'] = np.array(
|
||||
[im_scale_y, im_scale_x]).astype('float32')
|
||||
return im, im_info
|
||||
|
||||
def generate_scale(self, im):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
Returns:
|
||||
im_scale_x: the resize ratio of X
|
||||
im_scale_y: the resize ratio of Y
|
||||
"""
|
||||
origin_shape = im.shape[:2]
|
||||
_im_c = im.shape[2]
|
||||
if self.keep_ratio:
|
||||
im_size_min = np.min(origin_shape)
|
||||
im_size_max = np.max(origin_shape)
|
||||
target_size_min = np.min(self.target_size)
|
||||
target_size_max = np.max(self.target_size)
|
||||
im_scale = float(target_size_min) / float(im_size_min)
|
||||
if np.round(im_scale * im_size_max) > target_size_max:
|
||||
im_scale = float(target_size_max) / float(im_size_max)
|
||||
im_scale_x = im_scale
|
||||
im_scale_y = im_scale
|
||||
else:
|
||||
resize_h, resize_w = self.target_size
|
||||
im_scale_y = resize_h / float(origin_shape[0])
|
||||
im_scale_x = resize_w / float(origin_shape[1])
|
||||
return im_scale_y, im_scale_x
|
||||
|
||||
|
||||
class Resize:
|
||||
def __init__(self, size=(640, 640), **kwargs):
|
||||
self.size = size
|
||||
|
||||
def resize_image(self, img):
|
||||
resize_h, resize_w = self.size
|
||||
ori_h, ori_w = img.shape[:2] # (h, w, c)
|
||||
ratio_h = float(resize_h) / ori_h
|
||||
ratio_w = float(resize_w) / ori_w
|
||||
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
if 'polys' in data:
|
||||
text_polys = data['polys']
|
||||
|
||||
img_resize, [ratio_h, ratio_w] = self.resize_image(img)
|
||||
if 'polys' in data:
|
||||
new_boxes = []
|
||||
for box in text_polys:
|
||||
new_box = []
|
||||
for cord in box:
|
||||
new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
|
||||
new_boxes.append(new_box)
|
||||
data['polys'] = np.array(new_boxes, dtype=np.float32)
|
||||
data['image'] = img_resize
|
||||
return data
|
||||
|
||||
|
||||
class DetResizeForTest:
|
||||
def __init__(self, **kwargs):
|
||||
super(DetResizeForTest, self).__init__()
|
||||
self.resize_type = 0
|
||||
self.keep_ratio = False
|
||||
if 'image_shape' in kwargs:
|
||||
self.image_shape = kwargs['image_shape']
|
||||
self.resize_type = 1
|
||||
if 'keep_ratio' in kwargs:
|
||||
self.keep_ratio = kwargs['keep_ratio']
|
||||
elif 'limit_side_len' in kwargs:
|
||||
self.limit_side_len = kwargs['limit_side_len']
|
||||
self.limit_type = kwargs.get('limit_type', 'min')
|
||||
elif 'resize_long' in kwargs:
|
||||
self.resize_type = 2
|
||||
self.resize_long = kwargs.get('resize_long', 960)
|
||||
else:
|
||||
self.limit_side_len = 736
|
||||
self.limit_type = 'min'
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
src_h, src_w, _ = img.shape
|
||||
if sum([src_h, src_w]) < 64:
|
||||
img = self.image_padding(img)
|
||||
|
||||
if self.resize_type == 0:
|
||||
# img, shape = self.resize_image_type0(img)
|
||||
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
|
||||
elif self.resize_type == 2:
|
||||
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
|
||||
else:
|
||||
# img, shape = self.resize_image_type1(img)
|
||||
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
|
||||
data['image'] = img
|
||||
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
|
||||
return data
|
||||
|
||||
def image_padding(self, im, value=0):
|
||||
h, w, c = im.shape
|
||||
im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
|
||||
im_pad[:h, :w, :] = im
|
||||
return im_pad
|
||||
|
||||
def resize_image_type1(self, img):
|
||||
resize_h, resize_w = self.image_shape
|
||||
ori_h, ori_w = img.shape[:2] # (h, w, c)
|
||||
if self.keep_ratio is True:
|
||||
resize_w = ori_w * resize_h / ori_h
|
||||
N = math.ceil(resize_w / 32)
|
||||
resize_w = N * 32
|
||||
ratio_h = float(resize_h) / ori_h
|
||||
ratio_w = float(resize_w) / ori_w
|
||||
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||
# return img, np.array([ori_h, ori_w])
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
def resize_image_type0(self, img):
|
||||
"""
|
||||
resize image to a size multiple of 32 which is required by the network
|
||||
args:
|
||||
img(array): array with shape [h, w, c]
|
||||
return(tuple):
|
||||
img, (ratio_h, ratio_w)
|
||||
"""
|
||||
limit_side_len = self.limit_side_len
|
||||
h, w, c = img.shape
|
||||
|
||||
# limit the max side
|
||||
if self.limit_type == 'max':
|
||||
if max(h, w) > limit_side_len:
|
||||
if h > w:
|
||||
ratio = float(limit_side_len) / h
|
||||
else:
|
||||
ratio = float(limit_side_len) / w
|
||||
else:
|
||||
ratio = 1.
|
||||
elif self.limit_type == 'min':
|
||||
if min(h, w) < limit_side_len:
|
||||
if h < w:
|
||||
ratio = float(limit_side_len) / h
|
||||
else:
|
||||
ratio = float(limit_side_len) / w
|
||||
else:
|
||||
ratio = 1.
|
||||
elif self.limit_type == 'resize_long':
|
||||
ratio = float(limit_side_len) / max(h, w)
|
||||
else:
|
||||
raise Exception('not support limit type, image ')
|
||||
resize_h = int(h * ratio)
|
||||
resize_w = int(w * ratio)
|
||||
|
||||
resize_h = max(int(round(resize_h / 32) * 32), 32)
|
||||
resize_w = max(int(round(resize_w / 32) * 32), 32)
|
||||
|
||||
try:
|
||||
if int(resize_w) <= 0 or int(resize_h) <= 0:
|
||||
return None, (None, None)
|
||||
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||
except BaseException:
|
||||
logging.exception("{} {} {}".format(img.shape, resize_w, resize_h))
|
||||
sys.exit(0)
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
def resize_image_type2(self, img):
|
||||
h, w, _ = img.shape
|
||||
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
|
||||
if resize_h > resize_w:
|
||||
ratio = float(self.resize_long) / resize_h
|
||||
else:
|
||||
ratio = float(self.resize_long) / resize_w
|
||||
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
|
||||
class E2EResizeForTest:
|
||||
def __init__(self, **kwargs):
|
||||
super(E2EResizeForTest, self).__init__()
|
||||
self.max_side_len = kwargs['max_side_len']
|
||||
self.valid_set = kwargs['valid_set']
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
src_h, src_w, _ = img.shape
|
||||
if self.valid_set == 'totaltext':
|
||||
im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
|
||||
img, max_side_len=self.max_side_len)
|
||||
else:
|
||||
im_resized, (ratio_h, ratio_w) = self.resize_image(
|
||||
img, max_side_len=self.max_side_len)
|
||||
data['image'] = im_resized
|
||||
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
|
||||
return data
|
||||
|
||||
def resize_image_for_totaltext(self, im, max_side_len=512):
|
||||
h, w, _ = im.shape
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
ratio = 1.25
|
||||
if h * ratio > max_side_len:
|
||||
ratio = float(max_side_len) / resize_h
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
return im, (ratio_h, ratio_w)
|
||||
|
||||
def resize_image(self, im, max_side_len=512):
|
||||
"""
|
||||
resize image to a size multiple of max_stride which is required by the network
|
||||
:param im: the resized image
|
||||
:param max_side_len: limit of max image size to avoid out of memory in gpu
|
||||
:return: the resized image and the resize ratio
|
||||
"""
|
||||
h, w, _ = im.shape
|
||||
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
|
||||
# Fix the longer side
|
||||
if resize_h > resize_w:
|
||||
ratio = float(max_side_len) / resize_h
|
||||
else:
|
||||
ratio = float(max_side_len) / resize_w
|
||||
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
|
||||
return im, (ratio_h, ratio_w)
|
||||
|
||||
|
||||
class KieResize:
|
||||
def __init__(self, **kwargs):
|
||||
super(KieResize, self).__init__()
|
||||
self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[
|
||||
'img_scale'][1]
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
points = data['points']
|
||||
src_h, src_w, _ = img.shape
|
||||
im_resized, scale_factor, [ratio_h, ratio_w
|
||||
], [new_h, new_w] = self.resize_image(img)
|
||||
resize_points = self.resize_boxes(img, points, scale_factor)
|
||||
data['ori_image'] = img
|
||||
data['ori_boxes'] = points
|
||||
data['points'] = resize_points
|
||||
data['image'] = im_resized
|
||||
data['shape'] = np.array([new_h, new_w])
|
||||
return data
|
||||
|
||||
def resize_image(self, img):
|
||||
norm_img = np.zeros([1024, 1024, 3], dtype='float32')
|
||||
scale = [512, 1024]
|
||||
h, w = img.shape[:2]
|
||||
max_long_edge = max(scale)
|
||||
max_short_edge = min(scale)
|
||||
scale_factor = min(max_long_edge / max(h, w),
|
||||
max_short_edge / min(h, w))
|
||||
resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float(
|
||||
scale_factor) + 0.5)
|
||||
max_stride = 32
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
im = cv2.resize(img, (resize_w, resize_h))
|
||||
new_h, new_w = im.shape[:2]
|
||||
w_scale = new_w / w
|
||||
h_scale = new_h / h
|
||||
scale_factor = np.array(
|
||||
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
|
||||
norm_img[:new_h, :new_w, :] = im
|
||||
return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
|
||||
|
||||
def resize_boxes(self, im, points, scale_factor):
|
||||
points = points * scale_factor
|
||||
img_shape = im.shape[:2]
|
||||
points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
|
||||
points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
|
||||
return points
|
||||
|
||||
|
||||
class SRResize:
|
||||
def __init__(self,
|
||||
imgH=32,
|
||||
imgW=128,
|
||||
down_sample_scale=4,
|
||||
keep_ratio=False,
|
||||
min_ratio=1,
|
||||
mask=False,
|
||||
infer_mode=False,
|
||||
**kwargs):
|
||||
self.imgH = imgH
|
||||
self.imgW = imgW
|
||||
self.keep_ratio = keep_ratio
|
||||
self.min_ratio = min_ratio
|
||||
self.down_sample_scale = down_sample_scale
|
||||
self.mask = mask
|
||||
self.infer_mode = infer_mode
|
||||
|
||||
def __call__(self, data):
|
||||
imgH = self.imgH
|
||||
imgW = self.imgW
|
||||
images_lr = data["image_lr"]
|
||||
transform2 = ResizeNormalize(
|
||||
(imgW // self.down_sample_scale, imgH // self.down_sample_scale))
|
||||
images_lr = transform2(images_lr)
|
||||
data["img_lr"] = images_lr
|
||||
if self.infer_mode:
|
||||
return data
|
||||
|
||||
images_HR = data["image_hr"]
|
||||
_label_strs = data["label"]
|
||||
transform = ResizeNormalize((imgW, imgH))
|
||||
images_HR = transform(images_HR)
|
||||
data["img_hr"] = images_HR
|
||||
return data
|
||||
|
||||
|
||||
class ResizeNormalize:
|
||||
def __init__(self, size, interpolation=Image.BICUBIC):
|
||||
self.size = size
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, img):
|
||||
img = img.resize(self.size, self.interpolation)
|
||||
img_numpy = np.array(img).astype("float32")
|
||||
img_numpy = img_numpy.transpose((2, 0, 1)) / 255
|
||||
return img_numpy
|
||||
|
||||
|
||||
class GrayImageChannelFormat:
|
||||
"""
|
||||
format gray scale image's channel: (3,h,w) -> (1,h,w)
|
||||
Args:
|
||||
inverse: inverse gray image
|
||||
"""
|
||||
|
||||
def __init__(self, inverse=False, **kwargs):
|
||||
self.inverse = inverse
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
img_single_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
img_expanded = np.expand_dims(img_single_channel, 0)
|
||||
|
||||
if self.inverse:
|
||||
data['image'] = np.abs(img_expanded - 1)
|
||||
else:
|
||||
data['image'] = img_expanded
|
||||
|
||||
data['src_image'] = img
|
||||
return data
|
||||
|
||||
|
||||
class Permute:
|
||||
"""permute image
|
||||
Args:
|
||||
to_bgr (bool): whether convert RGB to BGR
|
||||
channel_first (bool): whether convert HWC to CHW
|
||||
"""
|
||||
|
||||
def __init__(self, ):
|
||||
super(Permute, self).__init__()
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
im = im.transpose((2, 0, 1)).copy()
|
||||
return im, im_info
|
||||
|
||||
|
||||
class PadStride:
|
||||
""" padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
|
||||
Args:
|
||||
stride (bool): model with FPN need image shape % stride == 0
|
||||
"""
|
||||
|
||||
def __init__(self, stride=0):
|
||||
self.coarsest_stride = stride
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
coarsest_stride = self.coarsest_stride
|
||||
if coarsest_stride <= 0:
|
||||
return im, im_info
|
||||
im_c, im_h, im_w = im.shape
|
||||
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
|
||||
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
|
||||
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
|
||||
padding_im[:, :im_h, :im_w] = im
|
||||
return padding_im, im_info
|
||||
|
||||
|
||||
def decode_image(im_file, im_info):
|
||||
"""read rgb image
|
||||
Args:
|
||||
im_file (str|np.ndarray): input can be image path or np.ndarray
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
if isinstance(im_file, str):
|
||||
with open(im_file, 'rb') as f:
|
||||
im_read = f.read()
|
||||
data = np.frombuffer(im_read, dtype='uint8')
|
||||
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
|
||||
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
|
||||
else:
|
||||
im = im_file
|
||||
im_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32)
|
||||
im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32)
|
||||
return im, im_info
|
||||
|
||||
|
||||
def preprocess(im, preprocess_ops):
|
||||
# process image by preprocess_ops
|
||||
im_info = {
|
||||
'scale_factor': np.array(
|
||||
[1., 1.], dtype=np.float32),
|
||||
'im_shape': None,
|
||||
}
|
||||
im, im_info = decode_image(im, im_info)
|
||||
for operator in preprocess_ops:
|
||||
im, im_info = operator(im, im_info)
|
||||
return im, im_info
|
||||
|
||||
|
||||
def nms(bboxes, scores, iou_thresh):
|
||||
import numpy as np
|
||||
x1 = bboxes[:, 0]
|
||||
y1 = bboxes[:, 1]
|
||||
x2 = bboxes[:, 2]
|
||||
y2 = bboxes[:, 3]
|
||||
areas = (y2 - y1) * (x2 - x1)
|
||||
|
||||
indices = []
|
||||
index = scores.argsort()[::-1]
|
||||
while index.size > 0:
|
||||
i = index[0]
|
||||
indices.append(i)
|
||||
x11 = np.maximum(x1[i], x1[index[1:]])
|
||||
y11 = np.maximum(y1[i], y1[index[1:]])
|
||||
x22 = np.minimum(x2[i], x2[index[1:]])
|
||||
y22 = np.minimum(y2[i], y2[index[1:]])
|
||||
w = np.maximum(0, x22 - x11 + 1)
|
||||
h = np.maximum(0, y22 - y11 + 1)
|
||||
overlaps = w * h
|
||||
ious = overlaps / (areas[i] + areas[index[1:]] - overlaps)
|
||||
idx = np.where(ious <= iou_thresh)[0]
|
||||
index = index[idx + 1]
|
||||
return indices
|
||||
|
||||
339
ocr/pdf_parser.py
Normal file
339
ocr/pdf_parser.py
Normal file
@@ -0,0 +1,339 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
简化的PDF解析器,只使用OCR处理PDF文档
|
||||
|
||||
从 RAGFlow 的 RAGFlowPdfParser 中提取OCR相关功能,移除了:
|
||||
- 布局识别(Layout Recognition)
|
||||
- 表格结构识别(Table Structure Recognition)
|
||||
- 文本合并和语义分析
|
||||
- RAG相关功能
|
||||
|
||||
只保留:
|
||||
- PDF转图片
|
||||
- OCR文本检测和识别
|
||||
- 基本的文本和位置信息返回
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from timeit import default_timer as timer
|
||||
|
||||
import numpy as np
|
||||
import pdfplumber
|
||||
import trio
|
||||
|
||||
# 处理导入问题:支持直接运行和模块导入
|
||||
try:
|
||||
_package = __package__
|
||||
except NameError:
|
||||
_package = None
|
||||
|
||||
if _package is None:
|
||||
# 直接运行时,添加父目录到路径并使用绝对导入
|
||||
parent_dir = Path(__file__).parent.parent
|
||||
if str(parent_dir) not in sys.path:
|
||||
sys.path.insert(0, str(parent_dir))
|
||||
from ocr.config import PARALLEL_DEVICES
|
||||
from ocr.ocr import OCR
|
||||
else:
|
||||
# 作为模块导入时使用相对导入
|
||||
from config import PARALLEL_DEVICES
|
||||
from ocr import OCR
|
||||
|
||||
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
||||
if LOCK_KEY_pdfplumber not in sys.modules:
|
||||
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
|
||||
|
||||
|
||||
class SimplePdfParser:
|
||||
"""
|
||||
简化的PDF解析器,只使用OCR处理PDF
|
||||
|
||||
使用方法:
|
||||
parser = SimplePdfParser()
|
||||
result = parser.parse_pdf("file.pdf") # 或传入二进制数据
|
||||
# result 格式:
|
||||
# {
|
||||
# "pages": [
|
||||
# {
|
||||
# "page_number": 1,
|
||||
# "boxes": [
|
||||
# {
|
||||
# "text": "识别到的文本",
|
||||
# "bbox": [[x0, y0], [x1, y0], [x1, y1], [x0, y1]],
|
||||
# "confidence": 0.95
|
||||
# },
|
||||
# ...
|
||||
# ]
|
||||
# },
|
||||
# ...
|
||||
# ]
|
||||
# }
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir=None):
|
||||
"""
|
||||
初始化PDF解析器
|
||||
|
||||
Args:
|
||||
model_dir: OCR模型目录,如果为None则使用默认路径
|
||||
"""
|
||||
self.ocr = OCR(model_dir=model_dir)
|
||||
self.parallel_limiter = None
|
||||
if PARALLEL_DEVICES > 1:
|
||||
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)]
|
||||
|
||||
def __ocr_page(self, page_num, img, zoomin=3, device_id=None):
|
||||
"""
|
||||
对单页进行OCR处理
|
||||
|
||||
Args:
|
||||
page_num: 页码
|
||||
img: PIL Image对象
|
||||
zoomin: 放大倍数(用于坐标缩放)
|
||||
device_id: GPU设备ID
|
||||
|
||||
Returns:
|
||||
list: OCR结果列表,每个元素为 {"text": str, "bbox": list, "confidence": float}
|
||||
"""
|
||||
start = timer()
|
||||
img_np = np.array(img)
|
||||
|
||||
# 文本检测
|
||||
# detect方法返回: zip对象,格式为 (box_coords, (text, score))
|
||||
# 但检测阶段text和score都是默认值,需要后续识别
|
||||
detection_result = self.ocr.detect(img_np, device_id)
|
||||
|
||||
if detection_result is None:
|
||||
return []
|
||||
|
||||
# 转换为列表并提取box坐标
|
||||
# detect返回的格式是zip,每个元素是 (box_coords, (text, score))
|
||||
# 在检测阶段,text是空字符串,score是0
|
||||
bxs = list(detection_result)
|
||||
|
||||
logging.info(f"Page {page_num}: OCR detection found {len(bxs)} boxes in {timer() - start:.2f}s")
|
||||
|
||||
if not bxs:
|
||||
return []
|
||||
|
||||
# 解析检测结果并准备识别
|
||||
boxes_to_reg = []
|
||||
|
||||
start = timer()
|
||||
for box_coords, _, _ in bxs:
|
||||
# box_coords 是四边形坐标: [[x0, y0], [x1, y0], [x1, y1], [x0, y1]]
|
||||
# 转换为原始坐标(考虑zoomin)
|
||||
box_coords_np = np.array(box_coords, dtype=np.float32)
|
||||
original_coords = box_coords_np / zoomin # 缩放回原始坐标
|
||||
|
||||
# 裁剪图像用于识别
|
||||
# 使用放大后的坐标裁剪(因为img_np是放大后的图像)
|
||||
crop_box = box_coords_np
|
||||
crop_img = self.ocr.get_rotate_crop_image(img_np, crop_box)
|
||||
boxes_to_reg.append({
|
||||
"bbox": original_coords.tolist(),
|
||||
"crop_img": crop_img
|
||||
})
|
||||
|
||||
# 批量识别文本
|
||||
ocr_results = []
|
||||
if boxes_to_reg:
|
||||
crop_imgs = [b["crop_img"] for b in boxes_to_reg]
|
||||
texts = self.ocr.recognize_batch(crop_imgs, device_id)
|
||||
|
||||
# 组装结果
|
||||
for i, b in enumerate(boxes_to_reg):
|
||||
if i < len(texts) and texts[i]: # 过滤空文本
|
||||
ocr_results.append({
|
||||
"text": texts[i],
|
||||
"bbox": b["bbox"],
|
||||
"confidence": 0.9 # 简化版本,不计算具体置信度
|
||||
})
|
||||
|
||||
logging.info(f"Page {page_num}: OCR recognition {len(ocr_results)} boxes cost {timer() - start:.2f}s")
|
||||
return ocr_results
|
||||
|
||||
async def __ocr_page_async(self, page_num, img, zoomin, device_id, limiter, callback):
|
||||
"""
|
||||
异步OCR处理单页
|
||||
|
||||
Args:
|
||||
page_num: 页码
|
||||
img: PIL Image对象
|
||||
zoomin: 放大倍数
|
||||
device_id: GPU设备ID
|
||||
limiter: 并发限制器
|
||||
callback: 进度回调函数
|
||||
"""
|
||||
if limiter:
|
||||
async with limiter:
|
||||
result = await trio.to_thread.run_sync(
|
||||
lambda: self.__ocr_page(page_num, img, zoomin, device_id)
|
||||
)
|
||||
else:
|
||||
result = await trio.to_thread.run_sync(
|
||||
lambda: self.__ocr_page(page_num, img, zoomin, device_id)
|
||||
)
|
||||
|
||||
if callback and page_num % 5 == 0:
|
||||
callback(prog=page_num * 0.9 / 100, msg=f"Processing page {page_num}...")
|
||||
|
||||
return result
|
||||
|
||||
def __convert_pdf_to_images(self, pdf_source, zoomin=3, page_from=0, page_to=299):
|
||||
"""
|
||||
将PDF转换为图片
|
||||
|
||||
Args:
|
||||
pdf_source: PDF文件路径(str)或二进制数据(bytes)
|
||||
zoomin: 放大倍数,默认3(72*3=216 DPI)
|
||||
page_from: 起始页码(从0开始)
|
||||
page_to: 结束页码
|
||||
|
||||
Returns:
|
||||
list: PIL Image对象列表
|
||||
"""
|
||||
start = timer()
|
||||
page_images = []
|
||||
|
||||
try:
|
||||
with sys.modules[LOCK_KEY_pdfplumber]:
|
||||
pdf = pdfplumber.open(pdf_source) if isinstance(pdf_source, str) else pdfplumber.open(BytesIO(pdf_source))
|
||||
try:
|
||||
# 转换为图片,resolution = 72 * zoomin
|
||||
page_images = [
|
||||
p.to_image(resolution=72 * zoomin, antialias=True).annotated
|
||||
for i, p in enumerate(pdf.pages[page_from:page_to])
|
||||
]
|
||||
pdf.close()
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to convert PDF pages {page_from}-{page_to}: {str(e)}")
|
||||
if hasattr(pdf, 'close'):
|
||||
pdf.close()
|
||||
except Exception as e:
|
||||
logging.exception(f"Error converting PDF to images: {str(e)}")
|
||||
|
||||
logging.info(f"Converted {len(page_images)} pages to images in {timer() - start:.2f}s")
|
||||
return page_images
|
||||
|
||||
def parse_pdf(self, pdf_source, zoomin=3, page_from=0, page_to=299, callback=None):
|
||||
"""
|
||||
解析PDF文档,使用OCR识别文本
|
||||
|
||||
Args:
|
||||
pdf_source: PDF文件路径(str)或二进制数据(bytes)
|
||||
zoomin: 放大倍数,默认3
|
||||
page_from: 起始页码(从0开始)
|
||||
page_to: 结束页码
|
||||
callback: 进度回调函数,格式: callback(prog: float, msg: str)
|
||||
|
||||
Returns:
|
||||
dict: 解析结果
|
||||
{
|
||||
"pages": [
|
||||
{
|
||||
"page_number": int,
|
||||
"boxes": [
|
||||
{
|
||||
"text": str,
|
||||
"bbox": [[x0, y0], [x1, y0], [x1, y1], [x0, y1]],
|
||||
"confidence": float
|
||||
},
|
||||
...
|
||||
]
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
"""
|
||||
if callback:
|
||||
callback(0.0, "Starting PDF parsing...")
|
||||
|
||||
# 1. 转换为图片
|
||||
if callback:
|
||||
callback(0.1, "Converting PDF to images...")
|
||||
page_images = self.__convert_pdf_to_images(pdf_source, zoomin, page_from, page_to)
|
||||
|
||||
if not page_images:
|
||||
logging.warning("No pages converted from PDF")
|
||||
return {"pages": []}
|
||||
|
||||
# 2. OCR处理
|
||||
async def process_all_pages():
|
||||
pages_result = []
|
||||
|
||||
if self.parallel_limiter:
|
||||
# 并行处理(多GPU)
|
||||
async with trio.open_nursery() as nursery:
|
||||
tasks = []
|
||||
for i, img in enumerate(page_images):
|
||||
page_num = page_from + i + 1
|
||||
device_id = i % PARALLEL_DEVICES
|
||||
task = nursery.start_soon(
|
||||
self.__ocr_page_async,
|
||||
page_num, img, zoomin, device_id,
|
||||
self.parallel_limiter[device_id], callback
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# 等待所有任务完成并收集结果
|
||||
for i, task in enumerate(tasks):
|
||||
result = await task
|
||||
pages_result.append({
|
||||
"page_number": page_from + i + 1,
|
||||
"boxes": result
|
||||
})
|
||||
else:
|
||||
# 串行处理(单GPU或CPU)
|
||||
for i, img in enumerate(page_images):
|
||||
page_num = page_from + i + 1
|
||||
result = await trio.to_thread.run_sync(
|
||||
lambda img=img, pn=page_num: self.__ocr_page(pn, img, zoomin, 0)
|
||||
)
|
||||
pages_result.append({
|
||||
"page_number": page_num,
|
||||
"boxes": result
|
||||
})
|
||||
if callback:
|
||||
callback(0.1 + (i + 1) * 0.9 / len(page_images), f"Processing page {page_num}...")
|
||||
|
||||
return pages_result
|
||||
|
||||
# 运行异步处理
|
||||
if callback:
|
||||
callback(0.2, "Starting OCR processing...")
|
||||
|
||||
start = timer()
|
||||
pages_result = trio.run(process_all_pages)
|
||||
logging.info(f"OCR processing completed in {timer() - start:.2f}s")
|
||||
|
||||
if callback:
|
||||
callback(1.0, "OCR processing completed")
|
||||
|
||||
return {
|
||||
"pages": pages_result
|
||||
}
|
||||
|
||||
|
||||
# 向后兼容的别名
|
||||
PdfParser = SimplePdfParser
|
||||
|
||||
371
ocr/postprocess.py
Normal file
371
ocr/postprocess.py
Normal file
@@ -0,0 +1,371 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import copy
|
||||
import re
|
||||
import numpy as np
|
||||
import cv2
|
||||
from shapely.geometry import Polygon
|
||||
import pyclipper
|
||||
|
||||
|
||||
def build_post_process(config, global_config=None):
|
||||
support_dict = {'DBPostProcess': DBPostProcess, 'CTCLabelDecode': CTCLabelDecode}
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
if module_name == "None":
|
||||
return
|
||||
if global_config is not None:
|
||||
config.update(global_config)
|
||||
module_class = support_dict.get(module_name)
|
||||
if module_class is None:
|
||||
raise ValueError(
|
||||
'post process only support {}'.format(list(support_dict)))
|
||||
return module_class(**config)
|
||||
|
||||
|
||||
class DBPostProcess:
|
||||
"""
|
||||
The post process for Differentiable Binarization (DB).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
thresh=0.3,
|
||||
box_thresh=0.7,
|
||||
max_candidates=1000,
|
||||
unclip_ratio=2.0,
|
||||
use_dilation=False,
|
||||
score_mode="fast",
|
||||
box_type='quad',
|
||||
**kwargs):
|
||||
self.thresh = thresh
|
||||
self.box_thresh = box_thresh
|
||||
self.max_candidates = max_candidates
|
||||
self.unclip_ratio = unclip_ratio
|
||||
self.min_size = 3
|
||||
self.score_mode = score_mode
|
||||
self.box_type = box_type
|
||||
assert score_mode in [
|
||||
"slow", "fast"
|
||||
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
|
||||
|
||||
self.dilation_kernel = None if not use_dilation else np.array(
|
||||
[[1, 1], [1, 1]])
|
||||
|
||||
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
||||
'''
|
||||
_bitmap: single map with shape (1, H, W),
|
||||
whose values are binarized as {0, 1}
|
||||
'''
|
||||
|
||||
bitmap = _bitmap
|
||||
height, width = bitmap.shape
|
||||
|
||||
boxes = []
|
||||
scores = []
|
||||
|
||||
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
|
||||
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
for contour in contours[:self.max_candidates]:
|
||||
epsilon = 0.002 * cv2.arcLength(contour, True)
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||||
points = approx.reshape((-1, 2))
|
||||
if points.shape[0] < 4:
|
||||
continue
|
||||
|
||||
score = self.box_score_fast(pred, points.reshape(-1, 2))
|
||||
if self.box_thresh > score:
|
||||
continue
|
||||
|
||||
if points.shape[0] > 2:
|
||||
box = self.unclip(points, self.unclip_ratio)
|
||||
if len(box) > 1:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
box = box.reshape(-1, 2)
|
||||
|
||||
_, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
|
||||
if sside < self.min_size + 2:
|
||||
continue
|
||||
|
||||
box = np.array(box)
|
||||
box[:, 0] = np.clip(
|
||||
np.round(box[:, 0] / width * dest_width), 0, dest_width)
|
||||
box[:, 1] = np.clip(
|
||||
np.round(box[:, 1] / height * dest_height), 0, dest_height)
|
||||
boxes.append(box.tolist())
|
||||
scores.append(score)
|
||||
return boxes, scores
|
||||
|
||||
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
||||
'''
|
||||
_bitmap: single map with shape (1, H, W),
|
||||
whose values are binarized as {0, 1}
|
||||
'''
|
||||
|
||||
bitmap = _bitmap
|
||||
height, width = bitmap.shape
|
||||
|
||||
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
|
||||
cv2.CHAIN_APPROX_SIMPLE)
|
||||
if len(outs) == 3:
|
||||
_img, contours, _ = outs[0], outs[1], outs[2]
|
||||
elif len(outs) == 2:
|
||||
contours, _ = outs[0], outs[1]
|
||||
|
||||
num_contours = min(len(contours), self.max_candidates)
|
||||
|
||||
boxes = []
|
||||
scores = []
|
||||
for index in range(num_contours):
|
||||
contour = contours[index]
|
||||
points, sside = self.get_mini_boxes(contour)
|
||||
if sside < self.min_size:
|
||||
continue
|
||||
points = np.array(points)
|
||||
if self.score_mode == "fast":
|
||||
score = self.box_score_fast(pred, points.reshape(-1, 2))
|
||||
else:
|
||||
score = self.box_score_slow(pred, contour)
|
||||
if self.box_thresh > score:
|
||||
continue
|
||||
|
||||
box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
|
||||
box, sside = self.get_mini_boxes(box)
|
||||
if sside < self.min_size + 2:
|
||||
continue
|
||||
box = np.array(box)
|
||||
|
||||
box[:, 0] = np.clip(
|
||||
np.round(box[:, 0] / width * dest_width), 0, dest_width)
|
||||
box[:, 1] = np.clip(
|
||||
np.round(box[:, 1] / height * dest_height), 0, dest_height)
|
||||
boxes.append(box.astype("int32"))
|
||||
scores.append(score)
|
||||
return np.array(boxes, dtype="int32"), scores
|
||||
|
||||
def unclip(self, box, unclip_ratio):
|
||||
poly = Polygon(box)
|
||||
distance = poly.area * unclip_ratio / poly.length
|
||||
offset = pyclipper.PyclipperOffset()
|
||||
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
||||
expanded = np.array(offset.Execute(distance))
|
||||
return expanded
|
||||
|
||||
def get_mini_boxes(self, contour):
|
||||
bounding_box = cv2.minAreaRect(contour)
|
||||
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
|
||||
|
||||
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
|
||||
if points[1][1] > points[0][1]:
|
||||
index_1 = 0
|
||||
index_4 = 1
|
||||
else:
|
||||
index_1 = 1
|
||||
index_4 = 0
|
||||
if points[3][1] > points[2][1]:
|
||||
index_2 = 2
|
||||
index_3 = 3
|
||||
else:
|
||||
index_2 = 3
|
||||
index_3 = 2
|
||||
|
||||
box = [
|
||||
points[index_1], points[index_2], points[index_3], points[index_4]
|
||||
]
|
||||
return box, min(bounding_box[1])
|
||||
|
||||
def box_score_fast(self, bitmap, _box):
|
||||
'''
|
||||
box_score_fast: use bbox mean score as the mean score
|
||||
'''
|
||||
h, w = bitmap.shape[:2]
|
||||
box = _box.copy()
|
||||
xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
|
||||
xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1)
|
||||
ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1)
|
||||
ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1)
|
||||
|
||||
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
||||
box[:, 0] = box[:, 0] - xmin
|
||||
box[:, 1] = box[:, 1] - ymin
|
||||
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1)
|
||||
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||
|
||||
def box_score_slow(self, bitmap, contour):
|
||||
'''
|
||||
box_score_slow: use polyon mean score as the mean score
|
||||
'''
|
||||
h, w = bitmap.shape[:2]
|
||||
contour = contour.copy()
|
||||
contour = np.reshape(contour, (-1, 2))
|
||||
|
||||
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
|
||||
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
|
||||
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
|
||||
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
|
||||
|
||||
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
||||
|
||||
contour[:, 0] = contour[:, 0] - xmin
|
||||
contour[:, 1] = contour[:, 1] - ymin
|
||||
|
||||
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1)
|
||||
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||
|
||||
def __call__(self, outs_dict, shape_list):
|
||||
pred = outs_dict['maps']
|
||||
if not isinstance(pred, np.ndarray):
|
||||
pred = pred.numpy()
|
||||
pred = pred[:, 0, :, :]
|
||||
segmentation = pred > self.thresh
|
||||
|
||||
boxes_batch = []
|
||||
for batch_index in range(pred.shape[0]):
|
||||
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
|
||||
if self.dilation_kernel is not None:
|
||||
mask = cv2.dilate(
|
||||
np.array(segmentation[batch_index]).astype(np.uint8),
|
||||
self.dilation_kernel)
|
||||
else:
|
||||
mask = segmentation[batch_index]
|
||||
if self.box_type == 'poly':
|
||||
boxes, scores = self.polygons_from_bitmap(pred[batch_index],
|
||||
mask, src_w, src_h)
|
||||
elif self.box_type == 'quad':
|
||||
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
|
||||
src_w, src_h)
|
||||
else:
|
||||
raise ValueError(
|
||||
"box_type can only be one of ['quad', 'poly']")
|
||||
|
||||
boxes_batch.append({'points': boxes})
|
||||
return boxes_batch
|
||||
|
||||
|
||||
class BaseRecLabelDecode:
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=False):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
self.reverse = False
|
||||
self.character_str = []
|
||||
|
||||
if character_dict_path is None:
|
||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
dict_character = list(self.character_str)
|
||||
else:
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
||||
self.character_str.append(line)
|
||||
if use_space_char:
|
||||
self.character_str.append(" ")
|
||||
dict_character = list(self.character_str)
|
||||
if 'arabic' in character_dict_path:
|
||||
self.reverse = True
|
||||
|
||||
dict_character = self.add_special_char(dict_character)
|
||||
self.dict = {}
|
||||
for i, char in enumerate(dict_character):
|
||||
self.dict[char] = i
|
||||
self.character = dict_character
|
||||
|
||||
def pred_reverse(self, pred):
|
||||
pred_re = []
|
||||
c_current = ''
|
||||
for c in pred:
|
||||
if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)):
|
||||
if c_current != '':
|
||||
pred_re.append(c_current)
|
||||
pred_re.append(c)
|
||||
c_current = ''
|
||||
else:
|
||||
c_current += c
|
||||
if c_current != '':
|
||||
pred_re.append(c_current)
|
||||
|
||||
return ''.join(pred_re[::-1])
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
return dict_character
|
||||
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||
""" convert text-index into text-label. """
|
||||
result_list = []
|
||||
ignored_tokens = self.get_ignored_tokens()
|
||||
batch_size = len(text_index)
|
||||
for batch_idx in range(batch_size):
|
||||
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
|
||||
if is_remove_duplicate:
|
||||
selection[1:] = text_index[batch_idx][1:] != text_index[
|
||||
batch_idx][:-1]
|
||||
for ignored_token in ignored_tokens:
|
||||
selection &= text_index[batch_idx] != ignored_token
|
||||
|
||||
char_list = [
|
||||
self.character[text_id]
|
||||
for text_id in text_index[batch_idx][selection]
|
||||
]
|
||||
if text_prob is not None:
|
||||
conf_list = text_prob[batch_idx][selection]
|
||||
else:
|
||||
conf_list = [1] * len(selection)
|
||||
if len(conf_list) == 0:
|
||||
conf_list = [0]
|
||||
|
||||
text = ''.join(char_list)
|
||||
|
||||
if self.reverse: # for arabic rec
|
||||
text = self.pred_reverse(text)
|
||||
|
||||
result_list.append((text, np.mean(conf_list).tolist()))
|
||||
return result_list
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
return [0] # for ctc blank
|
||||
|
||||
|
||||
class CTCLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||
**kwargs):
|
||||
super(CTCLabelDecode, self).__init__(character_dict_path,
|
||||
use_space_char)
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if isinstance(preds, tuple) or isinstance(preds, list):
|
||||
preds = preds[-1]
|
||||
if not isinstance(preds, np.ndarray):
|
||||
preds = preds.numpy()
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
preds_prob = preds.max(axis=2)
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label)
|
||||
return text, label
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
dict_character = ['blank'] + dict_character
|
||||
return dict_character
|
||||
|
||||
25
ocr/requirements.txt
Normal file
25
ocr/requirements.txt
Normal file
@@ -0,0 +1,25 @@
|
||||
# OCR PDF处理模块依赖
|
||||
# 核心依赖
|
||||
numpy>=1.21.0
|
||||
opencv-python>=4.5.0
|
||||
pillow>=8.0.0
|
||||
pdfplumber>=0.9.0
|
||||
onnxruntime>=1.12.0
|
||||
trio>=0.22.0
|
||||
|
||||
# 几何计算依赖
|
||||
shapely>=1.8.0
|
||||
pyclipper>=1.2.0
|
||||
|
||||
# Web框架依赖
|
||||
fastapi>=0.100.0
|
||||
uvicorn[standard]>=0.23.0
|
||||
pydantic>=2.0.0
|
||||
|
||||
# 模型下载依赖
|
||||
huggingface_hub>=0.16.0
|
||||
|
||||
# 可选依赖(用于GPU检测和加速)
|
||||
# torch>=1.12.0 # 如果需要GPU支持,取消注释并安装
|
||||
# onnxruntime-gpu>=1.12.0 # 如果需要GPU支持,取消注释并安装
|
||||
|
||||
40
ocr/utils.py
Normal file
40
ocr/utils.py
Normal file
@@ -0,0 +1,40 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
OCR 模块工具函数
|
||||
"""
|
||||
import os
|
||||
|
||||
|
||||
def get_project_base_directory(*args):
|
||||
"""
|
||||
获取项目根目录
|
||||
|
||||
Args:
|
||||
*args: 可选的子路径
|
||||
|
||||
Returns:
|
||||
str: 项目根目录路径
|
||||
"""
|
||||
# 获取当前文件的目录
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
# 返回 ocr 模块的父目录(项目根目录)
|
||||
base_dir = os.path.dirname(current_dir)
|
||||
|
||||
if args:
|
||||
return os.path.join(base_dir, *args)
|
||||
return base_dir
|
||||
|
||||
Reference in New Issue
Block a user