Compare commits

...

15 Commits

Author SHA1 Message Date
dangzerong
3e58c3d0e9 优化OCR解析 2025-11-03 10:22:28 +08:00
dzr
4603a86df4 将ocr解析模块独立出来 2025-10-31 17:50:25 +08:00
dangzerong
4318179904 将ocr解析模块独立出来 2025-10-31 14:38:37 +08:00
dangzerong
d78f1fe91d 优化docker-compose-base.yml 2025-10-30 20:58:41 +08:00
dangzerong
83e2cd58e2 修改启动脚本 2025-10-30 14:47:57 +08:00
dangzerong
cb3c94ec50 task_executor 添加日志 2025-10-30 14:24:20 +08:00
dangzerong
0c31eabf20 改造 chunk_app.py 2025-10-29 13:14:56 +08:00
dangzerong
7d0d65a0ac 改造 chunk_app.py 2025-10-29 11:26:35 +08:00
dangzerong
6f2f26be10 修复/kb/update 2025-10-29 11:09:36 +08:00
dangzerong
d174de94b6 修改document_app.py的upload 2025-10-28 16:28:41 +08:00
dangzerong
cbe0477ba1 llm.py改造 fastapi 2025-10-28 11:49:07 +08:00
dangzerong
45f69ab3d5 将Dockerfile拆成2个文件 2025-10-27 17:46:40 +08:00
dangzerong
50925f98ce tenant_app.py fastapi改造 2025-10-27 17:03:16 +08:00
dangzerong
4b95be9762 支持 标准Bearer格式和直接token格式( 2025-10-27 16:31:17 +08:00
dangzerong
8086a73f9f 修改Dockerfile 2025-10-27 10:34:16 +08:00
42 changed files with 5562 additions and 9234 deletions

View File

@@ -1,196 +1,11 @@
# base stage
FROM ubuntu:22.04 AS base
USER root
SHELL ["/bin/bash", "-c"]
ARG NEED_MIRROR=0
ARG LIGHTEN=0
ENV LIGHTEN=${LIGHTEN}
WORKDIR /ragflow
# Copy models downloaded via download_deps.py
RUN mkdir -p /ragflow/rag/res/deepdoc /root/.ragflow
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/huggingface.co,target=/huggingface.co \
cp /huggingface.co/InfiniFlow/huqie/huqie.txt.trie /ragflow/rag/res/ && \
tar --exclude='.*' -cf - \
/huggingface.co/InfiniFlow/text_concat_xgb_v1.0 \
/huggingface.co/InfiniFlow/deepdoc \
| tar -xf - --strip-components=3 -C /ragflow/rag/res/deepdoc
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/huggingface.co,target=/huggingface.co \
if [ "$LIGHTEN" != "1" ]; then \
(tar -cf - \
/huggingface.co/BAAI/bge-large-zh-v1.5 \
/huggingface.co/maidalun1020/bce-embedding-base_v1 \
| tar -xf - --strip-components=2 -C /root/.ragflow) \
fi
# https://github.com/chrismattmann/tika-python
# This is the only way to run python-tika without internet access. Without this set, the default is to check the tika version and pull latest every time from Apache.
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \
cp -r /deps/nltk_data /root/ && \
cp /deps/tika-server-standard-3.0.0.jar /deps/tika-server-standard-3.0.0.jar.md5 /ragflow/ && \
cp /deps/cl100k_base.tiktoken /ragflow/9b5ad71b2ce5302211f9c61530b329a4922fc6a4
ENV TIKA_SERVER_JAR="file:///ragflow/tika-server-standard-3.0.0.jar"
ENV DEBIAN_FRONTEND=noninteractive
# Setup apt
# Python package and implicit dependencies:
# opencv-python: libglib2.0-0 libglx-mesa0 libgl1
# aspose-slides: pkg-config libicu-dev libgdiplus libssl1.1_1.1.1f-1ubuntu2_amd64.deb
# python-pptx: default-jdk tika-server-standard-3.0.0.jar
# selenium: libatk-bridge2.0-0 chrome-linux64-121-0-6167-85
# Building C extensions: libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev
RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
if [ "$NEED_MIRROR" == "1" ]; then \
sed -i 's|http://ports.ubuntu.com|http://mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list; \
sed -i 's|http://archive.ubuntu.com|http://mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list; \
fi; \
rm -f /etc/apt/apt.conf.d/docker-clean && \
echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache && \
chmod 1777 /tmp && \
apt update && \
apt --no-install-recommends install -y ca-certificates && \
apt update && \
apt install -y libglib2.0-0 libglx-mesa0 libgl1 && \
apt install -y pkg-config libicu-dev libgdiplus && \
apt install -y default-jdk && \
apt install -y libatk-bridge2.0-0 && \
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
apt install -y libjemalloc-dev && \
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
apt install -y ghostscript
RUN if [ "$NEED_MIRROR" == "1" ]; then \
pip3 config set global.index-url https://mirrors.aliyun.com/pypi/simple && \
pip3 config set global.trusted-host mirrors.aliyun.com; \
mkdir -p /etc/uv && \
echo "[[index]]" > /etc/uv/uv.toml && \
echo 'url = "https://mirrors.aliyun.com/pypi/simple"' >> /etc/uv/uv.toml && \
echo "default = true" >> /etc/uv/uv.toml; \
fi; \
pipx install uv
ENV PYTHONDONTWRITEBYTECODE=1 DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1
ENV PATH=/root/.local/bin:$PATH
# nodejs 12.22 on Ubuntu 22.04 is too old
RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \
apt purge -y nodejs npm cargo && \
apt autoremove -y && \
apt update && \
apt install -y nodejs
# A modern version of cargo is needed for the latest version of the Rust compiler.
RUN apt update && apt install -y curl build-essential \
&& if [ "$NEED_MIRROR" == "1" ]; then \
# Use TUNA mirrors for rustup/rust dist files
export RUSTUP_DIST_SERVER="https://mirrors.tuna.tsinghua.edu.cn/rustup"; \
export RUSTUP_UPDATE_ROOT="https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup"; \
echo "Using TUNA mirrors for Rustup."; \
fi; \
# Force curl to use HTTP/1.1
curl --proto '=https' --tlsv1.2 --http1.1 -sSf https://sh.rustup.rs | bash -s -- -y --profile minimal \
&& echo 'export PATH="/root/.cargo/bin:${PATH}"' >> /root/.bashrc
ENV PATH="/root/.cargo/bin:${PATH}"
RUN cargo --version && rustc --version
# Add msssql ODBC driver
# macOS ARM64 environment, install msodbcsql18.
# general x86_64 environment, install msodbcsql17.
RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
curl https://packages.microsoft.com/keys/microsoft.asc | apt-key add - && \
curl https://packages.microsoft.com/config/ubuntu/22.04/prod.list > /etc/apt/sources.list.d/mssql-release.list && \
apt update && \
arch="$(uname -m)"; \
if [ "$arch" = "arm64" ] || [ "$arch" = "aarch64" ]; then \
# ARM64 (macOS/Apple Silicon or Linux aarch64)
ACCEPT_EULA=Y apt install -y unixodbc-dev msodbcsql18; \
else \
# x86_64 or others
ACCEPT_EULA=Y apt install -y unixodbc-dev msodbcsql17; \
fi || \
{ echo "Failed to install ODBC driver"; exit 1; }
# Add dependencies of selenium
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/chrome-linux64-121-0-6167-85,target=/chrome-linux64.zip \
unzip /chrome-linux64.zip && \
mv chrome-linux64 /opt/chrome && \
ln -s /opt/chrome/chrome /usr/local/bin/
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/chromedriver-linux64-121-0-6167-85,target=/chromedriver-linux64.zip \
unzip -j /chromedriver-linux64.zip chromedriver-linux64/chromedriver && \
mv chromedriver /usr/local/bin/ && \
rm -f /usr/bin/google-chrome
# https://forum.aspose.com/t/aspose-slides-for-net-no-usable-version-of-libssl-found-with-linux-server/271344/13
# aspose-slides on linux/arm64 is unavailable
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \
if [ "$(uname -m)" = "x86_64" ]; then \
dpkg -i /deps/libssl1.1_1.1.1f-1ubuntu2_amd64.deb; \
elif [ "$(uname -m)" = "aarch64" ]; then \
dpkg -i /deps/libssl1.1_1.1.1f-1ubuntu2_arm64.deb; \
fi
# builder stage
FROM base AS builder
# Application stage - builds on top of the base image
# First build the base image using: docker build -f Dockerfile.base -t ragflow-base:latest .
FROM ragflow-base:latest AS production
USER root
WORKDIR /ragflow
# install dependencies from uv.lock file
COPY pyproject.toml uv.lock ./
# https://github.com/astral-sh/uv/issues/10462
# uv records index url into uv.lock but doesn't failover among multiple indexes
RUN --mount=type=cache,id=ragflow_uv,target=/root/.cache/uv,sharing=locked \
if [ "$NEED_MIRROR" == "1" ]; then \
sed -i 's|pypi.org|mirrors.aliyun.com/pypi|g' uv.lock; \
else \
sed -i 's|mirrors.aliyun.com/pypi|pypi.org|g' uv.lock; \
fi; \
if [ "$LIGHTEN" == "1" ]; then \
uv sync --python 3.10 --frozen; \
else \
uv sync --python 3.10 --frozen --all-extras; \
fi
COPY web web
COPY docs docs
RUN --mount=type=cache,id=ragflow_npm,target=/root/.npm,sharing=locked \
cd web && npm install && npm run build
COPY .git /ragflow/.git
RUN version_info=$(git describe --tags --match=v* --first-parent --always); \
if [ "$LIGHTEN" == "1" ]; then \
version_info="$version_info slim"; \
else \
version_info="$version_info full"; \
fi; \
echo "RAGFlow version: $version_info"; \
echo $version_info > /ragflow/VERSION
# production stage
FROM base AS production
USER root
WORKDIR /ragflow
# Copy Python environment and packages
ENV VIRTUAL_ENV=/ragflow/.venv
COPY --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV}
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
ENV PYTHONPATH=/ragflow/
COPY web web
# Copy application source code (these files change frequently)
COPY api api
COPY conf conf
COPY deepdoc deepdoc
@@ -198,16 +13,14 @@ COPY rag rag
COPY agent agent
COPY graphrag graphrag
COPY agentic_reasoning agentic_reasoning
COPY pyproject.toml uv.lock ./
COPY pyproject.toml ./
COPY mcp mcp
COPY plugin plugin
# Copy configuration templates and entrypoint
COPY docker/service_conf.yaml.template ./conf/service_conf.yaml.template
COPY docker/entrypoint.sh ./
RUN chmod +x ./entrypoint*.sh
# Copy compiled web pages
COPY --from=builder /ragflow/web/dist /ragflow/web/dist
COPY --from=builder /ragflow/VERSION /ragflow/VERSION
ENTRYPOINT ["./entrypoint.sh"]
# Set the entrypoint
ENTRYPOINT ["./entrypoint.sh"]

189
Dockerfile.base Normal file
View File

@@ -0,0 +1,189 @@
# base stage
FROM ubuntu:22.04 AS base
USER root
SHELL ["/bin/bash", "-c"]
ARG NEED_MIRROR=0
ARG LIGHTEN=0
ENV LIGHTEN=${LIGHTEN}
WORKDIR /ragflow
# Copy models downloaded via download_deps.py
RUN mkdir -p /ragflow/rag/res/deepdoc /root/.ragflow
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/huggingface.co,target=/huggingface.co \
cp /huggingface.co/InfiniFlow/huqie/huqie.txt.trie /ragflow/rag/res/ && \
tar --exclude='.*' -cf - \
/huggingface.co/InfiniFlow/text_concat_xgb_v1.0 \
/huggingface.co/InfiniFlow/deepdoc \
| tar -xf - --strip-components=3 -C /ragflow/rag/res/deepdoc
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/huggingface.co,target=/huggingface.co \
if [ "$LIGHTEN" != "1" ]; then \
(tar -cf - \
/huggingface.co/BAAI/bge-large-zh-v1.5 \
/huggingface.co/maidalun1020/bce-embedding-base_v1 \
| tar -xf - --strip-components=2 -C /root/.ragflow) \
fi
# https://github.com/chrismattmann/tika-python
# This is the only way to run python-tika without internet access. Without this set, the default is to check the tika version and pull latest every time from Apache.
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \
cp -r /deps/nltk_data /root/ && \
cp /deps/tika-server-standard-3.0.0.jar /deps/tika-server-standard-3.0.0.jar.md5 /ragflow/ && \
cp /deps/cl100k_base.tiktoken /ragflow/9b5ad71b2ce5302211f9c61530b329a4922fc6a4
ENV TIKA_SERVER_JAR="file:///ragflow/tika-server-standard-3.0.0.jar"
ENV DEBIAN_FRONTEND=noninteractive
# Setup apt
# Python package and implicit dependencies:
# opencv-python: libglib2.0-0 libglx-mesa0 libgl1
# aspose-slides: pkg-config libicu-dev libgdiplus libssl1.1_1.1.1f-1ubuntu2_amd64.deb
# python-pptx: default-jdk tika-server-standard-3.0.0.jar
# selenium: libatk-bridge2.0-0 chrome-linux64-121-0-6167-85
# Building C extensions: libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev
RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
if [ "$NEED_MIRROR" == "1" ]; then \
sed -i 's|http://ports.ubuntu.com|http://mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list; \
sed -i 's|http://archive.ubuntu.com|http://mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list; \
fi; \
rm -f /etc/apt/apt.conf.d/docker-clean && \
echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache && \
chmod 1777 /tmp && \
apt update && \
apt --no-install-recommends install -y ca-certificates && \
apt update && \
apt install -y libglib2.0-0 libglx-mesa0 libgl1 && \
apt install -y pkg-config libicu-dev libgdiplus && \
apt install -y default-jdk && \
apt install -y libatk-bridge2.0-0 && \
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
apt install -y libjemalloc-dev && \
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
apt install -y ghostscript
RUN if [ "$NEED_MIRROR" == "1" ]; then \
pip3 config set global.index-url https://mirrors.aliyun.com/pypi/simple && \
pip3 config set global.trusted-host mirrors.aliyun.com; \
mkdir -p /etc/uv && \
echo "[[index]]" > /etc/uv/uv.toml && \
echo 'url = "https://mirrors.aliyun.com/pypi/simple"' >> /etc/uv/uv.toml && \
echo "default = true" >> /etc/uv/uv.toml; \
fi; \
pipx install uv
ENV PYTHONDONTWRITEBYTECODE=1 DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1
ENV PATH=/root/.local/bin:$PATH
# nodejs 12.22 on Ubuntu 22.04 is too old
RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \
apt purge -y nodejs npm cargo && \
apt autoremove -y && \
apt update && \
apt install -y nodejs
# A modern version of cargo is needed for the latest version of the Rust compiler.
RUN apt update && apt install -y curl build-essential \
&& if [ "$NEED_MIRROR" == "1" ]; then \
# Use TUNA mirrors for rustup/rust dist files
export RUSTUP_DIST_SERVER="https://mirrors.tuna.tsinghua.edu.cn/rustup"; \
export RUSTUP_UPDATE_ROOT="https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup"; \
echo "Using TUNA mirrors for Rustup."; \
fi; \
# Force curl to use HTTP/1.1
curl --proto '=https' --tlsv1.2 --http1.1 -sSf https://sh.rustup.rs | bash -s -- -y --profile minimal \
&& echo 'export PATH="/root/.cargo/bin:${PATH}"' >> /root/.bashrc
ENV PATH="/root/.cargo/bin:${PATH}"
RUN cargo --version && rustc --version
# Add msssql ODBC driver
# macOS ARM64 environment, install msodbcsql18.
# general x86_64 environment, install msodbcsql17.
RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
curl https://packages.microsoft.com/keys/microsoft.asc | apt-key add - && \
curl https://packages.microsoft.com/config/ubuntu/22.04/prod.list > /etc/apt/sources.list.d/mssql-release.list && \
apt update && \
arch="$(uname -m)"; \
if [ "$arch" = "arm64" ] || [ "$arch" = "aarch64" ]; then \
# ARM64 (macOS/Apple Silicon or Linux aarch64)
ACCEPT_EULA=Y apt install -y unixodbc-dev msodbcsql18; \
else \
# x86_64 or others
ACCEPT_EULA=Y apt install -y unixodbc-dev msodbcsql17; \
fi || \
{ echo "Failed to install ODBC driver"; exit 1; }
# Add dependencies of selenium
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/chrome-linux64-121-0-6167-85,target=/chrome-linux64.zip \
unzip /chrome-linux64.zip && \
mv chrome-linux64 /opt/chrome && \
ln -s /opt/chrome/chrome /usr/local/bin/
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/chromedriver-linux64-121-0-6167-85,target=/chromedriver-linux64.zip \
unzip -j /chromedriver-linux64.zip chromedriver-linux64/chromedriver && \
mv chromedriver /usr/local/bin/ && \
rm -f /usr/bin/google-chrome
# https://forum.aspose.com/t/aspose-slides-for-net-no-usable-version-of-libssl-found-with-linux-server/271344/13
# aspose-slides on linux/arm64 is unavailable
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \
if [ "$(uname -m)" = "x86_64" ]; then \
dpkg -i /deps/libssl1.1_1.1.1f-1ubuntu2_amd64.deb; \
elif [ "$(uname -m)" = "aarch64" ]; then \
dpkg -i /deps/libssl1.1_1.1.1f-1ubuntu2_arm64.deb; \
fi
# builder stage
FROM base AS builder
USER root
WORKDIR /ragflow
# install dependencies from uv.lock file
COPY pyproject.toml ./
RUN uv lock --python 3.10
# https://github.com/astral-sh/uv/issues/10462
# uv records index url into uv.lock but doesn't failover among multiple indexes
RUN --mount=type=cache,id=ragflow_uv,target=/root/.cache/uv,sharing=locked \
if [ "$NEED_MIRROR" == "1" ]; then \
sed -i 's|pypi.org|mirrors.aliyun.com/pypi|g' uv.lock; \
else \
sed -i 's|mirrors.aliyun.com/pypi|pypi.org|g' uv.lock; \
fi; \
if [ "$LIGHTEN" == "1" ]; then \
uv sync --python 3.10 --frozen; \
else \
uv sync --python 3.10 --frozen --all-extras; \
fi
RUN --mount=type=cache,id=ragflow_npm,target=/root/.npm,sharing=locked
COPY .git /ragflow/.git
RUN version_info=$(git describe --tags --match=v* --first-parent --always); \
if [ "$LIGHTEN" == "1" ]; then \
version_info="$version_info slim"; \
else \
version_info="$version_info full"; \
fi; \
echo "RAGFlow version: $version_info"; \
echo $version_info > /ragflow/VERSION
# Final base image with Python environment
FROM base AS ragflow-base
USER root
WORKDIR /ragflow
# Copy Python environment and packages from builder
ENV VIRTUAL_ENV=/ragflow/.venv
COPY --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV}
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
ENV PYTHONPATH=/ragflow/
# Copy version info
COPY --from=builder /ragflow/VERSION /ragflow/VERSION

