Fix SSE route dependency and align architecture docs

This commit is contained in:
ash66
2026-05-18 16:32:42 +08:00
parent 86b9ac806a
commit 3f69cad404
149 changed files with 4786 additions and 5957 deletions

12
.env
View File

@@ -9,7 +9,7 @@ DEBUG=false
# ===== Milvus向量数据库配置已有===== # ===== Milvus向量数据库配置已有=====
MILVUS_HOST=localhost MILVUS_HOST=localhost
MILVUS_PORT=19530 MILVUS_PORT=19530
MILVUS_COLLECTION=regulations MILVUS_COLLECTION=regulations_dense_1536
MILVUS_DB_NAME=default MILVUS_DB_NAME=default
# ===== MinIO对象存储配置已有===== # ===== MinIO对象存储配置已有=====
@@ -33,11 +33,11 @@ POSTGRES_PASSWORD=postgresql123456
POSTGRES_DB=compliance_db POSTGRES_DB=compliance_db
# ===== 嵌入模型配置 ===== # ===== 嵌入模型配置 =====
EMBEDDING_MODEL=BAAI/bge-m3 EMBEDDING_MODEL=text-embedding-v3
EMBEDDING_DIM=1024 EMBEDDING_DIM=1536
EMBEDDING_MAX_LENGTH=8192 EMBEDDING_API_KEY=sk-fVr9KmDZNC4pGDBQj0EUWz9bDmFzNxjYC9EzZpe2bVDsxtz8
EMBEDDING_BATCH_SIZE=12 EMBEDDING_BASE_URL=http://6.86.80.4:30080/v1
EMBEDDING_USE_FP16=true EMBEDDING_TIMEOUT_SECONDS=120
# ===== 文档处理配置 ===== # ===== 文档处理配置 =====
CHUNK_SIZE=512 CHUNK_SIZE=512

29
.env.development Normal file
View File

@@ -0,0 +1,29 @@
# Local development overrides for the repo-root backend configuration.
# Keep shared defaults in .env and put machine-specific or remote-dev values here.
# ===== Milvus向量数据库配置已有=====
MILVUS_HOST=6.86.80.8
MILVUS_PORT=19530
MILVUS_COLLECTION=regulations_dense_1536
MILVUS_DB_NAME=default
# ===== MinIO对象存储配置已有=====
MINIO_ENDPOINT=6.86.80.8:9000
MINIO_ACCESS_KEY=minioadmin
MINIO_SECRET_KEY=minioadmin
MINIO_BUCKET=compliance-docs
MINIO_SECURE=false
# ===== Redis配置已有=====
REDIS_HOST=6.86.80.8
REDIS_PORT=6379
REDIS_PASSWORD=redis@123
REDIS_DB=0
# ===== PostgreSQL配置已有=====
POSTGRES_HOST=6.86.80.8
POSTGRES_PORT=5432
POSTGRES_USER=postgresql
POSTGRES_PASSWORD=postgresql123456
POSTGRES_DB=compliance_db

View File

@@ -9,15 +9,15 @@ DEBUG=false
# ===== Milvus向量数据库配置 ===== # ===== Milvus向量数据库配置 =====
MILVUS_HOST=localhost MILVUS_HOST=localhost
MILVUS_PORT=19530 MILVUS_PORT=19530
MILVUS_COLLECTION=regulations MILVUS_COLLECTION=regulations_dense_1536
MILVUS_DB_NAME=default MILVUS_DB_NAME=default
# ===== 嵌入模型配置 ===== # ===== 嵌入模型配置 =====
EMBEDDING_MODEL=BAAI/bge-m3 EMBEDDING_MODEL=text-embedding-v3
EMBEDDING_DIM=1024 EMBEDDING_DIM=1536
EMBEDDING_MAX_LENGTH=8192 EMBEDDING_API_KEY=your_embedding_api_key_here
EMBEDDING_BATCH_SIZE=12 EMBEDDING_BASE_URL=http://6.86.80.4:30080/v1
EMBEDDING_USE_FP16=true EMBEDDING_TIMEOUT_SECONDS=120
# ===== MinIO对象存储配置 ===== # ===== MinIO对象存储配置 =====
MINIO_ENDPOINT=localhost:9000 MINIO_ENDPOINT=localhost:9000
@@ -43,6 +43,12 @@ POSTGRES_DB=compliance_db
CHUNK_SIZE=512 CHUNK_SIZE=512
CHUNK_OVERLAP=50 CHUNK_OVERLAP=50
MAX_FILE_SIZE_MB=100 MAX_FILE_SIZE_MB=100
DOCUMENT_METADATA_PATH=backend/data/documents.json
# ===== 阿里云文档解析 =====
ALIBABA_ACCESS_KEY_ID=your_aliyun_access_key_id
ALIBABA_ACCESS_KEY_SECRET=your_aliyun_access_key_secret
ALIBABA_ENDPOINT=docmind-api.cn-hangzhou.aliyuncs.com
# ===== API服务配置 ===== # ===== API服务配置 =====
API_HOST=0.0.0.0 API_HOST=0.0.0.0
@@ -75,4 +81,7 @@ DEEPSEEK_MODEL=deepseek-v4-flash
RAG_TOP_K=10 RAG_TOP_K=10
RAG_MAX_CONTEXT_TOKENS=4000 RAG_MAX_CONTEXT_TOKENS=4000
RAG_SUMMARY_MAX_TOKENS=1024 RAG_SUMMARY_MAX_TOKENS=1024
RAG_SKILLS_MAX_TOKENS=2048
# ===== 会话配置 =====
SESSION_MAX_SESSIONS=100
SESSION_TIMEOUT_MINUTES=30

2
.gitignore vendored
View File

@@ -39,8 +39,6 @@ nosetests.xml
*.py,cover *.py,cover
# Environments # Environments
.env
.env.*
.venv/ .venv/
venv/ venv/
env/ env/

View File

@@ -2,38 +2,47 @@
## Scope ## Scope
- This repo uses `backend/app/` for the backend and `frontend/` for the Vite React app. - Backend code lives under `backend/app/`; frontend is the Vite app in `frontend/`.
## Real Entrypoints ## Entrypoints
- Current backend app entrypoint is `backend/app/main.py`, exporting `app` from `app.api.main`. - Backend entrypoint is `backend/app/main.py`, which re-exports `app` from `app.api.main`.
- Current backend dev start command is `python -m uvicorn app.main:app --reload` with `PYTHONPATH=backend`. - FastAPI mounts the real API under `/api/v1`; health endpoints are `GET /health` and `GET /`.
- `dev.sh start api --foreground` is the repo script flow that encodes the expected backend startup behavior. - Frontend API calls use relative `/api` URLs from `frontend/src/api/index.ts`.
- Frontend dev server is `frontend` Vite on port `5173`, proxying `/api` to `http://localhost:8000`. - Current Vite dev proxy in `frontend/vite.config.ts` forwards `/api` to `http://6.86.80.8:8000`, not the local backend.
## Commands ## Preferred Commands
- Backend install: `pip install -r backend/requirements.txt` - Prefer the repo scripts over ad hoc startup commands: `./dev.sh ...` on Unix, `dev.bat ...` on Windows.
- Backend run from repo root: `PYTHONPATH=backend uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload` - First-time local setup: `./dev.sh setup` or `dev.bat setup`.
- Frontend install: `cd frontend && npm install` - Backend foreground dev server: `./dev.sh start api --foreground` or `dev.bat start api --foreground`.
- Frontend dev: `cd frontend && npm run dev` - Equivalent direct backend run from repo root: `PYTHONPATH=backend uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload`.
- Frontend build: `cd frontend && npm run build` - Python tooling must use the repo-root `.venv` environment; do not use the global `python`.
- Frontend lint: `cd frontend && npm run lint` - Prefer `uv`-managed execution from the repo root, for example `uv run --python .venv\\Scripts\\python.exe python ...` or other `uv run ...` forms that resolve within the project environment.
- Frontend dev: `npm --prefix frontend run dev`.
- Frontend verification: `npm --prefix frontend run lint` and `npm --prefix frontend run build`.
- Use `npm`, not `pnpm`; the checked-in scripts run `npm install` even though `frontend/pnpm-lock.yaml` also exists.
## Infra And Env ## Env And Infra
- Backend settings load from root `.env`, not `backend/.env`, because `backend/app/config/settings.py` uses `env_file = ".env"`. - Backend settings must resolve env files from the repo root only. The supported files are root `.env` and optional root `.env.development`; files under `backend/` must not be treated as authoritative env sources.
- Docker infra is defined in `docker/docker-compose.yml`; it starts Milvus, MinIO, Redis, and PostgreSQL. - The dev scripts read `API_HOST`, `API_PORT`, `FRONTEND_PORT`, and `FRONTEND_MODE` from root `.env` first, then root `.env.development` as an override; all other backend config comes from the repo-root env files via Pydantic settings.
- Default local service ports: Milvus `19530`, MinIO `9000/9001`, Redis `6379`, PostgreSQL `5432`, backend `8000`, frontend `5173`. - Docker infra is `docker/docker-compose.yml`; it brings up `milvus`, `minio`, `redis`, and `postgres`.
- LLM base URLs in `.env.example` point at a shared remote gateway (`http://6.86.80.4:30080/v1`); do not assume offline/local-only LLM execution. - Default ports from config/scripts: backend `8000`, frontend `5173`, Milvus `19530`, MinIO `9000/9001`, Redis `6379`, PostgreSQL `5432`.
- `.env.example` points embedding/LLM base URLs at a shared remote gateway `http://6.86.80.4:30080/v1`; do not assume model inference is local.
## Verification ## Verification
- Root pytest config in `pyproject.toml` points at root `tests/`, and those tests import from `backend/app` via `PYTHONPATH` setup. - Root `pyproject.toml` is the active Python manifest and pytest config; `testpaths = ["tests"]`.
- For frontend-only changes, run `npm run lint` and `npm run build` in `frontend`. - Tests under `tests/` insert `backend/` into `sys.path` themselves, so targeted runs can be launched from the repo root.
- For backend changes, prefer focused import/startup verification against `backend/app`, and run root `tests/` when the environment supports it. - `tests/test_milvus.py` and `tests/verify_mvp.py` require Milvus and model/runtime dependencies; they are not cheap smoke tests.
- `tests/verify_mvp.py` also expects the `BGEM3Embedder` stack to be available and explicitly mentions `FlagEmbedding`.
- For backend-only changes, prefer focused import/startup checks unless you know the external services and model dependencies are available.
## Gotchas ## Backend Commenting Standard
- Backend settings load from root `.env`, not `backend/.env`, because `backend/app/config/settings.py` uses `env_file = ".env"`. - All comments and docstrings in `backend/**/*.py` must be written in English.
- The root `pyproject.toml` is the active Python package manifest for the repo. - Every Python file under `backend` must include a module-level explanation and at least one meaningful `#` code comment.
- Every function and method must include a docstring.
- Files without functions, including `__init__.py`, schemas, enums, dataclasses, and export-only modules, must still include a module docstring, class docstrings when applicable, and at least one meaningful `#` code comment.
- Comments must explain intent, assumptions, invariants, or non-obvious logic; do not add empty, placeholder, or restatement-only comments.

View File

@@ -96,19 +96,16 @@ pip install -r backend/requirements.txt
--- ---
## 四、下载嵌入模型 ## 四、配置解析与 embedding
BGE-M3模型约2GB首次使用需下载。 当前主链路不再依赖本地 BGE-M3 模型文件,必须配置:
### 方式A自动下载联网环境 ```env
ALIBABA_ACCESS_KEY_ID=your_aliyun_access_key_id
首次启动API时自动下载到 `~/.cache/huggingface/` ALIBABA_ACCESS_KEY_SECRET=your_aliyun_access_key_secret
EMBEDDING_API_KEY=your_embedding_api_key_here
### 方式B手动下载离线环境 EMBEDDING_MODEL=text-embedding-v3
EMBEDDING_DIM=1536
```bash
# 从ModelScope下载
python -c "from modelscope import snapshot_download; snapshot_download('Xorbits/bge-m3', cache_dir='~/.cache/modelscope')"
``` ```
--- ---
@@ -320,7 +317,7 @@ export HF_ENDPOINT=https://hf-mirror.com
或手动下载: 或手动下载:
```bash ```bash
python -c "from modelscope import snapshot_download; snapshot_download('Xorbits/bge-m3')" 当前版本无需下载本地 BGE-M3 模型;请改为确认 `EMBEDDING_API_KEY` 与阿里云文档解析凭据已配置。
``` ```
### Q3: LLM调用失败 ### Q3: LLM调用失败

View File

@@ -6,10 +6,10 @@
本次实现的核心功能(最小可用版本): 本次实现的核心功能(最小可用版本):
- ✅ PDF/DOCX文档解析MinerU + PyMuPDF - ✅ PDF/DOC/DOCX 文档解析(阿里云文档智能
-智能分块(章节级+条款级双粒度切割) -基于阿里云 `vector_chunks` 的统一切片
-BGE-M3嵌入Dense+Sparse双路向量 -OpenAI 兼容 embedding`text-embedding-v3`1536维
- ✅ Milvus向量数据库存储与混合检索 - ✅ Milvus 向量数据库存储与 dense-only 检索
- ✅ FastAPI接口封装 - ✅ FastAPI接口封装
## 项目结构 ## 项目结构
@@ -19,8 +19,10 @@ AIRegulation-DocAnalysis-Demo/
├── backend/ ├── backend/
│ ├── app/ │ ├── app/
│ │ ├── api/ # FastAPI 接口层 │ │ ├── api/ # FastAPI 接口层
│ │ ├── application/ # 用例编排层
│ │ ├── domain/ # 领域模型与稳定端口
│ │ ├── infrastructure/ # MinIO / Milvus / 阿里云 / embedding / session 适配
│ │ ├── config/ # 配置与日志 │ │ ├── config/ # 配置与日志
│ │ ├── services/ # 解析、分块、嵌入、存储、Agent
│ │ └── workers/ │ │ └── workers/
│ ├── requirements.txt │ ├── requirements.txt
│ └── main.py │ └── main.py
@@ -52,15 +54,7 @@ docker-compose up -d
docker-compose logs -f milvus docker-compose logs -f milvus
``` ```
### 3. 运行验证脚本 ### 3. 启动API服务
```bash
python tests/verify_mvp.py
```
根级测试脚本会自动把 `backend/` 加入导入路径,并从 `app.*` 加载当前后端代码。
### 4. 启动API服务
```bash ```bash
PYTHONPATH=backend uvicorn app.main:app --reload --port 8000 PYTHONPATH=backend uvicorn app.main:app --reload --port 8000
@@ -91,11 +85,11 @@ curl -X POST http://localhost:8000/api/v1/knowledge/search \
| 类别 | 技术 | | 类别 | 技术 |
|------|------| |------|------|
| 文档解析 | MinerU + PyMuPDF + python-docx | | 文档解析 | 阿里云文档智能 + python-docx |
| 分块策略 | 章节级+条款级双粒度切割 | | 分块策略 | 阿里云 `vector_chunks` |
| 嵌入模型 | BGE-M31024维 Dense + Sparse | | 嵌入模型 | `text-embedding-v3`1536维 Dense |
| 向量数据库 | Milvus 2.4本地Docker部署 | | 向量数据库 | Milvus 2.4本地Docker部署 |
| 检索方式 | Dense+Sparse混合检索 + RRF融合 | | 检索方式 | Dense-only 检索 |
| API框架 | FastAPI | | API框架 | FastAPI |
## 配置 ## 配置
@@ -107,9 +101,14 @@ curl -X POST http://localhost:8000/api/v1/knowledge/search \
MILVUS_HOST=localhost MILVUS_HOST=localhost
MILVUS_PORT=19530 MILVUS_PORT=19530
# 嵌入模型配置 # 阿里云文档解析
EMBEDDING_MODEL=BAAI/bge-m3 ALIBABA_ACCESS_KEY_ID=your_aliyun_access_key_id
EMBEDDING_DIM=1024 ALIBABA_ACCESS_KEY_SECRET=your_aliyun_access_key_secret
# embedding 配置
EMBEDDING_MODEL=text-embedding-v3
EMBEDDING_DIM=1536
EMBEDDING_API_KEY=your_embedding_api_key_here
# 分块配置 # 分块配置
CHUNK_SIZE=512 CHUNK_SIZE=512
@@ -117,7 +116,7 @@ CHUNK_SIZE=512
## 后续迭代不在本次MVP范围 ## 后续迭代不在本次MVP范围
- LLM摘要生成DeepSeek/Qwen API - LLM摘要生成当前上传主链路默认不生成
- 文档上传UI界面 - 文档上传UI界面
- 混合检索问答功能 - 混合检索问答功能
- 法规变更监控与自动更新 - 法规变更监控与自动更新

View File

@@ -1,56 +0,0 @@
APP_NAME=AI+合规智能中枢
APP_VERSION=0.1.0
DEBUG=false
MILVUS_HOST=localhost
MILVUS_PORT=19530
MILVUS_COLLECTION=regulations
MILVUS_DB_NAME=default
MINIO_ENDPOINT=localhost:9000
MINIO_ACCESS_KEY=minioadmin
MINIO_SECRET_KEY=minioadmin
MINIO_BUCKET=compliance-docs
MINIO_SECURE=false
REDIS_HOST=localhost
REDIS_PORT=6379
REDIS_PASSWORD=redis@123
REDIS_DB=0
POSTGRES_HOST=localhost
POSTGRES_PORT=5432
POSTGRES_USER=postgresql
POSTGRES_PASSWORD=postgresql123456
POSTGRES_DB=compliance_db
EMBEDDING_MODEL=BAAI/bge-m3
EMBEDDING_DIM=1024
EMBEDDING_MAX_LENGTH=8192
EMBEDDING_BATCH_SIZE=12
EMBEDDING_USE_FP16=true
CHUNK_SIZE=512
CHUNK_OVERLAP=50
MAX_FILE_SIZE_MB=100
API_HOST=0.0.0.0
API_PORT=8000
LLM_PROVIDER=deepseek
LLM_MODEL=deepseek-v4-flash
LLM_MAX_TOKENS=4096
LLM_TEMPERATURE=0.7
QWEN_API_KEY=sk-fVr9KmDZNC4pGDBQj0EUWz9bDmFzNxjYC9EzZpe2bVDsxtz8
QWEN_BASE_URL=http://6.86.80.4:30080/v1
QWEN_MODEL=qwen3.5-plus
QWEN_VL_MODEL=qwen3-vl-plus
DEEPSEEK_API_KEY=sk-fVr9KmDZNC4pGDBQj0EUWz9bDmFzNxjYC9EzZpe2bVDsxtz8
DEEPSEEK_BASE_URL=http://6.86.80.4:30080/v1
DEEPSEEK_MODEL=deepseek-v4-flash
RAG_TOP_K=10
RAG_MAX_CONTEXT_TOKENS=4000
RAG_SUMMARY_MAX_TOKENS=1024

View File

@@ -1,56 +0,0 @@
APP_NAME=AI+合规智能中枢
APP_VERSION=0.1.0
DEBUG=false
MILVUS_HOST=localhost
MILVUS_PORT=19530
MILVUS_COLLECTION=regulations
MILVUS_DB_NAME=default
EMBEDDING_MODEL=BAAI/bge-m3
EMBEDDING_DIM=1024
EMBEDDING_MAX_LENGTH=8192
EMBEDDING_BATCH_SIZE=12
EMBEDDING_USE_FP16=true
MINIO_ENDPOINT=localhost:9000
MINIO_ACCESS_KEY=minioadmin
MINIO_SECRET_KEY=minioadmin123
MINIO_BUCKET=compliance-docs
MINIO_SECURE=false
REDIS_HOST=localhost
REDIS_PORT=6379
REDIS_PASSWORD=
REDIS_DB=0
POSTGRES_HOST=localhost
POSTGRES_PORT=5432
POSTGRES_USER=compliance
POSTGRES_PASSWORD=compliance123
POSTGRES_DB=compliance_db
CHUNK_SIZE=512
CHUNK_OVERLAP=50
MAX_FILE_SIZE_MB=100
API_HOST=0.0.0.0
API_PORT=8000
LLM_PROVIDER=deepseek
LLM_MODEL=deepseek-v4-flash
LLM_MAX_TOKENS=4096
LLM_TEMPERATURE=0.7
QWEN_API_KEY=your_api_key_here
DEEPSEEK_API_KEY=your_api_key_here
QWEN_BASE_URL=http://6.86.80.4:30080/v1
DEEPSEEK_BASE_URL=http://6.86.80.4:30080/v1
QWEN_MODEL=qwen3.5-plus
QWEN_VL_MODEL=qwen3-vl-plus
DEEPSEEK_MODEL=deepseek-v4-flash
RAG_TOP_K=10
RAG_MAX_CONTEXT_TOKENS=4000
RAG_SUMMARY_MAX_TOKENS=1024

View File

@@ -1,3 +1,14 @@
from .main import app """Initialize the app package."""
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = ["app"] __all__ = ["app"]
def __getattr__(name: str):
"""Handle getattr for this module."""
if name == "app":
from .main import app
return app
raise AttributeError(name)

View File

@@ -1,14 +1,10 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """Handle Aliyun parsing support for parse pdf."""
阿里云文档智能 API 解析 PDF输出三层结构 chunks
- structure_nodes: 目录树结构
- semantic_blocks: 语义块(章节文本、表格、图片)
- vector_chunks: 检索块(带 overlap 切分)
"""
import argparse import argparse
import json import json
import os
import re import re
import time import time
from pathlib import Path from pathlib import Path
@@ -19,16 +15,16 @@ from alibabacloud_tea_openapi import models as open_api_models
from alibabacloud_docmind_api20220711 import models as docmind_models from alibabacloud_docmind_api20220711 import models as docmind_models
from alibabacloud_tea_util import models as util_models from alibabacloud_tea_util import models as util_models
# ===================== 阿里云配置 ===================== # Keep parser integration steps explicit so external workflow behavior stays traceable.
ALIBABA_ACCESS_KEY_ID = "LTAI5t6fWvAsvZkoF9WTbtys" ALIBABA_ACCESS_KEY_ID = os.getenv("ALIBABA_ACCESS_KEY_ID", "")
ALIBABA_ACCESS_KEY_SECRET = "WX4oaE4FLYRa5L85TMQkqRPHeTJAF0" ALIBABA_ACCESS_KEY_SECRET = os.getenv("ALIBABA_ACCESS_KEY_SECRET", "")
ALIBABA_ENDPOINT = "docmind-api.cn-hangzhou.aliyuncs.com" ALIBABA_ENDPOINT = os.getenv("ALIBABA_ENDPOINT", "docmind-api.cn-hangzhou.aliyuncs.com")
# ===================== 切分参数 ===================== # Keep parser integration steps explicit so external workflow behavior stays traceable.
MAX_CHARS = 600 MAX_CHARS = 600
OVERLAP_CHARS = 80 OVERLAP_CHARS = 80
# ===================== 布局类型常量 ===================== # Keep parser integration steps explicit so external workflow behavior stays traceable.
TOC_TITLES = {"目次", "目录"} TOC_TITLES = {"目次", "目录"}
TITLE_SUBTYPES = {"doc_title", "para_title"} TITLE_SUBTYPES = {"doc_title", "para_title"}
TEXT_SUBTYPES = {"para", "none"} TEXT_SUBTYPES = {"para", "none"}
@@ -36,8 +32,11 @@ FIGURE_TYPES = {"figure", "figure_name", "figure_note"}
FIGURE_SUBTYPES = {"picture", "pic_title", "pic_caption"} FIGURE_SUBTYPES = {"picture", "pic_title", "pic_caption"}
# ===================== 阿里云 API 客户端 ===================== # Keep parser integration steps explicit so external workflow behavior stays traceable.
def init_client() -> DocmindClient: def init_client() -> DocmindClient:
"""Handle init client."""
if not ALIBABA_ACCESS_KEY_ID or not ALIBABA_ACCESS_KEY_SECRET:
raise ValueError("缺少阿里云文档解析凭据,请设置 ALIBABA_ACCESS_KEY_ID 和 ALIBABA_ACCESS_KEY_SECRET")
config = open_api_models.Config( config = open_api_models.Config(
access_key_id=ALIBABA_ACCESS_KEY_ID, access_key_id=ALIBABA_ACCESS_KEY_ID,
access_key_secret=ALIBABA_ACCESS_KEY_SECRET, access_key_secret=ALIBABA_ACCESS_KEY_SECRET,
@@ -47,7 +46,7 @@ def init_client() -> DocmindClient:
def submit_job(client: DocmindClient, file_path: str) -> str: def submit_job(client: DocmindClient, file_path: str) -> str:
"""提交文档解析任务""" """Submit job."""
file_name = Path(file_path).name file_name = Path(file_path).name
request = docmind_models.SubmitDocParserJobAdvanceRequest( request = docmind_models.SubmitDocParserJobAdvanceRequest(
file_url_object=open(file_path, "rb"), file_url_object=open(file_path, "rb"),
@@ -62,14 +61,14 @@ def submit_job(client: DocmindClient, file_path: str) -> str:
def query_status(client: DocmindClient, task_id: str) -> Dict: def query_status(client: DocmindClient, task_id: str) -> Dict:
"""查询任务状态""" """Handle query status."""
request = docmind_models.QueryDocParserStatusRequest(id=task_id) request = docmind_models.QueryDocParserStatusRequest(id=task_id)
response = client.query_doc_parser_status(request) response = client.query_doc_parser_status(request)
return response.body.data.to_map() if response.body.data else None return response.body.data.to_map() if response.body.data else None
def wait_for_completion(client: DocmindClient, task_id: str, poll_interval: int = 5) -> bool: def wait_for_completion(client: DocmindClient, task_id: str, poll_interval: int = 5) -> bool:
"""等待任务完成""" """Wait for for completion."""
while True: while True:
status_data = query_status(client, task_id) status_data = query_status(client, task_id)
if not status_data: if not status_data:
@@ -85,7 +84,7 @@ def wait_for_completion(client: DocmindClient, task_id: str, poll_interval: int
def get_result(client: DocmindClient, task_id: str, layout_num: int = 0, layout_step_size: int = 50) -> Dict: def get_result(client: DocmindClient, task_id: str, layout_num: int = 0, layout_step_size: int = 50) -> Dict:
"""获取解析结果""" """Return result."""
request = docmind_models.GetDocParserResultRequest( request = docmind_models.GetDocParserResultRequest(
id=task_id, id=task_id,
layout_step_size=layout_step_size, layout_step_size=layout_step_size,
@@ -96,7 +95,7 @@ def get_result(client: DocmindClient, task_id: str, layout_num: int = 0, layout_
def collect_all_results(client: DocmindClient, task_id: str, layout_step_size: int = 50) -> List[Dict]: def collect_all_results(client: DocmindClient, task_id: str, layout_step_size: int = 50) -> List[Dict]:
"""收集所有解析结果""" """Collect all results."""
all_layouts = [] all_layouts = []
layout_num = 0 layout_num = 0
while True: while True:
@@ -113,8 +112,9 @@ def collect_all_results(client: DocmindClient, task_id: str, layout_step_size: i
return all_layouts return all_layouts
# ===================== 文本处理 ===================== # Keep parser integration steps explicit so external workflow behavior stays traceable.
def normalize_text(text: str) -> str: def normalize_text(text: str) -> str:
"""Normalize text."""
text = text.replace("\r", "\n") text = text.replace("\r", "\n")
text = text.replace(" ", " ") text = text.replace(" ", " ")
text = re.sub(r"\n+", "\n", text) text = re.sub(r"\n+", "\n", text)
@@ -123,34 +123,41 @@ def normalize_text(text: str) -> str:
def get_page(layout: Dict) -> int: def get_page(layout: Dict) -> int:
"""Return page."""
return layout.get("pageNum", layout.get("pageNumber", 0)) return layout.get("pageNum", layout.get("pageNumber", 0))
def get_text(layout: Dict) -> str: def get_text(layout: Dict) -> str:
"""Return text."""
text = normalize_text(layout.get("text", "")) text = normalize_text(layout.get("text", ""))
if text: if text:
return text return text
return normalize_text(layout.get("markdownContent", "")) return normalize_text(layout.get("markdownContent", ""))
# ===================== 布局类型判断 ===================== # Keep parser integration steps explicit so external workflow behavior stays traceable.
def is_title(layout: Dict) -> bool: def is_title(layout: Dict) -> bool:
"""Return whether title."""
return layout.get("type") == "title" or layout.get("subType") in TITLE_SUBTYPES return layout.get("type") == "title" or layout.get("subType") in TITLE_SUBTYPES
def is_text(layout: Dict) -> bool: def is_text(layout: Dict) -> bool:
"""Return whether text."""
return layout.get("type") == "text" and layout.get("subType", "none") in TEXT_SUBTYPES return layout.get("type") == "text" and layout.get("subType", "none") in TEXT_SUBTYPES
def is_figure(layout: Dict) -> bool: def is_figure(layout: Dict) -> bool:
"""Return whether figure."""
return layout.get("type") in FIGURE_TYPES or layout.get("subType") in FIGURE_SUBTYPES return layout.get("type") in FIGURE_TYPES or layout.get("subType") in FIGURE_SUBTYPES
def is_table(layout: Dict) -> bool: def is_table(layout: Dict) -> bool:
"""Return whether table."""
return layout.get("type") == "table" return layout.get("type") == "table"
def is_toc_layout(layout: Dict) -> bool: def is_toc_layout(layout: Dict) -> bool:
"""Return whether toc layout."""
text = get_text(layout) text = get_text(layout)
if text in TOC_TITLES: if text in TOC_TITLES:
return True return True
@@ -160,6 +167,7 @@ def is_toc_layout(layout: Dict) -> bool:
def extract_table_text(layout: Dict) -> str: def extract_table_text(layout: Dict) -> str:
"""Extract table text."""
rows = [] rows = []
for cell in layout.get("cells", []): for cell in layout.get("cells", []):
texts = [] texts = []
@@ -172,8 +180,9 @@ def extract_table_text(layout: Dict) -> str:
return "\n".join(rows).strip() return "\n".join(rows).strip()
# ===================== 结构层:目录树 ===================== # Keep parser integration steps explicit so external workflow behavior stays traceable.
def build_structure_nodes(layouts: List[Dict]) -> List[Dict]: def build_structure_nodes(layouts: List[Dict]) -> List[Dict]:
"""Build structure nodes."""
nodes = [] nodes = []
for layout in layouts: for layout in layouts:
if not is_title(layout): if not is_title(layout):
@@ -195,8 +204,9 @@ def build_structure_nodes(layouts: List[Dict]) -> List[Dict]:
return nodes return nodes
# ===================== 语义层:章节内容 ===================== # Keep parser integration steps explicit so external workflow behavior stays traceable.
def update_section_path(section_stack: List[Dict], layout: Dict) -> List[Dict]: def update_section_path(section_stack: List[Dict], layout: Dict) -> List[Dict]:
"""Update section path."""
level = layout.get("level", 0) level = layout.get("level", 0)
title = get_text(layout) title = get_text(layout)
while section_stack and section_stack[-1]["level"] >= level: while section_stack and section_stack[-1]["level"] >= level:
@@ -213,10 +223,12 @@ def update_section_path(section_stack: List[Dict], layout: Dict) -> List[Dict]:
def section_path_titles(section_stack: List[Dict]) -> List[str]: def section_path_titles(section_stack: List[Dict]) -> List[str]:
"""Handle section path titles."""
return [item["title"] for item in section_stack] return [item["title"] for item in section_stack]
def flush_text_block(blocks: List[Dict], semantic_blocks: List[Dict], block_id: int) -> int: def flush_text_block(blocks: List[Dict], semantic_blocks: List[Dict], block_id: int) -> int:
"""Handle flush text block."""
if not blocks: if not blocks:
return block_id return block_id
@@ -242,6 +254,7 @@ def flush_text_block(blocks: List[Dict], semantic_blocks: List[Dict], block_id:
def build_semantic_blocks(layouts: List[Dict]) -> List[Dict]: def build_semantic_blocks(layouts: List[Dict]) -> List[Dict]:
"""Build semantic blocks."""
semantic_blocks = [] semantic_blocks = []
section_stack = [] section_stack = []
pending_text_blocks = [] pending_text_blocks = []
@@ -327,8 +340,9 @@ def build_semantic_blocks(layouts: List[Dict]) -> List[Dict]:
return semantic_blocks return semantic_blocks
# ===================== 检索层:向量 chunks ===================== # Keep parser integration steps explicit so external workflow behavior stays traceable.
def split_text_with_overlap(text: str, max_chars: int, overlap_chars: int) -> List[str]: def split_text_with_overlap(text: str, max_chars: int, overlap_chars: int) -> List[str]:
"""Handle split text with overlap."""
text = text.strip() text = text.strip()
if len(text) <= max_chars: if len(text) <= max_chars:
return [text] if text else [] return [text] if text else []
@@ -351,6 +365,7 @@ def build_vector_chunks(
max_chars: int, max_chars: int,
overlap_chars: int, overlap_chars: int,
) -> List[Dict]: ) -> List[Dict]:
"""Build vector chunks."""
vector_chunks = [] vector_chunks = []
chunk_index = 1 chunk_index = 1
@@ -385,7 +400,31 @@ def build_vector_chunks(
return vector_chunks return vector_chunks
# ===================== 主转换函数 ===================== def parse_pdf_to_structured_chunks(
pdf_path: str,
*,
doc_id: str,
doc_title: str,
max_chars: int = MAX_CHARS,
overlap_chars: int = OVERLAP_CHARS,
poll_interval: int = 5,
) -> Dict:
"""Parse pdf to structured chunks."""
client = init_client()
task_id = submit_job(client, pdf_path)
if not wait_for_completion(client, task_id, poll_interval):
raise RuntimeError("阿里云文档解析任务失败")
layouts = collect_all_results(client, task_id)
return convert_layouts(
layouts,
doc_id=doc_id,
doc_title=doc_title,
max_chars=max_chars,
overlap_chars=overlap_chars,
)
# Keep parser integration steps explicit so external workflow behavior stays traceable.
def convert_layouts( def convert_layouts(
layouts: List[Dict], layouts: List[Dict],
doc_id: str, doc_id: str,
@@ -393,6 +432,7 @@ def convert_layouts(
max_chars: int, max_chars: int,
overlap_chars: int, overlap_chars: int,
) -> Dict: ) -> Dict:
"""Handle convert layouts."""
structure_nodes = build_structure_nodes(layouts) structure_nodes = build_structure_nodes(layouts)
semantic_blocks = build_semantic_blocks(layouts) semantic_blocks = build_semantic_blocks(layouts)
vector_chunks = build_vector_chunks( vector_chunks = build_vector_chunks(
@@ -411,8 +451,9 @@ def convert_layouts(
} }
# ===================== CLI 入口 ===================== # Keep parser integration steps explicit so external workflow behavior stays traceable.
def main() -> None: def main() -> None:
"""Run the module entrypoint."""
parser = argparse.ArgumentParser(description="阿里云文档智能解析 PDF输出三层结构 chunks") parser = argparse.ArgumentParser(description="阿里云文档智能解析 PDF输出三层结构 chunks")
parser.add_argument("pdf_path", help="PDF 文件路径") parser.add_argument("pdf_path", help="PDF 文件路径")
parser.add_argument("--out", default="vector_chunks.json", help="输出 JSON 文件路径") parser.add_argument("--out", default="vector_chunks.json", help="输出 JSON 文件路径")
@@ -428,30 +469,30 @@ def main() -> None:
if not pdf_path.exists(): if not pdf_path.exists():
raise FileNotFoundError(f"PDF 文件不存在: {pdf_path}") raise FileNotFoundError(f"PDF 文件不存在: {pdf_path}")
# 1. 提交阿里云任务 # Keep parser integration steps explicit so external workflow behavior stays traceable.
client = init_client() client = init_client()
print(f"提交任务: {pdf_path}") print(f"提交任务: {pdf_path}")
task_id = submit_job(client, str(pdf_path)) task_id = submit_job(client, str(pdf_path))
print(f"任务 ID: {task_id}") print(f"任务 ID: {task_id}")
# 2. 等待完成 # Keep parser integration steps explicit so external workflow behavior stays traceable.
print("等待任务完成...") print("等待任务完成...")
if not wait_for_completion(client, task_id, args.poll_interval): if not wait_for_completion(client, task_id, args.poll_interval):
print("任务失败,退出") print("任务失败,退出")
return return
# 3. 获取 layouts # Keep parser integration steps explicit so external workflow behavior stays traceable.
print("获取解析结果...") print("获取解析结果...")
layouts = collect_all_results(client, task_id) layouts = collect_all_results(client, task_id)
print(f"获取到 {len(layouts)} 个布局块") print(f"获取到 {len(layouts)} 个布局块")
# 4. 输出原始 layouts可选 # Keep parser integration steps explicit so external workflow behavior stays traceable.
if args.layouts_output: if args.layouts_output:
layouts_path = Path(args.layouts_output).expanduser().resolve() layouts_path = Path(args.layouts_output).expanduser().resolve()
layouts_path.write_text(json.dumps(layouts, ensure_ascii=False, indent=2), encoding="utf-8") layouts_path.write_text(json.dumps(layouts, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"原始 layouts 已写入: {layouts_path}") print(f"原始 layouts 已写入: {layouts_path}")
# 5. 转换为三层结构 # Keep parser integration steps explicit so external workflow behavior stays traceable.
print("转换为三层结构...") print("转换为三层结构...")
data = convert_layouts( data = convert_layouts(
layouts, layouts,
@@ -461,7 +502,7 @@ def main() -> None:
overlap_chars=args.overlap_chars, overlap_chars=args.overlap_chars,
) )
# 6. 输出结果 # Keep parser integration steps explicit so external workflow behavior stays traceable.
output_path = Path(args.out).expanduser().resolve() output_path = Path(args.out).expanduser().resolve()
output_path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8") output_path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")

View File

@@ -1,9 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """Handle Aliyun parsing support for upload to milvus."""
将 vector_chunks.json 向量化并上传到 Milvus 和 PostgreSQL
使用中转站的 OpenAI 兼容 API
"""
import argparse import argparse
import json import json
@@ -23,18 +20,18 @@ from pymilvus import (
) )
from openai import OpenAI from openai import OpenAI
# ===================== 配置 ===================== # Keep parser integration steps explicit so external workflow behavior stays traceable.
# 中转站配置 # Keep parser integration steps explicit so external workflow behavior stays traceable.
RELAY_BASE_URL = "http://6.86.80.4:30080/v1" RELAY_BASE_URL = "http://6.86.80.4:30080/v1"
RELAY_API_KEY = "sk-5HeY7gfSIlyZMacfuXOf5cphpymsNqufEu1ou4U3avbULcyY" RELAY_API_KEY = "sk-5HeY7gfSIlyZMacfuXOf5cphpymsNqufEu1ou4U3avbULcyY"
EMBEDDING_MODEL = "text-embedding-v3" # 中转站支持的 embedding 模型 EMBEDDING_MODEL = "text-embedding-v3" # Keep parser integration steps explicit so external workflow behavior stays traceable.
# Milvus 配置 # Keep parser integration steps explicit so external workflow behavior stays traceable.
MILVUS_HOST = "localhost" MILVUS_HOST = "localhost"
MILVUS_PORT = "19530" MILVUS_PORT = "19530"
COLLECTION_NAME = "regulation_chunks" COLLECTION_NAME = "regulation_chunks"
# PostgreSQL 配置 # Keep parser integration steps explicit so external workflow behavior stays traceable.
PG_HOST = "6.86.80.10" PG_HOST = "6.86.80.10"
PG_PORT = 5432 PG_PORT = 5432
PG_USER = "postgresql" PG_USER = "postgresql"
@@ -44,12 +41,12 @@ PG_DATABASE = "postgres"
# ===================== Embedding ===================== # ===================== Embedding =====================
def get_openai_client(api_key: str, base_url: str) -> OpenAI: def get_openai_client(api_key: str, base_url: str) -> OpenAI:
"""创建 OpenAI 客户端连接到中转站""" """Return openai client."""
return OpenAI(api_key=api_key, base_url=base_url) return OpenAI(api_key=api_key, base_url=base_url)
def get_embeddings_batch(client: OpenAI, texts: List[str], batch_size: int = 10) -> List[List[float]]: def get_embeddings_batch(client: OpenAI, texts: List[str], batch_size: int = 10) -> List[List[float]]:
"""批量获取文本向量""" """Return embeddings batch."""
all_embeddings = [] all_embeddings = []
for i in range(0, len(texts), batch_size): for i in range(0, len(texts), batch_size):
@@ -69,12 +66,13 @@ def get_embeddings_batch(client: OpenAI, texts: List[str], batch_size: int = 10)
# ===================== Milvus ===================== # ===================== Milvus =====================
def init_milvus(host: str, port: str): def init_milvus(host: str, port: str):
"""Handle init milvus."""
connections.connect("default", host=host, port=port) connections.connect("default", host=host, port=port)
print(f"已连接 Milvus: {host}:{port}") print(f"已连接 Milvus: {host}:{port}")
def create_collection(name: str, dim: int) -> Collection: def create_collection(name: str, dim: int) -> Collection:
"""创建或获取 collection""" """Create collection."""
if utility.has_collection(name): if utility.has_collection(name):
print(f"Collection '{name}' 已存在,删除重建") print(f"Collection '{name}' 已存在,删除重建")
utility.drop_collection(name) utility.drop_collection(name)
@@ -90,14 +88,14 @@ def create_collection(name: str, dim: int) -> Collection:
FieldSchema(name="page_end", dtype=DataType.INT64), FieldSchema(name="page_end", dtype=DataType.INT64),
FieldSchema(name="section_title", dtype=DataType.VARCHAR, max_length=512), FieldSchema(name="section_title", dtype=DataType.VARCHAR, max_length=512),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=2048), FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=2048),
FieldSchema(name="source_ids", dtype=DataType.VARCHAR, max_length=4096), # JSON 字符串 FieldSchema(name="source_ids", dtype=DataType.VARCHAR, max_length=4096), # Keep parser integration steps explicit so external workflow behavior stays traceable.
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dim), FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dim),
] ]
schema = CollectionSchema(fields, description="法规文档检索 chunks") schema = CollectionSchema(fields, description="法规文档检索 chunks")
collection = Collection(name, schema) collection = Collection(name, schema)
# 创建向量索引IVF_FLAT适合中小规模 # Keep parser integration steps explicit so external workflow behavior stays traceable.
index_params = { index_params = {
"metric_type": "COSINE", "metric_type": "COSINE",
"index_type": "IVF_FLAT", "index_type": "IVF_FLAT",
@@ -110,7 +108,7 @@ def create_collection(name: str, dim: int) -> Collection:
def insert_chunks(collection: Collection, chunks: List[Dict], embeddings: List[List[float]]): def insert_chunks(collection: Collection, chunks: List[Dict], embeddings: List[List[float]]):
"""插入 chunks 到 Milvus""" """Handle insert chunks."""
data = [ data = [
[c["chunk_id"] for c in chunks], [c["chunk_id"] for c in chunks],
[c["doc_id"] for c in chunks], [c["doc_id"] for c in chunks],
@@ -122,7 +120,7 @@ def insert_chunks(collection: Collection, chunks: List[Dict], embeddings: List[L
[c["page_end"] for c in chunks], [c["page_end"] for c in chunks],
[c["section_title"] for c in chunks], [c["section_title"] for c in chunks],
[c["text"] for c in chunks], [c["text"] for c in chunks],
[json.dumps(c.get("source_ids", [])) for c in chunks], # JSON 字符串 [json.dumps(c.get("source_ids", [])) for c in chunks], # Keep parser integration steps explicit so external workflow behavior stays traceable.
embeddings, embeddings,
] ]
@@ -132,14 +130,14 @@ def insert_chunks(collection: Collection, chunks: List[Dict], embeddings: List[L
def load_collection(collection: Collection): def load_collection(collection: Collection):
"""加载 collection 到内存(搜索前必须)""" """Load collection."""
collection.load() collection.load()
print(f"Collection 已加载到内存") print(f"Collection 已加载到内存")
# ===================== PostgreSQL ===================== # ===================== PostgreSQL =====================
def get_pg_connection(host: str, port: int, user: str, password: str, database: str): def get_pg_connection(host: str, port: int, user: str, password: str, database: str):
"""获取 PostgreSQL 连接""" """Return pg connection."""
conn = psycopg2.connect( conn = psycopg2.connect(
host=host, host=host,
port=port, port=port,
@@ -152,18 +150,18 @@ def get_pg_connection(host: str, port: int, user: str, password: str, database:
def insert_chunks_to_pg(conn, chunks: List[Dict], doc_data: Dict): def insert_chunks_to_pg(conn, chunks: List[Dict], doc_data: Dict):
"""插入 chunks 和相关数据到 PostgreSQL""" """Handle insert chunks to pg."""
cursor = conn.cursor() cursor = conn.cursor()
try: try:
# 1. 插入文档 # Keep parser integration steps explicit so external workflow behavior stays traceable.
cursor.execute(""" cursor.execute("""
INSERT INTO documents (doc_id, title, standard_number, upload_time) INSERT INTO documents (doc_id, title, standard_number, upload_time)
VALUES (%s, %s, %s, NOW()) VALUES (%s, %s, %s, NOW())
ON CONFLICT (doc_id) DO UPDATE SET title = EXCLUDED.title, updated_at = NOW() ON CONFLICT (doc_id) DO UPDATE SET title = EXCLUDED.title, updated_at = NOW()
""", (doc_data["doc_id"], doc_data["doc_title"], doc_data.get("standard_number"))) """, (doc_data["doc_id"], doc_data["doc_title"], doc_data.get("standard_number")))
# 2. 插入语义块 # Keep parser integration steps explicit so external workflow behavior stays traceable.
semantic_blocks = doc_data.get("semantic_blocks", []) semantic_blocks = doc_data.get("semantic_blocks", [])
if semantic_blocks: if semantic_blocks:
block_rows = [ block_rows = [
@@ -192,7 +190,7 @@ def insert_chunks_to_pg(conn, chunks: List[Dict], doc_data: Dict):
) )
print(f"已插入 {len(semantic_blocks)} 个语义块") print(f"已插入 {len(semantic_blocks)} 个语义块")
# 3. 插入向量块元数据 # Keep parser integration steps explicit so external workflow behavior stays traceable.
chunk_rows = [ chunk_rows = [
( (
doc_data["doc_id"], doc_data["doc_id"],
@@ -230,9 +228,9 @@ def insert_chunks_to_pg(conn, chunks: List[Dict], doc_data: Dict):
cursor.close() cursor.close()
# ===================== 主流程 ===================== # Keep parser integration steps explicit so external workflow behavior stays traceable.
def load_data(file_path: Path) -> Dict: def load_data(file_path: Path) -> Dict:
"""加载 vector_chunks.json返回完整数据""" """Load data."""
data = json.loads(file_path.read_text(encoding="utf-8")) data = json.loads(file_path.read_text(encoding="utf-8"))
return data return data
@@ -251,7 +249,8 @@ def upload_to_milvus_and_pg(
pg_password: str, pg_password: str,
pg_database: str, pg_database: str,
): ):
# 1. 加载完整数据 # Keep parser integration steps explicit so external workflow behavior stays traceable.
"""Handle upload to milvus and pg."""
chunks_path = Path(chunks_file).expanduser().resolve() chunks_path = Path(chunks_file).expanduser().resolve()
if not chunks_path.exists(): if not chunks_path.exists():
raise FileNotFoundError(f"文件不存在: {chunks_path}") raise FileNotFoundError(f"文件不存在: {chunks_path}")
@@ -262,29 +261,29 @@ def upload_to_milvus_and_pg(
raise ValueError("vector_chunks 为空") raise ValueError("vector_chunks 为空")
print(f"加载 {len(chunks)} 个 chunks") print(f"加载 {len(chunks)} 个 chunks")
# 2. 初始化连接 # Keep parser integration steps explicit so external workflow behavior stays traceable.
client = get_openai_client(api_key, base_url) client = get_openai_client(api_key, base_url)
init_milvus(milvus_host, milvus_port) init_milvus(milvus_host, milvus_port)
pg_conn = get_pg_connection(pg_host, pg_port, pg_user, pg_password, pg_database) pg_conn = get_pg_connection(pg_host, pg_port, pg_user, pg_password, pg_database)
# 3. 获取 embeddings # Keep parser integration steps explicit so external workflow behavior stays traceable.
texts = [c["embedding_text"] for c in chunks] texts = [c["embedding_text"] for c in chunks]
embeddings = get_embeddings_batch(client, texts, batch_size) embeddings = get_embeddings_batch(client, texts, batch_size)
print(f"生成 {len(embeddings)} 个向量") print(f"生成 {len(embeddings)} 个向量")
# 4. 获取 embedding 维度 # Keep parser integration steps explicit so external workflow behavior stays traceable.
embedding_dim = len(embeddings[0]) embedding_dim = len(embeddings[0])
print(f"Embedding 维度: {embedding_dim}") print(f"Embedding 维度: {embedding_dim}")
# 5. 创建 collection 并插入 Milvus # Keep parser integration steps explicit so external workflow behavior stays traceable.
collection = create_collection(collection_name, embedding_dim) collection = create_collection(collection_name, embedding_dim)
insert_chunks(collection, chunks, embeddings) insert_chunks(collection, chunks, embeddings)
load_collection(collection) load_collection(collection)
# 6. 插入 PostgreSQL # Keep parser integration steps explicit so external workflow behavior stays traceable.
insert_chunks_to_pg(pg_conn, chunks, data) insert_chunks_to_pg(pg_conn, chunks, data)
# 7. 关闭连接 # Keep parser integration steps explicit so external workflow behavior stays traceable.
pg_conn.close() pg_conn.close()
print("上传完成!") print("上传完成!")
@@ -292,6 +291,7 @@ def upload_to_milvus_and_pg(
# ===================== CLI ===================== # ===================== CLI =====================
def main(): def main():
"""Run the module entrypoint."""
parser = argparse.ArgumentParser(description="将 vector_chunks 向量化并上传到 Milvus 和 PostgreSQL") parser = argparse.ArgumentParser(description="将 vector_chunks 向量化并上传到 Milvus 和 PostgreSQL")
parser.add_argument("chunks_file", help="vector_chunks.json 文件路径") parser.add_argument("chunks_file", help="vector_chunks.json 文件路径")
parser.add_argument("--api-key", default=RELAY_API_KEY, help="中转站 API Key") parser.add_argument("--api-key", default=RELAY_API_KEY, help="中转站 API Key")

View File

@@ -1 +1,3 @@
"""API接口模块""" """Initialize the app.api package."""
# Keep package boundaries explicit so backend imports stay predictable.

View File

@@ -12,6 +12,8 @@ from app.api.routes import api_router
from app.config.logging import setup_logging from app.config.logging import setup_logging
from app.config.settings import settings from app.config.settings import settings
from app.services.llm.llm_factory import LLMFactory from app.services.llm.llm_factory import LLMFactory
# Keep module behavior explicit so the backend flow stays easy to audit.
setup_logging(level="INFO" if not settings.debug else "DEBUG") setup_logging(level="INFO" if not settings.debug else "DEBUG")

View File

@@ -1,5 +1,14 @@
"""API数据模型""" """Initialize the app.api.models package."""
from .agent import (
AskRequest,
AskResponse,
ChatRequest,
ChatResponse,
FeedbackRequest,
SessionInfo,
TemplateListResponse,
)
from .document import ( from .document import (
DocumentUploadRequest, DocumentUploadRequest,
DocumentUploadResponse, DocumentUploadResponse,
@@ -9,8 +18,17 @@ from .document import (
DocumentStatusResponse, DocumentStatusResponse,
ErrorResponse ErrorResponse
) )
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = [ __all__ = [
"AskRequest",
"AskResponse",
"ChatRequest",
"ChatResponse",
"FeedbackRequest",
"SessionInfo",
"TemplateListResponse",
"DocumentUploadRequest", "DocumentUploadRequest",
"DocumentUploadResponse", "DocumentUploadResponse",
"SearchRequest", "SearchRequest",

View File

@@ -0,0 +1,79 @@
"""Define API models for agent endpoints."""
from __future__ import annotations
from typing import Dict, List, Optional
from pydantic import BaseModel, Field
# Group agent transport schemas together so route modules stay focused on HTTP flow.
class AskRequest(BaseModel):
"""Define the Ask Request API model."""
query: str = Field(..., min_length=1, max_length=2000)
filters: Optional[str] = None
provider: Optional[str] = None
model: Optional[str] = None
top_k: Optional[int] = Field(default=None, ge=1, le=20)
prompt_template: Optional[str] = None
class AskResponse(BaseModel):
"""Define the Ask Response API model."""
answer: str
sources: List[Dict] = Field(default_factory=list)
model: str = ""
latency_ms: int = 0
retrieved_count: int = 0
context_tokens: int = 0
truncated: bool = False
error: Optional[str] = None
class ChatRequest(BaseModel):
"""Define the Chat Request API model."""
query: str = Field(..., min_length=1, max_length=2000)
session_id: Optional[str] = None
filters: Optional[str] = None
provider: Optional[str] = None
model: Optional[str] = None
top_k: Optional[int] = Field(default=None, ge=1, le=20)
class ChatResponse(BaseModel):
"""Define the Chat Response API model."""
session_id: str
answer: str
sources: List[Dict] = Field(default_factory=list)
model: str = ""
latency_ms: int = 0
message_count: int = 0
class SessionInfo(BaseModel):
"""Define the Session Info API model."""
session_id: str
message_count: int
created_at: int
updated_at: int
class FeedbackRequest(BaseModel):
"""Define the Feedback Request API model."""
session_id: str
message_index: int
rating: int = Field(..., ge=1, le=5)
comment: Optional[str] = None
class TemplateListResponse(BaseModel):
"""Define the Template List Response API model."""
templates: Dict[str, str]

View File

@@ -1,19 +1,21 @@
"""文档相关Pydantic数据模型""" """Define API models for document."""
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any from typing import List, Optional, Dict, Any
from datetime import datetime from datetime import datetime
# Group related schema definitions so validation rules stay consistent.
class DocumentUploadRequest(BaseModel): class DocumentUploadRequest(BaseModel):
"""文档上传请求""" """Define the Document Upload Request API model."""
doc_name: Optional[str] = Field(None, description="文档名称") doc_name: Optional[str] = Field(None, description="文档名称")
regulation_type: Optional[str] = Field(None, description="法规类型") regulation_type: Optional[str] = Field(None, description="法规类型")
version: Optional[str] = Field(None, description="文档版本") version: Optional[str] = Field(None, description="文档版本")
class DocumentUploadResponse(BaseModel): class DocumentUploadResponse(BaseModel):
"""文档上传响应""" """Define the Document Upload Response API model."""
doc_id: str = Field(..., description="文档ID") doc_id: str = Field(..., description="文档ID")
doc_name: str = Field(..., description="文档名称") doc_name: str = Field(..., description="文档名称")
status: str = Field(..., description="处理状态") status: str = Field(..., description="处理状态")
@@ -25,14 +27,14 @@ class DocumentUploadResponse(BaseModel):
class SearchRequest(BaseModel): class SearchRequest(BaseModel):
"""检索请求""" """Define the Search Request API model."""
query: str = Field(..., description="查询文本") query: str = Field(..., description="查询文本")
top_k: int = Field(default=10, description="返回结果数量") top_k: int = Field(default=10, description="返回结果数量")
filters: Optional[str] = Field(None, description="过滤条件") filters: Optional[str] = Field(None, description="过滤条件")
class SearchResultItem(BaseModel): class SearchResultItem(BaseModel):
"""单个检索结果""" """Define the Search Result Item API model."""
id: int = Field(..., description="记录ID") id: int = Field(..., description="记录ID")
content: str = Field(..., description="内容") content: str = Field(..., description="内容")
score: float = Field(..., description="相似度分数") score: float = Field(..., description="相似度分数")
@@ -40,7 +42,7 @@ class SearchResultItem(BaseModel):
class SearchResponse(BaseModel): class SearchResponse(BaseModel):
"""检索响应""" """Define the Search Response API model."""
query: str = Field(..., description="查询文本") query: str = Field(..., description="查询文本")
total: int = Field(..., description="结果总数") total: int = Field(..., description="结果总数")
results: List[SearchResultItem] = Field(default_factory=list, description="结果列表") results: List[SearchResultItem] = Field(default_factory=list, description="结果列表")
@@ -48,7 +50,7 @@ class SearchResponse(BaseModel):
class DocumentStatusResponse(BaseModel): class DocumentStatusResponse(BaseModel):
"""文档状态响应""" """Define the Document Status Response API model."""
doc_id: str = Field(..., description="文档ID") doc_id: str = Field(..., description="文档ID")
status: str = Field(..., description="状态") status: str = Field(..., description="状态")
num_chunks: Optional[int] = Field(None, description="分块数量") num_chunks: Optional[int] = Field(None, description="分块数量")
@@ -56,7 +58,7 @@ class DocumentStatusResponse(BaseModel):
class ErrorResponse(BaseModel): class ErrorResponse(BaseModel):
"""错误响应""" """Define the Error Response API model."""
error: str = Field(..., description="错误类型") error: str = Field(..., description="错误类型")
message: str = Field(..., description="错误消息") message: str = Field(..., description="错误消息")
timestamp: datetime = Field(default_factory=datetime.now, description="时间戳") timestamp: datetime = Field(default_factory=datetime.now, description="时间戳")

View File

@@ -1,16 +1,29 @@
"""API路由模块""" """Initialize the app.api.routes package."""
from fastapi import APIRouter from fastapi import APIRouter
from .compliance import router as compliance_router
from .documents import router as documents_router from .documents import router as documents_router
from .knowledge import router as knowledge_router from .knowledge import router as knowledge_router
from .agent import router as agent_router from .agent import router as agent_router
from .status import router as status_router
# Keep package boundaries explicit so backend imports stay predictable.
# 主路由
# Keep package boundaries explicit so backend imports stay predictable.
api_router = APIRouter() api_router = APIRouter()
# 注册子路由 # Keep package boundaries explicit so backend imports stay predictable.
api_router.include_router(documents_router) api_router.include_router(documents_router)
api_router.include_router(knowledge_router) api_router.include_router(knowledge_router)
api_router.include_router(agent_router) api_router.include_router(agent_router)
api_router.include_router(compliance_router)
api_router.include_router(status_router)
__all__ = ["api_router", "documents_router", "knowledge_router", "agent_router"] __all__ = [
"api_router",
"documents_router",
"knowledge_router",
"agent_router",
"compliance_router",
"status_router",
]

View File

@@ -1,186 +1,83 @@
"""Agent API接口 - 问答对话接口""" """Define API routes for agent."""
from __future__ import annotations
from fastapi import APIRouter, HTTPException, Depends
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, AsyncGenerator
from loguru import logger
import json
import asyncio import asyncio
import json
from typing import AsyncGenerator, List, Optional
from app.services.agent.qa_agent import QAAgent, AgentConfig from fastapi import APIRouter, HTTPException
from app.services.agent.session_manager import SessionManager from fastapi.responses import StreamingResponse
from dataclasses import asdict
from app.api.models import (
AskRequest,
AskResponse,
ChatRequest,
ChatResponse,
FeedbackRequest,
SessionInfo,
)
from app.config.settings import settings from app.config.settings import settings
from app.shared.bootstrap import get_agent_conversation_service, get_conversation_store
# Keep route handlers close to their transport-layer wiring for easier auditing.
router = APIRouter(prefix="/agent", tags=["agent"]) router = APIRouter(prefix="/agent", tags=["agent"])
# 会话管理器(全局实例)
session_manager = SessionManager()
# ===== Pydantic Models =====
class AskRequest(BaseModel):
"""单次问答请求"""
query: str = Field(..., description="用户问题", min_length=1, max_length=2000)
filters: Optional[str] = Field(None, description="检索过滤条件")
provider: Optional[str] = Field(None, description="LLM提供商 (qwen/deepseek)")
model: Optional[str] = Field(None, description="LLM模型名称")
top_k: Optional[int] = Field(None, description="检索数量", ge=1, le=20)
prompt_template: Optional[str] = Field(None, description="Prompt模板名称")
class AskResponse(BaseModel):
"""问答响应"""
answer: str
sources: List[Dict] = []
model: str = ""
latency_ms: int = 0
retrieved_count: int = 0
context_tokens: int = 0
truncated: bool = False
error: Optional[str] = None
class ChatRequest(BaseModel):
"""多轮对话请求"""
query: str = Field(..., description="用户问题", min_length=1, max_length=2000)
session_id: Optional[str] = Field(None, description="会话ID首次对话可不传")
filters: Optional[str] = Field(None, description="检索过滤条件")
provider: Optional[str] = Field(None, description="LLM提供商")
model: Optional[str] = Field(None, description="LLM模型名称")
class ChatResponse(BaseModel):
"""多轮对话响应"""
session_id: str
answer: str
sources: List[Dict] = []
model: str = ""
latency_ms: int = 0
message_count: int = 0
class SessionInfo(BaseModel):
"""会话信息"""
session_id: str
message_count: int
created_at: int
updated_at: int
class FeedbackRequest(BaseModel):
"""反馈请求"""
session_id: str
message_index: int
rating: int = Field(..., ge=1, le=5, description="评分 1-5")
comment: Optional[str] = Field(None, description="反馈内容")
class TemplateListResponse(BaseModel):
"""模板列表响应"""
templates: Dict[str, str]
# ===== API Endpoints =====
@router.post("/ask", response_model=AskResponse) @router.post("/ask", response_model=AskResponse)
async def ask_question(request: AskRequest): async def ask_question(request: AskRequest):
""" """Handle ask question."""
单次问答接口
不保存会话历史,适合单次查询场景。
"""
logger.info(f"收到问答请求: {request.query}")
try: try:
# 构建Agent配置 _, result = get_agent_conversation_service().ask(
config = AgentConfig(
llm_provider=request.provider or settings.llm_provider,
llm_model=request.model or settings.llm_model,
top_k=request.top_k or settings.rag_top_k
)
# 创建Agent并执行问答
agent = QAAgent(config)
response = agent.ask(
query=request.query, query=request.query,
filters=request.filters, filters=request.filters,
prompt_template=request.prompt_template provider=request.provider or settings.llm_provider,
model=request.model or settings.llm_model,
top_k=request.top_k or settings.rag_top_k,
prompt_template=request.prompt_template,
) )
agent.close()
return AskResponse( return AskResponse(
answer=response.answer, answer=result.answer,
sources=response.sources, sources=[asdict(source) for source in result.sources],
model=response.model, model=result.model,
latency_ms=response.latency_ms, latency_ms=result.latency_ms,
retrieved_count=response.retrieved_count, retrieved_count=result.retrieved_count,
context_tokens=response.context_tokens, context_tokens=result.context_tokens,
truncated=response.truncated, truncated=result.truncated,
error=response.error error=result.error,
) )
except Exception as exc:
except Exception as e: raise HTTPException(status_code=500, detail=str(exc))
logger.error(f"问答失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/chat", response_model=ChatResponse) @router.post("/chat", response_model=ChatResponse)
async def chat_with_session(request: ChatRequest): async def chat_with_session(request: ChatRequest):
""" """Handle chat with session."""
多轮对话接口
支持会话历史记录,适合连续对话场景。
"""
logger.info(f"收到对话请求: session={request.session_id}, query={request.query}")
try: try:
# 获取或创建会话 session_id, result = get_agent_conversation_service().chat(
if request.session_id:
session = session_manager.get_session(request.session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在或已过期")
else:
session = session_manager.create_session()
# 添加用户消息
session.add_user_message(request.query)
# 执行问答
config = AgentConfig(
llm_provider=request.provider or settings.llm_provider,
llm_model=request.model or settings.llm_model
)
agent = QAAgent(config)
response = agent.ask(
query=request.query, query=request.query,
filters=request.filters session_id=request.session_id,
filters=request.filters,
provider=request.provider or settings.llm_provider,
model=request.model or settings.llm_model,
top_k=request.top_k or settings.rag_top_k,
) )
agent.close() session = get_conversation_store().get_session(session_id)
# 添加助手消息
session.add_assistant_message(
response.answer,
response.sources
)
return ChatResponse( return ChatResponse(
session_id=session.session_id, session_id=session_id,
answer=response.answer, answer=result.answer,
sources=response.sources, sources=[asdict(source) for source in result.sources],
model=response.model, model=result.model,
latency_ms=response.latency_ms, latency_ms=result.latency_ms,
message_count=session.message_count message_count=len(session.messages) if session else 0,
) )
except ValueError as exc:
except HTTPException: raise HTTPException(status_code=404, detail=str(exc))
raise except Exception as exc:
except Exception as e: raise HTTPException(status_code=500, detail=str(exc))
logger.error(f"对话失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/chat/stream") @router.get("/chat/stream")
@@ -189,260 +86,93 @@ async def chat_stream_get(
session_id: Optional[str] = None, session_id: Optional[str] = None,
filters: Optional[str] = None, filters: Optional[str] = None,
provider: Optional[str] = None, provider: Optional[str] = None,
model: Optional[str] = None model: Optional[str] = None,
): ):
""" """Handle chat stream get."""
流式对话接口SSE- GET版本
EventSource只能发送GET请求因此提供此接口。
query参数通过URL传递。
SSE事件格式
- event: session - 会话ID
- event: status - 状态更新(检索中、生成中)
- event: sources - 引用来源
- event: content - 回答内容片段
- event: done - 完成,包含统计信息
- event: error - 错误信息
"""
logger.info(f"收到GET流式对话请求: session={session_id}, query={query}")
async def generate_sse() -> AsyncGenerator[str, None]: async def generate_sse() -> AsyncGenerator[str, None]:
"""生成SSE事件流""" """Handle generate sse."""
try: try:
# 获取或创建会话 session_id_, event_stream = get_agent_conversation_service().stream_chat(
if session_id:
session = session_manager.get_session(session_id)
if not session:
yield f"event: error\ndata: 会话不存在或已过期\n\n"
return
else:
session = session_manager.create_session()
# 发送session_id
yield f"event: session\ndata: {json.dumps({'session_id': session.session_id})}\n\n"
# 添加用户消息
session.add_user_message(query)
# 创建Agent
config = AgentConfig(
llm_provider=provider or settings.llm_provider,
llm_model=model or settings.llm_model
)
agent = QAAgent(config)
# 执行流式问答
full_answer = ""
sources = []
done_data = {}
for event_data in agent.ask_stream(
query=query, query=query,
filters=filters session_id=session_id,
): filters=filters,
provider=provider or settings.llm_provider,
model=model or settings.llm_model,
top_k=settings.rag_top_k,
)
yield f"event: session\ndata: {json.dumps({'session_id': session_id_})}\n\n"
for event_data in event_stream:
event_type = event_data.get("event", "content") event_type = event_data.get("event", "content")
data = event_data.get("data", "") data = event_data.get("data", "")
# 收集完整回答和来源
if event_type == "content":
full_answer += str(data)
elif event_type == "sources":
sources = data
elif event_type == "done":
done_data = data
# 发送SSE事件
if isinstance(data, (dict, list)): if isinstance(data, (dict, list)):
yield f"event: {event_type}\ndata: {json.dumps(data)}\n\n" yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
else: else:
yield f"event: {event_type}\ndata: {data}\n\n" yield f"event: {event_type}\ndata: {data}\n\n"
# 小延迟让其他任务有机会执行
await asyncio.sleep(0) await asyncio.sleep(0)
except Exception as exc:
agent.close() yield f"event: error\ndata: {str(exc)}\n\n"
# 保存到会话历史
session.add_assistant_message(full_answer, sources)
except Exception as e:
logger.error(f"流式对话失败: {e}")
yield f"event: error\ndata: {str(e)}\n\n"
return StreamingResponse( return StreamingResponse(
generate_sse(), generate_sse(),
media_type="text/event-stream", media_type="text/event-stream",
headers={ headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # 禁用nginx缓冲
}
) )
@router.post("/chat/stream") @router.post("/chat/stream")
async def chat_stream(request: ChatRequest): async def chat_stream(request: ChatRequest):
""" """Handle chat stream."""
流式对话接口SSE return await chat_stream_get(
query=request.query,
返回Server-Sent Events格式的流式响应用户可实时看到思考过程和回答生成。 session_id=request.session_id,
filters=request.filters,
SSE事件格式 provider=request.provider,
- event: status - 状态更新(检索中、生成中) model=request.model,
- event: sources - 引用来源
- event: content - 回答内容片段
- event: done - 完成,包含统计信息
- event: error - 错误信息
"""
logger.info(f"收到流式对话请求: session={request.session_id}, query={request.query}")
async def generate_sse() -> AsyncGenerator[str, None]:
"""生成SSE事件流"""
try:
# 获取或创建会话
if request.session_id:
session = session_manager.get_session(request.session_id)
if not session:
yield f"event: error\ndata: 会话不存在或已过期\n\n"
return
else:
session = session_manager.create_session()
# 发送session_id
yield f"event: session\ndata: {json.dumps({'session_id': session.session_id})}\n\n"
# 添加用户消息
session.add_user_message(request.query)
# 创建Agent
config = AgentConfig(
llm_provider=request.provider or settings.llm_provider,
llm_model=request.model or settings.llm_model
)
agent = QAAgent(config)
# 执行流式问答
full_answer = ""
sources = []
done_data = {}
for event_data in agent.ask_stream(
query=request.query,
filters=request.filters
):
event_type = event_data.get("event", "content")
data = event_data.get("data", "")
# 收集完整回答和来源
if event_type == "content":
full_answer += str(data)
elif event_type == "sources":
sources = data
elif event_type == "done":
done_data = data
# 发送SSE事件
if isinstance(data, (dict, list)):
yield f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
else:
yield f"event: {event_type}\ndata: {data}\n\n"
# 小延迟让其他任务有机会执行
await asyncio.sleep(0)
agent.close()
# 保存到会话历史
session.add_assistant_message(full_answer, sources)
except Exception as e:
logger.error(f"流式对话失败: {e}")
yield f"event: error\ndata: {str(e)}\n\n"
return StreamingResponse(
generate_sse(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # 禁用nginx缓冲
}
) )
@router.get("/session/{session_id}", response_model=SessionInfo) @router.get("/session/{session_id}", response_model=SessionInfo)
async def get_session_info(session_id: str): async def get_session_info(session_id: str):
"""获取会话信息""" """Return session info."""
session = session_manager.get_session(session_id) session = get_conversation_store().get_session(session_id)
if not session: if not session:
raise HTTPException(status_code=404, detail="会话不存在或已过期") raise HTTPException(status_code=404, detail="会话不存在或已过期")
return SessionInfo( return SessionInfo(
session_id=session.session_id, session_id=session.session_id,
message_count=session.message_count, message_count=len(session.messages),
created_at=session.created_at, created_at=session.created_at,
updated_at=session.updated_at updated_at=session.updated_at,
) )
@router.get("/session/{session_id}/history") @router.get("/session/{session_id}/history")
async def get_session_history(session_id: str, max_turns: int = 5): async def get_session_history(session_id: str, max_turns: int = 5):
"""获取会话历史""" """Return session history."""
session = session_manager.get_session(session_id) session = get_conversation_store().get_session(session_id)
if not session: if not session:
raise HTTPException(status_code=404, detail="会话不存在或已过期") raise HTTPException(status_code=404, detail="会话不存在或已过期")
history = [{"role": msg.role, "content": msg.content} for msg in session.messages[-(max_turns * 2):]]
history = session.get_history(max_turns)
return {"session_id": session_id, "history": history} return {"session_id": session_id, "history": history}
@router.delete("/session/{session_id}") @router.delete("/session/{session_id}")
async def delete_session(session_id: str): async def delete_session(session_id: str):
"""删除会话""" """Delete session."""
success = session_manager.delete_session(session_id) if not get_conversation_store().delete_session(session_id):
if not success:
raise HTTPException(status_code=404, detail="会话不存在") raise HTTPException(status_code=404, detail="会话不存在")
return {"message": "会话已删除", "session_id": session_id} return {"message": "会话已删除", "session_id": session_id}
@router.get("/sessions", response_model=List[SessionInfo]) @router.get("/sessions", response_model=List[SessionInfo])
async def list_sessions(): async def list_sessions():
"""列出所有活跃会话""" """List sessions."""
sessions = session_manager.list_sessions() return [SessionInfo(**item) for item in get_conversation_store().list_sessions()]
return [SessionInfo(**s) for s in sessions]
@router.post("/feedback") @router.post("/feedback")
async def submit_feedback(request: FeedbackRequest): async def submit_feedback(request: FeedbackRequest):
"""提交问答反馈""" """Submit feedback."""
session = session_manager.get_session(request.session_id) session = get_conversation_store().get_session(request.session_id)
if not session: if not session:
raise HTTPException(status_code=404, detail="会话不存在") raise HTTPException(status_code=404, detail="会话不存在")
return {"message": "反馈已提交", "session_id": request.session_id, "message_index": request.message_index}
# 记录反馈(实际应用中可存储到数据库)
logger.info(f"收到反馈: session={request.session_id}, rating={request.rating}, comment={request.comment}")
return {"message": "反馈已记录", "rating": request.rating}
@router.get("/templates", response_model=TemplateListResponse)
async def list_prompt_templates():
"""列出可用的Prompt模板"""
from app.services.rag.prompt_templates import PromptTemplates
templates = PromptTemplates.list_templates()
return TemplateListResponse(templates=templates)
@router.get("/models")
async def list_available_models():
"""列出可用的LLM模型"""
from app.services.llm import LLMFactory
factory = LLMFactory()
models = factory.list_available_providers()
return {"models": models}

View File

@@ -1,9 +1,15 @@
from fastapi import APIRouter, UploadFile, File, HTTPException """Define API routes for compliance."""
from sse_starlette.sse import EventSourceResponse
import uuid from __future__ import annotations
import os
import json
import asyncio import asyncio
import json
from pathlib import Path
from typing import AsyncGenerator
from fastapi import APIRouter, File, UploadFile
from fastapi.responses import StreamingResponse
from app.schemas.compliance import ( from app.schemas.compliance import (
AnalyzeResponse, AnalyzeResponse,
ComplianceChatRequest, ComplianceChatRequest,
@@ -13,38 +19,42 @@ from app.services.mock_data import (
get_mock_compliance_result, get_mock_compliance_result,
get_mock_compliance_chat_response, get_mock_compliance_chat_response,
) )
# Keep route handlers close to their transport-layer wiring for easier auditing.
router = APIRouter(prefix="/compliance", tags=["合规分析"]) router = APIRouter(prefix="/compliance", tags=["合规分析"])
# 临时存储分析任务 # Keep route handlers close to their transport-layer wiring for easier auditing.
tasks_store: dict[str, dict] = {} tasks_store: dict[str, dict] = {}
# Store uploaded compliance files inside the local backend data directory.
RAW_DATA_DIR = Path(__file__).resolve().parents[3] / "data" / "raw"
@router.post("/analyze", response_model=AnalyzeResponse) @router.post("/analyze", response_model=AnalyzeResponse)
async def analyze_document(file: UploadFile = File(...)): async def analyze_document(file: UploadFile = File(...)):
"""上传设计方案进行分析""" """Handle analyze document."""
# 生成任务ID # Keep route handlers close to their transport-layer wiring for easier auditing.
task_id = generate_task_id() task_id = generate_task_id()
# 保存文件 # Keep route handlers close to their transport-layer wiring for easier auditing.
raw_dir = "/airegulation/demo-mao/backend/data/raw" RAW_DATA_DIR.mkdir(parents=True, exist_ok=True)
os.makedirs(raw_dir, exist_ok=True) file_path = RAW_DATA_DIR / f"compliance_{task_id}_{file.filename}"
file_path = os.path.join(raw_dir, f"compliance_{task_id}_{file.filename}")
content = await file.read() content = await file.read()
with open(file_path, "wb") as f: with file_path.open("wb") as f:
f.write(content) f.write(content)
# 记录任务 # Keep route handlers close to their transport-layer wiring for easier auditing.
tasks_store[task_id] = { tasks_store[task_id] = {
"task_id": task_id, "task_id": task_id,
"file_path": file_path, "file_path": str(file_path),
"status": "processing", "status": "processing",
"result": None, "result": None,
} }
# 模拟异步处理完成(立即返回结果) # Keep route handlers close to their transport-layer wiring for easier auditing.
# 实际应用中这应该是后台任务 # Keep route handlers close to their transport-layer wiring for easier auditing.
tasks_store[task_id]["status"] = "completed" tasks_store[task_id]["status"] = "completed"
tasks_store[task_id]["result"] = get_mock_compliance_result(task_id) tasks_store[task_id]["result"] = get_mock_compliance_result(task_id)
@@ -53,9 +63,9 @@ async def analyze_document(file: UploadFile = File(...)):
@router.get("/result/{task_id}") @router.get("/result/{task_id}")
async def get_result(task_id: str): async def get_result(task_id: str):
"""获取分析结果""" """Return result."""
if task_id not in tasks_store: if task_id not in tasks_store:
# 如果任务ID不存在返回默认mock结果 # Keep route handlers close to their transport-layer wiring for easier auditing.
return get_mock_compliance_result(task_id) return get_mock_compliance_result(task_id)
task = tasks_store[task_id] task = tasks_store[task_id]
@@ -68,8 +78,8 @@ async def get_result(task_id: str):
@router.post("/chat/{segment_id}") @router.post("/chat/{segment_id}")
async def compliance_chat(segment_id: int, request: ComplianceChatRequest): async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
"""针对段落进行合规对话""" """Handle compliance chat."""
# 根据segment_id获取对应的intent # Keep route handlers close to their transport-layer wiring for easier auditing.
intent_map = { intent_map = {
1: "车身结构设计", 1: "车身结构设计",
2: "动力系统配置", 2: "动力系统配置",
@@ -77,11 +87,12 @@ async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
} }
intent = intent_map.get(segment_id, "车身结构设计") intent = intent_map.get(segment_id, "车身结构设计")
async def generate(): async def generate() -> AsyncGenerator[str, None]:
# 获取预设响应 # Keep route handlers close to their transport-layer wiring for easier auditing.
"""Handle generate."""
response = get_mock_compliance_chat_response(intent, request.query) response = get_mock_compliance_chat_response(intent, request.query)
# 流式输出响应 # Keep route handlers close to their transport-layer wiring for easier auditing.
sentences = response.split("\n\n") sentences = response.split("\n\n")
for sentence in sentences: for sentence in sentences:
if sentence.strip(): if sentence.strip():
@@ -89,8 +100,15 @@ async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
for chunk in chunks: for chunk in chunks:
if chunk.strip(): if chunk.strip():
await asyncio.sleep(0.05) await asyncio.sleep(0.05)
yield {"event": "message", "data": json.dumps({"type": "chunk", "text": chunk + "\n"})} yield (
"event: message\n"
f"data: {json.dumps({'type': 'chunk', 'text': chunk + chr(10)}, ensure_ascii=False)}\n\n"
)
yield {"event": "message", "data": json.dumps({"type": "done"})} yield f"event: message\ndata: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
return EventSourceResponse(generate()) return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)

View File

@@ -1,3 +1,5 @@
"""Define API routes for docs."""
from fastapi import APIRouter, UploadFile, File, HTTPException from fastapi import APIRouter, UploadFile, File, HTTPException
import os import os
import uuid import uuid
@@ -10,30 +12,32 @@ from app.schemas.doc import (
EmbedResponse, EmbedResponse,
) )
from app.services.mock_data import get_mock_documents, generate_doc_id from app.services.mock_data import get_mock_documents, generate_doc_id
# Keep route handlers close to their transport-layer wiring for easier auditing.
router = APIRouter(prefix="/docs", tags=["文档管理"]) router = APIRouter(prefix="/docs", tags=["文档管理"])
# 临时存储文档信息包含预设的mock文档 # Keep route handlers close to their transport-layer wiring for easier auditing.
documents_store: dict[str, dict] = {} documents_store: dict[str, dict] = {}
# 初始化时加载mock文档 # Keep route handlers close to their transport-layer wiring for easier auditing.
for doc in get_mock_documents(): for doc in get_mock_documents():
documents_store[doc["id"]] = doc documents_store[doc["id"]] = doc
@router.post("/upload", response_model=DocumentUploadResponse) @router.post("/upload", response_model=DocumentUploadResponse)
async def upload_document(file: UploadFile = File(...)): async def upload_document(file: UploadFile = File(...)):
"""上传法规文档""" """Handle upload document."""
# 检查文件格式 # Keep route handlers close to their transport-layer wiring for easier auditing.
allowed_ext = [".pdf", ".docx", ".doc", ".txt"] allowed_ext = [".pdf", ".docx", ".doc", ".txt"]
ext = os.path.splitext(file.filename)[1].lower() ext = os.path.splitext(file.filename)[1].lower()
if ext not in allowed_ext: if ext not in allowed_ext:
raise HTTPException(400, f"Unsupported file format: {ext}") raise HTTPException(400, f"Unsupported file format: {ext}")
# 生成文档ID # Keep route handlers close to their transport-layer wiring for easier auditing.
doc_id = generate_doc_id() doc_id = generate_doc_id()
# 保存文件 # Keep route handlers close to their transport-layer wiring for easier auditing.
raw_dir = "/airegulation/demo-mao/backend/data/raw" raw_dir = "/airegulation/demo-mao/backend/data/raw"
os.makedirs(raw_dir, exist_ok=True) os.makedirs(raw_dir, exist_ok=True)
file_path = os.path.join(raw_dir, f"{doc_id}_{file.filename}") file_path = os.path.join(raw_dir, f"{doc_id}_{file.filename}")
@@ -42,7 +46,7 @@ async def upload_document(file: UploadFile = File(...)):
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(content) f.write(content)
# 记录文档信息 # Keep route handlers close to their transport-layer wiring for easier auditing.
documents_store[doc_id] = { documents_store[doc_id] = {
"id": doc_id, "id": doc_id,
"name": file.filename, "name": file.filename,
@@ -62,7 +66,7 @@ async def upload_document(file: UploadFile = File(...)):
@router.get("/list", response_model=DocumentListResponse) @router.get("/list", response_model=DocumentListResponse)
async def list_documents(): async def list_documents():
"""获取已索引文档列表""" """List documents."""
docs = [ docs = [
DocumentInfo( DocumentInfo(
id=d["id"], id=d["id"],
@@ -78,14 +82,14 @@ async def list_documents():
@router.post("/parse/{doc_id}", response_model=ParseResponse) @router.post("/parse/{doc_id}", response_model=ParseResponse)
async def parse_document(doc_id: str): async def parse_document(doc_id: str):
"""解析文档并分块""" """Parse document."""
if doc_id not in documents_store: if doc_id not in documents_store:
raise HTTPException(404, "Document not found") raise HTTPException(404, "Document not found")
doc = documents_store[doc_id] doc = documents_store[doc_id]
# 模拟解析逻辑 # Keep route handlers close to their transport-layer wiring for easier auditing.
doc["status"] = "parsed" doc["status"] = "parsed"
# 根据文件大小计算chunks数量 # Keep route handlers close to their transport-layer wiring for easier auditing.
file_size = doc.get("size", 100000) file_size = doc.get("size", 100000)
doc["chunks"] = max(20, file_size // 8000) doc["chunks"] = max(20, file_size // 8000)
@@ -94,12 +98,12 @@ async def parse_document(doc_id: str):
@router.post("/embed/{doc_id}", response_model=EmbedResponse) @router.post("/embed/{doc_id}", response_model=EmbedResponse)
async def embed_document(doc_id: str): async def embed_document(doc_id: str):
"""嵌入并存入向量库""" """Embed document."""
if doc_id not in documents_store: if doc_id not in documents_store:
raise HTTPException(404, "Document not found") raise HTTPException(404, "Document not found")
doc = documents_store[doc_id] doc = documents_store[doc_id]
# 模拟嵌入逻辑 # Keep route handlers close to their transport-layer wiring for easier auditing.
doc["status"] = "indexed" doc["status"] = "indexed"
return EmbedResponse(doc_id=doc_id, vectors=doc["chunks"]) return EmbedResponse(doc_id=doc_id, vectors=doc["chunks"])
@@ -107,7 +111,7 @@ async def embed_document(doc_id: str):
@router.delete("/delete/{doc_id}") @router.delete("/delete/{doc_id}")
async def delete_document(doc_id: str): async def delete_document(doc_id: str):
"""删除文档""" """Delete document."""
if doc_id not in documents_store: if doc_id not in documents_store:
raise HTTPException(404, "Document not found") raise HTTPException(404, "Document not found")

View File

@@ -1,290 +1,140 @@
"""文档上传与处理接口""" """Define API routes for documents."""
from __future__ import annotations
from fastapi import APIRouter, UploadFile, File, Form, HTTPException
from fastapi.responses import FileResponse, StreamingResponse
from typing import Optional
import os
import uuid
import tempfile
from pathlib import Path
from loguru import logger
from io import BytesIO from io import BytesIO
from urllib.parse import quote from urllib.parse import quote
from ..models import DocumentUploadResponse, ErrorResponse from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from app.services.document_processor import DocumentProcessor from fastapi.responses import StreamingResponse
from app.services.storage.minio_client import MinIOClient from loguru import logger
from app.config.settings import settings
from app.api.models import DocumentUploadResponse
from app.application.documents import DocumentProcessResult
from app.shared.bootstrap import get_document_command_service, get_document_query_service
# Keep route handlers close to their transport-layer wiring for easier auditing.
router = APIRouter(prefix="/documents", tags=["documents"]) router = APIRouter(prefix="/documents", tags=["documents"])
# MinIO客户端用于文档存储
minio_client: Optional[MinIOClient] = None
def _document_response(result: DocumentProcessResult) -> DocumentUploadResponse:
def get_minio_client() -> MinIOClient: """Handle document response for this module."""
"""获取MinIO客户端实例""" return DocumentUploadResponse(
global minio_client doc_id=result.doc_id,
if minio_client is None: doc_name=result.doc_name,
minio_client = MinIOClient() status=result.status,
minio_client.connect() message=result.message,
minio_client.ensure_bucket() num_chunks=result.num_chunks,
return minio_client summary=result.summary,
summary_latency_ms=result.summary_latency_ms,
)
def _build_document_records(limit: Optional[int] = None):
"""构建文档列表记录,支持按最近更新时间倒序截断。"""
minio = get_minio_client()
document_records = []
objects = minio.client.list_objects(minio.bucket, recursive=True)
for obj in objects:
parts = obj.object_name.split("/", 1)
if len(parts) != 2:
continue
doc_id, filename = parts
last_modified = getattr(obj, "last_modified", None)
document_records.append({
"doc_id": doc_id,
"filename": filename,
"size": getattr(obj, "size", 0) or 0,
"object_name": obj.object_name,
"download_url": f"/api/v1/documents/download/{doc_id}",
"last_modified": last_modified.isoformat() if last_modified else None,
"_sort_key": last_modified.timestamp() if last_modified else 0,
})
document_records.sort(key=lambda item: item["_sort_key"], reverse=True)
if limit is not None:
document_records = document_records[:limit]
for item in document_records:
item.pop("_sort_key", None)
return document_records
@router.post("/upload", response_model=DocumentUploadResponse) @router.post("/upload", response_model=DocumentUploadResponse)
async def upload_document( async def upload_document(
file: UploadFile = File(..., description="上传的文档文件"), file: UploadFile = File(..., description="上传的文档文件"),
doc_name: Optional[str] = Form(None, description="文档名称"), doc_name: str | None = Form(None, description="文档名称"),
regulation_type: Optional[str] = Form(None, description="法规类型"), regulation_type: str | None = Form(None, description="法规类型"),
version: Optional[str] = Form(None, description="文档版本"), version: str | None = Form(None, description="文档版本"),
generate_summary: bool = Form(False, description="是否生成摘要默认不生成可节省约60秒") generate_summary: bool = Form(False, description="是否生成摘要"),
): ):
""" """Handle upload document."""
上传文档并处理 content = await file.read()
if not file.filename:
支持格式PDF、DOCX、DOC raise HTTPException(status_code=400, detail="文件名不能为空")
处理流程:解析 → 分块 → 嵌入 → 入库(摘要可选) if not content:
文件存储MinIO对象存储 raise HTTPException(status_code=400, detail="上传文件为空")
参数说明:
- generate_summary: 是否生成LLM摘要默认False。勾选后处理时间增加约60秒。
"""
# 验证文件类型
ext = os.path.splitext(file.filename)[1].lower()
if ext not in [".pdf", ".docx", ".doc"]:
raise HTTPException(
status_code=400,
detail=f"不支持的文件类型: {ext}仅支持PDF、DOCX、DOC"
)
# 验证文件大小
if file.size and file.size > settings.max_file_size_mb * 1024 * 1024:
raise HTTPException(
status_code=400,
detail=f"文件过大,最大支持{settings.max_file_size_mb}MB"
)
# 生成文档ID
doc_id = str(uuid.uuid4())[:8]
# 文档名称
final_doc_name = doc_name or file.filename
# MinIO对象名称
object_name = f"{doc_id}/{file.filename}"
logger.info(f"接收到文件上传: {final_doc_name}, 类型: {ext}, doc_id={doc_id}")
try: try:
# 读取文件内容 result = get_document_command_service().upload_and_process(
content = await file.read() file_name=file.filename,
content=content,
# 保存临时文件用于处理 content_type=file.content_type or "application/octet-stream",
temp_dir = tempfile.gettempdir() doc_name=doc_name,
temp_path = os.path.join(temp_dir, f"{doc_id}_{file.filename}")
with open(temp_path, "wb") as f:
f.write(content)
logger.info(f"临时文件已保存到: {temp_path}")
# 上传到MinIO
minio = get_minio_client()
upload_success = minio.upload_bytes(
data=content,
object_name=object_name,
content_type=minio._get_content_type(file.filename),
metadata={
"doc_id": doc_id # 仅传递ASCII安全的metadata
}
)
if upload_success:
logger.success(f"文件已上传到MinIO: {object_name}")
else:
logger.warning(f"MinIO上传失败仅使用本地临时文件")
# 处理文档传入相同的doc_id保持一致性
processor = DocumentProcessor(generate_summary=generate_summary)
result = processor.process(
file_path=temp_path,
doc_id=doc_id, # 使用相同的doc_id
doc_name=final_doc_name,
regulation_type=regulation_type or "", regulation_type=regulation_type or "",
version=version or "" version=version or "",
) generate_summary=generate_summary,
processor.close()
# 清理临时文件
try:
os.remove(temp_path)
except:
pass
if result.success:
return DocumentUploadResponse(
doc_id=result.doc_id,
doc_name=result.doc_name,
status="success",
message=result.message,
num_chunks=result.num_chunks,
summary=result.summary,
summary_latency_ms=result.summary_latency_ms
)
else:
raise HTTPException(
status_code=500,
detail=result.message
)
except Exception as e:
logger.error(f"文档处理失败: {e}")
raise HTTPException(
status_code=500,
detail=f"文档处理失败: {str(e)}"
) )
if result.status == "failed":
raise HTTPException(status_code=500, detail=result.message)
return _document_response(result)
except HTTPException:
raise
except Exception as exc:
logger.exception("文档上传失败")
raise HTTPException(status_code=500, detail=str(exc))
@router.get("/status/{doc_id}", response_model=DocumentUploadResponse) @router.get("/status/{doc_id}", response_model=DocumentUploadResponse)
async def get_document_status(doc_id: str): async def get_document_status(doc_id: str):
""" """Return document status."""
查询文档处理状态 document = get_document_query_service().get(doc_id)
if not document:
Args: raise HTTPException(status_code=404, detail="文档不存在")
doc_id: 文档ID
"""
# TODO: 实现状态查询(需要数据库支持)
return DocumentUploadResponse( return DocumentUploadResponse(
doc_id=doc_id, doc_id=document.doc_id,
doc_name="", doc_name=document.doc_name,
status="unknown", status=document.status.value,
message="状态查询功能待实现" message=document.error_message or "查询成功",
num_chunks=document.chunk_count,
summary=document.summary,
summary_latency_ms=document.summary_latency_ms,
) )
@router.get("/download/{doc_id}") @router.get("/download/{doc_id}")
async def download_document(doc_id: str): async def download_document(doc_id: str):
""" """Handle download document."""
下载文档从MinIO获取
Args:
doc_id: 文档ID
Returns:
文件下载响应
"""
logger.info(f"请求下载文档: doc_id={doc_id}")
try: try:
minio = get_minio_client() document, file_data = get_document_query_service().download(doc_id)
encoded_name = quote(document.file_name)
# 查找该doc_id下的文件MinIO对象名称格式: {doc_id}/{filename}
objects = minio.list_objects(prefix=f"{doc_id}/")
if not objects:
logger.warning(f"MinIO中未找到文档: doc_id={doc_id}")
raise HTTPException(
status_code=404,
detail=f"文档不存在: doc_id={doc_id}"
)
# 获取第一个匹配的对象
object_name = objects[0]
logger.info(f"找到MinIO对象: {object_name}")
# 获取文件数据
file_data = minio.get_object_data(object_name)
if file_data is None:
raise HTTPException(
status_code=500,
detail=f"获取文档数据失败"
)
# 解析原始文件名
original_name = object_name.split("/", 1)[1] if "/" in object_name else object_name
# 获取Content-Type
content_type = minio._get_content_type(original_name)
logger.success(f"文档下载成功: {original_name}, 大小={len(file_data)}")
# 返回文件流URL编码文件名以支持中文
encoded_name = quote(original_name)
return StreamingResponse( return StreamingResponse(
BytesIO(file_data), BytesIO(file_data),
media_type=content_type, media_type=document.content_type or "application/octet-stream",
headers={ headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_name}"},
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_name}"
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"文档下载失败: {e}")
raise HTTPException(
status_code=500,
detail=f"文档下载失败: {str(e)}"
) )
except FileNotFoundError as exc:
raise HTTPException(status_code=404, detail=str(exc))
except Exception as exc:
logger.exception("文档下载失败")
raise HTTPException(status_code=500, detail=str(exc))
@router.get("/list") @router.get("/list")
async def list_documents(): async def list_documents():
""" """List documents."""
列出所有已上传的文档从MinIO获取 documents = get_document_query_service().list_documents()
""" return {
try: "documents": [
documents = _build_document_records() {
return {"documents": documents, "total": len(documents)} "doc_id": item.doc_id,
"doc_name": item.doc_name,
except Exception as e: "status": item.status.value,
logger.error(f"列出文档失败: {e}") "chunk_count": item.chunk_count,
return {"documents": [], "total": 0, "error": str(e)} "updated_at": item.updated_at.isoformat(),
}
for item in documents
],
"total": len(documents),
}
@router.get("/management-list") @router.get("/management-list")
async def get_document_management_list(): async def get_document_management_list():
""" """Return document management list."""
文档管理清单接口仅返回最近的10条文档。 documents = get_document_query_service().list_documents(limit=10)
""" return {
try: "documents": [
documents = _build_document_records(limit=10) {
return {"documents": documents, "total": len(documents), "limit": 10} "doc_id": item.doc_id,
"doc_name": item.doc_name,
except Exception as e: "status": item.status.value,
logger.error(f"获取文档管理清单失败: {e}") "chunk_count": item.chunk_count,
return {"documents": [], "total": 0, "limit": 10, "error": str(e)} "updated_at": item.updated_at.isoformat(),
}
for item in documents
],
"total": len(documents),
"limit": 10,
}

View File

@@ -1,80 +1,51 @@
"""知识库检索接口""" """Define API routes for knowledge."""
from __future__ import annotations
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from loguru import logger
from ..models import SearchRequest, SearchResponse, SearchResultItem, ErrorResponse from app.api.models import SearchResponse, SearchResultItem, SearchRequest
from app.services.document_processor import DocumentProcessor from app.shared.bootstrap import get_retrieval_service
# Keep route handlers close to their transport-layer wiring for easier auditing.
router = APIRouter(prefix="/knowledge", tags=["knowledge"]) router = APIRouter(prefix="/knowledge", tags=["knowledge"])
@router.post("/search", response_model=SearchResponse) @router.post("/search", response_model=SearchResponse)
async def search_knowledge(request: SearchRequest): async def search_knowledge(request: SearchRequest):
""" """Search knowledge."""
检索法规知识库 if not request.query or not request.query.strip():
raise HTTPException(status_code=400, detail="查询文本不能为空")
使用混合检索Dense向量 + Sparse向量 + RRF融合 results = get_retrieval_service().retrieve(
query=request.query,
Args: top_k=request.top_k,
request: 检索请求参数 filters=request.filters,
""" )
if not request.query or len(request.query.strip()) == 0: return SearchResponse(
raise HTTPException( query=request.query,
status_code=400, total=len(results),
detail="查询文本不能为空" results=[
) SearchResultItem(
id=index + 1,
logger.info(f"收到检索请求: {request.query}") content=item.content,
score=item.score,
try: metadata={
# 执行检索 "doc_id": item.doc_id,
processor = DocumentProcessor() "doc_name": item.doc_name,
results = processor.search( "chunk_id": item.chunk_id,
query=request.query, "section_title": item.section_title,
top_k=request.top_k, "page_number": item.page_number,
filters=request.filters **item.metadata,
) },
processor.close()
# 转换结果格式
result_items = []
for r in results:
item = SearchResultItem(
id=r.get("id", 0),
content=r.get("content", ""),
score=r.get("score", 0.0),
metadata=r.get("metadata", {})
) )
result_items.append(item) for index, item in enumerate(results)
],
return SearchResponse( )
query=request.query,
total=len(result_items),
results=result_items
)
except Exception as e:
logger.error(f"检索失败: {e}")
raise HTTPException(
status_code=500,
detail=f"检索失败: {str(e)}"
)
@router.post("/retrieval", response_model=SearchResponse) @router.post("/retrieval", response_model=SearchResponse)
async def knowledge_retrieval(request: SearchRequest): async def knowledge_retrieval(request: SearchRequest):
""" """Handle knowledge retrieval."""
知识检索接口(与架构文档对齐)
该接口实现完整的检索流程:
1. 意图识别
2. BM25关键词检索 + 向量语义检索(双路召回)
3. Cross-Encoder精排
4. 返回结果
Args:
request: 检索请求
"""
# 当前版本使用混合检索,后续可添加精排步骤
return await search_knowledge(request) return await search_knowledge(request)

View File

@@ -1,29 +1,39 @@
"""Define API routes for rag."""
from __future__ import annotations
import asyncio
import json
from typing import AsyncGenerator
from fastapi import APIRouter from fastapi import APIRouter
from sse_starlette.sse import EventSourceResponse from fastapi.responses import StreamingResponse
from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion
from app.services.mock_data import ( from app.services.mock_data import (
get_mock_quick_questions, get_mock_quick_questions,
get_mock_retrieval, get_mock_retrieval,
get_mock_rag_answer, get_mock_rag_answer,
) )
import json # Keep route handlers close to their transport-layer wiring for easier auditing.
import asyncio
router = APIRouter(prefix="/rag", tags=["RAG问答"]) router = APIRouter(prefix="/rag", tags=["RAG问答"])
@router.post("/chat") @router.post("/chat")
async def rag_chat(request: RagChatRequest): async def rag_chat(request: RagChatRequest):
"""SSE流式问答""" """Handle rag chat."""
async def generate(): async def generate() -> AsyncGenerator[str, None]:
# 发送检索开始事件 # Keep route handlers close to their transport-layer wiring for easier auditing.
yield {"event": "message", "data": json.dumps({"type": "retrieving"})} """Handle generate."""
yield f"event: message\ndata: {json.dumps({'type': 'retrieving'}, ensure_ascii=False)}\n\n"
# 模拟检索延迟 # Keep route handlers close to their transport-layer wiring for easier auditing.
await asyncio.sleep(0.3) await asyncio.sleep(0.3)
# 执行检索 # Keep route handlers close to their transport-layer wiring for easier auditing.
docs = get_mock_retrieval(request.query, top_k=request.top_k) docs = get_mock_retrieval(request.query, top_k=request.top_k)
retrieved_data = [ retrieved_data = [
@@ -36,37 +46,47 @@ async def rag_chat(request: RagChatRequest):
} }
for d in docs for d in docs
] ]
yield {"event": "message", "data": json.dumps({"type": "retrieved", "docs": retrieved_data})} yield f"event: message\ndata: {json.dumps({'type': 'retrieved', 'docs': retrieved_data}, ensure_ascii=False)}\n\n"
# 发送生成开始事件 # Keep route handlers close to their transport-layer wiring for easier auditing.
yield {"event": "message", "data": json.dumps({"type": "generating", "text": "正在生成答案..."})} yield (
f"event: message\ndata: "
f"{json.dumps({'type': 'generating', 'text': '正在生成答案...'}, ensure_ascii=False)}\n\n"
)
# 模拟生成延迟 # Keep route handlers close to their transport-layer wiring for easier auditing.
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
# 获取预设答案 # Keep route handlers close to their transport-layer wiring for easier auditing.
answer = get_mock_rag_answer(request.query) answer = get_mock_rag_answer(request.query)
# 流式输出答案(按句子分割) # Keep route handlers close to their transport-layer wiring for easier auditing.
sentences = answer.split("\n\n") sentences = answer.split("\n\n")
for sentence in sentences: for sentence in sentences:
if sentence.strip(): if sentence.strip():
# 进一步分割长句子 # Keep route handlers close to their transport-layer wiring for easier auditing.
chunks = sentence.split("\n") chunks = sentence.split("\n")
for chunk in chunks: for chunk in chunks:
if chunk.strip(): if chunk.strip():
await asyncio.sleep(0.05) # 模拟生成延迟 await asyncio.sleep(0.05) # Keep route handlers close to their transport-layer wiring for easier auditing.
yield {"event": "message", "data": json.dumps({"type": "chunk", "text": chunk + "\n"})} yield (
"event: message\n"
f"data: {json.dumps({'type': 'chunk', 'text': chunk + chr(10)}, ensure_ascii=False)}\n\n"
)
# 发送完成事件 # Keep route handlers close to their transport-layer wiring for easier auditing.
yield {"event": "message", "data": json.dumps({"type": "done"})} yield f"event: message\ndata: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
return EventSourceResponse(generate()) return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)
@router.get("/quick-questions", response_model=QuickQuestionsResponse) @router.get("/quick-questions", response_model=QuickQuestionsResponse)
async def get_quick_questions(): async def get_quick_questions():
"""获取预设快捷问题""" """Return quick questions."""
questions = [ questions = [
QuickQuestion(id=q["id"], question=q["question"], category=q["category"]) QuickQuestion(id=q["id"], question=q["question"], category=q["category"])
for q in get_mock_quick_questions() for q in get_mock_quick_questions()

View File

@@ -1,28 +1,44 @@
"""Define API routes for status."""
from fastapi import APIRouter from fastapi import APIRouter
from app.core.config import settings
from app.services.mock_data import MOCK_SYSTEM_STATS, MOCK_SYSTEM_CONFIG from app.config.settings import settings
from app.shared.bootstrap import get_document_query_service, get_vector_index
# Keep route handlers close to their transport-layer wiring for easier auditing.
router = APIRouter(prefix="/status", tags=["系统状态"]) router = APIRouter(prefix="/status", tags=["系统状态"])
@router.get("/stats") @router.get("/stats")
async def get_stats(): async def get_stats():
"""获取系统统计""" """Return stats."""
# 返回预设统计数据 documents = get_document_query_service().list_documents()
return MOCK_SYSTEM_STATS indexed = sum(1 for item in documents if item.status.value == "indexed")
failed = sum(1 for item in documents if item.status.value == "failed")
return {
"documents_total": len(documents),
"documents_indexed": indexed,
"documents_failed": failed,
"chunks_total": sum(item.chunk_count for item in documents),
}
@router.get("/config") @router.get("/config")
async def get_config(): async def get_config():
"""获取当前配置""" """Return config."""
return MOCK_SYSTEM_CONFIG return {
"embedding_model": settings.embedding_model,
"embedding_dim": settings.embedding_dim,
"embedding_base_url": settings.embedding_base_url,
"milvus_collection": settings.milvus_collection,
"llm_provider": settings.llm_provider,
"llm_model": settings.llm_model,
"document_metadata_path": settings.document_metadata_path,
}
@router.get("/milvus/health") @router.get("/milvus/health")
async def milvus_health(): async def milvus_health():
"""Milvus健康检查""" """Handle milvus health."""
# 模拟连接状态(假数据模式下始终返回连接成功) return get_vector_index().health()
return {
"connected": True,
"collections": ["vehicle_regulations"],
}

View File

@@ -0,0 +1,5 @@
"""Initialize the app.application package."""
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = []

View File

@@ -0,0 +1,7 @@
"""Initialize the app.application.agent package."""
from .services import AgentConversationService
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = ["AgentConversationService"]

View File

@@ -0,0 +1,145 @@
"""Implement application-layer logic for services."""
from __future__ import annotations
from typing import Generator
from app.domain.conversation import AnswerGenerator, AnswerResult, ConversationStore
from app.domain.retrieval import RetrievedChunk
from app.application.knowledge import KnowledgeRetrievalService
# Keep orchestration logic centralized so use-case flow stays easy to trace.
class AgentConversationService:
"""Provide the Agent Conversation Service service."""
def __init__(
self,
*,
retrieval_service: KnowledgeRetrievalService,
answer_generator: AnswerGenerator,
conversation_store: ConversationStore,
) -> None:
"""Initialize the Agent Conversation Service instance."""
self.retrieval_service = retrieval_service
self.answer_generator = answer_generator
self.conversation_store = conversation_store
def ask(
self,
*,
query: str,
filters: str | None = None,
provider: str | None = None,
model: str | None = None,
top_k: int = 5,
prompt_template: str | None = None,
session_id: str | None = None,
) -> tuple[str | None, AnswerResult]:
"""Handle ask for the Agent Conversation Service instance."""
history = None
active_session_id = None
if session_id:
session = self.conversation_store.get_session(session_id)
if not session:
raise ValueError("会话不存在或已过期")
history = [{"role": msg.role, "content": msg.content} for msg in session.messages[-10:]]
active_session_id = session.session_id
self.conversation_store.save_message(session_id, role="user", content=query)
retrieved = self.retrieval_service.retrieve(query=query, top_k=top_k, filters=filters)
result = self.answer_generator.generate(
query=query,
retrieved_chunks=retrieved,
history=history,
provider=provider,
model=model,
prompt_template=prompt_template,
)
if active_session_id:
self.conversation_store.save_message(
active_session_id,
role="assistant",
content=result.answer,
sources=[source.__dict__ for source in result.sources],
)
return active_session_id, result
def chat(
self,
*,
query: str,
session_id: str | None = None,
filters: str | None = None,
provider: str | None = None,
model: str | None = None,
top_k: int = 5,
) -> tuple[str, AnswerResult]:
"""Handle chat for the Agent Conversation Service instance."""
session = self.conversation_store.get_session(session_id) if session_id else None
if session is None:
session = self.conversation_store.create_session()
self.conversation_store.save_message(session.session_id, role="user", content=query)
history = [{"role": msg.role, "content": msg.content} for msg in session.messages[-10:]]
retrieved = self.retrieval_service.retrieve(query=query, top_k=top_k, filters=filters)
result = self.answer_generator.generate(
query=query,
retrieved_chunks=retrieved,
history=history,
provider=provider,
model=model,
)
self.conversation_store.save_message(
session.session_id,
role="assistant",
content=result.answer,
sources=[source.__dict__ for source in result.sources],
)
return session.session_id, result
def stream_chat(
self,
*,
query: str,
session_id: str | None = None,
filters: str | None = None,
provider: str | None = None,
model: str | None = None,
top_k: int = 5,
prompt_template: str | None = None,
) -> tuple[str, Generator[dict, None, None]]:
"""Stream chat for the Agent Conversation Service instance."""
session = self.conversation_store.get_session(session_id) if session_id else None
if session is None:
session = self.conversation_store.create_session()
self.conversation_store.save_message(session.session_id, role="user", content=query)
history = [{"role": msg.role, "content": msg.content} for msg in session.messages[-10:]]
retrieved = self.retrieval_service.retrieve(query=query, top_k=top_k, filters=filters)
def event_stream() -> Generator[dict, None, None]:
"""Handle event stream for the Agent Conversation Service instance."""
yield {"event": "status", "data": f"找到{len(retrieved)}条相关法规,正在生成回答..."}
answer_parts: list[str] = []
sources_payload: list[dict] = []
for event in self.answer_generator.stream_generate(
query=query,
retrieved_chunks=retrieved,
history=history,
provider=provider,
model=model,
prompt_template=prompt_template,
):
if event.get("event") == "sources":
sources_payload = event.get("data", [])
if event.get("event") == "content":
answer_parts.append(str(event.get("data", "")))
yield event
full_answer = "".join(answer_parts)
self.conversation_store.save_message(
session.session_id,
role="assistant",
content=full_answer,
sources=sources_payload,
)
return session.session_id, event_stream()

View File

@@ -0,0 +1,7 @@
"""Initialize the app.application.documents package."""
from .services import DocumentCommandService, DocumentProcessResult, DocumentQueryService
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = ["DocumentCommandService", "DocumentProcessResult", "DocumentQueryService"]

View File

@@ -0,0 +1,186 @@
"""Implement application-layer logic for services."""
from __future__ import annotations
import os
import tempfile
import uuid
from dataclasses import dataclass
from loguru import logger
from app.domain.documents import (
ChunkBuilder,
Document,
DocumentBinaryStore,
DocumentParser,
DocumentRepository,
DocumentStatus,
)
from app.domain.retrieval import EmbeddingProvider, VectorIndex
# Keep orchestration logic centralized so use-case flow stays easy to trace.
@dataclass
class DocumentProcessResult:
"""Represent document process result data."""
doc_id: str
doc_name: str
status: str
message: str
num_chunks: int = 0
summary: str = ""
summary_latency_ms: int = 0
class DocumentCommandService:
"""Provide the Document Command Service service."""
def __init__(
self,
*,
document_repository: DocumentRepository,
binary_store: DocumentBinaryStore,
parser: DocumentParser,
chunk_builder: ChunkBuilder,
embedding_provider: EmbeddingProvider,
vector_index: VectorIndex,
) -> None:
"""Initialize the Document Command Service instance."""
self.document_repository = document_repository
self.binary_store = binary_store
self.parser = parser
self.chunk_builder = chunk_builder
self.embedding_provider = embedding_provider
self.vector_index = vector_index
def upload_and_process(
self,
*,
doc_id: str | None = None,
file_name: str,
content: bytes,
content_type: str,
doc_name: str | None,
regulation_type: str,
version: str,
generate_summary: bool,
) -> DocumentProcessResult:
"""Handle upload and process for the Document Command Service instance."""
doc_id = doc_id or str(uuid.uuid4())[:8]
final_doc_name = doc_name or file_name
object_name = f"{doc_id}/{file_name}"
document = Document(
doc_id=doc_id,
doc_name=final_doc_name,
file_name=file_name,
object_name=object_name,
content_type=content_type,
size_bytes=len(content),
regulation_type=regulation_type,
version=version,
metadata={"generate_summary": generate_summary},
)
self.document_repository.create(document)
temp_path = ""
try:
self.binary_store.save(
object_name=object_name,
data=content,
content_type=content_type,
metadata={"doc_id": doc_id},
)
self.document_repository.update_status(doc_id, DocumentStatus.STORED)
suffix = os.path.splitext(file_name)[1]
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
temp_file.write(content)
temp_path = temp_file.name
parsed_document = self.parser.parse(
file_path=temp_path,
doc_id=doc_id,
doc_name=final_doc_name,
)
self.document_repository.update_status(
doc_id,
DocumentStatus.PARSED,
parser_name=parsed_document.parser_name,
metadata={"structure_nodes": len(parsed_document.structure_nodes)},
)
chunks = self.chunk_builder.build(
parsed_document=parsed_document,
regulation_type=regulation_type,
version=version,
)
if not chunks:
raise ValueError("解析完成但没有生成可入库的 chunks")
vectors = self.embedding_provider.embed_texts([chunk.embedding_text for chunk in chunks])
inserted = self.vector_index.upsert(chunks, vectors)
if inserted != len(chunks):
logger.warning("Milvus upsert count mismatched: inserted={}, chunks={}", inserted, len(chunks))
self.document_repository.update_status(
doc_id,
DocumentStatus.INDEXED,
chunk_count=len(chunks),
summary="",
summary_latency_ms=0,
index_name=self.vector_index.health().get("collection_name", ""),
)
stored = self.document_repository.get(doc_id)
return DocumentProcessResult(
doc_id=doc_id,
doc_name=final_doc_name,
status=(stored.status.value if stored else DocumentStatus.INDEXED.value),
message="处理成功",
num_chunks=len(chunks),
summary=stored.summary if stored else "",
summary_latency_ms=stored.summary_latency_ms if stored else 0,
)
except Exception as exc:
logger.exception("文档处理失败: doc_id={}", doc_id)
self.document_repository.update_status(
doc_id,
DocumentStatus.FAILED,
error_message=str(exc),
)
return DocumentProcessResult(
doc_id=doc_id,
doc_name=final_doc_name,
status=DocumentStatus.FAILED.value,
message=f"文档处理失败: {exc}",
)
finally:
if temp_path and os.path.exists(temp_path):
try:
os.remove(temp_path)
except OSError:
logger.warning("临时文件清理失败: {}", temp_path)
class DocumentQueryService:
"""Provide the Document Query Service service."""
def __init__(self, *, document_repository: DocumentRepository, binary_store: DocumentBinaryStore) -> None:
"""Initialize the Document Query Service instance."""
self.document_repository = document_repository
self.binary_store = binary_store
def get(self, doc_id: str) -> Document | None:
"""Handle get for the Document Query Service instance."""
return self.document_repository.get(doc_id)
def list_documents(self, limit: int | None = None) -> list[Document]:
"""List documents for the Document Query Service instance."""
return self.document_repository.list(limit=limit)
def download(self, doc_id: str) -> tuple[Document, bytes]:
"""Handle download for the Document Query Service instance."""
document = self.document_repository.get(doc_id)
if not document:
raise FileNotFoundError(f"文档不存在: {doc_id}")
return document, self.binary_store.read(document.object_name)

View File

@@ -0,0 +1,7 @@
"""Initialize the app.application.knowledge package."""
from .services import KnowledgeRetrievalService
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = ["KnowledgeRetrievalService"]

View File

@@ -0,0 +1,19 @@
"""Implement application-layer logic for services."""
from __future__ import annotations
from app.domain.retrieval import RetrievalQuery, Retriever, RetrievedChunk
# Keep orchestration logic centralized so use-case flow stays easy to trace.
class KnowledgeRetrievalService:
"""Provide the Knowledge Retrieval Service service."""
def __init__(self, *, retriever: Retriever) -> None:
"""Initialize the Knowledge Retrieval Service instance."""
self.retriever = retriever
def retrieve(self, *, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
"""Handle retrieve for the Knowledge Retrieval Service instance."""
retrieval_query = RetrievalQuery(query=query, top_k=top_k, filters=filters)
return self.retriever.retrieve(retrieval_query)

View File

@@ -1,5 +1,7 @@
"""配置模块""" """Initialize the app.config package."""
from .settings import Settings, get_settings, settings from .settings import Settings, get_settings, settings
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = ["Settings", "get_settings", "settings"] __all__ = ["Settings", "get_settings", "settings"]

View File

@@ -1,16 +1,18 @@
"""日志配置""" """Configure backend settings for logging."""
from loguru import logger from loguru import logger
import sys import sys
# Keep configuration setup explicit so runtime behavior is easy to reason about.
def setup_logging(level: str = "INFO"): def setup_logging(level: str = "INFO"):
"""设置日志配置""" """Handle setup logging."""
# 移除默认handler # Keep configuration setup explicit so runtime behavior is easy to reason about.
logger.remove() logger.remove()
# 添加控制台输出 # Keep configuration setup explicit so runtime behavior is easy to reason about.
logger.add( logger.add(
sys.stdout, sys.stdout,
level=level, level=level,
@@ -18,7 +20,7 @@ def setup_logging(level: str = "INFO"):
colorize=True colorize=True
) )
# 添加文件输出 # Keep configuration setup explicit so runtime behavior is easy to reason about.
logger.add( logger.add(
"logs/app_{time:YYYY-MM-DD}.log", "logs/app_{time:YYYY-MM-DD}.log",
level=level, level=level,

View File

@@ -1,94 +1,119 @@
"""配置管理 - 环境变量和默认配置""" """Configure backend settings for settings."""
from pydantic_settings import BaseSettings from pathlib import Path
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import Field from pydantic import Field
from typing import Optional
from functools import lru_cache from functools import lru_cache
# Keep configuration setup explicit so runtime behavior is easy to reason about.
ROOT_DIR = Path(__file__).resolve().parents[3]
ROOT_ENV_FILES = (
ROOT_DIR / ".env",
ROOT_DIR / ".env.development",
)
class Settings(BaseSettings): class Settings(BaseSettings):
"""应用配置""" """Define configuration for settings."""
# 应用基础配置 model_config = SettingsConfigDict(
env_file=tuple(str(env_file) for env_file in ROOT_ENV_FILES),
env_file_encoding="utf-8",
extra="ignore",
)
# Keep configuration setup explicit so runtime behavior is easy to reason about.
app_name: str = Field(default="AI Regulations Demo", description="Application name") app_name: str = Field(default="AI Regulations Demo", description="Application name")
app_version: str = Field(default="0.1.0", description="应用版本") app_version: str = Field(default="0.1.0", description="应用版本")
debug: bool = Field(default=False, description="调试模式") debug: bool = Field(default=False, description="调试模式")
# Milvus向量数据库配置 # Keep configuration setup explicit so runtime behavior is easy to reason about.
milvus_host: str = Field(default="localhost", description="Milvus服务地址") milvus_host: str = Field(default="localhost", description="Milvus服务地址")
milvus_port: int = Field(default=19530, description="Milvus服务端口") milvus_port: int = Field(default=19530, description="Milvus服务端口")
milvus_collection: str = Field(default="regulations", description="法规向量集合名称") milvus_collection: str = Field(default="regulations_dense_1536", description="法规向量集合名称")
milvus_db_name: str = Field(default="default", description="Milvus数据库名称") milvus_db_name: str = Field(default="default", description="Milvus数据库名称")
# 嵌入模型配置 # Keep configuration setup explicit so runtime behavior is easy to reason about.
embedding_model: str = Field(default="BAAI/bge-m3", description="嵌入模型名称") embedding_model: str = Field(default="text-embedding-v3", description="嵌入模型名称")
embedding_dim: int = Field(default=1024, description="嵌入向量维度") embedding_dim: int = Field(default=1536, description="嵌入向量维度")
embedding_max_length: int = Field(default=8192, description="最大嵌入长度") embedding_api_key: str = Field(default="", description="Embedding API密钥")
embedding_batch_size: int = Field(default=12, description="嵌入批处理大小") embedding_base_url: str = Field(default="http://6.86.80.4:30080/v1", description="Embedding API地址")
embedding_use_fp16: bool = Field(default=True, description="使用FP16加速") embedding_timeout_seconds: int = Field(default=120, description="Embedding API超时时间(秒)")
alibaba_access_key_id: str = Field(default="", description="阿里云文档解析 Access Key ID")
alibaba_access_key_secret: str = Field(default="", description="阿里云文档解析 Access Key Secret")
alibaba_endpoint: str = Field(default="docmind-api.cn-hangzhou.aliyuncs.com", description="阿里云文档解析 endpoint")
# MinIO对象存储配置 # Keep configuration setup explicit so runtime behavior is easy to reason about.
minio_endpoint: str = Field(default="localhost:9000", description="MinIO服务地址") minio_endpoint: str = Field(default="localhost:9000", description="MinIO服务地址")
minio_access_key: str = Field(default="minioadmin", description="MinIO访问密钥") minio_access_key: str = Field(default="minioadmin", description="MinIO访问密钥")
minio_secret_key: str = Field(default="minioadmin123", description="MinIO秘密密钥") minio_secret_key: str = Field(default="minioadmin123", description="MinIO秘密密钥")
minio_bucket: str = Field(default="upload-files", description="文档存储桶名称") minio_bucket: str = Field(default="upload-files", description="文档存储桶名称")
minio_secure: bool = Field(default=False, description="是否使用HTTPS") minio_secure: bool = Field(default=False, description="是否使用HTTPS")
# Redis配置 # Keep configuration setup explicit so runtime behavior is easy to reason about.
redis_host: str = Field(default="localhost", description="Redis服务地址") redis_host: str = Field(default="localhost", description="Redis服务地址")
redis_port: int = Field(default=6379, description="Redis服务端口") redis_port: int = Field(default=6379, description="Redis服务端口")
redis_password: str = Field(default="", description="Redis密码") redis_password: str = Field(default="", description="Redis密码")
redis_db: int = Field(default=0, description="Redis数据库编号") redis_db: int = Field(default=0, description="Redis数据库编号")
# PostgreSQL配置 # Keep configuration setup explicit so runtime behavior is easy to reason about.
postgres_host: str = Field(default="localhost", description="PostgreSQL服务地址") postgres_host: str = Field(default="localhost", description="PostgreSQL服务地址")
postgres_port: int = Field(default=5432, description="PostgreSQL服务端口") postgres_port: int = Field(default=5432, description="PostgreSQL服务端口")
postgres_user: str = Field(default="compliance", description="PostgreSQL用户名") postgres_user: str = Field(default="compliance", description="PostgreSQL用户名")
postgres_password: str = Field(default="compliance123", description="PostgreSQL密码") postgres_password: str = Field(default="compliance123", description="PostgreSQL密码")
postgres_db: str = Field(default="compliance_db", description="PostgreSQL数据库名称") postgres_db: str = Field(default="compliance_db", description="PostgreSQL数据库名称")
# 文档处理配置 # Keep configuration setup explicit so runtime behavior is easy to reason about.
chunk_size: int = Field(default=512, description="分块大小(字符数)") chunk_size: int = Field(default=512, description="分块大小(字符数)")
chunk_overlap: int = Field(default=50, description="分块重叠大小") chunk_overlap: int = Field(default=50, description="分块重叠大小")
max_file_size_mb: int = Field(default=100, description="最大文件大小(MB)") max_file_size_mb: int = Field(default=100, description="最大文件大小(MB)")
document_metadata_path: str = Field(default="backend/data/documents.json", description="文档元数据存储路径")
parser_backend: str = Field(default="local", description="解析后端(local/aliyun)")
chunk_backend: str = Field(default="local", description="分块后端(local/aliyun)")
# API配置 # Keep configuration setup explicit so runtime behavior is easy to reason about.
api_host: str = Field(default="0.0.0.0", description="API服务地址") api_host: str = Field(default="0.0.0.0", description="API服务地址")
api_port: int = Field(default=8000, description="API服务端口") api_port: int = Field(default=8000, description="API服务端口")
# LLM配置 # Keep configuration setup explicit so runtime behavior is easy to reason about.
llm_provider: str = Field(default="deepseek", description="LLM提供商 (deepseek/qwen/qwen_vl)") llm_provider: str = Field(default="deepseek", description="LLM提供商 (deepseek/qwen/qwen_vl)")
llm_model: str = Field(default="deepseek-v4-flash", description="LLM模型名称") llm_model: str = Field(default="deepseek-v4-flash", description="LLM模型名称")
llm_max_tokens: int = Field(default=4096, description="LLM最大输出token数") llm_max_tokens: int = Field(default=4096, description="LLM最大输出token数")
llm_temperature: float = Field(default=0.7, description="LLM温度参数") llm_temperature: float = Field(default=0.7, description="LLM温度参数")
# DeepSeek配置 # Keep configuration setup explicit so runtime behavior is easy to reason about.
deepseek_api_key: str = Field(default="", description="DeepSeek API密钥") deepseek_api_key: str = Field(default="", description="DeepSeek API密钥")
deepseek_base_url: str = Field(default="http://6.86.80.4:30080/v1", description="DeepSeek API地址") deepseek_base_url: str = Field(default="http://6.86.80.4:30080/v1", description="DeepSeek API地址")
deepseek_model: str = Field(default="deepseek-v4-flash", description="DeepSeek模型") deepseek_model: str = Field(default="deepseek-v4-flash", description="DeepSeek模型")
# Qwen配置通过统一代理API # Keep configuration setup explicit so runtime behavior is easy to reason about.
qwen_api_key: str = Field(default="", description="Qwen API密钥") qwen_api_key: str = Field(default="", description="Qwen API密钥")
qwen_base_url: str = Field(default="http://6.86.80.4:30080/v1", description="Qwen API地址") qwen_base_url: str = Field(default="http://6.86.80.4:30080/v1", description="Qwen API地址")
qwen_model: str = Field(default="qwen3.5-flash", description="Qwen文本模型") qwen_model: str = Field(default="qwen3.5-flash", description="Qwen文本模型")
qwen_vl_model: str = Field(default="qwen3-vl-plus", description="Qwen视觉模型") qwen_vl_model: str = Field(default="qwen3-vl-plus", description="Qwen视觉模型")
# RAG配置 # Keep configuration setup explicit so runtime behavior is easy to reason about.
rag_top_k: int = Field(default=5, description="检索召回数量") rag_top_k: int = Field(default=5, description="检索召回数量")
rag_max_context_tokens: int = Field(default=2000, description="RAG最大上下文token数") rag_max_context_tokens: int = Field(default=2000, description="RAG最大上下文token数")
rag_summary_max_tokens: int = Field(default=10240, description="文档摘要最大token数") rag_summary_max_tokens: int = Field(default=10240, description="文档摘要最大token数")
class Config: # Keep configuration setup explicit so runtime behavior is easy to reason about.
env_file = ".env" milvus_index_type: str = Field(default="IVF_FLAT", description="Milvus索引类型")
env_file_encoding = "utf-8" milvus_nlist: int = Field(default=128, description="Milvus nlist参数")
extra = "ignore" milvus_nprobe: int = Field(default=16, description="Milvus nprobe参数")
# Keep configuration setup explicit so runtime behavior is easy to reason about.
session_max_sessions: int = Field(default=100, description="最大会话数量")
session_timeout_minutes: int = Field(default=30, description="会话超时时间(分钟)")
@lru_cache @lru_cache
def get_settings() -> Settings: def get_settings() -> Settings:
"""获取配置实例(缓存)""" """Return settings."""
return Settings() return Settings()
# 导出默认配置实例 # Keep configuration setup explicit so runtime behavior is easy to reason about.
settings = get_settings() settings = get_settings()

View File

@@ -1,3 +1,7 @@
"""Initialize the app.core package."""
from .config import settings, Settings from .config import settings, Settings
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = ["settings", "Settings"] __all__ = ["settings", "Settings"]

View File

@@ -1,41 +1,54 @@
from pydantic_settings import BaseSettings """Legacy-compatible config used by older utility modules."""
from typing import Optional
from pathlib import Path
from pydantic_settings import BaseSettings, SettingsConfigDict
# Keep legacy settings aligned with the root-level env loading rules.
ROOT_DIR = Path(__file__).resolve().parents[3]
ROOT_ENV_FILES = tuple(str(path) for path in (ROOT_DIR / ".env", ROOT_DIR / ".env.development"))
class Settings(BaseSettings): class Settings(BaseSettings):
# DashScope API # DashScope API
"""Define configuration for settings."""
model_config = SettingsConfigDict(
env_file=ROOT_ENV_FILES,
env_file_encoding="utf-8",
case_sensitive=False,
extra="ignore",
)
dashscope_api_key: str = "" dashscope_api_key: str = ""
# Milvus # Milvus
milvus_host: str = "localhost" milvus_host: str = "localhost"
milvus_port: int = 19530 milvus_port: int = 19530
milvus_collection: str = "regulations_dense_1536"
# LLM配置 # LLM / embedding defaults aligned with the migrated backend path.
llm_model: str = "qwen-max" llm_model: str = "qwen-max"
embedding_model: str = "text-embedding-v3" embedding_model: str = "text-embedding-v3"
embedding_dim: int = 1536 embedding_dim: int = 1536
# 检索配置 # Legacy workflow compatibility only.
vector_top_k: int = 10 vector_top_k: int = 10
bm25_top_k: int = 10
final_top_k: int = 5 final_top_k: int = 5
# 分块配置 # Legacy local chunking compatibility only; main ingest now uses Aliyun vector_chunks.
chunk_size: int = 800 chunk_size: int = 800
chunk_overlap: int = 50 chunk_overlap: int = 50
# 服务配置 # Service config.
api_host: str = "0.0.0.0" api_host: str = "0.0.0.0"
api_port: int = 8000 api_port: int = 8000
# Collection名称 # Legacy aliases retained for old utility modules.
regulations_collection: str = "vehicle_regulations" regulations_collection: str = "regulations_dense_1536"
compliance_collection: str = "compliance_cache" compliance_collection: str = "compliance_cache"
class Config: # Preserve the legacy module API while keeping env resolution centralized at the repo root.
env_file = ".env"
env_file_encoding = "utf-8"
case_sensitive = False
settings = Settings() settings = Settings()

View File

@@ -0,0 +1,5 @@
"""Initialize the app.domain package."""
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = []

View File

@@ -0,0 +1,15 @@
"""Initialize the app.domain.conversation package."""
from .models import AnswerResult, AnswerSource, ConversationMessage, ConversationSession
from .ports import AnswerGenerator, ConversationStore
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = [
"AnswerGenerator",
"AnswerResult",
"AnswerSource",
"ConversationMessage",
"ConversationSession",
"ConversationStore",
]

View File

@@ -0,0 +1,53 @@
"""Define domain models for conversation."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
# Keep module behavior explicit so the backend flow stays easy to audit.
@dataclass
class AnswerSource:
"""Represent answer source data."""
doc_id: str
doc_name: str
chunk_id: str
section_title: str
page_number: int
score: float
content: str
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class ConversationMessage:
"""Represent conversation message data."""
role: str
content: str
timestamp: int
sources: list[dict[str, Any]] = field(default_factory=list)
@dataclass
class ConversationSession:
"""Represent conversation session data."""
session_id: str
messages: list[ConversationMessage] = field(default_factory=list)
created_at: int = 0
updated_at: int = 0
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class AnswerResult:
"""Represent answer result data."""
answer: str
sources: list[AnswerSource] = field(default_factory=list)
model: str = ""
latency_ms: int = 0
retrieved_count: int = 0
context_tokens: int = 0
truncated: bool = False
error: str | None = None

View File

@@ -0,0 +1,78 @@
"""Define domain ports for conversation."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Generator
from app.domain.retrieval.models import RetrievedChunk
from .models import AnswerResult, ConversationSession
# Keep domain contracts explicit so adapters can swap implementations cleanly.
class AnswerGenerator(ABC):
"""Represent the Answer Generator type."""
@abstractmethod
def generate(
self,
*,
query: str,
retrieved_chunks: list[RetrievedChunk],
history: list[dict[str, str]] | None = None,
provider: str | None = None,
model: str | None = None,
prompt_template: str | None = None,
) -> AnswerResult:
"""Handle generate for the Answer Generator instance."""
pass
@abstractmethod
def stream_generate(
self,
*,
query: str,
retrieved_chunks: list[RetrievedChunk],
history: list[dict[str, str]] | None = None,
provider: str | None = None,
model: str | None = None,
prompt_template: str | None = None,
) -> Generator[dict, None, AnswerResult]:
"""Stream generate for the Answer Generator instance."""
pass
class ConversationStore(ABC):
"""Provide the Conversation Store store implementation."""
@abstractmethod
def create_session(self, metadata: dict | None = None) -> ConversationSession:
"""Create session for the Conversation Store instance."""
pass
@abstractmethod
def get_session(self, session_id: str) -> ConversationSession | None:
"""Return session for the Conversation Store instance."""
pass
@abstractmethod
def save_message(
self,
session_id: str,
*,
role: str,
content: str,
sources: list[dict] | None = None,
) -> ConversationSession | None:
"""Save message for the Conversation Store instance."""
pass
@abstractmethod
def delete_session(self, session_id: str) -> bool:
"""Delete session for the Conversation Store instance."""
pass
@abstractmethod
def list_sessions(self) -> list[dict]:
"""List sessions for the Conversation Store instance."""
pass

View File

@@ -0,0 +1,17 @@
"""Initialize the app.domain.documents package."""
from .models import Chunk, Document, DocumentStatus, ParsedDocument
from .ports import ChunkBuilder, DocumentBinaryStore, DocumentParser, DocumentRepository
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = [
"Chunk",
"Document",
"DocumentStatus",
"ParsedDocument",
"ChunkBuilder",
"DocumentBinaryStore",
"DocumentParser",
"DocumentRepository",
]

View File

@@ -0,0 +1,77 @@
"""Define domain models for documents."""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import Enum
from typing import Any
# Keep module behavior explicit so the backend flow stays easy to audit.
def utcnow() -> datetime:
return datetime.now(UTC)
class DocumentStatus(str, Enum):
"""Define the Document Status enumeration."""
PENDING = "pending"
STORED = "stored"
PARSED = "parsed"
INDEXED = "indexed"
FAILED = "failed"
@dataclass
class Document:
"""Represent the Document type."""
doc_id: str
doc_name: str
file_name: str
object_name: str
content_type: str
size_bytes: int
status: DocumentStatus = DocumentStatus.PENDING
regulation_type: str = ""
version: str = ""
summary: str = ""
summary_latency_ms: int = 0
chunk_count: int = 0
parser_name: str = ""
index_name: str = ""
error_message: str = ""
created_at: datetime = field(default_factory=utcnow)
updated_at: datetime = field(default_factory=utcnow)
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class ParsedDocument:
"""Represent the Parsed Document type."""
doc_id: str
doc_name: str
structure_nodes: list[dict[str, Any]]
semantic_blocks: list[dict[str, Any]]
vector_chunks: list[dict[str, Any]]
parser_name: str
raw_text: str = ""
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class Chunk:
"""Represent the Chunk type."""
chunk_id: str
doc_id: str
doc_name: str
content: str
embedding_text: str
section_title: str = ""
section_path: list[str] = field(default_factory=list)
page_number: int = 0
regulation_type: str = ""
version: str = ""
semantic_id: str = ""
block_type: str = ""
metadata: dict[str, Any] = field(default_factory=dict)

View File

@@ -0,0 +1,96 @@
"""Define domain ports for documents."""
from __future__ import annotations
from abc import ABC, abstractmethod
from .models import Chunk, Document, DocumentStatus, ParsedDocument
# Keep domain contracts explicit so adapters can swap implementations cleanly.
class DocumentRepository(ABC):
"""Provide the Document Repository repository implementation."""
@abstractmethod
def create(self, document: Document) -> Document:
"""Handle create for the Document Repository instance."""
pass
@abstractmethod
def update(self, document: Document) -> Document:
"""Handle update for the Document Repository instance."""
pass
@abstractmethod
def get(self, doc_id: str) -> Document | None:
"""Handle get for the Document Repository instance."""
pass
@abstractmethod
def list(self, limit: int | None = None) -> list[Document]:
"""Handle list for the Document Repository instance."""
pass
@abstractmethod
def update_status(
self,
doc_id: str,
status: DocumentStatus,
*,
error_message: str = "",
chunk_count: int | None = None,
summary: str | None = None,
summary_latency_ms: int | None = None,
parser_name: str | None = None,
index_name: str | None = None,
metadata: dict | None = None,
) -> Document | None:
"""Update status for the Document Repository instance."""
pass
class DocumentBinaryStore(ABC):
"""Provide the Document Binary Store store implementation."""
@abstractmethod
def save(
self,
*,
object_name: str,
data: bytes,
content_type: str,
metadata: dict[str, str] | None = None,
) -> None:
"""Handle save for the Document Binary Store instance."""
pass
@abstractmethod
def read(self, object_name: str) -> bytes:
"""Handle read for the Document Binary Store instance."""
pass
@abstractmethod
def delete(self, object_name: str) -> None:
"""Handle delete for the Document Binary Store instance."""
pass
class DocumentParser(ABC):
"""Provide the Document Parser parser."""
@abstractmethod
def parse(self, *, file_path: str, doc_id: str, doc_name: str) -> ParsedDocument:
"""Handle parse for the Document Parser instance."""
pass
class ChunkBuilder(ABC):
"""Provide the Chunk Builder builder."""
@abstractmethod
def build(
self,
*,
parsed_document: ParsedDocument,
regulation_type: str,
version: str,
) -> list[Chunk]:
"""Handle build for the Chunk Builder instance."""
pass

View File

@@ -0,0 +1,8 @@
"""Initialize the app.domain.retrieval package."""
from .models import RetrievalQuery, RetrievedChunk
from .ports import EmbeddingProvider, Retriever, VectorIndex
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = ["RetrievalQuery", "RetrievedChunk", "EmbeddingProvider", "Retriever", "VectorIndex"]

View File

@@ -0,0 +1,29 @@
"""Define domain models for retrieval."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
# Keep module behavior explicit so the backend flow stays easy to audit.
@dataclass
class RetrievalQuery:
"""Represent the Retrieval Query type."""
query: str
top_k: int
filters: str | None = None
@dataclass
class RetrievedChunk:
"""Represent the Retrieved Chunk type."""
chunk_id: str
doc_id: str
doc_name: str
content: str
score: float
section_title: str = ""
page_number: int = 0
metadata: dict[str, Any] = field(default_factory=dict)

View File

@@ -0,0 +1,60 @@
"""Define domain ports for retrieval."""
from __future__ import annotations
from abc import ABC, abstractmethod
from app.domain.documents.models import Chunk
from .models import RetrievalQuery, RetrievedChunk
# Keep domain contracts explicit so adapters can swap implementations cleanly.
class EmbeddingProvider(ABC):
"""Provide the Embedding Provider provider."""
@abstractmethod
def embed_texts(self, texts: list[str]) -> list[list[float]]:
"""Embed texts for the Embedding Provider instance."""
pass
@abstractmethod
def embed_query(self, text: str) -> list[float]:
"""Embed query for the Embedding Provider instance."""
pass
class VectorIndex(ABC):
"""Provide the Vector Index index implementation."""
@abstractmethod
def upsert(self, chunks: list[Chunk], vectors: list[list[float]]) -> int:
"""Handle upsert for the Vector Index instance."""
pass
@abstractmethod
def delete_by_document(self, doc_id: str) -> int:
"""Delete by document for the Vector Index instance."""
pass
@abstractmethod
def search(self, query_vector: list[float], top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
"""Handle search for the Vector Index instance."""
pass
@abstractmethod
def health(self) -> dict:
"""Handle health for the Vector Index instance."""
pass
class Retriever(ABC):
"""Provide the Retriever retriever."""
@abstractmethod
def retrieve(self, query: RetrievalQuery) -> list[RetrievedChunk]:
"""Handle retrieve for the Retriever instance."""
pass
@abstractmethod
def search(self, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
"""Handle search for the Retriever instance."""
pass

View File

@@ -0,0 +1,5 @@
"""Initialize the app.infrastructure package."""
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = []

View File

@@ -0,0 +1,5 @@
"""Initialize the app.infrastructure.embedding package."""
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = []

View File

@@ -0,0 +1,59 @@
"""Implement infrastructure support for openai compatible embedding provider."""
from __future__ import annotations
import os
import httpx
from app.config.settings import settings
from app.domain.retrieval import EmbeddingProvider
# Keep adapter behavior explicit so integration details remain easy to audit.
class OpenAICompatibleEmbeddingProvider(EmbeddingProvider):
"""Provide the Open A I Compatible Embedding Provider provider."""
def __init__(self) -> None:
"""Initialize the Open A I Compatible Embedding Provider instance."""
self.base_url = settings.embedding_base_url.rstrip("/")
self.api_key = (
settings.embedding_api_key
or os.getenv("OPENAI_API_KEY", "")
or os.getenv("QWEN_API_KEY", "")
or os.getenv("DEEPSEEK_API_KEY", "")
)
self.model = settings.embedding_model
self.timeout = settings.embedding_timeout_seconds
self.dimension = settings.embedding_dim
def _request(self, texts: list[str]) -> list[list[float]]:
"""Handle request for this module for the Open A I Compatible Embedding Provider instance."""
if not self.api_key:
raise ValueError("缺少 EMBEDDING_API_KEY / OPENAI_API_KEY")
response = httpx.post(
f"{self.base_url}/embeddings",
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
json={"model": self.model, "input": texts},
timeout=self.timeout,
)
response.raise_for_status()
data = response.json()
vectors = [item["embedding"] for item in sorted(data.get("data", []), key=lambda item: item["index"])]
if any(len(vector) != self.dimension for vector in vectors):
raise ValueError(f"embedding 维度不匹配,期望 {self.dimension}")
return vectors
def embed_texts(self, texts: list[str]) -> list[list[float]]:
"""Embed texts for the Open A I Compatible Embedding Provider instance."""
if not texts:
return []
return self._request(texts)
def embed_query(self, text: str) -> list[float]:
"""Embed query for the Open A I Compatible Embedding Provider instance."""
vectors = self._request([text])
return vectors[0]

View File

@@ -0,0 +1,144 @@
"""Implement infrastructure support for openai compatible answer generator."""
from __future__ import annotations
import time
from typing import Generator
from app.config.settings import settings
from app.domain.conversation import AnswerGenerator, AnswerResult, AnswerSource
from app.domain.retrieval import RetrievedChunk
from app.services.llm.llm_factory import get_llm_client
# Keep adapter behavior explicit so integration details remain easy to audit.
PROMPT_TEMPLATES = {
"default": "你是法规知识问答助手。请仅依据提供的上下文回答;如果上下文不足,明确说明。",
"compliance_qa": "你是法规合规问答助手。优先引用给定法规原文,回答要准确、克制,并注明依据来源。",
}
class OpenAICompatibleAnswerGenerator(AnswerGenerator):
"""Represent the Open A I Compatible Answer Generator type."""
def _build_messages(
self,
*,
query: str,
retrieved_chunks: list[RetrievedChunk],
history: list[dict[str, str]] | None,
prompt_template: str | None,
) -> tuple[list[dict[str, str]], int]:
"""Handle build messages for this module for the Open A I Compatible Answer Generator instance."""
system_prompt = PROMPT_TEMPLATES.get(prompt_template or "compliance_qa", PROMPT_TEMPLATES["default"])
context_blocks = []
context_tokens = 0
for idx, chunk in enumerate(retrieved_chunks, start=1):
block = (
f"[{idx}] 文档: {chunk.doc_name}\n"
f"章节: {chunk.section_title or '未标注'}\n"
f"页码: {chunk.page_number}\n"
f"内容: {chunk.content}"
)
context_tokens += len(block)
context_blocks.append(block)
context = "\n\n".join(context_blocks)[: settings.rag_max_context_tokens * 4]
messages = [{"role": "system", "content": system_prompt}]
for item in history or []:
messages.append({"role": item["role"], "content": item["content"]})
messages.append(
{
"role": "user",
"content": f"问题:{query}\n\n参考上下文:\n{context}\n\n请在回答后给出简要引用编号。",
}
)
return messages, min(context_tokens, settings.rag_max_context_tokens)
def _sources(self, chunks: list[RetrievedChunk]) -> list[AnswerSource]:
"""Handle sources for this module for the Open A I Compatible Answer Generator instance."""
return [
AnswerSource(
doc_id=chunk.doc_id,
doc_name=chunk.doc_name,
chunk_id=chunk.chunk_id,
section_title=chunk.section_title,
page_number=chunk.page_number,
score=chunk.score,
content=chunk.content,
metadata=chunk.metadata,
)
for chunk in chunks
]
def generate(
self,
*,
query: str,
retrieved_chunks: list[RetrievedChunk],
history: list[dict[str, str]] | None = None,
provider: str | None = None,
model: str | None = None,
prompt_template: str | None = None,
) -> AnswerResult:
"""Handle generate for the Open A I Compatible Answer Generator instance."""
start = time.time()
messages, context_tokens = self._build_messages(
query=query,
retrieved_chunks=retrieved_chunks,
history=history,
prompt_template=prompt_template,
)
client = get_llm_client(provider=provider or settings.llm_provider, model=model or settings.llm_model)
response = client.chat(messages)
latency_ms = int((time.time() - start) * 1000)
return AnswerResult(
answer=response.content if response.is_success else "",
sources=self._sources(retrieved_chunks),
model=response.model or (model or settings.llm_model),
latency_ms=latency_ms,
retrieved_count=len(retrieved_chunks),
context_tokens=context_tokens,
truncated=False,
error=response.error,
)
def stream_generate(
self,
*,
query: str,
retrieved_chunks: list[RetrievedChunk],
history: list[dict[str, str]] | None = None,
provider: str | None = None,
model: str | None = None,
prompt_template: str | None = None,
) -> Generator[dict, None, AnswerResult]:
"""Stream generate for the Open A I Compatible Answer Generator instance."""
start = time.time()
messages, context_tokens = self._build_messages(
query=query,
retrieved_chunks=retrieved_chunks,
history=history,
prompt_template=prompt_template,
)
sources = [source.__dict__ for source in self._sources(retrieved_chunks)]
yield {"event": "sources", "data": sources}
client = get_llm_client(provider=provider or settings.llm_provider, model=model or settings.llm_model)
answer_parts: list[str] = []
if hasattr(client, "stream_chat"):
for chunk in client.stream_chat(messages):
answer_parts.append(chunk)
yield {"event": "content", "data": chunk}
else:
response = client.chat(messages)
answer_parts.append(response.content)
yield {"event": "content", "data": response.content}
full_answer = "".join(answer_parts)
yield {
"event": "done",
"data": {
"latency_ms": int((time.time() - start) * 1000),
"retrieved_count": len(retrieved_chunks),
"context_tokens": context_tokens,
"model": model or settings.llm_model,
},
}

View File

@@ -0,0 +1,5 @@
"""Initialize the app.infrastructure.parser package."""
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = []

View File

@@ -0,0 +1,55 @@
"""Implement infrastructure support for aliyun document parser."""
from __future__ import annotations
from app.aliyun_parser.parse_pdf import (
MAX_CHARS,
OVERLAP_CHARS,
build_semantic_blocks,
build_structure_nodes,
build_vector_chunks,
collect_all_results,
init_client,
submit_job,
wait_for_completion,
)
from app.domain.documents import DocumentParser, ParsedDocument
# Keep adapter behavior explicit so integration details remain easy to audit.
class AliyunDocumentParser(DocumentParser):
"""Provide the Aliyun Document Parser parser."""
parser_name = "aliyun_docmind"
def parse(self, *, file_path: str, doc_id: str, doc_name: str) -> ParsedDocument:
"""Handle parse for the Aliyun Document Parser instance."""
client = init_client()
task_id = submit_job(client, file_path)
if not wait_for_completion(client, task_id):
raise RuntimeError("阿里云文档解析任务失败")
layouts = collect_all_results(client, task_id)
structure_nodes = build_structure_nodes(layouts)
semantic_blocks = build_semantic_blocks(layouts)
vector_chunks = build_vector_chunks(
semantic_blocks,
doc_id=doc_id,
doc_title=doc_name,
max_chars=MAX_CHARS,
overlap_chars=OVERLAP_CHARS,
)
raw_text = "\n\n".join(
block.get("text", "")
for block in semantic_blocks
if block.get("text")
)
return ParsedDocument(
doc_id=doc_id,
doc_name=doc_name,
structure_nodes=structure_nodes,
semantic_blocks=semantic_blocks,
vector_chunks=vector_chunks,
parser_name=self.parser_name,
raw_text=raw_text,
metadata={"task_id": task_id, "layout_count": len(layouts)},
)

View File

@@ -0,0 +1,66 @@
"""Local chunk builder adapter for the migrated backend architecture."""
from __future__ import annotations
from app.domain.documents import Chunk, ChunkBuilder, ParsedDocument
from app.services.embedding.text_chunker import RegulationChunker
class LocalRegulationChunkBuilder(ChunkBuilder):
"""Adapt the existing markdown chunker to the new chunk builder port."""
def __init__(self, *, chunk_size: int = 512, chunk_overlap: int = 50) -> None:
self.chunker = RegulationChunker(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
def build(
self,
*,
parsed_document: ParsedDocument,
regulation_type: str,
version: str,
) -> list[Chunk]:
markdown_text = parsed_document.raw_text.strip()
if not markdown_text:
return []
legacy_chunks = self.chunker.chunk_document(
markdown_text,
doc_id=parsed_document.doc_id,
doc_name=parsed_document.doc_name,
regulation_type=regulation_type,
version=version,
)
chunks: list[Chunk] = []
for item in legacy_chunks:
metadata = {
"section_number": item.metadata.section_number,
"section_title": item.metadata.section_title,
"clause_number": item.metadata.clause_number,
"start_position": item.metadata.start_position,
"end_position": item.metadata.end_position,
"token_count": item.token_count,
"source": "local_chunker",
}
section_path = [value for value in [item.metadata.section_number, item.metadata.section_title] if value]
chunks.append(
Chunk(
chunk_id=item.metadata.chunk_id,
doc_id=parsed_document.doc_id,
doc_name=parsed_document.doc_name,
content=item.content,
embedding_text=item.content,
section_title=item.metadata.section_title or item.metadata.section_number,
section_path=section_path,
page_number=item.metadata.page_number,
regulation_type=regulation_type,
version=version,
semantic_id=item.metadata.clause_number,
block_type="local_markdown_chunk",
metadata=metadata,
)
)
return chunks

View File

@@ -0,0 +1,38 @@
"""Local parser adapter for the migrated backend architecture."""
from __future__ import annotations
from pathlib import Path
from app.domain.documents import DocumentParser, ParsedDocument
from app.services.parser.docx_parser import parse_docx_to_markdown
from app.services.parser.pdf_parser import parse_pdf_to_markdown
class LocalDocumentParser(DocumentParser):
"""Adapt the existing local PDF/DOCX parsers to the new parser port."""
parser_name = "local_markdown_parser"
def parse(self, *, file_path: str, doc_id: str, doc_name: str) -> ParsedDocument:
suffix = Path(file_path).suffix.lower()
if suffix == ".pdf":
markdown_text = parse_pdf_to_markdown(file_path)
elif suffix in {".docx", ".doc"}:
markdown_text = parse_docx_to_markdown(file_path)
else:
raise ValueError(f"不支持的文件类型: {suffix}")
if not markdown_text.strip():
raise ValueError("本地解析完成但未提取到有效文本")
return ParsedDocument(
doc_id=doc_id,
doc_name=doc_name,
structure_nodes=[],
semantic_blocks=[],
vector_chunks=[],
parser_name=self.parser_name,
raw_text=markdown_text,
metadata={"source": "local_parser", "file_suffix": suffix},
)

View File

@@ -0,0 +1,48 @@
"""Implement infrastructure support for vector chunk builder."""
from __future__ import annotations
from app.domain.documents import Chunk, ChunkBuilder, ParsedDocument
# Keep adapter behavior explicit so integration details remain easy to audit.
class AliyunVectorChunkBuilder(ChunkBuilder):
"""Provide the Aliyun Vector Chunk Builder builder."""
def build(
self,
*,
parsed_document: ParsedDocument,
regulation_type: str,
version: str,
) -> list[Chunk]:
"""Handle build for the Aliyun Vector Chunk Builder instance."""
chunks: list[Chunk] = []
for index, item in enumerate(parsed_document.vector_chunks):
content = item.get("content") or item.get("text") or ""
embedding_text = item.get("embedding_text") or content
if not embedding_text.strip():
continue
section_path = item.get("section_path") or []
section_title = item.get("section_title") or (section_path[-1] if section_path else "")
page_number = item.get("page_start") or item.get("page") or 0
chunk_id = item.get("chunk_id") or f"{parsed_document.doc_id}-chunk-{index}"
metadata = {k: v for k, v in item.items() if k not in {"content", "embedding_text"}}
chunks.append(
Chunk(
chunk_id=str(chunk_id),
doc_id=parsed_document.doc_id,
doc_name=parsed_document.doc_name,
content=content,
embedding_text=embedding_text,
section_title=section_title,
section_path=section_path,
page_number=int(page_number or 0),
regulation_type=regulation_type,
version=version,
semantic_id=item.get("semantic_id", ""),
block_type=item.get("block_type", ""),
metadata=metadata,
)
)
return chunks

View File

@@ -0,0 +1,5 @@
"""Initialize the app.infrastructure.session package."""
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = []

View File

@@ -0,0 +1,95 @@
"""Implement infrastructure support for in memory conversation store."""
from __future__ import annotations
import time
import uuid
from app.domain.conversation import ConversationMessage, ConversationSession, ConversationStore
# Keep adapter behavior explicit so integration details remain easy to audit.
class InMemoryConversationStore(ConversationStore):
"""Provide the In Memory Conversation Store store implementation."""
def __init__(self, *, max_sessions: int = 100, timeout_minutes: int = 30) -> None:
"""Initialize the In Memory Conversation Store instance."""
self.max_sessions = max_sessions
self.timeout_seconds = timeout_minutes * 60
self.sessions: dict[str, ConversationSession] = {}
def _now(self) -> int:
"""Handle now for this module for the In Memory Conversation Store instance."""
return int(time.time())
def _cleanup_expired(self) -> None:
"""Handle cleanup expired for this module for the In Memory Conversation Store instance."""
now = self._now()
expired = [
session_id
for session_id, session in self.sessions.items()
if (now - session.updated_at) > self.timeout_seconds
]
for session_id in expired:
self.sessions.pop(session_id, None)
def create_session(self, metadata: dict | None = None) -> ConversationSession:
"""Create session for the In Memory Conversation Store instance."""
self._cleanup_expired()
if len(self.sessions) >= self.max_sessions:
oldest = min(self.sessions.values(), key=lambda item: item.updated_at)
self.sessions.pop(oldest.session_id, None)
session_id = str(uuid.uuid4())[:8]
session = ConversationSession(
session_id=session_id,
created_at=self._now(),
updated_at=self._now(),
metadata=metadata or {},
)
self.sessions[session_id] = session
return session
def get_session(self, session_id: str) -> ConversationSession | None:
"""Return session for the In Memory Conversation Store instance."""
self._cleanup_expired()
return self.sessions.get(session_id)
def save_message(
self,
session_id: str,
*,
role: str,
content: str,
sources: list[dict] | None = None,
) -> ConversationSession | None:
"""Save message for the In Memory Conversation Store instance."""
session = self.get_session(session_id)
if not session:
return None
session.messages.append(
ConversationMessage(
role=role,
content=content,
timestamp=self._now(),
sources=sources or [],
)
)
session.updated_at = self._now()
return session
def delete_session(self, session_id: str) -> bool:
"""Delete session for the In Memory Conversation Store instance."""
return self.sessions.pop(session_id, None) is not None
def list_sessions(self) -> list[dict]:
"""List sessions for the In Memory Conversation Store instance."""
self._cleanup_expired()
return [
{
"session_id": session.session_id,
"message_count": len(session.messages),
"created_at": session.created_at,
"updated_at": session.updated_at,
}
for session in self.sessions.values()
]

View File

@@ -0,0 +1,5 @@
"""Initialize the app.infrastructure.storage package."""
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = []

View File

@@ -0,0 +1,109 @@
"""Implement infrastructure support for json document repository."""
from __future__ import annotations
import json
from datetime import UTC, datetime
from pathlib import Path
from app.domain.documents import Document, DocumentRepository, DocumentStatus
# Keep adapter behavior explicit so integration details remain easy to audit.
class JsonDocumentRepository(DocumentRepository):
"""Provide the Json Document Repository repository implementation."""
def __init__(self, file_path: str) -> None:
"""Initialize the Json Document Repository instance."""
self.file_path = Path(file_path)
self.file_path.parent.mkdir(parents=True, exist_ok=True)
if not self.file_path.exists():
self.file_path.write_text("{}", encoding="utf-8")
def _load(self) -> dict[str, dict]:
"""Handle load for this module for the Json Document Repository instance."""
return json.loads(self.file_path.read_text(encoding="utf-8") or "{}")
def _save(self, payload: dict[str, dict]) -> None:
"""Handle save for this module for the Json Document Repository instance."""
self.file_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
def _serialize(self, document: Document) -> dict:
"""Handle serialize for this module for the Json Document Repository instance."""
payload = document.__dict__.copy()
payload["status"] = document.status.value
payload["created_at"] = document.created_at.isoformat()
payload["updated_at"] = document.updated_at.isoformat()
return payload
def _deserialize(self, payload: dict) -> Document:
"""Handle deserialize for this module for the Json Document Repository instance."""
return Document(
**{
**payload,
"status": DocumentStatus(payload["status"]),
"created_at": datetime.fromisoformat(payload["created_at"]),
"updated_at": datetime.fromisoformat(payload["updated_at"]),
}
)
def create(self, document: Document) -> Document:
"""Handle create for the Json Document Repository instance."""
payload = self._load()
payload[document.doc_id] = self._serialize(document)
self._save(payload)
return document
def update(self, document: Document) -> Document:
"""Handle update for the Json Document Repository instance."""
document.updated_at = datetime.now(UTC)
payload = self._load()
payload[document.doc_id] = self._serialize(document)
self._save(payload)
return document
def get(self, doc_id: str) -> Document | None:
"""Handle get for the Json Document Repository instance."""
payload = self._load()
item = payload.get(doc_id)
return self._deserialize(item) if item else None
def list(self, limit: int | None = None) -> list[Document]:
"""Handle list for the Json Document Repository instance."""
payload = self._load()
documents = [self._deserialize(item) for item in payload.values()]
documents.sort(key=lambda item: item.updated_at, reverse=True)
return documents[:limit] if limit is not None else documents
def update_status(
self,
doc_id: str,
status: DocumentStatus,
*,
error_message: str = "",
chunk_count: int | None = None,
summary: str | None = None,
summary_latency_ms: int | None = None,
parser_name: str | None = None,
index_name: str | None = None,
metadata: dict | None = None,
) -> Document | None:
"""Update status for the Json Document Repository instance."""
document = self.get(doc_id)
if not document:
return None
document.status = status
document.error_message = error_message
if chunk_count is not None:
document.chunk_count = chunk_count
if summary is not None:
document.summary = summary
if summary_latency_ms is not None:
document.summary_latency_ms = summary_latency_ms
if parser_name is not None:
document.parser_name = parser_name
if index_name is not None:
document.index_name = index_name
if metadata:
document.metadata.update(metadata)
return self.update(document)

View File

@@ -0,0 +1,47 @@
"""Implement infrastructure support for minio binary store."""
from __future__ import annotations
from app.domain.documents import DocumentBinaryStore
from app.services.storage.minio_client import MinIOClient
# Keep adapter behavior explicit so integration details remain easy to audit.
class MinioDocumentBinaryStore(DocumentBinaryStore):
"""Provide the Minio Document Binary Store store implementation."""
def __init__(self) -> None:
"""Initialize the Minio Document Binary Store instance."""
self.client = MinIOClient()
self.client.connect()
self.client.ensure_bucket()
def save(
self,
*,
object_name: str,
data: bytes,
content_type: str,
metadata: dict[str, str] | None = None,
) -> None:
"""Handle save for the Minio Document Binary Store instance."""
success = self.client.upload_bytes(
data=data,
object_name=object_name,
content_type=content_type,
metadata=metadata,
)
if not success:
raise RuntimeError("MinIO 保存失败")
def read(self, object_name: str) -> bytes:
"""Handle read for the Minio Document Binary Store instance."""
data = self.client.get_object_data(object_name)
if data is None:
raise FileNotFoundError(f"对象不存在: {object_name}")
return data
def delete(self, object_name: str) -> None:
"""Handle delete for the Minio Document Binary Store instance."""
if not self.client.delete_object(object_name):
raise FileNotFoundError(f"对象删除失败: {object_name}")

View File

@@ -0,0 +1,5 @@
"""Initialize the app.infrastructure.vectorstore package."""
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = []

View File

@@ -0,0 +1,24 @@
"""Implement infrastructure support for dense retriever."""
from __future__ import annotations
from app.domain.retrieval import EmbeddingProvider, RetrievalQuery, Retriever, RetrievedChunk, VectorIndex
# Keep adapter behavior explicit so integration details remain easy to audit.
class DenseRetriever(Retriever):
"""Provide the Dense Retriever retriever."""
def __init__(self, *, embedding_provider: EmbeddingProvider, vector_index: VectorIndex) -> None:
"""Initialize the Dense Retriever instance."""
self.embedding_provider = embedding_provider
self.vector_index = vector_index
def retrieve(self, query: RetrievalQuery) -> list[RetrievedChunk]:
"""Handle retrieve for the Dense Retriever instance."""
query_vector = self.embedding_provider.embed_query(query.query)
return self.vector_index.search(query_vector, query.top_k, query.filters)
def search(self, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
"""Handle search for the Dense Retriever instance."""
return self.retrieve(RetrievalQuery(query=query, top_k=top_k, filters=filters))

View File

@@ -0,0 +1,154 @@
"""Implement infrastructure support for milvus vector index."""
from __future__ import annotations
import json
import time
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections, utility
from app.config.settings import settings
from app.domain.documents import Chunk
from app.domain.retrieval import RetrievedChunk, VectorIndex
# Keep adapter behavior explicit so integration details remain easy to audit.
class MilvusVectorIndex(VectorIndex):
"""Provide the Milvus Vector Index index implementation."""
def __init__(self) -> None:
"""Initialize the Milvus Vector Index instance."""
self.collection_name = settings.milvus_collection
self.db_name = settings.milvus_db_name
connections.connect(
alias="default",
host=settings.milvus_host,
port=settings.milvus_port,
db_name=self.db_name,
)
self.collection = self._ensure_collection()
def _ensure_collection(self) -> Collection:
"""Handle ensure collection for this module for the Milvus Vector Index instance."""
if utility.has_collection(self.collection_name):
collection = Collection(self.collection_name)
collection.load()
return collection
schema = CollectionSchema(
fields=[
FieldSchema(name="id", dtype=DataType.VARCHAR, max_length=128, is_primary=True, auto_id=False),
FieldSchema(name="doc_id", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="doc_name", dtype=DataType.VARCHAR, max_length=256),
FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=settings.embedding_dim),
FieldSchema(name="section_title", dtype=DataType.VARCHAR, max_length=512),
FieldSchema(name="section_path", dtype=DataType.VARCHAR, max_length=4096),
FieldSchema(name="page_number", dtype=DataType.INT64),
FieldSchema(name="regulation_type", dtype=DataType.VARCHAR, max_length=128),
FieldSchema(name="version", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="semantic_id", dtype=DataType.VARCHAR, max_length=128),
FieldSchema(name="block_type", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="metadata_json", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="created_at", dtype=DataType.INT64),
],
description="Dense-only regulations index",
enable_dynamic_field=False,
)
collection = Collection(name=self.collection_name, schema=schema)
collection.create_index(
field_name="embedding",
index_params={
"metric_type": "COSINE",
"index_type": settings.milvus_index_type,
"params": {"nlist": settings.milvus_nlist},
},
)
collection.load()
return collection
def upsert(self, chunks: list[Chunk], vectors: list[list[float]]) -> int:
"""Handle upsert for the Milvus Vector Index instance."""
if len(chunks) != len(vectors):
raise ValueError("chunks 与 vectors 数量不一致")
data = []
now = int(time.time())
for chunk, vector in zip(chunks, vectors):
data.append(
{
"id": chunk.chunk_id,
"doc_id": chunk.doc_id,
"doc_name": chunk.doc_name,
"content": chunk.content[:65535],
"embedding": vector,
"section_title": chunk.section_title[:512],
"section_path": json.dumps(chunk.section_path, ensure_ascii=False)[:4096],
"page_number": chunk.page_number,
"regulation_type": chunk.regulation_type[:128],
"version": chunk.version[:64],
"semantic_id": chunk.semantic_id[:128],
"block_type": chunk.block_type[:64],
"metadata_json": json.dumps(chunk.metadata, ensure_ascii=False)[:65535],
"created_at": now,
}
)
self.collection.insert(data)
self.collection.flush()
return len(data)
def delete_by_document(self, doc_id: str) -> int:
"""Delete by document for the Milvus Vector Index instance."""
result = self.collection.delete(f'doc_id == "{doc_id}"')
return len(result.primary_keys)
def search(self, query_vector: list[float], top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
"""Handle search for the Milvus Vector Index instance."""
results = self.collection.search(
data=[query_vector],
anns_field="embedding",
param={"metric_type": "COSINE", "params": {"nprobe": settings.milvus_nprobe}},
limit=top_k,
filter=filters,
output_fields=[
"doc_id",
"doc_name",
"content",
"section_title",
"page_number",
"regulation_type",
"version",
"semantic_id",
"block_type",
"metadata_json",
],
)
payload: list[RetrievedChunk] = []
for hits in results:
for hit in hits:
metadata = {}
raw_metadata = hit.entity.get("metadata_json", "")
if raw_metadata:
try:
metadata = json.loads(raw_metadata)
except json.JSONDecodeError:
metadata = {"raw_metadata": raw_metadata}
payload.append(
RetrievedChunk(
chunk_id=str(hit.id),
doc_id=hit.entity.get("doc_id", ""),
doc_name=hit.entity.get("doc_name", ""),
content=hit.entity.get("content", ""),
score=float(hit.score),
section_title=hit.entity.get("section_title", ""),
page_number=int(hit.entity.get("page_number", 0) or 0),
metadata=metadata,
)
)
return payload
def health(self) -> dict:
"""Handle health for the Milvus Vector Index instance."""
return {
"connected": True,
"collection_name": self.collection_name,
"num_entities": self.collection.num_entities if self.collection else 0,
}

View File

@@ -1,5 +1,7 @@
"""Backend application entrypoint.""" """Backend application entrypoint."""
from app.api.main import app from app.api.main import app
# Keep module behavior explicit so the backend flow stays easy to audit.
__all__ = ["app"] __all__ = ["app"]

View File

@@ -1,3 +1,5 @@
"""Initialize the app.schemas package."""
from .doc import ( from .doc import (
DocumentUploadResponse, DocumentUploadResponse,
DocumentInfo, DocumentInfo,
@@ -24,6 +26,8 @@ from .compliance import (
ComplianceChatRequest, ComplianceChatRequest,
AnalyzeResponse, AnalyzeResponse,
) )
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = [ __all__ = [
"DocumentUploadResponse", "DocumentUploadResponse",

View File

@@ -1,21 +1,28 @@
"""Define schema models for compliance."""
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional
from enum import Enum from enum import Enum
# Group related schema definitions so validation rules stay consistent.
class RiskLevel(str, Enum): class RiskLevel(str, Enum):
"""Define the Risk Level enumeration."""
high = "high" high = "high"
medium = "medium" medium = "medium"
low = "low" low = "low"
class ComplianceStatus(str, Enum): class ComplianceStatus(str, Enum):
"""Define the Compliance Status enumeration."""
pass_status = "pass" pass_status = "pass"
warning = "warning" warning = "warning"
fail = "fail" fail = "fail"
class Regulation(BaseModel): class Regulation(BaseModel):
"""Define the Regulation API model."""
id: int id: int
name: str name: str
clause: Optional[str] = None clause: Optional[str] = None
@@ -26,6 +33,7 @@ class Regulation(BaseModel):
class ComplianceSegment(BaseModel): class ComplianceSegment(BaseModel):
"""Define the Compliance Segment API model."""
id: int id: int
index: int index: int
intent: str intent: str
@@ -37,6 +45,7 @@ class ComplianceSegment(BaseModel):
class RiskDashboard(BaseModel): class RiskDashboard(BaseModel):
"""Define the Risk Dashboard API model."""
score: float score: float
high_risk_count: int high_risk_count: int
medium_risk_count: int medium_risk_count: int
@@ -47,6 +56,7 @@ class RiskDashboard(BaseModel):
class PriorityAction(BaseModel): class PriorityAction(BaseModel):
"""Define the Priority Action API model."""
regulation: str regulation: str
issue: str issue: str
suggestion: str suggestion: str
@@ -54,6 +64,7 @@ class PriorityAction(BaseModel):
class ComplianceResult(BaseModel): class ComplianceResult(BaseModel):
"""Define the Compliance Result API model."""
task_id: str task_id: str
dashboard: RiskDashboard dashboard: RiskDashboard
segments: list[ComplianceSegment] segments: list[ComplianceSegment]
@@ -61,9 +72,11 @@ class ComplianceResult(BaseModel):
class ComplianceChatRequest(BaseModel): class ComplianceChatRequest(BaseModel):
"""Define the Compliance Chat Request API model."""
query: str query: str
class AnalyzeResponse(BaseModel): class AnalyzeResponse(BaseModel):
"""Define the Analyze Response API model."""
task_id: str task_id: str
status: str = "processing" status: str = "processing"

View File

@@ -1,9 +1,14 @@
"""Define schema models for doc."""
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional
from datetime import datetime from datetime import datetime
# Group related schema definitions so validation rules stay consistent.
class DocumentUploadResponse(BaseModel): class DocumentUploadResponse(BaseModel):
"""Define the Document Upload Response API model."""
doc_id: str doc_id: str
filename: str filename: str
size: int size: int
@@ -11,6 +16,7 @@ class DocumentUploadResponse(BaseModel):
class DocumentInfo(BaseModel): class DocumentInfo(BaseModel):
"""Define the Document Info API model."""
id: str id: str
name: str name: str
chunks: int chunks: int
@@ -19,10 +25,12 @@ class DocumentInfo(BaseModel):
class DocumentListResponse(BaseModel): class DocumentListResponse(BaseModel):
"""Define the Document List Response API model."""
docs: list[DocumentInfo] docs: list[DocumentInfo]
class ChunkInfo(BaseModel): class ChunkInfo(BaseModel):
"""Define the Chunk Info API model."""
chunk_id: str chunk_id: str
doc_name: str doc_name: str
clause_id: Optional[str] = None clause_id: Optional[str] = None
@@ -33,12 +41,14 @@ class ChunkInfo(BaseModel):
class ParseResponse(BaseModel): class ParseResponse(BaseModel):
"""Define the Parse Response API model."""
doc_id: str doc_id: str
chunks: int chunks: int
status: str = "parsed" status: str = "parsed"
class EmbedResponse(BaseModel): class EmbedResponse(BaseModel):
"""Define the Embed Response API model."""
doc_id: str doc_id: str
vectors: int vectors: int
status: str = "embedded" status: str = "embedded"

View File

@@ -1,13 +1,19 @@
"""Define schema models for rag."""
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional
# Group related schema definitions so validation rules stay consistent.
class RagChatRequest(BaseModel): class RagChatRequest(BaseModel):
"""Define the Rag Chat Request API model."""
query: str query: str
top_k: int = 5 top_k: int = 5
class RetrievedDoc(BaseModel): class RetrievedDoc(BaseModel):
"""Define the Retrieved Doc API model."""
id: str id: str
doc_name: str doc_name: str
clause_id: Optional[str] = None clause_id: Optional[str] = None
@@ -17,15 +23,18 @@ class RetrievedDoc(BaseModel):
class SourceInfo(BaseModel): class SourceInfo(BaseModel):
"""Define the Source Info API model."""
name: str name: str
clause: Optional[str] = None clause: Optional[str] = None
class QuickQuestion(BaseModel): class QuickQuestion(BaseModel):
"""Define the Quick Question API model."""
id: str id: str
question: str question: str
category: str category: str
class QuickQuestionsResponse(BaseModel): class QuickQuestionsResponse(BaseModel):
"""Define the Quick Questions Response API model."""
questions: list[QuickQuestion] questions: list[QuickQuestion]

View File

@@ -1,3 +1,5 @@
"""Backend service package.""" """Backend service package."""
# Keep package boundaries explicit so backend imports stay predictable.
__all__: list[str] = [] __all__: list[str] = []

View File

@@ -1,6 +1,8 @@
"""Agent服务模块""" """Initialize the app.services.agent package."""
from .qa_agent import QAAgent, ask_compliance_question from .qa_agent import QAAgent, ask_compliance_question
from .session_manager import SessionManager, ChatSession from .session_manager import SessionManager, ChatSession
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = ["QAAgent", "ask_compliance_question", "SessionManager", "ChatSession"] __all__ = ["QAAgent", "ask_compliance_question", "SessionManager", "ChatSession"]

View File

@@ -1,21 +1,19 @@
"""RAG问答Agent - 合规智能问答核心实现""" """Provide service-layer logic for qa agent."""
from __future__ import annotations
import time
from typing import List, Dict, Optional, Any, Generator
from dataclasses import dataclass, field from dataclasses import dataclass, field
from loguru import logger from typing import Dict, Generator, List, Optional
from app.services.llm import get_llm_client, BaseLLMClient, LLMResponse
from app.services.llm.llm_factory import LLMFactory
from app.services.rag.retriever import Retriever, RetrievedDocument
from app.services.rag.context_builder import ContextBuilder, RAGContext
from app.services.rag.prompt_templates import get_prompt_template, PromptTemplate
from app.config.settings import settings from app.config.settings import settings
from app.shared.bootstrap import get_agent_conversation_service
# Keep service responsibilities explicit so downstream behavior stays predictable.
@dataclass @dataclass
class AgentResponse: class AgentResponse:
"""Agent响应结果""" """Represent the Agent Response type."""
answer: str answer: str
sources: List[Dict] = field(default_factory=list) sources: List[Dict] = field(default_factory=list)
model: str = "" model: str = ""
@@ -27,385 +25,73 @@ class AgentResponse:
@property @property
def is_success(self) -> bool: def is_success(self) -> bool:
"""Return whether success for the Agent Response instance."""
return self.error is None return self.error is None
@dataclass @dataclass
class AgentConfig: class AgentConfig:
"""Agent配置""" """Define configuration for agent config."""
llm_provider: str = "deepseek" llm_provider: str = settings.llm_provider
llm_model: str = "deepseek-v4-flash" llm_model: str = settings.llm_model
top_k: int = 5 top_k: int = settings.rag_top_k
min_score: float = 0.3 min_score: float = 0.0
max_context_tokens: int = 2000 max_context_tokens: int = settings.rag_max_context_tokens
temperature: float = 0.7 temperature: float = settings.llm_temperature
prompt_template: str = "compliance_qa" prompt_template: str = "compliance_qa"
include_metadata: bool = True include_metadata: bool = True
class QAAgent: class QAAgent:
""" """Represent the Q A Agent type."""
合规问答Agent
核心流程:
1. 接收用户问题
2. Milvus混合检索相关法规条款
3. 构建RAG上下文
4. 调用LLM生成回答
5. 返回答案和引用来源
使用示例:
agent = QAAgent()
response = agent.ask("机动车安全技术检验有哪些要求?")
print(response.answer)
for source in response.sources:
print(f"引用: {source['doc_name']} - {source['clause_number']}")
"""
def __init__(self, config: Optional[AgentConfig] = None): def __init__(self, config: Optional[AgentConfig] = None):
""" """Initialize the Q A Agent instance."""
初始化问答Agent self.config = config or AgentConfig()
Args:
config: Agent配置可选使用默认配置
"""
self.config = config or AgentConfig(
llm_provider=settings.llm_provider,
llm_model=settings.llm_model,
top_k=settings.rag_top_k,
max_context_tokens=settings.rag_max_context_tokens
)
# 初始化组件(延迟加载)
self.llm: Optional[BaseLLMClient] = None
self.retriever: Optional[Retriever] = None
self.context_builder: Optional[ContextBuilder] = None
logger.info(f"问答Agent初始化: provider={self.config.llm_provider}, model={self.config.llm_model}")
def _init_llm(self):
"""延迟初始化LLM客户端优先使用全局缓存"""
if self.llm is None:
# 尝试先获取全局缓存的客户端
cached = LLMFactory.get_global_client(self.config.llm_provider, self.config.llm_model)
if cached:
self.llm = cached
logger.debug(f"使用全局缓存的LLM客户端: {self.config.llm_provider} - {self.config.llm_model}")
else:
logger.info("创建新的LLM客户端...")
self.llm = get_llm_client(
provider=self.config.llm_provider,
model=self.config.llm_model,
temperature=self.config.temperature
)
def _init_retriever(self):
"""延迟初始化检索器"""
if self.retriever is None:
logger.info("初始化检索器...")
self.retriever = Retriever(
top_k=self.config.top_k,
min_score=self.config.min_score
)
def _init_context_builder(self):
"""延迟初始化上下文构建器"""
if self.context_builder is None:
logger.info("初始化上下文构建器...")
self.context_builder = ContextBuilder(
max_context_tokens=self.config.max_context_tokens,
include_metadata=self.config.include_metadata
)
def ask( def ask(
self, self,
query: str, query: str,
filters: Optional[str] = None, filters: Optional[str] = None,
prompt_template: Optional[str] = None prompt_template: Optional[str] = None,
) -> AgentResponse: ) -> AgentResponse:
""" """Handle ask for the Q A Agent instance."""
回答用户问题 _, result = get_agent_conversation_service().ask(
query=query,
Args: filters=filters,
query: 用户问题 provider=self.config.llm_provider,
filters: 检索过滤条件(如 "regulation_type=='车辆安全'" model=self.config.llm_model,
prompt_template: Prompt模板名称可选覆盖默认配置 top_k=self.config.top_k,
prompt_template=prompt_template or self.config.prompt_template,
Returns: )
AgentResponse: 包含答案和引用来源的响应对象 return AgentResponse(
""" answer=result.answer,
start_time = time.time() sources=[source.__dict__ for source in result.sources],
logger.info(f"收到问题: {query}") model=result.model,
latency_ms=result.latency_ms,
try: retrieved_count=result.retrieved_count,
# Step 1: 检索相关法规 context_tokens=result.context_tokens,
self._init_retriever() truncated=result.truncated,
documents = self.retriever.retrieve(query, filters) error=result.error,
retrieved_count = len(documents)
if retrieved_count == 0:
return AgentResponse(
answer="抱歉,未找到与您的问题相关的法规条款。请尝试用不同的关键词重新提问,或提供更具体的法规名称。",
retrieved_count=0,
error="no_retrieved_documents"
)
# Step 2: 构建RAG上下文
self._init_context_builder()
template_name = prompt_template or self.config.prompt_template
template = get_prompt_template(template_name)
context = self.context_builder.build(
query=query,
documents=documents,
system_prompt=template.system_prompt
)
# Step 3: 构建LLM输入消息
messages = self._build_messages(template, context)
# Step 4: 调用LLM生成回答
self._init_llm()
llm_response = self.llm.chat(
messages=messages,
temperature=self.config.temperature
)
if not llm_response.is_success:
return AgentResponse(
answer="",
retrieved_count=retrieved_count,
error=llm_response.error
)
latency_ms = int((time.time() - start_time) * 1000)
# Step 5: 返回结果
logger.success(f"问答完成: {latency_ms}ms, {retrieved_count}条引用")
return AgentResponse(
answer=llm_response.content,
sources=context.sources,
model=llm_response.model,
latency_ms=latency_ms,
retrieved_count=retrieved_count,
context_tokens=context.total_tokens,
truncated=context.truncated
)
except Exception as e:
logger.error(f"问答失败: {e}")
return AgentResponse(
answer="",
error=str(e)
)
def ask_with_context(
self,
query: str,
documents: List[RetrievedDocument],
prompt_template: Optional[str] = None
) -> AgentResponse:
"""
使用提供的文档回答问题(不执行检索)
Args:
query: 用户问题
documents: 已检索的文档列表
prompt_template: Prompt模板名称
Returns:
AgentResponse: 响应结果
"""
start_time = time.time()
try:
self._init_context_builder()
self._init_llm()
template_name = prompt_template or self.config.prompt_template
template = get_prompt_template(template_name)
context = self.context_builder.build(
query=query,
documents=documents,
system_prompt=template.system_prompt
)
messages = self._build_messages(template, context)
llm_response = self.llm.chat(messages)
latency_ms = int((time.time() - start_time) * 1000)
return AgentResponse(
answer=llm_response.content,
sources=context.sources,
model=llm_response.model,
latency_ms=latency_ms,
retrieved_count=len(documents),
context_tokens=context.total_tokens,
truncated=context.truncated
)
except Exception as e:
logger.error(f"问答失败: {e}")
return AgentResponse(answer="", error=str(e))
def _build_messages(
self,
template: PromptTemplate,
context: RAGContext
) -> List[Dict[str, str]]:
"""构建LLM输入消息"""
user_content = template.user_template.format(
context=context.context_text,
query=context.user_query
) )
return [ def ask_stream(self, query: str, filters: Optional[str] = None) -> Generator[dict, None, None]:
{"role": "system", "content": template.system_prompt}, """Handle ask stream for the Q A Agent instance."""
{"role": "user", "content": user_content} _, stream = get_agent_conversation_service().stream_chat(
] query=query,
filters=filters,
def ask_stream( provider=self.config.llm_provider,
self, model=self.config.llm_model,
query: str, top_k=self.config.top_k,
filters: Optional[str] = None, prompt_template=self.config.prompt_template,
prompt_template: Optional[str] = None )
) -> Generator[Dict[str, Any], None, None]: for event in stream:
""" yield event
流式回答用户问题SSE模式
返回事件类型:
- {"event": "status", "data": "正在检索..."} - 状态更新
- {"event": "sources", "data": [...]} - 引用来源
- {"event": "content", "data": "文本片段"} - 回答内容
- {"event": "done", "data": {"latency_ms": ..., "model": ...}} - 完成
Args:
query: 用户问题
filters: 检索过滤条件
prompt_template: Prompt模板名称
Yields:
Dict: SSE事件数据
"""
start_time = time.time()
logger.info(f"收到流式问题: {query}")
try:
# Step 1: 检索相关法规
yield {"event": "status", "data": "正在检索相关法规..."}
self._init_retriever()
documents = self.retriever.retrieve(query, filters)
retrieved_count = len(documents)
if retrieved_count == 0:
yield {"event": "status", "data": "未找到相关法规"}
yield {"event": "content", "data": "抱歉,未找到与您的问题相关的法规条款。请尝试用不同的关键词重新提问。"}
yield {"event": "done", "data": {"latency_ms": 0, "retrieved_count": 0}}
return
# Step 2: 发送检索结果
yield {"event": "status", "data": f"找到{retrieved_count}条相关法规,正在生成回答..."}
sources = [
{
"doc_name": doc.doc_name,
"doc_id": doc.doc_id,
"clause_number": doc.clause_number,
"score": doc.score
}
for doc in documents[:5] # 只返回前5条引用
]
yield {"event": "sources", "data": sources}
# Step 3: 构建RAG上下文
self._init_context_builder()
template_name = prompt_template or self.config.prompt_template
template = get_prompt_template(template_name)
context = self.context_builder.build(
query=query,
documents=documents,
system_prompt=template.system_prompt
)
# Step 4: 构建LLM输入消息
messages = self._build_messages(template, context)
# Step 5: 流式调用LLM生成回答
self._init_llm()
full_answer = ""
# 检查LLM是否支持流式输出
if hasattr(self.llm, 'stream_chat'):
yield {"event": "status", "data": "思考中..."}
for chunk in self.llm.stream_chat(
messages=messages,
temperature=self.config.temperature
):
full_answer += chunk
yield {"event": "content", "data": chunk}
else:
# 如果不支持流式,回退到普通调用
yield {"event": "status", "data": "生成回答中..."}
llm_response = self.llm.chat(
messages=messages,
temperature=self.config.temperature
)
if llm_response.is_success:
full_answer = llm_response.content
yield {"event": "content", "data": full_answer}
# Step 6: 发送完成事件
latency_ms = int((time.time() - start_time) * 1000)
logger.success(f"流式问答完成: {latency_ms}ms, {retrieved_count}条引用")
yield {
"event": "done",
"data": {
"latency_ms": latency_ms,
"model": self.config.llm_model,
"retrieved_count": retrieved_count,
"context_tokens": context.total_tokens
}
}
except Exception as e:
logger.error(f"流式问答失败: {e}")
yield {"event": "error", "data": str(e)}
def close(self): def close(self):
"""关闭Agent资源不关闭LLM客户端因为它全局缓存""" """Release the resources held by this component."""
if self.retriever: return None
self.retriever.close()
logger.info("问答Agent已关闭")
def ask_compliance_question( def ask_compliance_question(query: str, top_k: int = 5) -> AgentResponse:
query: str, """Handle ask compliance question."""
provider: str = "deepseek", return QAAgent(AgentConfig(top_k=top_k)).ask(query)
model: str = "deepseek-v4-flash",
top_k: int = 10
) -> AgentResponse:
"""
便捷函数:问答合规问题
Args:
query: 用户问题
provider: LLM提供商
model: LLM模型
top_k: 检索数量
Returns:
AgentResponse: 响应结果
"""
config = AgentConfig(
llm_provider=provider,
llm_model=model,
top_k=top_k
)
agent = QAAgent(config)
response = agent.ask(query)
agent.close()
return response

View File

@@ -1,4 +1,4 @@
"""多轮对话会话管理""" """Provide service-layer logic for session manager."""
import time import time
import uuid import uuid
@@ -9,7 +9,7 @@ from loguru import logger
@dataclass @dataclass
class ChatMessage: class ChatMessage:
"""对话消息""" """Represent the Chat Message type."""
role: str # "user" / "assistant" / "system" role: str # "user" / "assistant" / "system"
content: str content: str
timestamp: int timestamp: int
@@ -19,7 +19,7 @@ class ChatMessage:
@dataclass @dataclass
class ChatSession: class ChatSession:
"""对话会话""" """Represent the Chat Session type."""
session_id: str session_id: str
messages: List[ChatMessage] = field(default_factory=list) messages: List[ChatMessage] = field(default_factory=list)
created_at: int = field(default_factory=lambda: int(time.time())) created_at: int = field(default_factory=lambda: int(time.time()))
@@ -27,7 +27,7 @@ class ChatSession:
metadata: Dict = field(default_factory=dict) metadata: Dict = field(default_factory=dict)
def add_user_message(self, content: str) -> ChatMessage: def add_user_message(self, content: str) -> ChatMessage:
"""添加用户消息""" """Handle add user message for the Chat Session instance."""
message = ChatMessage( message = ChatMessage(
role="user", role="user",
content=content, content=content,
@@ -42,7 +42,7 @@ class ChatSession:
content: str, content: str,
sources: List[Dict] = None sources: List[Dict] = None
) -> ChatMessage: ) -> ChatMessage:
"""添加助手消息""" """Handle add assistant message for the Chat Session instance."""
message = ChatMessage( message = ChatMessage(
role="assistant", role="assistant",
content=content, content=content,
@@ -54,9 +54,9 @@ class ChatSession:
return message return message
def get_history(self, max_turns: int = 5) -> List[Dict[str, str]]: def get_history(self, max_turns: int = 5) -> List[Dict[str, str]]:
"""获取历史对话用于LLM上下文""" """Return history for the Chat Session instance."""
history = [] history = []
# 获取最近N轮对话每轮包含user + assistant # Keep service responsibilities explicit so downstream behavior stays predictable.
recent_messages = self.messages[-(max_turns * 2):] recent_messages = self.messages[-(max_turns * 2):]
for msg in recent_messages: for msg in recent_messages:
@@ -68,81 +68,47 @@ class ChatSession:
return history return history
def clear_history(self): def clear_history(self):
"""清空对话历史""" """Handle clear history for the Chat Session instance."""
self.messages = [] self.messages = []
self.updated_at = int(time.time()) self.updated_at = int(time.time())
logger.info(f"会话历史已清空: {self.session_id}") logger.info(f"会话历史已清空: {self.session_id}")
@property @property
def message_count(self) -> int: def message_count(self) -> int:
"""消息数量""" """Handle message count for the Chat Session instance."""
return len(self.messages) return len(self.messages)
@property @property
def is_empty(self) -> bool: def is_empty(self) -> bool:
"""是否为空会话""" """Return whether empty for the Chat Session instance."""
return len(self.messages) == 0 return len(self.messages) == 0
class SessionManager: class SessionManager:
""" """Represent the Session Manager type."""
会话管理器
功能:
- 创建/获取/删除会话
- 会话超时清理
- 会话历史记录管理
使用示例:
manager = SessionManager()
# 创建会话
session = manager.create_session()
# 添加消息
session.add_user_message("什么是机动车安全技术检验?")
session.add_assistant_message("根据GB 7258...", sources=[...])
# 获取历史用于LLM多轮对话
history = session.get_history(max_turns=3)
"""
def __init__( def __init__(
self, self,
max_sessions: int = 100, max_sessions: int = 100,
session_timeout_minutes: int = 30 session_timeout_minutes: int = 30
): ):
""" """Initialize the Session Manager instance."""
初始化会话管理器
Args:
max_sessions: 最大会话数量
session_timeout_minutes: 会话超时时间(分钟)
"""
self.max_sessions = max_sessions self.max_sessions = max_sessions
self.session_timeout = session_timeout_minutes * 60 self.session_timeout = session_timeout_minutes * 60
# 会话存储(内存) # Keep service responsibilities explicit so downstream behavior stays predictable.
self._sessions: Dict[str, ChatSession] = {} self._sessions: Dict[str, ChatSession] = {}
logger.info(f"会话管理器初始化: max_sessions={max_sessions}, timeout={session_timeout_minutes}min") logger.info(f"会话管理器初始化: max_sessions={max_sessions}, timeout={session_timeout_minutes}min")
def create_session(self, metadata: Dict = None) -> ChatSession: def create_session(self, metadata: Dict = None) -> ChatSession:
""" """Create session for the Session Manager instance."""
创建新会话 # Keep service responsibilities explicit so downstream behavior stays predictable.
Args:
metadata: 会话元数据(可选)
Returns:
ChatSession: 新创建的会话
"""
# 检查会话数量限制
if len(self._sessions) >= self.max_sessions: if len(self._sessions) >= self.max_sessions:
# 清理过期会话 # Keep service responsibilities explicit so downstream behavior stays predictable.
self._cleanup_expired_sessions() self._cleanup_expired_sessions()
# 如果仍然超出限制,删除最老的会话 # Keep service responsibilities explicit so downstream behavior stays predictable.
if len(self._sessions) >= self.max_sessions: if len(self._sessions) >= self.max_sessions:
oldest_id = min( oldest_id = min(
self._sessions.keys(), self._sessions.keys(),
@@ -163,19 +129,11 @@ class SessionManager:
return session return session
def get_session(self, session_id: str) -> Optional[ChatSession]: def get_session(self, session_id: str) -> Optional[ChatSession]:
""" """Return session for the Session Manager instance."""
获取会话
Args:
session_id: 会话ID
Returns:
ChatSession: 会话对象如不存在返回None
"""
session = self._sessions.get(session_id) session = self._sessions.get(session_id)
if session: if session:
# 检查是否过期 # Keep service responsibilities explicit so downstream behavior stays predictable.
if self._is_session_expired(session): if self._is_session_expired(session):
self.delete_session(session_id) self.delete_session(session_id)
logger.info(f"会话已过期,已删除: {session_id}") logger.info(f"会话已过期,已删除: {session_id}")
@@ -184,15 +142,7 @@ class SessionManager:
return session return session
def delete_session(self, session_id: str) -> bool: def delete_session(self, session_id: str) -> bool:
""" """Delete session for the Session Manager instance."""
删除会话
Args:
session_id: 会话ID
Returns:
bool: 是否成功删除
"""
if session_id in self._sessions: if session_id in self._sessions:
del self._sessions[session_id] del self._sessions[session_id]
logger.info(f"删除会话: {session_id}") logger.info(f"删除会话: {session_id}")
@@ -200,12 +150,7 @@ class SessionManager:
return False return False
def list_sessions(self) -> List[Dict]: def list_sessions(self) -> List[Dict]:
""" """List sessions for the Session Manager instance."""
列出所有会话
Returns:
List[Dict]: 会话列表摘要
"""
return [ return [
{ {
"session_id": s.session_id, "session_id": s.session_id,
@@ -217,12 +162,12 @@ class SessionManager:
] ]
def _is_session_expired(self, session: ChatSession) -> bool: def _is_session_expired(self, session: ChatSession) -> bool:
"""检查会话是否过期""" """Handle is session expired for this module for the Session Manager instance."""
current_time = int(time.time()) current_time = int(time.time())
return (current_time - session.updated_at) > self.session_timeout return (current_time - session.updated_at) > self.session_timeout
def _cleanup_expired_sessions(self) -> int: def _cleanup_expired_sessions(self) -> int:
"""清理过期会话""" """Handle cleanup expired sessions for this module for the Session Manager instance."""
expired_ids = [ expired_ids = [
sid for sid, session in self._sessions.items() sid for sid, session in self._sessions.items()
if self._is_session_expired(session) if self._is_session_expired(session)
@@ -237,10 +182,10 @@ class SessionManager:
return len(expired_ids) return len(expired_ids)
def get_session_count(self) -> int: def get_session_count(self) -> int:
"""获取当前会话数量""" """Return session count for the Session Manager instance."""
return len(self._sessions) return len(self._sessions)
def clear_all_sessions(self): def clear_all_sessions(self):
"""清空所有会话""" """Handle clear all sessions for the Session Manager instance."""
self._sessions.clear() self._sessions.clear()
logger.info("所有会话已清空") logger.info("所有会话已清空")

View File

@@ -1,24 +1,19 @@
"""文档处理主流程 - 解析→摘要→分块→嵌入→入库""" """Provide service-layer logic for document processor."""
from __future__ import annotations
import os
from typing import List, Dict, Optional
from dataclasses import dataclass from dataclasses import dataclass
from loguru import logger from pathlib import Path
import uuid from typing import Optional
from app.shared.bootstrap import get_document_command_service, get_retrieval_service
# Keep service responsibilities explicit so downstream behavior stays predictable.
from .parser.pdf_parser import PDFParser
from .parser.docx_parser import DocxParser
from .parser.mineru_parser import ParserOrchestrator
from .embedding.text_chunker import RegulationChunker, TextChunk
from .embedding.bge_m3_embedder import BGEM3Embedder, EmbeddingResult
from .storage.milvus_client import MilvusClient
from .llm.document_summarizer import DocumentSummarizer, DocumentSummary
from app.config.settings import settings
@dataclass @dataclass
class ProcessingResult: class ProcessingResult:
"""文档处理结果""" """Represent the Processing Result type."""
doc_id: str doc_id: str
doc_name: str doc_name: str
success: bool success: bool
@@ -30,87 +25,10 @@ class ProcessingResult:
class DocumentProcessor: class DocumentProcessor:
""" """Represent the Document Processor type."""
文档处理服务 - 完整处理流程 def __init__(self, *args, generate_summary: bool = False, **kwargs):
"""Initialize the Document Processor instance."""
流程:
1. 文档解析PDF/DOCX → Markdown
2. 智能分块(章节级+条款级)
3. LLM摘要生成可选
4. 向量嵌入BGE-M3 Dense+Sparse
5. 存储入库Milvus向量数据库
"""
def __init__(
self,
chunk_size: int = None,
embedding_model: str = None,
use_mineru: bool = True,
generate_summary: bool = False, # 默认不生成摘要节省约60秒
llm_provider: str = None,
llm_model: str = None
):
"""
初始化文档处理器
Args:
chunk_size: 分块大小
embedding_model: 嵌入模型名称
use_mineru: 是否优先使用MinerU解析
generate_summary: 是否生成文档摘要默认False可节省约60秒处理时间
llm_provider: LLM提供商
llm_model: LLM模型名称
"""
self.chunk_size = chunk_size or settings.chunk_size
self.embedding_model = embedding_model or settings.embedding_model
self.use_mineru = use_mineru
self.generate_summary = generate_summary self.generate_summary = generate_summary
self.llm_provider = llm_provider or settings.llm_provider
self.llm_model = llm_model or settings.llm_model
# 初始化各组件
logger.info("初始化文档处理组件...")
# 解析器
self.parser = ParserOrchestrator()
# 分块器
self.chunker = RegulationChunker(chunk_size=self.chunk_size)
# 嵌入模型(延迟加载)
self.embedder: Optional[BGEM3Embedder] = None
# Milvus客户端延迟连接
self.milvus: Optional[MilvusClient] = None
# 摘要生成器(延迟加载)
self.summarizer: Optional[DocumentSummarizer] = None
logger.success("文档处理器初始化完成")
def _init_embedder(self):
"""延迟初始化嵌入模型"""
if self.embedder is None:
logger.info("加载嵌入模型...")
self.embedder = BGEM3Embedder(model_name=self.embedding_model)
def _init_milvus(self):
"""延迟初始化Milvus连接"""
if self.milvus is None:
logger.info("连接Milvus...")
self.milvus = MilvusClient()
self.milvus.connect()
self.milvus.create_collection(recreate=False)
self.milvus.load_collection()
def _init_summarizer(self):
"""延迟初始化摘要生成器"""
if self.summarizer is None:
logger.info("初始化摘要生成器...")
self.summarizer = DocumentSummarizer(
provider=self.llm_provider,
model=self.llm_model
)
def process( def process(
self, self,
@@ -118,286 +36,51 @@ class DocumentProcessor:
doc_id: Optional[str] = None, doc_id: Optional[str] = None,
doc_name: Optional[str] = None, doc_name: Optional[str] = None,
regulation_type: str = "", regulation_type: str = "",
version: str = "" version: str = "",
) -> ProcessingResult: ) -> ProcessingResult:
""" """Handle process for the Document Processor instance."""
处理单个文档 path = Path(file_path)
content = path.read_bytes()
result = get_document_command_service().upload_and_process(
doc_id=doc_id,
file_name=path.name,
content=content,
content_type="application/octet-stream",
doc_name=doc_name or path.name,
regulation_type=regulation_type,
version=version,
generate_summary=self.generate_summary,
)
return ProcessingResult(
doc_id=result.doc_id,
doc_name=result.doc_name,
success=result.status != "failed",
num_chunks=result.num_chunks,
message=result.message,
summary=result.summary,
summary_latency_ms=result.summary_latency_ms,
)
Args: def search(self, query: str, top_k: int = 10, filters: str | None = None) -> list[dict]:
file_path: 文档文件路径 """Handle search for the Document Processor instance."""
doc_id: 文档ID可选默认自动生成 results = get_retrieval_service().retrieve(query=query, top_k=top_k, filters=filters)
doc_name: 文档名称(可选,默认从文件名获取) return [
regulation_type: 法规类型 {
version: 文档版本 "id": item.chunk_id,
"content": item.content,
Returns: "score": item.score,
ProcessingResult: 处理结果 "metadata": {
""" "doc_id": item.doc_id,
# 生成或使用传入的文档ID "doc_name": item.doc_name,
if doc_id is None: "chunk_id": item.chunk_id,
doc_id = str(uuid.uuid4())[:8] "section_title": item.section_title,
"page_number": item.page_number,
# 获取文档名称 **item.metadata,
if doc_name is None: },
doc_name = os.path.basename(file_path) }
for item in results
logger.info(f"开始处理文档: {doc_name} (ID: {doc_id})") ]
# 初始化结果变量
summary = ""
summary_latency_ms = 0
try:
# 1. 文档解析
logger.info("Step 1: 文档解析")
markdown_text = self._parse_document(file_path)
if not markdown_text:
return ProcessingResult(
doc_id=doc_id,
doc_name=doc_name,
success=False,
message="文档解析失败,内容为空"
)
# 2. LLM摘要生成可选
if self.generate_summary:
logger.info("Step 2: LLM摘要生成")
self._init_summarizer()
summary_result = self.summarizer.summarize(
doc_name,
markdown_text,
regulation_type
)
if summary_result.is_success:
summary = summary_result.summary
summary_latency_ms = summary_result.latency_ms
logger.success(f"摘要生成完成: {summary_latency_ms}ms")
else:
logger.warning(f"摘要生成失败: {summary_result.error}")
else:
logger.info("Step 2: 跳过摘要生成未勾选节省约60秒")
# 3. 智能分块
logger.info("Step 3: 智能分块")
chunks = self._chunk_document(
markdown_text,
doc_id,
doc_name,
regulation_type,
version
)
if not chunks:
return ProcessingResult(
doc_id=doc_id,
doc_name=doc_name,
success=False,
message="分块失败,无有效内容",
markdown_text=markdown_text,
summary=summary
)
# 4. 向量嵌入
logger.info("Step 4: 向量嵌入")
embeddings = self._embed_chunks(chunks)
if embeddings is None:
return ProcessingResult(
doc_id=doc_id,
doc_name=doc_name,
success=False,
message="向量嵌入失败",
markdown_text=markdown_text,
summary=summary
)
# 5. 存储入库
logger.info("Step 5: 存储入库")
inserted_ids = self._insert_to_milvus(chunks, embeddings)
logger.success(f"文档处理完成: {doc_name}, 共{len(inserted_ids)}条记录")
return ProcessingResult(
doc_id=doc_id,
doc_name=doc_name,
success=True,
num_chunks=len(inserted_ids),
message="处理成功",
markdown_text=markdown_text,
summary=summary,
summary_latency_ms=summary_latency_ms
)
except Exception as e:
logger.error(f"文档处理失败: {e}")
return ProcessingResult(
doc_id=doc_id,
doc_name=doc_name,
success=False,
message=f"处理失败: {str(e)}"
)
def _parse_document(self, file_path: str) -> str:
"""解析文档"""
ext = os.path.splitext(file_path)[1].lower()
try:
if ext == ".pdf":
# PDF文档解析优先MinerU回退PyMuPDF
markdown_text = self.parser.parse_pdf(file_path, prefer_mineru=self.use_mineru)
elif ext in [".docx", ".doc"]:
# Word文档解析
markdown_text = self.parser.parse_docx(file_path)
else:
logger.warning(f"不支持的文件类型: {ext}")
return ""
logger.success(f"文档解析完成,内容长度: {len(markdown_text)}字符")
return markdown_text
except Exception as e:
logger.error(f"文档解析失败: {e}")
return ""
def _chunk_document(
self,
markdown_text: str,
doc_id: str,
doc_name: str,
regulation_type: str,
version: str
) -> List[TextChunk]:
"""分块文档"""
try:
chunks = self.chunker.chunk_document(
markdown_text,
doc_id=doc_id,
doc_name=doc_name,
regulation_type=regulation_type,
version=version
)
logger.success(f"分块完成,共{len(chunks)}个chunk")
return chunks
except Exception as e:
logger.error(f"分块失败: {e}")
return []
def _embed_chunks(self, chunks: List[TextChunk]) -> Optional[EmbeddingResult]:
"""嵌入分块"""
try:
# 延迟初始化嵌入模型
self._init_embedder()
# 提取文本内容
texts = [chunk.content for chunk in chunks]
# 执行嵌入
embeddings = self.embedder.embed(texts)
logger.success(f"嵌入完成,向量数: {len(embeddings.dense_embeddings)}")
return embeddings
except Exception as e:
logger.error(f"嵌入失败: {e}")
return None
def _insert_to_milvus(
self,
chunks: List[TextChunk],
embeddings: EmbeddingResult
) -> List[int]:
"""插入Milvus"""
try:
# 延迟初始化Milvus
self._init_milvus()
# 执行插入
inserted_ids = self.milvus.insert_chunks(chunks, embeddings)
logger.success(f"入库完成,共{len(inserted_ids)}条记录")
return inserted_ids
except Exception as e:
logger.error(f"入库失败: {e}")
return []
def search(
self,
query: str,
top_k: int = 10,
filters: Optional[str] = None
) -> List[Dict]:
"""
检索法规内容
Args:
query: 查询文本
top_k: 返回结果数量
filters: 过滤条件
Returns:
List[Dict]: 检索结果
"""
logger.info(f"执行检索: {query}")
try:
# 延迟初始化
self._init_embedder()
self._init_milvus()
# 生成查询向量
query_embedding = self.embedder.embed_single(query)
# 执行混合检索
results = self.milvus.hybrid_search(
query_dense=query_embedding['dense'].tolist(),
query_sparse=query_embedding['sparse'],
top_k=top_k,
filters=filters
)
# 转换为字典格式
result_dicts = []
for r in results:
result_dicts.append({
"id": r.id,
"content": r.content,
"score": r.score,
"metadata": r.metadata
})
logger.success(f"检索完成,返回{len(result_dicts)}条结果")
return result_dicts
except Exception as e:
logger.error(f"检索失败: {e}")
return []
def close(self): def close(self):
"""关闭连接""" """Release the resources held by this component."""
if self.milvus: return None
self.milvus.disconnect()
logger.info("文档处理器已关闭")
def process_document(
file_path: str,
doc_name: Optional[str] = None,
regulation_type: str = "",
version: str = ""
) -> ProcessingResult:
"""便捷函数:处理单个文档"""
processor = DocumentProcessor()
result = processor.process(file_path, doc_name, regulation_type, version)
processor.close()
return result
def search_regulations(query: str, top_k: int = 10) -> List[Dict]:
"""便捷函数:检索法规"""
processor = DocumentProcessor()
results = processor.search(query, top_k)
processor.close()
return results

View File

@@ -1,6 +1,18 @@
"""嵌入和分块服务""" """Initialize the app.services.embedding package."""
# Keep package boundaries explicit so backend imports stay predictable.
from .text_chunker import RegulationChunker
from .bge_m3_embedder import BGEM3Embedder
__all__ = ["RegulationChunker", "BGEM3Embedder"] __all__ = ["RegulationChunker", "BGEM3Embedder"]
def __getattr__(name: str):
"""Handle getattr for this module."""
if name == "RegulationChunker":
from .text_chunker import RegulationChunker
return RegulationChunker
if name == "BGEM3Embedder":
from .bge_m3_embedder import BGEM3Embedder
return BGEM3Embedder
raise AttributeError(name)

View File

@@ -1,4 +1,4 @@
"""BGE-M3嵌入服务 - Dense+Sparse双路向量生成""" """Provide service-layer logic for bge m3 embedder."""
import numpy as np import numpy as np
from typing import List, Dict, Optional, Union from typing import List, Dict, Optional, Union
@@ -6,43 +6,31 @@ from dataclasses import dataclass, field
from loguru import logger from loguru import logger
import torch import torch
import os import os
# Keep service responsibilities explicit so downstream behavior stays predictable.
# 设置HuggingFace镜像国内网络
# Keep service responsibilities explicit so downstream behavior stays predictable.
if 'HF_ENDPOINT' not in os.environ: if 'HF_ENDPOINT' not in os.environ:
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
# 本地模型路径(按优先级检查) # Keep service responsibilities explicit so downstream behavior stays predictable.
LOCAL_MODEL_PATHS = [ LOCAL_MODEL_PATHS = [
os.path.expanduser("~/.cache/modelscope/Xorbits/bge-m3"), # ModelScope下载路径 os.path.expanduser("~/.cache/modelscope/Xorbits/bge-m3"), # Keep service responsibilities explicit so downstream behavior stays predictable.
os.path.expanduser("~/.cache/huggingface/hub/models--BAAI--bge-m3/snapshots/main"), # HuggingFace本地路径 os.path.expanduser("~/.cache/huggingface/hub/models--BAAI--bge-m3/snapshots/main"), # Keep service responsibilities explicit so downstream behavior stays predictable.
] ]
@dataclass @dataclass
class EmbeddingResult: class EmbeddingResult:
"""嵌入结果""" """Represent the Embedding Result type."""
dense_embeddings: np.ndarray # Dense向量语义检索 dense_embeddings: np.ndarray # Keep service responsibilities explicit so downstream behavior stays predictable.
sparse_embeddings: List[Dict[int, float]] # Sparse向量关键词匹配 sparse_embeddings: List[Dict[int, float]] # Keep service responsibilities explicit so downstream behavior stays predictable.
texts: List[str] texts: List[str]
dim: int = 1024 dim: int = 1024
class BGEM3Embedder: class BGEM3Embedder:
""" """Represent the B G E M3 Embedder type."""
BGE-M3多语言嵌入模型服务
BGE-M3是BAAI发布的多语言嵌入模型支持
- Dense向量用于语义相似度检索
- Sparse向量用于关键词精确匹配BM25风格
- ColBERT向量用于细粒度交互匹配可选
特点:
- 支持100+语言(中英双语优化)
- 8192 tokens超长上下文
- Dense+Sparse双路检索能力
GitHub: https://github.com/FlagOpen/FlagEmbedding
"""
def __init__( def __init__(
self, self,
@@ -53,28 +41,18 @@ class BGEM3Embedder:
max_length: int = 8192, max_length: int = 8192,
local_model_path: Optional[str] = None local_model_path: Optional[str] = None
): ):
""" """Initialize the B G E M3 Embedder instance."""
初始化BGE-M3嵌入模型
Args:
model_name: 模型名称(如果使用本地路径,此参数会被忽略)
use_fp16: 是否使用FP16加速
device: 设备类型cuda/cpu默认自动选择
batch_size: 批处理大小
max_length: 最大序列长度
local_model_path: 本地模型路径(可选,优先使用)
"""
self.use_fp16 = use_fp16 self.use_fp16 = use_fp16
self.batch_size = batch_size self.batch_size = batch_size
self.max_length = max_length self.max_length = max_length
# 确定模型路径(优先使用本地路径) # Keep service responsibilities explicit so downstream behavior stays predictable.
if local_model_path and os.path.exists(local_model_path): if local_model_path and os.path.exists(local_model_path):
self.model_path = local_model_path self.model_path = local_model_path
self.model_name = "local" self.model_name = "local"
logger.info(f"使用本地模型路径: {local_model_path}") logger.info(f"使用本地模型路径: {local_model_path}")
else: else:
# 检查多个可能的本地路径 # Keep service responsibilities explicit so downstream behavior stays predictable.
found_local = False found_local = False
for path in LOCAL_MODEL_PATHS: for path in LOCAL_MODEL_PATHS:
if os.path.exists(path) and os.path.exists(os.path.join(path, "config.json")): if os.path.exists(path) and os.path.exists(os.path.join(path, "config.json")):
@@ -89,7 +67,7 @@ class BGEM3Embedder:
self.model_name = model_name self.model_name = model_name
logger.info(f"使用远程模型: {model_name}") logger.info(f"使用远程模型: {model_name}")
# 自动选择设备 # Keep service responsibilities explicit so downstream behavior stays predictable.
if device is None: if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu" self.device = "cuda" if torch.cuda.is_available() else "cpu"
else: else:
@@ -101,7 +79,7 @@ class BGEM3Embedder:
self._load_model() self._load_model()
def _load_model(self): def _load_model(self):
"""加载嵌入模型""" """Handle load model for this module for the B G E M3 Embedder instance."""
try: try:
from FlagEmbedding import BGEM3FlagModel from FlagEmbedding import BGEM3FlagModel
@@ -127,18 +105,7 @@ class BGEM3Embedder:
return_sparse: bool = True, return_sparse: bool = True,
return_colbert_vecs: bool = False return_colbert_vecs: bool = False
) -> EmbeddingResult: ) -> EmbeddingResult:
""" """Handle embed for the B G E M3 Embedder instance."""
对文本列表生成嵌入向量
Args:
texts: 文本列表
return_dense: 是否返回Dense向量
return_sparse: 是否返回Sparse向量
return_colbert_vecs: 是否返回ColBERT向量
Returns:
EmbeddingResult: 嵌入结果
"""
if not texts: if not texts:
logger.warning("输入文本列表为空") logger.warning("输入文本列表为空")
return EmbeddingResult( return EmbeddingResult(
@@ -151,7 +118,7 @@ class BGEM3Embedder:
logger.info(f"开始嵌入{len(texts)}个文本块") logger.info(f"开始嵌入{len(texts)}个文本块")
try: try:
# 执行嵌入 # Keep service responsibilities explicit so downstream behavior stays predictable.
embeddings = self.model.encode( embeddings = self.model.encode(
texts, texts,
batch_size=self.batch_size, batch_size=self.batch_size,
@@ -161,11 +128,11 @@ class BGEM3Embedder:
return_colbert_vecs=return_colbert_vecs return_colbert_vecs=return_colbert_vecs
) )
# 提取结果 # Keep service responsibilities explicit so downstream behavior stays predictable.
dense_embeddings = embeddings.get('dense_vecs', np.array([])) dense_embeddings = embeddings.get('dense_vecs', np.array([]))
sparse_embeddings = embeddings.get('lexical_weights', []) sparse_embeddings = embeddings.get('lexical_weights', [])
# 获取维度 # Keep service responsibilities explicit so downstream behavior stays predictable.
dim = dense_embeddings.shape[1] if len(dense_embeddings) > 0 else 1024 dim = dense_embeddings.shape[1] if len(dense_embeddings) > 0 else 1024
logger.success(f"嵌入完成,向量维度: {dim}") logger.success(f"嵌入完成,向量维度: {dim}")
@@ -182,15 +149,7 @@ class BGEM3Embedder:
raise raise
def embed_single(self, text: str) -> Dict[str, Union[np.ndarray, Dict]]: def embed_single(self, text: str) -> Dict[str, Union[np.ndarray, Dict]]:
""" """Embed single for the B G E M3 Embedder instance."""
对单个文本生成嵌入向量
Args:
text: 输入文本
Returns:
Dict: 包含dense和sparse向量
"""
result = self.embed([text]) result = self.embed([text])
return { return {
'dense': result.dense_embeddings[0], 'dense': result.dense_embeddings[0],
@@ -199,25 +158,17 @@ class BGEM3Embedder:
} }
def embed_dense(self, texts: List[str]) -> np.ndarray: def embed_dense(self, texts: List[str]) -> np.ndarray:
"""只生成Dense向量""" """Embed dense for the B G E M3 Embedder instance."""
result = self.embed(texts, return_sparse=False, return_colbert_vecs=False) result = self.embed(texts, return_sparse=False, return_colbert_vecs=False)
return result.dense_embeddings return result.dense_embeddings
def embed_sparse(self, texts: List[str]) -> List[Dict[int, float]]: def embed_sparse(self, texts: List[str]) -> List[Dict[int, float]]:
"""只生成Sparse向量""" """Embed sparse for the B G E M3 Embedder instance."""
result = self.embed(texts, return_dense=False, return_colbert_vecs=False) result = self.embed(texts, return_dense=False, return_colbert_vecs=False)
return result.sparse_embeddings return result.sparse_embeddings
def embed_query(self, query: str) -> Dict: def embed_query(self, query: str) -> Dict:
""" """Embed query for the B G E M3 Embedder instance."""
对查询文本生成嵌入(用于检索)
Args:
query: 查询文本
Returns:
Dict: 包含dense和sparse向量
"""
return self.embed_single(query) return self.embed_single(query)
def compute_similarity( def compute_similarity(
@@ -226,26 +177,16 @@ class BGEM3Embedder:
doc_embeddings: np.ndarray, doc_embeddings: np.ndarray,
metric: str = "cosine" metric: str = "cosine"
) -> np.ndarray: ) -> np.ndarray:
""" """Handle compute similarity for the B G E M3 Embedder instance."""
计算查询与文档的相似度
Args:
query_embedding: 查询向量
doc_embeddings: 文档向量矩阵
metric: 相似度度量cosine/dot
Returns:
np.ndarray: 相似度分数数组
"""
if metric == "cosine": if metric == "cosine":
# 余弦相似度 # Keep service responsibilities explicit so downstream behavior stays predictable.
query_norm = np.linalg.norm(query_embedding) query_norm = np.linalg.norm(query_embedding)
doc_norms = np.linalg.norm(doc_embeddings, axis=1) doc_norms = np.linalg.norm(doc_embeddings, axis=1)
similarities = np.dot(doc_embeddings, query_embedding) / (doc_norms * query_norm) similarities = np.dot(doc_embeddings, query_embedding) / (doc_norms * query_norm)
elif metric == "dot": elif metric == "dot":
# 点积相似度 # Keep service responsibilities explicit so downstream behavior stays predictable.
similarities = np.dot(doc_embeddings, query_embedding) similarities = np.dot(doc_embeddings, query_embedding)
else: else:
@@ -258,17 +199,8 @@ class BGEM3Embedder:
query_sparse: Dict[int, float], query_sparse: Dict[int, float],
doc_sparse: Dict[int, float] doc_sparse: Dict[int, float]
) -> float: ) -> float:
""" """Handle sparse similarity for the B G E M3 Embedder instance."""
计算Sparse向量的相似度BM25风格 # Keep service responsibilities explicit so downstream behavior stays predictable.
Args:
query_sparse: 查询的Sparse向量词ID -> 权重)
doc_sparse: 文档的Sparse向量
Returns:
float: 相似度分数
"""
# 计算交集词的点积
common_keys = set(query_sparse.keys()) & set(doc_sparse.keys()) common_keys = set(query_sparse.keys()) & set(doc_sparse.keys())
score = sum(query_sparse[k] * doc_sparse[k] for k in common_keys) score = sum(query_sparse[k] * doc_sparse[k] for k in common_keys)
@@ -280,7 +212,7 @@ def embed_texts(
model_name: str = "BAAI/bge-m3", model_name: str = "BAAI/bge-m3",
**kwargs **kwargs
) -> EmbeddingResult: ) -> EmbeddingResult:
"""便捷函数:对文本列表生成嵌入""" """Embed texts."""
embedder = BGEM3Embedder(model_name=model_name, **kwargs) embedder = BGEM3Embedder(model_name=model_name, **kwargs)
return embedder.embed(texts) return embedder.embed(texts)
@@ -290,6 +222,6 @@ def embed_single_text(
model_name: str = "BAAI/bge-m3", model_name: str = "BAAI/bge-m3",
**kwargs **kwargs
) -> Dict: ) -> Dict:
"""便捷函数:对单个文本生成嵌入""" """Embed single text."""
embedder = BGEM3Embedder(model_name=model_name, **kwargs) embedder = BGEM3Embedder(model_name=model_name, **kwargs)
return embedder.embed_single(text) return embedder.embed_single(text)

View File

@@ -1,51 +1,46 @@
"""智能分块器 - 章节级+条款级双粒度切割""" """Provide service-layer logic for text chunker."""
import re import re
from typing import List, Dict, Optional, Tuple from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass, field from dataclasses import dataclass, field
from loguru import logger from loguru import logger
# Keep service responsibilities explicit so downstream behavior stays predictable.
@dataclass @dataclass
class ChunkMetadata: class ChunkMetadata:
"""分块元数据""" """Represent the Chunk Metadata type."""
doc_id: str = "" doc_id: str = ""
doc_name: str = "" doc_name: str = ""
chunk_id: str = "" chunk_id: str = ""
section_number: str = "" # 章节编号(如 "第一章" section_number: str = "" # Keep service responsibilities explicit so downstream behavior stays predictable.
section_title: str = "" # 章节标题 section_title: str = "" # Keep service responsibilities explicit so downstream behavior stays predictable.
clause_number: str = "" # 条款编号(如 "第一条" clause_number: str = "" # Keep service responsibilities explicit so downstream behavior stays predictable.
page_number: int = 0 page_number: int = 0
start_position: int = 0 # 在原文中的起始位置 start_position: int = 0 # Keep service responsibilities explicit so downstream behavior stays predictable.
end_position: int = 0 # 在原文中的结束位置 end_position: int = 0 # Keep service responsibilities explicit so downstream behavior stays predictable.
regulation_type: str = "" # 法规类型 regulation_type: str = "" # Keep service responsibilities explicit so downstream behavior stays predictable.
version: str = "" version: str = ""
@dataclass @dataclass
class TextChunk: class TextChunk:
"""文本分块""" """Represent the Text Chunk type."""
content: str content: str
metadata: ChunkMetadata metadata: ChunkMetadata
token_count: int = 0 # 估算的token数量 token_count: int = 0 # Keep service responsibilities explicit so downstream behavior stays predictable.
class RegulationChunker: class RegulationChunker:
""" """Represent the Regulation Chunker type."""
法规文档智能分块器
实现章节级/条款级双粒度切割适配国标GB文档结构 # Keep service responsibilities explicit so downstream behavior stays predictable.
- 国标文档通常有明确的层级结构:章 > 节 > 条
- 每个条款应作为一个独立的语义单元
- 保留条款完整性,避免跨条款截断
"""
# 法规标题模式
CHAPTER_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+章\s+[^\n]+') CHAPTER_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+章\s+[^\n]+')
SECTION_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+节\s+[^\n]+') SECTION_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+节\s+[^\n]+')
CLAUSE_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+条\s') CLAUSE_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+条\s')
# 条款子项模式 # Keep service responsibilities explicit so downstream behavior stays predictable.
SUB_ITEM_PATTERN = re.compile(r'^[\(][一二三四五六七八九十]+[\)]\s') SUB_ITEM_PATTERN = re.compile(r'^[\(][一二三四五六七八九十]+[\)]\s')
NUMBER_ITEM_PATTERN = re.compile(r'^[\d]+[\.、]\s') NUMBER_ITEM_PATTERN = re.compile(r'^[\d]+[\.、]\s')
@@ -56,15 +51,7 @@ class RegulationChunker:
max_chunk_size: int = 2048, max_chunk_size: int = 2048,
min_chunk_size: int = 100 min_chunk_size: int = 100
): ):
""" """Initialize the Regulation Chunker instance."""
初始化分块器
Args:
chunk_size: 默认分块大小(字符数)
chunk_overlap: 分块重叠大小
max_chunk_size: 最大分块大小(防止单个条款过长)
min_chunk_size: 最小分块大小(防止碎片化)
"""
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap self.chunk_overlap = chunk_overlap
self.max_chunk_size = max_chunk_size self.max_chunk_size = max_chunk_size
@@ -78,30 +65,18 @@ class RegulationChunker:
regulation_type: str = "", regulation_type: str = "",
version: str = "" version: str = ""
) -> List[TextChunk]: ) -> List[TextChunk]:
""" """Handle chunk document for the Regulation Chunker instance."""
对法规文档进行智能分块
Args:
markdown_text: Markdown格式的文档内容
doc_id: 文档ID
doc_name: 文档名称
regulation_type: 法规类型
version: 文档版本
Returns:
List[TextChunk]: 分块列表
"""
logger.info(f"开始分块文档: {doc_name}") logger.info(f"开始分块文档: {doc_name}")
# 1. 按章节分割(一级分块) # Keep service responsibilities explicit so downstream behavior stays predictable.
sections = self._split_by_sections(markdown_text) sections = self._split_by_sections(markdown_text)
# 2. 在每个章节内按条款分割(二级分块) # Keep service responsibilities explicit so downstream behavior stays predictable.
chunks = [] chunks = []
global_position = 0 global_position = 0
for section_num, section_title, section_content, section_start in sections: for section_num, section_title, section_content, section_start in sections:
# 在章节内按条款分割 # Keep service responsibilities explicit so downstream behavior stays predictable.
clause_chunks = self._split_by_clauses( clause_chunks = self._split_by_clauses(
section_content, section_content,
section_num, section_num,
@@ -110,7 +85,7 @@ class RegulationChunker:
) )
for chunk_content, clause_num, clause_title, start_pos, end_pos in clause_chunks: for chunk_content, clause_num, clause_title, start_pos, end_pos in clause_chunks:
# 处理过长的条款(进一步细分) # Keep service responsibilities explicit so downstream behavior stays predictable.
if len(chunk_content) > self.max_chunk_size: if len(chunk_content) > self.max_chunk_size:
sub_chunks = self._split_long_clause( sub_chunks = self._split_long_clause(
chunk_content, chunk_content,
@@ -150,12 +125,7 @@ class RegulationChunker:
return chunks return chunks
def _split_by_sections(self, markdown_text: str) -> List[Tuple[str, str, str, int]]: def _split_by_sections(self, markdown_text: str) -> List[Tuple[str, str, str, int]]:
""" """Handle split by sections for this module for the Regulation Chunker instance."""
按章节分割文档
Returns:
List of (section_number, section_title, section_content, start_position)
"""
sections = [] sections = []
lines = markdown_text.split('\n') lines = markdown_text.split('\n')
@@ -165,12 +135,12 @@ class RegulationChunker:
current_section_start = 0 current_section_start = 0
for i, line in enumerate(lines): for i, line in enumerate(lines):
# 检测章节标题 # Keep service responsibilities explicit so downstream behavior stays predictable.
chapter_match = self.CHAPTER_PATTERN.match(line.strip()) chapter_match = self.CHAPTER_PATTERN.match(line.strip())
section_match = self.SECTION_PATTERN.match(line.strip()) section_match = self.SECTION_PATTERN.match(line.strip())
if chapter_match or section_match: if chapter_match or section_match:
# 保存上一个章节 # Keep service responsibilities explicit so downstream behavior stays predictable.
if current_section_content: if current_section_content:
content = '\n'.join(current_section_content) content = '\n'.join(current_section_content)
sections.append(( sections.append((
@@ -180,7 +150,7 @@ class RegulationChunker:
current_section_start current_section_start
)) ))
# 开始新章节 # Keep service responsibilities explicit so downstream behavior stays predictable.
current_section_start = sum(len(l) + 1 for l in lines[:i]) current_section_start = sum(len(l) + 1 for l in lines[:i])
current_section_content = [] current_section_content = []
@@ -193,7 +163,7 @@ class RegulationChunker:
current_section_content.append(line) current_section_content.append(line)
# 保存最后一个章节 # Keep service responsibilities explicit so downstream behavior stays predictable.
if current_section_content: if current_section_content:
content = '\n'.join(current_section_content) content = '\n'.join(current_section_content)
sections.append(( sections.append((
@@ -203,7 +173,7 @@ class RegulationChunker:
current_section_start current_section_start
)) ))
# 如果没有检测到章节,将整个文档作为一个大章节 # Keep service responsibilities explicit so downstream behavior stays predictable.
if not sections: if not sections:
sections.append(( sections.append((
"", "",
@@ -221,12 +191,7 @@ class RegulationChunker:
section_title: str, section_title: str,
section_start: int section_start: int
) -> List[Tuple[str, str, str, int, int]]: ) -> List[Tuple[str, str, str, int, int]]:
""" """Handle split by clauses for this module for the Regulation Chunker instance."""
在章节内按条款分割
Returns:
List of (content, clause_number, clause_title, start_position, end_position)
"""
clauses = [] clauses = []
lines = section_content.split('\n') lines = section_content.split('\n')
@@ -236,11 +201,11 @@ class RegulationChunker:
current_clause_start = section_start current_clause_start = section_start
for i, line in enumerate(lines): for i, line in enumerate(lines):
# 检测条款标题 # Keep service responsibilities explicit so downstream behavior stays predictable.
clause_match = self.CLAUSE_PATTERN.match(line.strip()) clause_match = self.CLAUSE_PATTERN.match(line.strip())
if clause_match: if clause_match:
# 保存上一个条款 # Keep service responsibilities explicit so downstream behavior stays predictable.
if current_clause_content: if current_clause_content:
content = '\n'.join(current_clause_content) content = '\n'.join(current_clause_content)
end_pos = current_clause_start + len(content) end_pos = current_clause_start + len(content)
@@ -252,7 +217,7 @@ class RegulationChunker:
end_pos end_pos
)) ))
# 开始新条款 # Keep service responsibilities explicit so downstream behavior stays predictable.
current_clause_start = section_start + sum(len(l) + 1 for l in lines[:i]) current_clause_start = section_start + sum(len(l) + 1 for l in lines[:i])
current_clause_content = [] current_clause_content = []
current_clause_num = self._extract_clause_number(line.strip()) current_clause_num = self._extract_clause_number(line.strip())
@@ -260,7 +225,7 @@ class RegulationChunker:
current_clause_content.append(line) current_clause_content.append(line)
# 保存最后一个条款 # Keep service responsibilities explicit so downstream behavior stays predictable.
if current_clause_content: if current_clause_content:
content = '\n'.join(current_clause_content) content = '\n'.join(current_clause_content)
end_pos = current_clause_start + len(content) end_pos = current_clause_start + len(content)
@@ -272,7 +237,7 @@ class RegulationChunker:
end_pos end_pos
)) ))
# 如果没有检测到条款,将整个章节作为一个条款 # Keep service responsibilities explicit so downstream behavior stays predictable.
if not clauses: if not clauses:
clauses.append(( clauses.append((
section_content, section_content,
@@ -290,15 +255,11 @@ class RegulationChunker:
clause_num: str, clause_num: str,
clause_title: str clause_title: str
) -> List[Tuple[str, int, int]]: ) -> List[Tuple[str, int, int]]:
""" """Handle split long clause for this module for the Regulation Chunker instance."""
分割过长的条款内容
按条款子项或段落分割,保持语义完整性
"""
sub_chunks = [] sub_chunks = []
lines = content.split('\n') lines = content.split('\n')
# 检测是否有子项结构 # Keep service responsibilities explicit so downstream behavior stays predictable.
has_sub_items = any( has_sub_items = any(
self.SUB_ITEM_PATTERN.match(line.strip()) or self.SUB_ITEM_PATTERN.match(line.strip()) or
self.NUMBER_ITEM_PATTERN.match(line.strip()) self.NUMBER_ITEM_PATTERN.match(line.strip())
@@ -306,7 +267,7 @@ class RegulationChunker:
) )
if has_sub_items: if has_sub_items:
# 按子项分割 # Keep service responsibilities explicit so downstream behavior stays predictable.
current_sub_content = [] current_sub_content = []
current_sub_start = 0 current_sub_start = 0
@@ -326,14 +287,14 @@ class RegulationChunker:
current_sub_content.append(line) current_sub_content.append(line)
# 保存最后一个子项 # Keep service responsibilities explicit so downstream behavior stays predictable.
if current_sub_content: if current_sub_content:
sub_content = '\n'.join(current_sub_content) sub_content = '\n'.join(current_sub_content)
sub_end = current_sub_start + len(sub_content) sub_end = current_sub_start + len(sub_content)
sub_chunks.append((sub_content, current_sub_start, sub_end)) sub_chunks.append((sub_content, current_sub_start, sub_end))
else: else:
# 按段落分割(滑动窗口) # Keep service responsibilities explicit so downstream behavior stays predictable.
paragraphs = [] paragraphs = []
current_para = [] current_para = []
@@ -348,7 +309,7 @@ class RegulationChunker:
if current_para: if current_para:
paragraphs.append('\n'.join(current_para)) paragraphs.append('\n'.join(current_para))
# 合并段落直到达到chunk_size # Keep service responsibilities explicit so downstream behavior stays predictable.
current_chunk = [] current_chunk = []
current_length = 0 current_length = 0
chunk_start = 0 chunk_start = 0
@@ -365,7 +326,7 @@ class RegulationChunker:
current_chunk.append(para) current_chunk.append(para)
current_length += len(para) current_length += len(para)
# 保存最后一个chunk # Keep service responsibilities explicit so downstream behavior stays predictable.
if current_chunk: if current_chunk:
chunk_content = '\n'.join(current_chunk) chunk_content = '\n'.join(current_chunk)
chunk_end = chunk_start + len(chunk_content) chunk_end = chunk_start + len(chunk_content)
@@ -374,13 +335,13 @@ class RegulationChunker:
return sub_chunks return sub_chunks
def _extract_title(self, header_line: str) -> str: def _extract_title(self, header_line: str) -> str:
"""从标题行提取标题内容""" """Handle extract title for this module for the Regulation Chunker instance."""
# 移除"第X章"、"第X节"前缀 # Keep service responsibilities explicit so downstream behavior stays predictable.
title = re.sub(r'^第[一二三四五六七八九十百]+[章节]\s+', '', header_line) title = re.sub(r'^第[一二三四五六七八九十百]+[章节]\s+', '', header_line)
return title.strip() return title.strip()
def _extract_clause_number(self, clause_line: str) -> str: def _extract_clause_number(self, clause_line: str) -> str:
"""从条款行提取条款编号""" """Handle extract clause number for this module for the Regulation Chunker instance."""
match = self.CLAUSE_PATTERN.match(clause_line) match = self.CLAUSE_PATTERN.match(clause_line)
if match: if match:
return match.group(0).strip() return match.group(0).strip()
@@ -399,14 +360,14 @@ class RegulationChunker:
regulation_type: str, regulation_type: str,
version: str version: str
) -> TextChunk: ) -> TextChunk:
"""创建文本分块""" """Handle create chunk for this module for the Regulation Chunker instance."""
# 清理内容 # Keep service responsibilities explicit so downstream behavior stays predictable.
content = content.strip() content = content.strip()
# 计算估算token数中文约1.5字符/token # Keep service responsibilities explicit so downstream behavior stays predictable.
token_count = int(len(content) * 0.7) # 简化估算 token_count = int(len(content) * 0.7) # Keep service responsibilities explicit so downstream behavior stays predictable.
# 生成chunk_id # Keep service responsibilities explicit so downstream behavior stays predictable.
chunk_id = f"{doc_id}_{section_num}_{clause_num}_{start_pos}" chunk_id = f"{doc_id}_{section_num}_{clause_num}_{start_pos}"
metadata = ChunkMetadata( metadata = ChunkMetadata(
@@ -437,7 +398,7 @@ def chunk_regulation_document(
version: str = "", version: str = "",
chunk_size: int = 512 chunk_size: int = 512
) -> List[TextChunk]: ) -> List[TextChunk]:
"""便捷函数:对法规文档进行分块""" """Handle chunk regulation document."""
chunker = RegulationChunker(chunk_size=chunk_size) chunker = RegulationChunker(chunk_size=chunk_size)
return chunker.chunk_document( return chunker.chunk_document(
markdown_text, markdown_text,

View File

@@ -1,14 +1,36 @@
"""LLM服务模块""" """Initialize the app.services.llm package."""
from .llm_factory import LLMFactory, get_llm_client from .base_client import BaseLLMClient, LLMConfig, LLMProvider, LLMResponse
from .base_client import BaseLLMClient, LLMResponse, LLMConfig, LLMProvider
from .deepseek_client import DeepSeekClient from .deepseek_client import DeepSeekClient
from .llm_factory import LLMFactory, get_llm_client
from .qwen_client import QwenClient, QwenVLClient from .qwen_client import QwenClient, QwenVLClient
from .document_summarizer import DocumentSummarizer, summarize_document, DocumentSummary # Keep package boundaries explicit so backend imports stay predictable.
__all__ = [ __all__ = [
"LLMFactory", "get_llm_client", "LLMFactory",
"BaseLLMClient", "LLMResponse", "LLMConfig", "LLMProvider", "get_llm_client",
"DeepSeekClient", "QwenClient", "QwenVLClient", "BaseLLMClient",
"DocumentSummarizer", "summarize_document", "DocumentSummary" "LLMResponse",
"LLMConfig",
"LLMProvider",
"DeepSeekClient",
"QwenClient",
"QwenVLClient",
"DocumentSummarizer",
"summarize_document",
"DocumentSummary",
] ]
def __getattr__(name: str):
"""Handle getattr for this module."""
if name in {"DocumentSummarizer", "summarize_document", "DocumentSummary"}:
from .document_summarizer import DocumentSummarizer, DocumentSummary, summarize_document
return {
"DocumentSummarizer": DocumentSummarizer,
"summarize_document": summarize_document,
"DocumentSummary": DocumentSummary,
}[name]
raise AttributeError(name)

View File

@@ -1,13 +1,15 @@
"""LLM客户端基类 - 统一接口定义""" """Provide service-layer logic for base client."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Dict, Optional, Any from typing import List, Dict, Optional, Any
from enum import Enum from enum import Enum
# Keep provider-specific behavior explicit so debugging stays straightforward.
class LLMProvider(Enum): class LLMProvider(Enum):
"""LLM提供商""" """Define the L L M Provider enumeration."""
DEEPSEEK = "deepseek" DEEPSEEK = "deepseek"
QWEN = "qwen" QWEN = "qwen"
QWEN_VL = "qwen_vl" QWEN_VL = "qwen_vl"
@@ -15,7 +17,7 @@ class LLMProvider(Enum):
@dataclass @dataclass
class LLMResponse: class LLMResponse:
"""LLM响应结果""" """Represent the L L M Response type."""
content: str content: str
model: str model: str
usage: Dict[str, int] = field(default_factory=dict) usage: Dict[str, int] = field(default_factory=dict)
@@ -25,12 +27,13 @@ class LLMResponse:
@property @property
def is_success(self) -> bool: def is_success(self) -> bool:
"""Return whether success for the L L M Response instance."""
return self.error is None return self.error is None
@dataclass @dataclass
class LLMConfig: class LLMConfig:
"""LLM配置""" """Define configuration for l l m config."""
provider: LLMProvider provider: LLMProvider
model: str model: str
api_key: str api_key: str
@@ -38,19 +41,20 @@ class LLMConfig:
max_tokens: int = 4096 max_tokens: int = 4096
temperature: float = 0.7 temperature: float = 0.7
top_p: float = 0.9 top_p: float = 0.9
timeout: int = 300 # 默认超时300秒摘要/Skills生成可能需要较长时间 timeout: int = 300 # Keep provider-specific behavior explicit so debugging stays straightforward.
class BaseLLMClient(ABC): class BaseLLMClient(ABC):
"""LLM客户端基类""" """Represent the Base L L M Client type."""
def __init__(self, config: LLMConfig): def __init__(self, config: LLMConfig):
"""Initialize the Base L L M Client instance."""
self.config = config self.config = config
self._client = None self._client = None
@abstractmethod @abstractmethod
def _init_client(self): def _init_client(self):
"""初始化客户端""" """Handle init client for this module for the Base L L M Client instance."""
pass pass
@abstractmethod @abstractmethod
@@ -61,18 +65,7 @@ class BaseLLMClient(ABC):
temperature: Optional[float] = None, temperature: Optional[float] = None,
**kwargs **kwargs
) -> LLMResponse: ) -> LLMResponse:
""" """Handle chat for the Base L L M Client instance."""
对话补全
Args:
messages: 对话消息列表 [{"role": "user/assistant/system", "content": "..."}]
max_tokens: 最大输出token数
temperature: 温度参数
**kwargs: 其他参数
Returns:
LLMResponse: 响应结果
"""
pass pass
def complete( def complete(
@@ -83,18 +76,7 @@ class BaseLLMClient(ABC):
temperature: Optional[float] = None, temperature: Optional[float] = None,
**kwargs **kwargs
) -> LLMResponse: ) -> LLMResponse:
""" """Handle complete for the Base L L M Client instance."""
单轮补全(便捷方法)
Args:
prompt: 用户输入
system_prompt: 系统提示词
max_tokens: 最大输出token数
temperature: 温度参数
Returns:
LLMResponse: 响应结果
"""
messages = [] messages = []
if system_prompt: if system_prompt:
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
@@ -104,12 +86,12 @@ class BaseLLMClient(ABC):
@abstractmethod @abstractmethod
def get_available_models(self) -> List[str]: def get_available_models(self) -> List[str]:
"""获取可用模型列表""" """Return available models for the Base L L M Client instance."""
pass pass
def estimate_tokens(self, text: str) -> int: def estimate_tokens(self, text: str) -> int:
"""估算文本token数粗略估计""" """Handle estimate tokens for the Base L L M Client instance."""
# 中文字符约1.5 token英文约0.25 token # Keep provider-specific behavior explicit so debugging stays straightforward.
chinese_chars = sum(1 for c in text if '' <= c <= '鿿') chinese_chars = sum(1 for c in text if '' <= c <= '鿿')
other_chars = len(text) - chinese_chars other_chars = len(text) - chinese_chars
return int(chinese_chars * 1.5 + other_chars * 0.25) return int(chinese_chars * 1.5 + other_chars * 0.25)

View File

@@ -1,4 +1,4 @@
"""DeepSeek LLM客户端 - OpenAI兼容API""" """Provide service-layer logic for deepseek client."""
import time import time
from typing import List, Dict, Optional from typing import List, Dict, Optional
@@ -6,20 +6,12 @@ from loguru import logger
import httpx import httpx
from .base_client import BaseLLMClient, LLMResponse, LLMConfig, LLMProvider from .base_client import BaseLLMClient, LLMResponse, LLMConfig, LLMProvider
# Keep provider-specific behavior explicit so debugging stays straightforward.
class DeepSeekClient(BaseLLMClient): class DeepSeekClient(BaseLLMClient):
""" """Represent the Deep Seek Client type."""
DeepSeek API客户端OpenAI兼容格式
支持模型:
- deepseek-chat
- deepseek-coder
- deepseek-reasoner
- deepseek-v3
- deepseek-v3.2
- deepseek-v4-flash
"""
SUPPORTED_MODELS = [ SUPPORTED_MODELS = [
"deepseek-chat", "deepseek-chat",
@@ -31,13 +23,14 @@ class DeepSeekClient(BaseLLMClient):
] ]
def __init__(self, config: LLMConfig): def __init__(self, config: LLMConfig):
"""Initialize the Deep Seek Client instance."""
if config.provider != LLMProvider.DEEPSEEK: if config.provider != LLMProvider.DEEPSEEK:
raise ValueError(f"配置provider应为DEEPSEEK实际为{config.provider}") raise ValueError(f"配置provider应为DEEPSEEK实际为{config.provider}")
super().__init__(config) super().__init__(config)
self._init_client() self._init_client()
def _init_client(self): def _init_client(self):
"""初始化HTTP客户端""" """Handle init client for this module for the Deep Seek Client instance."""
self._client = httpx.Client( self._client = httpx.Client(
base_url=self.config.base_url, base_url=self.config.base_url,
headers={ headers={
@@ -55,7 +48,7 @@ class DeepSeekClient(BaseLLMClient):
temperature: Optional[float] = None, temperature: Optional[float] = None,
**kwargs **kwargs
) -> LLMResponse: ) -> LLMResponse:
"""对话补全""" """Handle chat for the Deep Seek Client instance."""
start_time = time.time() start_time = time.time()
try: try:
@@ -103,11 +96,11 @@ class DeepSeekClient(BaseLLMClient):
) )
def get_available_models(self) -> List[str]: def get_available_models(self) -> List[str]:
"""获取可用模型列表""" """Return available models for the Deep Seek Client instance."""
return self.SUPPORTED_MODELS return self.SUPPORTED_MODELS
def close(self): def close(self):
"""关闭客户端""" """Release the resources held by this component."""
if self._client: if self._client:
self._client.close() self._client.close()
@@ -118,7 +111,7 @@ def create_deepseek_client(
base_url: str = "http://6.86.80.4:30080/v1", base_url: str = "http://6.86.80.4:30080/v1",
**kwargs **kwargs
) -> DeepSeekClient: ) -> DeepSeekClient:
"""便捷函数创建DeepSeek客户端""" """Create deepseek client."""
config = LLMConfig( config = LLMConfig(
provider=LLMProvider.DEEPSEEK, provider=LLMProvider.DEEPSEEK,
model=model, model=model,

View File

@@ -1,17 +1,20 @@
"""文档摘要生成服务 - LLM生成法规文档摘要""" """Provide service-layer logic for document summarizer."""
from typing import Dict, Optional from typing import Dict, Optional
from dataclasses import dataclass from dataclasses import dataclass
from loguru import logger from loguru import logger
from app.services.llm import get_llm_client, BaseLLMClient from app.services.llm.base_client import BaseLLMClient
from app.services.llm.llm_factory import get_llm_client
from app.services.rag.prompt_templates import get_prompt_template from app.services.rag.prompt_templates import get_prompt_template
from app.config.settings import settings from app.config.settings import settings
# Keep provider-specific behavior explicit so debugging stays straightforward.
@dataclass @dataclass
class DocumentSummary: class DocumentSummary:
"""文档摘要结果""" """Represent the Document Summary type."""
doc_name: str doc_name: str
summary: str summary: str
applicable_scope: str applicable_scope: str
@@ -24,24 +27,12 @@ class DocumentSummary:
@property @property
def is_success(self) -> bool: def is_success(self) -> bool:
"""Return whether success for the Document Summary instance."""
return self.error is None return self.error is None
class DocumentSummarizer: class DocumentSummarizer:
""" """Represent the Document Summarizer type."""
文档摘要生成器
功能:
- 生成法规文档的核心要点摘要
- 提取适用范围
- 突出关键条款
- 列出合规要点
使用示例:
summarizer = DocumentSummarizer()
result = summarizer.summarize("GB 7258-2017", markdown_content)
print(result.summary)
"""
def __init__( def __init__(
self, self,
@@ -49,25 +40,18 @@ class DocumentSummarizer:
model: str = None, model: str = None,
max_tokens: int = None max_tokens: int = None
): ):
""" """Initialize the Document Summarizer instance."""
初始化摘要生成器
Args:
provider: LLM提供商
model: LLM模型名称
max_tokens: 最大输出token数
"""
self.provider = provider or settings.llm_provider self.provider = provider or settings.llm_provider
self.model = model or settings.llm_model self.model = model or settings.llm_model
self.max_tokens = max_tokens or settings.rag_summary_max_tokens self.max_tokens = max_tokens or settings.rag_summary_max_tokens
# LLM客户端延迟加载 # Keep provider-specific behavior explicit so debugging stays straightforward.
self.llm: Optional[BaseLLMClient] = None self.llm: Optional[BaseLLMClient] = None
logger.info(f"摘要生成器初始化: provider={self.provider}, model={self.model}") logger.info(f"摘要生成器初始化: provider={self.provider}, model={self.model}")
def _init_llm(self): def _init_llm(self):
"""延迟初始化LLM""" """Handle init llm for this module for the Document Summarizer instance."""
if self.llm is None: if self.llm is None:
self.llm = get_llm_client( self.llm = get_llm_client(
provider=self.provider, provider=self.provider,
@@ -81,18 +65,7 @@ class DocumentSummarizer:
regulation_type: str = "", regulation_type: str = "",
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
) -> DocumentSummary: ) -> DocumentSummary:
""" """Handle summarize for the Document Summarizer instance."""
生成文档摘要
Args:
doc_name: 文档名称
content: 文档内容Markdown格式
regulation_type: 法规类型
max_tokens: 最大输出token数
Returns:
DocumentSummary: 摘要结果
"""
import time import time
start_time = time.time() start_time = time.time()
@@ -101,23 +74,23 @@ class DocumentSummarizer:
try: try:
self._init_llm() self._init_llm()
# 使用摘要模板 # Keep provider-specific behavior explicit so debugging stays straightforward.
template = get_prompt_template("document_summary") template = get_prompt_template("document_summary")
# 构建用户消息 # Keep provider-specific behavior explicit so debugging stays straightforward.
user_content = template.user_template.format( user_content = template.user_template.format(
doc_name=doc_name, doc_name=doc_name,
content=content[:8000] # 截取前8000字符避免超出token限制 content=content[:8000] # Keep provider-specific behavior explicit so debugging stays straightforward.
) )
# 调用LLM # Keep provider-specific behavior explicit so debugging stays straightforward.
response = self.llm.chat( response = self.llm.chat(
messages=[ messages=[
{"role": "system", "content": template.system_prompt}, {"role": "system", "content": template.system_prompt},
{"role": "user", "content": user_content} {"role": "user", "content": user_content}
], ],
max_tokens=max_tokens or self.max_tokens, max_tokens=max_tokens or self.max_tokens,
temperature=0.3 # 低温度保证摘要准确性 temperature=0.3 # Keep provider-specific behavior explicit so debugging stays straightforward.
) )
latency_ms = int((time.time() - start_time) * 1000) latency_ms = int((time.time() - start_time) * 1000)
@@ -135,7 +108,7 @@ class DocumentSummarizer:
error=response.error error=response.error
) )
# 解析摘要结构 # Keep provider-specific behavior explicit so debugging stays straightforward.
summary_data = self._parse_summary(response.content) summary_data = self._parse_summary(response.content)
logger.success(f"摘要生成完成: {doc_name}, {latency_ms}ms") logger.success(f"摘要生成完成: {doc_name}, {latency_ms}ms")
@@ -166,7 +139,7 @@ class DocumentSummarizer:
) )
def _parse_summary(self, content: str) -> Dict: def _parse_summary(self, content: str) -> Dict:
"""解析摘要内容(提取结构化信息)""" """Handle parse summary for this module for the Document Summarizer instance."""
result = { result = {
"summary": content, "summary": content,
"applicable_scope": "", "applicable_scope": "",
@@ -175,26 +148,26 @@ class DocumentSummarizer:
"compliance_points": [] "compliance_points": []
} }
# 简单解析(提取关键信息) # Keep provider-specific behavior explicit so debugging stays straightforward.
lines = content.split("\n") lines = content.split("\n")
for line in lines: for line in lines:
line = line.strip() line = line.strip()
# 提取适用范围 # Keep provider-specific behavior explicit so debugging stays straightforward.
if "适用范围" in line or "适用对象" in line: if "适用范围" in line or "适用对象" in line:
result["applicable_scope"] = line.split("")[-1].strip() if "" in line else line.split(":")[-1].strip() result["applicable_scope"] = line.split("")[-1].strip() if "" in line else line.split(":")[-1].strip()
# 提取关键条款 # Keep provider-specific behavior explicit so debugging stays straightforward.
if line.startswith("- 【条款") or line.startswith("【条款"): if line.startswith("- 【条款") or line.startswith("【条款"):
result["key_clauses"].append(line) result["key_clauses"].append(line)
# 提取关键术语 # Keep provider-specific behavior explicit so debugging stays straightforward.
if "关键术语" in line or "术语定义" in line: if "关键术语" in line or "术语定义" in line:
# 继续读取后续几行 # Keep provider-specific behavior explicit so debugging stays straightforward.
pass pass
# 提取合规要点 # Keep provider-specific behavior explicit so debugging stays straightforward.
if "合规要点" in line or "必须满足" in line: if "合规要点" in line or "必须满足" in line:
pass pass
@@ -204,15 +177,7 @@ class DocumentSummarizer:
self, self,
documents: list documents: list
) -> list: ) -> list:
""" """Handle batch summarize for the Document Summarizer instance."""
批量生成摘要
Args:
documents: 文档列表 [{"doc_name": str, "content": str}, ...]
Returns:
list: 摘要结果列表
"""
results = [] results = []
for doc in documents: for doc in documents:
result = self.summarize(doc["doc_name"], doc["content"]) result = self.summarize(doc["doc_name"], doc["content"])
@@ -225,6 +190,6 @@ def summarize_document(
content: str, content: str,
**kwargs **kwargs
) -> DocumentSummary: ) -> DocumentSummary:
"""便捷函数:生成文档摘要""" """Handle summarize document."""
summarizer = DocumentSummarizer(**kwargs) summarizer = DocumentSummarizer(**kwargs)
return summarizer.summarize(doc_name, content) return summarizer.summarize(doc_name, content)

View File

@@ -1,4 +1,4 @@
"""LLM工厂 - 统一创建和管理LLM客户端""" """Provide service-layer logic for llm factory."""
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
from loguru import logger from loguru import logger
@@ -7,16 +7,18 @@ from functools import lru_cache
from .base_client import BaseLLMClient, LLMConfig, LLMProvider, LLMResponse from .base_client import BaseLLMClient, LLMConfig, LLMProvider, LLMResponse
from .deepseek_client import DeepSeekClient from .deepseek_client import DeepSeekClient
from .qwen_client import QwenClient, QwenVLClient from .qwen_client import QwenClient, QwenVLClient
# Keep provider-specific behavior explicit so debugging stays straightforward.
# 默认模型映射
# Keep provider-specific behavior explicit so debugging stays straightforward.
DEFAULT_MODELS = { DEFAULT_MODELS = {
LLMProvider.DEEPSEEK: "deepseek-v4-flash", LLMProvider.DEEPSEEK: "deepseek-v4-flash",
LLMProvider.QWEN: "qwen3.5-flash", LLMProvider.QWEN: "qwen3.5-flash",
LLMProvider.QWEN_VL: "qwen3-vl-plus" LLMProvider.QWEN_VL: "qwen3-vl-plus"
} }
# API基础URL使用统一代理服务 # Keep provider-specific behavior explicit so debugging stays straightforward.
DEFAULT_BASE_URLS = { DEFAULT_BASE_URLS = {
LLMProvider.DEEPSEEK: "http://6.86.80.4:30080/v1", LLMProvider.DEEPSEEK: "http://6.86.80.4:30080/v1",
LLMProvider.QWEN: "http://6.86.80.4:30080/v1", LLMProvider.QWEN: "http://6.86.80.4:30080/v1",
@@ -25,31 +27,13 @@ DEFAULT_BASE_URLS = {
class LLMFactory: class LLMFactory:
""" """Represent the L L M Factory type."""
LLM客户端工厂支持全局缓存
支持的提供商和模型: # Keep provider-specific behavior explicit so debugging stays straightforward.
- DeepSeek: deepseek-chat (DeepSeek-V3), deepseek-coder
- Qwen: qwen-turbo, qwen-plus, qwen-max, qwen-long
- QwenVL: qwen-vl-plus, qwen-vl-max (多模态)
使用示例:
factory = LLMFactory()
# 使用默认配置
client = factory.create("deepseek")
# 自定义配置
client = factory.create("qwen", model="qwen-max", temperature=0.5)
# 调用LLM
response = client.complete("你好,介绍一下自己")
"""
# 全局客户端缓存(类级别,跨实例共享)
_global_instances: Dict[str, BaseLLMClient] = {} _global_instances: Dict[str, BaseLLMClient] = {}
def __init__(self): def __init__(self):
"""Initialize the L L M Factory instance."""
self._config_cache: Dict[str, Any] = {} self._config_cache: Dict[str, Any] = {}
def create( def create(
@@ -62,24 +46,10 @@ class LLMFactory:
temperature: float = 0.7, temperature: float = 0.7,
**kwargs **kwargs
) -> BaseLLMClient: ) -> BaseLLMClient:
""" """Handle create for the L L M Factory instance."""
创建LLM客户端
Args:
provider: 提供商名称 ("deepseek", "qwen", "qwen_vl")
api_key: API密钥如未提供从环境变量获取
model: 模型名称(如未提供,使用默认模型)
base_url: API基础URL
max_tokens: 最大输出token数
temperature: 温度参数
**kwargs: 其他配置参数
Returns:
BaseLLMClient: LLM客户端实例
"""
provider_enum = self._parse_provider(provider) provider_enum = self._parse_provider(provider)
# 获取配置 # Keep provider-specific behavior explicit so debugging stays straightforward.
api_key = api_key or self._get_api_key(provider_enum) api_key = api_key or self._get_api_key(provider_enum)
model = model or DEFAULT_MODELS.get(provider_enum) model = model or DEFAULT_MODELS.get(provider_enum)
base_url = base_url or DEFAULT_BASE_URLS.get(provider_enum) base_url = base_url or DEFAULT_BASE_URLS.get(provider_enum)
@@ -87,7 +57,7 @@ class LLMFactory:
if not api_key: if not api_key:
raise ValueError(f"缺少API密钥请设置环境变量或传入api_key参数") raise ValueError(f"缺少API密钥请设置环境变量或传入api_key参数")
# 检查全局缓存 # Keep provider-specific behavior explicit so debugging stays straightforward.
cache_key = f"{provider}_{model}" cache_key = f"{provider}_{model}"
if cache_key in LLMFactory._global_instances: if cache_key in LLMFactory._global_instances:
logger.debug(f"使用缓存的LLM客户端: {cache_key}") logger.debug(f"使用缓存的LLM客户端: {cache_key}")
@@ -103,17 +73,17 @@ class LLMFactory:
**kwargs **kwargs
) )
# 创建客户端 # Keep provider-specific behavior explicit so debugging stays straightforward.
client = self._create_client(config) client = self._create_client(config)
# 缓存到全局实例 # Keep provider-specific behavior explicit so debugging stays straightforward.
LLMFactory._global_instances[cache_key] = client LLMFactory._global_instances[cache_key] = client
logger.info(f"LLM客户端创建成功并缓存: {provider} - {model}") logger.info(f"LLM客户端创建成功并缓存: {provider} - {model}")
return client return client
def _parse_provider(self, provider: str) -> LLMProvider: def _parse_provider(self, provider: str) -> LLMProvider:
"""解析提供商名称""" """Handle parse provider for this module for the L L M Factory instance."""
provider_map = { provider_map = {
"deepseek": LLMProvider.DEEPSEEK, "deepseek": LLMProvider.DEEPSEEK,
"deepseek-v3": LLMProvider.DEEPSEEK, "deepseek-v3": LLMProvider.DEEPSEEK,
@@ -137,7 +107,7 @@ class LLMFactory:
return provider_map[provider_lower] return provider_map[provider_lower]
def _get_api_key(self, provider: LLMProvider) -> Optional[str]: def _get_api_key(self, provider: LLMProvider) -> Optional[str]:
"""从环境变量获取API密钥""" """Handle get api key for this module for the L L M Factory instance."""
import os import os
key_map = { key_map = {
@@ -154,7 +124,7 @@ class LLMFactory:
return None return None
def _create_client(self, config: LLMConfig) -> BaseLLMClient: def _create_client(self, config: LLMConfig) -> BaseLLMClient:
"""创建具体客户端""" """Handle create client for this module for the L L M Factory instance."""
client_map = { client_map = {
LLMProvider.DEEPSEEK: DeepSeekClient, LLMProvider.DEEPSEEK: DeepSeekClient,
LLMProvider.QWEN: QwenClient, LLMProvider.QWEN: QwenClient,
@@ -168,14 +138,14 @@ class LLMFactory:
return client_class(config) return client_class(config)
def get_cached(self, provider: str, model: Optional[str] = None) -> Optional[BaseLLMClient]: def get_cached(self, provider: str, model: Optional[str] = None) -> Optional[BaseLLMClient]:
"""获取缓存的客户端""" """Return cached for the L L M Factory instance."""
provider_enum = self._parse_provider(provider) provider_enum = self._parse_provider(provider)
model = model or DEFAULT_MODELS.get(provider_enum) model = model or DEFAULT_MODELS.get(provider_enum)
cache_key = f"{provider}_{model}" cache_key = f"{provider}_{model}"
return LLMFactory._global_instances.get(cache_key) return LLMFactory._global_instances.get(cache_key)
def list_available_providers(self) -> Dict[str, list]: def list_available_providers(self) -> Dict[str, list]:
"""列出可用的提供商和模型""" """List available providers for the L L M Factory instance."""
return { return {
"deepseek": DeepSeekClient.SUPPORTED_MODELS, "deepseek": DeepSeekClient.SUPPORTED_MODELS,
"qwen": QwenClient.SUPPORTED_MODELS, "qwen": QwenClient.SUPPORTED_MODELS,
@@ -184,12 +154,7 @@ class LLMFactory:
@classmethod @classmethod
def preload_clients(cls, providers: list = None): def preload_clients(cls, providers: list = None):
""" """Handle preload clients for the L L M Factory instance."""
预加载LLM客户端应用启动时调用
Args:
providers: 要预加载的提供商列表默认加载qwen和deepseek
"""
if providers is None: if providers is None:
providers = ["qwen", "deepseek"] providers = ["qwen", "deepseek"]
@@ -203,9 +168,9 @@ class LLMFactory:
@classmethod @classmethod
def get_global_client(cls, provider: str, model: Optional[str] = None) -> Optional[BaseLLMClient]: def get_global_client(cls, provider: str, model: Optional[str] = None) -> Optional[BaseLLMClient]:
"""获取全局缓存的客户端""" """Return global client for the L L M Factory instance."""
provider_lower = provider.lower() provider_lower = provider.lower()
# 处理模型名作为provider的情况(如 qwen3.5-flash # Keep provider-specific behavior explicit so debugging stays straightforward.
if provider_lower.startswith("qwen"): if provider_lower.startswith("qwen"):
provider_lower = "qwen" provider_lower = "qwen"
model = model or DEFAULT_MODELS.get(LLMProvider.QWEN if provider_lower == "qwen" else LLMProvider.DEEPSEEK) model = model or DEFAULT_MODELS.get(LLMProvider.QWEN if provider_lower == "qwen" else LLMProvider.DEEPSEEK)
@@ -214,7 +179,7 @@ class LLMFactory:
@classmethod @classmethod
def cleanup(cls): def cleanup(cls):
"""清理所有缓存的客户端""" """Handle cleanup for the L L M Factory instance."""
for cache_key, client in cls._global_instances.items(): for cache_key, client in cls._global_instances.items():
try: try:
client.close() client.close()
@@ -227,7 +192,7 @@ class LLMFactory:
@lru_cache @lru_cache
def get_llm_factory() -> LLMFactory: def get_llm_factory() -> LLMFactory:
"""获取LLM工厂实例缓存""" """Return llm factory."""
return LLMFactory() return LLMFactory()
@@ -236,20 +201,10 @@ def get_llm_client(
model: Optional[str] = None, model: Optional[str] = None,
**kwargs **kwargs
) -> BaseLLMClient: ) -> BaseLLMClient:
""" """Return llm client."""
便捷函数获取LLM客户端优先使用缓存
Args:
provider: 提供商名称
model: 模型名称
**kwargs: 其他配置
Returns:
BaseLLMClient: LLM客户端实例
"""
factory = get_llm_factory() factory = get_llm_factory()
# 先尝试获取缓存的实例 # Keep provider-specific behavior explicit so debugging stays straightforward.
cached = factory.get_cached(provider, model) cached = factory.get_cached(provider, model)
if cached: if cached:
return cached return cached

View File

@@ -1,4 +1,4 @@
"""Qwen LLM客户端 - 支持OpenAI兼容API格式""" """Provide service-layer logic for qwen client."""
import time import time
import json import json
@@ -7,21 +7,12 @@ from loguru import logger
import httpx import httpx
from .base_client import BaseLLMClient, LLMResponse, LLMConfig, LLMProvider from .base_client import BaseLLMClient, LLMResponse, LLMConfig, LLMProvider
# Keep provider-specific behavior explicit so debugging stays straightforward.
class QwenClient(BaseLLMClient): class QwenClient(BaseLLMClient):
""" """Represent the Qwen Client type."""
Qwen API客户端OpenAI兼容格式
支持通过new-api等代理服务调用
- qwen-turbo
- qwen-plus
- qwen-max
- qwen3.5-flash (推荐:快速响应)
- qwen3.5-plus
- qwen-long
- qwen2.5系列
"""
SUPPORTED_MODELS = [ SUPPORTED_MODELS = [
"qwen-turbo", "qwen-turbo",
@@ -39,14 +30,15 @@ class QwenClient(BaseLLMClient):
] ]
def __init__(self, config: LLMConfig): def __init__(self, config: LLMConfig):
"""Initialize the Qwen Client instance."""
if config.provider not in [LLMProvider.QWEN, LLMProvider.QWEN_VL]: if config.provider not in [LLMProvider.QWEN, LLMProvider.QWEN_VL]:
raise ValueError(f"配置provider应为Qwen实际为{config.provider}") raise ValueError(f"配置provider应为Qwen实际为{config.provider}")
super().__init__(config) super().__init__(config)
self._init_client() self._init_client()
def _init_client(self): def _init_client(self):
"""初始化HTTP客户端""" """Handle init client for this module for the Qwen Client instance."""
# OpenAI兼容API格式 # Keep provider-specific behavior explicit so debugging stays straightforward.
self._client = httpx.Client( self._client = httpx.Client(
base_url=self.config.base_url, base_url=self.config.base_url,
headers={ headers={
@@ -64,11 +56,11 @@ class QwenClient(BaseLLMClient):
temperature: Optional[float] = None, temperature: Optional[float] = None,
**kwargs **kwargs
) -> LLMResponse: ) -> LLMResponse:
"""对话补全OpenAI兼容格式""" """Handle chat for the Qwen Client instance."""
start_time = time.time() start_time = time.time()
try: try:
# OpenAI兼容格式的请求体 # Keep provider-specific behavior explicit so debugging stays straightforward.
payload = { payload = {
"model": self.config.model, "model": self.config.model,
"messages": messages, "messages": messages,
@@ -78,7 +70,7 @@ class QwenClient(BaseLLMClient):
"stream": False "stream": False
} }
# OpenAI兼容接口路径 # Keep provider-specific behavior explicit so debugging stays straightforward.
response = self._client.post("/chat/completions", json=payload) response = self._client.post("/chat/completions", json=payload)
response.raise_for_status() response.raise_for_status()
@@ -86,7 +78,7 @@ class QwenClient(BaseLLMClient):
latency_ms = int((time.time() - start_time) * 1000) latency_ms = int((time.time() - start_time) * 1000)
# OpenAI兼容格式的响应解析 # Keep provider-specific behavior explicit so debugging stays straightforward.
choices = data.get("choices", [{}]) choices = data.get("choices", [{}])
message = choices[0].get("message", {}) message = choices[0].get("message", {})
@@ -121,42 +113,33 @@ class QwenClient(BaseLLMClient):
temperature: Optional[float] = None, temperature: Optional[float] = None,
**kwargs **kwargs
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
""" """Stream chat for the Qwen Client instance."""
流式对话补全SSE格式
Yields:
str: 每次返回一个文本片段
使用示例:
for chunk in client.stream_chat(messages):
print(chunk, end="", flush=True)
"""
try: try:
# OpenAI兼容格式的请求体启用流式输出 # Keep provider-specific behavior explicit so debugging stays straightforward.
payload = { payload = {
"model": self.config.model, "model": self.config.model,
"messages": messages, "messages": messages,
"max_tokens": max_tokens or self.config.max_tokens, "max_tokens": max_tokens or self.config.max_tokens,
"temperature": temperature or self.config.temperature, "temperature": temperature or self.config.temperature,
"top_p": kwargs.get("top_p", self.config.top_p), "top_p": kwargs.get("top_p", self.config.top_p),
"stream": True # 启用流式输出 "stream": True # Keep provider-specific behavior explicit so debugging stays straightforward.
} }
# 使用stream模式发送请求 # Keep provider-specific behavior explicit so debugging stays straightforward.
with self._client.stream("POST", "/chat/completions", json=payload) as response: with self._client.stream("POST", "/chat/completions", json=payload) as response:
for line in response.iter_lines(): for line in response.iter_lines():
if line: if line:
line = line.strip() line = line.strip()
# SSE格式: data: {...} # Keep provider-specific behavior explicit so debugging stays straightforward.
if line.startswith("data: "): if line.startswith("data: "):
data_str = line[6:] # 移除 "data: " 前缀 data_str = line[6:] # Keep provider-specific behavior explicit so debugging stays straightforward.
if data_str == "[DONE]": if data_str == "[DONE]":
break break
try: try:
data = json.loads(data_str) data = json.loads(data_str)
choices = data.get("choices", []) choices = data.get("choices", [])
if not choices: if not choices:
continue # 跳过空的choices continue # Keep provider-specific behavior explicit so debugging stays straightforward.
delta = choices[0].get("delta", {}) delta = choices[0].get("delta", {})
content = delta.get("content", "") content = delta.get("content", "")
if content: if content:
@@ -179,41 +162,27 @@ class QwenClient(BaseLLMClient):
temperature: Optional[float] = None, temperature: Optional[float] = None,
**kwargs **kwargs
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
""" """Handle async stream chat for the Qwen Client instance."""
异步流式对话补全用于FastAPI SSE响应
Yields:
str: 每次返回一个文本片段
"""
import asyncio import asyncio
# 使用同步流式方法,包装为异步 # Keep provider-specific behavior explicit so debugging stays straightforward.
for chunk in self.stream_chat(messages, max_tokens, temperature, **kwargs): for chunk in self.stream_chat(messages, max_tokens, temperature, **kwargs):
yield chunk yield chunk
# 给async循环一个小延迟让其他任务有机会执行 # Keep provider-specific behavior explicit so debugging stays straightforward.
await asyncio.sleep(0) await asyncio.sleep(0)
def get_available_models(self) -> List[str]: def get_available_models(self) -> List[str]:
"""获取可用模型列表""" """Return available models for the Qwen Client instance."""
return self.SUPPORTED_MODELS return self.SUPPORTED_MODELS
def close(self): def close(self):
"""关闭客户端""" """Release the resources held by this component."""
if self._client: if self._client:
self._client.close() self._client.close()
class QwenVLClient(BaseLLMClient): class QwenVLClient(BaseLLMClient):
""" """Represent the Qwen V L Client type."""
Qwen VL多模态客户端OpenAI兼容格式
支持模型:
- qwen-vl-plus
- qwen-vl-max
- qwen3-vl-plus
- qwen2-vl-7b-instruct
- qwen2-vl-72b-instruct
"""
SUPPORTED_MODELS = [ SUPPORTED_MODELS = [
"qwen-vl-plus", "qwen-vl-plus",
@@ -224,13 +193,14 @@ class QwenVLClient(BaseLLMClient):
] ]
def __init__(self, config: LLMConfig): def __init__(self, config: LLMConfig):
"""Initialize the Qwen V L Client instance."""
if config.provider != LLMProvider.QWEN_VL: if config.provider != LLMProvider.QWEN_VL:
raise ValueError(f"配置provider应为QWEN_VL实际为{config.provider}") raise ValueError(f"配置provider应为QWEN_VL实际为{config.provider}")
super().__init__(config) super().__init__(config)
self._init_client() self._init_client()
def _init_client(self): def _init_client(self):
"""初始化HTTP客户端""" """Handle init client for this module for the Qwen V L Client instance."""
self._client = httpx.Client( self._client = httpx.Client(
base_url=self.config.base_url, base_url=self.config.base_url,
headers={ headers={
@@ -248,21 +218,11 @@ class QwenVLClient(BaseLLMClient):
temperature: Optional[float] = None, temperature: Optional[float] = None,
**kwargs **kwargs
) -> LLMResponse: ) -> LLMResponse:
"""多模态对话补全OpenAI兼容格式 """Handle chat for the Qwen V L Client instance."""
支持图片输入,消息格式:
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}},
{"type": "text", "text": "描述这张图片"}
]
}
"""
start_time = time.time() start_time = time.time()
try: try:
# OpenAI兼容格式的请求体 # Keep provider-specific behavior explicit so debugging stays straightforward.
payload = { payload = {
"model": self.config.model, "model": self.config.model,
"messages": messages, "messages": messages,
@@ -312,7 +272,7 @@ class QwenVLClient(BaseLLMClient):
temperature: Optional[float] = None, temperature: Optional[float] = None,
**kwargs **kwargs
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
"""流式多模态对话补全""" """Stream chat for the Qwen V L Client instance."""
try: try:
payload = { payload = {
"model": self.config.model, "model": self.config.model,
@@ -335,7 +295,7 @@ class QwenVLClient(BaseLLMClient):
data = json.loads(data_str) data = json.loads(data_str)
choices = data.get("choices", []) choices = data.get("choices", [])
if not choices: if not choices:
continue # 跳过空的choices continue # Keep provider-specific behavior explicit so debugging stays straightforward.
delta = choices[0].get("delta", {}) delta = choices[0].get("delta", {})
content = delta.get("content", "") content = delta.get("content", "")
if content: if content:
@@ -348,11 +308,11 @@ class QwenVLClient(BaseLLMClient):
yield f"[ERROR: {str(e)}]" yield f"[ERROR: {str(e)}]"
def get_available_models(self) -> List[str]: def get_available_models(self) -> List[str]:
"""获取可用模型列表""" """Return available models for the Qwen V L Client instance."""
return self.SUPPORTED_MODELS return self.SUPPORTED_MODELS
def close(self): def close(self):
"""关闭客户端""" """Release the resources held by this component."""
if self._client: if self._client:
self._client.close() self._client.close()
@@ -363,7 +323,7 @@ def create_qwen_client(
base_url: str = "http://6.86.80.4:30080/v1", base_url: str = "http://6.86.80.4:30080/v1",
**kwargs **kwargs
) -> QwenClient: ) -> QwenClient:
"""便捷函数创建Qwen客户端""" """Create qwen client."""
config = LLMConfig( config = LLMConfig(
provider=LLMProvider.QWEN, provider=LLMProvider.QWEN,
model=model, model=model,
@@ -380,7 +340,7 @@ def create_qwen_vl_client(
base_url: str = "http://6.86.80.4:30080/v1", base_url: str = "http://6.86.80.4:30080/v1",
**kwargs **kwargs
) -> QwenVLClient: ) -> QwenVLClient:
"""便捷函数创建QwenVL客户端""" """Create qwen vl client."""
config = LLMConfig( config = LLMConfig(
provider=LLMProvider.QWEN_VL, provider=LLMProvider.QWEN_VL,
model=model, model=model,

View File

@@ -1,12 +1,12 @@
""" """Provide service-layer logic for mock data."""
Mock数据服务 - 提供预设假数据供前后端对接测试
"""
from datetime import datetime from datetime import datetime
from typing import Dict, List, Any from typing import Dict, List, Any
import uuid import uuid
# Keep service responsibilities explicit so downstream behavior stays predictable.
# 预设法规文档列表
# Keep service responsibilities explicit so downstream behavior stays predictable.
MOCK_DOCUMENTS: List[Dict[str, Any]] = [ MOCK_DOCUMENTS: List[Dict[str, Any]] = [
{ {
"id": "doc-001", "id": "doc-001",
@@ -45,7 +45,7 @@ MOCK_DOCUMENTS: List[Dict[str, Any]] = [
}, },
] ]
# 预设快捷问题 # Keep service responsibilities explicit so downstream behavior stays predictable.
MOCK_QUICK_QUESTIONS: List[Dict[str, str]] = [ MOCK_QUICK_QUESTIONS: List[Dict[str, str]] = [
{"id": "q1", "question": "电动自行车需要上牌照吗?", "category": "车辆登记"}, {"id": "q1", "question": "电动自行车需要上牌照吗?", "category": "车辆登记"},
{"id": "q2", "question": "新能源汽车有哪些补贴政策?", "category": "新能源"}, {"id": "q2", "question": "新能源汽车有哪些补贴政策?", "category": "新能源"},
@@ -53,7 +53,7 @@ MOCK_QUICK_QUESTIONS: List[Dict[str, str]] = [
{"id": "q4", "question": "驾驶证过期了怎么处理?", "category": "驾驶证"}, {"id": "q4", "question": "驾驶证过期了怎么处理?", "category": "驾驶证"},
] ]
# 预设检索结果 # Keep service responsibilities explicit so downstream behavior stays predictable.
MOCK_RETRIEVAL_RESULTS: List[Dict[str, Any]] = [ MOCK_RETRIEVAL_RESULTS: List[Dict[str, Any]] = [
{ {
"id": "chunk-001", "id": "chunk-001",
@@ -97,7 +97,7 @@ MOCK_RETRIEVAL_RESULTS: List[Dict[str, Any]] = [
}, },
] ]
# 预设RAG问答答案模板按关键词匹配 # Keep service responsibilities explicit so downstream behavior stays predictable.
MOCK_RAG_ANSWERS: Dict[str, Dict[str, Any]] = { MOCK_RAG_ANSWERS: Dict[str, Dict[str, Any]] = {
"电动自行车": { "电动自行车": {
"text": "根据《道路交通安全法》及相关规范,电动自行车上路需满足以下条件:\n\n1. 符合国家标准 GB17761-2018\n2. 经公安机关交通管理部门登记\n3. 最高设计车速不超过 25km/h\n4. 整车质量不超过 55kg\n5. 具有脚踏骑行能力\n6. 蓄电池标称电压不超过 48V\n\n行驶时还需佩戴安全头盔,不得逆向行驶或在机动车道内行驶。", "text": "根据《道路交通安全法》及相关规范,电动自行车上路需满足以下条件:\n\n1. 符合国家标准 GB17761-2018\n2. 经公安机关交通管理部门登记\n3. 最高设计车速不超过 25km/h\n4. 整车质量不超过 55kg\n5. 具有脚踏骑行能力\n6. 蓄电池标称电压不超过 48V\n\n行驶时还需佩戴安全头盔,不得逆向行驶或在机动车道内行驶。",
@@ -133,7 +133,7 @@ MOCK_RAG_ANSWERS: Dict[str, Dict[str, Any]] = {
}, },
} }
# 预设合规分析结果 # Keep service responsibilities explicit so downstream behavior stays predictable.
MOCK_COMPLIANCE_RESULT: Dict[str, Any] = { MOCK_COMPLIANCE_RESULT: Dict[str, Any] = {
"task_id": "task-001", "task_id": "task-001",
"dashboard": { "dashboard": {
@@ -310,7 +310,7 @@ MOCK_COMPLIANCE_RESULT: Dict[str, Any] = {
], ],
} }
# 预设合规对话响应模板 # Keep service responsibilities explicit so downstream behavior stays predictable.
MOCK_COMPLIANCE_CHAT_RESPONSES: Dict[str, Dict[str, str]] = { MOCK_COMPLIANCE_CHAT_RESPONSES: Dict[str, Dict[str, str]] = {
"车身结构设计": { "车身结构设计": {
"compliance": "根据当前分析,车身结构设计部分存在以下合规问题:\n\n1. GB 26112-2010要求车顶承受1.5倍整备质量载荷,目前设计声明满足要求但缺少测试数据\n2. C-NCAP正面碰撞后车门应能打开需提供碰撞测试报告\n\n建议补充相关测试数据以提升合规评分。", "compliance": "根据当前分析,车身结构设计部分存在以下合规问题:\n\n1. GB 26112-2010要求车顶承受1.5倍整备质量载荷,目前设计声明满足要求但缺少测试数据\n2. C-NCAP正面碰撞后车门应能打开需提供碰撞测试报告\n\n建议补充相关测试数据以提升合规评分。",
@@ -329,7 +329,7 @@ MOCK_COMPLIANCE_CHAT_RESPONSES: Dict[str, Dict[str, str]] = {
}, },
} }
# 预设系统统计数据 # Keep service responsibilities explicit so downstream behavior stays predictable.
MOCK_SYSTEM_STATS: Dict[str, int] = { MOCK_SYSTEM_STATS: Dict[str, int] = {
"docs": 5, "docs": 5,
"chunks": 510, "chunks": 510,
@@ -337,7 +337,7 @@ MOCK_SYSTEM_STATS: Dict[str, int] = {
"segments": 0, "segments": 0,
} }
# 预设系统配置 # Keep service responsibilities explicit so downstream behavior stays predictable.
MOCK_SYSTEM_CONFIG: Dict[str, Any] = { MOCK_SYSTEM_CONFIG: Dict[str, Any] = {
"llm": { "llm": {
"model": "qwen-max", "model": "qwen-max",
@@ -358,17 +358,17 @@ MOCK_SYSTEM_CONFIG: Dict[str, Any] = {
def get_mock_documents() -> List[Dict[str, Any]]: def get_mock_documents() -> List[Dict[str, Any]]:
"""获取预设法规文档列表""" """Return mock documents."""
return MOCK_DOCUMENTS return MOCK_DOCUMENTS
def get_mock_quick_questions() -> List[Dict[str, str]]: def get_mock_quick_questions() -> List[Dict[str, str]]:
"""获取预设快捷问题""" """Return mock quick questions."""
return MOCK_QUICK_QUESTIONS return MOCK_QUICK_QUESTIONS
def get_mock_retrieval(query: str, top_k: int = 5) -> List[Dict[str, Any]]: def get_mock_retrieval(query: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""根据查询关键词返回预设检索结果""" """Return mock retrieval."""
results = [] results = []
for keyword, data in MOCK_RAG_ANSWERS.items(): for keyword, data in MOCK_RAG_ANSWERS.items():
if keyword in query: if keyword in query:
@@ -389,7 +389,7 @@ def get_mock_retrieval(query: str, top_k: int = 5) -> List[Dict[str, Any]]:
def get_mock_rag_answer(query: str) -> str: def get_mock_rag_answer(query: str) -> str:
"""根据查询关键词返回预设答案""" """Return mock rag answer."""
for keyword, data in MOCK_RAG_ANSWERS.items(): for keyword, data in MOCK_RAG_ANSWERS.items():
if keyword in query: if keyword in query:
return data["text"] return data["text"]
@@ -397,14 +397,14 @@ def get_mock_rag_answer(query: str) -> str:
def get_mock_compliance_result(task_id: str) -> Dict[str, Any]: def get_mock_compliance_result(task_id: str) -> Dict[str, Any]:
"""获取预设合规分析结果""" """Return mock compliance result."""
result = MOCK_COMPLIANCE_RESULT.copy() result = MOCK_COMPLIANCE_RESULT.copy()
result["task_id"] = task_id result["task_id"] = task_id
return result return result
def get_mock_compliance_chat_response(intent: str, query: str) -> str: def get_mock_compliance_chat_response(intent: str, query: str) -> str:
"""获取预设合规对话响应""" """Return mock compliance chat response."""
responses = MOCK_COMPLIANCE_CHAT_RESPONSES.get(intent, {}) responses = MOCK_COMPLIANCE_CHAT_RESPONSES.get(intent, {})
if "合规" in query or "符合" in query: if "合规" in query or "符合" in query:
return responses.get("compliance", "根据相关法规分析,该段落的合规性需进一步评估。") return responses.get("compliance", "根据相关法规分析,该段落的合规性需进一步评估。")
@@ -416,10 +416,10 @@ def get_mock_compliance_chat_response(intent: str, query: str) -> str:
def generate_task_id() -> str: def generate_task_id() -> str:
"""生成任务ID""" """Handle generate task id."""
return f"task-{uuid.uuid4().hex[:8]}" return f"task-{uuid.uuid4().hex[:8]}"
def generate_doc_id() -> str: def generate_doc_id() -> str:
"""生成文档ID""" """Handle generate doc id."""
return f"doc-{uuid.uuid4().hex[:8]}" return f"doc-{uuid.uuid4().hex[:8]}"

View File

@@ -1,6 +1,8 @@
"""文档解析服务""" """Initialize the app.services.parser package."""
from .pdf_parser import PDFParser from .pdf_parser import PDFParser
from .docx_parser import DocxParser from .docx_parser import DocxParser
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = ["PDFParser", "DocxParser"] __all__ = ["PDFParser", "DocxParser"]

View File

@@ -1,4 +1,4 @@
"""Word文档解析 - 使用python-docx""" """Provide service-layer logic for docx parser."""
from docx import Document from docx import Document
from docx.enum.text import WD_ALIGN_PARAGRAPH from docx.enum.text import WD_ALIGN_PARAGRAPH
@@ -6,27 +6,29 @@ from typing import List, Dict, Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
from loguru import logger from loguru import logger
import re import re
# Keep service responsibilities explicit so downstream behavior stays predictable.
@dataclass @dataclass
class DocxParagraph: class DocxParagraph:
"""段落内容""" """Represent the Docx Paragraph type."""
text: str text: str
level: int = 0 # 标题级别0表示正文 level: int = 0 # Keep service responsibilities explicit so downstream behavior stays predictable.
is_list: bool = False is_list: bool = False
list_number: Optional[str] = None list_number: Optional[str] = None
@dataclass @dataclass
class DocxTable: class DocxTable:
"""表格内容""" """Represent the Docx Table type."""
rows: List[List[str]] rows: List[List[str]]
markdown: str = "" markdown: str = ""
@dataclass @dataclass
class DocxDocumentContent: class DocxDocumentContent:
"""Word文档完整内容""" """Represent the Docx Document Content type."""
file_path: str file_path: str
paragraphs: List[DocxParagraph] paragraphs: List[DocxParagraph]
tables: List[DocxTable] tables: List[DocxTable]
@@ -35,21 +37,14 @@ class DocxDocumentContent:
class DocxParser: class DocxParser:
"""Word文档解析器 - 基于python-docx""" """Provide the Docx Parser parser."""
def __init__(self): def __init__(self):
"""Initialize the Docx Parser instance."""
self.document = None self.document = None
def parse(self, file_path: str) -> DocxDocumentContent: def parse(self, file_path: str) -> DocxDocumentContent:
""" """Handle parse for the Docx Parser instance."""
解析Word文档
Args:
file_path: Word文档路径
Returns:
DocxDocumentContent: 解析后的文档内容
"""
logger.info(f"开始解析Word文档: {file_path}") logger.info(f"开始解析Word文档: {file_path}")
try: try:
@@ -60,16 +55,16 @@ class DocxParser:
tables=[] tables=[]
) )
# 提取文档元数据 # Keep service responsibilities explicit so downstream behavior stays predictable.
doc_content.metadata = self._extract_metadata() doc_content.metadata = self._extract_metadata()
# 提取段落 # Keep service responsibilities explicit so downstream behavior stays predictable.
doc_content.paragraphs = self._extract_paragraphs() doc_content.paragraphs = self._extract_paragraphs()
# 提取表格 # Keep service responsibilities explicit so downstream behavior stays predictable.
doc_content.tables = self._extract_tables() doc_content.tables = self._extract_tables()
# 生成Markdown格式文本 # Keep service responsibilities explicit so downstream behavior stays predictable.
doc_content.markdown_text = self._generate_markdown(doc_content) doc_content.markdown_text = self._generate_markdown(doc_content)
logger.success(f"Word文档解析完成{len(doc_content.paragraphs)}个段落") logger.success(f"Word文档解析完成{len(doc_content.paragraphs)}个段落")
@@ -81,7 +76,7 @@ class DocxParser:
raise raise
def _extract_metadata(self) -> Dict[str, str]: def _extract_metadata(self) -> Dict[str, str]:
"""提取文档元数据""" """Handle extract metadata for this module for the Docx Parser instance."""
metadata = {} metadata = {}
try: try:
core_props = self.document.core_properties core_props = self.document.core_properties
@@ -98,7 +93,7 @@ class DocxParser:
return metadata return metadata
def _extract_paragraphs(self) -> List[DocxParagraph]: def _extract_paragraphs(self) -> List[DocxParagraph]:
"""提取所有段落""" """Handle extract paragraphs for this module for the Docx Parser instance."""
paragraphs = [] paragraphs = []
for para in self.document.paragraphs: for para in self.document.paragraphs:
@@ -106,10 +101,10 @@ class DocxParser:
if not text: if not text:
continue continue
# 判断标题级别 # Keep service responsibilities explicit so downstream behavior stays predictable.
level = self._get_paragraph_level(para) level = self._get_paragraph_level(para)
# 判断是否是列表项 # Keep service responsibilities explicit so downstream behavior stays predictable.
is_list, list_number = self._detect_list_item(para) is_list, list_number = self._detect_list_item(para)
paragraph = DocxParagraph( paragraph = DocxParagraph(
@@ -123,66 +118,61 @@ class DocxParser:
return paragraphs return paragraphs
def _get_paragraph_level(self, para) -> int: def _get_paragraph_level(self, para) -> int:
""" """Handle get paragraph level for this module for the Docx Parser instance."""
判断段落标题级别 # Keep service responsibilities explicit so downstream behavior stays predictable.
Returns:
int: 标题级别0表示正文
"""
# 方法1检查段落样式
style_name = para.style.name if para.style else "" style_name = para.style.name if para.style else ""
if "Heading" in style_name or "标题" in style_name: if "Heading" in style_name or "标题" in style_name:
# 从样式名称中提取级别 # Keep service responsibilities explicit so downstream behavior stays predictable.
match = re.search(r'Heading\s*(\d)|标题\s*(\d)', style_name) match = re.search(r'Heading\s*(\d)|标题\s*(\d)', style_name)
if match: if match:
level = int(match.group(1) or match.group(2)) level = int(match.group(1) or match.group(2))
return level return level
# 方法2检查段落格式字号 # Keep service responsibilities explicit so downstream behavior stays predictable.
# 标题通常字号较大 # Keep service responsibilities explicit so downstream behavior stays predictable.
if para.paragraph_format: if para.paragraph_format:
# 可以根据字号判断,这里简化处理 # Keep service responsibilities explicit so downstream behavior stays predictable.
pass pass
# 方法3根据内容模式判断法规文档特征 # Keep service responsibilities explicit so downstream behavior stays predictable.
text = para.text.strip() text = para.text.strip()
# 第一章、第X章 -> 二级标题 # Keep service responsibilities explicit so downstream behavior stays predictable.
if re.match(r'^第[一二三四五六七八九十百]+章\s', text): if re.match(r'^第[一二三四五六七八九十百]+章\s', text):
return 2 return 2
# 第X节 -> 三级标题 # Keep service responsibilities explicit so downstream behavior stays predictable.
elif re.match(r'^第[一二三四五六七八九十百]+节\s', text): elif re.match(r'^第[一二三四五六七八九十百]+节\s', text):
return 3 return 3
# 第X条 -> 四级标题 # Keep service responsibilities explicit so downstream behavior stays predictable.
elif re.match(r'^第[一二三四五六七八九十百]+条\s', text): elif re.match(r'^第[一二三四五六七八九十百]+条\s', text):
return 4 return 4
return 0 # 正文 return 0 # Keep service responsibilities explicit so downstream behavior stays predictable.
def _detect_list_item(self, para) -> tuple[bool, Optional[str]]: def _detect_list_item(self, para) -> tuple[bool, Optional[str]]:
"""检测是否是列表项""" """Handle detect list item for this module for the Docx Parser instance."""
text = para.text.strip() text = para.text.strip()
# 数字列表1.、2.、1、[1]等 # Keep service responsibilities explicit so downstream behavior stays predictable.
if re.match(r'^[\d]+[.、)\]]\s', text): if re.match(r'^[\d]+[.、)\]]\s', text):
match = re.match(r'^([\d]+[.、)\]])\s', text) match = re.match(r'^([\d]+[.、)\]])\s', text)
return True, match.group(1) if match else None return True, match.group(1) if match else None
# 中文数字列表:一、二、(一)等 # Keep service responsibilities explicit so downstream behavior stays predictable.
if re.match(r'^[一二三四五六七八九十]+[、.)]\s', text): if re.match(r'^[一二三四五六七八九十]+[、.)]\s', text):
match = re.match(r'^([一二三四五六七八九十]+[、.)])\s', text) match = re.match(r'^([一二三四五六七八九十]+[、.)])\s', text)
return True, match.group(1) if match else None return True, match.group(1) if match else None
# 检查段落格式中的列表编号 # Keep service responsibilities explicit so downstream behavior stays predictable.
if para.paragraph_format and hasattr(para.paragraph_format, 'left_indent'): if para.paragraph_format and hasattr(para.paragraph_format, 'left_indent'):
# 有缩进的可能是列表项 # Keep service responsibilities explicit so downstream behavior stays predictable.
pass pass
return False, None return False, None
def _extract_tables(self) -> List[DocxTable]: def _extract_tables(self) -> List[DocxTable]:
"""提取所有表格""" """Handle extract tables for this module for the Docx Parser instance."""
tables = [] tables = []
for table in self.document.tables: for table in self.document.tables:
@@ -193,7 +183,7 @@ class DocxParser:
cells.append(cell.text.strip()) cells.append(cell.text.strip())
rows.append(cells) rows.append(cells)
# 转换为Markdown表格 # Keep service responsibilities explicit so downstream behavior stays predictable.
markdown = self._table_to_markdown(rows) markdown = self._table_to_markdown(rows)
table_content = DocxTable(rows=rows, markdown=markdown) table_content = DocxTable(rows=rows, markdown=markdown)
@@ -202,34 +192,34 @@ class DocxParser:
return tables return tables
def _table_to_markdown(self, rows: List[List[str]]) -> str: def _table_to_markdown(self, rows: List[List[str]]) -> str:
"""将表格转换为Markdown格式""" """Handle table to markdown for this module for the Docx Parser instance."""
if not rows or len(rows) < 1: if not rows or len(rows) < 1:
return "" return ""
lines = [] lines = []
# 表头 # Keep service responsibilities explicit so downstream behavior stays predictable.
if len(rows) >= 1: if len(rows) >= 1:
header = rows[0] header = rows[0]
lines.append("| " + " | ".join(cell for cell in header) + " |") lines.append("| " + " | ".join(cell for cell in header) + " |")
lines.append("| " + " | ".join("---" for _ in header) + " |") lines.append("| " + " | ".join("---" for _ in header) + " |")
# 数据行 # Keep service responsibilities explicit so downstream behavior stays predictable.
for row in rows[1:]: for row in rows[1:]:
lines.append("| " + " | ".join(cell for cell in row) + " |") lines.append("| " + " | ".join(cell for cell in row) + " |")
return "\n".join(lines) return "\n".join(lines)
def _generate_markdown(self, doc_content: DocxDocumentContent) -> str: def _generate_markdown(self, doc_content: DocxDocumentContent) -> str:
"""生成Markdown格式文本""" """Handle generate markdown for this module for the Docx Parser instance."""
lines = [] lines = []
# 文档标题 # Keep service responsibilities explicit so downstream behavior stays predictable.
title = doc_content.metadata.get("title", "") title = doc_content.metadata.get("title", "")
if title: if title:
lines.append(f"# {title}\n") lines.append(f"# {title}\n")
else: else:
# 从第一个段落获取标题(如果是标题样式) # Keep service responsibilities explicit so downstream behavior stays predictable.
for para in doc_content.paragraphs[:5]: for para in doc_content.paragraphs[:5]:
if para.level == 1: if para.level == 1:
lines.append(f"# {para.text}\n") lines.append(f"# {para.text}\n")
@@ -237,29 +227,29 @@ class DocxParser:
else: else:
lines.append(f"# {doc_content.file_path}\n") lines.append(f"# {doc_content.file_path}\n")
# 元数据信息 # Keep service responsibilities explicit so downstream behavior stays predictable.
lines.append("\n## 文档信息\n") lines.append("\n## 文档信息\n")
for key, value in doc_content.metadata.items(): for key, value in doc_content.metadata.items():
if value: if value:
lines.append(f"- **{key}**: {value}") lines.append(f"- **{key}**: {value}")
# 正文内容 # Keep service responsibilities explicit so downstream behavior stays predictable.
lines.append("\n## 正文\n") lines.append("\n## 正文\n")
table_index = 0 table_index = 0
for para in doc_content.paragraphs: for para in doc_content.paragraphs:
if para.level > 0: if para.level > 0:
# 标题 # Keep service responsibilities explicit so downstream behavior stays predictable.
prefix = "#" * para.level prefix = "#" * para.level
lines.append(f"\n{prefix} {para.text}\n") lines.append(f"\n{prefix} {para.text}\n")
elif para.is_list: elif para.is_list:
# 列表项 # Keep service responsibilities explicit so downstream behavior stays predictable.
lines.append(f"- {para.text}") lines.append(f"- {para.text}")
else: else:
# 正文 # Keep service responsibilities explicit so downstream behavior stays predictable.
lines.append(para.text) lines.append(para.text)
# 添加表格 # Keep service responsibilities explicit so downstream behavior stays predictable.
if doc_content.tables: if doc_content.tables:
lines.append("\n## 表格\n") lines.append("\n## 表格\n")
for i, table in enumerate(doc_content.tables): for i, table in enumerate(doc_content.tables):
@@ -269,18 +259,18 @@ class DocxParser:
return "\n".join(lines) return "\n".join(lines)
def parse_to_markdown(self, file_path: str) -> str: def parse_to_markdown(self, file_path: str) -> str:
"""直接解析并返回Markdown文本""" """Parse to markdown for the Docx Parser instance."""
doc_content = self.parse(file_path) doc_content = self.parse(file_path)
return doc_content.markdown_text return doc_content.markdown_text
def parse_docx(file_path: str) -> DocxDocumentContent: def parse_docx(file_path: str) -> DocxDocumentContent:
"""便捷函数解析Word文档""" """Parse docx."""
parser = DocxParser() parser = DocxParser()
return parser.parse(file_path) return parser.parse(file_path)
def parse_docx_to_markdown(file_path: str) -> str: def parse_docx_to_markdown(file_path: str) -> str:
"""便捷函数解析Word并返回Markdown""" """Parse docx to markdown."""
parser = DocxParser() parser = DocxParser()
return parser.parse_to_markdown(file_path) return parser.parse_to_markdown(file_path)

View File

@@ -1,14 +1,16 @@
"""MinerU多模态PDF解析 - 版面感知解析""" """Provide service-layer logic for mineru parser."""
from typing import Optional, Dict from typing import Optional, Dict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from loguru import logger from loguru import logger
import os import os
# Keep service responsibilities explicit so downstream behavior stays predictable.
@dataclass @dataclass
class MinerUResult: class MinerUResult:
"""MinerU解析结果""" """Represent the Miner U Result type."""
file_path: str file_path: str
markdown_text: str markdown_text: str
metadata: Dict[str, str] = field(default_factory=dict) metadata: Dict[str, str] = field(default_factory=dict)
@@ -17,21 +19,14 @@ class MinerUResult:
class MinerUParser: class MinerUParser:
""" """Provide the Miner U Parser parser."""
MinerU多模态PDF解析器
MinerU (magic-pdf) 是一个开源的高质量PDF解析工具
支持版面感知解析,能够识别文档中的标题、正文、表格、图片等元素,
并输出结构化的Markdown格式。
GitHub: https://github.com/opendatalab/MinerU
"""
def __init__(self): def __init__(self):
"""Initialize the Miner U Parser instance."""
self.available = self._check_mineru_available() self.available = self._check_mineru_available()
def _check_mineru_available(self) -> bool: def _check_mineru_available(self) -> bool:
"""检查MinerU是否可用""" """Handle check mineru available for this module for the Miner U Parser instance."""
try: try:
from magic_pdf.pipe.UNIPipe import UNIPipe from magic_pdf.pipe.UNIPipe import UNIPipe
return True return True
@@ -40,16 +35,7 @@ class MinerUParser:
return False return False
def parse(self, file_path: str, output_dir: Optional[str] = None) -> MinerUResult: def parse(self, file_path: str, output_dir: Optional[str] = None) -> MinerUResult:
""" """Handle parse for the Miner U Parser instance."""
使用MinerU解析PDF文档
Args:
file_path: PDF文件路径
output_dir: 输出目录(可选,用于保存解析产物)
Returns:
MinerUResult: 解析结果
"""
logger.info(f"尝试使用MinerU解析: {file_path}") logger.info(f"尝试使用MinerU解析: {file_path}")
if not self.available: if not self.available:
@@ -64,19 +50,19 @@ class MinerUParser:
from magic_pdf.pipe.UNIPipe import UNIPipe from magic_pdf.pipe.UNIPipe import UNIPipe
from magic_pdf.libs.MakeContentConfig import DropMode from magic_pdf.libs.MakeContentConfig import DropMode
# 设置输出目录 # Keep service responsibilities explicit so downstream behavior stays predictable.
if output_dir is None: if output_dir is None:
output_dir = os.path.dirname(file_path) output_dir = os.path.dirname(file_path)
# 创建解析管道 # Keep service responsibilities explicit so downstream behavior stays predictable.
# OCR模式可以根据PDF类型选择 # Keep service responsibilities explicit so downstream behavior stays predictable.
# auto: 自动判断是否需要OCR # Keep service responsibilities explicit so downstream behavior stays predictable.
# txt: 纯文本PDF无OCR # Keep service responsibilities explicit so downstream behavior stays predictable.
# ocr: 扫描件PDFOCR # Keep service responsibilities explicit so downstream behavior stays predictable.
pipe = UNIPipe(file_path, output_dir) pipe = UNIPipe(file_path, output_dir)
# 执行解析 # Keep service responsibilities explicit so downstream behavior stays predictable.
# pipe_mk() 返回Markdown格式文本 # Keep service responsibilities explicit so downstream behavior stays predictable.
markdown_content = pipe.pipe_mk() markdown_content = pipe.pipe_mk()
logger.success(f"MinerU解析成功") logger.success(f"MinerU解析成功")
@@ -98,13 +84,13 @@ class MinerUParser:
) )
def _extract_metadata(self, pipe) -> Dict[str, str]: def _extract_metadata(self, pipe) -> Dict[str, str]:
"""从解析管道提取元数据""" """Handle extract metadata for this module for the Miner U Parser instance."""
metadata = {} metadata = {}
try: try:
# MinerU解析管道中可能包含的元数据信息 # Keep service responsibilities explicit so downstream behavior stays predictable.
if hasattr(pipe, 'pdf_mid_data') and pipe.pdf_mid_data: if hasattr(pipe, 'pdf_mid_data') and pipe.pdf_mid_data:
mid_data = pipe.pdf_mid_data mid_data = pipe.pdf_mid_data
# 提取可能的元数据字段 # Keep service responsibilities explicit so downstream behavior stays predictable.
metadata = { metadata = {
"page_count": str(mid_data.get("page_count", "")), "page_count": str(mid_data.get("page_count", "")),
"language": str(mid_data.get("language", "")), "language": str(mid_data.get("language", "")),
@@ -116,41 +102,27 @@ class MinerUParser:
return metadata return metadata
def parse_to_markdown(self, file_path: str) -> str: def parse_to_markdown(self, file_path: str) -> str:
"""直接解析并返回Markdown文本""" """Parse to markdown for the Miner U Parser instance."""
result = self.parse(file_path) result = self.parse(file_path)
return result.markdown_text if result.success else "" return result.markdown_text if result.success else ""
class ParserOrchestrator: class ParserOrchestrator:
""" """Represent the Parser Orchestrator type."""
解析服务编排 - 按优先级选择解析器
解析策略:
1. 优先尝试MinerU版面感知能力强
2. MinerU失败时回退到基础PyMuPDF解析
"""
def __init__(self): def __init__(self):
"""Initialize the Parser Orchestrator instance."""
from .pdf_parser import PDFParser from .pdf_parser import PDFParser
self.mineru_parser = MinerUParser() self.mineru_parser = MinerUParser()
self.pdf_parser = PDFParser() self.pdf_parser = PDFParser()
self.mineru_available = self.mineru_parser.available self.mineru_available = self.mineru_parser.available
def parse_pdf(self, file_path: str, prefer_mineru: bool = True) -> str: def parse_pdf(self, file_path: str, prefer_mineru: bool = True) -> str:
""" """Parse pdf for the Parser Orchestrator instance."""
解析PDF文档按优先级选择解析器
Args:
file_path: PDF文件路径
prefer_mineru: 是否优先使用MinerU
Returns:
str: Markdown格式文本
"""
markdown_text = "" markdown_text = ""
if prefer_mineru and self.mineru_available: if prefer_mineru and self.mineru_available:
# 优先尝试MinerU # Keep service responsibilities explicit so downstream behavior stays predictable.
result = self.mineru_parser.parse(file_path) result = self.mineru_parser.parse(file_path)
if result.success: if result.success:
markdown_text = result.markdown_text markdown_text = result.markdown_text
@@ -159,28 +131,20 @@ class ParserOrchestrator:
else: else:
logger.warning(f"MinerU解析失败回退到PyMuPDF: {result.error_message}") logger.warning(f"MinerU解析失败回退到PyMuPDF: {result.error_message}")
# 回退到PyMuPDF基础解析 # Keep service responsibilities explicit so downstream behavior stays predictable.
logger.info("使用PyMuPDF基础解析") logger.info("使用PyMuPDF基础解析")
markdown_text = self.pdf_parser.parse_to_markdown(file_path) markdown_text = self.pdf_parser.parse_to_markdown(file_path)
return markdown_text return markdown_text
def parse_docx(self, file_path: str) -> str: def parse_docx(self, file_path: str) -> str:
"""解析Word文档""" """Parse docx for the Parser Orchestrator instance."""
from .docx_parser import DocxParser from .docx_parser import DocxParser
docx_parser = DocxParser() docx_parser = DocxParser()
return docx_parser.parse_to_markdown(file_path) return docx_parser.parse_to_markdown(file_path)
def parse(self, file_path: str) -> str: def parse(self, file_path: str) -> str:
""" """Handle parse for the Parser Orchestrator instance."""
根据文件类型选择解析器
Args:
file_path: 文件路径
Returns:
str: Markdown格式文本
"""
ext = os.path.splitext(file_path)[1].lower() ext = os.path.splitext(file_path)[1].lower()
if ext == ".pdf": if ext == ".pdf":
@@ -192,12 +156,12 @@ class ParserOrchestrator:
def parse_with_mineru(file_path: str) -> MinerUResult: def parse_with_mineru(file_path: str) -> MinerUResult:
"""便捷函数使用MinerU解析""" """Parse with mineru."""
parser = MinerUParser() parser = MinerUParser()
return parser.parse(file_path) return parser.parse(file_path)
def parse_pdf_smart(file_path: str) -> str: def parse_pdf_smart(file_path: str) -> str:
"""便捷函数智能解析PDF自动选择最佳解析器""" """Parse pdf smart."""
orchestrator = ParserOrchestrator() orchestrator = ParserOrchestrator()
return orchestrator.parse_pdf(file_path) return orchestrator.parse_pdf(file_path)

View File

@@ -1,4 +1,4 @@
"""PDF文档解析 - 使用PyMuPDF基础解析""" """Provide service-layer logic for pdf parser."""
import fitz # PyMuPDF import fitz # PyMuPDF
from typing import List, Dict, Optional, Tuple from typing import List, Dict, Optional, Tuple
@@ -9,17 +9,17 @@ import re
@dataclass @dataclass
class PDFPageContent: class PDFPageContent:
"""PDF页面内容""" """Represent the P D F Page Content type."""
page_number: int page_number: int
text: str text: str
tables: List[str] = field(default_factory=list) tables: List[str] = field(default_factory=list)
images: List[str] = field(default_factory=list) # 图片路径列表 images: List[str] = field(default_factory=list) # Keep service responsibilities explicit so downstream behavior stays predictable.
blocks: List[Dict] = field(default_factory=list) blocks: List[Dict] = field(default_factory=list)
@dataclass @dataclass
class PDFDocumentContent: class PDFDocumentContent:
"""PDF文档完整内容""" """Represent the P D F Document Content type."""
file_path: str file_path: str
total_pages: int total_pages: int
pages: List[PDFPageContent] pages: List[PDFPageContent]
@@ -28,23 +28,14 @@ class PDFDocumentContent:
class PDFParser: class PDFParser:
"""PDF文档解析器 - 基于PyMuPDF""" """Provide the P D F Parser parser."""
def __init__(self): def __init__(self):
"""Initialize the P D F Parser instance."""
self.pdf = None self.pdf = None
def parse(self, file_path: str, extract_tables: bool = True, extract_images: bool = False) -> PDFDocumentContent: def parse(self, file_path: str, extract_tables: bool = True, extract_images: bool = False) -> PDFDocumentContent:
""" """Handle parse for the P D F Parser instance."""
解析PDF文档
Args:
file_path: PDF文件路径
extract_tables: 是否提取表格
extract_images: 是否提取图片
Returns:
PDFDocumentContent: 解析后的文档内容
"""
logger.info(f"开始解析PDF文档: {file_path}") logger.info(f"开始解析PDF文档: {file_path}")
try: try:
@@ -55,16 +46,16 @@ class PDFParser:
pages=[] pages=[]
) )
# 提取文档元数据 # Keep service responsibilities explicit so downstream behavior stays predictable.
doc_content.metadata = self._extract_metadata() doc_content.metadata = self._extract_metadata()
# 逐页解析 # Keep service responsibilities explicit so downstream behavior stays predictable.
for page_num in range(self.pdf.page_count): for page_num in range(self.pdf.page_count):
page = self.pdf[page_num] page = self.pdf[page_num]
page_content = self._parse_page(page, page_num + 1, extract_tables, extract_images) page_content = self._parse_page(page, page_num + 1, extract_tables, extract_images)
doc_content.pages.append(page_content) doc_content.pages.append(page_content)
# 生成Markdown格式文本 # Keep service responsibilities explicit so downstream behavior stays predictable.
doc_content.markdown_text = self._generate_markdown(doc_content) doc_content.markdown_text = self._generate_markdown(doc_content)
self.pdf.close() self.pdf.close()
@@ -77,7 +68,7 @@ class PDFParser:
raise raise
def _extract_metadata(self) -> Dict[str, str]: def _extract_metadata(self) -> Dict[str, str]:
"""提取PDF元数据""" """Handle extract metadata for this module for the P D F Parser instance."""
metadata = {} metadata = {}
try: try:
meta = self.pdf.metadata meta = self.pdf.metadata
@@ -97,23 +88,23 @@ class PDFParser:
def _parse_page(self, page: fitz.Page, page_num: int, def _parse_page(self, page: fitz.Page, page_num: int,
extract_tables: bool, extract_images: bool) -> PDFPageContent: extract_tables: bool, extract_images: bool) -> PDFPageContent:
"""解析单页内容""" """Handle parse page for this module for the P D F Parser instance."""
page_content = PDFPageContent(page_number=page_num, text="") page_content = PDFPageContent(page_number=page_num, text="")
# 提取文本块(保留结构) # Keep service responsibilities explicit so downstream behavior stays predictable.
blocks = page.get_text("dict", flags=fitz.TEXT_PRESERVE_WHITESPACE)["blocks"] blocks = page.get_text("dict", flags=fitz.TEXT_PRESERVE_WHITESPACE)["blocks"]
page_content.blocks = blocks page_content.blocks = blocks
# 提取纯文本 # Keep service responsibilities explicit so downstream behavior stays predictable.
text = page.get_text("text", flags=fitz.TEXT_PRESERVE_WHITESPACE) text = page.get_text("text", flags=fitz.TEXT_PRESERVE_WHITESPACE)
page_content.text = text.strip() page_content.text = text.strip()
# 提取表格使用PyMuPDF的表格提取功能 # Keep service responsibilities explicit so downstream behavior stays predictable.
if extract_tables: if extract_tables:
tables = self._extract_tables_from_page(page) tables = self._extract_tables_from_page(page)
page_content.tables = tables page_content.tables = tables
# 提取图片 # Keep service responsibilities explicit so downstream behavior stays predictable.
if extract_images: if extract_images:
images = self._extract_images_from_page(page, page_num) images = self._extract_images_from_page(page, page_num)
page_content.images = images page_content.images = images
@@ -121,25 +112,22 @@ class PDFParser:
return page_content return page_content
def _extract_tables_from_page(self, page: fitz.Page) -> List[str]: def _extract_tables_from_page(self, page: fitz.Page) -> List[str]:
""" """Handle extract tables from page for this module for the P D F Parser instance."""
从页面提取表格(基于文本块分析)
注意PyMuPDF基础版表格提取能力有限复杂表格建议使用MinerU
"""
tables = [] tables = []
try: try:
# 使用PyMuPDF的表格提取方法2.4+版本) # Keep service responsibilities explicit so downstream behavior stays predictable.
# 对于更复杂的表格需要在mineru_parser中使用更高级的方法 # Keep service responsibilities explicit so downstream behavior stays predictable.
tabs = page.find_tables() tabs = page.find_tables()
if tabs: if tabs:
for tab in tabs: for tab in tabs:
table_text = tab.extract() table_text = tab.extract()
# 将表格转换为Markdown格式 # Keep service responsibilities explicit so downstream behavior stays predictable.
markdown_table = self._table_to_markdown(table_text) markdown_table = self._table_to_markdown(table_text)
tables.append(markdown_table) tables.append(markdown_table)
except AttributeError: except AttributeError:
# 旧版本PyMuPDF没有表格提取功能 # Keep service responsibilities explicit so downstream behavior stays predictable.
logger.warning("PyMuPDF版本不支持表格提取请升级到2.4+版本") logger.warning("PyMuPDF版本不支持表格提取请升级到2.4+版本")
except Exception as e: except Exception as e:
logger.warning(f"表格提取失败: {e}") logger.warning(f"表格提取失败: {e}")
@@ -147,28 +135,28 @@ class PDFParser:
return tables return tables
def _table_to_markdown(self, table_data: List[List[str]]) -> str: def _table_to_markdown(self, table_data: List[List[str]]) -> str:
"""将表格数据转换为Markdown格式""" """Handle table to markdown for this module for the P D F Parser instance."""
if not table_data or len(table_data) < 1: if not table_data or len(table_data) < 1:
return "" return ""
lines = [] lines = []
# 表头 # Keep service responsibilities explicit so downstream behavior stays predictable.
if len(table_data) >= 1: if len(table_data) >= 1:
header = table_data[0] header = table_data[0]
lines.append("| " + " | ".join(str(cell).strip() for cell in header) + " |") lines.append("| " + " | ".join(str(cell).strip() for cell in header) + " |")
lines.append("| " + " | ".join("---" for _ in header) + " |") lines.append("| " + " | ".join("---" for _ in header) + " |")
# 数据行 # Keep service responsibilities explicit so downstream behavior stays predictable.
for row in table_data[1:]: for row in table_data[1:]:
lines.append("| " + " | ".join(str(cell).strip() for cell in row) + " |") lines.append("| " + " | ".join(str(cell).strip() for cell in row) + " |")
return "\n".join(lines) return "\n".join(lines)
def _extract_images_from_page(self, page: fitz.Page, page_num: int) -> List[str]: def _extract_images_from_page(self, page: fitz.Page, page_num: int) -> List[str]:
"""提取页面图片""" """Handle extract images from page for this module for the P D F Parser instance."""
images = [] images = []
# 图片提取功能(可选实现) # Keep service responsibilities explicit so downstream behavior stays predictable.
# 这里仅记录图片信息,实际图片需要额外保存 # Keep service responsibilities explicit so downstream behavior stays predictable.
try: try:
image_list = page.get_images() image_list = page.get_images()
for img_index, img in enumerate(image_list): for img_index, img in enumerate(image_list):
@@ -179,52 +167,52 @@ class PDFParser:
return images return images
def _generate_markdown(self, doc_content: PDFDocumentContent) -> str: def _generate_markdown(self, doc_content: PDFDocumentContent) -> str:
"""生成Markdown格式文本""" """Handle generate markdown for this module for the P D F Parser instance."""
lines = [] lines = []
# 文档标题 # Keep service responsibilities explicit so downstream behavior stays predictable.
title = doc_content.metadata.get("title", "") title = doc_content.metadata.get("title", "")
if title: if title:
lines.append(f"# {title}\n") lines.append(f"# {title}\n")
else: else:
lines.append(f"# {doc_content.file_path}\n") lines.append(f"# {doc_content.file_path}\n")
# 元数据信息 # Keep service responsibilities explicit so downstream behavior stays predictable.
lines.append("\n## 文档信息\n") lines.append("\n## 文档信息\n")
for key, value in doc_content.metadata.items(): for key, value in doc_content.metadata.items():
if value and key in ["author", "subject", "keywords", "creation_date"]: if value and key in ["author", "subject", "keywords", "creation_date"]:
lines.append(f"- **{key}**: {value}") lines.append(f"- **{key}**: {value}")
# 正文内容 # Keep service responsibilities explicit so downstream behavior stays predictable.
lines.append("\n## 正文\n") lines.append("\n## 正文\n")
for page in doc_content.pages: for page in doc_content.pages:
# 页码标记 # Keep service responsibilities explicit so downstream behavior stays predictable.
lines.append(f"\n---\n**第 {page.page_number} 页**\n") lines.append(f"\n---\n**第 {page.page_number} 页**\n")
# 处理文本内容,识别标题结构 # Keep service responsibilities explicit so downstream behavior stays predictable.
text = self._process_page_text(page.text, page.blocks) text = self._process_page_text(page.text, page.blocks)
lines.append(text) lines.append(text)
# 添加表格 # Keep service responsibilities explicit so downstream behavior stays predictable.
for table in page.tables: for table in page.tables:
lines.append("\n" + table + "\n") lines.append("\n" + table + "\n")
return "\n".join(lines) return "\n".join(lines)
def _process_page_text(self, text: str, blocks: List[Dict]) -> str: def _process_page_text(self, text: str, blocks: List[Dict]) -> str:
"""处理页面文本,识别标题结构""" """Handle process page text for this module for the P D F Parser instance."""
# 基于字体大小识别标题 # Keep service responsibilities explicit so downstream behavior stays predictable.
processed_text = text processed_text = text
# 尝试识别标题(基于字号) # Keep service responsibilities explicit so downstream behavior stays predictable.
# 法规文档通常有明确的层级结构:章、节、条 # Keep service responsibilities explicit so downstream behavior stays predictable.
processed_text = self._detect_headers(text, blocks) processed_text = self._detect_headers(text, blocks)
return processed_text return processed_text
def _detect_headers(self, text: str, blocks: List[Dict]) -> str: def _detect_headers(self, text: str, blocks: List[Dict]) -> str:
"""检测并标记标题(基于字号或内容模式)""" """Handle detect headers for this module for the P D F Parser instance."""
lines = text.split("\n") lines = text.split("\n")
processed_lines = [] processed_lines = []
@@ -233,8 +221,8 @@ class PDFParser:
if not line: if not line:
continue continue
# 法规标题模式检测 # Keep service responsibilities explicit so downstream behavior stays predictable.
# 第一章、第X章、第X节、第X条等 # Keep service responsibilities explicit so downstream behavior stays predictable.
if re.match(r'^第[一二三四五六七八九十百]+章\s', line): if re.match(r'^第[一二三四五六七八九十百]+章\s', line):
processed_lines.append(f"\n## {line}\n") processed_lines.append(f"\n## {line}\n")
elif re.match(r'^第[一二三四五六七八九十百]+节\s', line): elif re.match(r'^第[一二三四五六七八九十百]+节\s', line):
@@ -242,7 +230,7 @@ class PDFParser:
elif re.match(r'^第[一二三四五六七八九十百]+条\s', line): elif re.match(r'^第[一二三四五六七八九十百]+条\s', line):
processed_lines.append(f"\n#### {line}\n") processed_lines.append(f"\n#### {line}\n")
elif re.match(r'^[一二三四五六七八九十]+\s*[、.]', line): elif re.match(r'^[一二三四五六七八九十]+\s*[、.]', line):
# 条款子项 # Keep service responsibilities explicit so downstream behavior stays predictable.
processed_lines.append(f"- {line}") processed_lines.append(f"- {line}")
else: else:
processed_lines.append(line) processed_lines.append(line)
@@ -250,18 +238,18 @@ class PDFParser:
return "\n".join(processed_lines) return "\n".join(processed_lines)
def parse_to_markdown(self, file_path: str) -> str: def parse_to_markdown(self, file_path: str) -> str:
"""直接解析并返回Markdown文本""" """Parse to markdown for the P D F Parser instance."""
doc_content = self.parse(file_path) doc_content = self.parse(file_path)
return doc_content.markdown_text return doc_content.markdown_text
def parse_pdf(file_path: str, **kwargs) -> PDFDocumentContent: def parse_pdf(file_path: str, **kwargs) -> PDFDocumentContent:
"""便捷函数解析PDF文档""" """Parse pdf."""
parser = PDFParser() parser = PDFParser()
return parser.parse(file_path, **kwargs) return parser.parse(file_path, **kwargs)
def parse_pdf_to_markdown(file_path: str) -> str: def parse_pdf_to_markdown(file_path: str) -> str:
"""便捷函数解析PDF并返回Markdown""" """Parse pdf to markdown."""
parser = PDFParser() parser = PDFParser()
return parser.parse_to_markdown(file_path) return parser.parse_to_markdown(file_path)

View File

@@ -1,11 +1,29 @@
"""RAG服务模块""" """Initialize the app.services.rag package."""
# Keep package boundaries explicit so backend imports stay predictable.
from .retriever import Retriever, retrieve_regulations
from .context_builder import ContextBuilder, build_rag_context
from .prompt_templates import PromptTemplates, get_prompt_template
__all__ = [ __all__ = [
"Retriever", "retrieve_regulations", "Retriever",
"ContextBuilder", "build_rag_context", "retrieve_regulations",
"PromptTemplates", "get_prompt_template" "ContextBuilder",
"build_rag_context",
"PromptTemplates",
"get_prompt_template",
] ]
def __getattr__(name: str):
"""Handle getattr for this module."""
if name in {"Retriever", "retrieve_regulations"}:
from .retriever import Retriever, retrieve_regulations
return {"Retriever": Retriever, "retrieve_regulations": retrieve_regulations}[name]
if name in {"ContextBuilder", "build_rag_context"}:
from .context_builder import ContextBuilder, build_rag_context
return {"ContextBuilder": ContextBuilder, "build_rag_context": build_rag_context}[name]
if name in {"PromptTemplates", "get_prompt_template"}:
from .prompt_templates import PromptTemplates, get_prompt_template
return {"PromptTemplates": PromptTemplates, "get_prompt_template": get_prompt_template}[name]
raise AttributeError(name)

View File

@@ -1,4 +1,4 @@
"""RAG上下文构建服务 - 构建LLM输入上下文""" """Provide service-layer logic for context builder."""
from typing import List, Dict, Optional from typing import List, Dict, Optional
from dataclasses import dataclass from dataclasses import dataclass
@@ -6,11 +6,13 @@ from loguru import logger
from .retriever import RetrievedDocument from .retriever import RetrievedDocument
from app.config.settings import settings from app.config.settings import settings
# Keep service responsibilities explicit so downstream behavior stays predictable.
@dataclass @dataclass
class RAGContext: class RAGContext:
"""RAG构建的上下文""" """Represent the R A G Context type."""
system_prompt: str system_prompt: str
context_text: str context_text: str
user_query: str user_query: str
@@ -20,14 +22,7 @@ class RAGContext:
class ContextBuilder: class ContextBuilder:
""" """Provide the Context Builder builder."""
RAG上下文构建器
功能:
- 格式化检索结果为上下文文本
- 控制上下文长度token限制
- 构建完整的LLM输入格式
"""
def __init__( def __init__(
self, self,
@@ -35,14 +30,7 @@ class ContextBuilder:
include_metadata: bool = True, include_metadata: bool = True,
citation_format: str = "【条款{clause}" citation_format: str = "【条款{clause}"
): ):
""" """Initialize the Context Builder instance."""
初始化上下文构建器
Args:
max_context_tokens: 最大上下文token数
include_metadata: 是否包含元数据(文档名、条款号等)
citation_format: 引用格式模板
"""
self.max_context_tokens = max_context_tokens or settings.rag_max_context_tokens self.max_context_tokens = max_context_tokens or settings.rag_max_context_tokens
self.include_metadata = include_metadata self.include_metadata = include_metadata
self.citation_format = citation_format self.citation_format = citation_format
@@ -56,30 +44,19 @@ class ContextBuilder:
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
) -> RAGContext: ) -> RAGContext:
""" """Handle build for the Context Builder instance."""
构建RAG上下文
Args:
query: 用户查询
documents: 检索到的文档列表
system_prompt: 系统提示词(可选)
max_tokens: 最大token数可选覆盖默认值
Returns:
RAGContext: 构建的上下文对象
"""
max_tokens = max_tokens or self.max_context_tokens max_tokens = max_tokens or self.max_context_tokens
# 格式化文档内容 # Keep service responsibilities explicit so downstream behavior stays predictable.
context_text, sources, truncated = self._format_documents( context_text, sources, truncated = self._format_documents(
documents, documents,
max_tokens max_tokens
) )
# 构建系统提示词 # Keep service responsibilities explicit so downstream behavior stays predictable.
system_prompt = system_prompt or self._default_system_prompt() system_prompt = system_prompt or self._default_system_prompt()
# 估算总token数 # Keep service responsibilities explicit so downstream behavior stays predictable.
total_tokens = self._estimate_tokens(system_prompt + context_text + query) total_tokens = self._estimate_tokens(system_prompt + context_text + query)
logger.info(f"上下文构建完成: {len(documents)}条文档, {total_tokens}tokens, truncated={truncated}") logger.info(f"上下文构建完成: {len(documents)}条文档, {total_tokens}tokens, truncated={truncated}")
@@ -98,29 +75,20 @@ class ContextBuilder:
documents: List[RetrievedDocument], documents: List[RetrievedDocument],
max_tokens: int max_tokens: int
) -> tuple: ) -> tuple:
""" """Handle format documents for this module for the Context Builder instance."""
格式化文档内容
Args:
documents: 文档列表
max_tokens: 最大token数
Returns:
(context_text, sources, truncated)
"""
context_parts = [] context_parts = []
sources = [] sources = []
current_tokens = 0 current_tokens = 0
truncated = False truncated = False
for i, doc in enumerate(documents): for i, doc in enumerate(documents):
# 格式化单个文档 # Keep service responsibilities explicit so downstream behavior stays predictable.
formatted = self._format_single_doc(doc, i + 1) formatted = self._format_single_doc(doc, i + 1)
# 估算token数 # Keep service responsibilities explicit so downstream behavior stays predictable.
doc_tokens = self._estimate_tokens(formatted) doc_tokens = self._estimate_tokens(formatted)
# 检查是否超出限制 # Keep service responsibilities explicit so downstream behavior stays predictable.
if current_tokens + doc_tokens > max_tokens: if current_tokens + doc_tokens > max_tokens:
truncated = True truncated = True
logger.warning(f"上下文截断: 已达到{max_tokens}tokens限制") logger.warning(f"上下文截断: 已达到{max_tokens}tokens限制")
@@ -129,7 +97,7 @@ class ContextBuilder:
context_parts.append(formatted) context_parts.append(formatted)
current_tokens += doc_tokens current_tokens += doc_tokens
# 记录来源 # Keep service responsibilities explicit so downstream behavior stays predictable.
sources.append({ sources.append({
"index": i + 1, "index": i + 1,
"doc_id": doc.doc_id, "doc_id": doc.doc_id,
@@ -148,13 +116,13 @@ class ContextBuilder:
doc: RetrievedDocument, doc: RetrievedDocument,
index: int index: int
) -> str: ) -> str:
"""格式化单个文档""" """Handle format single doc for this module for the Context Builder instance."""
parts = [] parts = []
# 索引编号 # Keep service responsibilities explicit so downstream behavior stays predictable.
parts.append(f"[{index}]") parts.append(f"[{index}]")
# 元数据(可选) # Keep service responsibilities explicit so downstream behavior stays predictable.
if self.include_metadata: if self.include_metadata:
meta_parts = [] meta_parts = []
@@ -171,13 +139,13 @@ class ContextBuilder:
if meta_parts: if meta_parts:
parts.append(" | ".join(meta_parts)) parts.append(" | ".join(meta_parts))
# 内容 # Keep service responsibilities explicit so downstream behavior stays predictable.
parts.append(doc.content) parts.append(doc.content)
return "\n".join(parts) return "\n".join(parts)
def _default_system_prompt(self) -> str: def _default_system_prompt(self) -> str:
"""默认系统提示词""" """Handle default system prompt for this module for the Context Builder instance."""
return """你是合规专家助手,基于提供的法规条款回答问题。 return """你是合规专家助手,基于提供的法规条款回答问题。
回答要求: 回答要求:
@@ -192,8 +160,8 @@ class ContextBuilder:
- 最后给出合规建议""" - 最后给出合规建议"""
def _estimate_tokens(self, text: str) -> int: def _estimate_tokens(self, text: str) -> int:
"""估算文本token数""" """Handle estimate tokens for this module for the Context Builder instance."""
# 中文字符约1.5 token英文约0.25 token # Keep service responsibilities explicit so downstream behavior stays predictable.
chinese_chars = sum(1 for c in text if '' <= c <= '鿿') chinese_chars = sum(1 for c in text if '' <= c <= '鿿')
other_chars = len(text) - chinese_chars other_chars = len(text) - chinese_chars
return int(chinese_chars * 1.5 + other_chars * 0.25) return int(chinese_chars * 1.5 + other_chars * 0.25)
@@ -202,15 +170,7 @@ class ContextBuilder:
self, self,
context: RAGContext context: RAGContext
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
""" """Build messages for the Context Builder instance."""
构建LLM消息格式
Args:
context: RAG上下文对象
Returns:
List[Dict]: [{"role": "system/user/assistant", "content": "..."}]
"""
messages = [ messages = [
{"role": "system", "content": context.system_prompt}, {"role": "system", "content": context.system_prompt},
{"role": "user", "content": f"参考以下法规条款回答问题。\n\n{context.context_text}\n\n问题:{context.user_query}"} {"role": "user", "content": f"参考以下法规条款回答问题。\n\n{context.context_text}\n\n问题:{context.user_query}"}
@@ -224,6 +184,6 @@ def build_rag_context(
documents: List[RetrievedDocument], documents: List[RetrievedDocument],
**kwargs **kwargs
) -> RAGContext: ) -> RAGContext:
"""便捷函数构建RAG上下文""" """Build rag context."""
builder = ContextBuilder() builder = ContextBuilder()
return builder.build(query, documents, **kwargs) return builder.build(query, documents, **kwargs)

View File

@@ -1,12 +1,14 @@
"""RAG Prompt模板 - 合规问答专用Prompt""" """Provide service-layer logic for prompt templates."""
from typing import Dict, Optional from typing import Dict, Optional
from dataclasses import dataclass from dataclasses import dataclass
# Keep service responsibilities explicit so downstream behavior stays predictable.
@dataclass @dataclass
class PromptTemplate: class PromptTemplate:
"""Prompt模板""" """Represent the Prompt Template type."""
name: str name: str
system_prompt: str system_prompt: str
user_template: str user_template: str
@@ -14,18 +16,9 @@ class PromptTemplate:
class PromptTemplates: class PromptTemplates:
""" """Represent the Prompt Templates type."""
合规问答Prompt模板库
包含多种场景的Prompt模板 # Keep service responsibilities explicit so downstream behavior stays predictable.
- 合规问答(标准)
- 条款解读(详细解释)
- 合规检查(判断合规状态)
- 差异对比(新旧法规对比)
- 报告生成(合规报告)
"""
# 合规问答标准模板
COMPLIANCE_QA = PromptTemplate( COMPLIANCE_QA = PromptTemplate(
name="compliance_qa", name="compliance_qa",
system_prompt="""你是合规专家助手,专门解答法规合规问题。 system_prompt="""你是合规专家助手,专门解答法规合规问题。
@@ -63,7 +56,7 @@ class PromptTemplates:
description="标准合规问答模板" description="标准合规问答模板"
) )
# 条款解读模板(详细解释) # Keep service responsibilities explicit so downstream behavior stays predictable.
CLAUSE_INTERPRETATION = PromptTemplate( CLAUSE_INTERPRETATION = PromptTemplate(
name="clause_interpretation", name="clause_interpretation",
system_prompt="""你是法规解读专家,负责详细解释法规条款的含义和应用。 system_prompt="""你是法规解读专家,负责详细解释法规条款的含义和应用。
@@ -96,7 +89,7 @@ class PromptTemplates:
description="条款详细解读模板" description="条款详细解读模板"
) )
# 合规检查模板(判断合规状态) # Keep service responsibilities explicit so downstream behavior stays predictable.
COMPLIANCE_CHECK = PromptTemplate( COMPLIANCE_CHECK = PromptTemplate(
name="compliance_check", name="compliance_check",
system_prompt="""你是合规检查专家,负责评估企业行为或产品的合规状态。 system_prompt="""你是合规检查专家,负责评估企业行为或产品的合规状态。
@@ -140,7 +133,7 @@ class PromptTemplates:
description="合规检查评估模板" description="合规检查评估模板"
) )
# 差异对比模板(新旧法规对比) # Keep service responsibilities explicit so downstream behavior stays predictable.
COMPARISON = PromptTemplate( COMPARISON = PromptTemplate(
name="comparison", name="comparison",
system_prompt="""你是法规变更分析专家,负责对比新旧法规版本的差异。 system_prompt="""你是法规变更分析专家,负责对比新旧法规版本的差异。
@@ -192,7 +185,7 @@ class PromptTemplates:
description="法规版本对比模板" description="法规版本对比模板"
) )
# 报告生成模板 # Keep service responsibilities explicit so downstream behavior stays predictable.
REPORT_GENERATION = PromptTemplate( REPORT_GENERATION = PromptTemplate(
name="report_generation", name="report_generation",
system_prompt="""你是合规报告撰写专家,负责生成结构化的合规分析报告。 system_prompt="""你是合规报告撰写专家,负责生成结构化的合规分析报告。
@@ -222,7 +215,7 @@ class PromptTemplates:
description="合规报告生成模板" description="合规报告生成模板"
) )
# 文档摘要生成模板 # Keep service responsibilities explicit so downstream behavior stays predictable.
DOCUMENT_SUMMARY = PromptTemplate( DOCUMENT_SUMMARY = PromptTemplate(
name="document_summary", name="document_summary",
system_prompt="""你是法规文档摘要专家,负责生成法规文档的核心要点摘要。 system_prompt="""你是法规文档摘要专家,负责生成法规文档的核心要点摘要。
@@ -263,7 +256,7 @@ class PromptTemplates:
@classmethod @classmethod
def get_template(cls, name: str) -> Optional[PromptTemplate]: def get_template(cls, name: str) -> Optional[PromptTemplate]:
"""获取指定模板""" """Return template for the Prompt Templates instance."""
templates = { templates = {
"compliance_qa": cls.COMPLIANCE_QA, "compliance_qa": cls.COMPLIANCE_QA,
"clause_interpretation": cls.CLAUSE_INTERPRETATION, "clause_interpretation": cls.CLAUSE_INTERPRETATION,
@@ -276,7 +269,7 @@ class PromptTemplates:
@classmethod @classmethod
def list_templates(cls) -> Dict[str, str]: def list_templates(cls) -> Dict[str, str]:
"""列出所有模板""" """List templates for the Prompt Templates instance."""
return { return {
"compliance_qa": cls.COMPLIANCE_QA.description, "compliance_qa": cls.COMPLIANCE_QA.description,
"clause_interpretation": cls.CLAUSE_INTERPRETATION.description, "clause_interpretation": cls.CLAUSE_INTERPRETATION.description,
@@ -288,7 +281,7 @@ class PromptTemplates:
def get_prompt_template(name: str) -> PromptTemplate: def get_prompt_template(name: str) -> PromptTemplate:
"""便捷函数获取Prompt模板""" """Return prompt template."""
template = PromptTemplates.get_template(name) template = PromptTemplates.get_template(name)
if not template: if not template:
raise ValueError(f"不存在的模板: {name}") raise ValueError(f"不存在的模板: {name}")

View File

@@ -1,192 +1,82 @@
"""RAG检索服务 - 封装Milvus检索""" """Provide service-layer logic for retriever."""
from __future__ import annotations
from typing import List, Dict, Optional, Any
from dataclasses import dataclass, field from dataclasses import dataclass, field
from loguru import logger from typing import Any, Optional
from app.shared.bootstrap import get_retrieval_service
# Keep service responsibilities explicit so downstream behavior stays predictable.
from app.services.embedding.bge_m3_embedder import BGEM3Embedder
from app.services.storage.milvus_client import MilvusClient, SearchResult
from app.config.settings import settings
@dataclass @dataclass
class RetrievedDocument: class RetrievedDocument:
"""检索到的文档""" """Represent the Retrieved Document type."""
content: str content: str
doc_id: str # 文档ID用于下载 doc_id: str
doc_name: str doc_name: str
section_title: str section_title: str
clause_number: str clause_number: str
page_number: int page_number: int
score: float score: float
metadata: Dict[str, Any] = field(default_factory=dict) metadata: dict[str, Any] = field(default_factory=dict)
class Retriever: class Retriever:
""" """Provide the Retriever retriever."""
RAG检索器 def __init__(self, top_k: int = 5, rerank: bool = False, min_score: float = 0.0):
"""Initialize the Retriever instance."""
功能: self.top_k = top_k
- 向量检索Dense + Sparse混合
- 重排序(可选)
- 过滤和筛选
"""
def __init__(
self,
top_k: int = None,
rerank: bool = False,
min_score: float = 0.3
):
"""
初始化检索器
Args:
top_k: 检索召回数量
rerank: 是否启用重排序
min_score: 最低相关性分数阈值
"""
self.top_k = top_k or settings.rag_top_k
self.rerank = rerank self.rerank = rerank
self.min_score = min_score self.min_score = min_score
# 嵌入模型(延迟加载) def retrieve(self, query: str, filters: Optional[str] = None, top_k: Optional[int] = None) -> list[RetrievedDocument]:
self.embedder: Optional[BGEM3Embedder] = None """Handle retrieve for the Retriever instance."""
results = get_retrieval_service().retrieve(query=query, top_k=top_k or self.top_k, filters=filters)
# Milvus客户端延迟连接
self.milvus: Optional[MilvusClient] = None
logger.info(f"检索器初始化: top_k={self.top_k}, rerank={self.rerank}")
def _init_embedder(self):
"""延迟初始化嵌入模型"""
if self.embedder is None:
logger.info("加载嵌入模型...")
self.embedder = BGEM3Embedder(model_name=settings.embedding_model)
def _init_milvus(self):
"""延迟初始化Milvus"""
if self.milvus is None:
logger.info("连接Milvus...")
self.milvus = MilvusClient()
self.milvus.connect()
self.milvus.create_collection(recreate=False)
self.milvus.load_collection()
def retrieve(
self,
query: str,
filters: Optional[str] = None,
top_k: Optional[int] = None
) -> List[RetrievedDocument]:
"""
检索相关文档
Args:
query: 查询文本
filters: 过滤条件(如 "regulation_type=='车辆安全'"
top_k: 返回数量(可选,覆盖默认值)
Returns:
List[RetrievedDocument]: 检索结果列表
"""
logger.info(f"执行检索: {query}")
# 初始化组件
self._init_embedder()
self._init_milvus()
# 生成查询向量
query_embedding = self.embedder.embed_single(query)
# 执行混合检索
results = self.milvus.hybrid_search(
query_dense=query_embedding['dense'].tolist(),
query_sparse=query_embedding['sparse'],
top_k=top_k or self.top_k,
filters=filters
)
# 转换为RetrievedDocument格式
documents = []
for r in results:
if r.score >= self.min_score:
doc = RetrievedDocument(
content=r.content,
doc_id=r.metadata.get("doc_id", ""),
doc_name=r.metadata.get("doc_name", ""),
section_title=r.metadata.get("section_title", ""),
clause_number=r.metadata.get("clause_number", ""),
page_number=r.metadata.get("page_number", 0),
score=r.score,
metadata=r.metadata
)
documents.append(doc)
logger.success(f"检索完成,返回{len(documents)}条结果(阈值过滤后)")
return documents
def retrieve_with_scores(
self,
query: str,
filters: Optional[str] = None
) -> List[Dict]:
"""
检索并返回完整结果(包含分数)
Args:
query: 查询文本
filters: 过滤条件
Returns:
List[Dict]: 包含分数的检索结果
"""
documents = self.retrieve(query, filters)
return [ return [
{ RetrievedDocument(
"content": doc.content, content=item.content,
"doc_id": doc.doc_id, doc_id=item.doc_id,
"doc_name": doc.doc_name, doc_name=item.doc_name,
"section_title": doc.section_title, section_title=item.section_title,
"clause_number": doc.clause_number, clause_number=item.metadata.get("clause_number", ""),
"page_number": doc.page_number, page_number=item.page_number,
"score": doc.score score=item.score,
} metadata=item.metadata,
for doc in documents )
for item in results
if item.score >= self.min_score
] ]
def search_by_doc_name( def retrieve_with_scores(self, query: str, filters: Optional[str] = None) -> list[dict]:
self, """Handle retrieve with scores for the Retriever instance."""
query: str, return [
doc_name: str {
) -> List[RetrievedDocument]: "content": item.content,
"""按文档名称过滤检索""" "doc_id": item.doc_id,
filters = f'doc_name=="{doc_name}"' "doc_name": item.doc_name,
return self.retrieve(query, filters) "section_title": item.section_title,
"clause_number": item.clause_number,
"page_number": item.page_number,
"score": item.score,
}
for item in self.retrieve(query, filters)
]
def search_by_regulation_type( def search_by_doc_name(self, query: str, doc_name: str) -> list[RetrievedDocument]:
self, """Search by doc name for the Retriever instance."""
query: str, return self.retrieve(query, filters=f'doc_name == "{doc_name}"')
regulation_type: str
) -> List[RetrievedDocument]: def search_by_regulation_type(self, query: str, regulation_type: str) -> list[RetrievedDocument]:
"""按法规类型过滤检索""" """Search by regulation type for the Retriever instance."""
filters = f'regulation_type=="{regulation_type}"' return self.retrieve(query, filters=f'regulation_type == "{regulation_type}"')
return self.retrieve(query, filters)
def close(self): def close(self):
"""关闭连接""" """Release the resources held by this component."""
if self.milvus: return None
self.milvus.disconnect()
logger.info("检索器已关闭")
def retrieve_regulations( def retrieve_regulations(query: str, top_k: int = 10, filters: Optional[str] = None) -> list[RetrievedDocument]:
query: str, """Handle retrieve regulations."""
top_k: int = 10, return Retriever(top_k=top_k).retrieve(query, filters)
filters: Optional[str] = None
) -> List[RetrievedDocument]:
"""便捷函数:检索法规"""
retriever = Retriever(top_k=top_k)
results = retriever.retrieve(query, filters)
retriever.close()
return results

View File

@@ -1,6 +1,18 @@
"""存储服务""" """Initialize the app.services.storage package."""
# Keep package boundaries explicit so backend imports stay predictable.
from .milvus_client import MilvusClient
from .minio_client import MinIOClient
__all__ = ["MilvusClient", "MinIOClient"] __all__ = ["MilvusClient", "MinIOClient"]
def __getattr__(name: str):
"""Handle getattr for this module."""
if name == "MilvusClient":
from .milvus_client import MilvusClient
return MilvusClient
if name == "MinIOClient":
from .minio_client import MinIOClient
return MinIOClient
raise AttributeError(name)

View File

@@ -1,4 +1,4 @@
"""Milvus向量数据库客户端 - 存储与检索服务""" """Provide service-layer logic for milvus client."""
from pymilvus import ( from pymilvus import (
connections, connections,
@@ -17,11 +17,13 @@ import numpy as np
from ..embedding.text_chunker import TextChunk from ..embedding.text_chunker import TextChunk
from ..embedding.bge_m3_embedder import EmbeddingResult from ..embedding.bge_m3_embedder import EmbeddingResult
from app.config.settings import settings from app.config.settings import settings
# Keep service responsibilities explicit so downstream behavior stays predictable.
@dataclass @dataclass
class SearchResult: class SearchResult:
"""检索结果""" """Represent the Search Result type."""
id: int id: int
content: str content: str
score: float score: float
@@ -30,7 +32,7 @@ class SearchResult:
@dataclass @dataclass
class MilvusDocument: class MilvusDocument:
"""Milvus文档数据结构""" """Represent the Milvus Document type."""
doc_id: str doc_id: str
chunk_id: str chunk_id: str
content: str content: str
@@ -46,7 +48,7 @@ class MilvusDocument:
class MilvusClient: class MilvusClient:
"""Milvus向量数据库客户端""" """Represent the Milvus Client type."""
COLLECTION_NAME = "regulations" COLLECTION_NAME = "regulations"
@@ -73,6 +75,7 @@ class MilvusClient:
collection_name: str = None, collection_name: str = None,
db_name: str = None db_name: str = None
): ):
"""Initialize the Milvus Client instance."""
self.host = host or settings.milvus_host self.host = host or settings.milvus_host
self.port = port or settings.milvus_port self.port = port or settings.milvus_port
self.collection_name = collection_name or settings.milvus_collection self.collection_name = collection_name or settings.milvus_collection
@@ -84,7 +87,7 @@ class MilvusClient:
logger.info(f"Milvus客户端配置: {self.host}:{self.port}, Collection: {self.collection_name}") logger.info(f"Milvus客户端配置: {self.host}:{self.port}, Collection: {self.collection_name}")
def connect(self) -> bool: def connect(self) -> bool:
"""连接到Milvus服务器""" """Handle connect for the Milvus Client instance."""
try: try:
connections.connect( connections.connect(
alias="default", alias="default",
@@ -101,7 +104,7 @@ class MilvusClient:
return False return False
def disconnect(self): def disconnect(self):
"""断开连接""" """Handle disconnect for the Milvus Client instance."""
try: try:
connections.disconnect("default") connections.disconnect("default")
self.connected = False self.connected = False
@@ -110,7 +113,7 @@ class MilvusClient:
logger.warning(f"断开连接时出错: {e}") logger.warning(f"断开连接时出错: {e}")
def create_collection(self, recreate: bool = False) -> bool: def create_collection(self, recreate: bool = False) -> bool:
"""创建Collection""" """Create collection for the Milvus Client instance."""
if not self.connected: if not self.connected:
logger.warning("未连接到Milvus请先调用connect()") logger.warning("未连接到Milvus请先调用connect()")
return False return False
@@ -146,7 +149,7 @@ class MilvusClient:
return False return False
def _create_indexes(self): def _create_indexes(self):
"""创建向量索引""" """Handle create indexes for this module for the Milvus Client instance."""
if not self.collection: if not self.collection:
return return
@@ -177,13 +180,13 @@ class MilvusClient:
logger.warning(f"创建索引时出错: {e}") logger.warning(f"创建索引时出错: {e}")
def load_collection(self): def load_collection(self):
"""加载Collection到内存""" """Load collection for the Milvus Client instance."""
if self.collection: if self.collection:
self.collection.load() self.collection.load()
logger.info(f"Collection已加载: {self.collection_name}") logger.info(f"Collection已加载: {self.collection_name}")
def release_collection(self): def release_collection(self):
"""释放Collection内存""" """Handle release collection for the Milvus Client instance."""
if self.collection: if self.collection:
self.collection.release() self.collection.release()
logger.info(f"Collection已释放: {self.collection_name}") logger.info(f"Collection已释放: {self.collection_name}")
@@ -193,7 +196,7 @@ class MilvusClient:
chunks: List[TextChunk], chunks: List[TextChunk],
embeddings: EmbeddingResult embeddings: EmbeddingResult
) -> List[int]: ) -> List[int]:
"""插入文档分块和嵌入向量""" """Handle insert chunks for the Milvus Client instance."""
if not self.collection: if not self.collection:
logger.warning("Collection未初始化") logger.warning("Collection未初始化")
return [] return []
@@ -246,7 +249,7 @@ class MilvusClient:
top_k: int = 10, top_k: int = 10,
filters: Optional[str] = None filters: Optional[str] = None
) -> List[SearchResult]: ) -> List[SearchResult]:
"""混合检索Dense + Sparse""" """Handle hybrid search for the Milvus Client instance."""
if not self.collection: if not self.collection:
logger.warning("Collection未初始化") logger.warning("Collection未初始化")
return [] return []
@@ -254,10 +257,10 @@ class MilvusClient:
try: try:
self.collection.load() self.collection.load()
# 使用简单的Dense检索兼容所有版本 # Keep service responsibilities explicit so downstream behavior stays predictable.
dense_results = self.dense_search(query_dense, top_k, filters) dense_results = self.dense_search(query_dense, top_k, filters)
# 可选合并Sparse结果 # Keep service responsibilities explicit so downstream behavior stays predictable.
if query_sparse: if query_sparse:
sparse_results = self.sparse_search(query_sparse, top_k, filters) sparse_results = self.sparse_search(query_sparse, top_k, filters)
merged = self._merge_results(dense_results, sparse_results, top_k) merged = self._merge_results(dense_results, sparse_results, top_k)
@@ -277,7 +280,7 @@ class MilvusClient:
top_k: int, top_k: int,
dense_weight: float = 0.6 dense_weight: float = 0.6
) -> List[SearchResult]: ) -> List[SearchResult]:
"""手动融合Dense和Sparse结果""" """Handle merge results for this module for the Milvus Client instance."""
sparse_weight = 1 - dense_weight sparse_weight = 1 - dense_weight
merged_dict = {} merged_dict = {}
@@ -318,7 +321,7 @@ class MilvusClient:
top_k: int = 10, top_k: int = 10,
filters: Optional[str] = None filters: Optional[str] = None
) -> List[SearchResult]: ) -> List[SearchResult]:
"""纯Dense向量检索""" """Handle dense search for the Milvus Client instance."""
if not self.collection: if not self.collection:
return [] return []
@@ -375,7 +378,7 @@ class MilvusClient:
top_k: int = 10, top_k: int = 10,
filters: Optional[str] = None filters: Optional[str] = None
) -> List[SearchResult]: ) -> List[SearchResult]:
"""纯Sparse向量检索""" """Handle sparse search for the Milvus Client instance."""
if not self.collection: if not self.collection:
return [] return []
@@ -427,7 +430,7 @@ class MilvusClient:
return [] return []
def delete_by_doc_id(self, doc_id: str) -> int: def delete_by_doc_id(self, doc_id: str) -> int:
"""根据doc_id删除记录""" """Delete by doc id for the Milvus Client instance."""
if not self.collection: if not self.collection:
return 0 return 0
@@ -441,7 +444,7 @@ class MilvusClient:
return 0 return 0
def get_collection_stats(self) -> Dict[str, Any]: def get_collection_stats(self) -> Dict[str, Any]:
"""获取Collection统计信息""" """Return collection stats for the Milvus Client instance."""
if not self.collection: if not self.collection:
return {} return {}
@@ -458,7 +461,7 @@ class MilvusClient:
def create_milvus_client() -> MilvusClient: def create_milvus_client() -> MilvusClient:
"""便捷函数创建Milvus客户端""" """Create milvus client."""
client = MilvusClient() client = MilvusClient()
client.connect() client.connect()
client.create_collection(recreate=False) client.create_collection(recreate=False)
@@ -470,7 +473,7 @@ def insert_documents(
chunks: List[TextChunk], chunks: List[TextChunk],
embeddings: EmbeddingResult embeddings: EmbeddingResult
) -> List[int]: ) -> List[int]:
"""便捷函数:插入文档""" """Handle insert documents."""
return client.insert_chunks(chunks, embeddings) return client.insert_chunks(chunks, embeddings)
@@ -480,5 +483,5 @@ def search_regulations(
query_sparse: Dict[int, float], query_sparse: Dict[int, float],
top_k: int = 10 top_k: int = 10
) -> List[SearchResult]: ) -> List[SearchResult]:
"""便捷函数:检索法规""" """Search regulations."""
return client.hybrid_search(query_dense, query_sparse, top_k) return client.hybrid_search(query_dense, query_sparse, top_k)

View File

@@ -1,4 +1,4 @@
"""MinIO对象存储客户端 - 文档文件存储""" """Provide service-layer logic for minio client."""
from minio import Minio from minio import Minio
from minio.error import S3Error from minio.error import S3Error
@@ -8,10 +8,12 @@ from io import BytesIO
import os import os
from app.config.settings import settings from app.config.settings import settings
# Keep service responsibilities explicit so downstream behavior stays predictable.
class MinIOClient: class MinIOClient:
"""MinIO对象存储客户端""" """Represent the Min I O Client type."""
def __init__( def __init__(
self, self,
@@ -21,16 +23,7 @@ class MinIOClient:
bucket: str = None, bucket: str = None,
secure: bool = None secure: bool = None
): ):
""" """Initialize the Min I O Client instance."""
初始化MinIO客户端
Args:
endpoint: MinIO服务地址
access_key: 访问密钥
secret_key: 秘密密钥
bucket: 存储桶名称
secure: 是否使用HTTPS
"""
self.endpoint = endpoint or settings.minio_endpoint self.endpoint = endpoint or settings.minio_endpoint
self.access_key = access_key or settings.minio_access_key self.access_key = access_key or settings.minio_access_key
self.secret_key = secret_key or settings.minio_secret_key self.secret_key = secret_key or settings.minio_secret_key
@@ -43,7 +36,7 @@ class MinIOClient:
logger.info(f"MinIO客户端配置: {self.endpoint}, bucket={self.bucket}") logger.info(f"MinIO客户端配置: {self.endpoint}, bucket={self.bucket}")
def connect(self) -> bool: def connect(self) -> bool:
"""连接MinIO服务""" """Handle connect for the Min I O Client instance."""
try: try:
self.client = Minio( self.client = Minio(
self.endpoint, self.endpoint,
@@ -60,7 +53,7 @@ class MinIOClient:
return False return False
def ensure_bucket(self) -> bool: def ensure_bucket(self) -> bool:
"""确保存储桶存在""" """Handle ensure bucket for the Min I O Client instance."""
if not self.connected: if not self.connected:
logger.warning("未连接MinIO请先调用connect()") logger.warning("未连接MinIO请先调用connect()")
return False return False
@@ -82,17 +75,7 @@ class MinIOClient:
object_name: str, object_name: str,
metadata: Dict[str, Any] = None metadata: Dict[str, Any] = None
) -> bool: ) -> bool:
""" """Handle upload file for the Min I O Client instance."""
上传本地文件到MinIO
Args:
file_path: 本地文件路径
object_name: MinIO对象名称
metadata: 元数据
Returns:
bool: 是否成功
"""
if not self.connected: if not self.connected:
self.connect() self.connect()
self.ensure_bucket() self.ensure_bucket()
@@ -125,18 +108,7 @@ class MinIOClient:
content_type: str = "application/octet-stream", content_type: str = "application/octet-stream",
metadata: Dict[str, Any] = None metadata: Dict[str, Any] = None
) -> bool: ) -> bool:
""" """Handle upload bytes for the Min I O Client instance."""
上传字节数据到MinIO
Args:
data: 文件字节数据
object_name: MinIO对象名称
content_type: 内容类型
metadata: 元数据注意MinIO仅支持US-ASCII字符
Returns:
bool: 是否成功
"""
if not self.connected: if not self.connected:
self.connect() self.connect()
self.ensure_bucket() self.ensure_bucket()
@@ -144,18 +116,18 @@ class MinIOClient:
try: try:
data_stream = BytesIO(data) data_stream = BytesIO(data)
# 处理metadata仅保留ASCII安全字符 # Keep service responsibilities explicit so downstream behavior stays predictable.
safe_metadata = None safe_metadata = None
if metadata: if metadata:
safe_metadata = {} safe_metadata = {}
for key, value in metadata.items(): for key, value in metadata.items():
if isinstance(value, str): if isinstance(value, str):
# 只保留ASCII字符或转换为安全格式 # Keep service responsibilities explicit so downstream behavior stays predictable.
try: try:
value.encode('ascii') value.encode('ascii')
safe_metadata[key] = value safe_metadata[key] = value
except UnicodeEncodeError: except UnicodeEncodeError:
# 中文字符跳过或用占位符 # Keep service responsibilities explicit so downstream behavior stays predictable.
safe_metadata[key] = "" safe_metadata[key] = ""
else: else:
safe_metadata[key] = str(value) safe_metadata[key] = str(value)
@@ -181,16 +153,7 @@ class MinIOClient:
object_name: str, object_name: str,
file_path: str file_path: str
) -> bool: ) -> bool:
""" """Handle download file for the Min I O Client instance."""
从MinIO下载文件到本地
Args:
object_name: MinIO对象名称
file_path: 本地保存路径
Returns:
bool: 是否成功
"""
if not self.connected: if not self.connected:
self.connect() self.connect()
@@ -212,16 +175,7 @@ class MinIOClient:
object_name: str, object_name: str,
expires: int = 3600 expires: int = 3600
) -> Optional[str]: ) -> Optional[str]:
""" """Return object url for the Min I O Client instance."""
获取对象下载URL临时URL
Args:
object_name: MinIO对象名称
expires: URL有效期
Returns:
str: 下载URL
"""
if not self.connected: if not self.connected:
self.connect() self.connect()
@@ -238,15 +192,7 @@ class MinIOClient:
return None return None
def get_object_data(self, object_name: str) -> Optional[bytes]: def get_object_data(self, object_name: str) -> Optional[bytes]:
""" """Return object data for the Min I O Client instance."""
获取对象数据(字节)
Args:
object_name: MinIO对象名称
Returns:
bytes: 文件数据
"""
if not self.connected: if not self.connected:
self.connect() self.connect()
@@ -262,15 +208,7 @@ class MinIOClient:
return None return None
def delete_object(self, object_name: str) -> bool: def delete_object(self, object_name: str) -> bool:
""" """Delete object for the Min I O Client instance."""
删除对象
Args:
object_name: MinIO对象名称
Returns:
bool: 是否成功
"""
if not self.connected: if not self.connected:
self.connect() self.connect()
@@ -284,15 +222,7 @@ class MinIOClient:
return False return False
def list_objects(self, prefix: str = "") -> list: def list_objects(self, prefix: str = "") -> list:
""" """List objects for the Min I O Client instance."""
列出存储桶中的对象
Args:
prefix: 对象名称前缀
Returns:
list: 对象列表
"""
if not self.connected: if not self.connected:
self.connect() self.connect()
@@ -305,15 +235,7 @@ class MinIOClient:
return [] return []
def object_exists(self, object_name: str) -> bool: def object_exists(self, object_name: str) -> bool:
""" """Handle object exists for the Min I O Client instance."""
检查对象是否存在
Args:
object_name: MinIO对象名称
Returns:
bool: 是否存在
"""
if not self.connected: if not self.connected:
self.connect() self.connect()
@@ -325,7 +247,7 @@ class MinIOClient:
return False return False
def _get_content_type(self, file_path: str) -> str: def _get_content_type(self, file_path: str) -> str:
"""根据文件扩展名获取Content-Type""" """Handle get content type for this module for the Min I O Client instance."""
ext = os.path.splitext(file_path)[1].lower() ext = os.path.splitext(file_path)[1].lower()
content_types = { content_types = {
'.pdf': 'application/pdf', '.pdf': 'application/pdf',
@@ -338,13 +260,13 @@ class MinIOClient:
return content_types.get(ext, 'application/octet-stream') return content_types.get(ext, 'application/octet-stream')
def close(self): def close(self):
"""关闭连接MinIO客户端无需显式关闭""" """Release the resources held by this component."""
self.connected = False self.connected = False
logger.info("MinIO客户端已关闭") logger.info("MinIO客户端已关闭")
def create_minio_client() -> MinIOClient: def create_minio_client() -> MinIOClient:
"""便捷函数创建MinIO客户端""" """Create minio client."""
client = MinIOClient() client = MinIOClient()
client.connect() client.connect()
client.ensure_bucket() client.ensure_bucket()

View File

@@ -0,0 +1,5 @@
"""Initialize the app.shared package."""
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = []

View File

@@ -0,0 +1,117 @@
"""Share backend wiring for bootstrap."""
from __future__ import annotations
from functools import lru_cache
from app.application.agent import AgentConversationService
from app.application.documents import DocumentCommandService, DocumentQueryService
from app.application.knowledge import KnowledgeRetrievalService
from app.config.settings import settings
from app.infrastructure.embedding.openai_compatible_embedding_provider import OpenAICompatibleEmbeddingProvider
from app.infrastructure.llm.openai_compatible_answer_generator import OpenAICompatibleAnswerGenerator
from app.infrastructure.parser.aliyun_document_parser import AliyunDocumentParser
from app.infrastructure.parser.local_chunk_builder import LocalRegulationChunkBuilder
from app.infrastructure.parser.local_document_parser import LocalDocumentParser
from app.infrastructure.parser.vector_chunk_builder import AliyunVectorChunkBuilder
from app.infrastructure.session.in_memory_conversation_store import InMemoryConversationStore
from app.infrastructure.storage.json_document_repository import JsonDocumentRepository
from app.infrastructure.storage.minio_binary_store import MinioDocumentBinaryStore
from app.infrastructure.vectorstore.dense_retriever import DenseRetriever
from app.infrastructure.vectorstore.milvus_vector_index import MilvusVectorIndex
# Keep shared wiring centralized so dependency construction remains consistent.
@lru_cache
def get_document_repository() -> JsonDocumentRepository:
"""Return document repository."""
return JsonDocumentRepository(settings.document_metadata_path)
@lru_cache
def get_binary_store() -> MinioDocumentBinaryStore:
"""Return binary store."""
return MinioDocumentBinaryStore()
@lru_cache
def get_parser():
"""Return parser."""
if settings.parser_backend == "aliyun":
return AliyunDocumentParser()
return LocalDocumentParser()
@lru_cache
def get_chunk_builder():
"""Return chunk builder."""
if settings.chunk_backend == "aliyun":
return AliyunVectorChunkBuilder()
return LocalRegulationChunkBuilder(
chunk_size=settings.chunk_size,
chunk_overlap=settings.chunk_overlap,
)
@lru_cache
def get_embedding_provider() -> OpenAICompatibleEmbeddingProvider:
"""Return embedding provider."""
return OpenAICompatibleEmbeddingProvider()
@lru_cache
def get_vector_index() -> MilvusVectorIndex:
"""Return vector index."""
return MilvusVectorIndex()
@lru_cache
def get_retrieval_service() -> KnowledgeRetrievalService:
"""Return retrieval service."""
retriever = DenseRetriever(
embedding_provider=get_embedding_provider(),
vector_index=get_vector_index(),
)
return KnowledgeRetrievalService(retriever=retriever)
@lru_cache
def get_document_command_service() -> DocumentCommandService:
"""Return document command service."""
return DocumentCommandService(
document_repository=get_document_repository(),
binary_store=get_binary_store(),
parser=get_parser(),
chunk_builder=get_chunk_builder(),
embedding_provider=get_embedding_provider(),
vector_index=get_vector_index(),
)
@lru_cache
def get_document_query_service() -> DocumentQueryService:
"""Return document query service."""
return DocumentQueryService(
document_repository=get_document_repository(),
binary_store=get_binary_store(),
)
@lru_cache
def get_conversation_store() -> InMemoryConversationStore:
"""Return conversation store."""
return InMemoryConversationStore(
max_sessions=settings.session_max_sessions,
timeout_minutes=settings.session_timeout_minutes,
)
@lru_cache
def get_agent_conversation_service() -> AgentConversationService:
"""Return agent conversation service."""
return AgentConversationService(
retrieval_service=get_retrieval_service(),
answer_generator=OpenAICompatibleAnswerGenerator(),
conversation_store=get_conversation_store(),
)

View File

@@ -1,4 +1,8 @@
"""Initialize the app.utils package."""
from .chunking import TextChunker, chunker from .chunking import TextChunker, chunker
from .logger import logger, setup_logging from .logger import logger, setup_logging
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = ["TextChunker", "chunker", "logger", "setup_logging"] __all__ = ["TextChunker", "chunker", "logger", "setup_logging"]

View File

@@ -1,19 +1,25 @@
"""Provide utility helpers for chunking."""
import re import re
from typing import List from typing import List
from app.core.config import settings from app.core.config import settings
# Keep module behavior explicit so the backend flow stays easy to audit.
class TextChunker: class TextChunker:
"""Represent the Text Chunker type."""
def __init__( def __init__(
self, self,
chunk_size: int = settings.chunk_size, chunk_size: int = settings.chunk_size,
chunk_overlap: int = settings.chunk_overlap, chunk_overlap: int = settings.chunk_overlap,
): ):
"""Initialize the Text Chunker instance."""
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap self.chunk_overlap = chunk_overlap
def chunk_by_clause(self, text: str) -> List[dict]: def chunk_by_clause(self, text: str) -> List[dict]:
"""按条款边界分块(适用于法规文档)""" """Handle chunk by clause for the Text Chunker instance."""
clause_pattern = r"(第[一二三四五六七八九十百]+条)" clause_pattern = r"(第[一二三四五六七八九十百]+条)"
parts = re.split(clause_pattern, text) parts = re.split(clause_pattern, text)
@@ -46,7 +52,7 @@ class TextChunker:
return chunks return chunks
def chunk_by_size(self, text: str) -> List[dict]: def chunk_by_size(self, text: str) -> List[dict]:
"""按固定大小分块""" """Handle chunk by size for the Text Chunker instance."""
chunks = [] chunks = []
start = 0 start = 0
chunk_index = 0 chunk_index = 0
@@ -69,7 +75,7 @@ class TextChunker:
return chunks return chunks
def estimate_tokens(self, text: str) -> int: def estimate_tokens(self, text: str) -> int:
"""估算token数量""" """Handle estimate tokens for the Text Chunker instance."""
chinese_chars = len(re.findall(r"[^\x00-\xff]", text)) chinese_chars = len(re.findall(r"[^\x00-\xff]", text))
english_chars = len(text) - chinese_chars english_chars = len(text) - chinese_chars
return int(chinese_chars / 1.5 + english_chars / 4) return int(chinese_chars / 1.5 + english_chars / 4)

View File

@@ -1,9 +1,13 @@
"""Provide utility helpers for logger."""
import logging import logging
import sys import sys
# Keep module behavior explicit so the backend flow stays easy to audit.
def setup_logging() -> logging.Logger: def setup_logging() -> logging.Logger:
"""配置日志""" """Handle setup logging."""
logger = logging.getLogger("app") logger = logging.getLogger("app")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)

Some files were not shown because too many files have changed in this diff Show More