From 3e58c3d0e9e03197be9349ff01311542357932c0 Mon Sep 17 00:00:00 2001 From: dangzerong <429714019@qq.com> Date: Mon, 3 Nov 2025 10:22:28 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96OCR=E8=A7=A3=E6=9E=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ocr/main.py => main-ocr.py | 6 + ocr/__init__.py | 6 +- ocr/api.py | 14 +- ocr/client.py | 239 +++++++++++++++ ocr/service.py | 290 ++++++++++++++++++ .../hierarchical_merger.py | 19 +- rag/flow/parser/parser.py | 7 +- rag/flow/splitter/splitter.py | 20 +- rag/nlp/__init__.py | 10 +- 9 files changed, 581 insertions(+), 30 deletions(-) rename ocr/main.py => main-ocr.py (96%) create mode 100644 ocr/client.py create mode 100644 ocr/service.py diff --git a/ocr/main.py b/main-ocr.py similarity index 96% rename from ocr/main.py rename to main-ocr.py index b03dacf..19522f3 100644 --- a/ocr/main.py +++ b/main-ocr.py @@ -25,6 +25,12 @@ import sys import signal from pathlib import Path +# 确保项目根目录在 sys.path 中 +_current_file = Path(__file__).resolve() +_project_root = _current_file.parent.parent +if str(_project_root) not in sys.path: + sys.path.insert(0, str(_project_root)) + import uvicorn from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware diff --git a/ocr/__init__.py b/ocr/__init__.py index 4dace37..5cf99ac 100644 --- a/ocr/__init__.py +++ b/ocr/__init__.py @@ -20,12 +20,15 @@ 可以直接作为独立模块使用。 使用方法: - from ocr import OCR + from ocr import OCR, SimplePdfParser import cv2 ocr = OCR() img = cv2.imread("image.jpg") results = ocr(img) + + parser = SimplePdfParser() + result = parser.parse_pdf("document.pdf") """ # 处理导入问题:支持直接运行和模块导入 @@ -35,3 +38,4 @@ from pathlib import Path __all__ = ['OCR', 'TextDetector', 'TextRecognizer', 'SimplePdfParser'] + diff --git a/ocr/api.py b/ocr/api.py index 957ecbc..0ff0e61 100644 --- a/ocr/api.py +++ b/ocr/api.py @@ -57,7 +57,7 @@ class ParseResponse(BaseModel): data: Optional[dict] = None -@router.get( +@ocr_router.get( "/health", summary="健康检查", description="检查OCR服务的健康状态和配置信息", @@ -79,7 +79,7 @@ async def health_check(): } -@router.post( +@ocr_router.post( "/parse", response_model=ParseResponse, summary="上传并解析PDF文件", @@ -165,7 +165,7 @@ async def parse_pdf_endpoint( logger.warning(f"Failed to delete temp file {temp_file}: {e}") -@router.post( +@ocr_router.post( "/parse/bytes", response_model=ParseResponse, summary="通过二进制数据解析PDF", @@ -244,7 +244,7 @@ async def parse_pdf_bytes( logger.warning(f"Failed to delete temp file {temp_file}: {e}") -@router.post( +@ocr_router.post( "/parse/path", response_model=ParseResponse, summary="通过文件路径解析PDF", @@ -315,7 +315,7 @@ async def parse_pdf_path( ) -@router.post( +@ocr_router.post( "/parse_into_bboxes", summary="解析PDF并返回边界框", description="解析PDF文件并返回文本边界框信息,用于文档结构化处理", @@ -414,7 +414,7 @@ class RemoveTagResponse(BaseModel): text: Optional[str] = None -@router.post( +@ocr_router.post( "/remove_tag", response_model=RemoveTagResponse, summary="移除文本中的位置标签", @@ -464,7 +464,7 @@ class ExtractPositionsResponse(BaseModel): positions: Optional[list] = None -@router.post( +@ocr_router.post( "/extract_positions", response_model=ExtractPositionsResponse, summary="从文本中提取位置信息", diff --git a/ocr/client.py b/ocr/client.py new file mode 100644 index 0000000..526a20e --- /dev/null +++ b/ocr/client.py @@ -0,0 +1,239 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +OCR HTTP 客户端工具类 +用于通过 HTTP 接口调用 OCR 服务 +""" + +import logging +import os +from typing import Optional, Callable, List, Tuple, Any + +try: + import httpx + HAS_HTTPX = True +except ImportError: + HAS_HTTPX = False + import aiohttp + +logger = logging.getLogger(__name__) + + +class OCRClient: + """OCR HTTP 客户端,用于调用 OCR API""" + + def __init__(self, base_url: Optional[str] = None, timeout: float = 300.0): + """ + 初始化 OCR 客户端 + + Args: + base_url: OCR 服务的基础 URL,如果不提供则从环境变量 OCR_SERVICE_URL 获取, + 如果仍未设置则默认为 http://localhost:8000/api/v1/ocr + timeout: 请求超时时间(秒),默认 300 秒 + """ + self.base_url = base_url or os.getenv("OCR_SERVICE_URL", "http://localhost:8000/api/v1/ocr") + self.timeout = timeout + # 移除末尾的斜杠 + if self.base_url.endswith('/'): + self.base_url = self.base_url.rstrip('/') + + async def _make_request(self, method: str, endpoint: str, **kwargs) -> dict: + """内部方法:发送 HTTP 请求""" + url = f"{self.base_url}{endpoint}" + + if HAS_HTTPX: + async with httpx.AsyncClient(timeout=self.timeout) as client: + response = await client.request(method, url, **kwargs) + response.raise_for_status() + return response.json() + else: + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session: + async with session.request(method, url, **kwargs) as response: + response.raise_for_status() + return await response.json() + + async def remove_tag(self, text: str) -> str: + """ + 移除文本中的位置标签 + + Args: + text: 包含位置标签的文本 + + Returns: + 移除标签后的文本 + """ + response = await self._make_request( + "POST", + "/remove_tag", + json={"text": text} + ) + if response.get("success") and response.get("text") is not None: + return response["text"] + raise Exception(f"移除标签失败: {response.get('message', '未知错误')}") + + def remove_tag_sync(self, text: str) -> str: + """ + 同步版本的 remove_tag(用于同步代码) + + Args: + text: 包含位置标签的文本 + + Returns: + 移除标签后的文本 + """ + import asyncio + try: + loop = asyncio.get_event_loop() + return loop.run_until_complete(self.remove_tag(text)) + except RuntimeError: + # 如果没有事件循环,创建一个新的 + return asyncio.run(self.remove_tag(text)) + + async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]: + """ + 从文本中提取位置信息 + + Args: + text: 包含位置标签的文本 + + Returns: + 位置信息列表,格式为 [(页码列表, left, right, top, bottom), ...] + """ + response = await self._make_request( + "POST", + "/extract_positions", + json={"text": text} + ) + if response.get("success") and response.get("positions") is not None: + # 将响应格式转换为原始格式 + positions = [] + for pos in response["positions"]: + positions.append(( + pos["page_numbers"], + pos["left"], + pos["right"], + pos["top"], + pos["bottom"] + )) + return positions + raise Exception(f"提取位置信息失败: {response.get('message', '未知错误')}") + + def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]: + """ + 同步版本的 extract_positions(用于同步代码) + + Args: + text: 包含位置标签的文本 + + Returns: + 位置信息列表 + """ + import asyncio + try: + loop = asyncio.get_event_loop() + return loop.run_until_complete(self.extract_positions(text)) + except RuntimeError: + return asyncio.run(self.extract_positions(text)) + + async def parse_into_bboxes( + self, + pdf_bytes: bytes, + callback: Optional[Callable[[float, str], None]] = None, + zoomin: int = 3, + filename: str = "document.pdf" + ) -> List[dict]: + """ + 解析 PDF 并返回边界框 + + Args: + pdf_bytes: PDF 文件的二进制数据 + callback: 进度回调函数 (progress: float, message: str) -> None + zoomin: 图像放大倍数(1-5,默认为3) + filename: 文件名(仅用于日志) + + Returns: + 边界框列表 + """ + if HAS_HTTPX: + async with httpx.AsyncClient(timeout=self.timeout) as client: + # 注意:httpx 需要将文件和数据合并到 files 参数中 + form_data = {"filename": filename, "zoomin": str(zoomin)} + form_files = {"pdf_bytes": (filename, pdf_bytes, "application/pdf")} + response = await client.post( + f"{self.base_url}/parse_into_bboxes", + files=form_files, + data=form_data + ) + response.raise_for_status() + result = response.json() + else: + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session: + form_data = aiohttp.FormData() + form_data.add_field('pdf_bytes', pdf_bytes, filename=filename, content_type='application/pdf') + form_data.add_field('filename', filename) + form_data.add_field('zoomin', str(zoomin)) + + async with session.post( + f"{self.base_url}/parse_into_bboxes", + data=form_data + ) as response: + response.raise_for_status() + result = await response.json() + + if result.get("success") and result.get("data") and result["data"].get("bboxes"): + return result["data"]["bboxes"] + raise Exception(f"解析 PDF 失败: {result.get('message', '未知错误')}") + + def parse_into_bboxes_sync( + self, + pdf_bytes: bytes, + callback: Optional[Callable[[float, str], None]] = None, + zoomin: int = 3, + filename: str = "document.pdf" + ) -> List[dict]: + """ + 同步版本的 parse_into_bboxes(用于同步代码) + + Args: + pdf_bytes: PDF 文件的二进制数据 + callback: 进度回调函数(注意:HTTP 调用中无法实时传递回调,此参数将被忽略) + zoomin: 图像放大倍数(1-5,默认为3) + filename: 文件名(仅用于日志) + + Returns: + 边界框列表 + """ + if callback: + logger.warning("HTTP 调用中无法使用 callback,将忽略回调函数") + import asyncio + try: + loop = asyncio.get_event_loop() + return loop.run_until_complete(self.parse_into_bboxes(pdf_bytes, None, zoomin, filename)) + except RuntimeError: + return asyncio.run(self.parse_into_bboxes(pdf_bytes, None, zoomin, filename)) + + +# 全局客户端实例(懒加载) +_global_client: Optional[OCRClient] = None + + +def get_ocr_client() -> OCRClient: + """获取全局 OCR 客户端实例(单例模式)""" + global _global_client + if _global_client is None: + _global_client = OCRClient() + return _global_client + diff --git a/ocr/service.py b/ocr/service.py new file mode 100644 index 0000000..acfe810 --- /dev/null +++ b/ocr/service.py @@ -0,0 +1,290 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +OCR 服务统一接口 +支持本地OCR模型和HTTP接口两种方式,可通过配置选择 +""" + +import logging +import os +from abc import ABC, abstractmethod +from typing import Optional, Callable, List, Tuple, Any + +logger = logging.getLogger(__name__) + + +class OCRService(ABC): + """OCR服务抽象接口""" + + @abstractmethod + async def remove_tag(self, text: str) -> str: + """ + 移除文本中的位置标签 + + Args: + text: 包含位置标签的文本 + + Returns: + 清理后的文本 + """ + pass + + @abstractmethod + def remove_tag_sync(self, text: str) -> str: + """ + 同步版本的 remove_tag(用于同步代码) + + Args: + text: 包含位置标签的文本 + + Returns: + 清理后的文本 + """ + pass + + @abstractmethod + async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]: + """ + 从文本中提取位置信息 + + Args: + text: 包含位置标签的文本 + + Returns: + 位置信息列表,格式为 [(页码列表, left, right, top, bottom), ...] + """ + pass + + @abstractmethod + def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]: + """ + 同步版本的 extract_positions(用于同步代码) + + Args: + text: 包含位置标签的文本 + + Returns: + 位置信息列表 + """ + pass + + @abstractmethod + async def parse_into_bboxes( + self, + pdf_bytes: bytes, + callback: Optional[Callable[[float, str], None]] = None, + zoomin: int = 3, + filename: str = "document.pdf" + ) -> List[dict]: + """ + 解析 PDF 并返回边界框 + + Args: + pdf_bytes: PDF 文件的二进制数据 + callback: 进度回调函数 (progress: float, message: str) -> None + zoomin: 图像放大倍数(1-5,默认为3) + filename: 文件名(仅用于日志) + + Returns: + 边界框列表 + """ + pass + + @abstractmethod + def parse_into_bboxes_sync( + self, + pdf_bytes: bytes, + callback: Optional[Callable[[float, str], None]] = None, + zoomin: int = 3, + filename: str = "document.pdf" + ) -> List[dict]: + """ + 同步版本的 parse_into_bboxes(用于同步代码) + + Args: + pdf_bytes: PDF 文件的二进制数据 + callback: 进度回调函数(注意:HTTP 调用中无法实时传递回调,此参数将被忽略) + zoomin: 图像放大倍数(1-5,默认为3) + filename: 文件名(仅用于日志) + + Returns: + 边界框列表 + """ + pass + + +class LocalOCRService(OCRService): + """本地OCR服务实现(直接调用本地OCR模型)""" + + def __init__(self, parser_instance=None): + """ + 初始化本地OCR服务 + + Args: + parser_instance: SimplePdfParser 实例,如果不提供则自动创建 + """ + if parser_instance is None: + from ocr import SimplePdfParser + from ocr.config import MODEL_DIR + logger.info(f"Initializing local OCR parser with model_dir={MODEL_DIR}") + self.parser = SimplePdfParser(model_dir=MODEL_DIR) + else: + self.parser = parser_instance + + async def remove_tag(self, text: str) -> str: + """使用本地解析器的静态方法移除标签""" + # SimplePdfParser.remove_tag 是静态方法,可以直接调用 + return self.parser.remove_tag(text) + + def remove_tag_sync(self, text: str) -> str: + """同步版本的 remove_tag""" + return self.parser.remove_tag(text) + + async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]: + """使用本地解析器的静态方法提取位置""" + # SimplePdfParser.extract_positions 是静态方法,可以直接调用 + return self.parser.extract_positions(text) + + def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]: + """同步版本的 extract_positions""" + return self.parser.extract_positions(text) + + async def parse_into_bboxes( + self, + pdf_bytes: bytes, + callback: Optional[Callable[[float, str], None]] = None, + zoomin: int = 3, + filename: str = "document.pdf" + ) -> List[dict]: + """使用本地解析器解析PDF""" + # 本地解析器可以直接接受BytesIO + import asyncio + from io import BytesIO + + # 在后台线程中运行同步方法 + loop = asyncio.get_event_loop() + bboxes = await loop.run_in_executor( + None, + lambda: self.parser.parse_into_bboxes(BytesIO(pdf_bytes), callback=callback, zoomin=zoomin) + ) + return bboxes + + def parse_into_bboxes_sync( + self, + pdf_bytes: bytes, + callback: Optional[Callable[[float, str], None]] = None, + zoomin: int = 3, + filename: str = "document.pdf" + ) -> List[dict]: + """同步版本的 parse_into_bboxes""" + from io import BytesIO + # 本地解析器可以直接接受BytesIO + return self.parser.parse_into_bboxes(BytesIO(pdf_bytes), callback=callback, zoomin=zoomin) + + +class HTTPOCRService(OCRService): + """HTTP OCR服务实现(通过HTTP接口调用OCR服务)""" + + def __init__(self, base_url: Optional[str] = None, timeout: float = 300.0): + """ + 初始化HTTP OCR服务 + + Args: + base_url: OCR 服务的基础 URL,如果不提供则从环境变量 OCR_SERVICE_URL 获取 + timeout: 请求超时时间(秒),默认 300 秒 + """ + from ocr.client import OCRClient + self.client = OCRClient(base_url=base_url, timeout=timeout) + + async def remove_tag(self, text: str) -> str: + """通过HTTP接口移除标签""" + return await self.client.remove_tag(text) + + def remove_tag_sync(self, text: str) -> str: + """同步版本的 remove_tag""" + return self.client.remove_tag_sync(text) + + async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]: + """通过HTTP接口提取位置""" + return await self.client.extract_positions(text) + + def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]: + """同步版本的 extract_positions""" + return self.client.extract_positions_sync(text) + + async def parse_into_bboxes( + self, + pdf_bytes: bytes, + callback: Optional[Callable[[float, str], None]] = None, + zoomin: int = 3, + filename: str = "document.pdf" + ) -> List[dict]: + """通过HTTP接口解析PDF""" + return await self.client.parse_into_bboxes(pdf_bytes, callback, zoomin, filename) + + def parse_into_bboxes_sync( + self, + pdf_bytes: bytes, + callback: Optional[Callable[[float, str], None]] = None, + zoomin: int = 3, + filename: str = "document.pdf" + ) -> List[dict]: + """同步版本的 parse_into_bboxes""" + return self.client.parse_into_bboxes_sync(pdf_bytes, callback, zoomin, filename) + + +# 全局服务实例(懒加载) +_global_service: Optional[OCRService] = None + + +def get_ocr_service() -> OCRService: + """ + 获取全局 OCR 服务实例(单例模式) + 根据环境变量 OCR_MODE 选择使用本地或HTTP方式: + - OCR_MODE=local 或未设置:使用本地OCR模型 + - OCR_MODE=http:使用HTTP接口 + + 也可以通过环境变量 OCR_SERVICE_URL 配置HTTP服务的地址(仅在OCR_MODE=http时生效) + + Returns: + OCRService 实例 + """ + global _global_service + if _global_service is None: + ocr_mode = os.getenv("OCR_MODE", "local").lower() + + if ocr_mode == "http": + base_url = os.getenv("OCR_SERVICE_URL", "http://localhost:8000/api/v1/ocr") + logger.info(f"Initializing HTTP OCR service with URL: {base_url}") + _global_service = HTTPOCRService(base_url=base_url) + else: + logger.info("Initializing local OCR service") + _global_service = LocalOCRService() + + return _global_service + + +# 为了向后兼容,保留 get_ocr_client 函数(但重定向到 get_ocr_service) +def get_ocr_client() -> OCRService: + """ + 获取OCR服务实例(向后兼容函数) + 建议使用 get_ocr_service() 替代 + + Returns: + OCRService 实例 + """ + return get_ocr_service() + diff --git a/rag/flow/hierarchical_merger/hierarchical_merger.py b/rag/flow/hierarchical_merger/hierarchical_merger.py index dda2bcf..5967089 100644 --- a/rag/flow/hierarchical_merger/hierarchical_merger.py +++ b/rag/flow/hierarchical_merger/hierarchical_merger.py @@ -22,7 +22,7 @@ import trio from api.utils import get_uuid from api.utils.base64_image import id2image, image2id -from deepdoc.parser.pdf_parser import RAGFlowPdfParser +from ocr.service import get_ocr_service from rag.flow.base import ProcessBase, ProcessParamBase from rag.flow.hierarchical_merger.schema import HierarchicalMergerFromUpstream from rag.nlp import concat_img @@ -170,14 +170,17 @@ class HierarchicalMerger(ProcessBase): cks.append(txt) images.append(img) - cks = [ - { - "text": RAGFlowPdfParser.remove_tag(c), + ocr_service = get_ocr_service() + processed_cks = [] + for c, img in zip(cks, images): + cleaned_text = await ocr_service.remove_tag(c) + positions = await ocr_service.extract_positions(c) + processed_cks.append({ + "text": cleaned_text, "image": img, - "positions": RAGFlowPdfParser.extract_positions(c), - } - for c, img in zip(cks, images) - ] + "positions": positions, + }) + cks = processed_cks async with trio.open_nursery() as nursery: for d in cks: nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid()) diff --git a/rag/flow/parser/parser.py b/rag/flow/parser/parser.py index f437885..efb00f2 100644 --- a/rag/flow/parser/parser.py +++ b/rag/flow/parser/parser.py @@ -29,7 +29,8 @@ from api.db.services.llm_service import LLMBundle from api.utils import get_uuid from api.utils.base64_image import image2id from deepdoc.parser import ExcelParser -from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser +from deepdoc.parser.pdf_parser import PlainParser, VisionParser +from ocr.service import get_ocr_service from rag.app.naive import Docx from rag.flow.base import ProcessBase, ProcessParamBase from rag.flow.parser.schema import ParserFromUpstream @@ -204,7 +205,9 @@ class Parser(ProcessBase): self.set_output("output_format", conf["output_format"]) if conf.get("parse_method").lower() == "deepdoc": - bboxes = RAGFlowPdfParser().parse_into_bboxes(blob, callback=self.callback) + # 注意:HTTP 调用中无法传递 callback,callback 将被忽略 + ocr_service = get_ocr_service() + bboxes = ocr_service.parse_into_bboxes_sync(blob, callback=self.callback, filename=name) elif conf.get("parse_method").lower() == "plain_text": lines, _ = PlainParser()(blob) bboxes = [{"text": t} for t, _ in lines] diff --git a/rag/flow/splitter/splitter.py b/rag/flow/splitter/splitter.py index 9c6eb7b..e9d5878 100644 --- a/rag/flow/splitter/splitter.py +++ b/rag/flow/splitter/splitter.py @@ -19,7 +19,7 @@ import trio from api.utils import get_uuid from api.utils.base64_image import id2image, image2id -from deepdoc.parser.pdf_parser import RAGFlowPdfParser +from ocr.service import get_ocr_service from rag.flow.base import ProcessBase, ProcessParamBase from rag.flow.splitter.schema import SplitterFromUpstream from rag.nlp import naive_merge, naive_merge_with_images @@ -96,14 +96,18 @@ class Splitter(ProcessBase): deli, self._param.overlapped_percent, ) - cks = [ - { - "text": RAGFlowPdfParser.remove_tag(c), + ocr_service = get_ocr_service() + cks = [] + for c, img in zip(chunks, images): + if not c.strip(): + continue + cleaned_text = await ocr_service.remove_tag(c) + positions = await ocr_service.extract_positions(c) + cks.append({ + "text": cleaned_text, "image": img, - "positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)], - } - for c, img in zip(chunks, images) if c.strip() - ] + "positions": [[pos[0][-1]+1, *pos[1:]] for pos in positions], + }) async with trio.open_nursery() as nursery: for d in cks: nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid()) diff --git a/rag/nlp/__init__.py b/rag/nlp/__init__.py index 7362986..db17998 100644 --- a/rag/nlp/__init__.py +++ b/rag/nlp/__init__.py @@ -578,7 +578,8 @@ def hierarchical_merge(bull, sections, depth): def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0): - from deepdoc.parser.pdf_parser import RAGFlowPdfParser + from ocr.service import get_ocr_service + ocr_service = get_ocr_service() if not sections: return [] if isinstance(sections, str): @@ -598,7 +599,7 @@ def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。; # Ensure that the length of the merged chunk does not exceed chunk_token_num if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.: if cks: - overlapped = RAGFlowPdfParser.remove_tag(cks[-1]) + overlapped = ocr_service.remove_tag_sync(cks[-1]) t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t if t.find(pos) < 0: t += pos @@ -625,7 +626,8 @@ def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。; def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0): - from deepdoc.parser.pdf_parser import RAGFlowPdfParser + from ocr.service import get_ocr_service + ocr_service = get_ocr_service() if not texts or len(texts) != len(images): return [], [] cks = [""] @@ -642,7 +644,7 @@ def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。 # Ensure that the length of the merged chunk does not exceed chunk_token_num if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.: if cks: - overlapped = RAGFlowPdfParser.remove_tag(cks[-1]) + overlapped = ocr_service.remove_tag_sync(cks[-1]) t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t if t.find(pos) < 0: t += pos