View File

@@ -127,6 +127,9 @@ def setup_routes(app: FastAPI):
from api.apps.file_app import router as file_router
from api.apps.file2document_app import router as file2document_router
from api.apps.mcp_server_app import router as mcp_router
from api.apps.tenant_app import router as tenant_router
from api.apps.llm_app import router as llm_router
from api.apps.chunk_app import router as chunk_router
app.include_router(user_router, prefix=f"/{API_VERSION}/user", tags=["User"])
app.include_router(kb_router, prefix=f"/{API_VERSION}/kb", tags=["KB"])
@@ -134,6 +137,9 @@ def setup_routes(app: FastAPI):
app.include_router(file_router, prefix=f"/{API_VERSION}/file", tags=["File"])
app.include_router(file2document_router, prefix=f"/{API_VERSION}/file2document", tags=["File2Document"])
app.include_router(mcp_router, prefix=f"/{API_VERSION}/mcp", tags=["MCP"])
app.include_router(tenant_router, prefix=f"/{API_VERSION}/tenant", tags=["Tenant"])
app.include_router(llm_router, prefix=f"/{API_VERSION}/llm", tags=["LLM"])
app.include_router(chunk_router, prefix=f"/{API_VERSION}/chunk", tags=["Chunk"])
def get_current_user_from_token(authorization: str):
"""从token获取当前用户"""

View File

@@ -16,10 +16,10 @@
import datetime
import json
import re
from typing import Optional, List
import xxhash
from flask import request
from flask_login import current_user, login_required
from fastapi import APIRouter, Depends, Query, HTTPException
from api import settings
from api.db import LLMType, ParserType
@@ -29,7 +29,17 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
from api.models.chunk_models import (
ListChunkRequest,
GetChunkRequest,
SetChunkRequest,
SwitchChunkRequest,
RemoveChunkRequest,
CreateChunkRequest,
RetrievalTestRequest,
KnowledgeGraphRequest
)
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, get_current_user
from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search
@@ -37,18 +47,21 @@ from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extr
from rag.settings import PAGERANK_FLD
from rag.utils import rmSpace
# 创建 FastAPI 路由器
router = APIRouter()
@manager.route('/list', methods=['POST']) # noqa: F821
@login_required
@validate_request("doc_id")
def list_chunk():
req = request.json
doc_id = req["doc_id"]
page = int(req.get("page", 1))
size = int(req.get("size", 30))
question = req.get("keywords", "")
@router.post('/list')
async def list_chunk(
request: ListChunkRequest,
current_user = Depends(get_current_user)
):
doc_id = request.doc_id
page = request.page
size = request.size
question = request.keywords
try:
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
tenant_id = DocumentService.get_tenant_id(doc_id)
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
e, doc = DocumentService.get_by_id(doc_id)
@@ -58,8 +71,8 @@ def list_chunk():
query = {
"doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True
}
if "available_int" in req:
query["available_int"] = int(req["available_int"])
if request.available_int is not None:
query["available_int"] = int(request.available_int)
sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
for id in sres.ids:
@@ -87,10 +100,11 @@ def list_chunk():
return server_error_response(e)
@manager.route('/get', methods=['GET']) # noqa: F821
@login_required
def get():
chunk_id = request.args["chunk_id"]
@router.get('/get')
async def get(
chunk_id: str = Query(..., description="块ID"),
current_user = Depends(get_current_user)
):
try:
chunk = None
tenants = UserTenantService.query(user_id=current_user.id)
@@ -119,42 +133,42 @@ def get():
return server_error_response(e)
@manager.route('/set', methods=['POST']) # noqa: F821
@login_required
@validate_request("doc_id", "chunk_id", "content_with_weight")
def set():
req = request.json
@router.post('/set')
async def set(
request: SetChunkRequest,
current_user = Depends(get_current_user)
):
d = {
"id": req["chunk_id"],
"content_with_weight": req["content_with_weight"]}
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
"id": request.chunk_id,
"content_with_weight": request.content_with_weight}
d["content_ltks"] = rag_tokenizer.tokenize(request.content_with_weight)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
if "important_kwd" in req:
if not isinstance(req["important_kwd"], list):
if request.important_kwd is not None:
if not isinstance(request.important_kwd, list):
return get_data_error_result(message="`important_kwd` should be a list")
d["important_kwd"] = req["important_kwd"]
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"]))
if "question_kwd" in req:
if not isinstance(req["question_kwd"], list):
d["important_kwd"] = request.important_kwd
d["important_tks"] = rag_tokenizer.tokenize(" ".join(request.important_kwd))
if request.question_kwd is not None:
if not isinstance(request.question_kwd, list):
return get_data_error_result(message="`question_kwd` should be a list")
d["question_kwd"] = req["question_kwd"]
d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"]))
if "tag_kwd" in req:
d["tag_kwd"] = req["tag_kwd"]
if "tag_feas" in req:
d["tag_feas"] = req["tag_feas"]
if "available_int" in req:
d["available_int"] = req["available_int"]
d["question_kwd"] = request.question_kwd
d["question_tks"] = rag_tokenizer.tokenize("\n".join(request.question_kwd))
if request.tag_kwd is not None:
d["tag_kwd"] = request.tag_kwd
if request.tag_feas is not None:
d["tag_feas"] = request.tag_feas
if request.available_int is not None:
d["available_int"] = request.available_int
try:
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
tenant_id = DocumentService.get_tenant_id(request.doc_id)
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_id = DocumentService.get_embd_id(request.doc_id)
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
e, doc = DocumentService.get_by_id(req["doc_id"])
e, doc = DocumentService.get_by_id(request.doc_id)
if not e:
return get_data_error_result(message="Document not found!")
@@ -162,33 +176,33 @@ def set():
arr = [
t for t in re.split(
r"[\n\t]",
req["content_with_weight"]) if len(t) > 1]
request.content_with_weight) if len(t) > 1]
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
d = beAdoc(d, q, a, not any(
[rag_tokenizer.is_chinese(t) for t in q + a]))
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
v, c = embd_mdl.encode([doc.name, request.content_with_weight if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.update({"id": request.chunk_id}, d, search.index_name(tenant_id), doc.kb_id)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@manager.route('/switch', methods=['POST']) # noqa: F821
@login_required
@validate_request("chunk_ids", "available_int", "doc_id")
def switch():
req = request.json
@router.post('/switch')
async def switch(
request: SwitchChunkRequest,
current_user = Depends(get_current_user)
):
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
e, doc = DocumentService.get_by_id(request.doc_id)
if not e:
return get_data_error_result(message="Document not found!")
for cid in req["chunk_ids"]:
for cid in request.chunk_ids:
if not settings.docStoreConn.update({"id": cid},
{"available_int": int(req["available_int"])},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
{"available_int": int(request.available_int)},
search.index_name(DocumentService.get_tenant_id(request.doc_id)),
doc.kb_id):
return get_data_error_result(message="Index updating failure")
return get_json_result(data=True)
@@ -196,21 +210,21 @@ def switch():
return server_error_response(e)
@manager.route('/rm', methods=['POST']) # noqa: F821
@login_required
@validate_request("chunk_ids", "doc_id")
def rm():
@router.post('/rm')
async def rm(
request: RemoveChunkRequest,
current_user = Depends(get_current_user)
):
from rag.utils.storage_factory import STORAGE_IMPL
req = request.json
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
e, doc = DocumentService.get_by_id(request.doc_id)
if not e:
return get_data_error_result(message="Document not found!")
if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
if not settings.docStoreConn.delete({"id": request.chunk_ids},
search.index_name(DocumentService.get_tenant_id(request.doc_id)),
doc.kb_id):
return get_data_error_result(message="Chunk deleting failure")
deleted_chunk_ids = req["chunk_ids"]
deleted_chunk_ids = request.chunk_ids
chunk_number = len(deleted_chunk_ids)
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
for cid in deleted_chunk_ids:
@@ -221,32 +235,30 @@ def rm():
return server_error_response(e)
@manager.route('/create', methods=['POST']) # noqa: F821
@login_required
@validate_request("doc_id", "content_with_weight")
def create():
req = request.json
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
"content_with_weight": req["content_with_weight"]}
@router.post('/create')
async def create(
request: CreateChunkRequest,
current_user = Depends(get_current_user)
):
chunck_id = xxhash.xxh64((request.content_with_weight + request.doc_id).encode("utf-8")).hexdigest()
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(request.content_with_weight),
"content_with_weight": request.content_with_weight}
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["important_kwd"] = req.get("important_kwd", [])
d["important_kwd"] = request.important_kwd
if not isinstance(d["important_kwd"], list):
return get_data_error_result(message="`important_kwd` is required to be a list")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
d["question_kwd"] = req.get("question_kwd", [])
d["question_kwd"] = request.question_kwd
if not isinstance(d["question_kwd"], list):
return get_data_error_result(message="`question_kwd` is required to be a list")
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
if "tag_feas" in req:
d["tag_feas"] = req["tag_feas"]
if "tag_feas" in req:
d["tag_feas"] = req["tag_feas"]
if request.tag_feas is not None:
d["tag_feas"] = request.tag_feas
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
e, doc = DocumentService.get_by_id(request.doc_id)
if not e:
return get_data_error_result(message="Document not found!")
d["kb_id"] = [doc.kb_id]
@@ -254,7 +266,7 @@ def create():
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
d["doc_id"] = doc.id
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
tenant_id = DocumentService.get_tenant_id(request.doc_id)
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
@@ -264,10 +276,10 @@ def create():
if kb.pagerank:
d[PAGERANK_FLD] = kb.pagerank
embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_id = DocumentService.get_embd_id(request.doc_id)
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
v, c = embd_mdl.encode([doc.name, request.content_with_weight if not d["question_kwd"] else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
@@ -279,29 +291,29 @@ def create():
return server_error_response(e)
@manager.route('/retrieval_test', methods=['POST']) # noqa: F821
@login_required
@validate_request("kb_id", "question")
def retrieval_test():
req = request.json
page = int(req.get("page", 1))
size = int(req.get("size", 30))
question = req["question"]
kb_ids = req["kb_id"]
@router.post('/retrieval_test')
async def retrieval_test(
request: RetrievalTestRequest,
current_user = Depends(get_current_user)
):
page = request.page
size = request.size
question = request.question
kb_ids = request.kb_id
if isinstance(kb_ids, str):
kb_ids = [kb_ids]
if not kb_ids:
return get_json_result(data=False, message='Please specify dataset firstly.',
code=settings.RetCode.DATA_ERROR)
doc_ids = req.get("doc_ids", [])
use_kg = req.get("use_kg", False)
top = int(req.get("top_k", 1024))
langs = req.get("cross_languages", [])
doc_ids = request.doc_ids
use_kg = request.use_kg
top = request.top_k
langs = request.cross_languages
tenant_ids = []
if req.get("search_id", ""):
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
if request.search_id:
search_config = SearchService.get_detail(request.search_id).get("search_config", {})
meta_data_filter = search_config.get("meta_data_filter", {})
metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto":
@@ -338,19 +350,19 @@ def retrieval_test():
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
rerank_mdl = None
if req.get("rerank_id"):
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
if request.rerank_id:
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=request.rerank_id)
if req.get("keyword", False):
if request.keyword:
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
labels = label_question(question, [kb])
ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
float(req.get("similarity_threshold", 0.0)),
float(req.get("vector_similarity_weight", 0.3)),
float(request.similarity_threshold),
float(request.vector_similarity_weight),
top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"),
doc_ids, rerank_mdl=rerank_mdl, highlight=request.highlight,
rank_feature=labels
)
if use_kg:
@@ -374,10 +386,11 @@ def retrieval_test():
return server_error_response(e)
@manager.route('/knowledge_graph', methods=['GET']) # noqa: F821
@login_required
def knowledge_graph():
doc_id = request.args["doc_id"]
@router.get('/knowledge_graph')
async def knowledge_graph(
doc_id: str = Query(..., description="文档ID"),
current_user = Depends(get_current_user)
):
tenant_id = DocumentService.get_tenant_id(doc_id)
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
req = {

View File

@@ -23,7 +23,8 @@ from typing import List, Optional
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, Query
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.security import HTTPAuthorizationCredentials
from api.utils.api_utils import security
from api import settings
from api.common.check_team_permission import check_kb_team_permission
@@ -53,7 +54,6 @@ from pydantic import BaseModel
from api.db.db_models import User
# Security
security = HTTPBearer()
# Pydantic models for request/response
class WebCrawlRequest(BaseModel):
@@ -89,7 +89,7 @@ class RemoveDocumentRequest(BaseModel):
class RunDocumentRequest(BaseModel):
doc_ids: List[str]
run: str
run: int
delete: bool = False
class RenameDocumentRequest(BaseModel):
@@ -182,17 +182,17 @@ router = APIRouter()
@router.post("/upload")
async def upload(
kb_id: str = Form(...),
files: List[UploadFile] = File(...),
file: List[UploadFile] = File(...),
current_user = Depends(get_current_user)
):
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
if not files:
if not file:
return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR)
# Use UploadFile directly
file_objs = files
# Use UploadFile directly - file is already a list from multiple file fields
file_objs = file
for file_obj in file_objs:
if file_obj.filename == "":
@@ -580,7 +580,7 @@ async def run(
kb_table_num_map = {}
for id in req.doc_ids:
info = {"run": str(req.run), "progress": 0}
if str(req.run) == TaskStatus.RUNNING.value and req.delete:
if req.run == int(TaskStatus.RUNNING.value) and req.delete:
info["progress_msg"] = ""
info["chunk_num"] = 0
info["token_num"] = 0
@@ -592,12 +592,12 @@ async def run(
if not e:
return get_data_error_result(message="Document not found!")
if str(req.run) == TaskStatus.CANCEL.value:
if req.run == int(TaskStatus.CANCEL.value):
if str(doc.run) == TaskStatus.RUNNING.value:
cancel_all_task_of(id)
else:
return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status")
if all([req.delete, str(req.run) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]):
if all([req.delete, req.run == int(TaskStatus.RUNNING.value), str(doc.run) == TaskStatus.DONE.value]):
DocumentService.clear_chunk_num_when_rerun(doc.id)
DocumentService.update_by_id(id, info)
@@ -606,7 +606,7 @@ async def run(
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
if str(req.run) == TaskStatus.RUNNING.value:
if req.run == int(TaskStatus.RUNNING.value):
doc = doc.to_dict()
doc["tenant_id"] = tenant_id

View File

@@ -18,7 +18,8 @@ from pathlib import Path
from typing import List
from fastapi import APIRouter, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.security import HTTPAuthorizationCredentials
from api.utils.api_utils import security
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
@@ -33,7 +34,6 @@ from api.utils.api_utils import get_json_result
from pydantic import BaseModel
# Security
security = HTTPBearer()
# Pydantic models for request/response
class ConvertRequest(BaseModel):

View File

@@ -20,7 +20,8 @@ from typing import List, Optional
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, Query
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.security import HTTPAuthorizationCredentials
from api.utils.api_utils import security
from api.common.check_team_permission import check_file_team_permission
from api.db.services.document_service import DocumentService
@@ -38,7 +39,6 @@ from rag.utils.storage_factory import STORAGE_IMPL
from pydantic import BaseModel
# Security
security = HTTPBearer()
# Pydantic models for request/response
class CreateFileRequest(BaseModel):

View File

@@ -174,10 +174,26 @@ async def update(
return get_data_error_result(
message="Duplicated knowledgebase name.")
# 构建更新数据,包含所有可更新的字段
update_data = {
"name": name,
"pagerank": request.pagerank
}
# 添加可选字段(如果提供了的话)
if request.description is not None:
update_data["description"] = request.description
if request.permission is not None:
update_data["permission"] = request.permission
if request.avatar is not None:
update_data["avatar"] = request.avatar
if request.parser_id is not None:
update_data["parser_id"] = request.parser_id
if request.embd_id is not None:
update_data["embd_id"] = request.embd_id
if request.parser_config is not None:
update_data["parser_config"] = request.parser_config
if not KnowledgebaseService.update_by_id(kb.id, update_data):
return get_data_error_result()
@@ -195,7 +211,26 @@ async def update(
return get_data_error_result(
message="Database error (Knowledgebase rename)!")
kb = kb.to_dict()
kb.update(update_data)
# 使用完整的请求数据更新返回结果,保持与原来代码的一致性
request_data = {
"name": name,
"pagerank": request.pagerank
}
if request.description is not None:
request_data["description"] = request.description
if request.permission is not None:
request_data["permission"] = request.permission
if request.avatar is not None:
request_data["avatar"] = request.avatar
if request.parser_id is not None:
request_data["parser_id"] = request.parser_id
if request.embd_id is not None:
request_data["embd_id"] = request.embd_id
if request.parser_config is not None:
request_data["parser_config"] = request.parser_config
kb.update(request_data)
return get_json_result(data=kb)
except Exception as e:

View File

@@ -15,22 +15,26 @@
#
import logging
import json
from flask import request
from flask_login import login_required, current_user
from typing import Optional
from fastapi import APIRouter, Depends, Query
from fastapi.responses import JSONResponse
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
from api.db.services.llm_service import LLMService
from api import settings
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result
from api.db import StatusEnum, LLMType
from api.db.db_models import TenantLLM
from api.utils.api_utils import get_json_result
from api.utils.base64_image import test_image
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
from api.models.llm_models import SetApiKeyRequest, AddLLMRequest, DeleteLLMRequest, DeleteFactoryRequest
from api.utils.api_utils import get_current_user
# 创建 FastAPI 路由器
router = APIRouter()
@manager.route('/factories', methods=['GET']) # noqa: F821
@login_required
def factories():
@router.get('/factories')
async def factories(current_user = Depends(get_current_user)):
try:
fac = LLMFactoriesService.get_all()
fac = [f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]]
@@ -50,21 +54,18 @@ def factories():
return server_error_response(e)
@manager.route('/set_api_key', methods=['POST']) # noqa: F821
@login_required
@validate_request("llm_factory", "api_key")
def set_api_key():
req = request.json
@router.post('/set_api_key')
async def set_api_key(request: SetApiKeyRequest, current_user = Depends(get_current_user)):
# test if api key works
chat_passed, embd_passed, rerank_passed = False, False, False
factory = req["llm_factory"]
factory = request.llm_factory
extra = {"provider": factory}
msg = ""
for llm in LLMService.query(fid=factory):
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
mdl = EmbeddingModel[factory](
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
request.api_key, llm.llm_name, base_url=request.base_url)
try:
arr, tc = mdl.encode(["Test if the api key is available"])
if len(arr[0]) == 0:
@@ -75,7 +76,7 @@ def set_api_key():
elif not chat_passed and llm.model_type == LLMType.CHAT.value:
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
mdl = ChatModel[factory](
req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
request.api_key, llm.llm_name, base_url=request.base_url, **extra)
try:
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
{"temperature": 0.9, 'max_tokens': 50})
@@ -88,7 +89,7 @@ def set_api_key():
elif not rerank_passed and llm.model_type == LLMType.RERANK:
assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet."
mdl = RerankModel[factory](
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
request.api_key, llm.llm_name, base_url=request.base_url)
try:
arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
if len(arr) == 0 or tc == 0:
@@ -106,12 +107,9 @@ def set_api_key():
return get_data_error_result(message=msg)
llm_config = {
"api_key": req["api_key"],
"api_base": req.get("base_url", "")
"api_key": request.api_key,
"api_base": request.base_url or ""
}
for n in ["model_type", "llm_name"]:
if n in req:
llm_config[n] = req[n]
for llm in LLMService.query(fid=factory):
llm_config["max_tokens"]=llm.max_tokens
@@ -133,18 +131,15 @@ def set_api_key():
return get_json_result(data=True)
@manager.route('/add_llm', methods=['POST']) # noqa: F821
@login_required
@validate_request("llm_factory")
def add_llm():
req = request.json
factory = req["llm_factory"]
api_key = req.get("api_key", "x")
llm_name = req.get("llm_name")
@router.post('/add_llm')
async def add_llm(request: AddLLMRequest, current_user = Depends(get_current_user)):
factory = request.llm_factory
api_key = request.api_key or "x"
llm_name = request.llm_name
def apikey_json(keys):
nonlocal req
return json.dumps({k: req.get(k, "") for k in keys})
nonlocal request
return json.dumps({k: getattr(request, k, "") for k in keys})
if factory == "VolcEngine":
# For VolcEngine, due to its special authentication method
@@ -152,12 +147,21 @@ def add_llm():
api_key = apikey_json(["ark_api_key", "endpoint_id"])
elif factory == "Tencent Hunyuan":
req["api_key"] = apikey_json(["hunyuan_sid", "hunyuan_sk"])
return set_api_key()
# Create a temporary request object for set_api_key
temp_request = SetApiKeyRequest(
llm_factory=factory,
api_key=apikey_json(["hunyuan_sid", "hunyuan_sk"]),
base_url=request.api_base
)
return await set_api_key(temp_request, current_user)
elif factory == "Tencent Cloud":
req["api_key"] = apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"])
return set_api_key()
temp_request = SetApiKeyRequest(
llm_factory=factory,
api_key=apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"]),
base_url=request.api_base
)
return await set_api_key(temp_request, current_user)
elif factory == "Bedrock":
# For Bedrock, due to its special authentication method
@@ -177,9 +181,9 @@ def add_llm():
llm_name += "___VLLM"
elif factory == "XunFei Spark":
if req["model_type"] == "chat":
api_key = req.get("spark_api_password", "")
elif req["model_type"] == "tts":
if request.model_type == "chat":
api_key = request.spark_api_password or ""
elif request.model_type == "tts":
api_key = apikey_json(["spark_app_id", "spark_api_secret", "spark_api_key"])
elif factory == "BaiduYiyan":
@@ -197,11 +201,11 @@ def add_llm():
llm = {
"tenant_id": current_user.id,
"llm_factory": factory,
"model_type": req["model_type"],
"model_type": request.model_type,
"llm_name": llm_name,
"api_base": req.get("api_base", ""),
"api_base": request.api_base or "",
"api_key": api_key,
"max_tokens": req.get("max_tokens")
"max_tokens": request.max_tokens
}
msg = ""
@@ -290,33 +294,27 @@ def add_llm():
return get_json_result(data=True)
@manager.route('/delete_llm', methods=['POST']) # noqa: F821
@login_required
@validate_request("llm_factory", "llm_name")
def delete_llm():
req = request.json
@router.post('/delete_llm')
async def delete_llm(request: DeleteLLMRequest, current_user = Depends(get_current_user)):
TenantLLMService.filter_delete(
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"],
TenantLLM.llm_name == req["llm_name"]])
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == request.llm_factory,
TenantLLM.llm_name == request.llm_name])
return get_json_result(data=True)
@manager.route('/delete_factory', methods=['POST']) # noqa: F821
@login_required
@validate_request("llm_factory")
def delete_factory():
req = request.json
@router.post('/delete_factory')
async def delete_factory(request: DeleteFactoryRequest, current_user = Depends(get_current_user)):
TenantLLMService.filter_delete(
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]])
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == request.llm_factory])
return get_json_result(data=True)
@manager.route('/my_llms', methods=['GET']) # noqa: F821
@login_required
def my_llms():
@router.get('/my_llms')
async def my_llms(
include_details: bool = Query(False, description="是否包含详细信息"),
current_user = Depends(get_current_user)
):
try:
include_details = request.args.get('include_details', 'false').lower() == 'true'
if include_details:
res = {}
objs = TenantLLMService.query(tenant_id=current_user.id)
@@ -362,12 +360,13 @@ def my_llms():
return server_error_response(e)
@manager.route('/list', methods=['GET']) # noqa: F821
@login_required
def list_app():
@router.get('/list')
async def list_app(
model_type: Optional[str] = Query(None, description="模型类型"),
current_user = Depends(get_current_user)
):
self_deployed = ["Youdao", "FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"]
weighted = ["Youdao", "FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else []
model_type = request.args.get("model_type")
try:
objs = TenantLLMService.query(tenant_id=current_user.id)
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])

View File

@@ -15,7 +15,8 @@
#
from typing import List, Optional, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.security import HTTPAuthorizationCredentials
from api.utils.api_utils import security
from api import settings
from api.db import VALID_MCP_SERVER_TYPES
@@ -31,7 +32,6 @@ from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_
from pydantic import BaseModel
# Security
security = HTTPBearer()
# Pydantic models for request/response
class ListMCPRequest(BaseModel):

View File

@@ -14,8 +14,9 @@
# limitations under the License.
#
from flask import request
from flask_login import login_required, current_user
from fastapi import APIRouter, Depends, HTTPException, Query, status
from api.models.tenant_models import InviteUserRequest, UserTenantResponse
from api.utils.api_utils import get_current_user
from api import settings
from api.apps import smtp_mail_server
@@ -24,13 +25,18 @@ from api.db.db_models import UserTenant
from api.db.services.user_service import UserTenantService, UserService
from api.utils import get_uuid, delta_seconds
from api.utils.api_utils import get_json_result, validate_request, server_error_response, get_data_error_result
from api.utils.api_utils import get_json_result, server_error_response, get_data_error_result
from api.utils.web_utils import send_invite_email
# 创建 FastAPI 路由器
router = APIRouter()
@manager.route("/<tenant_id>/user/list", methods=["GET"]) # noqa: F821
@login_required
def user_list(tenant_id):
@router.get("/{tenant_id}/user/list")
async def user_list(
tenant_id: str,
current_user = Depends(get_current_user)
):
if current_user.id != tenant_id:
return get_json_result(
data=False,
@@ -46,18 +52,19 @@ def user_list(tenant_id):
return server_error_response(e)
@manager.route('/<tenant_id>/user', methods=['POST']) # noqa: F821
@login_required
@validate_request("email")
def create(tenant_id):
@router.post('/{tenant_id}/user')
async def create(
tenant_id: str,
request: InviteUserRequest,
current_user = Depends(get_current_user)
):
if current_user.id != tenant_id:
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json
invite_user_email = req["email"]
invite_user_email = request.email
invite_users = UserService.query(email=invite_user_email)
if not invite_users:
return get_data_error_result(message="User not found.")
@@ -100,9 +107,12 @@ def create(tenant_id):
return get_json_result(data=usr)
@manager.route('/<tenant_id>/user/<user_id>', methods=['DELETE']) # noqa: F821
@login_required
def rm(tenant_id, user_id):
@router.delete('/{tenant_id}/user/{user_id}')
async def rm(
tenant_id: str,
user_id: str,
current_user = Depends(get_current_user)
):
if current_user.id != tenant_id and current_user.id != user_id:
return get_json_result(
data=False,
@@ -116,9 +126,10 @@ def rm(tenant_id, user_id):
return server_error_response(e)
@manager.route("/list", methods=["GET"]) # noqa: F821
@login_required
def tenant_list():
@router.get("/list")
async def tenant_list(
current_user = Depends(get_current_user)
):
try:
users = UserTenantService.get_tenants_by_user_id(current_user.id)
for u in users:
@@ -128,9 +139,11 @@ def tenant_list():
return server_error_response(e)
@manager.route("/agree/<tenant_id>", methods=["PUT"]) # noqa: F821
@login_required
def agree(tenant_id):
@router.put("/agree/{tenant_id}")
async def agree(
tenant_id: str,
current_user = Depends(get_current_user)
):
try:
UserTenantService.filter_update([UserTenant.tenant_id == tenant_id, UserTenant.user_id == current_user.id], {"role": UserTenantRole.NORMAL})
return get_json_result(data=True)

View File

@@ -21,7 +21,8 @@ from datetime import datetime
from typing import Optional, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.security import HTTPAuthorizationCredentials
from api.utils.api_utils import security
from fastapi.responses import RedirectResponse
from pydantic import BaseModel, EmailStr
try:
@@ -65,7 +66,6 @@ from api.utils.crypt import decrypt
router = APIRouter()
# 安全方案
security = HTTPBearer()
# Pydantic模型
class LoginRequest(BaseModel):

View File

@@ -0,0 +1,88 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
class ListChunkRequest(BaseModel):
"""列出块请求模型"""
doc_id: str = Field(..., description="文档ID")
page: Optional[int] = Field(1, description="页码")
size: Optional[int] = Field(30, description="每页大小")
keywords: Optional[str] = Field("", description="关键词")
available_int: Optional[int] = Field(None, description="可用性状态")
class GetChunkRequest(BaseModel):
"""获取块请求模型"""
chunk_id: str = Field(..., description="块ID")
class SetChunkRequest(BaseModel):
"""设置块请求模型"""
doc_id: str = Field(..., description="文档ID")
chunk_id: str = Field(..., description="块ID")
content_with_weight: str = Field(..., description="带权重的内容")
important_kwd: Optional[List[str]] = Field(None, description="重要关键词")
question_kwd: Optional[List[str]] = Field(None, description="问题关键词")
tag_kwd: Optional[str] = Field(None, description="标签关键词")
tag_feas: Optional[Any] = Field(None, description="标签特征")
available_int: Optional[int] = Field(None, description="可用性状态")
class SwitchChunkRequest(BaseModel):
"""切换块状态请求模型"""
chunk_ids: List[str] = Field(..., description="块ID列表")
available_int: int = Field(..., description="可用性状态")
doc_id: str = Field(..., description="文档ID")
class RemoveChunkRequest(BaseModel):
"""删除块请求模型"""
chunk_ids: List[str] = Field(..., description="块ID列表")
doc_id: str = Field(..., description="文档ID")
class CreateChunkRequest(BaseModel):
"""创建块请求模型"""
doc_id: str = Field(..., description="文档ID")
content_with_weight: str = Field(..., description="带权重的内容")
important_kwd: Optional[List[str]] = Field([], description="重要关键词")
question_kwd: Optional[List[str]] = Field([], description="问题关键词")
tag_feas: Optional[Any] = Field(None, description="标签特征")
class RetrievalTestRequest(BaseModel):
"""检索测试请求模型"""
kb_id: List[str] = Field(..., description="知识库ID列表")
question: str = Field(..., description="问题")
page: Optional[int] = Field(1, description="页码")
size: Optional[int] = Field(30, description="每页大小")
doc_ids: Optional[List[str]] = Field([], description="文档ID列表")
use_kg: Optional[bool] = Field(False, description="是否使用知识图谱")
top_k: Optional[int] = Field(1024, description="返回数量")
cross_languages: Optional[List[str]] = Field([], description="跨语言列表")
search_id: Optional[str] = Field("", description="搜索ID")
rerank_id: Optional[str] = Field(None, description="重排序ID")
keyword: Optional[bool] = Field(False, description="是否使用关键词")
similarity_threshold: Optional[float] = Field(0.0, description="相似度阈值")
vector_similarity_weight: Optional[float] = Field(0.3, description="向量相似度权重")
highlight: Optional[bool] = Field(None, description="是否高亮")
class KnowledgeGraphRequest(BaseModel):
"""知识图谱请求模型"""
doc_id: str = Field(..., description="文档ID")

View File

@@ -30,6 +30,12 @@ class UpdateKnowledgeBaseRequest(BaseModel):
"""更新知识库请求模型"""
kb_id: str = Field(..., description="知识库ID")
name: str = Field(..., description="知识库名称")
description: Optional[str] = Field(None, description="知识库描述")
permission: Optional[str] = Field(None, description="权限设置")
avatar: Optional[str] = Field(None, description="头像base64字符串")
parser_id: Optional[str] = Field(None, description="解析器ID")
embd_id: Optional[str] = Field(None, description="嵌入模型ID")
parser_config: Optional[Dict[str, Any]] = Field(None, description="解析器配置")
pagerank: Optional[int] = Field(0, description="页面排名")

84
api/models/llm_models.py Normal file
View File

@@ -0,0 +1,84 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, Dict, Any
from pydantic import BaseModel, Field
class SetApiKeyRequest(BaseModel):
"""设置API密钥请求模型"""
llm_factory: str = Field(..., description="LLM工厂名称")
api_key: str = Field(..., description="API密钥")
base_url: Optional[str] = Field(None, description="基础URL")
class AddLLMRequest(BaseModel):
"""添加LLM请求模型"""
llm_factory: str = Field(..., description="LLM工厂名称")
api_key: Optional[str] = Field("x", description="API密钥")
llm_name: Optional[str] = Field(None, description="LLM名称")
model_type: Optional[str] = Field(None, description="模型类型")
api_base: Optional[str] = Field(None, description="API基础URL")
max_tokens: Optional[int] = Field(None, description="最大token数")
# VolcEngine specific fields
ark_api_key: Optional[str] = Field(None, description="VolcEngine ARK API密钥")
endpoint_id: Optional[str] = Field(None, description="VolcEngine端点ID")
# Tencent Hunyuan specific fields
hunyuan_sid: Optional[str] = Field(None, description="腾讯混元SID")
hunyuan_sk: Optional[str] = Field(None, description="腾讯混元SK")
# Tencent Cloud specific fields
tencent_cloud_sid: Optional[str] = Field(None, description="腾讯云SID")
tencent_cloud_sk: Optional[str] = Field(None, description="腾讯云SK")
# Bedrock specific fields
bedrock_ak: Optional[str] = Field(None, description="Bedrock访问密钥")
bedrock_sk: Optional[str] = Field(None, description="Bedrock秘密密钥")
bedrock_region: Optional[str] = Field(None, description="Bedrock区域")
# XunFei Spark specific fields
spark_api_password: Optional[str] = Field(None, description="讯飞Spark API密码")
spark_app_id: Optional[str] = Field(None, description="讯飞Spark应用ID")
spark_api_secret: Optional[str] = Field(None, description="讯飞Spark API密钥")
spark_api_key: Optional[str] = Field(None, description="讯飞Spark API密钥")
# BaiduYiyan specific fields
yiyan_ak: Optional[str] = Field(None, description="百度文心一言AK")
yiyan_sk: Optional[str] = Field(None, description="百度文心一言SK")
# Fish Audio specific fields
fish_audio_ak: Optional[str] = Field(None, description="Fish Audio AK")
fish_audio_refid: Optional[str] = Field(None, description="Fish Audio参考ID")
# Google Cloud specific fields
google_project_id: Optional[str] = Field(None, description="Google Cloud项目ID")
google_region: Optional[str] = Field(None, description="Google Cloud区域")
google_service_account_key: Optional[str] = Field(None, description="Google Cloud服务账户密钥")
# Azure OpenAI specific fields
api_version: Optional[str] = Field(None, description="Azure OpenAI API版本")
class DeleteLLMRequest(BaseModel):
"""删除LLM请求模型"""
llm_factory: str = Field(..., description="LLM工厂名称")
llm_name: str = Field(..., description="LLM名称")
class DeleteFactoryRequest(BaseModel):
"""删除工厂请求模型"""
llm_factory: str = Field(..., description="LLM工厂名称")

View File

@@ -0,0 +1,30 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional
from pydantic import BaseModel, Field
class InviteUserRequest(BaseModel):
"""邀请用户请求模型"""
email: str = Field(..., description="用户邮箱")
class UserTenantResponse(BaseModel):
"""用户租户响应模型"""
id: str = Field(..., description="用户ID")
avatar: Optional[str] = Field(None, description="用户头像")
email: str = Field(..., description="用户邮箱")
nickname: Optional[str] = Field(None, description="用户昵称")

View File

@@ -38,6 +38,8 @@ from fastapi import Request, Response as FastAPIResponse, HTTPException, status
from fastapi.responses import JSONResponse, FileResponse, StreamingResponse
from fastapi import Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.security.base import SecurityBase
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
from itsdangerous import URLSafeTimedSerializer
from peewee import OperationalError
from werkzeug.http import HTTP_STATUS_CODES
@@ -51,8 +53,35 @@ from api.db.services.llm_service import LLMService
from api.db.services.tenant_llm_service import TenantLLMService
from api.utils.json import CustomJSONEncoder, json_dumps
# 自定义认证方案支持不传Bearer格式
class CustomHTTPBearer(SecurityBase):
def __init__(self, *, scheme_name: str = None, auto_error: bool = True):
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
# 添加 model 属性用于 OpenAPI 文档生成
self.model = HTTPBearer()
async def __call__(self, request: Request) -> HTTPAuthorizationCredentials:
authorization: str = request.headers.get("Authorization")
if not authorization:
if self.auto_error:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authenticated"
)
else:
return None
# 支持Bearer格式和直接token格式
if authorization.startswith("Bearer "):
token = authorization[7:] # 移除"Bearer "前缀
else:
token = authorization # 直接使用token
return HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
# FastAPI 安全方案
security = HTTPBearer()
security = CustomHTTPBearer()
from api.utils import get_uuid
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions

View File

@@ -0,0 +1,175 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
OCR HTTP 客户端
用于调用独立的 OCR 服务的 HTTP API
"""
import os
import logging
import requests
from typing import Optional, Union, Dict, Any
logger = logging.getLogger(__name__)
class OCRHttpClient:
"""OCR HTTP 客户端,用于调用独立的 OCR 服务"""
def __init__(self, base_url: Optional[str] = None, timeout: int = 300):
"""
初始化 OCR HTTP 客户端
Args:
base_url: OCR 服务的基础 URL如果不提供则从环境变量 OCR_SERVICE_URL 读取
默认值为 http://localhost:8000
timeout: 请求超时时间(秒),默认 300 秒
"""
if base_url is None:
base_url = os.getenv("OCR_SERVICE_URL", "http://localhost:8000")
# 确保 URL 不包含尾随斜杠
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self.api_prefix = "/api/v1/ocr"
logger.info(f"Initialized OCR HTTP client with base_url: {self.base_url}")
def parse_pdf_by_path(self, file_path: str, page_from: int = 1, page_to: int = 0, zoomin: int = 3) -> Dict[str, Any]:
"""
通过文件路径解析 PDF
Args:
file_path: PDF 文件的本地路径
page_from: 起始页码从1开始
page_to: 结束页码0表示最后一页
zoomin: 图像放大倍数1-5
Returns:
dict: 解析结果,格式:
{
"success": bool,
"message": str,
"data": {
"pages": [
{
"page_number": int,
"boxes": [
{
"text": str,
"bbox": [[x0, y0], [x1, y0], [x1, y1], [x0, y1]],
"confidence": float
},
...
]
},
...
]
}
}
Raises:
requests.RequestException: HTTP 请求失败
ValueError: 响应格式不正确
"""
url = f"{self.base_url}{self.api_prefix}/parse/path"
data = {
"file_path": file_path,
"page_from": page_from,
"page_to": page_to,
"zoomin": zoomin
}
try:
logger.info(f"Calling OCR service: {url} for file: {file_path}")
response = requests.post(url, data=data, timeout=self.timeout)
response.raise_for_status()
result = response.json()
if not result.get("success", False):
raise ValueError(f"OCR service returned error: {result.get('message', 'Unknown error')}")
return result
except requests.RequestException as e:
logger.error(f"Failed to call OCR service: {e}")
raise
def parse_pdf_by_bytes(self, pdf_bytes: bytes, filename: str = "document.pdf",
page_from: int = 1, page_to: int = 0, zoomin: int = 3) -> Dict[str, Any]:
"""
通过二进制数据解析 PDF
Args:
pdf_bytes: PDF 文件的二进制数据
filename: 文件名(仅用于日志)
page_from: 起始页码从1开始
page_to: 结束页码0表示最后一页
zoomin: 图像放大倍数1-5
Returns:
dict: 解析结果,格式同 parse_pdf_by_path
Raises:
requests.RequestException: HTTP 请求失败
ValueError: 响应格式不正确
"""
url = f"{self.base_url}{self.api_prefix}/parse/bytes"
files = {
"pdf_bytes": (filename, pdf_bytes, "application/pdf")
}
data = {
"filename": filename,
"page_from": page_from,
"page_to": page_to,
"zoomin": zoomin
}
try:
logger.info(f"Calling OCR service: {url} with {len(pdf_bytes)} bytes")
response = requests.post(url, files=files, data=data, timeout=self.timeout)
response.raise_for_status()
result = response.json()
if not result.get("success", False):
raise ValueError(f"OCR service returned error: {result.get('message', 'Unknown error')}")
return result
except requests.RequestException as e:
logger.error(f"Failed to call OCR service: {e}")
raise
def health_check(self) -> Dict[str, Any]:
"""
检查 OCR 服务健康状态
Returns:
dict: 健康状态信息
"""
url = f"{self.base_url}{self.api_prefix}/health"
try:
response = requests.get(url, timeout=10)
response.raise_for_status()
return response.json()
except requests.RequestException as e:
logger.error(f"Failed to check OCR service health: {e}")
raise

File diff suppressed because it is too large Load Diff

View File

@@ -94,7 +94,7 @@ SVR_HTTP_PORT=9380
# The RAGFlow Docker image to download.
# Defaults to the v0.20.5-slim edition, which is the RAGFlow Docker image without embedding models.
RAGFLOW_IMAGE=infiniflow/ragflow:v0.20.5-slim
RAGFLOW_IMAGE=infiniflow/ragflow:fastapi
#
# To download the RAGFlow Docker image with embedding models, uncomment the following line instead:
# RAGFLOW_IMAGE=infiniflow/ragflow:v0.20.5
@@ -198,4 +198,6 @@ POSTGRES_DBNAME=rag_flow
POSTGRES_USER=rag_flow
POSTGRES_PASSWORD=infini_rag_flow
POSTGRES_PORT=5432
DB_TYPE=postgres
DB_TYPE=postgres
USE_OCR_HTTP=true

View File

@@ -22,6 +22,7 @@ services:
- cluster.routing.allocation.disk.watermark.flood_stage=2gb
- TZ=${TIMEZONE}
- http.port=9201
- OPENSEARCH_JAVA_OPTS=-Xms512m -Xmx512m
mem_limit: ${MEM_LIMIT}
ulimits:
memlock:
@@ -106,19 +107,19 @@ services:
volumes:
esdata01:
driver: local
name: ragflow_esdata01
osdata01:
driver: local
name: ragflow_osdata01
infinity_data:
driver: local
name: ragflow_infinity_data
mysql_data:
driver: local
name: ragflow_mysql_data
minio_data:
driver: local
name: ragflow_minio_data
redis_data:
driver: local
name: ragflow_redis_data
postgres_data:
driver: local
name: ragflow_postgres_data
networks:
ragflow:

View File

@@ -155,6 +155,8 @@ function task_exe() {
while true; do
LD_PRELOAD="$JEMALLOC_PATH" \
"$PY" rag/svr/task_executor.py "${host_id}_${consumer_id}"
echo "task_executor exited. Sleeping 5s before restart."
sleep 5
done
}
@@ -181,7 +183,7 @@ if [[ "${ENABLE_WEBSERVER}" -eq 1 ]]; then
echo "Starting ragflow_server..."
while true; do
"$PY" api/ragflow_server.py
"$PY" api/ragflow_server_fastapi.py
done &
fi

View File

@@ -21,7 +21,7 @@ if ! docker network ls --format "{{.Name}}" | grep -q "ragflow-20250916_ragflow"
fi
echo "启动 ragflow 服务..."
docker-compose -p ragflow -f docker-compose.yml up -d ragflow
docker compose -p ragflow -f docker-compose.yml up -d ragflow
echo "ragflow 服务启动完成!"
echo "访问地址: http://localhost:${SVR_HTTP_PORT:-9380}"

View File

@@ -3,7 +3,10 @@
# 停止脚本:只停止 ragflow 服务,保留基础服务
echo "停止 ragflow 服务..."
docker-compose -f docker-compose.yml down
docker compose -p ragflow -f docker-compose.yml down
# 等待一段时间确保完全停止
sleep 10
echo "ragflow 服务已停止"
echo "基础服务postgres、redis、minio、opensearch仍在运行"

191
main-ocr.py Normal file
View File

@@ -0,0 +1,191 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
OCR PDF处理服务的主程序入口
独立运行不依赖RAGFlow的其他部分
"""
import argparse
import logging
import os
import sys
import signal
from pathlib import Path
# 确保项目根目录在 sys.path 中
_current_file = Path(__file__).resolve()
_project_root = _current_file.parent.parent
if str(_project_root) not in sys.path:
sys.path.insert(0, str(_project_root))
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from ocr.api import ocr_router
from ocr.config import MODEL_DIR
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
def create_app() -> FastAPI:
"""创建FastAPI应用实例"""
app = FastAPI(
title="OCR PDF Parser API",
description="独立的OCR PDF处理服务提供PDF文档的OCR识别功能",
version="1.0.0",
docs_url="/apidocs", # Swagger UI 文档地址
redoc_url="/redoc", # ReDoc 文档地址(备用)
openapi_url="/openapi.json" # OpenAPI JSON schema 地址
)
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 生产环境中应该设置具体的域名
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 注册OCR路由
app.include_router(ocr_router)
# 根路径
@app.get("/")
async def root():
return {
"service": "OCR PDF Parser",
"version": "1.0.0",
"docs": "/apidocs",
"health": "/api/v1/ocr/health"
}
return app
def signal_handler(sig, frame):
"""信号处理器,用于优雅关闭"""
logger.info("Received shutdown signal, exiting...")
sys.exit(0)
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="OCR PDF处理服务")
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="服务器监听地址 (default: 0.0.0.0)"
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="服务器端口 (default: 8000)"
)
parser.add_argument(
"--reload",
action="store_true",
help="开发模式:自动重载代码"
)
parser.add_argument(
"--workers",
type=int,
default=1,
help="工作进程数 (default: 1)"
)
parser.add_argument(
"--log-level",
type=str,
default="info",
choices=["critical", "error", "warning", "info", "debug", "trace"],
help="日志级别 (default: info)"
)
parser.add_argument(
"--model-dir",
type=str,
default=None,
help=f"OCR模型目录路径 (default: {MODEL_DIR})"
)
args = parser.parse_args()
# 设置模型目录(如果提供)
if args.model_dir:
os.environ["OCR_MODEL_DIR"] = args.model_dir
logger.info(f"Using custom model directory: {args.model_dir}")
# 检查模型目录
model_dir = os.environ.get("OCR_MODEL_DIR", MODEL_DIR)
if model_dir and not os.path.exists(model_dir):
logger.warning(f"Model directory does not exist: {model_dir}")
logger.info("Models will be downloaded on first use")
# 注册信号处理器
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# 显示启动信息
logger.info("=" * 60)
logger.info("OCR PDF Parser Service")
logger.info("=" * 60)
logger.info(f"Host: {args.host}")
logger.info(f"Port: {args.port}")
logger.info(f"Model Directory: {model_dir}")
logger.info(f"Workers: {args.workers}")
logger.info(f"Reload: {args.reload}")
logger.info(f"Log Level: {args.log_level}")
logger.info("=" * 60)
logger.info(f"API Documentation (Swagger): http://{args.host}:{args.port}/apidocs")
logger.info(f"API Documentation (ReDoc): http://{args.host}:{args.port}/redoc")
logger.info(f"Health Check: http://{args.host}:{args.port}/api/v1/ocr/health")
logger.info("=" * 60)
# 创建应用
app = create_app()
# 启动服务器
try:
uvicorn.run(
app,
host=args.host,
port=args.port,
log_level=args.log_level,
reload=args.reload,
workers=args.workers if not args.reload else 1, # reload模式不支持多进程
access_log=True
)
except KeyboardInterrupt:
logger.info("Server stopped by user")
except Exception as e:
logger.error(f"Server error: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()

41
ocr/__init__.py Normal file
View File

@@ -0,0 +1,41 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
独立的 OCR 模块
此模块从 RAGFlow 项目中提取,已经移除了对 RAGFlow 特定模块的依赖。
可以直接作为独立模块使用。
使用方法:
from ocr import OCR, SimplePdfParser
import cv2
ocr = OCR()
img = cv2.imread("image.jpg")
results = ocr(img)
parser = SimplePdfParser()
result = parser.parse_pdf("document.pdf")
"""
# 处理导入问题:支持直接运行和模块导入
import sys
from pathlib import Path
__all__ = ['OCR', 'TextDetector', 'TextRecognizer', 'SimplePdfParser']

525
ocr/api.py Normal file
View File

@@ -0,0 +1,525 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
OCR PDF处理的FastAPI路由
提供HTTP接口用于PDF的OCR识别
"""
import asyncio
import logging
import os
import sys
import tempfile
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from ocr import SimplePdfParser
from ocr.config import MODEL_DIR
logger = logging.getLogger(__name__)
ocr_router = APIRouter(prefix="/api/v1/ocr", tags=["OCR"])
# 全局解析器实例(懒加载)
_parser_instance: Optional[SimplePdfParser] = None
def get_parser() -> SimplePdfParser:
"""获取全局解析器实例(单例模式)"""
global _parser_instance
if _parser_instance is None:
logger.info(f"Initializing OCR parser with model_dir={MODEL_DIR}")
_parser_instance = SimplePdfParser(model_dir=MODEL_DIR)
return _parser_instance
class ParseResponse(BaseModel):
"""解析响应模型"""
success: bool
message: str
data: Optional[dict] = None
@ocr_router.get(
"/health",
summary="健康检查",
description="检查OCR服务的健康状态和配置信息",
response_description="返回服务状态和模型目录信息"
)
async def health_check():
"""
健康检查端点
用于检查OCR服务的运行状态和配置信息。
Returns:
dict: 包含服务状态和模型目录的信息
"""
return {
"status": "healthy",
"service": "OCR PDF Parser",
"model_dir": MODEL_DIR
}
@ocr_router.post(
"/parse",
response_model=ParseResponse,
summary="上传并解析PDF文件",
description="上传PDF文件并通过OCR识别提取文本内容",
response_description="返回OCR识别结果"
)
async def parse_pdf_endpoint(
file: UploadFile = File(..., description="PDF文件支持上传任意PDF文档"),
page_from: int = Form(1, ge=1, description="起始页码从1开始默认为1"),
page_to: int = Form(0, ge=0, description="结束页码0表示解析到最后一页默认为0"),
zoomin: int = Form(3, ge=1, le=5, description="图像放大倍数1-5数值越大识别精度越高但速度越慢默认为3")
):
"""
上传并解析PDF文件
通过上传PDF文件使用OCR技术识别并提取其中的文本内容。
支持指定解析的页码范围,以及调整图像放大倍数以平衡识别精度和速度。
Args:
file: 上传的PDF文件multipart/form-data格式
page_from: 起始页码从1开始最小值为1
page_to: 结束页码0表示解析到最后一页最小值为0
zoomin: 图像放大倍数1-5之间数值越大识别精度越高但处理速度越慢
Returns:
ParseResponse: 包含解析结果的响应对象,包括:
- success: 是否成功
- message: 操作结果消息
- data: OCR识别的文本内容和元数据
Raises:
HTTPException: 400 - 如果文件不是PDF格式或文件为空
HTTPException: 500 - 如果解析过程中发生错误
"""
if not file.filename.lower().endswith('.pdf'):
raise HTTPException(status_code=400, detail="只支持PDF文件")
# 保存上传的文件到临时目录
temp_file = None
try:
# 读取文件内容
content = await file.read()
if not content:
raise HTTPException(status_code=400, detail="文件为空")
# 创建临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
tmp.write(content)
temp_file = tmp.name
logger.info(f"Parsing PDF file: {file.filename}, pages {page_from}-{page_to or 'end'}, zoomin={zoomin}")
# 解析PDFparse_pdf是同步方法使用to_thread在线程池中执行
parser = get_parser()
result = await asyncio.to_thread(
parser.parse_pdf,
temp_file,
zoomin,
page_from - 1, # 转换为从0开始的索引
(page_to - 1) if page_to > 0 else 299, # 转换为从0开始的索引
None # callback
)
return ParseResponse(
success=True,
message=f"成功解析PDF: {file.filename}",
data=result
)
except Exception as e:
logger.error(f"Error parsing PDF: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"解析PDF时发生错误: {str(e)}"
)
finally:
# 清理临时文件
if temp_file and os.path.exists(temp_file):
try:
os.unlink(temp_file)
except Exception as e:
logger.warning(f"Failed to delete temp file {temp_file}: {e}")
@ocr_router.post(
"/parse/bytes",
response_model=ParseResponse,
summary="通过二进制数据解析PDF",
description="直接通过二进制数据解析PDF文件无需上传文件",
response_description="返回OCR识别结果"
)
async def parse_pdf_bytes(
pdf_bytes: bytes = File(..., description="PDF文件的二进制数据multipart/form-data格式"),
filename: str = Form("document.pdf", description="文件名(仅用于日志记录,不影响解析)"),
page_from: int = Form(1, ge=1, description="起始页码从1开始默认为1"),
page_to: int = Form(0, ge=0, description="结束页码0表示解析到最后一页默认为0"),
zoomin: int = Form(3, ge=1, le=5, description="图像放大倍数1-5数值越大识别精度越高但速度越慢默认为3")
):
"""
直接通过二进制数据解析PDF
适用于已获取PDF二进制数据的场景无需文件上传步骤。
直接将PDF的二进制数据提交即可进行OCR识别。
Args:
pdf_bytes: PDF文件的二进制数据以文件形式提交
filename: 文件名(仅用于日志记录,不影响实际解析过程)
page_from: 起始页码从1开始最小值为1
page_to: 结束页码0表示解析到最后一页最小值为0
zoomin: 图像放大倍数1-5之间数值越大识别精度越高但处理速度越慢
Returns:
ParseResponse: 包含解析结果的响应对象
Raises:
HTTPException: 400 - 如果PDF数据为空
HTTPException: 500 - 如果解析过程中发生错误
"""
if not pdf_bytes:
raise HTTPException(status_code=400, detail="PDF数据为空")
# 保存到临时文件
temp_file = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
tmp.write(pdf_bytes)
temp_file = tmp.name
logger.info(f"Parsing PDF bytes (filename: {filename}), pages {page_from}-{page_to or 'end'}, zoomin={zoomin}")
# 解析PDFparse_pdf是同步方法使用to_thread在线程池中执行
parser = get_parser()
result = await asyncio.to_thread(
parser.parse_pdf,
temp_file,
zoomin,
page_from - 1, # 转换为从0开始的索引
(page_to - 1) if page_to > 0 else 299, # 转换为从0开始的索引
None # callback
)
return ParseResponse(
success=True,
message=f"成功解析PDF: {filename}",
data=result
)
except Exception as e:
logger.error(f"Error parsing PDF bytes: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"解析PDF时发生错误: {str(e)}"
)
finally:
# 清理临时文件
if temp_file and os.path.exists(temp_file):
try:
os.unlink(temp_file)
except Exception as e:
logger.warning(f"Failed to delete temp file {temp_file}: {e}")
@ocr_router.post(
"/parse/path",
response_model=ParseResponse,
summary="通过文件路径解析PDF",
description="通过服务器本地文件路径解析PDF文件",
response_description="返回OCR识别结果"
)
async def parse_pdf_path(
file_path: str = Form(..., description="PDF文件在服务器上的本地路径必须是可访问的绝对路径"),
page_from: int = Form(1, ge=1, description="起始页码从1开始默认为1"),
page_to: int = Form(0, ge=0, description="结束页码0表示解析到最后一页默认为0"),
zoomin: int = Form(3, ge=1, le=5, description="图像放大倍数1-5数值越大识别精度越高但速度越慢默认为3")
):
"""
通过文件路径解析PDF
适用于PDF文件已经存在于服务器上的场景。
通过提供文件路径直接进行OCR识别无需上传文件。
Args:
file_path: PDF文件在服务器上的本地路径必须是服务器可访问的绝对路径
page_from: 起始页码从1开始最小值为1
page_to: 结束页码0表示解析到最后一页最小值为0
zoomin: 图像放大倍数1-5之间数值越大识别精度越高但处理速度越慢
Returns:
ParseResponse: 包含解析结果的响应对象
Raises:
HTTPException: 400 - 如果文件不是PDF格式
HTTPException: 404 - 如果文件不存在
HTTPException: 500 - 如果解析过程中发生错误
Note:
此端点需要确保提供的文件路径在服务器上可访问。
建议仅在内网环境或受信任的环境中使用,避免路径遍历安全风险。
"""
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail=f"文件不存在: {file_path}")
if not file_path.lower().endswith('.pdf'):
raise HTTPException(status_code=400, detail="只支持PDF文件")
try:
logger.info(f"Parsing PDF from path: {file_path}, pages {page_from}-{page_to or 'end'}, zoomin={zoomin}")
# 解析PDFparse_pdf是同步方法使用to_thread在线程池中执行
parser = get_parser()
result = await asyncio.to_thread(
parser.parse_pdf,
file_path,
zoomin,
page_from - 1, # 转换为从0开始的索引
(page_to - 1) if page_to > 0 else 299, # 转换为从0开始的索引
None # callback
)
return ParseResponse(
success=True,
message=f"成功解析PDF: {file_path}",
data=result
)
except Exception as e:
logger.error(f"Error parsing PDF from path: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"解析PDF时发生错误: {str(e)}"
)
@ocr_router.post(
"/parse_into_bboxes",
summary="解析PDF并返回边界框",
description="解析PDF文件并返回文本边界框信息用于文档结构化处理",
response_description="返回包含文本边界框的列表"
)
async def parse_into_bboxes_endpoint(
pdf_bytes: bytes = File(..., description="PDF文件的二进制数据"),
filename: str = Form("document.pdf", description="文件名(仅用于日志)"),
zoomin: int = Form(3, ge=1, le=5, description="图像放大倍数1-5默认为3")
):
"""
解析PDF并返回边界框
此接口用于将PDF文档解析为结构化文本边界框每个边界框包含
- 文本内容
- 页面编号
- 坐标信息x0, x1, top, bottom
- 布局类型(如 text, table, figure 等)
- 图像数据(如果有)
Args:
pdf_bytes: PDF文件的二进制数据
filename: 文件名(仅用于日志记录)
zoomin: 图像放大倍数1-5之间
Returns:
dict: 包含解析结果的对象data字段为边界框列表
Raises:
HTTPException: 400 - 如果PDF数据为空
HTTPException: 500 - 如果解析过程中发生错误
"""
if not pdf_bytes:
raise HTTPException(status_code=400, detail="PDF数据为空")
temp_file = None
try:
# 保存到临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
tmp.write(pdf_bytes)
temp_file = tmp.name
logger.info(f"Parsing PDF into bboxes: {filename}, zoomin={zoomin}")
# 定义一个简单的callback包装器用于处理进度回调记录日志
def progress_callback(prog, msg):
logger.info(f"Progress: {prog:.2%} - {msg}")
parser = get_parser()
result = await asyncio.to_thread(
parser.parse_into_bboxes,
temp_file,
progress_callback,
zoomin
)
# 将图像数据转换为base64或None
processed_result = []
for bbox in result:
processed_bbox = dict(bbox)
# 如果有图像转换为base64如果需要的话可以在这里处理
# 但为了保持兼容性,我们保留原始格式
processed_result.append(processed_bbox)
return ParseResponse(
success=True,
message=f"成功解析PDF为边界框: {filename}",
data={"bboxes": processed_result}
)
except Exception as e:
logger.error(f"Error parsing PDF into bboxes: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"解析PDF为边界框时发生错误: {str(e)}"
)
finally:
# 清理临时文件
if temp_file and os.path.exists(temp_file):
try:
os.unlink(temp_file)
except Exception as e:
logger.warning(f"Failed to delete temp file {temp_file}: {e}")
class TextRequest(BaseModel):
"""文本处理请求模型"""
text: str = Field(..., description="需要处理的文本内容")
class RemoveTagResponse(BaseModel):
"""移除标签响应模型"""
success: bool
message: str
text: Optional[str] = None
@ocr_router.post(
"/remove_tag",
response_model=RemoveTagResponse,
summary="移除文本中的位置标签",
description="从文本中移除PDF解析生成的位置标签格式@@页码\t坐标##",
response_description="返回移除标签后的文本"
)
async def remove_tag_endpoint(request: TextRequest):
"""
移除文本中的位置标签
此接口用于从包含位置标签的文本中移除标签信息。
位置标签格式为:@@页码\t坐标##,例如:@@1\t100.0\t200.0\t50.0\t60.0##
Args:
request: 包含待处理文本的请求对象
Returns:
RemoveTagResponse: 包含处理结果的响应对象
Raises:
HTTPException: 400 - 如果文本为空
"""
if not request.text:
raise HTTPException(status_code=400, detail="文本内容不能为空")
try:
cleaned_text = SimplePdfParser.remove_tag(request.text)
return RemoveTagResponse(
success=True,
message="成功移除文本标签",
text=cleaned_text
)
except Exception as e:
logger.error(f"Error removing tag: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"移除标签时发生错误: {str(e)}"
)
class ExtractPositionsResponse(BaseModel):
"""提取位置信息响应模型"""
success: bool
message: str
positions: Optional[list] = None
@ocr_router.post(
"/extract_positions",
response_model=ExtractPositionsResponse,
summary="从文本中提取位置信息",
description="从包含位置标签的文本中提取所有位置坐标信息",
response_description="返回提取到的位置信息列表"
)
async def extract_positions_endpoint(request: TextRequest):
"""
从文本中提取位置信息
此接口用于从包含位置标签的文本中提取所有位置坐标信息。
位置标签格式为:@@页码\t坐标##
返回的位置信息格式为:
[
([页码列表], left, right, top, bottom),
...
]
Args:
request: 包含待处理文本的请求对象
Returns:
ExtractPositionsResponse: 包含提取结果的响应对象
Raises:
HTTPException: 400 - 如果文本为空
"""
if not request.text:
raise HTTPException(status_code=400, detail="文本内容不能为空")
try:
positions = SimplePdfParser.extract_positions(request.text)
# 将位置信息转换为可序列化的格式
serializable_positions = [
{
"page_numbers": pos[0],
"left": pos[1],
"right": pos[2],
"top": pos[3],
"bottom": pos[4]
}
for pos in positions
]
return ExtractPositionsResponse(
success=True,
message=f"成功提取 {len(positions)} 个位置信息",
positions=serializable_positions
)
except Exception as e:
logger.error(f"Error extracting positions: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"提取位置信息时发生错误: {str(e)}"
)

239
ocr/client.py Normal file
View File

@@ -0,0 +1,239 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
OCR HTTP 客户端工具类
用于通过 HTTP 接口调用 OCR 服务
"""
import logging
import os
from typing import Optional, Callable, List, Tuple, Any
try:
import httpx
HAS_HTTPX = True
except ImportError:
HAS_HTTPX = False
import aiohttp
logger = logging.getLogger(__name__)
class OCRClient:
"""OCR HTTP 客户端,用于调用 OCR API"""
def __init__(self, base_url: Optional[str] = None, timeout: float = 300.0):
"""
初始化 OCR 客户端
Args:
base_url: OCR 服务的基础 URL如果不提供则从环境变量 OCR_SERVICE_URL 获取,
如果仍未设置则默认为 http://localhost:8000/api/v1/ocr
timeout: 请求超时时间(秒),默认 300 秒
"""
self.base_url = base_url or os.getenv("OCR_SERVICE_URL", "http://localhost:8000/api/v1/ocr")
self.timeout = timeout
# 移除末尾的斜杠
if self.base_url.endswith('/'):
self.base_url = self.base_url.rstrip('/')
async def _make_request(self, method: str, endpoint: str, **kwargs) -> dict:
"""内部方法:发送 HTTP 请求"""
url = f"{self.base_url}{endpoint}"
if HAS_HTTPX:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.request(method, url, **kwargs)
response.raise_for_status()
return response.json()
else:
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session:
async with session.request(method, url, **kwargs) as response:
response.raise_for_status()
return await response.json()
async def remove_tag(self, text: str) -> str:
"""
移除文本中的位置标签
Args:
text: 包含位置标签的文本
Returns:
移除标签后的文本
"""
response = await self._make_request(
"POST",
"/remove_tag",
json={"text": text}
)
if response.get("success") and response.get("text") is not None:
return response["text"]
raise Exception(f"移除标签失败: {response.get('message', '未知错误')}")
def remove_tag_sync(self, text: str) -> str:
"""
同步版本的 remove_tag用于同步代码
Args:
text: 包含位置标签的文本
Returns:
移除标签后的文本
"""
import asyncio
try:
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.remove_tag(text))
except RuntimeError:
# 如果没有事件循环,创建一个新的
return asyncio.run(self.remove_tag(text))
async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""
从文本中提取位置信息
Args:
text: 包含位置标签的文本
Returns:
位置信息列表,格式为 [(页码列表, left, right, top, bottom), ...]
"""
response = await self._make_request(
"POST",
"/extract_positions",
json={"text": text}
)
if response.get("success") and response.get("positions") is not None:
# 将响应格式转换为原始格式
positions = []
for pos in response["positions"]:
positions.append((
pos["page_numbers"],
pos["left"],
pos["right"],
pos["top"],
pos["bottom"]
))
return positions
raise Exception(f"提取位置信息失败: {response.get('message', '未知错误')}")
def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""
同步版本的 extract_positions用于同步代码
Args:
text: 包含位置标签的文本
Returns:
位置信息列表
"""
import asyncio
try:
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.extract_positions(text))
except RuntimeError:
return asyncio.run(self.extract_positions(text))
async def parse_into_bboxes(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""
解析 PDF 并返回边界框
Args:
pdf_bytes: PDF 文件的二进制数据
callback: 进度回调函数 (progress: float, message: str) -> None
zoomin: 图像放大倍数1-5默认为3
filename: 文件名(仅用于日志)
Returns:
边界框列表
"""
if HAS_HTTPX:
async with httpx.AsyncClient(timeout=self.timeout) as client:
# 注意httpx 需要将文件和数据合并到 files 参数中
form_data = {"filename": filename, "zoomin": str(zoomin)}
form_files = {"pdf_bytes": (filename, pdf_bytes, "application/pdf")}
response = await client.post(
f"{self.base_url}/parse_into_bboxes",
files=form_files,
data=form_data
)
response.raise_for_status()
result = response.json()
else:
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session:
form_data = aiohttp.FormData()
form_data.add_field('pdf_bytes', pdf_bytes, filename=filename, content_type='application/pdf')
form_data.add_field('filename', filename)
form_data.add_field('zoomin', str(zoomin))
async with session.post(
f"{self.base_url}/parse_into_bboxes",
data=form_data
) as response:
response.raise_for_status()
result = await response.json()
if result.get("success") and result.get("data") and result["data"].get("bboxes"):
return result["data"]["bboxes"]
raise Exception(f"解析 PDF 失败: {result.get('message', '未知错误')}")
def parse_into_bboxes_sync(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""
同步版本的 parse_into_bboxes用于同步代码
Args:
pdf_bytes: PDF 文件的二进制数据
callback: 进度回调函数注意HTTP 调用中无法实时传递回调,此参数将被忽略)
zoomin: 图像放大倍数1-5默认为3
filename: 文件名(仅用于日志)
Returns:
边界框列表
"""
if callback:
logger.warning("HTTP 调用中无法使用 callback将忽略回调函数")
import asyncio
try:
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.parse_into_bboxes(pdf_bytes, None, zoomin, filename))
except RuntimeError:
return asyncio.run(self.parse_into_bboxes(pdf_bytes, None, zoomin, filename))
# 全局客户端实例(懒加载)
_global_client: Optional[OCRClient] = None
def get_ocr_client() -> OCRClient:
"""获取全局 OCR 客户端实例(单例模式)"""
global _global_client
if _global_client is None:
_global_client = OCRClient()
return _global_client

42
ocr/config.py Normal file
View File

@@ -0,0 +1,42 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
OCR 模块配置文件
"""
import os
import logging
# 并行设备数量GPU数量0表示使用CPU
PARALLEL_DEVICES = 0
try:
import torch.cuda
PARALLEL_DEVICES = torch.cuda.device_count()
logging.info(f"found {PARALLEL_DEVICES} gpus")
except Exception:
logging.info("can't import package 'torch', using CPU mode")
# 模型目录
# 可以从环境变量获取,或使用默认路径
MODEL_DIR = os.getenv("OCR_MODEL_DIR", None)
if MODEL_DIR is None:
# 默认模型目录:当前项目根目录下的 models/deepdoc 目录
# 如果不存在,将在 OCR 类初始化时尝试从 HuggingFace 下载
_base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
MODEL_DIR = os.path.join(_base_dir, "models", "deepdoc")
# 如果目录不存在,设置为 None让 OCR 类处理下载逻辑
if not os.path.exists(MODEL_DIR):
MODEL_DIR = None

785
ocr/ocr.py Normal file
View File

@@ -0,0 +1,785 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import gc
import logging
import copy
import time
import os
import sys
from pathlib import Path
from huggingface_hub import snapshot_download
# 处理导入问题:支持直接运行和模块导入
try:
_package = __package__
except NameError:
_package = None
if _package is None:
# 直接运行时,添加父目录到路径并使用绝对导入
parent_dir = Path(__file__).parent.parent
if str(parent_dir) not in sys.path:
sys.path.insert(0, str(parent_dir))
from ocr.utils import get_project_base_directory
from ocr.config import PARALLEL_DEVICES, MODEL_DIR
from ocr.operators import * # noqa: F403
import ocr.operators as operators
from ocr.postprocess import build_post_process
else:
# 作为模块导入时使用相对导入
from utils import get_project_base_directory
from config import PARALLEL_DEVICES, MODEL_DIR
from operators import * # noqa: F403
import operators
from postprocess import build_post_process
import math
import numpy as np
import cv2
import onnxruntime as ort
loaded_models = {}
def transform(data, ops=None):
""" transform """
if ops is None:
ops = []
for op in ops:
data = op(data)
if data is None:
return None
return data
def create_operators(op_param_list, global_config=None):
"""
create operators based on the config
Args:
params(list): a dict list, used to create some operators
"""
assert isinstance(
op_param_list, list), ('operator config should be a list')
ops = []
for operator in op_param_list:
assert isinstance(operator,
dict) and len(operator) == 1, "yaml format error"
op_name = list(operator)[0]
param = {} if operator[op_name] is None else operator[op_name]
if global_config is not None:
param.update(global_config)
op = getattr(operators, op_name)(**param)
ops.append(op)
return ops
def load_model(model_dir, nm, device_id: int | None = None):
model_file_path = os.path.join(model_dir, nm + ".onnx")
model_cached_tag = model_file_path + str(device_id) if device_id is not None else model_file_path
global loaded_models
loaded_model = loaded_models.get(model_cached_tag)
if loaded_model:
logging.info(f"load_model {model_file_path} reuses cached model")
return loaded_model
if not os.path.exists(model_file_path):
raise ValueError("not find model file path {}".format(
model_file_path))
def cuda_is_available():
try:
import torch
if torch.cuda.is_available() and torch.cuda.device_count() > device_id:
return True
except Exception:
return False
return False
options = ort.SessionOptions()
options.enable_cpu_mem_arena = False
options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
options.intra_op_num_threads = 2
options.inter_op_num_threads = 2
# https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580
# Shrink GPU memory after execution
run_options = ort.RunOptions()
if cuda_is_available():
cuda_provider_options = {
"device_id": device_id, # Use specific GPU
"gpu_mem_limit": 512 * 1024 * 1024, # Limit gpu memory
"arena_extend_strategy": "kNextPowerOfTwo", # gpu memory allocation strategy
}
sess = ort.InferenceSession(
model_file_path,
options=options,
providers=['CUDAExecutionProvider'],
provider_options=[cuda_provider_options]
)
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "gpu:" + str(device_id))
logging.info(f"load_model {model_file_path} uses GPU")
else:
sess = ort.InferenceSession(
model_file_path,
options=options,
providers=['CPUExecutionProvider'])
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "cpu")
logging.info(f"load_model {model_file_path} uses CPU")
loaded_model = (sess, run_options)
loaded_models[model_cached_tag] = loaded_model
return loaded_model
class TextRecognizer:
def __init__(self, model_dir, device_id: int | None = None):
self.rec_image_shape = [int(v) for v in "3, 48, 320".split(",")]
self.rec_batch_num = 16
postprocess_params = {
'name': 'CTCLabelDecode',
"character_dict_path": os.path.join(model_dir, "ocr.res"),
"use_space_char": True
}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.run_options = load_model(model_dir, 'rec', device_id)
self.input_tensor = self.predictor.get_inputs()[0]
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
assert imgC == img.shape[2]
imgW = int((imgH * max_wh_ratio))
w = self.input_tensor.shape[3:][0]
if isinstance(w, str):
pass
elif w is not None and w > 0:
imgW = w
h, w = img.shape[:2]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def resize_norm_img_vl(self, img, image_shape):
imgC, imgH, imgW = image_shape
img = img[:, :, ::-1] # bgr2rgb
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
return resized_image
def resize_norm_img_srn(self, img, image_shape):
imgC, imgH, imgW = image_shape
img_black = np.zeros((imgH, imgW))
im_hei = img.shape[0]
im_wid = img.shape[1]
if im_wid <= im_hei * 1:
img_new = cv2.resize(img, (imgH * 1, imgH))
elif im_wid <= im_hei * 2:
img_new = cv2.resize(img, (imgH * 2, imgH))
elif im_wid <= im_hei * 3:
img_new = cv2.resize(img, (imgH * 3, imgH))
else:
img_new = cv2.resize(img, (imgW, imgH))
img_np = np.asarray(img_new)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
img_black[:, 0:img_np.shape[1]] = img_np
img_black = img_black[:, :, np.newaxis]
row, col, c = img_black.shape
c = 1
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def srn_other_inputs(self, image_shape, num_heads, max_text_length):
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
(feature_dim, 1)).astype('int64')
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
(max_text_length, 1)).astype('int64')
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias1 = np.tile(
gsrm_slf_attn_bias1,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias2 = np.tile(
gsrm_slf_attn_bias2,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
encoder_word_pos = encoder_word_pos[np.newaxis, :]
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
return [
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
def process_image_srn(self, img, image_shape, num_heads, max_text_length):
norm_img = self.resize_norm_img_srn(img, image_shape)
norm_img = norm_img[np.newaxis, :]
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
self.srn_other_inputs(image_shape, num_heads, max_text_length)
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
encoder_word_pos = encoder_word_pos.astype(np.int64)
gsrm_word_pos = gsrm_word_pos.astype(np.int64)
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2)
def resize_norm_img_sar(self, img, image_shape,
width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
w = img.shape[1]
valid_ratio = 1.0
# make sure new_width is an integral multiple of width_divisor.
width_divisor = int(1 / width_downsample_ratio)
# resize
ratio = w / float(h)
resize_w = math.ceil(imgH * ratio)
if resize_w % width_divisor != 0:
resize_w = round(resize_w / width_divisor) * width_divisor
if imgW_min is not None:
resize_w = max(imgW_min, resize_w)
if imgW_max is not None:
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
resize_w = min(imgW_max, resize_w)
resized_image = cv2.resize(img, (resize_w, imgH))
resized_image = resized_image.astype('float32')
# norm
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
resize_shape = resized_image.shape
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
padding_im[:, :, 0:resize_w] = resized_image
pad_shape = padding_im.shape
return padding_im, resize_shape, pad_shape, valid_ratio
def resize_norm_img_spin(self, img):
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# return padding_im
img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC)
img = np.array(img, np.float32)
img = np.expand_dims(img, -1)
img = img.transpose((2, 0, 1))
mean = [127.5]
std = [127.5]
mean = np.array(mean, dtype=np.float32)
std = np.array(std, dtype=np.float32)
mean = np.float32(mean.reshape(1, -1))
stdinv = 1 / np.float32(std.reshape(1, -1))
img -= mean
img *= stdinv
return img
def resize_norm_img_svtr(self, img, image_shape):
imgC, imgH, imgW = image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
return resized_image
def resize_norm_img_abinet(self, img, image_shape):
imgC, imgH, imgW = image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image / 255.
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
resized_image = (
resized_image - mean[None, None, ...]) / std[None, None, ...]
resized_image = resized_image.transpose((2, 0, 1))
resized_image = resized_image.astype('float32')
return resized_image
def norm_img_can(self, img, image_shape):
img = cv2.cvtColor(
img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
if self.rec_image_shape[0] == 1:
h, w = img.shape
_, imgH, imgW = self.rec_image_shape
if h < imgH or w < imgW:
padding_h = max(imgH - h, 0)
padding_w = max(imgW - w, 0)
img_padded = np.pad(img, ((0, padding_h), (0, padding_w)),
'constant',
constant_values=(255))
img = img_padded
img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w
img = img.astype('float32')
return img
def close(self):
# close session and release manually
logging.info('Close text recognizer.')
if hasattr(self, "predictor"):
del self.predictor
gc.collect()
def __call__(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
width_list = []
for img in img_list:
width_list.append(img.shape[1] / float(img.shape[0]))
# Sorting can speed up the recognition process
indices = np.argsort(np.array(width_list))
rec_res = [['', 0.0]] * img_num
batch_num = self.rec_batch_num
st = time.time()
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
imgC, imgH, imgW = self.rec_image_shape[:3]
max_wh_ratio = imgW / imgH
# max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
input_dict = {}
input_dict[self.input_tensor.name] = norm_img_batch
for i in range(100000):
try:
outputs = self.predictor.run(None, input_dict, self.run_options)
break
except Exception as e:
if i >= 3:
raise e
time.sleep(5)
preds = outputs[0]
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
return rec_res, time.time() - st
def __del__(self):
self.close()
class TextDetector:
def __init__(self, model_dir, device_id: int | None = None):
pre_process_list = [{
'DetResizeForTest': {
'limit_side_len': 960,
'limit_type': "max",
}
}, {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225],
'mean': [0.485, 0.456, 0.406],
'scale': '1./255.',
'order': 'hwc'
}
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': ['image', 'shape']
}
}]
postprocess_params = {"name": "DBPostProcess", "thresh": 0.3, "box_thresh": 0.5, "max_candidates": 1000,
"unclip_ratio": 1.5, "use_dilation": False, "score_mode": "fast", "box_type": "quad"}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.run_options = load_model(model_dir, 'det', device_id)
self.input_tensor = self.predictor.get_inputs()[0]
img_h, img_w = self.input_tensor.shape[2:]
if isinstance(img_h, str) or isinstance(img_w, str):
pass
elif img_h is not None and img_w is not None and img_h > 0 and img_w > 0:
pre_process_list[0] = {
'DetResizeForTest': {
'image_shape': [img_h, img_w]
}
}
self.preprocess_op = create_operators(pre_process_list)
def order_points_clockwise(self, pts):
rect = np.zeros((4, 2), dtype="float32")
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)]
rect[2] = pts[np.argmax(s)]
tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
diff = np.diff(np.array(tmp), axis=1)
rect[1] = tmp[np.argmin(diff)]
rect[3] = tmp[np.argmax(diff)]
return rect
def clip_det_res(self, points, img_height, img_width):
for pno in range(points.shape[0]):
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
return points
def filter_tag_det_res(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
if isinstance(box, list):
box = np.array(box)
box = self.order_points_clockwise(box)
box = self.clip_det_res(box, img_height, img_width)
rect_width = int(np.linalg.norm(box[0] - box[1]))
rect_height = int(np.linalg.norm(box[0] - box[3]))
if rect_width <= 3 or rect_height <= 3:
continue
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
if isinstance(box, list):
box = np.array(box)
box = self.clip_det_res(box, img_height, img_width)
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def close(self):
logging.info("Close text detector.")
if hasattr(self, "predictor"):
del self.predictor
gc.collect()
def __call__(self, img):
ori_im = img.copy()
data = {'image': img}
st = time.time()
data = transform(data, self.preprocess_op)
img, shape_list = data
if img is None:
return None, 0
img = np.expand_dims(img, axis=0)
shape_list = np.expand_dims(shape_list, axis=0)
img = img.copy()
input_dict = {}
input_dict[self.input_tensor.name] = img
for i in range(100000):
try:
outputs = self.predictor.run(None, input_dict, self.run_options)
break
except Exception as e:
if i >= 3:
raise e
time.sleep(5)
post_result = self.postprocess_op({"maps": outputs[0]}, shape_list)
dt_boxes = post_result[0]['points']
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
return dt_boxes, time.time() - st
def __del__(self):
self.close()
class OCR:
def __init__(self, model_dir=None):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
For Linux:
export HF_ENDPOINT=https://hf-mirror.com
For Windows:
Good luck
^_-
"""
if not model_dir:
try:
# 使用配置中的 MODEL_DIR如果不存在则尝试默认路径
if MODEL_DIR and os.path.exists(MODEL_DIR):
model_dir = MODEL_DIR
else:
model_dir = os.path.join(
get_project_base_directory(),
"models", "deepdoc")
# Append muti-gpus task to the list
if PARALLEL_DEVICES > 0:
self.text_detector = []
self.text_recognizer = []
for device_id in range(PARALLEL_DEVICES):
self.text_detector.append(TextDetector(model_dir, device_id))
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
else:
self.text_detector = [TextDetector(model_dir)]
self.text_recognizer = [TextRecognizer(model_dir)]
except Exception:
# 如果模型目录不存在,尝试从 HuggingFace 下载
default_model_dir = os.path.join(
get_project_base_directory(), "models", "deepdoc")
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
local_dir=default_model_dir,
local_dir_use_symlinks=False)
if PARALLEL_DEVICES > 0:
self.text_detector = []
self.text_recognizer = []
for device_id in range(PARALLEL_DEVICES):
self.text_detector.append(TextDetector(model_dir, device_id))
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
else:
self.text_detector = [TextDetector(model_dir)]
self.text_recognizer = [TextRecognizer(model_dir)]
else:
# 如果指定了 model_dir直接使用
if PARALLEL_DEVICES > 0:
self.text_detector = []
self.text_recognizer = []
for device_id in range(PARALLEL_DEVICES):
self.text_detector.append(TextDetector(model_dir, device_id))
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
else:
self.text_detector = [TextDetector(model_dir)]
self.text_recognizer = [TextRecognizer(model_dir)]
self.drop_score = 0.5
self.crop_image_res_index = 0
def get_rotate_crop_image(self, img, points):
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
'''
assert len(points) == 4, "shape of points must be 4*2"
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])))
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
# Try original orientation
rec_result = self.text_recognizer[0]([dst_img])
text, score = rec_result[0][0]
best_score = score
best_img = dst_img
# Try clockwise 90° rotation
rotated_cw = np.rot90(dst_img, k=3)
rec_result = self.text_recognizer[0]([rotated_cw])
rotated_cw_text, rotated_cw_score = rec_result[0][0]
if rotated_cw_score > best_score:
best_score = rotated_cw_score
best_img = rotated_cw
# Try counter-clockwise 90° rotation
rotated_ccw = np.rot90(dst_img, k=1)
rec_result = self.text_recognizer[0]([rotated_ccw])
rotated_ccw_text, rotated_ccw_score = rec_result[0][0]
if rotated_ccw_score > best_score:
best_img = rotated_ccw
# Use the best image
dst_img = best_img
return dst_img
def sorted_boxes(self, dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
for j in range(i, -1, -1):
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
tmp = _boxes[j]
_boxes[j] = _boxes[j + 1]
_boxes[j + 1] = tmp
else:
break
return _boxes
def detect(self, img, device_id: int | None = None):
if device_id is None:
device_id = 0
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
if img is None:
return None, None, time_dict
start = time.time()
dt_boxes, elapse = self.text_detector[device_id](img)
time_dict['det'] = elapse
if dt_boxes is None:
end = time.time()
time_dict['all'] = end - start
return None, None, time_dict
return zip(self.sorted_boxes(dt_boxes), [
("", 0) for _ in range(len(dt_boxes))])
def recognize(self, ori_im, box, device_id: int | None = None):
if device_id is None:
device_id = 0
img_crop = self.get_rotate_crop_image(ori_im, box)
rec_res, elapse = self.text_recognizer[device_id]([img_crop])
text, score = rec_res[0]
if score < self.drop_score:
return ""
return text
def recognize_batch(self, img_list, device_id: int | None = None):
if device_id is None:
device_id = 0
rec_res, elapse = self.text_recognizer[device_id](img_list)
texts = []
for i in range(len(rec_res)):
text, score = rec_res[i]
if score < self.drop_score:
text = ""
texts.append(text)
return texts
def __call__(self, img, device_id = 0, cls=True):
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
if device_id is None:
device_id = 0
if img is None:
return None, None, time_dict
start = time.time()
ori_im = img.copy()
dt_boxes, elapse = self.text_detector[device_id](img)
time_dict['det'] = elapse
if dt_boxes is None:
end = time.time()
time_dict['all'] = end - start
return None, None, time_dict
img_crop_list = []
dt_boxes = self.sorted_boxes(dt_boxes)
for bno in range(len(dt_boxes)):
tmp_box = copy.deepcopy(dt_boxes[bno])
img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
img_crop_list.append(img_crop)
rec_res, elapse = self.text_recognizer[device_id](img_crop_list)
time_dict['rec'] = elapse
filter_boxes, filter_rec_res = [], []
for box, rec_result in zip(dt_boxes, rec_res):
text, score = rec_result
if score >= self.drop_score:
filter_boxes.append(box)
filter_rec_res.append(rec_result)
end = time.time()
time_dict['all'] = end - start
# for bno in range(len(img_crop_list)):
# print(f"{bno}, {rec_res[bno]}")
return list(zip([a.tolist() for a in filter_boxes], filter_rec_res))

726
ocr/operators.py Normal file
View File

@@ -0,0 +1,726 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import sys
import six
import cv2
import numpy as np
import math
from PIL import Image
class DecodeImage:
""" decode image """
def __init__(self,
img_mode='RGB',
channel_first=False,
ignore_orientation=False,
**kwargs):
self.img_mode = img_mode
self.channel_first = channel_first
self.ignore_orientation = ignore_orientation
def __call__(self, data):
img = data['image']
if six.PY2:
assert isinstance(img, str) and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert isinstance(img, bytes) and len(
img) > 0, "invalid input 'img' in DecodeImage"
img = np.frombuffer(img, dtype='uint8')
if self.ignore_orientation:
img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION |
cv2.IMREAD_COLOR)
else:
img = cv2.imdecode(img, 1)
if img is None:
return None
if self.img_mode == 'GRAY':
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif self.img_mode == 'RGB':
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
img.shape)
img = img[:, :, ::-1]
if self.channel_first:
img = img.transpose((2, 0, 1))
data['image'] = img
return data
class StandardizeImag:
"""normalize image
Args:
mean (list): im - mean
std (list): im / std
is_scale (bool): whether need im / 255
norm_type (str): type in ['mean_std', 'none']
"""
def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
self.mean = mean
self.std = std
self.is_scale = is_scale
self.norm_type = norm_type
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.astype(np.float32, copy=False)
if self.is_scale:
scale = 1.0 / 255.0
im *= scale
if self.norm_type == 'mean_std':
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im -= mean
im /= std
return im, im_info
class NormalizeImage:
""" normalize image such as subtract mean, divide std
"""
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
if isinstance(scale, str):
scale = eval(scale)
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32')
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = (
img.astype('float32') * self.scale - self.mean) / self.std
return data
class ToCHWImage:
""" convert hwc image to chw image
"""
def __init__(self, **kwargs):
pass
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
data['image'] = img.transpose((2, 0, 1))
return data
class KeepKeys:
def __init__(self, keep_keys, **kwargs):
self.keep_keys = keep_keys
def __call__(self, data):
data_list = []
for key in self.keep_keys:
data_list.append(data[key])
return data_list
class Pad:
def __init__(self, size=None, size_div=32, **kwargs):
if size is not None and not isinstance(size, (int, list, tuple)):
raise TypeError("Type of target_size is invalid. Now is {}".format(
type(size)))
if isinstance(size, int):
size = [size, size]
self.size = size
self.size_div = size_div
def __call__(self, data):
img = data['image']
img_h, img_w = img.shape[0], img.shape[1]
if self.size:
resize_h2, resize_w2 = self.size
assert (
img_h < resize_h2 and img_w < resize_w2
), '(h, w) of target size should be greater than (img_h, img_w)'
else:
resize_h2 = max(
int(math.ceil(img.shape[0] / self.size_div) * self.size_div),
self.size_div)
resize_w2 = max(
int(math.ceil(img.shape[1] / self.size_div) * self.size_div),
self.size_div)
img = cv2.copyMakeBorder(
img,
0,
resize_h2 - img_h,
0,
resize_w2 - img_w,
cv2.BORDER_CONSTANT,
value=0)
data['image'] = img
return data
class LinearResize:
"""resize image by target_size and max_size
Args:
target_size (int): the target size of image
keep_ratio (bool): whether keep_ratio or not, default true
interp (int): method of resize
"""
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
self.keep_ratio = keep_ratio
self.interp = interp
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
_im_channel = im.shape[2]
im_scale_y, im_scale_x = self.generate_scale(im)
im = cv2.resize(
im,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
im_info['scale_factor'] = np.array(
[im_scale_y, im_scale_x]).astype('float32')
return im, im_info
def generate_scale(self, im):
"""
Args:
im (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape = im.shape[:2]
_im_c = im.shape[2]
if self.keep_ratio:
im_size_min = np.min(origin_shape)
im_size_max = np.max(origin_shape)
target_size_min = np.min(self.target_size)
target_size_max = np.max(self.target_size)
im_scale = float(target_size_min) / float(im_size_min)
if np.round(im_scale * im_size_max) > target_size_max:
im_scale = float(target_size_max) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
else:
resize_h, resize_w = self.target_size
im_scale_y = resize_h / float(origin_shape[0])
im_scale_x = resize_w / float(origin_shape[1])
return im_scale_y, im_scale_x
class Resize:
def __init__(self, size=(640, 640), **kwargs):
self.size = size
def resize_image(self, img):
resize_h, resize_w = self.size
ori_h, ori_w = img.shape[:2] # (h, w, c)
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h)))
return img, [ratio_h, ratio_w]
def __call__(self, data):
img = data['image']
if 'polys' in data:
text_polys = data['polys']
img_resize, [ratio_h, ratio_w] = self.resize_image(img)
if 'polys' in data:
new_boxes = []
for box in text_polys:
new_box = []
for cord in box:
new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
new_boxes.append(new_box)
data['polys'] = np.array(new_boxes, dtype=np.float32)
data['image'] = img_resize
return data
class DetResizeForTest:
def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__()
self.resize_type = 0
self.keep_ratio = False
if 'image_shape' in kwargs:
self.image_shape = kwargs['image_shape']
self.resize_type = 1
if 'keep_ratio' in kwargs:
self.keep_ratio = kwargs['keep_ratio']
elif 'limit_side_len' in kwargs:
self.limit_side_len = kwargs['limit_side_len']
self.limit_type = kwargs.get('limit_type', 'min')
elif 'resize_long' in kwargs:
self.resize_type = 2
self.resize_long = kwargs.get('resize_long', 960)
else:
self.limit_side_len = 736
self.limit_type = 'min'
def __call__(self, data):
img = data['image']
src_h, src_w, _ = img.shape
if sum([src_h, src_w]) < 64:
img = self.image_padding(img)
if self.resize_type == 0:
# img, shape = self.resize_image_type0(img)
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
elif self.resize_type == 2:
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
else:
# img, shape = self.resize_image_type1(img)
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
data['image'] = img
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def image_padding(self, im, value=0):
h, w, c = im.shape
im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
im_pad[:h, :w, :] = im
return im_pad
def resize_image_type1(self, img):
resize_h, resize_w = self.image_shape
ori_h, ori_w = img.shape[:2] # (h, w, c)
if self.keep_ratio is True:
resize_w = ori_w * resize_h / ori_h
N = math.ceil(resize_w / 32)
resize_w = N * 32
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h)))
# return img, np.array([ori_h, ori_w])
return img, [ratio_h, ratio_w]
def resize_image_type0(self, img):
"""
resize image to a size multiple of 32 which is required by the network
args:
img(array): array with shape [h, w, c]
return(tuple):
img, (ratio_h, ratio_w)
"""
limit_side_len = self.limit_side_len
h, w, c = img.shape
# limit the max side
if self.limit_type == 'max':
if max(h, w) > limit_side_len:
if h > w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
elif self.limit_type == 'min':
if min(h, w) < limit_side_len:
if h < w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
elif self.limit_type == 'resize_long':
ratio = float(limit_side_len) / max(h, w)
else:
raise Exception('not support limit type, image ')
resize_h = int(h * ratio)
resize_w = int(w * ratio)
resize_h = max(int(round(resize_h / 32) * 32), 32)
resize_w = max(int(round(resize_w / 32) * 32), 32)
try:
if int(resize_w) <= 0 or int(resize_h) <= 0:
return None, (None, None)
img = cv2.resize(img, (int(resize_w), int(resize_h)))
except BaseException:
logging.exception("{} {} {}".format(img.shape, resize_w, resize_h))
sys.exit(0)
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
def resize_image_type2(self, img):
h, w, _ = img.shape
resize_w = w
resize_h = h
if resize_h > resize_w:
ratio = float(self.resize_long) / resize_h
else:
ratio = float(self.resize_long) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
img = cv2.resize(img, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
class E2EResizeForTest:
def __init__(self, **kwargs):
super(E2EResizeForTest, self).__init__()
self.max_side_len = kwargs['max_side_len']
self.valid_set = kwargs['valid_set']
def __call__(self, data):
img = data['image']
src_h, src_w, _ = img.shape
if self.valid_set == 'totaltext':
im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
img, max_side_len=self.max_side_len)
else:
im_resized, (ratio_h, ratio_w) = self.resize_image(
img, max_side_len=self.max_side_len)
data['image'] = im_resized
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def resize_image_for_totaltext(self, im, max_side_len=512):
h, w, _ = im.shape
resize_w = w
resize_h = h
ratio = 1.25
if h * ratio > max_side_len:
ratio = float(max_side_len) / resize_h
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image(self, im, max_side_len=512):
"""
resize image to a size multiple of max_stride which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
# Fix the longer side
if resize_h > resize_w:
ratio = float(max_side_len) / resize_h
else:
ratio = float(max_side_len) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
class KieResize:
def __init__(self, **kwargs):
super(KieResize, self).__init__()
self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[
'img_scale'][1]
def __call__(self, data):
img = data['image']
points = data['points']
src_h, src_w, _ = img.shape
im_resized, scale_factor, [ratio_h, ratio_w
], [new_h, new_w] = self.resize_image(img)
resize_points = self.resize_boxes(img, points, scale_factor)
data['ori_image'] = img
data['ori_boxes'] = points
data['points'] = resize_points
data['image'] = im_resized
data['shape'] = np.array([new_h, new_w])
return data
def resize_image(self, img):
norm_img = np.zeros([1024, 1024, 3], dtype='float32')
scale = [512, 1024]
h, w = img.shape[:2]
max_long_edge = max(scale)
max_short_edge = min(scale)
scale_factor = min(max_long_edge / max(h, w),
max_short_edge / min(h, w))
resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float(
scale_factor) + 0.5)
max_stride = 32
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(img, (resize_w, resize_h))
new_h, new_w = im.shape[:2]
w_scale = new_w / w
h_scale = new_h / h
scale_factor = np.array(
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
norm_img[:new_h, :new_w, :] = im
return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
def resize_boxes(self, im, points, scale_factor):
points = points * scale_factor
img_shape = im.shape[:2]
points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
return points
class SRResize:
def __init__(self,
imgH=32,
imgW=128,
down_sample_scale=4,
keep_ratio=False,
min_ratio=1,
mask=False,
infer_mode=False,
**kwargs):
self.imgH = imgH
self.imgW = imgW
self.keep_ratio = keep_ratio
self.min_ratio = min_ratio
self.down_sample_scale = down_sample_scale
self.mask = mask
self.infer_mode = infer_mode
def __call__(self, data):
imgH = self.imgH
imgW = self.imgW
images_lr = data["image_lr"]
transform2 = ResizeNormalize(
(imgW // self.down_sample_scale, imgH // self.down_sample_scale))
images_lr = transform2(images_lr)
data["img_lr"] = images_lr
if self.infer_mode:
return data
images_HR = data["image_hr"]
_label_strs = data["label"]
transform = ResizeNormalize((imgW, imgH))
images_HR = transform(images_HR)
data["img_hr"] = images_HR
return data
class ResizeNormalize:
def __init__(self, size, interpolation=Image.BICUBIC):
self.size = size
self.interpolation = interpolation
def __call__(self, img):
img = img.resize(self.size, self.interpolation)
img_numpy = np.array(img).astype("float32")
img_numpy = img_numpy.transpose((2, 0, 1)) / 255
return img_numpy
class GrayImageChannelFormat:
"""
format gray scale image's channel: (3,h,w) -> (1,h,w)
Args:
inverse: inverse gray image
"""
def __init__(self, inverse=False, **kwargs):
self.inverse = inverse
def __call__(self, data):
img = data['image']
img_single_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img_expanded = np.expand_dims(img_single_channel, 0)
if self.inverse:
data['image'] = np.abs(img_expanded - 1)
else:
data['image'] = img_expanded
data['src_image'] = img
return data
class Permute:
"""permute image
Args:
to_bgr (bool): whether convert RGB to BGR
channel_first (bool): whether convert HWC to CHW
"""
def __init__(self, ):
super(Permute, self).__init__()
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.transpose((2, 0, 1)).copy()
return im, im_info
class PadStride:
""" padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
Args:
stride (bool): model with FPN need image shape % stride == 0
"""
def __init__(self, stride=0):
self.coarsest_stride = stride
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
coarsest_stride = self.coarsest_stride
if coarsest_stride <= 0:
return im, im_info
im_c, im_h, im_w = im.shape
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = im
return padding_im, im_info
def decode_image(im_file, im_info):
"""read rgb image
Args:
im_file (str|np.ndarray): input can be image path or np.ndarray
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
if isinstance(im_file, str):
with open(im_file, 'rb') as f:
im_read = f.read()
data = np.frombuffer(im_read, dtype='uint8')
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
else:
im = im_file
im_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32)
im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32)
return im, im_info
def preprocess(im, preprocess_ops):
# process image by preprocess_ops
im_info = {
'scale_factor': np.array(
[1., 1.], dtype=np.float32),
'im_shape': None,
}
im, im_info = decode_image(im, im_info)
for operator in preprocess_ops:
im, im_info = operator(im, im_info)
return im, im_info
def nms(bboxes, scores, iou_thresh):
import numpy as np
x1 = bboxes[:, 0]
y1 = bboxes[:, 1]
x2 = bboxes[:, 2]
y2 = bboxes[:, 3]
areas = (y2 - y1) * (x2 - x1)
indices = []
index = scores.argsort()[::-1]
while index.size > 0:
i = index[0]
indices.append(i)
x11 = np.maximum(x1[i], x1[index[1:]])
y11 = np.maximum(y1[i], y1[index[1:]])
x22 = np.minimum(x2[i], x2[index[1:]])
y22 = np.minimum(y2[i], y2[index[1:]])
w = np.maximum(0, x22 - x11 + 1)
h = np.maximum(0, y22 - y11 + 1)
overlaps = w * h
ious = overlaps / (areas[i] + areas[index[1:]] - overlaps)
idx = np.where(ious <= iou_thresh)[0]
index = index[idx + 1]
return indices

1319
ocr/pdf_parser.py Normal file

File diff suppressed because it is too large Load Diff

371
ocr/postprocess.py Normal file
View File

@@ -0,0 +1,371 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
import re
import numpy as np
import cv2
from shapely.geometry import Polygon
import pyclipper
def build_post_process(config, global_config=None):
support_dict = {'DBPostProcess': DBPostProcess, 'CTCLabelDecode': CTCLabelDecode}
config = copy.deepcopy(config)
module_name = config.pop('name')
if module_name == "None":
return
if global_config is not None:
config.update(global_config)
module_class = support_dict.get(module_name)
if module_class is None:
raise ValueError(
'post process only support {}'.format(list(support_dict)))
return module_class(**config)
class DBPostProcess:
"""
The post process for Differentiable Binarization (DB).
"""
def __init__(self,
thresh=0.3,
box_thresh=0.7,
max_candidates=1000,
unclip_ratio=2.0,
use_dilation=False,
score_mode="fast",
box_type='quad',
**kwargs):
self.thresh = thresh
self.box_thresh = box_thresh
self.max_candidates = max_candidates
self.unclip_ratio = unclip_ratio
self.min_size = 3
self.score_mode = score_mode
self.box_type = box_type
assert score_mode in [
"slow", "fast"
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
self.dilation_kernel = None if not use_dilation else np.array(
[[1, 1], [1, 1]])
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap = _bitmap
height, width = bitmap.shape
boxes = []
scores = []
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours[:self.max_candidates]:
epsilon = 0.002 * cv2.arcLength(contour, True)
approx = cv2.approxPolyDP(contour, epsilon, True)
points = approx.reshape((-1, 2))
if points.shape[0] < 4:
continue
score = self.box_score_fast(pred, points.reshape(-1, 2))
if self.box_thresh > score:
continue
if points.shape[0] > 2:
box = self.unclip(points, self.unclip_ratio)
if len(box) > 1:
continue
else:
continue
box = box.reshape(-1, 2)
_, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
if sside < self.min_size + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(
np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box.tolist())
scores.append(score)
return boxes, scores
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap = _bitmap
height, width = bitmap.shape
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
cv2.CHAIN_APPROX_SIMPLE)
if len(outs) == 3:
_img, contours, _ = outs[0], outs[1], outs[2]
elif len(outs) == 2:
contours, _ = outs[0], outs[1]
num_contours = min(len(contours), self.max_candidates)
boxes = []
scores = []
for index in range(num_contours):
contour = contours[index]
points, sside = self.get_mini_boxes(contour)
if sside < self.min_size:
continue
points = np.array(points)
if self.score_mode == "fast":
score = self.box_score_fast(pred, points.reshape(-1, 2))
else:
score = self.box_score_slow(pred, contour)
if self.box_thresh > score:
continue
box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
box, sside = self.get_mini_boxes(box)
if sside < self.min_size + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(
np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box.astype("int32"))
scores.append(score)
return np.array(boxes, dtype="int32"), scores
def unclip(self, box, unclip_ratio):
poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset()
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
expanded = np.array(offset.Execute(distance))
return expanded
def get_mini_boxes(self, contour):
bounding_box = cv2.minAreaRect(contour)
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
if points[1][1] > points[0][1]:
index_1 = 0
index_4 = 1
else:
index_1 = 1
index_4 = 0
if points[3][1] > points[2][1]:
index_2 = 2
index_3 = 3
else:
index_2 = 3
index_3 = 2
box = [
points[index_1], points[index_2], points[index_3], points[index_4]
]
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
'''
box_score_fast: use bbox mean score as the mean score
'''
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def box_score_slow(self, bitmap, contour):
'''
box_score_slow: use polyon mean score as the mean score
'''
h, w = bitmap.shape[:2]
contour = contour.copy()
contour = np.reshape(contour, (-1, 2))
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
contour[:, 0] = contour[:, 0] - xmin
contour[:, 1] = contour[:, 1] - ymin
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def __call__(self, outs_dict, shape_list):
pred = outs_dict['maps']
if not isinstance(pred, np.ndarray):
pred = pred.numpy()
pred = pred[:, 0, :, :]
segmentation = pred > self.thresh
boxes_batch = []
for batch_index in range(pred.shape[0]):
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
if self.dilation_kernel is not None:
mask = cv2.dilate(
np.array(segmentation[batch_index]).astype(np.uint8),
self.dilation_kernel)
else:
mask = segmentation[batch_index]
if self.box_type == 'poly':
boxes, scores = self.polygons_from_bitmap(pred[batch_index],
mask, src_w, src_h)
elif self.box_type == 'quad':
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
src_w, src_h)
else:
raise ValueError(
"box_type can only be one of ['quad', 'poly']")
boxes_batch.append({'points': boxes})
return boxes_batch
class BaseRecLabelDecode:
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False):
self.beg_str = "sos"
self.end_str = "eos"
self.reverse = False
self.character_str = []
if character_dict_path is None:
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
else:
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
self.character_str.append(line)
if use_space_char:
self.character_str.append(" ")
dict_character = list(self.character_str)
if 'arabic' in character_dict_path:
self.reverse = True
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
def pred_reverse(self, pred):
pred_re = []
c_current = ''
for c in pred:
if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)):
if c_current != '':
pred_re.append(c_current)
pred_re.append(c)
c_current = ''
else:
c_current += c
if c_current != '':
pred_re.append(c_current)
return ''.join(pred_re[::-1])
def add_special_char(self, dict_character):
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
if is_remove_duplicate:
selection[1:] = text_index[batch_idx][1:] != text_index[
batch_idx][:-1]
for ignored_token in ignored_tokens:
selection &= text_index[batch_idx] != ignored_token
char_list = [
self.character[text_id]
for text_id in text_index[batch_idx][selection]
]
if text_prob is not None:
conf_list = text_prob[batch_idx][selection]
else:
conf_list = [1] * len(selection)
if len(conf_list) == 0:
conf_list = [0]
text = ''.join(char_list)
if self.reverse: # for arabic rec
text = self.pred_reverse(text)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def get_ignored_tokens(self):
return [0] # for ctc blank
class CTCLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(CTCLabelDecode, self).__init__(character_dict_path,
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, tuple) or isinstance(preds, list):
preds = preds[-1]
if not isinstance(preds, np.ndarray):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
if label is None:
return text
label = self.decode(label)
return text, label
def add_special_char(self, dict_character):
dict_character = ['blank'] + dict_character
return dict_character

25
ocr/requirements.txt Normal file
View File

@@ -0,0 +1,25 @@
# OCR PDF处理模块依赖
# 核心依赖
numpy>=1.21.0
opencv-python>=4.5.0
pillow>=8.0.0
pdfplumber>=0.9.0
onnxruntime>=1.12.0
trio>=0.22.0
# 几何计算依赖
shapely>=1.8.0
pyclipper>=1.2.0
# Web框架依赖
fastapi>=0.100.0
uvicorn[standard]>=0.23.0
pydantic>=2.0.0
# 模型下载依赖
huggingface_hub>=0.16.0
# 可选依赖用于GPU检测和加速
# torch>=1.12.0 # 如果需要GPU支持取消注释并安装
# onnxruntime-gpu>=1.12.0 # 如果需要GPU支持取消注释并安装

290
ocr/service.py Normal file
View File

@@ -0,0 +1,290 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
OCR 服务统一接口
支持本地OCR模型和HTTP接口两种方式可通过配置选择
"""
import logging
import os
from abc import ABC, abstractmethod
from typing import Optional, Callable, List, Tuple, Any
logger = logging.getLogger(__name__)
class OCRService(ABC):
"""OCR服务抽象接口"""
@abstractmethod
async def remove_tag(self, text: str) -> str:
"""
移除文本中的位置标签
Args:
text: 包含位置标签的文本
Returns:
清理后的文本
"""
pass
@abstractmethod
def remove_tag_sync(self, text: str) -> str:
"""
同步版本的 remove_tag用于同步代码
Args:
text: 包含位置标签的文本
Returns:
清理后的文本
"""
pass
@abstractmethod
async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""
从文本中提取位置信息
Args:
text: 包含位置标签的文本
Returns:
位置信息列表,格式为 [(页码列表, left, right, top, bottom), ...]
"""
pass
@abstractmethod
def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""
同步版本的 extract_positions用于同步代码
Args:
text: 包含位置标签的文本
Returns:
位置信息列表
"""
pass
@abstractmethod
async def parse_into_bboxes(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""
解析 PDF 并返回边界框
Args:
pdf_bytes: PDF 文件的二进制数据
callback: 进度回调函数 (progress: float, message: str) -> None
zoomin: 图像放大倍数1-5默认为3
filename: 文件名(仅用于日志)
Returns:
边界框列表
"""
pass
@abstractmethod
def parse_into_bboxes_sync(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""
同步版本的 parse_into_bboxes用于同步代码
Args:
pdf_bytes: PDF 文件的二进制数据
callback: 进度回调函数注意HTTP 调用中无法实时传递回调,此参数将被忽略)
zoomin: 图像放大倍数1-5默认为3
filename: 文件名(仅用于日志)
Returns:
边界框列表
"""
pass
class LocalOCRService(OCRService):
"""本地OCR服务实现直接调用本地OCR模型"""
def __init__(self, parser_instance=None):
"""
初始化本地OCR服务
Args:
parser_instance: SimplePdfParser 实例,如果不提供则自动创建
"""
if parser_instance is None:
from ocr import SimplePdfParser
from ocr.config import MODEL_DIR
logger.info(f"Initializing local OCR parser with model_dir={MODEL_DIR}")
self.parser = SimplePdfParser(model_dir=MODEL_DIR)
else:
self.parser = parser_instance
async def remove_tag(self, text: str) -> str:
"""使用本地解析器的静态方法移除标签"""
# SimplePdfParser.remove_tag 是静态方法,可以直接调用
return self.parser.remove_tag(text)
def remove_tag_sync(self, text: str) -> str:
"""同步版本的 remove_tag"""
return self.parser.remove_tag(text)
async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""使用本地解析器的静态方法提取位置"""
# SimplePdfParser.extract_positions 是静态方法,可以直接调用
return self.parser.extract_positions(text)
def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""同步版本的 extract_positions"""
return self.parser.extract_positions(text)
async def parse_into_bboxes(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""使用本地解析器解析PDF"""
# 本地解析器可以直接接受BytesIO
import asyncio
from io import BytesIO
# 在后台线程中运行同步方法
loop = asyncio.get_event_loop()
bboxes = await loop.run_in_executor(
None,
lambda: self.parser.parse_into_bboxes(BytesIO(pdf_bytes), callback=callback, zoomin=zoomin)
)
return bboxes
def parse_into_bboxes_sync(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""同步版本的 parse_into_bboxes"""
from io import BytesIO
# 本地解析器可以直接接受BytesIO
return self.parser.parse_into_bboxes(BytesIO(pdf_bytes), callback=callback, zoomin=zoomin)
class HTTPOCRService(OCRService):
"""HTTP OCR服务实现通过HTTP接口调用OCR服务"""
def __init__(self, base_url: Optional[str] = None, timeout: float = 300.0):
"""
初始化HTTP OCR服务
Args:
base_url: OCR 服务的基础 URL如果不提供则从环境变量 OCR_SERVICE_URL 获取
timeout: 请求超时时间(秒),默认 300 秒
"""
from ocr.client import OCRClient
self.client = OCRClient(base_url=base_url, timeout=timeout)
async def remove_tag(self, text: str) -> str:
"""通过HTTP接口移除标签"""
return await self.client.remove_tag(text)
def remove_tag_sync(self, text: str) -> str:
"""同步版本的 remove_tag"""
return self.client.remove_tag_sync(text)
async def extract_positions(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""通过HTTP接口提取位置"""
return await self.client.extract_positions(text)
def extract_positions_sync(self, text: str) -> List[Tuple[List[int], float, float, float, float]]:
"""同步版本的 extract_positions"""
return self.client.extract_positions_sync(text)
async def parse_into_bboxes(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""通过HTTP接口解析PDF"""
return await self.client.parse_into_bboxes(pdf_bytes, callback, zoomin, filename)
def parse_into_bboxes_sync(
self,
pdf_bytes: bytes,
callback: Optional[Callable[[float, str], None]] = None,
zoomin: int = 3,
filename: str = "document.pdf"
) -> List[dict]:
"""同步版本的 parse_into_bboxes"""
return self.client.parse_into_bboxes_sync(pdf_bytes, callback, zoomin, filename)
# 全局服务实例(懒加载)
_global_service: Optional[OCRService] = None
def get_ocr_service() -> OCRService:
"""
获取全局 OCR 服务实例(单例模式)
根据环境变量 OCR_MODE 选择使用本地或HTTP方式
- OCR_MODE=local 或未设置使用本地OCR模型
- OCR_MODE=http使用HTTP接口
也可以通过环境变量 OCR_SERVICE_URL 配置HTTP服务的地址仅在OCR_MODE=http时生效
Returns:
OCRService 实例
"""
global _global_service
if _global_service is None:
ocr_mode = os.getenv("OCR_MODE", "local").lower()
if ocr_mode == "http":
base_url = os.getenv("OCR_SERVICE_URL", "http://localhost:8000/api/v1/ocr")
logger.info(f"Initializing HTTP OCR service with URL: {base_url}")
_global_service = HTTPOCRService(base_url=base_url)
else:
logger.info("Initializing local OCR service")
_global_service = LocalOCRService()
return _global_service
# 为了向后兼容,保留 get_ocr_client 函数(但重定向到 get_ocr_service
def get_ocr_client() -> OCRService:
"""
获取OCR服务实例向后兼容函数
建议使用 get_ocr_service() 替代
Returns:
OCRService 实例
"""
return get_ocr_service()

40
ocr/utils.py Normal file
View File

@@ -0,0 +1,40 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
OCR 模块工具函数
"""
import os
def get_project_base_directory(*args):
"""
获取项目根目录
Args:
*args: 可选的子路径
Returns:
str: 项目根目录路径
"""
# 获取当前文件的目录
current_dir = os.path.dirname(os.path.realpath(__file__))
# 返回 ocr 模块的父目录(项目根目录)
base_dir = os.path.dirname(current_dir)
if args:
return os.path.join(base_dir, *args)
return base_dir

View File

@@ -30,6 +30,7 @@ dependencies = [
"demjson3==3.0.6",
"discord-py==2.3.2",
"duckduckgo-search>=7.2.0,<8.0.0",
"email-validator==2.3.0",
"editdistance==0.8.1",
"elastic-transport==8.12.0",
"elasticsearch==8.12.1",
@@ -40,6 +41,7 @@ dependencies = [
"flask-cors==5.0.0",
"flask-login==0.6.3",
"flask-session==0.8.0",
"fastapi==0.118.2",
"google-search-results==2.4.2",
"groq==0.9.0",
"hanziconv==0.3.2",

View File

@@ -22,7 +22,7 @@ import trio
from api.utils import get_uuid
from api.utils.base64_image import id2image, image2id
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
from ocr.service import get_ocr_service
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.hierarchical_merger.schema import HierarchicalMergerFromUpstream
from rag.nlp import concat_img
@@ -170,14 +170,17 @@ class HierarchicalMerger(ProcessBase):
cks.append(txt)
images.append(img)
cks = [
{
"text": RAGFlowPdfParser.remove_tag(c),
ocr_service = get_ocr_service()
processed_cks = []
for c, img in zip(cks, images):
cleaned_text = await ocr_service.remove_tag(c)
positions = await ocr_service.extract_positions(c)
processed_cks.append({
"text": cleaned_text,
"image": img,
"positions": RAGFlowPdfParser.extract_positions(c),
}
for c, img in zip(cks, images)
]
"positions": positions,
})
cks = processed_cks
async with trio.open_nursery() as nursery:
for d in cks:
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid())

View File

@@ -29,7 +29,8 @@ from api.db.services.llm_service import LLMBundle
from api.utils import get_uuid
from api.utils.base64_image import image2id
from deepdoc.parser import ExcelParser
from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser
from deepdoc.parser.pdf_parser import PlainParser, VisionParser
from ocr.service import get_ocr_service
from rag.app.naive import Docx
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.parser.schema import ParserFromUpstream
@@ -204,7 +205,9 @@ class Parser(ProcessBase):
self.set_output("output_format", conf["output_format"])
if conf.get("parse_method").lower() == "deepdoc":
bboxes = RAGFlowPdfParser().parse_into_bboxes(blob, callback=self.callback)
# 注意HTTP 调用中无法传递 callbackcallback 将被忽略
ocr_service = get_ocr_service()
bboxes = ocr_service.parse_into_bboxes_sync(blob, callback=self.callback, filename=name)
elif conf.get("parse_method").lower() == "plain_text":
lines, _ = PlainParser()(blob)
bboxes = [{"text": t} for t, _ in lines]

View File

@@ -19,7 +19,7 @@ import trio
from api.utils import get_uuid
from api.utils.base64_image import id2image, image2id
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
from ocr.service import get_ocr_service
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.splitter.schema import SplitterFromUpstream
from rag.nlp import naive_merge, naive_merge_with_images
@@ -96,14 +96,18 @@ class Splitter(ProcessBase):
deli,
self._param.overlapped_percent,
)
cks = [
{
"text": RAGFlowPdfParser.remove_tag(c),
ocr_service = get_ocr_service()
cks = []
for c, img in zip(chunks, images):
if not c.strip():
continue
cleaned_text = await ocr_service.remove_tag(c)
positions = await ocr_service.extract_positions(c)
cks.append({
"text": cleaned_text,
"image": img,
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)],
}
for c, img in zip(chunks, images) if c.strip()
]
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in positions],
})
async with trio.open_nursery() as nursery:
for d in cks:
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid())

View File

@@ -578,7 +578,8 @@ def hierarchical_merge(bull, sections, depth):
def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
from ocr.service import get_ocr_service
ocr_service = get_ocr_service()
if not sections:
return []
if isinstance(sections, str):
@@ -598,7 +599,7 @@ def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。
# Ensure that the length of the merged chunk does not exceed chunk_token_num
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
if cks:
overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
overlapped = ocr_service.remove_tag_sync(cks[-1])
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
if t.find(pos) < 0:
t += pos
@@ -625,7 +626,8 @@ def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。
def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
from ocr.service import get_ocr_service
ocr_service = get_ocr_service()
if not texts or len(texts) != len(images):
return [], []
cks = [""]
@@ -642,7 +644,7 @@ def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。
# Ensure that the length of the merged chunk does not exceed chunk_token_num
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
if cks:
overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
overlapped = ocr_service.remove_tag_sync(cks[-1])
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
if t.find(pos) < 0:
t += pos

7505
uv.lock generated

File diff suppressed because it is too large Load Diff