update
This commit is contained in:
43
.env
43
.env
@@ -1,48 +1,75 @@
|
|||||||
# 环境变量配置 - 已有数据库服务
|
# 环境变量配置 - 已有数据库服务
|
||||||
|
# AI+合规智能中枢
|
||||||
|
|
||||||
# 应用配置
|
# ===== 应用配置 =====
|
||||||
APP_NAME=AI+合规智能中枢
|
APP_NAME=AI+合规智能中枢
|
||||||
APP_VERSION=0.1.0
|
APP_VERSION=0.1.0
|
||||||
DEBUG=false
|
DEBUG=false
|
||||||
|
|
||||||
# Milvus向量数据库配置(已有)
|
# ===== Milvus向量数据库配置(已有)=====
|
||||||
MILVUS_HOST=localhost
|
MILVUS_HOST=localhost
|
||||||
MILVUS_PORT=19530
|
MILVUS_PORT=19530
|
||||||
MILVUS_COLLECTION=regulations
|
MILVUS_COLLECTION=regulations
|
||||||
MILVUS_DB_NAME=default
|
MILVUS_DB_NAME=default
|
||||||
|
|
||||||
# MinIO对象存储配置(已有)
|
# ===== MinIO对象存储配置(已有)=====
|
||||||
MINIO_ENDPOINT=localhost:9000
|
MINIO_ENDPOINT=localhost:9000
|
||||||
MINIO_ACCESS_KEY=minioadmin
|
MINIO_ACCESS_KEY=minioadmin
|
||||||
MINIO_SECRET_KEY=minioadmin
|
MINIO_SECRET_KEY=minioadmin
|
||||||
MINIO_BUCKET=compliance-docs
|
MINIO_BUCKET=compliance-docs
|
||||||
MINIO_SECURE=false
|
MINIO_SECURE=false
|
||||||
|
|
||||||
# Redis配置(已有)
|
# ===== Redis配置(已有)=====
|
||||||
REDIS_HOST=localhost
|
REDIS_HOST=localhost
|
||||||
REDIS_PORT=6379
|
REDIS_PORT=6379
|
||||||
REDIS_PASSWORD=redis@123
|
REDIS_PASSWORD=redis@123
|
||||||
REDIS_DB=0
|
REDIS_DB=0
|
||||||
|
|
||||||
# PostgreSQL配置(已有)
|
# ===== PostgreSQL配置(已有)=====
|
||||||
POSTGRES_HOST=localhost
|
POSTGRES_HOST=localhost
|
||||||
POSTGRES_PORT=5432
|
POSTGRES_PORT=5432
|
||||||
POSTGRES_USER=postgresql
|
POSTGRES_USER=postgresql
|
||||||
POSTGRES_PASSWORD=postgresql123456
|
POSTGRES_PASSWORD=postgresql123456
|
||||||
POSTGRES_DB=compliance_db
|
POSTGRES_DB=compliance_db
|
||||||
|
|
||||||
# 嵌入模型配置
|
# ===== 嵌入模型配置 =====
|
||||||
EMBEDDING_MODEL=BAAI/bge-m3
|
EMBEDDING_MODEL=BAAI/bge-m3
|
||||||
EMBEDDING_DIM=1024
|
EMBEDDING_DIM=1024
|
||||||
EMBEDDING_MAX_LENGTH=8192
|
EMBEDDING_MAX_LENGTH=8192
|
||||||
EMBEDDING_BATCH_SIZE=12
|
EMBEDDING_BATCH_SIZE=12
|
||||||
EMBEDDING_USE_FP16=true
|
EMBEDDING_USE_FP16=true
|
||||||
|
|
||||||
# 文档处理配置
|
# ===== 文档处理配置 =====
|
||||||
CHUNK_SIZE=512
|
CHUNK_SIZE=512
|
||||||
CHUNK_OVERLAP=50
|
CHUNK_OVERLAP=50
|
||||||
MAX_FILE_SIZE_MB=100
|
MAX_FILE_SIZE_MB=100
|
||||||
|
|
||||||
# API配置
|
# ===== API配置 =====
|
||||||
API_HOST=0.0.0.0
|
API_HOST=0.0.0.0
|
||||||
API_PORT=8000
|
API_PORT=8000
|
||||||
|
|
||||||
|
# ===== LLM配置 =====
|
||||||
|
# LLM提供商选择: qwen / deepseek / qwen_vl
|
||||||
|
LLM_PROVIDER=deepseek
|
||||||
|
LLM_MODEL=deepseek-v4-flash
|
||||||
|
LLM_MAX_TOKENS=4096
|
||||||
|
LLM_TEMPERATURE=0.7
|
||||||
|
|
||||||
|
# ===== Qwen API配置(阿里云DashScope)=====
|
||||||
|
# 获取API Key: https://dashscope.console.aliyun.com/
|
||||||
|
QWEN_API_KEY=sk-fVr9KmDZNC4pGDBQj0EUWz9bDmFzNxjYC9EzZpe2bVDsxtz8
|
||||||
|
QWEN_BASE_URL=http://6.86.80.4:30080/v1
|
||||||
|
QWEN_MODEL=qwen3.5-plus
|
||||||
|
QWEN_VL_MODEL=qwen3-vl-plus
|
||||||
|
|
||||||
|
# ===== DeepSeek API配置 =====
|
||||||
|
# 获取API Key: https://platform.deepseek.com/
|
||||||
|
DEEPSEEK_API_KEY=sk-fVr9KmDZNC4pGDBQj0EUWz9bDmFzNxjYC9EzZpe2bVDsxtz8
|
||||||
|
DEEPSEEK_BASE_URL=http://6.86.80.4:30080/v1
|
||||||
|
DEEPSEEK_MODEL=deepseek-v4-flash
|
||||||
|
|
||||||
|
# ===== RAG配置 =====
|
||||||
|
RAG_TOP_K=10
|
||||||
|
RAG_MAX_CONTEXT_TOKENS=4000
|
||||||
|
RAG_SUMMARY_MAX_TOKENS=1024
|
||||||
|
RAG_SKILLS_MAX_TOKENS=2048
|
||||||
|
|||||||
59
.env.example
59
.env.example
@@ -1,31 +1,78 @@
|
|||||||
# .env.example - 环境变量配置示例
|
# .env.example - 环境变量配置示例
|
||||||
|
# AI+合规智能中枢
|
||||||
|
|
||||||
# Milvus向量数据库配置
|
# ===== 应用基础配置 =====
|
||||||
|
APP_NAME=AI+合规智能中枢
|
||||||
|
APP_VERSION=0.1.0
|
||||||
|
DEBUG=false
|
||||||
|
|
||||||
|
# ===== Milvus向量数据库配置 =====
|
||||||
MILVUS_HOST=localhost
|
MILVUS_HOST=localhost
|
||||||
MILVUS_PORT=19530
|
MILVUS_PORT=19530
|
||||||
MILVUS_COLLECTION=regulations
|
MILVUS_COLLECTION=regulations
|
||||||
|
MILVUS_DB_NAME=default
|
||||||
|
|
||||||
# 嵌入模型配置
|
# ===== 嵌入模型配置 =====
|
||||||
EMBEDDING_MODEL=BAAI/bge-m3
|
EMBEDDING_MODEL=BAAI/bge-m3
|
||||||
EMBEDDING_DIM=1024
|
EMBEDDING_DIM=1024
|
||||||
|
EMBEDDING_MAX_LENGTH=8192
|
||||||
|
EMBEDDING_BATCH_SIZE=12
|
||||||
|
EMBEDDING_USE_FP16=true
|
||||||
|
|
||||||
# MinIO对象存储配置
|
# ===== MinIO对象存储配置 =====
|
||||||
MINIO_ENDPOINT=localhost:9000
|
MINIO_ENDPOINT=localhost:9000
|
||||||
MINIO_ACCESS_KEY=minioadmin
|
MINIO_ACCESS_KEY=minioadmin
|
||||||
MINIO_SECRET_KEY=minioadmin123
|
MINIO_SECRET_KEY=minioadmin123
|
||||||
MINIO_BUCKET=compliance-docs
|
MINIO_BUCKET=compliance-docs
|
||||||
|
MINIO_SECURE=false
|
||||||
|
|
||||||
# Redis配置
|
# ===== Redis配置 =====
|
||||||
REDIS_HOST=localhost
|
REDIS_HOST=localhost
|
||||||
REDIS_PORT=6379
|
REDIS_PORT=6379
|
||||||
|
REDIS_PASSWORD=
|
||||||
|
REDIS_DB=0
|
||||||
|
|
||||||
# PostgreSQL配置
|
# ===== PostgreSQL配置 =====
|
||||||
POSTGRES_HOST=localhost
|
POSTGRES_HOST=localhost
|
||||||
POSTGRES_PORT=5432
|
POSTGRES_PORT=5432
|
||||||
POSTGRES_USER=compliance
|
POSTGRES_USER=compliance
|
||||||
POSTGRES_PASSWORD=compliance123
|
POSTGRES_PASSWORD=compliance123
|
||||||
POSTGRES_DB=compliance_db
|
POSTGRES_DB=compliance_db
|
||||||
|
|
||||||
# 文档处理配置
|
# ===== 文档处理配置 =====
|
||||||
CHUNK_SIZE=512
|
CHUNK_SIZE=512
|
||||||
CHUNK_OVERLAP=50
|
CHUNK_OVERLAP=50
|
||||||
|
MAX_FILE_SIZE_MB=100
|
||||||
|
|
||||||
|
# ===== API服务配置 =====
|
||||||
|
API_HOST=0.0.0.0
|
||||||
|
API_PORT=8000
|
||||||
|
|
||||||
|
# ===== LLM配置(必填)=====
|
||||||
|
# LLM提供商选择: qwen / deepseek / qwen_vl
|
||||||
|
LLM_PROVIDER=deepseek
|
||||||
|
LLM_MODEL=deepseek-v4-flash
|
||||||
|
LLM_MAX_TOKENS=4096
|
||||||
|
LLM_TEMPERATURE=0.7
|
||||||
|
|
||||||
|
# ===== 统一API代理配置 =====
|
||||||
|
# 使用new-api代理服务,支持多个LLM模型
|
||||||
|
# 获取API Key: 向管理员申请
|
||||||
|
QWEN_API_KEY=your_api_key_here
|
||||||
|
DEEPSEEK_API_KEY=your_api_key_here
|
||||||
|
QWEN_BASE_URL=http://6.86.80.4:30080/v1
|
||||||
|
DEEPSEEK_BASE_URL=http://6.86.80.4:30080/v1
|
||||||
|
|
||||||
|
# ===== 可用模型 =====
|
||||||
|
# Qwen系列: qwen3.5-plus, qwen3-plus, qwen-max, qwen-turbo, qwen-long
|
||||||
|
# Qwen VL系列: qwen3-vl-plus, qwen-vl-max
|
||||||
|
# DeepSeek系列: deepseek-v4-flash, deepseek-v3.2, deepseek-v3, deepseek-chat, deepseek-coder
|
||||||
|
QWEN_MODEL=qwen3.5-plus
|
||||||
|
QWEN_VL_MODEL=qwen3-vl-plus
|
||||||
|
DEEPSEEK_MODEL=deepseek-v4-flash
|
||||||
|
|
||||||
|
# ===== RAG配置 =====
|
||||||
|
RAG_TOP_K=10
|
||||||
|
RAG_MAX_CONTEXT_TOKENS=4000
|
||||||
|
RAG_SUMMARY_MAX_TOKENS=1024
|
||||||
|
RAG_SKILLS_MAX_TOKENS=2048
|
||||||
|
|||||||
422
QUICK_DEPLOY.md
Normal file
422
QUICK_DEPLOY.md
Normal file
@@ -0,0 +1,422 @@
|
|||||||
|
# AI+合规智能中枢 - 快速部署指南
|
||||||
|
|
||||||
|
## 系统要求
|
||||||
|
|
||||||
|
- Python 3.10+
|
||||||
|
- Docker & Docker Compose
|
||||||
|
- 8GB+ 内存(推荐16GB)
|
||||||
|
- 20GB+ 磁盘空间
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 一、环境准备
|
||||||
|
|
||||||
|
### 1. 克隆项目
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone <project_url>
|
||||||
|
cd Demo-glm
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 配置环境变量
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 复制配置模板
|
||||||
|
cp .env.example .env
|
||||||
|
|
||||||
|
# 编辑配置文件,填入API密钥
|
||||||
|
vim .env
|
||||||
|
```
|
||||||
|
|
||||||
|
**必填配置项**:
|
||||||
|
|
||||||
|
```env
|
||||||
|
# LLM配置(使用统一API代理)
|
||||||
|
LLM_PROVIDER=qwen
|
||||||
|
LLM_MODEL=qwen3.5-plus
|
||||||
|
|
||||||
|
# API密钥(通过 new-api.fletcher0516.online 代理)
|
||||||
|
QWEN_API_KEY=your_api_key_here
|
||||||
|
DEEPSEEK_API_KEY=your_api_key_here
|
||||||
|
QWEN_BASE_URL=https://new-api.fletcher0516.online/v1
|
||||||
|
DEEPSEEK_BASE_URL=https://new-api.fletcher0516.online/v1
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 二、启动基础设施
|
||||||
|
|
||||||
|
### 1. 启动Docker服务
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd docker
|
||||||
|
docker-compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
等待服务启动完成(约30秒)。
|
||||||
|
|
||||||
|
### 2. 验证服务状态
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker ps
|
||||||
|
```
|
||||||
|
|
||||||
|
确认以下容器运行正常:
|
||||||
|
- `milvus` - 向量数据库
|
||||||
|
- `minio` - 对象存储
|
||||||
|
- `redis` - 缓存服务
|
||||||
|
- `postgres` - 关系数据库
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 三、安装Python依赖
|
||||||
|
|
||||||
|
### 方式A:使用快速启动脚本(推荐)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
chmod +x quick_start.sh
|
||||||
|
./quick_start.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
脚本自动完成:
|
||||||
|
- 创建虚拟环境
|
||||||
|
- 安装依赖(使用阿里云镜像)
|
||||||
|
- 检查各服务连接状态
|
||||||
|
|
||||||
|
### 方式B:手动安装
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 创建虚拟环境
|
||||||
|
python3 -m venv .venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
# 安装依赖
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 四、下载嵌入模型
|
||||||
|
|
||||||
|
BGE-M3模型约2GB,首次使用需下载。
|
||||||
|
|
||||||
|
### 方式A:自动下载(联网环境)
|
||||||
|
|
||||||
|
首次启动API时自动下载到 `~/.cache/huggingface/`
|
||||||
|
|
||||||
|
### 方式B:手动下载(离线环境)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 从ModelScope下载
|
||||||
|
python -c "from modelscope import snapshot_download; snapshot_download('Xorbits/bge-m3', cache_dir='~/.cache/modelscope')"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 五、启动服务
|
||||||
|
|
||||||
|
### 整合启动脚本(推荐)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 赋予脚本执行权限
|
||||||
|
chmod +x start_all.sh stop_all.sh restart_all.sh status.sh
|
||||||
|
|
||||||
|
# 启动所有服务(API + 前端)
|
||||||
|
./start_all.sh
|
||||||
|
|
||||||
|
# 查看服务状态
|
||||||
|
./status.sh
|
||||||
|
|
||||||
|
# 重启所有服务
|
||||||
|
./restart_all.sh
|
||||||
|
|
||||||
|
# 停止所有服务
|
||||||
|
./stop_all.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
### 单独启动(可选)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 仅启动API服务(前台运行,可调试)
|
||||||
|
./start_api.sh
|
||||||
|
|
||||||
|
# 仅启动API服务(后台运行)
|
||||||
|
./start_api_background.sh
|
||||||
|
|
||||||
|
# 仅停止API服务
|
||||||
|
./stop_api.sh
|
||||||
|
|
||||||
|
# 仅启动前端服务
|
||||||
|
./start_frontend.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 六、服务访问地址
|
||||||
|
|
||||||
|
启动成功后访问:
|
||||||
|
|
||||||
|
| 服务 | 地址 |
|
||||||
|
|------|------|
|
||||||
|
| **API服务** | http://localhost:8000 |
|
||||||
|
| **API文档** | http://localhost:8000/docs |
|
||||||
|
| **健康检查** | http://localhost:8000/health |
|
||||||
|
| **前端测试页面** | http://localhost:3000 |
|
||||||
|
|
||||||
|
> 注意:前端测试页面通过 `http://localhost:3000` 访问,自动连接到API服务。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 七、功能测试
|
||||||
|
|
||||||
|
### 1. 上传文档测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8000/api/v1/documents/upload \
|
||||||
|
-F "file=@test.pdf" \
|
||||||
|
-F "doc_name=测试文档"
|
||||||
|
```
|
||||||
|
|
||||||
|
文档上传后会自动存储到MinIO对象存储(bucket: upload-files)。
|
||||||
|
|
||||||
|
### 2. 下载文档测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 下载已上传的文档
|
||||||
|
curl -O http://localhost:8000/api/v1/documents/download/{doc_id}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 检索测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8000/api/v1/knowledge/search \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"query": "机动车安全", "top_k": 10}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 智能问答测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8000/api/v1/agent/ask \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"query": "机动车安全技术检验有哪些要求?"}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. 多轮对话测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8000/api/v1/agent/chat \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"query": "什么是机动车安全技术检验?"}'
|
||||||
|
|
||||||
|
# 返回 session_id,继续对话
|
||||||
|
curl -X POST http://localhost:8000/api/v1/agent/chat \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"query": "检验周期是多久?", "session_id": "<session_id>"}'
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 八、脚本命令速查表
|
||||||
|
|
||||||
|
| 操作 | 命令 |
|
||||||
|
|------|------|
|
||||||
|
| **启动所有服务** | `./start_all.sh` |
|
||||||
|
| **停止所有服务** | `./stop_all.sh` |
|
||||||
|
| **重启所有服务** | `./restart_all.sh` |
|
||||||
|
| **查看服务状态** | `./status.sh` |
|
||||||
|
| 查看API日志 | `tail -f logs/api.log` |
|
||||||
|
| 查看前端日志 | `tail -f logs/frontend.log` |
|
||||||
|
| 环境初始化 | `./quick_start.sh` |
|
||||||
|
| 重启Docker | `cd docker && docker-compose restart` |
|
||||||
|
| 下载嵌入模型 | `./download_model.sh` |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 九、服务状态检查
|
||||||
|
|
||||||
|
运行状态检查脚本:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./status.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
输出示例:
|
||||||
|
```
|
||||||
|
========================================
|
||||||
|
AI+合规智能中枢 - 服务状态
|
||||||
|
========================================
|
||||||
|
|
||||||
|
API服务:
|
||||||
|
状态: 运行中 ✓
|
||||||
|
PID: 12345
|
||||||
|
健康检查: 正常 ✓
|
||||||
|
地址: http://localhost:8000
|
||||||
|
|
||||||
|
前端服务:
|
||||||
|
状态: 运行中 ✓
|
||||||
|
PID: 12346
|
||||||
|
地址: http://localhost:3000
|
||||||
|
|
||||||
|
Docker服务:
|
||||||
|
milvus: 运行中 ✓
|
||||||
|
minio: 运行中 ✓
|
||||||
|
redis: 运行中 ✓
|
||||||
|
postgres: 运行中 ✓
|
||||||
|
|
||||||
|
========================================
|
||||||
|
所有服务正常运行 ✓
|
||||||
|
========================================
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 十、常见问题
|
||||||
|
|
||||||
|
### Q1: Milvus连接失败
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 检查Milvus状态
|
||||||
|
docker logs milvus
|
||||||
|
|
||||||
|
# 重启Milvus
|
||||||
|
docker restart milvus
|
||||||
|
|
||||||
|
# 等待30秒后再启动服务
|
||||||
|
```
|
||||||
|
|
||||||
|
### Q2: 模型下载慢/失败
|
||||||
|
|
||||||
|
使用ModelScope镜像:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export HF_ENDPOINT=https://hf-mirror.com
|
||||||
|
```
|
||||||
|
|
||||||
|
或手动下载:
|
||||||
|
```bash
|
||||||
|
python -c "from modelscope import snapshot_download; snapshot_download('Xorbits/bge-m3')"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Q3: LLM调用失败
|
||||||
|
|
||||||
|
检查 `.env` 中API密钥配置:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 验证配置
|
||||||
|
cat .env | grep API_KEY
|
||||||
|
|
||||||
|
# 确保base_url正确
|
||||||
|
cat .env | grep BASE_URL
|
||||||
|
```
|
||||||
|
|
||||||
|
### Q4: 端口被占用
|
||||||
|
|
||||||
|
修改 `.env` 中的端口配置:
|
||||||
|
|
||||||
|
```env
|
||||||
|
API_PORT=8001
|
||||||
|
FRONTEND_PORT=3001
|
||||||
|
```
|
||||||
|
|
||||||
|
### Q5: 服务无法停止
|
||||||
|
|
||||||
|
强制清理:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 查找并停止所有相关进程
|
||||||
|
pkill -f uvicorn
|
||||||
|
pkill -f http.server
|
||||||
|
|
||||||
|
# 清理PID文件
|
||||||
|
rm -f logs/*.pid
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 十一、目录结构
|
||||||
|
|
||||||
|
```
|
||||||
|
Demo-glm/
|
||||||
|
├── src/
|
||||||
|
│ ├── api/ # FastAPI接口
|
||||||
|
│ │ ├── main.py # API入口
|
||||||
|
│ │ └── routes/
|
||||||
|
│ │ ├── documents.py # 文档上传
|
||||||
|
│ │ ├── knowledge.py # 知识检索
|
||||||
|
│ │ └── agent.py # 智能问答
|
||||||
|
│ ├── services/ # 核心服务
|
||||||
|
│ │ ├── llm/ # LLM调用(Qwen/DeepSeek)
|
||||||
|
│ │ ├── rag/ # RAG检索
|
||||||
|
│ │ ├── agent/ # 问答Agent
|
||||||
|
│ │ ├── parser/ # 文档解析
|
||||||
|
│ │ ├── embedding/ # 向量嵌入(BGE-M3)
|
||||||
|
│ │ └── storage/ # Milvus存储
|
||||||
|
│ └── config/ # 配置管理
|
||||||
|
├── frontend/ # 前端测试页面
|
||||||
|
│ └── index.html # 测试界面
|
||||||
|
├── docker/ # Docker配置
|
||||||
|
│ └── docker-compose.yml
|
||||||
|
├── logs/ # 运行日志
|
||||||
|
│ ├── api.log
|
||||||
|
│ └── frontend.log
|
||||||
|
├── tests/ # 测试脚本
|
||||||
|
├── .env # 环境配置
|
||||||
|
├── .env.example # 配置模板
|
||||||
|
├── requirements.txt # Python依赖
|
||||||
|
├── quick_start.sh # 环境初始化脚本
|
||||||
|
├── start_all.sh # 整合启动脚本
|
||||||
|
├── stop_all.sh # 整合停止脚本
|
||||||
|
├── restart_all.sh # 重启脚本
|
||||||
|
├── status.sh # 状态检查脚本
|
||||||
|
├── start_api.sh # 单独启动API
|
||||||
|
├── start_frontend.sh # 单独启动前端
|
||||||
|
└── QUICK_DEPLOY.md # 本文档
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 十二、API接口清单
|
||||||
|
|
||||||
|
| 接口 | 路径 | 方法 | 功能 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| 上传文档 | `/api/v1/documents/upload` | POST | 上传PDF/DOCX |
|
||||||
|
| 下载文档 | `/api/v1/documents/download/{doc_id}` | GET | 下载原文PDF/DOCX |
|
||||||
|
| 文档列表 | `/api/v1/documents/list` | GET | 列出已上传文档 |
|
||||||
|
| 检索知识 | `/api/v1/knowledge/search` | POST | 向量检索 |
|
||||||
|
| 单次问答 | `/api/v1/agent/ask` | POST | 智能问答 |
|
||||||
|
| 多轮对话 | `/api/v1/agent/chat` | POST | 会话对话 |
|
||||||
|
| 会话信息 | `/api/v1/agent/session/{id}` | GET | 获取会话 |
|
||||||
|
| 删除会话 | `/api/v1/agent/session/{id}` | DELETE | 删除会话 |
|
||||||
|
| Prompt模板 | `/api/v1/agent/templates` | GET | 模板列表 |
|
||||||
|
| 可用模型 | `/api/v1/agent/models` | GET | LLM模型列表 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 十三、支持的LLM模型
|
||||||
|
|
||||||
|
通过统一API代理 `https://new-api.fletcher0516.online/v1` 支持:
|
||||||
|
|
||||||
|
**Qwen系列**:
|
||||||
|
- `qwen3.5-plus` (推荐)
|
||||||
|
- `qwen3-plus`
|
||||||
|
- `qwen-max`
|
||||||
|
- `qwen-turbo`
|
||||||
|
- `qwen-long`
|
||||||
|
|
||||||
|
**Qwen VL系列**(多模态):
|
||||||
|
- `qwen3-vl-plus`
|
||||||
|
- `qwen-vl-max`
|
||||||
|
|
||||||
|
**DeepSeek系列**:
|
||||||
|
- `deepseek-v3.2` (推荐)
|
||||||
|
- `deepseek-v3`
|
||||||
|
- `deepseek-chat`
|
||||||
|
- `deepseek-coder`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 技术支持
|
||||||
|
|
||||||
|
- API文档:http://localhost:8000/docs
|
||||||
|
- 问题反馈:提交Issue到项目仓库
|
||||||
56
backend/.env
Normal file
56
backend/.env
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
APP_NAME=AI+合规智能中枢
|
||||||
|
APP_VERSION=0.1.0
|
||||||
|
DEBUG=false
|
||||||
|
|
||||||
|
MILVUS_HOST=localhost
|
||||||
|
MILVUS_PORT=19530
|
||||||
|
MILVUS_COLLECTION=regulations
|
||||||
|
MILVUS_DB_NAME=default
|
||||||
|
|
||||||
|
MINIO_ENDPOINT=localhost:9000
|
||||||
|
MINIO_ACCESS_KEY=minioadmin
|
||||||
|
MINIO_SECRET_KEY=minioadmin
|
||||||
|
MINIO_BUCKET=compliance-docs
|
||||||
|
MINIO_SECURE=false
|
||||||
|
|
||||||
|
REDIS_HOST=localhost
|
||||||
|
REDIS_PORT=6379
|
||||||
|
REDIS_PASSWORD=redis@123
|
||||||
|
REDIS_DB=0
|
||||||
|
|
||||||
|
POSTGRES_HOST=localhost
|
||||||
|
POSTGRES_PORT=5432
|
||||||
|
POSTGRES_USER=postgresql
|
||||||
|
POSTGRES_PASSWORD=postgresql123456
|
||||||
|
POSTGRES_DB=compliance_db
|
||||||
|
|
||||||
|
EMBEDDING_MODEL=BAAI/bge-m3
|
||||||
|
EMBEDDING_DIM=1024
|
||||||
|
EMBEDDING_MAX_LENGTH=8192
|
||||||
|
EMBEDDING_BATCH_SIZE=12
|
||||||
|
EMBEDDING_USE_FP16=true
|
||||||
|
|
||||||
|
CHUNK_SIZE=512
|
||||||
|
CHUNK_OVERLAP=50
|
||||||
|
MAX_FILE_SIZE_MB=100
|
||||||
|
|
||||||
|
API_HOST=0.0.0.0
|
||||||
|
API_PORT=8000
|
||||||
|
|
||||||
|
LLM_PROVIDER=deepseek
|
||||||
|
LLM_MODEL=deepseek-v4-flash
|
||||||
|
LLM_MAX_TOKENS=4096
|
||||||
|
LLM_TEMPERATURE=0.7
|
||||||
|
|
||||||
|
QWEN_API_KEY=sk-fVr9KmDZNC4pGDBQj0EUWz9bDmFzNxjYC9EzZpe2bVDsxtz8
|
||||||
|
QWEN_BASE_URL=http://6.86.80.4:30080/v1
|
||||||
|
QWEN_MODEL=qwen3.5-plus
|
||||||
|
QWEN_VL_MODEL=qwen3-vl-plus
|
||||||
|
|
||||||
|
DEEPSEEK_API_KEY=sk-fVr9KmDZNC4pGDBQj0EUWz9bDmFzNxjYC9EzZpe2bVDsxtz8
|
||||||
|
DEEPSEEK_BASE_URL=http://6.86.80.4:30080/v1
|
||||||
|
DEEPSEEK_MODEL=deepseek-v4-flash
|
||||||
|
|
||||||
|
RAG_TOP_K=10
|
||||||
|
RAG_MAX_CONTEXT_TOKENS=4000
|
||||||
|
RAG_SUMMARY_MAX_TOKENS=1024
|
||||||
56
backend/.env.example
Normal file
56
backend/.env.example
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
APP_NAME=AI+合规智能中枢
|
||||||
|
APP_VERSION=0.1.0
|
||||||
|
DEBUG=false
|
||||||
|
|
||||||
|
MILVUS_HOST=localhost
|
||||||
|
MILVUS_PORT=19530
|
||||||
|
MILVUS_COLLECTION=regulations
|
||||||
|
MILVUS_DB_NAME=default
|
||||||
|
|
||||||
|
EMBEDDING_MODEL=BAAI/bge-m3
|
||||||
|
EMBEDDING_DIM=1024
|
||||||
|
EMBEDDING_MAX_LENGTH=8192
|
||||||
|
EMBEDDING_BATCH_SIZE=12
|
||||||
|
EMBEDDING_USE_FP16=true
|
||||||
|
|
||||||
|
MINIO_ENDPOINT=localhost:9000
|
||||||
|
MINIO_ACCESS_KEY=minioadmin
|
||||||
|
MINIO_SECRET_KEY=minioadmin123
|
||||||
|
MINIO_BUCKET=compliance-docs
|
||||||
|
MINIO_SECURE=false
|
||||||
|
|
||||||
|
REDIS_HOST=localhost
|
||||||
|
REDIS_PORT=6379
|
||||||
|
REDIS_PASSWORD=
|
||||||
|
REDIS_DB=0
|
||||||
|
|
||||||
|
POSTGRES_HOST=localhost
|
||||||
|
POSTGRES_PORT=5432
|
||||||
|
POSTGRES_USER=compliance
|
||||||
|
POSTGRES_PASSWORD=compliance123
|
||||||
|
POSTGRES_DB=compliance_db
|
||||||
|
|
||||||
|
CHUNK_SIZE=512
|
||||||
|
CHUNK_OVERLAP=50
|
||||||
|
MAX_FILE_SIZE_MB=100
|
||||||
|
|
||||||
|
API_HOST=0.0.0.0
|
||||||
|
API_PORT=8000
|
||||||
|
|
||||||
|
LLM_PROVIDER=deepseek
|
||||||
|
LLM_MODEL=deepseek-v4-flash
|
||||||
|
LLM_MAX_TOKENS=4096
|
||||||
|
LLM_TEMPERATURE=0.7
|
||||||
|
|
||||||
|
QWEN_API_KEY=your_api_key_here
|
||||||
|
DEEPSEEK_API_KEY=your_api_key_here
|
||||||
|
QWEN_BASE_URL=http://6.86.80.4:30080/v1
|
||||||
|
DEEPSEEK_BASE_URL=http://6.86.80.4:30080/v1
|
||||||
|
|
||||||
|
QWEN_MODEL=qwen3.5-plus
|
||||||
|
QWEN_VL_MODEL=qwen3-vl-plus
|
||||||
|
DEEPSEEK_MODEL=deepseek-v4-flash
|
||||||
|
|
||||||
|
RAG_TOP_K=10
|
||||||
|
RAG_MAX_CONTEXT_TOKENS=4000
|
||||||
|
RAG_SUMMARY_MAX_TOKENS=1024
|
||||||
10
backend/.gitignore
vendored
Normal file
10
backend/.gitignore
vendored
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# Python-generated files
|
||||||
|
__pycache__/
|
||||||
|
*.py[oc]
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
wheels/
|
||||||
|
*.egg-info
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv
|
||||||
1
backend/.python-version
Normal file
1
backend/.python-version
Normal file
@@ -0,0 +1 @@
|
|||||||
|
3.9
|
||||||
18
backend/Dockerfile
Normal file
18
backend/Dockerfile
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
FROM python:3.11-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# 安装依赖
|
||||||
|
COPY requirements.txt .
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
# 复制代码
|
||||||
|
COPY app/ ./app/
|
||||||
|
COPY data/ ./data/
|
||||||
|
|
||||||
|
# 环境变量
|
||||||
|
ENV API_HOST=0.0.0.0
|
||||||
|
ENV API_PORT=8000
|
||||||
|
|
||||||
|
# 启动命令
|
||||||
|
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
50
backend/README.md
Normal file
50
backend/README.md
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
# AI+合规智能中枢后端
|
||||||
|
|
||||||
|
`backend` 已承接原 `src` 的完整 FastAPI 后端能力,当前正式入口为 `app.main:app`。
|
||||||
|
|
||||||
|
## 启动
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r backend/requirements.txt
|
||||||
|
PYTHONPATH=backend uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||||
|
```
|
||||||
|
|
||||||
|
也可以直接使用根目录脚本:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./start_api.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
## 主要接口
|
||||||
|
|
||||||
|
- `GET /health`
|
||||||
|
- `GET /`
|
||||||
|
- `POST /api/v1/documents/upload`
|
||||||
|
- `GET /api/v1/documents/list`
|
||||||
|
- `GET /api/v1/documents/management-list`
|
||||||
|
- `GET /api/v1/documents/download/{doc_id}`
|
||||||
|
- `POST /api/v1/knowledge/search`
|
||||||
|
- `POST /api/v1/knowledge/retrieval`
|
||||||
|
- `POST /api/v1/agent/ask`
|
||||||
|
- `POST /api/v1/agent/chat`
|
||||||
|
- `GET /api/v1/agent/chat/stream`
|
||||||
|
|
||||||
|
## 目录说明
|
||||||
|
|
||||||
|
```text
|
||||||
|
backend/
|
||||||
|
├── app/
|
||||||
|
│ ├── api/ # FastAPI 路由与模型
|
||||||
|
│ ├── config/ # 配置与日志
|
||||||
|
│ ├── services/ # 文档处理、LLM、RAG、存储
|
||||||
|
│ └── workers/ # 任务相关代码
|
||||||
|
├── .env.example
|
||||||
|
├── requirements.txt
|
||||||
|
└── main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## 说明
|
||||||
|
|
||||||
|
- `backend/app/api/main.py` 来自原 `src/api/main.py`,已切换为 `app.*` 导入。
|
||||||
|
- 路由前缀保持为 `/api/v1`,以兼容当前前端。
|
||||||
|
- 原 `backend/app/api/routes/docs.py`、`rag.py`、`compliance.py`、`status.py` 仍保留在仓库中,但不再作为主路由入口。
|
||||||
3
backend/app/__init__.py
Normal file
3
backend/app/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .main import app
|
||||||
|
|
||||||
|
__all__ = ["app"]
|
||||||
2
backend/app/api/__init__.py
Normal file
2
backend/app/api/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# src/api/__init__.py
|
||||||
|
"""API接口模块"""
|
||||||
99
backend/app/api/main.py
Normal file
99
backend/app/api/main.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
"""FastAPI application entrypoint."""
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from app.api.models import ErrorResponse
|
||||||
|
from app.api.routes import api_router
|
||||||
|
from app.config.logging import setup_logging
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.services.llm.llm_factory import LLMFactory
|
||||||
|
|
||||||
|
setup_logging(level="INFO" if not settings.debug else "DEBUG")
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""Application lifecycle hooks."""
|
||||||
|
logger.info(f"启动 {settings.app_name} v{settings.app_version}")
|
||||||
|
logger.info(f"调试模式: {settings.debug}")
|
||||||
|
logger.info("预加载LLM客户端...")
|
||||||
|
LLMFactory.preload_clients(["qwen", "deepseek"])
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
logger.info("应用关闭,执行清理...")
|
||||||
|
LLMFactory.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title=settings.app_name,
|
||||||
|
description=(
|
||||||
|
"AI+合规智能中枢 - 法律法规文档解析入库功能\n\n"
|
||||||
|
"支持PDF/DOCX文档解析、智能分块、向量嵌入、Milvus存储"
|
||||||
|
),
|
||||||
|
version=settings.app_version,
|
||||||
|
lifespan=lifespan,
|
||||||
|
docs_url="/docs",
|
||||||
|
redoc_url="/redoc",
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.include_router(api_router, prefix="/api/v1")
|
||||||
|
|
||||||
|
|
||||||
|
@app.exception_handler(Exception)
|
||||||
|
async def global_exception_handler(request: Request, exc: Exception):
|
||||||
|
"""Global exception handler."""
|
||||||
|
logger.error(f"未处理的异常: {exc}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=ErrorResponse(
|
||||||
|
error="InternalServerError",
|
||||||
|
message=str(exc),
|
||||||
|
).model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health", tags=["health"])
|
||||||
|
async def health_check():
|
||||||
|
"""Health check endpoint."""
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"app": settings.app_name,
|
||||||
|
"version": settings.app_version,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/", tags=["root"])
|
||||||
|
async def root():
|
||||||
|
"""Root endpoint."""
|
||||||
|
return {
|
||||||
|
"message": f"Welcome to {settings.app_name}",
|
||||||
|
"version": settings.app_version,
|
||||||
|
"docs": "/docs",
|
||||||
|
"health": "/health",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(
|
||||||
|
"app.api.main:app",
|
||||||
|
host=settings.api_host,
|
||||||
|
port=settings.api_port,
|
||||||
|
reload=settings.debug,
|
||||||
|
log_level="info",
|
||||||
|
)
|
||||||
22
backend/app/api/models/__init__.py
Normal file
22
backend/app/api/models/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# src/api/models/__init__.py
|
||||||
|
"""API数据模型"""
|
||||||
|
|
||||||
|
from .document import (
|
||||||
|
DocumentUploadRequest,
|
||||||
|
DocumentUploadResponse,
|
||||||
|
SearchRequest,
|
||||||
|
SearchResultItem,
|
||||||
|
SearchResponse,
|
||||||
|
DocumentStatusResponse,
|
||||||
|
ErrorResponse
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DocumentUploadRequest",
|
||||||
|
"DocumentUploadResponse",
|
||||||
|
"SearchRequest",
|
||||||
|
"SearchResultItem",
|
||||||
|
"SearchResponse",
|
||||||
|
"DocumentStatusResponse",
|
||||||
|
"ErrorResponse"
|
||||||
|
]
|
||||||
63
backend/app/api/models/document.py
Normal file
63
backend/app/api/models/document.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
# src/api/models/document.py
|
||||||
|
"""文档相关Pydantic数据模型"""
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentUploadRequest(BaseModel):
|
||||||
|
"""文档上传请求"""
|
||||||
|
doc_name: Optional[str] = Field(None, description="文档名称")
|
||||||
|
regulation_type: Optional[str] = Field(None, description="法规类型")
|
||||||
|
version: Optional[str] = Field(None, description="文档版本")
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentUploadResponse(BaseModel):
|
||||||
|
"""文档上传响应"""
|
||||||
|
doc_id: str = Field(..., description="文档ID")
|
||||||
|
doc_name: str = Field(..., description="文档名称")
|
||||||
|
status: str = Field(..., description="处理状态")
|
||||||
|
message: str = Field(default="", description="状态消息")
|
||||||
|
num_chunks: int = Field(default=0, description="分块数量")
|
||||||
|
summary: str = Field(default="", description="LLM生成的文档摘要")
|
||||||
|
summary_latency_ms: int = Field(default=0, description="摘要生成耗时(ms)")
|
||||||
|
timestamp: datetime = Field(default_factory=datetime.now, description="时间戳")
|
||||||
|
|
||||||
|
|
||||||
|
class SearchRequest(BaseModel):
|
||||||
|
"""检索请求"""
|
||||||
|
query: str = Field(..., description="查询文本")
|
||||||
|
top_k: int = Field(default=10, description="返回结果数量")
|
||||||
|
filters: Optional[str] = Field(None, description="过滤条件")
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResultItem(BaseModel):
|
||||||
|
"""单个检索结果"""
|
||||||
|
id: int = Field(..., description="记录ID")
|
||||||
|
content: str = Field(..., description="内容")
|
||||||
|
score: float = Field(..., description="相似度分数")
|
||||||
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="元数据")
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResponse(BaseModel):
|
||||||
|
"""检索响应"""
|
||||||
|
query: str = Field(..., description="查询文本")
|
||||||
|
total: int = Field(..., description="结果总数")
|
||||||
|
results: List[SearchResultItem] = Field(default_factory=list, description="结果列表")
|
||||||
|
timestamp: datetime = Field(default_factory=datetime.now, description="时间戳")
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentStatusResponse(BaseModel):
|
||||||
|
"""文档状态响应"""
|
||||||
|
doc_id: str = Field(..., description="文档ID")
|
||||||
|
status: str = Field(..., description="状态")
|
||||||
|
num_chunks: Optional[int] = Field(None, description="分块数量")
|
||||||
|
timestamp: datetime = Field(default_factory=datetime.now, description="时间戳")
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorResponse(BaseModel):
|
||||||
|
"""错误响应"""
|
||||||
|
error: str = Field(..., description="错误类型")
|
||||||
|
message: str = Field(..., description="错误消息")
|
||||||
|
timestamp: datetime = Field(default_factory=datetime.now, description="时间戳")
|
||||||
17
backend/app/api/routes/__init__.py
Normal file
17
backend/app/api/routes/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
# src/api/routes/__init__.py
|
||||||
|
"""API路由模块"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from .documents import router as documents_router
|
||||||
|
from .knowledge import router as knowledge_router
|
||||||
|
from .agent import router as agent_router
|
||||||
|
|
||||||
|
# 主路由
|
||||||
|
api_router = APIRouter()
|
||||||
|
|
||||||
|
# 注册子路由
|
||||||
|
api_router.include_router(documents_router)
|
||||||
|
api_router.include_router(knowledge_router)
|
||||||
|
api_router.include_router(agent_router)
|
||||||
|
|
||||||
|
__all__ = ["api_router", "documents_router", "knowledge_router", "agent_router"]
|
||||||
449
backend/app/api/routes/agent.py
Normal file
449
backend/app/api/routes/agent.py
Normal file
@@ -0,0 +1,449 @@
|
|||||||
|
# src/api/routes/agent.py
|
||||||
|
"""Agent API接口 - 问答对话接口"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Depends
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List, Dict, Optional, AsyncGenerator
|
||||||
|
from loguru import logger
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from app.services.agent.qa_agent import QAAgent, AgentConfig
|
||||||
|
from app.services.agent.session_manager import SessionManager
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/agent", tags=["agent"])
|
||||||
|
|
||||||
|
# 会话管理器(全局实例)
|
||||||
|
session_manager = SessionManager()
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Pydantic Models =====
|
||||||
|
|
||||||
|
class AskRequest(BaseModel):
|
||||||
|
"""单次问答请求"""
|
||||||
|
query: str = Field(..., description="用户问题", min_length=1, max_length=2000)
|
||||||
|
filters: Optional[str] = Field(None, description="检索过滤条件")
|
||||||
|
provider: Optional[str] = Field(None, description="LLM提供商 (qwen/deepseek)")
|
||||||
|
model: Optional[str] = Field(None, description="LLM模型名称")
|
||||||
|
top_k: Optional[int] = Field(None, description="检索数量", ge=1, le=20)
|
||||||
|
prompt_template: Optional[str] = Field(None, description="Prompt模板名称")
|
||||||
|
|
||||||
|
|
||||||
|
class AskResponse(BaseModel):
|
||||||
|
"""问答响应"""
|
||||||
|
answer: str
|
||||||
|
sources: List[Dict] = []
|
||||||
|
model: str = ""
|
||||||
|
latency_ms: int = 0
|
||||||
|
retrieved_count: int = 0
|
||||||
|
context_tokens: int = 0
|
||||||
|
truncated: bool = False
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
"""多轮对话请求"""
|
||||||
|
query: str = Field(..., description="用户问题", min_length=1, max_length=2000)
|
||||||
|
session_id: Optional[str] = Field(None, description="会话ID(首次对话可不传)")
|
||||||
|
filters: Optional[str] = Field(None, description="检索过滤条件")
|
||||||
|
provider: Optional[str] = Field(None, description="LLM提供商")
|
||||||
|
model: Optional[str] = Field(None, description="LLM模型名称")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatResponse(BaseModel):
|
||||||
|
"""多轮对话响应"""
|
||||||
|
session_id: str
|
||||||
|
answer: str
|
||||||
|
sources: List[Dict] = []
|
||||||
|
model: str = ""
|
||||||
|
latency_ms: int = 0
|
||||||
|
message_count: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class SessionInfo(BaseModel):
|
||||||
|
"""会话信息"""
|
||||||
|
session_id: str
|
||||||
|
message_count: int
|
||||||
|
created_at: int
|
||||||
|
updated_at: int
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackRequest(BaseModel):
|
||||||
|
"""反馈请求"""
|
||||||
|
session_id: str
|
||||||
|
message_index: int
|
||||||
|
rating: int = Field(..., ge=1, le=5, description="评分 1-5")
|
||||||
|
comment: Optional[str] = Field(None, description="反馈内容")
|
||||||
|
|
||||||
|
|
||||||
|
class TemplateListResponse(BaseModel):
|
||||||
|
"""模板列表响应"""
|
||||||
|
templates: Dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
# ===== API Endpoints =====
|
||||||
|
|
||||||
|
@router.post("/ask", response_model=AskResponse)
|
||||||
|
async def ask_question(request: AskRequest):
|
||||||
|
"""
|
||||||
|
单次问答接口
|
||||||
|
|
||||||
|
不保存会话历史,适合单次查询场景。
|
||||||
|
"""
|
||||||
|
logger.info(f"收到问答请求: {request.query}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 构建Agent配置
|
||||||
|
config = AgentConfig(
|
||||||
|
llm_provider=request.provider or settings.llm_provider,
|
||||||
|
llm_model=request.model or settings.llm_model,
|
||||||
|
top_k=request.top_k or settings.rag_top_k
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建Agent并执行问答
|
||||||
|
agent = QAAgent(config)
|
||||||
|
response = agent.ask(
|
||||||
|
query=request.query,
|
||||||
|
filters=request.filters,
|
||||||
|
prompt_template=request.prompt_template
|
||||||
|
)
|
||||||
|
agent.close()
|
||||||
|
|
||||||
|
return AskResponse(
|
||||||
|
answer=response.answer,
|
||||||
|
sources=response.sources,
|
||||||
|
model=response.model,
|
||||||
|
latency_ms=response.latency_ms,
|
||||||
|
retrieved_count=response.retrieved_count,
|
||||||
|
context_tokens=response.context_tokens,
|
||||||
|
truncated=response.truncated,
|
||||||
|
error=response.error
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"问答失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/chat", response_model=ChatResponse)
|
||||||
|
async def chat_with_session(request: ChatRequest):
|
||||||
|
"""
|
||||||
|
多轮对话接口
|
||||||
|
|
||||||
|
支持会话历史记录,适合连续对话场景。
|
||||||
|
"""
|
||||||
|
logger.info(f"收到对话请求: session={request.session_id}, query={request.query}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取或创建会话
|
||||||
|
if request.session_id:
|
||||||
|
session = session_manager.get_session(request.session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail="会话不存在或已过期")
|
||||||
|
else:
|
||||||
|
session = session_manager.create_session()
|
||||||
|
|
||||||
|
# 添加用户消息
|
||||||
|
session.add_user_message(request.query)
|
||||||
|
|
||||||
|
# 执行问答
|
||||||
|
config = AgentConfig(
|
||||||
|
llm_provider=request.provider or settings.llm_provider,
|
||||||
|
llm_model=request.model or settings.llm_model
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = QAAgent(config)
|
||||||
|
response = agent.ask(
|
||||||
|
query=request.query,
|
||||||
|
filters=request.filters
|
||||||
|
)
|
||||||
|
agent.close()
|
||||||
|
|
||||||
|
# 添加助手消息
|
||||||
|
session.add_assistant_message(
|
||||||
|
response.answer,
|
||||||
|
response.sources
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatResponse(
|
||||||
|
session_id=session.session_id,
|
||||||
|
answer=response.answer,
|
||||||
|
sources=response.sources,
|
||||||
|
model=response.model,
|
||||||
|
latency_ms=response.latency_ms,
|
||||||
|
message_count=session.message_count
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"对话失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/chat/stream")
|
||||||
|
async def chat_stream_get(
|
||||||
|
query: str,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
filters: Optional[str] = None,
|
||||||
|
provider: Optional[str] = None,
|
||||||
|
model: Optional[str] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
流式对话接口(SSE)- GET版本
|
||||||
|
|
||||||
|
EventSource只能发送GET请求,因此提供此接口。
|
||||||
|
query参数通过URL传递。
|
||||||
|
|
||||||
|
SSE事件格式:
|
||||||
|
- event: session - 会话ID
|
||||||
|
- event: status - 状态更新(检索中、生成中)
|
||||||
|
- event: sources - 引用来源
|
||||||
|
- event: content - 回答内容片段
|
||||||
|
- event: done - 完成,包含统计信息
|
||||||
|
- event: error - 错误信息
|
||||||
|
"""
|
||||||
|
logger.info(f"收到GET流式对话请求: session={session_id}, query={query}")
|
||||||
|
|
||||||
|
async def generate_sse() -> AsyncGenerator[str, None]:
|
||||||
|
"""生成SSE事件流"""
|
||||||
|
try:
|
||||||
|
# 获取或创建会话
|
||||||
|
if session_id:
|
||||||
|
session = session_manager.get_session(session_id)
|
||||||
|
if not session:
|
||||||
|
yield f"event: error\ndata: 会话不存在或已过期\n\n"
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
session = session_manager.create_session()
|
||||||
|
|
||||||
|
# 发送session_id
|
||||||
|
yield f"event: session\ndata: {json.dumps({'session_id': session.session_id})}\n\n"
|
||||||
|
|
||||||
|
# 添加用户消息
|
||||||
|
session.add_user_message(query)
|
||||||
|
|
||||||
|
# 创建Agent
|
||||||
|
config = AgentConfig(
|
||||||
|
llm_provider=provider or settings.llm_provider,
|
||||||
|
llm_model=model or settings.llm_model
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = QAAgent(config)
|
||||||
|
|
||||||
|
# 执行流式问答
|
||||||
|
full_answer = ""
|
||||||
|
sources = []
|
||||||
|
done_data = {}
|
||||||
|
|
||||||
|
for event_data in agent.ask_stream(
|
||||||
|
query=query,
|
||||||
|
filters=filters
|
||||||
|
):
|
||||||
|
event_type = event_data.get("event", "content")
|
||||||
|
data = event_data.get("data", "")
|
||||||
|
|
||||||
|
# 收集完整回答和来源
|
||||||
|
if event_type == "content":
|
||||||
|
full_answer += str(data)
|
||||||
|
elif event_type == "sources":
|
||||||
|
sources = data
|
||||||
|
elif event_type == "done":
|
||||||
|
done_data = data
|
||||||
|
|
||||||
|
# 发送SSE事件
|
||||||
|
if isinstance(data, (dict, list)):
|
||||||
|
yield f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
||||||
|
else:
|
||||||
|
yield f"event: {event_type}\ndata: {data}\n\n"
|
||||||
|
|
||||||
|
# 小延迟让其他任务有机会执行
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
agent.close()
|
||||||
|
|
||||||
|
# 保存到会话历史
|
||||||
|
session.add_assistant_message(full_answer, sources)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"流式对话失败: {e}")
|
||||||
|
yield f"event: error\ndata: {str(e)}\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
generate_sse(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no" # 禁用nginx缓冲
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/chat/stream")
|
||||||
|
async def chat_stream(request: ChatRequest):
|
||||||
|
"""
|
||||||
|
流式对话接口(SSE)
|
||||||
|
|
||||||
|
返回Server-Sent Events格式的流式响应,用户可实时看到思考过程和回答生成。
|
||||||
|
|
||||||
|
SSE事件格式:
|
||||||
|
- event: status - 状态更新(检索中、生成中)
|
||||||
|
- event: sources - 引用来源
|
||||||
|
- event: content - 回答内容片段
|
||||||
|
- event: done - 完成,包含统计信息
|
||||||
|
- event: error - 错误信息
|
||||||
|
"""
|
||||||
|
logger.info(f"收到流式对话请求: session={request.session_id}, query={request.query}")
|
||||||
|
|
||||||
|
async def generate_sse() -> AsyncGenerator[str, None]:
|
||||||
|
"""生成SSE事件流"""
|
||||||
|
try:
|
||||||
|
# 获取或创建会话
|
||||||
|
if request.session_id:
|
||||||
|
session = session_manager.get_session(request.session_id)
|
||||||
|
if not session:
|
||||||
|
yield f"event: error\ndata: 会话不存在或已过期\n\n"
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
session = session_manager.create_session()
|
||||||
|
|
||||||
|
# 发送session_id
|
||||||
|
yield f"event: session\ndata: {json.dumps({'session_id': session.session_id})}\n\n"
|
||||||
|
|
||||||
|
# 添加用户消息
|
||||||
|
session.add_user_message(request.query)
|
||||||
|
|
||||||
|
# 创建Agent
|
||||||
|
config = AgentConfig(
|
||||||
|
llm_provider=request.provider or settings.llm_provider,
|
||||||
|
llm_model=request.model or settings.llm_model
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = QAAgent(config)
|
||||||
|
|
||||||
|
# 执行流式问答
|
||||||
|
full_answer = ""
|
||||||
|
sources = []
|
||||||
|
done_data = {}
|
||||||
|
|
||||||
|
for event_data in agent.ask_stream(
|
||||||
|
query=request.query,
|
||||||
|
filters=request.filters
|
||||||
|
):
|
||||||
|
event_type = event_data.get("event", "content")
|
||||||
|
data = event_data.get("data", "")
|
||||||
|
|
||||||
|
# 收集完整回答和来源
|
||||||
|
if event_type == "content":
|
||||||
|
full_answer += str(data)
|
||||||
|
elif event_type == "sources":
|
||||||
|
sources = data
|
||||||
|
elif event_type == "done":
|
||||||
|
done_data = data
|
||||||
|
|
||||||
|
# 发送SSE事件
|
||||||
|
if isinstance(data, (dict, list)):
|
||||||
|
yield f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
||||||
|
else:
|
||||||
|
yield f"event: {event_type}\ndata: {data}\n\n"
|
||||||
|
|
||||||
|
# 小延迟让其他任务有机会执行
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
agent.close()
|
||||||
|
|
||||||
|
# 保存到会话历史
|
||||||
|
session.add_assistant_message(full_answer, sources)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"流式对话失败: {e}")
|
||||||
|
yield f"event: error\ndata: {str(e)}\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
generate_sse(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no" # 禁用nginx缓冲
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/session/{session_id}", response_model=SessionInfo)
|
||||||
|
async def get_session_info(session_id: str):
|
||||||
|
"""获取会话信息"""
|
||||||
|
session = session_manager.get_session(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail="会话不存在或已过期")
|
||||||
|
|
||||||
|
return SessionInfo(
|
||||||
|
session_id=session.session_id,
|
||||||
|
message_count=session.message_count,
|
||||||
|
created_at=session.created_at,
|
||||||
|
updated_at=session.updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/session/{session_id}/history")
|
||||||
|
async def get_session_history(session_id: str, max_turns: int = 5):
|
||||||
|
"""获取会话历史"""
|
||||||
|
session = session_manager.get_session(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail="会话不存在或已过期")
|
||||||
|
|
||||||
|
history = session.get_history(max_turns)
|
||||||
|
return {"session_id": session_id, "history": history}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/session/{session_id}")
|
||||||
|
async def delete_session(session_id: str):
|
||||||
|
"""删除会话"""
|
||||||
|
success = session_manager.delete_session(session_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=404, detail="会话不存在")
|
||||||
|
|
||||||
|
return {"message": "会话已删除", "session_id": session_id}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions", response_model=List[SessionInfo])
|
||||||
|
async def list_sessions():
|
||||||
|
"""列出所有活跃会话"""
|
||||||
|
sessions = session_manager.list_sessions()
|
||||||
|
return [SessionInfo(**s) for s in sessions]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/feedback")
|
||||||
|
async def submit_feedback(request: FeedbackRequest):
|
||||||
|
"""提交问答反馈"""
|
||||||
|
session = session_manager.get_session(request.session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail="会话不存在")
|
||||||
|
|
||||||
|
# 记录反馈(实际应用中可存储到数据库)
|
||||||
|
logger.info(f"收到反馈: session={request.session_id}, rating={request.rating}, comment={request.comment}")
|
||||||
|
|
||||||
|
return {"message": "反馈已记录", "rating": request.rating}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/templates", response_model=TemplateListResponse)
|
||||||
|
async def list_prompt_templates():
|
||||||
|
"""列出可用的Prompt模板"""
|
||||||
|
from app.services.rag.prompt_templates import PromptTemplates
|
||||||
|
|
||||||
|
templates = PromptTemplates.list_templates()
|
||||||
|
return TemplateListResponse(templates=templates)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/models")
|
||||||
|
async def list_available_models():
|
||||||
|
"""列出可用的LLM模型"""
|
||||||
|
from app.services.llm import LLMFactory
|
||||||
|
|
||||||
|
factory = LLMFactory()
|
||||||
|
models = factory.list_available_providers()
|
||||||
|
return {"models": models}
|
||||||
96
backend/app/api/routes/compliance.py
Normal file
96
backend/app/api/routes/compliance.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
from fastapi import APIRouter, UploadFile, File, HTTPException
|
||||||
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
import uuid
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
from app.schemas.compliance import (
|
||||||
|
AnalyzeResponse,
|
||||||
|
ComplianceChatRequest,
|
||||||
|
)
|
||||||
|
from app.services.mock_data import (
|
||||||
|
generate_task_id,
|
||||||
|
get_mock_compliance_result,
|
||||||
|
get_mock_compliance_chat_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/compliance", tags=["合规分析"])
|
||||||
|
|
||||||
|
# 临时存储分析任务
|
||||||
|
tasks_store: dict[str, dict] = {}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/analyze", response_model=AnalyzeResponse)
|
||||||
|
async def analyze_document(file: UploadFile = File(...)):
|
||||||
|
"""上传设计方案进行分析"""
|
||||||
|
# 生成任务ID
|
||||||
|
task_id = generate_task_id()
|
||||||
|
|
||||||
|
# 保存文件
|
||||||
|
raw_dir = "/airegulation/demo-mao/backend/data/raw"
|
||||||
|
os.makedirs(raw_dir, exist_ok=True)
|
||||||
|
file_path = os.path.join(raw_dir, f"compliance_{task_id}_{file.filename}")
|
||||||
|
|
||||||
|
content = await file.read()
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
# 记录任务
|
||||||
|
tasks_store[task_id] = {
|
||||||
|
"task_id": task_id,
|
||||||
|
"file_path": file_path,
|
||||||
|
"status": "processing",
|
||||||
|
"result": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 模拟异步处理完成(立即返回结果)
|
||||||
|
# 实际应用中这应该是后台任务
|
||||||
|
tasks_store[task_id]["status"] = "completed"
|
||||||
|
tasks_store[task_id]["result"] = get_mock_compliance_result(task_id)
|
||||||
|
|
||||||
|
return AnalyzeResponse(task_id=task_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/result/{task_id}")
|
||||||
|
async def get_result(task_id: str):
|
||||||
|
"""获取分析结果"""
|
||||||
|
if task_id not in tasks_store:
|
||||||
|
# 如果任务ID不存在,返回默认mock结果
|
||||||
|
return get_mock_compliance_result(task_id)
|
||||||
|
|
||||||
|
task = tasks_store[task_id]
|
||||||
|
|
||||||
|
if task["status"] == "processing":
|
||||||
|
return {"status": "processing", "message": "分析进行中"}
|
||||||
|
|
||||||
|
return task["result"]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/chat/{segment_id}")
|
||||||
|
async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
|
||||||
|
"""针对段落进行合规对话"""
|
||||||
|
# 根据segment_id获取对应的intent
|
||||||
|
intent_map = {
|
||||||
|
1: "车身结构设计",
|
||||||
|
2: "动力系统配置",
|
||||||
|
3: "安全配置设计",
|
||||||
|
}
|
||||||
|
intent = intent_map.get(segment_id, "车身结构设计")
|
||||||
|
|
||||||
|
async def generate():
|
||||||
|
# 获取预设响应
|
||||||
|
response = get_mock_compliance_chat_response(intent, request.query)
|
||||||
|
|
||||||
|
# 流式输出响应
|
||||||
|
sentences = response.split("\n\n")
|
||||||
|
for sentence in sentences:
|
||||||
|
if sentence.strip():
|
||||||
|
chunks = sentence.split("\n")
|
||||||
|
for chunk in chunks:
|
||||||
|
if chunk.strip():
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
yield {"event": "message", "data": json.dumps({"type": "chunk", "text": chunk + "\n"})}
|
||||||
|
|
||||||
|
yield {"event": "message", "data": json.dumps({"type": "done"})}
|
||||||
|
|
||||||
|
return EventSourceResponse(generate())
|
||||||
115
backend/app/api/routes/docs.py
Normal file
115
backend/app/api/routes/docs.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
from fastapi import APIRouter, UploadFile, File, HTTPException
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from app.schemas.doc import (
|
||||||
|
DocumentUploadResponse,
|
||||||
|
DocumentListResponse,
|
||||||
|
DocumentInfo,
|
||||||
|
ParseResponse,
|
||||||
|
EmbedResponse,
|
||||||
|
)
|
||||||
|
from app.services.mock_data import get_mock_documents, generate_doc_id
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/docs", tags=["文档管理"])
|
||||||
|
|
||||||
|
# 临时存储文档信息(包含预设的mock文档)
|
||||||
|
documents_store: dict[str, dict] = {}
|
||||||
|
|
||||||
|
# 初始化时加载mock文档
|
||||||
|
for doc in get_mock_documents():
|
||||||
|
documents_store[doc["id"]] = doc
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/upload", response_model=DocumentUploadResponse)
|
||||||
|
async def upload_document(file: UploadFile = File(...)):
|
||||||
|
"""上传法规文档"""
|
||||||
|
# 检查文件格式
|
||||||
|
allowed_ext = [".pdf", ".docx", ".doc", ".txt"]
|
||||||
|
ext = os.path.splitext(file.filename)[1].lower()
|
||||||
|
if ext not in allowed_ext:
|
||||||
|
raise HTTPException(400, f"Unsupported file format: {ext}")
|
||||||
|
|
||||||
|
# 生成文档ID
|
||||||
|
doc_id = generate_doc_id()
|
||||||
|
|
||||||
|
# 保存文件
|
||||||
|
raw_dir = "/airegulation/demo-mao/backend/data/raw"
|
||||||
|
os.makedirs(raw_dir, exist_ok=True)
|
||||||
|
file_path = os.path.join(raw_dir, f"{doc_id}_{file.filename}")
|
||||||
|
|
||||||
|
content = await file.read()
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
# 记录文档信息
|
||||||
|
documents_store[doc_id] = {
|
||||||
|
"id": doc_id,
|
||||||
|
"name": file.filename,
|
||||||
|
"path": file_path,
|
||||||
|
"size": len(content),
|
||||||
|
"status": "uploaded",
|
||||||
|
"chunks": 0,
|
||||||
|
"created_at": datetime.now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return DocumentUploadResponse(
|
||||||
|
doc_id=doc_id,
|
||||||
|
filename=file.filename,
|
||||||
|
size=len(content),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/list", response_model=DocumentListResponse)
|
||||||
|
async def list_documents():
|
||||||
|
"""获取已索引文档列表"""
|
||||||
|
docs = [
|
||||||
|
DocumentInfo(
|
||||||
|
id=d["id"],
|
||||||
|
name=d["name"],
|
||||||
|
chunks=d["chunks"],
|
||||||
|
status=d["status"],
|
||||||
|
created_at=d.get("created_at"),
|
||||||
|
)
|
||||||
|
for d in documents_store.values()
|
||||||
|
]
|
||||||
|
return DocumentListResponse(docs=docs)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/parse/{doc_id}", response_model=ParseResponse)
|
||||||
|
async def parse_document(doc_id: str):
|
||||||
|
"""解析文档并分块"""
|
||||||
|
if doc_id not in documents_store:
|
||||||
|
raise HTTPException(404, "Document not found")
|
||||||
|
|
||||||
|
doc = documents_store[doc_id]
|
||||||
|
# 模拟解析逻辑
|
||||||
|
doc["status"] = "parsed"
|
||||||
|
# 根据文件大小计算chunks数量
|
||||||
|
file_size = doc.get("size", 100000)
|
||||||
|
doc["chunks"] = max(20, file_size // 8000)
|
||||||
|
|
||||||
|
return ParseResponse(doc_id=doc_id, chunks=doc["chunks"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/embed/{doc_id}", response_model=EmbedResponse)
|
||||||
|
async def embed_document(doc_id: str):
|
||||||
|
"""嵌入并存入向量库"""
|
||||||
|
if doc_id not in documents_store:
|
||||||
|
raise HTTPException(404, "Document not found")
|
||||||
|
|
||||||
|
doc = documents_store[doc_id]
|
||||||
|
# 模拟嵌入逻辑
|
||||||
|
doc["status"] = "indexed"
|
||||||
|
|
||||||
|
return EmbedResponse(doc_id=doc_id, vectors=doc["chunks"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/delete/{doc_id}")
|
||||||
|
async def delete_document(doc_id: str):
|
||||||
|
"""删除文档"""
|
||||||
|
if doc_id not in documents_store:
|
||||||
|
raise HTTPException(404, "Document not found")
|
||||||
|
|
||||||
|
del documents_store[doc_id]
|
||||||
|
return {"success": True}
|
||||||
291
backend/app/api/routes/documents.py
Normal file
291
backend/app/api/routes/documents.py
Normal file
@@ -0,0 +1,291 @@
|
|||||||
|
# src/api/routes/documents.py
|
||||||
|
"""文档上传与处理接口"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter, UploadFile, File, Form, HTTPException
|
||||||
|
from fastapi.responses import FileResponse, StreamingResponse
|
||||||
|
from typing import Optional
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from loguru import logger
|
||||||
|
from io import BytesIO
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
from ..models import DocumentUploadResponse, ErrorResponse
|
||||||
|
from app.services.document_processor import DocumentProcessor
|
||||||
|
from app.services.storage.minio_client import MinIOClient
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/documents", tags=["documents"])
|
||||||
|
|
||||||
|
# MinIO客户端(用于文档存储)
|
||||||
|
minio_client: Optional[MinIOClient] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_minio_client() -> MinIOClient:
|
||||||
|
"""获取MinIO客户端实例"""
|
||||||
|
global minio_client
|
||||||
|
if minio_client is None:
|
||||||
|
minio_client = MinIOClient()
|
||||||
|
minio_client.connect()
|
||||||
|
minio_client.ensure_bucket()
|
||||||
|
return minio_client
|
||||||
|
|
||||||
|
|
||||||
|
def _build_document_records(limit: Optional[int] = None):
|
||||||
|
"""构建文档列表记录,支持按最近更新时间倒序截断。"""
|
||||||
|
minio = get_minio_client()
|
||||||
|
|
||||||
|
document_records = []
|
||||||
|
objects = minio.client.list_objects(minio.bucket, recursive=True)
|
||||||
|
for obj in objects:
|
||||||
|
parts = obj.object_name.split("/", 1)
|
||||||
|
if len(parts) != 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
doc_id, filename = parts
|
||||||
|
last_modified = getattr(obj, "last_modified", None)
|
||||||
|
document_records.append({
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"filename": filename,
|
||||||
|
"size": getattr(obj, "size", 0) or 0,
|
||||||
|
"object_name": obj.object_name,
|
||||||
|
"download_url": f"/api/v1/documents/download/{doc_id}",
|
||||||
|
"last_modified": last_modified.isoformat() if last_modified else None,
|
||||||
|
"_sort_key": last_modified.timestamp() if last_modified else 0,
|
||||||
|
})
|
||||||
|
|
||||||
|
document_records.sort(key=lambda item: item["_sort_key"], reverse=True)
|
||||||
|
if limit is not None:
|
||||||
|
document_records = document_records[:limit]
|
||||||
|
|
||||||
|
for item in document_records:
|
||||||
|
item.pop("_sort_key", None)
|
||||||
|
|
||||||
|
return document_records
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/upload", response_model=DocumentUploadResponse)
|
||||||
|
async def upload_document(
|
||||||
|
file: UploadFile = File(..., description="上传的文档文件"),
|
||||||
|
doc_name: Optional[str] = Form(None, description="文档名称"),
|
||||||
|
regulation_type: Optional[str] = Form(None, description="法规类型"),
|
||||||
|
version: Optional[str] = Form(None, description="文档版本"),
|
||||||
|
generate_summary: bool = Form(False, description="是否生成摘要(默认不生成,可节省约60秒)")
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
上传文档并处理
|
||||||
|
|
||||||
|
支持格式:PDF、DOCX、DOC
|
||||||
|
处理流程:解析 → 分块 → 嵌入 → 入库(摘要可选)
|
||||||
|
文件存储:MinIO对象存储
|
||||||
|
|
||||||
|
参数说明:
|
||||||
|
- generate_summary: 是否生成LLM摘要,默认False。勾选后处理时间增加约60秒。
|
||||||
|
"""
|
||||||
|
# 验证文件类型
|
||||||
|
ext = os.path.splitext(file.filename)[1].lower()
|
||||||
|
if ext not in [".pdf", ".docx", ".doc"]:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"不支持的文件类型: {ext},仅支持PDF、DOCX、DOC"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证文件大小
|
||||||
|
if file.size and file.size > settings.max_file_size_mb * 1024 * 1024:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"文件过大,最大支持{settings.max_file_size_mb}MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 生成文档ID
|
||||||
|
doc_id = str(uuid.uuid4())[:8]
|
||||||
|
|
||||||
|
# 文档名称
|
||||||
|
final_doc_name = doc_name or file.filename
|
||||||
|
|
||||||
|
# MinIO对象名称
|
||||||
|
object_name = f"{doc_id}/{file.filename}"
|
||||||
|
|
||||||
|
logger.info(f"接收到文件上传: {final_doc_name}, 类型: {ext}, doc_id={doc_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 读取文件内容
|
||||||
|
content = await file.read()
|
||||||
|
|
||||||
|
# 保存临时文件用于处理
|
||||||
|
temp_dir = tempfile.gettempdir()
|
||||||
|
temp_path = os.path.join(temp_dir, f"{doc_id}_{file.filename}")
|
||||||
|
|
||||||
|
with open(temp_path, "wb") as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
logger.info(f"临时文件已保存到: {temp_path}")
|
||||||
|
|
||||||
|
# 上传到MinIO
|
||||||
|
minio = get_minio_client()
|
||||||
|
upload_success = minio.upload_bytes(
|
||||||
|
data=content,
|
||||||
|
object_name=object_name,
|
||||||
|
content_type=minio._get_content_type(file.filename),
|
||||||
|
metadata={
|
||||||
|
"doc_id": doc_id # 仅传递ASCII安全的metadata
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if upload_success:
|
||||||
|
logger.success(f"文件已上传到MinIO: {object_name}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"MinIO上传失败,仅使用本地临时文件")
|
||||||
|
|
||||||
|
# 处理文档(传入相同的doc_id,保持一致性)
|
||||||
|
processor = DocumentProcessor(generate_summary=generate_summary)
|
||||||
|
result = processor.process(
|
||||||
|
file_path=temp_path,
|
||||||
|
doc_id=doc_id, # 使用相同的doc_id
|
||||||
|
doc_name=final_doc_name,
|
||||||
|
regulation_type=regulation_type or "",
|
||||||
|
version=version or ""
|
||||||
|
)
|
||||||
|
processor.close()
|
||||||
|
|
||||||
|
# 清理临时文件
|
||||||
|
try:
|
||||||
|
os.remove(temp_path)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
return DocumentUploadResponse(
|
||||||
|
doc_id=result.doc_id,
|
||||||
|
doc_name=result.doc_name,
|
||||||
|
status="success",
|
||||||
|
message=result.message,
|
||||||
|
num_chunks=result.num_chunks,
|
||||||
|
summary=result.summary,
|
||||||
|
summary_latency_ms=result.summary_latency_ms
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=result.message
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"文档处理失败: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"文档处理失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/status/{doc_id}", response_model=DocumentUploadResponse)
|
||||||
|
async def get_document_status(doc_id: str):
|
||||||
|
"""
|
||||||
|
查询文档处理状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_id: 文档ID
|
||||||
|
"""
|
||||||
|
# TODO: 实现状态查询(需要数据库支持)
|
||||||
|
return DocumentUploadResponse(
|
||||||
|
doc_id=doc_id,
|
||||||
|
doc_name="",
|
||||||
|
status="unknown",
|
||||||
|
message="状态查询功能待实现"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/download/{doc_id}")
|
||||||
|
async def download_document(doc_id: str):
|
||||||
|
"""
|
||||||
|
下载文档(从MinIO获取)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_id: 文档ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
文件下载响应
|
||||||
|
"""
|
||||||
|
logger.info(f"请求下载文档: doc_id={doc_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
minio = get_minio_client()
|
||||||
|
|
||||||
|
# 查找该doc_id下的文件(MinIO对象名称格式: {doc_id}/{filename})
|
||||||
|
objects = minio.list_objects(prefix=f"{doc_id}/")
|
||||||
|
|
||||||
|
if not objects:
|
||||||
|
logger.warning(f"MinIO中未找到文档: doc_id={doc_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"文档不存在: doc_id={doc_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取第一个匹配的对象
|
||||||
|
object_name = objects[0]
|
||||||
|
logger.info(f"找到MinIO对象: {object_name}")
|
||||||
|
|
||||||
|
# 获取文件数据
|
||||||
|
file_data = minio.get_object_data(object_name)
|
||||||
|
if file_data is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"获取文档数据失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解析原始文件名
|
||||||
|
original_name = object_name.split("/", 1)[1] if "/" in object_name else object_name
|
||||||
|
|
||||||
|
# 获取Content-Type
|
||||||
|
content_type = minio._get_content_type(original_name)
|
||||||
|
|
||||||
|
logger.success(f"文档下载成功: {original_name}, 大小={len(file_data)}")
|
||||||
|
|
||||||
|
# 返回文件流(URL编码文件名以支持中文)
|
||||||
|
encoded_name = quote(original_name)
|
||||||
|
return StreamingResponse(
|
||||||
|
BytesIO(file_data),
|
||||||
|
media_type=content_type,
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_name}"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"文档下载失败: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"文档下载失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/list")
|
||||||
|
async def list_documents():
|
||||||
|
"""
|
||||||
|
列出所有已上传的文档(从MinIO获取)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
documents = _build_document_records()
|
||||||
|
return {"documents": documents, "total": len(documents)}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"列出文档失败: {e}")
|
||||||
|
return {"documents": [], "total": 0, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/management-list")
|
||||||
|
async def get_document_management_list():
|
||||||
|
"""
|
||||||
|
文档管理清单接口:仅返回最近的10条文档。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
documents = _build_document_records(limit=10)
|
||||||
|
return {"documents": documents, "total": len(documents), "limit": 10}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取文档管理清单失败: {e}")
|
||||||
|
return {"documents": [], "total": 0, "limit": 10, "error": str(e)}
|
||||||
81
backend/app/api/routes/knowledge.py
Normal file
81
backend/app/api/routes/knowledge.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
# src/api/routes/knowledge.py
|
||||||
|
"""知识库检索接口"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from ..models import SearchRequest, SearchResponse, SearchResultItem, ErrorResponse
|
||||||
|
from app.services.document_processor import DocumentProcessor
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/knowledge", tags=["knowledge"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/search", response_model=SearchResponse)
|
||||||
|
async def search_knowledge(request: SearchRequest):
|
||||||
|
"""
|
||||||
|
检索法规知识库
|
||||||
|
|
||||||
|
使用混合检索:Dense向量 + Sparse向量 + RRF融合
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 检索请求参数
|
||||||
|
"""
|
||||||
|
if not request.query or len(request.query.strip()) == 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="查询文本不能为空"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"收到检索请求: {request.query}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 执行检索
|
||||||
|
processor = DocumentProcessor()
|
||||||
|
results = processor.search(
|
||||||
|
query=request.query,
|
||||||
|
top_k=request.top_k,
|
||||||
|
filters=request.filters
|
||||||
|
)
|
||||||
|
processor.close()
|
||||||
|
|
||||||
|
# 转换结果格式
|
||||||
|
result_items = []
|
||||||
|
for r in results:
|
||||||
|
item = SearchResultItem(
|
||||||
|
id=r.get("id", 0),
|
||||||
|
content=r.get("content", ""),
|
||||||
|
score=r.get("score", 0.0),
|
||||||
|
metadata=r.get("metadata", {})
|
||||||
|
)
|
||||||
|
result_items.append(item)
|
||||||
|
|
||||||
|
return SearchResponse(
|
||||||
|
query=request.query,
|
||||||
|
total=len(result_items),
|
||||||
|
results=result_items
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"检索失败: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"检索失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/retrieval", response_model=SearchResponse)
|
||||||
|
async def knowledge_retrieval(request: SearchRequest):
|
||||||
|
"""
|
||||||
|
知识检索接口(与架构文档对齐)
|
||||||
|
|
||||||
|
该接口实现完整的检索流程:
|
||||||
|
1. 意图识别
|
||||||
|
2. BM25关键词检索 + 向量语义检索(双路召回)
|
||||||
|
3. Cross-Encoder精排
|
||||||
|
4. 返回结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 检索请求
|
||||||
|
"""
|
||||||
|
# 当前版本使用混合检索,后续可添加精排步骤
|
||||||
|
return await search_knowledge(request)
|
||||||
74
backend/app/api/routes/rag.py
Normal file
74
backend/app/api/routes/rag.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion
|
||||||
|
from app.services.mock_data import (
|
||||||
|
get_mock_quick_questions,
|
||||||
|
get_mock_retrieval,
|
||||||
|
get_mock_rag_answer,
|
||||||
|
)
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/rag", tags=["RAG问答"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/chat")
|
||||||
|
async def rag_chat(request: RagChatRequest):
|
||||||
|
"""SSE流式问答"""
|
||||||
|
|
||||||
|
async def generate():
|
||||||
|
# 发送检索开始事件
|
||||||
|
yield {"event": "message", "data": json.dumps({"type": "retrieving"})}
|
||||||
|
|
||||||
|
# 模拟检索延迟
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
# 执行检索
|
||||||
|
docs = get_mock_retrieval(request.query, top_k=request.top_k)
|
||||||
|
|
||||||
|
retrieved_data = [
|
||||||
|
{
|
||||||
|
"id": d["id"],
|
||||||
|
"score": d["score"],
|
||||||
|
"preview": d["preview"],
|
||||||
|
"doc_name": d.get("doc_name", ""),
|
||||||
|
"clause": d.get("clause", ""),
|
||||||
|
}
|
||||||
|
for d in docs
|
||||||
|
]
|
||||||
|
yield {"event": "message", "data": json.dumps({"type": "retrieved", "docs": retrieved_data})}
|
||||||
|
|
||||||
|
# 发送生成开始事件
|
||||||
|
yield {"event": "message", "data": json.dumps({"type": "generating", "text": "正在生成答案..."})}
|
||||||
|
|
||||||
|
# 模拟生成延迟
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
|
||||||
|
# 获取预设答案
|
||||||
|
answer = get_mock_rag_answer(request.query)
|
||||||
|
|
||||||
|
# 流式输出答案(按句子分割)
|
||||||
|
sentences = answer.split("\n\n")
|
||||||
|
for sentence in sentences:
|
||||||
|
if sentence.strip():
|
||||||
|
# 进一步分割长句子
|
||||||
|
chunks = sentence.split("\n")
|
||||||
|
for chunk in chunks:
|
||||||
|
if chunk.strip():
|
||||||
|
await asyncio.sleep(0.05) # 模拟生成延迟
|
||||||
|
yield {"event": "message", "data": json.dumps({"type": "chunk", "text": chunk + "\n"})}
|
||||||
|
|
||||||
|
# 发送完成事件
|
||||||
|
yield {"event": "message", "data": json.dumps({"type": "done"})}
|
||||||
|
|
||||||
|
return EventSourceResponse(generate())
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/quick-questions", response_model=QuickQuestionsResponse)
|
||||||
|
async def get_quick_questions():
|
||||||
|
"""获取预设快捷问题"""
|
||||||
|
questions = [
|
||||||
|
QuickQuestion(id=q["id"], question=q["question"], category=q["category"])
|
||||||
|
for q in get_mock_quick_questions()
|
||||||
|
]
|
||||||
|
return QuickQuestionsResponse(questions=questions)
|
||||||
28
backend/app/api/routes/status.py
Normal file
28
backend/app/api/routes/status.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.services.mock_data import MOCK_SYSTEM_STATS, MOCK_SYSTEM_CONFIG
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/status", tags=["系统状态"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/stats")
|
||||||
|
async def get_stats():
|
||||||
|
"""获取系统统计"""
|
||||||
|
# 返回预设统计数据
|
||||||
|
return MOCK_SYSTEM_STATS
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config")
|
||||||
|
async def get_config():
|
||||||
|
"""获取当前配置"""
|
||||||
|
return MOCK_SYSTEM_CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/milvus/health")
|
||||||
|
async def milvus_health():
|
||||||
|
"""Milvus健康检查"""
|
||||||
|
# 模拟连接状态(假数据模式下始终返回连接成功)
|
||||||
|
return {
|
||||||
|
"connected": True,
|
||||||
|
"collections": ["vehicle_regulations"],
|
||||||
|
}
|
||||||
6
backend/app/config/__init__.py
Normal file
6
backend/app/config/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
# src/config/__init__.py
|
||||||
|
"""配置模块"""
|
||||||
|
|
||||||
|
from .settings import Settings, get_settings, settings
|
||||||
|
|
||||||
|
__all__ = ["Settings", "get_settings", "settings"]
|
||||||
32
backend/app/config/logging.py
Normal file
32
backend/app/config/logging.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
# src/config/logging.py
|
||||||
|
"""日志配置"""
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(level: str = "INFO"):
|
||||||
|
"""设置日志配置"""
|
||||||
|
|
||||||
|
# 移除默认handler
|
||||||
|
logger.remove()
|
||||||
|
|
||||||
|
# 添加控制台输出
|
||||||
|
logger.add(
|
||||||
|
sys.stdout,
|
||||||
|
level=level,
|
||||||
|
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
|
||||||
|
colorize=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加文件输出
|
||||||
|
logger.add(
|
||||||
|
"logs/app_{time:YYYY-MM-DD}.log",
|
||||||
|
level=level,
|
||||||
|
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}",
|
||||||
|
rotation="00:00",
|
||||||
|
retention="7 days",
|
||||||
|
compression="zip"
|
||||||
|
)
|
||||||
|
|
||||||
|
return logger
|
||||||
95
backend/app/config/settings.py
Normal file
95
backend/app/config/settings.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
# src/config/settings.py
|
||||||
|
"""配置管理 - 环境变量和默认配置"""
|
||||||
|
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from pydantic import Field
|
||||||
|
from typing import Optional
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""应用配置"""
|
||||||
|
|
||||||
|
# 应用基础配置
|
||||||
|
app_name: str = Field(default="AI Regulations Demo", description="Application name")
|
||||||
|
app_version: str = Field(default="0.1.0", description="应用版本")
|
||||||
|
debug: bool = Field(default=False, description="调试模式")
|
||||||
|
|
||||||
|
# Milvus向量数据库配置
|
||||||
|
milvus_host: str = Field(default="localhost", description="Milvus服务地址")
|
||||||
|
milvus_port: int = Field(default=19530, description="Milvus服务端口")
|
||||||
|
milvus_collection: str = Field(default="regulations", description="法规向量集合名称")
|
||||||
|
milvus_db_name: str = Field(default="default", description="Milvus数据库名称")
|
||||||
|
|
||||||
|
# 嵌入模型配置
|
||||||
|
embedding_model: str = Field(default="BAAI/bge-m3", description="嵌入模型名称")
|
||||||
|
embedding_dim: int = Field(default=1024, description="嵌入向量维度")
|
||||||
|
embedding_max_length: int = Field(default=8192, description="最大嵌入长度")
|
||||||
|
embedding_batch_size: int = Field(default=12, description="嵌入批处理大小")
|
||||||
|
embedding_use_fp16: bool = Field(default=True, description="使用FP16加速")
|
||||||
|
|
||||||
|
# MinIO对象存储配置
|
||||||
|
minio_endpoint: str = Field(default="localhost:9000", description="MinIO服务地址")
|
||||||
|
minio_access_key: str = Field(default="minioadmin", description="MinIO访问密钥")
|
||||||
|
minio_secret_key: str = Field(default="minioadmin123", description="MinIO秘密密钥")
|
||||||
|
minio_bucket: str = Field(default="upload-files", description="文档存储桶名称")
|
||||||
|
minio_secure: bool = Field(default=False, description="是否使用HTTPS")
|
||||||
|
|
||||||
|
# Redis配置
|
||||||
|
redis_host: str = Field(default="localhost", description="Redis服务地址")
|
||||||
|
redis_port: int = Field(default=6379, description="Redis服务端口")
|
||||||
|
redis_password: str = Field(default="", description="Redis密码")
|
||||||
|
redis_db: int = Field(default=0, description="Redis数据库编号")
|
||||||
|
|
||||||
|
# PostgreSQL配置
|
||||||
|
postgres_host: str = Field(default="localhost", description="PostgreSQL服务地址")
|
||||||
|
postgres_port: int = Field(default=5432, description="PostgreSQL服务端口")
|
||||||
|
postgres_user: str = Field(default="compliance", description="PostgreSQL用户名")
|
||||||
|
postgres_password: str = Field(default="compliance123", description="PostgreSQL密码")
|
||||||
|
postgres_db: str = Field(default="compliance_db", description="PostgreSQL数据库名称")
|
||||||
|
|
||||||
|
# 文档处理配置
|
||||||
|
chunk_size: int = Field(default=512, description="分块大小(字符数)")
|
||||||
|
chunk_overlap: int = Field(default=50, description="分块重叠大小")
|
||||||
|
max_file_size_mb: int = Field(default=100, description="最大文件大小(MB)")
|
||||||
|
|
||||||
|
# API配置
|
||||||
|
api_host: str = Field(default="0.0.0.0", description="API服务地址")
|
||||||
|
api_port: int = Field(default=8000, description="API服务端口")
|
||||||
|
|
||||||
|
# LLM配置
|
||||||
|
llm_provider: str = Field(default="deepseek", description="LLM提供商 (deepseek/qwen/qwen_vl)")
|
||||||
|
llm_model: str = Field(default="deepseek-v4-flash", description="LLM模型名称")
|
||||||
|
llm_max_tokens: int = Field(default=4096, description="LLM最大输出token数")
|
||||||
|
llm_temperature: float = Field(default=0.7, description="LLM温度参数")
|
||||||
|
|
||||||
|
# DeepSeek配置
|
||||||
|
deepseek_api_key: str = Field(default="", description="DeepSeek API密钥")
|
||||||
|
deepseek_base_url: str = Field(default="http://6.86.80.4:30080/v1", description="DeepSeek API地址")
|
||||||
|
deepseek_model: str = Field(default="deepseek-v4-flash", description="DeepSeek模型")
|
||||||
|
|
||||||
|
# Qwen配置(通过统一代理API)
|
||||||
|
qwen_api_key: str = Field(default="", description="Qwen API密钥")
|
||||||
|
qwen_base_url: str = Field(default="http://6.86.80.4:30080/v1", description="Qwen API地址")
|
||||||
|
qwen_model: str = Field(default="qwen3.5-flash", description="Qwen文本模型")
|
||||||
|
qwen_vl_model: str = Field(default="qwen3-vl-plus", description="Qwen视觉模型")
|
||||||
|
|
||||||
|
# RAG配置
|
||||||
|
rag_top_k: int = Field(default=5, description="检索召回数量")
|
||||||
|
rag_max_context_tokens: int = Field(default=2000, description="RAG最大上下文token数")
|
||||||
|
rag_summary_max_tokens: int = Field(default=10240, description="文档摘要最大token数")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file = ".env"
|
||||||
|
env_file_encoding = "utf-8"
|
||||||
|
extra = "ignore"
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
"""获取配置实例(缓存)"""
|
||||||
|
return Settings()
|
||||||
|
|
||||||
|
|
||||||
|
# 导出默认配置实例
|
||||||
|
settings = get_settings()
|
||||||
3
backend/app/core/__init__.py
Normal file
3
backend/app/core/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .config import settings, Settings
|
||||||
|
|
||||||
|
__all__ = ["settings", "Settings"]
|
||||||
41
backend/app/core/config.py
Normal file
41
backend/app/core/config.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
# DashScope API
|
||||||
|
dashscope_api_key: str = ""
|
||||||
|
|
||||||
|
# Milvus
|
||||||
|
milvus_host: str = "localhost"
|
||||||
|
milvus_port: int = 19530
|
||||||
|
|
||||||
|
# LLM配置
|
||||||
|
llm_model: str = "qwen-max"
|
||||||
|
embedding_model: str = "text-embedding-v3"
|
||||||
|
embedding_dim: int = 1536
|
||||||
|
|
||||||
|
# 检索配置
|
||||||
|
vector_top_k: int = 10
|
||||||
|
bm25_top_k: int = 10
|
||||||
|
final_top_k: int = 5
|
||||||
|
|
||||||
|
# 分块配置
|
||||||
|
chunk_size: int = 800
|
||||||
|
chunk_overlap: int = 50
|
||||||
|
|
||||||
|
# 服务配置
|
||||||
|
api_host: str = "0.0.0.0"
|
||||||
|
api_port: int = 8000
|
||||||
|
|
||||||
|
# Collection名称
|
||||||
|
regulations_collection: str = "vehicle_regulations"
|
||||||
|
compliance_collection: str = "compliance_cache"
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file = ".env"
|
||||||
|
env_file_encoding = "utf-8"
|
||||||
|
case_sensitive = False
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
5
backend/app/main.py
Normal file
5
backend/app/main.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Backend application entrypoint."""
|
||||||
|
|
||||||
|
from app.api.main import app
|
||||||
|
|
||||||
|
__all__ = ["app"]
|
||||||
49
backend/app/schemas/__init__.py
Normal file
49
backend/app/schemas/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
from .doc import (
|
||||||
|
DocumentUploadResponse,
|
||||||
|
DocumentInfo,
|
||||||
|
DocumentListResponse,
|
||||||
|
ChunkInfo,
|
||||||
|
ParseResponse,
|
||||||
|
EmbedResponse,
|
||||||
|
)
|
||||||
|
from .rag import (
|
||||||
|
RagChatRequest,
|
||||||
|
RetrievedDoc,
|
||||||
|
SourceInfo,
|
||||||
|
QuickQuestion,
|
||||||
|
QuickQuestionsResponse,
|
||||||
|
)
|
||||||
|
from .compliance import (
|
||||||
|
RiskLevel,
|
||||||
|
ComplianceStatus,
|
||||||
|
Regulation,
|
||||||
|
ComplianceSegment,
|
||||||
|
RiskDashboard,
|
||||||
|
PriorityAction,
|
||||||
|
ComplianceResult,
|
||||||
|
ComplianceChatRequest,
|
||||||
|
AnalyzeResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DocumentUploadResponse",
|
||||||
|
"DocumentInfo",
|
||||||
|
"DocumentListResponse",
|
||||||
|
"ChunkInfo",
|
||||||
|
"ParseResponse",
|
||||||
|
"EmbedResponse",
|
||||||
|
"RagChatRequest",
|
||||||
|
"RetrievedDoc",
|
||||||
|
"SourceInfo",
|
||||||
|
"QuickQuestion",
|
||||||
|
"QuickQuestionsResponse",
|
||||||
|
"RiskLevel",
|
||||||
|
"ComplianceStatus",
|
||||||
|
"Regulation",
|
||||||
|
"ComplianceSegment",
|
||||||
|
"RiskDashboard",
|
||||||
|
"PriorityAction",
|
||||||
|
"ComplianceResult",
|
||||||
|
"ComplianceChatRequest",
|
||||||
|
"AnalyzeResponse",
|
||||||
|
]
|
||||||
69
backend/app/schemas/compliance.py
Normal file
69
backend/app/schemas/compliance.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class RiskLevel(str, Enum):
|
||||||
|
high = "high"
|
||||||
|
medium = "medium"
|
||||||
|
low = "low"
|
||||||
|
|
||||||
|
|
||||||
|
class ComplianceStatus(str, Enum):
|
||||||
|
pass_status = "pass"
|
||||||
|
warning = "warning"
|
||||||
|
fail = "fail"
|
||||||
|
|
||||||
|
|
||||||
|
class Regulation(BaseModel):
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
clause: Optional[str] = None
|
||||||
|
score: float
|
||||||
|
match_keyword: str
|
||||||
|
category: RiskLevel
|
||||||
|
full_content: str
|
||||||
|
|
||||||
|
|
||||||
|
class ComplianceSegment(BaseModel):
|
||||||
|
id: int
|
||||||
|
index: int
|
||||||
|
intent: str
|
||||||
|
start_pos: int
|
||||||
|
end_pos: int
|
||||||
|
content: str
|
||||||
|
risk_level: RiskLevel
|
||||||
|
regulations: list[Regulation]
|
||||||
|
|
||||||
|
|
||||||
|
class RiskDashboard(BaseModel):
|
||||||
|
score: float
|
||||||
|
high_risk_count: int
|
||||||
|
medium_risk_count: int
|
||||||
|
low_risk_count: int
|
||||||
|
need_fix_segments: int
|
||||||
|
status: ComplianceStatus
|
||||||
|
status_label: str
|
||||||
|
|
||||||
|
|
||||||
|
class PriorityAction(BaseModel):
|
||||||
|
regulation: str
|
||||||
|
issue: str
|
||||||
|
suggestion: str
|
||||||
|
severity: RiskLevel
|
||||||
|
|
||||||
|
|
||||||
|
class ComplianceResult(BaseModel):
|
||||||
|
task_id: str
|
||||||
|
dashboard: RiskDashboard
|
||||||
|
segments: list[ComplianceSegment]
|
||||||
|
priority_actions: list[PriorityAction]
|
||||||
|
|
||||||
|
|
||||||
|
class ComplianceChatRequest(BaseModel):
|
||||||
|
query: str
|
||||||
|
|
||||||
|
|
||||||
|
class AnalyzeResponse(BaseModel):
|
||||||
|
task_id: str
|
||||||
|
status: str = "processing"
|
||||||
44
backend/app/schemas/doc.py
Normal file
44
backend/app/schemas/doc.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentUploadResponse(BaseModel):
|
||||||
|
doc_id: str
|
||||||
|
filename: str
|
||||||
|
size: int
|
||||||
|
status: str = "uploaded"
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentInfo(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
chunks: int
|
||||||
|
status: str
|
||||||
|
created_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentListResponse(BaseModel):
|
||||||
|
docs: list[DocumentInfo]
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkInfo(BaseModel):
|
||||||
|
chunk_id: str
|
||||||
|
doc_name: str
|
||||||
|
clause_id: Optional[str] = None
|
||||||
|
chapter: Optional[str] = None
|
||||||
|
content: str
|
||||||
|
token_count: int
|
||||||
|
chunk_index: int
|
||||||
|
|
||||||
|
|
||||||
|
class ParseResponse(BaseModel):
|
||||||
|
doc_id: str
|
||||||
|
chunks: int
|
||||||
|
status: str = "parsed"
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedResponse(BaseModel):
|
||||||
|
doc_id: str
|
||||||
|
vectors: int
|
||||||
|
status: str = "embedded"
|
||||||
31
backend/app/schemas/rag.py
Normal file
31
backend/app/schemas/rag.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class RagChatRequest(BaseModel):
|
||||||
|
query: str
|
||||||
|
top_k: int = 5
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievedDoc(BaseModel):
|
||||||
|
id: str
|
||||||
|
doc_name: str
|
||||||
|
clause_id: Optional[str] = None
|
||||||
|
score: float
|
||||||
|
content: str
|
||||||
|
preview: str
|
||||||
|
|
||||||
|
|
||||||
|
class SourceInfo(BaseModel):
|
||||||
|
name: str
|
||||||
|
clause: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class QuickQuestion(BaseModel):
|
||||||
|
id: str
|
||||||
|
question: str
|
||||||
|
category: str
|
||||||
|
|
||||||
|
|
||||||
|
class QuickQuestionsResponse(BaseModel):
|
||||||
|
questions: list[QuickQuestion]
|
||||||
3
backend/app/services/__init__.py
Normal file
3
backend/app/services/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""Backend service package."""
|
||||||
|
|
||||||
|
__all__: list[str] = []
|
||||||
7
backend/app/services/agent/__init__.py
Normal file
7
backend/app/services/agent/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# src/services/agent/__init__.py
|
||||||
|
"""Agent服务模块"""
|
||||||
|
|
||||||
|
from .qa_agent import QAAgent, ask_compliance_question
|
||||||
|
from .session_manager import SessionManager, ChatSession
|
||||||
|
|
||||||
|
__all__ = ["QAAgent", "ask_compliance_question", "SessionManager", "ChatSession"]
|
||||||
412
backend/app/services/agent/qa_agent.py
Normal file
412
backend/app/services/agent/qa_agent.py
Normal file
@@ -0,0 +1,412 @@
|
|||||||
|
# src/services/agent/qa_agent.py
|
||||||
|
"""RAG问答Agent - 合规智能问答核心实现"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import List, Dict, Optional, Any, Generator
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from app.services.llm import get_llm_client, BaseLLMClient, LLMResponse
|
||||||
|
from app.services.llm.llm_factory import LLMFactory
|
||||||
|
from app.services.rag.retriever import Retriever, RetrievedDocument
|
||||||
|
from app.services.rag.context_builder import ContextBuilder, RAGContext
|
||||||
|
from app.services.rag.prompt_templates import get_prompt_template, PromptTemplate
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentResponse:
|
||||||
|
"""Agent响应结果"""
|
||||||
|
answer: str
|
||||||
|
sources: List[Dict] = field(default_factory=list)
|
||||||
|
model: str = ""
|
||||||
|
latency_ms: int = 0
|
||||||
|
retrieved_count: int = 0
|
||||||
|
context_tokens: int = 0
|
||||||
|
truncated: bool = False
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_success(self) -> bool:
|
||||||
|
return self.error is None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentConfig:
|
||||||
|
"""Agent配置"""
|
||||||
|
llm_provider: str = "deepseek"
|
||||||
|
llm_model: str = "deepseek-v4-flash"
|
||||||
|
top_k: int = 5
|
||||||
|
min_score: float = 0.3
|
||||||
|
max_context_tokens: int = 2000
|
||||||
|
temperature: float = 0.7
|
||||||
|
prompt_template: str = "compliance_qa"
|
||||||
|
include_metadata: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class QAAgent:
|
||||||
|
"""
|
||||||
|
合规问答Agent
|
||||||
|
|
||||||
|
核心流程:
|
||||||
|
1. 接收用户问题
|
||||||
|
2. Milvus混合检索相关法规条款
|
||||||
|
3. 构建RAG上下文
|
||||||
|
4. 调用LLM生成回答
|
||||||
|
5. 返回答案和引用来源
|
||||||
|
|
||||||
|
使用示例:
|
||||||
|
agent = QAAgent()
|
||||||
|
response = agent.ask("机动车安全技术检验有哪些要求?")
|
||||||
|
print(response.answer)
|
||||||
|
for source in response.sources:
|
||||||
|
print(f"引用: {source['doc_name']} - {source['clause_number']}")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[AgentConfig] = None):
|
||||||
|
"""
|
||||||
|
初始化问答Agent
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Agent配置(可选,使用默认配置)
|
||||||
|
"""
|
||||||
|
self.config = config or AgentConfig(
|
||||||
|
llm_provider=settings.llm_provider,
|
||||||
|
llm_model=settings.llm_model,
|
||||||
|
top_k=settings.rag_top_k,
|
||||||
|
max_context_tokens=settings.rag_max_context_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# 初始化组件(延迟加载)
|
||||||
|
self.llm: Optional[BaseLLMClient] = None
|
||||||
|
self.retriever: Optional[Retriever] = None
|
||||||
|
self.context_builder: Optional[ContextBuilder] = None
|
||||||
|
|
||||||
|
logger.info(f"问答Agent初始化: provider={self.config.llm_provider}, model={self.config.llm_model}")
|
||||||
|
|
||||||
|
def _init_llm(self):
|
||||||
|
"""延迟初始化LLM客户端(优先使用全局缓存)"""
|
||||||
|
if self.llm is None:
|
||||||
|
# 尝试先获取全局缓存的客户端
|
||||||
|
cached = LLMFactory.get_global_client(self.config.llm_provider, self.config.llm_model)
|
||||||
|
if cached:
|
||||||
|
self.llm = cached
|
||||||
|
logger.debug(f"使用全局缓存的LLM客户端: {self.config.llm_provider} - {self.config.llm_model}")
|
||||||
|
else:
|
||||||
|
logger.info("创建新的LLM客户端...")
|
||||||
|
self.llm = get_llm_client(
|
||||||
|
provider=self.config.llm_provider,
|
||||||
|
model=self.config.llm_model,
|
||||||
|
temperature=self.config.temperature
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_retriever(self):
|
||||||
|
"""延迟初始化检索器"""
|
||||||
|
if self.retriever is None:
|
||||||
|
logger.info("初始化检索器...")
|
||||||
|
self.retriever = Retriever(
|
||||||
|
top_k=self.config.top_k,
|
||||||
|
min_score=self.config.min_score
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_context_builder(self):
|
||||||
|
"""延迟初始化上下文构建器"""
|
||||||
|
if self.context_builder is None:
|
||||||
|
logger.info("初始化上下文构建器...")
|
||||||
|
self.context_builder = ContextBuilder(
|
||||||
|
max_context_tokens=self.config.max_context_tokens,
|
||||||
|
include_metadata=self.config.include_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
def ask(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
filters: Optional[str] = None,
|
||||||
|
prompt_template: Optional[str] = None
|
||||||
|
) -> AgentResponse:
|
||||||
|
"""
|
||||||
|
回答用户问题
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 用户问题
|
||||||
|
filters: 检索过滤条件(如 "regulation_type=='车辆安全'")
|
||||||
|
prompt_template: Prompt模板名称(可选,覆盖默认配置)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentResponse: 包含答案和引用来源的响应对象
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
logger.info(f"收到问题: {query}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Step 1: 检索相关法规
|
||||||
|
self._init_retriever()
|
||||||
|
documents = self.retriever.retrieve(query, filters)
|
||||||
|
retrieved_count = len(documents)
|
||||||
|
|
||||||
|
if retrieved_count == 0:
|
||||||
|
return AgentResponse(
|
||||||
|
answer="抱歉,未找到与您的问题相关的法规条款。请尝试用不同的关键词重新提问,或提供更具体的法规名称。",
|
||||||
|
retrieved_count=0,
|
||||||
|
error="no_retrieved_documents"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 2: 构建RAG上下文
|
||||||
|
self._init_context_builder()
|
||||||
|
template_name = prompt_template or self.config.prompt_template
|
||||||
|
template = get_prompt_template(template_name)
|
||||||
|
|
||||||
|
context = self.context_builder.build(
|
||||||
|
query=query,
|
||||||
|
documents=documents,
|
||||||
|
system_prompt=template.system_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: 构建LLM输入消息
|
||||||
|
messages = self._build_messages(template, context)
|
||||||
|
|
||||||
|
# Step 4: 调用LLM生成回答
|
||||||
|
self._init_llm()
|
||||||
|
llm_response = self.llm.chat(
|
||||||
|
messages=messages,
|
||||||
|
temperature=self.config.temperature
|
||||||
|
)
|
||||||
|
|
||||||
|
if not llm_response.is_success:
|
||||||
|
return AgentResponse(
|
||||||
|
answer="",
|
||||||
|
retrieved_count=retrieved_count,
|
||||||
|
error=llm_response.error
|
||||||
|
)
|
||||||
|
|
||||||
|
latency_ms = int((time.time() - start_time) * 1000)
|
||||||
|
|
||||||
|
# Step 5: 返回结果
|
||||||
|
logger.success(f"问答完成: {latency_ms}ms, {retrieved_count}条引用")
|
||||||
|
|
||||||
|
return AgentResponse(
|
||||||
|
answer=llm_response.content,
|
||||||
|
sources=context.sources,
|
||||||
|
model=llm_response.model,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
retrieved_count=retrieved_count,
|
||||||
|
context_tokens=context.total_tokens,
|
||||||
|
truncated=context.truncated
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"问答失败: {e}")
|
||||||
|
return AgentResponse(
|
||||||
|
answer="",
|
||||||
|
error=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
def ask_with_context(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
documents: List[RetrievedDocument],
|
||||||
|
prompt_template: Optional[str] = None
|
||||||
|
) -> AgentResponse:
|
||||||
|
"""
|
||||||
|
使用提供的文档回答问题(不执行检索)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 用户问题
|
||||||
|
documents: 已检索的文档列表
|
||||||
|
prompt_template: Prompt模板名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentResponse: 响应结果
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._init_context_builder()
|
||||||
|
self._init_llm()
|
||||||
|
|
||||||
|
template_name = prompt_template or self.config.prompt_template
|
||||||
|
template = get_prompt_template(template_name)
|
||||||
|
|
||||||
|
context = self.context_builder.build(
|
||||||
|
query=query,
|
||||||
|
documents=documents,
|
||||||
|
system_prompt=template.system_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = self._build_messages(template, context)
|
||||||
|
|
||||||
|
llm_response = self.llm.chat(messages)
|
||||||
|
|
||||||
|
latency_ms = int((time.time() - start_time) * 1000)
|
||||||
|
|
||||||
|
return AgentResponse(
|
||||||
|
answer=llm_response.content,
|
||||||
|
sources=context.sources,
|
||||||
|
model=llm_response.model,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
retrieved_count=len(documents),
|
||||||
|
context_tokens=context.total_tokens,
|
||||||
|
truncated=context.truncated
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"问答失败: {e}")
|
||||||
|
return AgentResponse(answer="", error=str(e))
|
||||||
|
|
||||||
|
def _build_messages(
|
||||||
|
self,
|
||||||
|
template: PromptTemplate,
|
||||||
|
context: RAGContext
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
"""构建LLM输入消息"""
|
||||||
|
user_content = template.user_template.format(
|
||||||
|
context=context.context_text,
|
||||||
|
query=context.user_query
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{"role": "system", "content": template.system_prompt},
|
||||||
|
{"role": "user", "content": user_content}
|
||||||
|
]
|
||||||
|
|
||||||
|
def ask_stream(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
filters: Optional[str] = None,
|
||||||
|
prompt_template: Optional[str] = None
|
||||||
|
) -> Generator[Dict[str, Any], None, None]:
|
||||||
|
"""
|
||||||
|
流式回答用户问题(SSE模式)
|
||||||
|
|
||||||
|
返回事件类型:
|
||||||
|
- {"event": "status", "data": "正在检索..."} - 状态更新
|
||||||
|
- {"event": "sources", "data": [...]} - 引用来源
|
||||||
|
- {"event": "content", "data": "文本片段"} - 回答内容
|
||||||
|
- {"event": "done", "data": {"latency_ms": ..., "model": ...}} - 完成
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 用户问题
|
||||||
|
filters: 检索过滤条件
|
||||||
|
prompt_template: Prompt模板名称
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Dict: SSE事件数据
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
logger.info(f"收到流式问题: {query}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Step 1: 检索相关法规
|
||||||
|
yield {"event": "status", "data": "正在检索相关法规..."}
|
||||||
|
self._init_retriever()
|
||||||
|
documents = self.retriever.retrieve(query, filters)
|
||||||
|
retrieved_count = len(documents)
|
||||||
|
|
||||||
|
if retrieved_count == 0:
|
||||||
|
yield {"event": "status", "data": "未找到相关法规"}
|
||||||
|
yield {"event": "content", "data": "抱歉,未找到与您的问题相关的法规条款。请尝试用不同的关键词重新提问。"}
|
||||||
|
yield {"event": "done", "data": {"latency_ms": 0, "retrieved_count": 0}}
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 2: 发送检索结果
|
||||||
|
yield {"event": "status", "data": f"找到{retrieved_count}条相关法规,正在生成回答..."}
|
||||||
|
sources = [
|
||||||
|
{
|
||||||
|
"doc_name": doc.doc_name,
|
||||||
|
"doc_id": doc.doc_id,
|
||||||
|
"clause_number": doc.clause_number,
|
||||||
|
"score": doc.score
|
||||||
|
}
|
||||||
|
for doc in documents[:5] # 只返回前5条引用
|
||||||
|
]
|
||||||
|
yield {"event": "sources", "data": sources}
|
||||||
|
|
||||||
|
# Step 3: 构建RAG上下文
|
||||||
|
self._init_context_builder()
|
||||||
|
template_name = prompt_template or self.config.prompt_template
|
||||||
|
template = get_prompt_template(template_name)
|
||||||
|
context = self.context_builder.build(
|
||||||
|
query=query,
|
||||||
|
documents=documents,
|
||||||
|
system_prompt=template.system_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 4: 构建LLM输入消息
|
||||||
|
messages = self._build_messages(template, context)
|
||||||
|
|
||||||
|
# Step 5: 流式调用LLM生成回答
|
||||||
|
self._init_llm()
|
||||||
|
full_answer = ""
|
||||||
|
|
||||||
|
# 检查LLM是否支持流式输出
|
||||||
|
if hasattr(self.llm, 'stream_chat'):
|
||||||
|
yield {"event": "status", "data": "思考中..."}
|
||||||
|
for chunk in self.llm.stream_chat(
|
||||||
|
messages=messages,
|
||||||
|
temperature=self.config.temperature
|
||||||
|
):
|
||||||
|
full_answer += chunk
|
||||||
|
yield {"event": "content", "data": chunk}
|
||||||
|
else:
|
||||||
|
# 如果不支持流式,回退到普通调用
|
||||||
|
yield {"event": "status", "data": "生成回答中..."}
|
||||||
|
llm_response = self.llm.chat(
|
||||||
|
messages=messages,
|
||||||
|
temperature=self.config.temperature
|
||||||
|
)
|
||||||
|
if llm_response.is_success:
|
||||||
|
full_answer = llm_response.content
|
||||||
|
yield {"event": "content", "data": full_answer}
|
||||||
|
|
||||||
|
# Step 6: 发送完成事件
|
||||||
|
latency_ms = int((time.time() - start_time) * 1000)
|
||||||
|
logger.success(f"流式问答完成: {latency_ms}ms, {retrieved_count}条引用")
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"event": "done",
|
||||||
|
"data": {
|
||||||
|
"latency_ms": latency_ms,
|
||||||
|
"model": self.config.llm_model,
|
||||||
|
"retrieved_count": retrieved_count,
|
||||||
|
"context_tokens": context.total_tokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"流式问答失败: {e}")
|
||||||
|
yield {"event": "error", "data": str(e)}
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""关闭Agent资源(不关闭LLM客户端,因为它全局缓存)"""
|
||||||
|
if self.retriever:
|
||||||
|
self.retriever.close()
|
||||||
|
logger.info("问答Agent已关闭")
|
||||||
|
|
||||||
|
|
||||||
|
def ask_compliance_question(
|
||||||
|
query: str,
|
||||||
|
provider: str = "deepseek",
|
||||||
|
model: str = "deepseek-v4-flash",
|
||||||
|
top_k: int = 10
|
||||||
|
) -> AgentResponse:
|
||||||
|
"""
|
||||||
|
便捷函数:问答合规问题
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 用户问题
|
||||||
|
provider: LLM提供商
|
||||||
|
model: LLM模型
|
||||||
|
top_k: 检索数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentResponse: 响应结果
|
||||||
|
"""
|
||||||
|
config = AgentConfig(
|
||||||
|
llm_provider=provider,
|
||||||
|
llm_model=model,
|
||||||
|
top_k=top_k
|
||||||
|
)
|
||||||
|
agent = QAAgent(config)
|
||||||
|
response = agent.ask(query)
|
||||||
|
agent.close()
|
||||||
|
return response
|
||||||
247
backend/app/services/agent/session_manager.py
Normal file
247
backend/app/services/agent/session_manager.py
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
# src/services/agent/session_manager.py
|
||||||
|
"""多轮对话会话管理"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatMessage:
|
||||||
|
"""对话消息"""
|
||||||
|
role: str # "user" / "assistant" / "system"
|
||||||
|
content: str
|
||||||
|
timestamp: int
|
||||||
|
sources: List[Dict] = field(default_factory=list)
|
||||||
|
metadata: Dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatSession:
|
||||||
|
"""对话会话"""
|
||||||
|
session_id: str
|
||||||
|
messages: List[ChatMessage] = field(default_factory=list)
|
||||||
|
created_at: int = field(default_factory=lambda: int(time.time()))
|
||||||
|
updated_at: int = field(default_factory=lambda: int(time.time()))
|
||||||
|
metadata: Dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
def add_user_message(self, content: str) -> ChatMessage:
|
||||||
|
"""添加用户消息"""
|
||||||
|
message = ChatMessage(
|
||||||
|
role="user",
|
||||||
|
content=content,
|
||||||
|
timestamp=int(time.time())
|
||||||
|
)
|
||||||
|
self.messages.append(message)
|
||||||
|
self.updated_at = int(time.time())
|
||||||
|
return message
|
||||||
|
|
||||||
|
def add_assistant_message(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
sources: List[Dict] = None
|
||||||
|
) -> ChatMessage:
|
||||||
|
"""添加助手消息"""
|
||||||
|
message = ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=content,
|
||||||
|
timestamp=int(time.time()),
|
||||||
|
sources=sources or []
|
||||||
|
)
|
||||||
|
self.messages.append(message)
|
||||||
|
self.updated_at = int(time.time())
|
||||||
|
return message
|
||||||
|
|
||||||
|
def get_history(self, max_turns: int = 5) -> List[Dict[str, str]]:
|
||||||
|
"""获取历史对话(用于LLM上下文)"""
|
||||||
|
history = []
|
||||||
|
# 获取最近N轮对话(每轮包含user + assistant)
|
||||||
|
recent_messages = self.messages[-(max_turns * 2):]
|
||||||
|
|
||||||
|
for msg in recent_messages:
|
||||||
|
history.append({
|
||||||
|
"role": msg.role,
|
||||||
|
"content": msg.content
|
||||||
|
})
|
||||||
|
|
||||||
|
return history
|
||||||
|
|
||||||
|
def clear_history(self):
|
||||||
|
"""清空对话历史"""
|
||||||
|
self.messages = []
|
||||||
|
self.updated_at = int(time.time())
|
||||||
|
logger.info(f"会话历史已清空: {self.session_id}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def message_count(self) -> int:
|
||||||
|
"""消息数量"""
|
||||||
|
return len(self.messages)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
"""是否为空会话"""
|
||||||
|
return len(self.messages) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class SessionManager:
|
||||||
|
"""
|
||||||
|
会话管理器
|
||||||
|
|
||||||
|
功能:
|
||||||
|
- 创建/获取/删除会话
|
||||||
|
- 会话超时清理
|
||||||
|
- 会话历史记录管理
|
||||||
|
|
||||||
|
使用示例:
|
||||||
|
manager = SessionManager()
|
||||||
|
|
||||||
|
# 创建会话
|
||||||
|
session = manager.create_session()
|
||||||
|
|
||||||
|
# 添加消息
|
||||||
|
session.add_user_message("什么是机动车安全技术检验?")
|
||||||
|
session.add_assistant_message("根据GB 7258...", sources=[...])
|
||||||
|
|
||||||
|
# 获取历史(用于LLM多轮对话)
|
||||||
|
history = session.get_history(max_turns=3)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_sessions: int = 100,
|
||||||
|
session_timeout_minutes: int = 30
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化会话管理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_sessions: 最大会话数量
|
||||||
|
session_timeout_minutes: 会话超时时间(分钟)
|
||||||
|
"""
|
||||||
|
self.max_sessions = max_sessions
|
||||||
|
self.session_timeout = session_timeout_minutes * 60
|
||||||
|
|
||||||
|
# 会话存储(内存)
|
||||||
|
self._sessions: Dict[str, ChatSession] = {}
|
||||||
|
|
||||||
|
logger.info(f"会话管理器初始化: max_sessions={max_sessions}, timeout={session_timeout_minutes}min")
|
||||||
|
|
||||||
|
def create_session(self, metadata: Dict = None) -> ChatSession:
|
||||||
|
"""
|
||||||
|
创建新会话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata: 会话元数据(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatSession: 新创建的会话
|
||||||
|
"""
|
||||||
|
# 检查会话数量限制
|
||||||
|
if len(self._sessions) >= self.max_sessions:
|
||||||
|
# 清理过期会话
|
||||||
|
self._cleanup_expired_sessions()
|
||||||
|
|
||||||
|
# 如果仍然超出限制,删除最老的会话
|
||||||
|
if len(self._sessions) >= self.max_sessions:
|
||||||
|
oldest_id = min(
|
||||||
|
self._sessions.keys(),
|
||||||
|
key=lambda x: self._sessions[x].created_at
|
||||||
|
)
|
||||||
|
self.delete_session(oldest_id)
|
||||||
|
logger.warning(f"删除最老会话以腾出空间: {oldest_id}")
|
||||||
|
|
||||||
|
session_id = str(uuid.uuid4())[:8]
|
||||||
|
session = ChatSession(
|
||||||
|
session_id=session_id,
|
||||||
|
metadata=metadata or {}
|
||||||
|
)
|
||||||
|
|
||||||
|
self._sessions[session_id] = session
|
||||||
|
logger.info(f"创建新会话: {session_id}")
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
|
def get_session(self, session_id: str) -> Optional[ChatSession]:
|
||||||
|
"""
|
||||||
|
获取会话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatSession: 会话对象(如不存在返回None)
|
||||||
|
"""
|
||||||
|
session = self._sessions.get(session_id)
|
||||||
|
|
||||||
|
if session:
|
||||||
|
# 检查是否过期
|
||||||
|
if self._is_session_expired(session):
|
||||||
|
self.delete_session(session_id)
|
||||||
|
logger.info(f"会话已过期,已删除: {session_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
|
def delete_session(self, session_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
删除会话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否成功删除
|
||||||
|
"""
|
||||||
|
if session_id in self._sessions:
|
||||||
|
del self._sessions[session_id]
|
||||||
|
logger.info(f"删除会话: {session_id}")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def list_sessions(self) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
列出所有会话
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: 会话列表摘要
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"session_id": s.session_id,
|
||||||
|
"message_count": s.message_count,
|
||||||
|
"created_at": s.created_at,
|
||||||
|
"updated_at": s.updated_at
|
||||||
|
}
|
||||||
|
for s in self._sessions.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
def _is_session_expired(self, session: ChatSession) -> bool:
|
||||||
|
"""检查会话是否过期"""
|
||||||
|
current_time = int(time.time())
|
||||||
|
return (current_time - session.updated_at) > self.session_timeout
|
||||||
|
|
||||||
|
def _cleanup_expired_sessions(self) -> int:
|
||||||
|
"""清理过期会话"""
|
||||||
|
expired_ids = [
|
||||||
|
sid for sid, session in self._sessions.items()
|
||||||
|
if self._is_session_expired(session)
|
||||||
|
]
|
||||||
|
|
||||||
|
for sid in expired_ids:
|
||||||
|
self.delete_session(sid)
|
||||||
|
|
||||||
|
if expired_ids:
|
||||||
|
logger.info(f"清理过期会话: {len(expired_ids)}个")
|
||||||
|
|
||||||
|
return len(expired_ids)
|
||||||
|
|
||||||
|
def get_session_count(self) -> int:
|
||||||
|
"""获取当前会话数量"""
|
||||||
|
return len(self._sessions)
|
||||||
|
|
||||||
|
def clear_all_sessions(self):
|
||||||
|
"""清空所有会话"""
|
||||||
|
self._sessions.clear()
|
||||||
|
logger.info("所有会话已清空")
|
||||||
404
backend/app/services/document_processor.py
Normal file
404
backend/app/services/document_processor.py
Normal file
@@ -0,0 +1,404 @@
|
|||||||
|
# src/services/document_processor.py
|
||||||
|
"""文档处理主流程 - 解析→摘要→分块→嵌入→入库"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from loguru import logger
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from .parser.pdf_parser import PDFParser
|
||||||
|
from .parser.docx_parser import DocxParser
|
||||||
|
from .parser.mineru_parser import ParserOrchestrator
|
||||||
|
from .embedding.text_chunker import RegulationChunker, TextChunk
|
||||||
|
from .embedding.bge_m3_embedder import BGEM3Embedder, EmbeddingResult
|
||||||
|
from .storage.milvus_client import MilvusClient
|
||||||
|
from .llm.document_summarizer import DocumentSummarizer, DocumentSummary
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProcessingResult:
|
||||||
|
"""文档处理结果"""
|
||||||
|
doc_id: str
|
||||||
|
doc_name: str
|
||||||
|
success: bool
|
||||||
|
num_chunks: int = 0
|
||||||
|
message: str = ""
|
||||||
|
markdown_text: str = ""
|
||||||
|
summary: str = ""
|
||||||
|
summary_latency_ms: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentProcessor:
|
||||||
|
"""
|
||||||
|
文档处理服务 - 完整处理流程
|
||||||
|
|
||||||
|
流程:
|
||||||
|
1. 文档解析(PDF/DOCX → Markdown)
|
||||||
|
2. 智能分块(章节级+条款级)
|
||||||
|
3. LLM摘要生成(可选)
|
||||||
|
4. 向量嵌入(BGE-M3 Dense+Sparse)
|
||||||
|
5. 存储入库(Milvus向量数据库)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size: int = None,
|
||||||
|
embedding_model: str = None,
|
||||||
|
use_mineru: bool = True,
|
||||||
|
generate_summary: bool = False, # 默认不生成摘要,节省约60秒
|
||||||
|
llm_provider: str = None,
|
||||||
|
llm_model: str = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化文档处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_size: 分块大小
|
||||||
|
embedding_model: 嵌入模型名称
|
||||||
|
use_mineru: 是否优先使用MinerU解析
|
||||||
|
generate_summary: 是否生成文档摘要(默认False,可节省约60秒处理时间)
|
||||||
|
llm_provider: LLM提供商
|
||||||
|
llm_model: LLM模型名称
|
||||||
|
"""
|
||||||
|
self.chunk_size = chunk_size or settings.chunk_size
|
||||||
|
self.embedding_model = embedding_model or settings.embedding_model
|
||||||
|
self.use_mineru = use_mineru
|
||||||
|
self.generate_summary = generate_summary
|
||||||
|
self.llm_provider = llm_provider or settings.llm_provider
|
||||||
|
self.llm_model = llm_model or settings.llm_model
|
||||||
|
|
||||||
|
# 初始化各组件
|
||||||
|
logger.info("初始化文档处理组件...")
|
||||||
|
|
||||||
|
# 解析器
|
||||||
|
self.parser = ParserOrchestrator()
|
||||||
|
|
||||||
|
# 分块器
|
||||||
|
self.chunker = RegulationChunker(chunk_size=self.chunk_size)
|
||||||
|
|
||||||
|
# 嵌入模型(延迟加载)
|
||||||
|
self.embedder: Optional[BGEM3Embedder] = None
|
||||||
|
|
||||||
|
# Milvus客户端(延迟连接)
|
||||||
|
self.milvus: Optional[MilvusClient] = None
|
||||||
|
|
||||||
|
# 摘要生成器(延迟加载)
|
||||||
|
self.summarizer: Optional[DocumentSummarizer] = None
|
||||||
|
|
||||||
|
logger.success("文档处理器初始化完成")
|
||||||
|
|
||||||
|
def _init_embedder(self):
|
||||||
|
"""延迟初始化嵌入模型"""
|
||||||
|
if self.embedder is None:
|
||||||
|
logger.info("加载嵌入模型...")
|
||||||
|
self.embedder = BGEM3Embedder(model_name=self.embedding_model)
|
||||||
|
|
||||||
|
def _init_milvus(self):
|
||||||
|
"""延迟初始化Milvus连接"""
|
||||||
|
if self.milvus is None:
|
||||||
|
logger.info("连接Milvus...")
|
||||||
|
self.milvus = MilvusClient()
|
||||||
|
self.milvus.connect()
|
||||||
|
self.milvus.create_collection(recreate=False)
|
||||||
|
self.milvus.load_collection()
|
||||||
|
|
||||||
|
def _init_summarizer(self):
|
||||||
|
"""延迟初始化摘要生成器"""
|
||||||
|
if self.summarizer is None:
|
||||||
|
logger.info("初始化摘要生成器...")
|
||||||
|
self.summarizer = DocumentSummarizer(
|
||||||
|
provider=self.llm_provider,
|
||||||
|
model=self.llm_model
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
doc_id: Optional[str] = None,
|
||||||
|
doc_name: Optional[str] = None,
|
||||||
|
regulation_type: str = "",
|
||||||
|
version: str = ""
|
||||||
|
) -> ProcessingResult:
|
||||||
|
"""
|
||||||
|
处理单个文档
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: 文档文件路径
|
||||||
|
doc_id: 文档ID(可选,默认自动生成)
|
||||||
|
doc_name: 文档名称(可选,默认从文件名获取)
|
||||||
|
regulation_type: 法规类型
|
||||||
|
version: 文档版本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ProcessingResult: 处理结果
|
||||||
|
"""
|
||||||
|
# 生成或使用传入的文档ID
|
||||||
|
if doc_id is None:
|
||||||
|
doc_id = str(uuid.uuid4())[:8]
|
||||||
|
|
||||||
|
# 获取文档名称
|
||||||
|
if doc_name is None:
|
||||||
|
doc_name = os.path.basename(file_path)
|
||||||
|
|
||||||
|
logger.info(f"开始处理文档: {doc_name} (ID: {doc_id})")
|
||||||
|
|
||||||
|
# 初始化结果变量
|
||||||
|
summary = ""
|
||||||
|
summary_latency_ms = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. 文档解析
|
||||||
|
logger.info("Step 1: 文档解析")
|
||||||
|
markdown_text = self._parse_document(file_path)
|
||||||
|
|
||||||
|
if not markdown_text:
|
||||||
|
return ProcessingResult(
|
||||||
|
doc_id=doc_id,
|
||||||
|
doc_name=doc_name,
|
||||||
|
success=False,
|
||||||
|
message="文档解析失败,内容为空"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. LLM摘要生成(可选)
|
||||||
|
if self.generate_summary:
|
||||||
|
logger.info("Step 2: LLM摘要生成")
|
||||||
|
self._init_summarizer()
|
||||||
|
summary_result = self.summarizer.summarize(
|
||||||
|
doc_name,
|
||||||
|
markdown_text,
|
||||||
|
regulation_type
|
||||||
|
)
|
||||||
|
if summary_result.is_success:
|
||||||
|
summary = summary_result.summary
|
||||||
|
summary_latency_ms = summary_result.latency_ms
|
||||||
|
logger.success(f"摘要生成完成: {summary_latency_ms}ms")
|
||||||
|
else:
|
||||||
|
logger.warning(f"摘要生成失败: {summary_result.error}")
|
||||||
|
else:
|
||||||
|
logger.info("Step 2: 跳过摘要生成(未勾选,节省约60秒)")
|
||||||
|
|
||||||
|
# 3. 智能分块
|
||||||
|
logger.info("Step 3: 智能分块")
|
||||||
|
chunks = self._chunk_document(
|
||||||
|
markdown_text,
|
||||||
|
doc_id,
|
||||||
|
doc_name,
|
||||||
|
regulation_type,
|
||||||
|
version
|
||||||
|
)
|
||||||
|
|
||||||
|
if not chunks:
|
||||||
|
return ProcessingResult(
|
||||||
|
doc_id=doc_id,
|
||||||
|
doc_name=doc_name,
|
||||||
|
success=False,
|
||||||
|
message="分块失败,无有效内容",
|
||||||
|
markdown_text=markdown_text,
|
||||||
|
summary=summary
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. 向量嵌入
|
||||||
|
logger.info("Step 4: 向量嵌入")
|
||||||
|
embeddings = self._embed_chunks(chunks)
|
||||||
|
|
||||||
|
if embeddings is None:
|
||||||
|
return ProcessingResult(
|
||||||
|
doc_id=doc_id,
|
||||||
|
doc_name=doc_name,
|
||||||
|
success=False,
|
||||||
|
message="向量嵌入失败",
|
||||||
|
markdown_text=markdown_text,
|
||||||
|
summary=summary
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. 存储入库
|
||||||
|
logger.info("Step 5: 存储入库")
|
||||||
|
inserted_ids = self._insert_to_milvus(chunks, embeddings)
|
||||||
|
|
||||||
|
logger.success(f"文档处理完成: {doc_name}, 共{len(inserted_ids)}条记录")
|
||||||
|
|
||||||
|
return ProcessingResult(
|
||||||
|
doc_id=doc_id,
|
||||||
|
doc_name=doc_name,
|
||||||
|
success=True,
|
||||||
|
num_chunks=len(inserted_ids),
|
||||||
|
message="处理成功",
|
||||||
|
markdown_text=markdown_text,
|
||||||
|
summary=summary,
|
||||||
|
summary_latency_ms=summary_latency_ms
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"文档处理失败: {e}")
|
||||||
|
return ProcessingResult(
|
||||||
|
doc_id=doc_id,
|
||||||
|
doc_name=doc_name,
|
||||||
|
success=False,
|
||||||
|
message=f"处理失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_document(self, file_path: str) -> str:
|
||||||
|
"""解析文档"""
|
||||||
|
ext = os.path.splitext(file_path)[1].lower()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if ext == ".pdf":
|
||||||
|
# PDF文档解析(优先MinerU,回退PyMuPDF)
|
||||||
|
markdown_text = self.parser.parse_pdf(file_path, prefer_mineru=self.use_mineru)
|
||||||
|
elif ext in [".docx", ".doc"]:
|
||||||
|
# Word文档解析
|
||||||
|
markdown_text = self.parser.parse_docx(file_path)
|
||||||
|
else:
|
||||||
|
logger.warning(f"不支持的文件类型: {ext}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
logger.success(f"文档解析完成,内容长度: {len(markdown_text)}字符")
|
||||||
|
return markdown_text
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"文档解析失败: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _chunk_document(
|
||||||
|
self,
|
||||||
|
markdown_text: str,
|
||||||
|
doc_id: str,
|
||||||
|
doc_name: str,
|
||||||
|
regulation_type: str,
|
||||||
|
version: str
|
||||||
|
) -> List[TextChunk]:
|
||||||
|
"""分块文档"""
|
||||||
|
try:
|
||||||
|
chunks = self.chunker.chunk_document(
|
||||||
|
markdown_text,
|
||||||
|
doc_id=doc_id,
|
||||||
|
doc_name=doc_name,
|
||||||
|
regulation_type=regulation_type,
|
||||||
|
version=version
|
||||||
|
)
|
||||||
|
logger.success(f"分块完成,共{len(chunks)}个chunk")
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"分块失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _embed_chunks(self, chunks: List[TextChunk]) -> Optional[EmbeddingResult]:
|
||||||
|
"""嵌入分块"""
|
||||||
|
try:
|
||||||
|
# 延迟初始化嵌入模型
|
||||||
|
self._init_embedder()
|
||||||
|
|
||||||
|
# 提取文本内容
|
||||||
|
texts = [chunk.content for chunk in chunks]
|
||||||
|
|
||||||
|
# 执行嵌入
|
||||||
|
embeddings = self.embedder.embed(texts)
|
||||||
|
|
||||||
|
logger.success(f"嵌入完成,向量数: {len(embeddings.dense_embeddings)}")
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"嵌入失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _insert_to_milvus(
|
||||||
|
self,
|
||||||
|
chunks: List[TextChunk],
|
||||||
|
embeddings: EmbeddingResult
|
||||||
|
) -> List[int]:
|
||||||
|
"""插入Milvus"""
|
||||||
|
try:
|
||||||
|
# 延迟初始化Milvus
|
||||||
|
self._init_milvus()
|
||||||
|
|
||||||
|
# 执行插入
|
||||||
|
inserted_ids = self.milvus.insert_chunks(chunks, embeddings)
|
||||||
|
|
||||||
|
logger.success(f"入库完成,共{len(inserted_ids)}条记录")
|
||||||
|
return inserted_ids
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"入库失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
top_k: int = 10,
|
||||||
|
filters: Optional[str] = None
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
检索法规内容
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 查询文本
|
||||||
|
top_k: 返回结果数量
|
||||||
|
filters: 过滤条件
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: 检索结果
|
||||||
|
"""
|
||||||
|
logger.info(f"执行检索: {query}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 延迟初始化
|
||||||
|
self._init_embedder()
|
||||||
|
self._init_milvus()
|
||||||
|
|
||||||
|
# 生成查询向量
|
||||||
|
query_embedding = self.embedder.embed_single(query)
|
||||||
|
|
||||||
|
# 执行混合检索
|
||||||
|
results = self.milvus.hybrid_search(
|
||||||
|
query_dense=query_embedding['dense'].tolist(),
|
||||||
|
query_sparse=query_embedding['sparse'],
|
||||||
|
top_k=top_k,
|
||||||
|
filters=filters
|
||||||
|
)
|
||||||
|
|
||||||
|
# 转换为字典格式
|
||||||
|
result_dicts = []
|
||||||
|
for r in results:
|
||||||
|
result_dicts.append({
|
||||||
|
"id": r.id,
|
||||||
|
"content": r.content,
|
||||||
|
"score": r.score,
|
||||||
|
"metadata": r.metadata
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.success(f"检索完成,返回{len(result_dicts)}条结果")
|
||||||
|
return result_dicts
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"检索失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""关闭连接"""
|
||||||
|
if self.milvus:
|
||||||
|
self.milvus.disconnect()
|
||||||
|
logger.info("文档处理器已关闭")
|
||||||
|
|
||||||
|
|
||||||
|
def process_document(
|
||||||
|
file_path: str,
|
||||||
|
doc_name: Optional[str] = None,
|
||||||
|
regulation_type: str = "",
|
||||||
|
version: str = ""
|
||||||
|
) -> ProcessingResult:
|
||||||
|
"""便捷函数:处理单个文档"""
|
||||||
|
processor = DocumentProcessor()
|
||||||
|
result = processor.process(file_path, doc_name, regulation_type, version)
|
||||||
|
processor.close()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def search_regulations(query: str, top_k: int = 10) -> List[Dict]:
|
||||||
|
"""便捷函数:检索法规"""
|
||||||
|
processor = DocumentProcessor()
|
||||||
|
results = processor.search(query, top_k)
|
||||||
|
processor.close()
|
||||||
|
return results
|
||||||
7
backend/app/services/embedding/__init__.py
Normal file
7
backend/app/services/embedding/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# src/services/embedding/__init__.py
|
||||||
|
"""嵌入和分块服务"""
|
||||||
|
|
||||||
|
from .text_chunker import RegulationChunker
|
||||||
|
from .bge_m3_embedder import BGEM3Embedder
|
||||||
|
|
||||||
|
__all__ = ["RegulationChunker", "BGEM3Embedder"]
|
||||||
296
backend/app/services/embedding/bge_m3_embedder.py
Normal file
296
backend/app/services/embedding/bge_m3_embedder.py
Normal file
@@ -0,0 +1,296 @@
|
|||||||
|
# src/services/embedding/bge_m3_embedder.py
|
||||||
|
"""BGE-M3嵌入服务 - Dense+Sparse双路向量生成"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Dict, Optional, Union
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from loguru import logger
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 设置HuggingFace镜像(国内网络)
|
||||||
|
if 'HF_ENDPOINT' not in os.environ:
|
||||||
|
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
||||||
|
|
||||||
|
# 本地模型路径(按优先级检查)
|
||||||
|
LOCAL_MODEL_PATHS = [
|
||||||
|
os.path.expanduser("~/.cache/modelscope/Xorbits/bge-m3"), # ModelScope下载路径
|
||||||
|
os.path.expanduser("~/.cache/huggingface/hub/models--BAAI--bge-m3/snapshots/main"), # HuggingFace本地路径
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmbeddingResult:
|
||||||
|
"""嵌入结果"""
|
||||||
|
dense_embeddings: np.ndarray # Dense向量(语义检索)
|
||||||
|
sparse_embeddings: List[Dict[int, float]] # Sparse向量(关键词匹配)
|
||||||
|
texts: List[str]
|
||||||
|
dim: int = 1024
|
||||||
|
|
||||||
|
|
||||||
|
class BGEM3Embedder:
|
||||||
|
"""
|
||||||
|
BGE-M3多语言嵌入模型服务
|
||||||
|
|
||||||
|
BGE-M3是BAAI发布的多语言嵌入模型,支持:
|
||||||
|
- Dense向量:用于语义相似度检索
|
||||||
|
- Sparse向量:用于关键词精确匹配(BM25风格)
|
||||||
|
- ColBERT向量:用于细粒度交互匹配(可选)
|
||||||
|
|
||||||
|
特点:
|
||||||
|
- 支持100+语言(中英双语优化)
|
||||||
|
- 8192 tokens超长上下文
|
||||||
|
- Dense+Sparse双路检索能力
|
||||||
|
|
||||||
|
GitHub: https://github.com/FlagOpen/FlagEmbedding
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "BAAI/bge-m3",
|
||||||
|
use_fp16: bool = True,
|
||||||
|
device: Optional[str] = None,
|
||||||
|
batch_size: int = 12,
|
||||||
|
max_length: int = 8192,
|
||||||
|
local_model_path: Optional[str] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化BGE-M3嵌入模型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: 模型名称(如果使用本地路径,此参数会被忽略)
|
||||||
|
use_fp16: 是否使用FP16加速
|
||||||
|
device: 设备类型(cuda/cpu),默认自动选择
|
||||||
|
batch_size: 批处理大小
|
||||||
|
max_length: 最大序列长度
|
||||||
|
local_model_path: 本地模型路径(可选,优先使用)
|
||||||
|
"""
|
||||||
|
self.use_fp16 = use_fp16
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
|
# 确定模型路径(优先使用本地路径)
|
||||||
|
if local_model_path and os.path.exists(local_model_path):
|
||||||
|
self.model_path = local_model_path
|
||||||
|
self.model_name = "local"
|
||||||
|
logger.info(f"使用本地模型路径: {local_model_path}")
|
||||||
|
else:
|
||||||
|
# 检查多个可能的本地路径
|
||||||
|
found_local = False
|
||||||
|
for path in LOCAL_MODEL_PATHS:
|
||||||
|
if os.path.exists(path) and os.path.exists(os.path.join(path, "config.json")):
|
||||||
|
self.model_path = path
|
||||||
|
self.model_name = "local"
|
||||||
|
logger.info(f"使用本地模型路径: {path}")
|
||||||
|
found_local = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not found_local:
|
||||||
|
self.model_path = model_name
|
||||||
|
self.model_name = model_name
|
||||||
|
logger.info(f"使用远程模型: {model_name}")
|
||||||
|
|
||||||
|
# 自动选择设备
|
||||||
|
if device is None:
|
||||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
else:
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
logger.info(f"初始化BGE-M3模型, 设备: {self.device}")
|
||||||
|
|
||||||
|
self.model = None
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
def _load_model(self):
|
||||||
|
"""加载嵌入模型"""
|
||||||
|
try:
|
||||||
|
from FlagEmbedding import BGEM3FlagModel
|
||||||
|
|
||||||
|
self.model = BGEM3FlagModel(
|
||||||
|
self.model_path,
|
||||||
|
use_fp16=self.use_fp16,
|
||||||
|
device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.success(f"BGE-M3模型加载成功")
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("FlagEmbedding库未安装,请运行: pip install FlagEmbedding")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"模型加载失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def embed(
|
||||||
|
self,
|
||||||
|
texts: List[str],
|
||||||
|
return_dense: bool = True,
|
||||||
|
return_sparse: bool = True,
|
||||||
|
return_colbert_vecs: bool = False
|
||||||
|
) -> EmbeddingResult:
|
||||||
|
"""
|
||||||
|
对文本列表生成嵌入向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: 文本列表
|
||||||
|
return_dense: 是否返回Dense向量
|
||||||
|
return_sparse: 是否返回Sparse向量
|
||||||
|
return_colbert_vecs: 是否返回ColBERT向量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmbeddingResult: 嵌入结果
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
logger.warning("输入文本列表为空")
|
||||||
|
return EmbeddingResult(
|
||||||
|
dense_embeddings=np.array([]),
|
||||||
|
sparse_embeddings=[],
|
||||||
|
texts=[],
|
||||||
|
dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"开始嵌入{len(texts)}个文本块")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 执行嵌入
|
||||||
|
embeddings = self.model.encode(
|
||||||
|
texts,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
max_length=self.max_length,
|
||||||
|
return_dense=return_dense,
|
||||||
|
return_sparse=return_sparse,
|
||||||
|
return_colbert_vecs=return_colbert_vecs
|
||||||
|
)
|
||||||
|
|
||||||
|
# 提取结果
|
||||||
|
dense_embeddings = embeddings.get('dense_vecs', np.array([]))
|
||||||
|
sparse_embeddings = embeddings.get('lexical_weights', [])
|
||||||
|
|
||||||
|
# 获取维度
|
||||||
|
dim = dense_embeddings.shape[1] if len(dense_embeddings) > 0 else 1024
|
||||||
|
|
||||||
|
logger.success(f"嵌入完成,向量维度: {dim}")
|
||||||
|
|
||||||
|
return EmbeddingResult(
|
||||||
|
dense_embeddings=dense_embeddings,
|
||||||
|
sparse_embeddings=sparse_embeddings,
|
||||||
|
texts=texts,
|
||||||
|
dim=dim
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"嵌入失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def embed_single(self, text: str) -> Dict[str, Union[np.ndarray, Dict]]:
|
||||||
|
"""
|
||||||
|
对单个文本生成嵌入向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 输入文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 包含dense和sparse向量
|
||||||
|
"""
|
||||||
|
result = self.embed([text])
|
||||||
|
return {
|
||||||
|
'dense': result.dense_embeddings[0],
|
||||||
|
'sparse': result.sparse_embeddings[0] if result.sparse_embeddings else {},
|
||||||
|
'dim': result.dim
|
||||||
|
}
|
||||||
|
|
||||||
|
def embed_dense(self, texts: List[str]) -> np.ndarray:
|
||||||
|
"""只生成Dense向量"""
|
||||||
|
result = self.embed(texts, return_sparse=False, return_colbert_vecs=False)
|
||||||
|
return result.dense_embeddings
|
||||||
|
|
||||||
|
def embed_sparse(self, texts: List[str]) -> List[Dict[int, float]]:
|
||||||
|
"""只生成Sparse向量"""
|
||||||
|
result = self.embed(texts, return_dense=False, return_colbert_vecs=False)
|
||||||
|
return result.sparse_embeddings
|
||||||
|
|
||||||
|
def embed_query(self, query: str) -> Dict:
|
||||||
|
"""
|
||||||
|
对查询文本生成嵌入(用于检索)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 查询文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 包含dense和sparse向量
|
||||||
|
"""
|
||||||
|
return self.embed_single(query)
|
||||||
|
|
||||||
|
def compute_similarity(
|
||||||
|
self,
|
||||||
|
query_embedding: np.ndarray,
|
||||||
|
doc_embeddings: np.ndarray,
|
||||||
|
metric: str = "cosine"
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
计算查询与文档的相似度
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_embedding: 查询向量
|
||||||
|
doc_embeddings: 文档向量矩阵
|
||||||
|
metric: 相似度度量(cosine/dot)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: 相似度分数数组
|
||||||
|
"""
|
||||||
|
if metric == "cosine":
|
||||||
|
# 余弦相似度
|
||||||
|
query_norm = np.linalg.norm(query_embedding)
|
||||||
|
doc_norms = np.linalg.norm(doc_embeddings, axis=1)
|
||||||
|
|
||||||
|
similarities = np.dot(doc_embeddings, query_embedding) / (doc_norms * query_norm)
|
||||||
|
|
||||||
|
elif metric == "dot":
|
||||||
|
# 点积相似度
|
||||||
|
similarities = np.dot(doc_embeddings, query_embedding)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的相似度度量: {metric}")
|
||||||
|
|
||||||
|
return similarities
|
||||||
|
|
||||||
|
def sparse_similarity(
|
||||||
|
self,
|
||||||
|
query_sparse: Dict[int, float],
|
||||||
|
doc_sparse: Dict[int, float]
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
计算Sparse向量的相似度(BM25风格)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_sparse: 查询的Sparse向量(词ID -> 权重)
|
||||||
|
doc_sparse: 文档的Sparse向量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: 相似度分数
|
||||||
|
"""
|
||||||
|
# 计算交集词的点积
|
||||||
|
common_keys = set(query_sparse.keys()) & set(doc_sparse.keys())
|
||||||
|
|
||||||
|
score = sum(query_sparse[k] * doc_sparse[k] for k in common_keys)
|
||||||
|
return score
|
||||||
|
|
||||||
|
|
||||||
|
def embed_texts(
|
||||||
|
texts: List[str],
|
||||||
|
model_name: str = "BAAI/bge-m3",
|
||||||
|
**kwargs
|
||||||
|
) -> EmbeddingResult:
|
||||||
|
"""便捷函数:对文本列表生成嵌入"""
|
||||||
|
embedder = BGEM3Embedder(model_name=model_name, **kwargs)
|
||||||
|
return embedder.embed(texts)
|
||||||
|
|
||||||
|
|
||||||
|
def embed_single_text(
|
||||||
|
text: str,
|
||||||
|
model_name: str = "BAAI/bge-m3",
|
||||||
|
**kwargs
|
||||||
|
) -> Dict:
|
||||||
|
"""便捷函数:对单个文本生成嵌入"""
|
||||||
|
embedder = BGEM3Embedder(model_name=model_name, **kwargs)
|
||||||
|
return embedder.embed_single(text)
|
||||||
449
backend/app/services/embedding/text_chunker.py
Normal file
449
backend/app/services/embedding/text_chunker.py
Normal file
@@ -0,0 +1,449 @@
|
|||||||
|
# src/services/embedding/text_chunker.py
|
||||||
|
"""智能分块器 - 章节级+条款级双粒度切割"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import List, Dict, Optional, Tuple
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChunkMetadata:
|
||||||
|
"""分块元数据"""
|
||||||
|
doc_id: str = ""
|
||||||
|
doc_name: str = ""
|
||||||
|
chunk_id: str = ""
|
||||||
|
section_number: str = "" # 章节编号(如 "第一章")
|
||||||
|
section_title: str = "" # 章节标题
|
||||||
|
clause_number: str = "" # 条款编号(如 "第一条")
|
||||||
|
page_number: int = 0
|
||||||
|
start_position: int = 0 # 在原文中的起始位置
|
||||||
|
end_position: int = 0 # 在原文中的结束位置
|
||||||
|
regulation_type: str = "" # 法规类型
|
||||||
|
version: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TextChunk:
|
||||||
|
"""文本分块"""
|
||||||
|
content: str
|
||||||
|
metadata: ChunkMetadata
|
||||||
|
token_count: int = 0 # 估算的token数量
|
||||||
|
|
||||||
|
|
||||||
|
class RegulationChunker:
|
||||||
|
"""
|
||||||
|
法规文档智能分块器
|
||||||
|
|
||||||
|
实现章节级/条款级双粒度切割,适配国标GB文档结构:
|
||||||
|
- 国标文档通常有明确的层级结构:章 > 节 > 条
|
||||||
|
- 每个条款应作为一个独立的语义单元
|
||||||
|
- 保留条款完整性,避免跨条款截断
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 法规标题模式
|
||||||
|
CHAPTER_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+章\s+[^\n]+')
|
||||||
|
SECTION_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+节\s+[^\n]+')
|
||||||
|
CLAUSE_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+条\s')
|
||||||
|
|
||||||
|
# 条款子项模式
|
||||||
|
SUB_ITEM_PATTERN = re.compile(r'^[\((][一二三四五六七八九十]+[\))]\s')
|
||||||
|
NUMBER_ITEM_PATTERN = re.compile(r'^[\d]+[\.、]\s')
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size: int = 512,
|
||||||
|
chunk_overlap: int = 50,
|
||||||
|
max_chunk_size: int = 2048,
|
||||||
|
min_chunk_size: int = 100
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化分块器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_size: 默认分块大小(字符数)
|
||||||
|
chunk_overlap: 分块重叠大小
|
||||||
|
max_chunk_size: 最大分块大小(防止单个条款过长)
|
||||||
|
min_chunk_size: 最小分块大小(防止碎片化)
|
||||||
|
"""
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self.chunk_overlap = chunk_overlap
|
||||||
|
self.max_chunk_size = max_chunk_size
|
||||||
|
self.min_chunk_size = min_chunk_size
|
||||||
|
|
||||||
|
def chunk_document(
|
||||||
|
self,
|
||||||
|
markdown_text: str,
|
||||||
|
doc_id: str = "",
|
||||||
|
doc_name: str = "",
|
||||||
|
regulation_type: str = "",
|
||||||
|
version: str = ""
|
||||||
|
) -> List[TextChunk]:
|
||||||
|
"""
|
||||||
|
对法规文档进行智能分块
|
||||||
|
|
||||||
|
Args:
|
||||||
|
markdown_text: Markdown格式的文档内容
|
||||||
|
doc_id: 文档ID
|
||||||
|
doc_name: 文档名称
|
||||||
|
regulation_type: 法规类型
|
||||||
|
version: 文档版本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[TextChunk]: 分块列表
|
||||||
|
"""
|
||||||
|
logger.info(f"开始分块文档: {doc_name}")
|
||||||
|
|
||||||
|
# 1. 按章节分割(一级分块)
|
||||||
|
sections = self._split_by_sections(markdown_text)
|
||||||
|
|
||||||
|
# 2. 在每个章节内按条款分割(二级分块)
|
||||||
|
chunks = []
|
||||||
|
global_position = 0
|
||||||
|
|
||||||
|
for section_num, section_title, section_content, section_start in sections:
|
||||||
|
# 在章节内按条款分割
|
||||||
|
clause_chunks = self._split_by_clauses(
|
||||||
|
section_content,
|
||||||
|
section_num,
|
||||||
|
section_title,
|
||||||
|
section_start + global_position
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk_content, clause_num, clause_title, start_pos, end_pos in clause_chunks:
|
||||||
|
# 处理过长的条款(进一步细分)
|
||||||
|
if len(chunk_content) > self.max_chunk_size:
|
||||||
|
sub_chunks = self._split_long_clause(
|
||||||
|
chunk_content,
|
||||||
|
clause_num,
|
||||||
|
clause_title
|
||||||
|
)
|
||||||
|
for sub_content, sub_start, sub_end in sub_chunks:
|
||||||
|
chunk = self._create_chunk(
|
||||||
|
sub_content,
|
||||||
|
doc_id,
|
||||||
|
doc_name,
|
||||||
|
section_num,
|
||||||
|
section_title,
|
||||||
|
clause_num,
|
||||||
|
sub_start + start_pos,
|
||||||
|
sub_end + start_pos,
|
||||||
|
regulation_type,
|
||||||
|
version
|
||||||
|
)
|
||||||
|
chunks.append(chunk)
|
||||||
|
else:
|
||||||
|
chunk = self._create_chunk(
|
||||||
|
chunk_content,
|
||||||
|
doc_id,
|
||||||
|
doc_name,
|
||||||
|
section_num,
|
||||||
|
section_title,
|
||||||
|
clause_num,
|
||||||
|
start_pos,
|
||||||
|
end_pos,
|
||||||
|
regulation_type,
|
||||||
|
version
|
||||||
|
)
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
logger.success(f"分块完成,共{len(chunks)}个chunk")
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def _split_by_sections(self, markdown_text: str) -> List[Tuple[str, str, str, int]]:
|
||||||
|
"""
|
||||||
|
按章节分割文档
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (section_number, section_title, section_content, start_position)
|
||||||
|
"""
|
||||||
|
sections = []
|
||||||
|
lines = markdown_text.split('\n')
|
||||||
|
|
||||||
|
current_section_num = ""
|
||||||
|
current_section_title = ""
|
||||||
|
current_section_content = []
|
||||||
|
current_section_start = 0
|
||||||
|
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
# 检测章节标题
|
||||||
|
chapter_match = self.CHAPTER_PATTERN.match(line.strip())
|
||||||
|
section_match = self.SECTION_PATTERN.match(line.strip())
|
||||||
|
|
||||||
|
if chapter_match or section_match:
|
||||||
|
# 保存上一个章节
|
||||||
|
if current_section_content:
|
||||||
|
content = '\n'.join(current_section_content)
|
||||||
|
sections.append((
|
||||||
|
current_section_num,
|
||||||
|
current_section_title,
|
||||||
|
content,
|
||||||
|
current_section_start
|
||||||
|
))
|
||||||
|
|
||||||
|
# 开始新章节
|
||||||
|
current_section_start = sum(len(l) + 1 for l in lines[:i])
|
||||||
|
current_section_content = []
|
||||||
|
|
||||||
|
if chapter_match:
|
||||||
|
current_section_num = line.strip()
|
||||||
|
current_section_title = self._extract_title(line.strip())
|
||||||
|
else:
|
||||||
|
current_section_num = line.strip()
|
||||||
|
current_section_title = self._extract_title(line.strip())
|
||||||
|
|
||||||
|
current_section_content.append(line)
|
||||||
|
|
||||||
|
# 保存最后一个章节
|
||||||
|
if current_section_content:
|
||||||
|
content = '\n'.join(current_section_content)
|
||||||
|
sections.append((
|
||||||
|
current_section_num,
|
||||||
|
current_section_title,
|
||||||
|
content,
|
||||||
|
current_section_start
|
||||||
|
))
|
||||||
|
|
||||||
|
# 如果没有检测到章节,将整个文档作为一个大章节
|
||||||
|
if not sections:
|
||||||
|
sections.append((
|
||||||
|
"",
|
||||||
|
"全文",
|
||||||
|
markdown_text,
|
||||||
|
0
|
||||||
|
))
|
||||||
|
|
||||||
|
return sections
|
||||||
|
|
||||||
|
def _split_by_clauses(
|
||||||
|
self,
|
||||||
|
section_content: str,
|
||||||
|
section_num: str,
|
||||||
|
section_title: str,
|
||||||
|
section_start: int
|
||||||
|
) -> List[Tuple[str, str, str, int, int]]:
|
||||||
|
"""
|
||||||
|
在章节内按条款分割
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (content, clause_number, clause_title, start_position, end_position)
|
||||||
|
"""
|
||||||
|
clauses = []
|
||||||
|
lines = section_content.split('\n')
|
||||||
|
|
||||||
|
current_clause_num = ""
|
||||||
|
current_clause_title = ""
|
||||||
|
current_clause_content = []
|
||||||
|
current_clause_start = section_start
|
||||||
|
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
# 检测条款标题
|
||||||
|
clause_match = self.CLAUSE_PATTERN.match(line.strip())
|
||||||
|
|
||||||
|
if clause_match:
|
||||||
|
# 保存上一个条款
|
||||||
|
if current_clause_content:
|
||||||
|
content = '\n'.join(current_clause_content)
|
||||||
|
end_pos = current_clause_start + len(content)
|
||||||
|
clauses.append((
|
||||||
|
content,
|
||||||
|
current_clause_num,
|
||||||
|
current_clause_title,
|
||||||
|
current_clause_start,
|
||||||
|
end_pos
|
||||||
|
))
|
||||||
|
|
||||||
|
# 开始新条款
|
||||||
|
current_clause_start = section_start + sum(len(l) + 1 for l in lines[:i])
|
||||||
|
current_clause_content = []
|
||||||
|
current_clause_num = self._extract_clause_number(line.strip())
|
||||||
|
current_clause_title = line.strip()
|
||||||
|
|
||||||
|
current_clause_content.append(line)
|
||||||
|
|
||||||
|
# 保存最后一个条款
|
||||||
|
if current_clause_content:
|
||||||
|
content = '\n'.join(current_clause_content)
|
||||||
|
end_pos = current_clause_start + len(content)
|
||||||
|
clauses.append((
|
||||||
|
content,
|
||||||
|
current_clause_num,
|
||||||
|
current_clause_title,
|
||||||
|
current_clause_start,
|
||||||
|
end_pos
|
||||||
|
))
|
||||||
|
|
||||||
|
# 如果没有检测到条款,将整个章节作为一个条款
|
||||||
|
if not clauses:
|
||||||
|
clauses.append((
|
||||||
|
section_content,
|
||||||
|
"",
|
||||||
|
section_title,
|
||||||
|
section_start,
|
||||||
|
section_start + len(section_content)
|
||||||
|
))
|
||||||
|
|
||||||
|
return clauses
|
||||||
|
|
||||||
|
def _split_long_clause(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
clause_num: str,
|
||||||
|
clause_title: str
|
||||||
|
) -> List[Tuple[str, int, int]]:
|
||||||
|
"""
|
||||||
|
分割过长的条款内容
|
||||||
|
|
||||||
|
按条款子项或段落分割,保持语义完整性
|
||||||
|
"""
|
||||||
|
sub_chunks = []
|
||||||
|
lines = content.split('\n')
|
||||||
|
|
||||||
|
# 检测是否有子项结构
|
||||||
|
has_sub_items = any(
|
||||||
|
self.SUB_ITEM_PATTERN.match(line.strip()) or
|
||||||
|
self.NUMBER_ITEM_PATTERN.match(line.strip())
|
||||||
|
for line in lines
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_sub_items:
|
||||||
|
# 按子项分割
|
||||||
|
current_sub_content = []
|
||||||
|
current_sub_start = 0
|
||||||
|
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
is_sub_item = (
|
||||||
|
self.SUB_ITEM_PATTERN.match(line.strip()) or
|
||||||
|
self.NUMBER_ITEM_PATTERN.match(line.strip())
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_sub_item and current_sub_content:
|
||||||
|
sub_content = '\n'.join(current_sub_content)
|
||||||
|
sub_end = current_sub_start + len(sub_content)
|
||||||
|
if len(sub_content) >= self.min_chunk_size:
|
||||||
|
sub_chunks.append((sub_content, current_sub_start, sub_end))
|
||||||
|
current_sub_content = []
|
||||||
|
current_sub_start = sum(len(l) + 1 for l in lines[:i])
|
||||||
|
|
||||||
|
current_sub_content.append(line)
|
||||||
|
|
||||||
|
# 保存最后一个子项
|
||||||
|
if current_sub_content:
|
||||||
|
sub_content = '\n'.join(current_sub_content)
|
||||||
|
sub_end = current_sub_start + len(sub_content)
|
||||||
|
sub_chunks.append((sub_content, current_sub_start, sub_end))
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 按段落分割(滑动窗口)
|
||||||
|
paragraphs = []
|
||||||
|
current_para = []
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if line.strip():
|
||||||
|
current_para.append(line)
|
||||||
|
else:
|
||||||
|
if current_para:
|
||||||
|
paragraphs.append('\n'.join(current_para))
|
||||||
|
current_para = []
|
||||||
|
|
||||||
|
if current_para:
|
||||||
|
paragraphs.append('\n'.join(current_para))
|
||||||
|
|
||||||
|
# 合并段落直到达到chunk_size
|
||||||
|
current_chunk = []
|
||||||
|
current_length = 0
|
||||||
|
chunk_start = 0
|
||||||
|
|
||||||
|
for para in paragraphs:
|
||||||
|
if current_length + len(para) > self.chunk_size and current_chunk:
|
||||||
|
chunk_content = '\n'.join(current_chunk)
|
||||||
|
chunk_end = chunk_start + len(chunk_content)
|
||||||
|
sub_chunks.append((chunk_content, chunk_start, chunk_end))
|
||||||
|
current_chunk = []
|
||||||
|
current_length = 0
|
||||||
|
chunk_start = chunk_end
|
||||||
|
|
||||||
|
current_chunk.append(para)
|
||||||
|
current_length += len(para)
|
||||||
|
|
||||||
|
# 保存最后一个chunk
|
||||||
|
if current_chunk:
|
||||||
|
chunk_content = '\n'.join(current_chunk)
|
||||||
|
chunk_end = chunk_start + len(chunk_content)
|
||||||
|
sub_chunks.append((chunk_content, chunk_start, chunk_end))
|
||||||
|
|
||||||
|
return sub_chunks
|
||||||
|
|
||||||
|
def _extract_title(self, header_line: str) -> str:
|
||||||
|
"""从标题行提取标题内容"""
|
||||||
|
# 移除"第X章"、"第X节"前缀
|
||||||
|
title = re.sub(r'^第[一二三四五六七八九十百]+[章节]\s+', '', header_line)
|
||||||
|
return title.strip()
|
||||||
|
|
||||||
|
def _extract_clause_number(self, clause_line: str) -> str:
|
||||||
|
"""从条款行提取条款编号"""
|
||||||
|
match = self.CLAUSE_PATTERN.match(clause_line)
|
||||||
|
if match:
|
||||||
|
return match.group(0).strip()
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _create_chunk(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
doc_id: str,
|
||||||
|
doc_name: str,
|
||||||
|
section_num: str,
|
||||||
|
section_title: str,
|
||||||
|
clause_num: str,
|
||||||
|
start_pos: int,
|
||||||
|
end_pos: int,
|
||||||
|
regulation_type: str,
|
||||||
|
version: str
|
||||||
|
) -> TextChunk:
|
||||||
|
"""创建文本分块"""
|
||||||
|
# 清理内容
|
||||||
|
content = content.strip()
|
||||||
|
|
||||||
|
# 计算估算token数(中文约1.5字符/token)
|
||||||
|
token_count = int(len(content) * 0.7) # 简化估算
|
||||||
|
|
||||||
|
# 生成chunk_id
|
||||||
|
chunk_id = f"{doc_id}_{section_num}_{clause_num}_{start_pos}"
|
||||||
|
|
||||||
|
metadata = ChunkMetadata(
|
||||||
|
doc_id=doc_id,
|
||||||
|
doc_name=doc_name,
|
||||||
|
chunk_id=chunk_id,
|
||||||
|
section_number=section_num,
|
||||||
|
section_title=section_title,
|
||||||
|
clause_number=clause_num,
|
||||||
|
start_position=start_pos,
|
||||||
|
end_position=end_pos,
|
||||||
|
regulation_type=regulation_type,
|
||||||
|
version=version
|
||||||
|
)
|
||||||
|
|
||||||
|
return TextChunk(
|
||||||
|
content=content,
|
||||||
|
metadata=metadata,
|
||||||
|
token_count=token_count
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_regulation_document(
|
||||||
|
markdown_text: str,
|
||||||
|
doc_id: str = "",
|
||||||
|
doc_name: str = "",
|
||||||
|
regulation_type: str = "",
|
||||||
|
version: str = "",
|
||||||
|
chunk_size: int = 512
|
||||||
|
) -> List[TextChunk]:
|
||||||
|
"""便捷函数:对法规文档进行分块"""
|
||||||
|
chunker = RegulationChunker(chunk_size=chunk_size)
|
||||||
|
return chunker.chunk_document(
|
||||||
|
markdown_text,
|
||||||
|
doc_id,
|
||||||
|
doc_name,
|
||||||
|
regulation_type,
|
||||||
|
version
|
||||||
|
)
|
||||||
15
backend/app/services/llm/__init__.py
Normal file
15
backend/app/services/llm/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# src/services/llm/__init__.py
|
||||||
|
"""LLM服务模块"""
|
||||||
|
|
||||||
|
from .llm_factory import LLMFactory, get_llm_client
|
||||||
|
from .base_client import BaseLLMClient, LLMResponse, LLMConfig, LLMProvider
|
||||||
|
from .deepseek_client import DeepSeekClient
|
||||||
|
from .qwen_client import QwenClient, QwenVLClient
|
||||||
|
from .document_summarizer import DocumentSummarizer, summarize_document, DocumentSummary
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LLMFactory", "get_llm_client",
|
||||||
|
"BaseLLMClient", "LLMResponse", "LLMConfig", "LLMProvider",
|
||||||
|
"DeepSeekClient", "QwenClient", "QwenVLClient",
|
||||||
|
"DocumentSummarizer", "summarize_document", "DocumentSummary"
|
||||||
|
]
|
||||||
116
backend/app/services/llm/base_client.py
Normal file
116
backend/app/services/llm/base_client.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
# src/services/llm/base_client.py
|
||||||
|
"""LLM客户端基类 - 统一接口定义"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Dict, Optional, Any
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class LLMProvider(Enum):
|
||||||
|
"""LLM提供商"""
|
||||||
|
DEEPSEEK = "deepseek"
|
||||||
|
QWEN = "qwen"
|
||||||
|
QWEN_VL = "qwen_vl"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMResponse:
|
||||||
|
"""LLM响应结果"""
|
||||||
|
content: str
|
||||||
|
model: str
|
||||||
|
usage: Dict[str, int] = field(default_factory=dict)
|
||||||
|
finish_reason: str = "stop"
|
||||||
|
latency_ms: int = 0
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_success(self) -> bool:
|
||||||
|
return self.error is None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMConfig:
|
||||||
|
"""LLM配置"""
|
||||||
|
provider: LLMProvider
|
||||||
|
model: str
|
||||||
|
api_key: str
|
||||||
|
base_url: str
|
||||||
|
max_tokens: int = 4096
|
||||||
|
temperature: float = 0.7
|
||||||
|
top_p: float = 0.9
|
||||||
|
timeout: int = 300 # 默认超时300秒(摘要/Skills生成可能需要较长时间)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLLMClient(ABC):
|
||||||
|
"""LLM客户端基类"""
|
||||||
|
|
||||||
|
def __init__(self, config: LLMConfig):
|
||||||
|
self.config = config
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _init_client(self):
|
||||||
|
"""初始化客户端"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""
|
||||||
|
对话补全
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: 对话消息列表 [{"role": "user/assistant/system", "content": "..."}]
|
||||||
|
max_tokens: 最大输出token数
|
||||||
|
temperature: 温度参数
|
||||||
|
**kwargs: 其他参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLMResponse: 响应结果
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def complete(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""
|
||||||
|
单轮补全(便捷方法)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: 用户输入
|
||||||
|
system_prompt: 系统提示词
|
||||||
|
max_tokens: 最大输出token数
|
||||||
|
temperature: 温度参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLMResponse: 响应结果
|
||||||
|
"""
|
||||||
|
messages = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
return self.chat(messages, max_tokens, temperature, **kwargs)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_available_models(self) -> List[str]:
|
||||||
|
"""获取可用模型列表"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def estimate_tokens(self, text: str) -> int:
|
||||||
|
"""估算文本token数(粗略估计)"""
|
||||||
|
# 中文字符约1.5 token,英文约0.25 token
|
||||||
|
chinese_chars = sum(1 for c in text if '一' <= c <= '鿿')
|
||||||
|
other_chars = len(text) - chinese_chars
|
||||||
|
return int(chinese_chars * 1.5 + other_chars * 0.25)
|
||||||
130
backend/app/services/llm/deepseek_client.py
Normal file
130
backend/app/services/llm/deepseek_client.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
# src/services/llm/deepseek_client.py
|
||||||
|
"""DeepSeek LLM客户端 - OpenAI兼容API"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
from loguru import logger
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from .base_client import BaseLLMClient, LLMResponse, LLMConfig, LLMProvider
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSeekClient(BaseLLMClient):
|
||||||
|
"""
|
||||||
|
DeepSeek API客户端(OpenAI兼容格式)
|
||||||
|
|
||||||
|
支持模型:
|
||||||
|
- deepseek-chat
|
||||||
|
- deepseek-coder
|
||||||
|
- deepseek-reasoner
|
||||||
|
- deepseek-v3
|
||||||
|
- deepseek-v3.2
|
||||||
|
- deepseek-v4-flash
|
||||||
|
"""
|
||||||
|
|
||||||
|
SUPPORTED_MODELS = [
|
||||||
|
"deepseek-chat",
|
||||||
|
"deepseek-coder",
|
||||||
|
"deepseek-reasoner",
|
||||||
|
"deepseek-v3",
|
||||||
|
"deepseek-v3.2",
|
||||||
|
"deepseek-v4-flash"
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, config: LLMConfig):
|
||||||
|
if config.provider != LLMProvider.DEEPSEEK:
|
||||||
|
raise ValueError(f"配置provider应为DEEPSEEK,实际为{config.provider}")
|
||||||
|
super().__init__(config)
|
||||||
|
self._init_client()
|
||||||
|
|
||||||
|
def _init_client(self):
|
||||||
|
"""初始化HTTP客户端"""
|
||||||
|
self._client = httpx.Client(
|
||||||
|
base_url=self.config.base_url,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {self.config.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
},
|
||||||
|
timeout=self.config.timeout
|
||||||
|
)
|
||||||
|
logger.info(f"DeepSeek客户端初始化完成: {self.config.base_url} - {self.config.model}")
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""对话补全"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = {
|
||||||
|
"model": self.config.model,
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": max_tokens or self.config.max_tokens,
|
||||||
|
"temperature": temperature or self.config.temperature,
|
||||||
|
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
|
||||||
|
response = self._client.post("/chat/completions", json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
latency_ms = int((time.time() - start_time) * 1000)
|
||||||
|
|
||||||
|
choices = data.get("choices", [{}])
|
||||||
|
message = choices[0].get("message", {})
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=message.get("content", ""),
|
||||||
|
model=data.get("model", self.config.model),
|
||||||
|
usage=data.get("usage", {}),
|
||||||
|
finish_reason=choices[0].get("finish_reason", "stop"),
|
||||||
|
latency_ms=latency_ms
|
||||||
|
)
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"DeepSeek API错误: {e.response.status_code} - {e.response.text}")
|
||||||
|
return LLMResponse(
|
||||||
|
content="",
|
||||||
|
model=self.config.model,
|
||||||
|
error=f"API错误: {e.response.status_code} - {e.response.text[:200]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"DeepSeek调用失败: {e}")
|
||||||
|
return LLMResponse(
|
||||||
|
content="",
|
||||||
|
model=self.config.model,
|
||||||
|
error=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_available_models(self) -> List[str]:
|
||||||
|
"""获取可用模型列表"""
|
||||||
|
return self.SUPPORTED_MODELS
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""关闭客户端"""
|
||||||
|
if self._client:
|
||||||
|
self._client.close()
|
||||||
|
|
||||||
|
|
||||||
|
def create_deepseek_client(
|
||||||
|
api_key: str,
|
||||||
|
model: str = "deepseek-v4-flash",
|
||||||
|
base_url: str = "http://6.86.80.4:30080/v1",
|
||||||
|
**kwargs
|
||||||
|
) -> DeepSeekClient:
|
||||||
|
"""便捷函数:创建DeepSeek客户端"""
|
||||||
|
config = LLMConfig(
|
||||||
|
provider=LLMProvider.DEEPSEEK,
|
||||||
|
model=model,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return DeepSeekClient(config)
|
||||||
231
backend/app/services/llm/document_summarizer.py
Normal file
231
backend/app/services/llm/document_summarizer.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
# src/services/llm/document_summarizer.py
|
||||||
|
"""文档摘要生成服务 - LLM生成法规文档摘要"""
|
||||||
|
|
||||||
|
from typing import Dict, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from app.services.llm import get_llm_client, BaseLLMClient
|
||||||
|
from app.services.rag.prompt_templates import get_prompt_template
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DocumentSummary:
|
||||||
|
"""文档摘要结果"""
|
||||||
|
doc_name: str
|
||||||
|
summary: str
|
||||||
|
applicable_scope: str
|
||||||
|
key_clauses: list
|
||||||
|
key_terms: list
|
||||||
|
compliance_points: list
|
||||||
|
model: str
|
||||||
|
latency_ms: int
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_success(self) -> bool:
|
||||||
|
return self.error is None
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentSummarizer:
|
||||||
|
"""
|
||||||
|
文档摘要生成器
|
||||||
|
|
||||||
|
功能:
|
||||||
|
- 生成法规文档的核心要点摘要
|
||||||
|
- 提取适用范围
|
||||||
|
- 突出关键条款
|
||||||
|
- 列出合规要点
|
||||||
|
|
||||||
|
使用示例:
|
||||||
|
summarizer = DocumentSummarizer()
|
||||||
|
result = summarizer.summarize("GB 7258-2017", markdown_content)
|
||||||
|
print(result.summary)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
provider: str = None,
|
||||||
|
model: str = None,
|
||||||
|
max_tokens: int = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化摘要生成器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: LLM提供商
|
||||||
|
model: LLM模型名称
|
||||||
|
max_tokens: 最大输出token数
|
||||||
|
"""
|
||||||
|
self.provider = provider or settings.llm_provider
|
||||||
|
self.model = model or settings.llm_model
|
||||||
|
self.max_tokens = max_tokens or settings.rag_summary_max_tokens
|
||||||
|
|
||||||
|
# LLM客户端(延迟加载)
|
||||||
|
self.llm: Optional[BaseLLMClient] = None
|
||||||
|
|
||||||
|
logger.info(f"摘要生成器初始化: provider={self.provider}, model={self.model}")
|
||||||
|
|
||||||
|
def _init_llm(self):
|
||||||
|
"""延迟初始化LLM"""
|
||||||
|
if self.llm is None:
|
||||||
|
self.llm = get_llm_client(
|
||||||
|
provider=self.provider,
|
||||||
|
model=self.model
|
||||||
|
)
|
||||||
|
|
||||||
|
def summarize(
|
||||||
|
self,
|
||||||
|
doc_name: str,
|
||||||
|
content: str,
|
||||||
|
regulation_type: str = "",
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
) -> DocumentSummary:
|
||||||
|
"""
|
||||||
|
生成文档摘要
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_name: 文档名称
|
||||||
|
content: 文档内容(Markdown格式)
|
||||||
|
regulation_type: 法规类型
|
||||||
|
max_tokens: 最大输出token数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DocumentSummary: 摘要结果
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
logger.info(f"生成文档摘要: {doc_name}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._init_llm()
|
||||||
|
|
||||||
|
# 使用摘要模板
|
||||||
|
template = get_prompt_template("document_summary")
|
||||||
|
|
||||||
|
# 构建用户消息
|
||||||
|
user_content = template.user_template.format(
|
||||||
|
doc_name=doc_name,
|
||||||
|
content=content[:8000] # 截取前8000字符(避免超出token限制)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 调用LLM
|
||||||
|
response = self.llm.chat(
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": template.system_prompt},
|
||||||
|
{"role": "user", "content": user_content}
|
||||||
|
],
|
||||||
|
max_tokens=max_tokens or self.max_tokens,
|
||||||
|
temperature=0.3 # 低温度保证摘要准确性
|
||||||
|
)
|
||||||
|
|
||||||
|
latency_ms = int((time.time() - start_time) * 1000)
|
||||||
|
|
||||||
|
if not response.is_success:
|
||||||
|
return DocumentSummary(
|
||||||
|
doc_name=doc_name,
|
||||||
|
summary="",
|
||||||
|
applicable_scope="",
|
||||||
|
key_clauses=[],
|
||||||
|
key_terms=[],
|
||||||
|
compliance_points=[],
|
||||||
|
model=self.model,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
error=response.error
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解析摘要结构
|
||||||
|
summary_data = self._parse_summary(response.content)
|
||||||
|
|
||||||
|
logger.success(f"摘要生成完成: {doc_name}, {latency_ms}ms")
|
||||||
|
|
||||||
|
return DocumentSummary(
|
||||||
|
doc_name=doc_name,
|
||||||
|
summary=summary_data.get("summary", response.content),
|
||||||
|
applicable_scope=summary_data.get("applicable_scope", ""),
|
||||||
|
key_clauses=summary_data.get("key_clauses", []),
|
||||||
|
key_terms=summary_data.get("key_terms", []),
|
||||||
|
compliance_points=summary_data.get("compliance_points", []),
|
||||||
|
model=response.model,
|
||||||
|
latency_ms=latency_ms
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"摘要生成失败: {e}")
|
||||||
|
return DocumentSummary(
|
||||||
|
doc_name=doc_name,
|
||||||
|
summary="",
|
||||||
|
applicable_scope="",
|
||||||
|
key_clauses=[],
|
||||||
|
key_terms=[],
|
||||||
|
compliance_points=[],
|
||||||
|
model=self.model,
|
||||||
|
latency_ms=0,
|
||||||
|
error=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_summary(self, content: str) -> Dict:
|
||||||
|
"""解析摘要内容(提取结构化信息)"""
|
||||||
|
result = {
|
||||||
|
"summary": content,
|
||||||
|
"applicable_scope": "",
|
||||||
|
"key_clauses": [],
|
||||||
|
"key_terms": [],
|
||||||
|
"compliance_points": []
|
||||||
|
}
|
||||||
|
|
||||||
|
# 简单解析(提取关键信息)
|
||||||
|
lines = content.split("\n")
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
|
||||||
|
# 提取适用范围
|
||||||
|
if "适用范围" in line or "适用对象" in line:
|
||||||
|
result["applicable_scope"] = line.split(":")[-1].strip() if ":" in line else line.split(":")[-1].strip()
|
||||||
|
|
||||||
|
# 提取关键条款
|
||||||
|
if line.startswith("- 【条款") or line.startswith("【条款"):
|
||||||
|
result["key_clauses"].append(line)
|
||||||
|
|
||||||
|
# 提取关键术语
|
||||||
|
if "关键术语" in line or "术语定义" in line:
|
||||||
|
# 继续读取后续几行
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 提取合规要点
|
||||||
|
if "合规要点" in line or "必须满足" in line:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def batch_summarize(
|
||||||
|
self,
|
||||||
|
documents: list
|
||||||
|
) -> list:
|
||||||
|
"""
|
||||||
|
批量生成摘要
|
||||||
|
|
||||||
|
Args:
|
||||||
|
documents: 文档列表 [{"doc_name": str, "content": str}, ...]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: 摘要结果列表
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for doc in documents:
|
||||||
|
result = self.summarize(doc["doc_name"], doc["content"])
|
||||||
|
results.append(result)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def summarize_document(
|
||||||
|
doc_name: str,
|
||||||
|
content: str,
|
||||||
|
**kwargs
|
||||||
|
) -> DocumentSummary:
|
||||||
|
"""便捷函数:生成文档摘要"""
|
||||||
|
summarizer = DocumentSummarizer(**kwargs)
|
||||||
|
return summarizer.summarize(doc_name, content)
|
||||||
258
backend/app/services/llm/llm_factory.py
Normal file
258
backend/app/services/llm/llm_factory.py
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
# src/services/llm/llm_factory.py
|
||||||
|
"""LLM工厂 - 统一创建和管理LLM客户端"""
|
||||||
|
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from loguru import logger
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from .base_client import BaseLLMClient, LLMConfig, LLMProvider, LLMResponse
|
||||||
|
from .deepseek_client import DeepSeekClient
|
||||||
|
from .qwen_client import QwenClient, QwenVLClient
|
||||||
|
|
||||||
|
|
||||||
|
# 默认模型映射
|
||||||
|
DEFAULT_MODELS = {
|
||||||
|
LLMProvider.DEEPSEEK: "deepseek-v4-flash",
|
||||||
|
LLMProvider.QWEN: "qwen3.5-flash",
|
||||||
|
LLMProvider.QWEN_VL: "qwen3-vl-plus"
|
||||||
|
}
|
||||||
|
|
||||||
|
# API基础URL(使用统一代理服务)
|
||||||
|
DEFAULT_BASE_URLS = {
|
||||||
|
LLMProvider.DEEPSEEK: "http://6.86.80.4:30080/v1",
|
||||||
|
LLMProvider.QWEN: "http://6.86.80.4:30080/v1",
|
||||||
|
LLMProvider.QWEN_VL: "http://6.86.80.4:30080/v1"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class LLMFactory:
|
||||||
|
"""
|
||||||
|
LLM客户端工厂(支持全局缓存)
|
||||||
|
|
||||||
|
支持的提供商和模型:
|
||||||
|
- DeepSeek: deepseek-chat (DeepSeek-V3), deepseek-coder
|
||||||
|
- Qwen: qwen-turbo, qwen-plus, qwen-max, qwen-long
|
||||||
|
- QwenVL: qwen-vl-plus, qwen-vl-max (多模态)
|
||||||
|
|
||||||
|
使用示例:
|
||||||
|
factory = LLMFactory()
|
||||||
|
|
||||||
|
# 使用默认配置
|
||||||
|
client = factory.create("deepseek")
|
||||||
|
|
||||||
|
# 自定义配置
|
||||||
|
client = factory.create("qwen", model="qwen-max", temperature=0.5)
|
||||||
|
|
||||||
|
# 调用LLM
|
||||||
|
response = client.complete("你好,介绍一下自己")
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 全局客户端缓存(类级别,跨实例共享)
|
||||||
|
_global_instances: Dict[str, BaseLLMClient] = {}
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._config_cache: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
provider: str,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
**kwargs
|
||||||
|
) -> BaseLLMClient:
|
||||||
|
"""
|
||||||
|
创建LLM客户端
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: 提供商名称 ("deepseek", "qwen", "qwen_vl")
|
||||||
|
api_key: API密钥(如未提供,从环境变量获取)
|
||||||
|
model: 模型名称(如未提供,使用默认模型)
|
||||||
|
base_url: API基础URL
|
||||||
|
max_tokens: 最大输出token数
|
||||||
|
temperature: 温度参数
|
||||||
|
**kwargs: 其他配置参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseLLMClient: LLM客户端实例
|
||||||
|
"""
|
||||||
|
provider_enum = self._parse_provider(provider)
|
||||||
|
|
||||||
|
# 获取配置
|
||||||
|
api_key = api_key or self._get_api_key(provider_enum)
|
||||||
|
model = model or DEFAULT_MODELS.get(provider_enum)
|
||||||
|
base_url = base_url or DEFAULT_BASE_URLS.get(provider_enum)
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError(f"缺少API密钥,请设置环境变量或传入api_key参数")
|
||||||
|
|
||||||
|
# 检查全局缓存
|
||||||
|
cache_key = f"{provider}_{model}"
|
||||||
|
if cache_key in LLMFactory._global_instances:
|
||||||
|
logger.debug(f"使用缓存的LLM客户端: {cache_key}")
|
||||||
|
return LLMFactory._global_instances[cache_key]
|
||||||
|
|
||||||
|
config = LLMConfig(
|
||||||
|
provider=provider_enum,
|
||||||
|
model=model,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建客户端
|
||||||
|
client = self._create_client(config)
|
||||||
|
|
||||||
|
# 缓存到全局实例
|
||||||
|
LLMFactory._global_instances[cache_key] = client
|
||||||
|
|
||||||
|
logger.info(f"LLM客户端创建成功并缓存: {provider} - {model}")
|
||||||
|
return client
|
||||||
|
|
||||||
|
def _parse_provider(self, provider: str) -> LLMProvider:
|
||||||
|
"""解析提供商名称"""
|
||||||
|
provider_map = {
|
||||||
|
"deepseek": LLMProvider.DEEPSEEK,
|
||||||
|
"deepseek-v3": LLMProvider.DEEPSEEK,
|
||||||
|
"deepseek_chat": LLMProvider.DEEPSEEK,
|
||||||
|
"qwen": LLMProvider.QWEN,
|
||||||
|
"qwen-turbo": LLMProvider.QWEN,
|
||||||
|
"qwen-plus": LLMProvider.QWEN,
|
||||||
|
"qwen-max": LLMProvider.QWEN,
|
||||||
|
"qwen3.5-flash": LLMProvider.QWEN,
|
||||||
|
"qwen3.5-plus": LLMProvider.QWEN,
|
||||||
|
"qwen_vl": LLMProvider.QWEN_VL,
|
||||||
|
"qwen-vl": LLMProvider.QWEN_VL,
|
||||||
|
"qwen-vl-plus": LLMProvider.QWEN_VL,
|
||||||
|
"qwen-vl-max": LLMProvider.QWEN_VL
|
||||||
|
}
|
||||||
|
|
||||||
|
provider_lower = provider.lower()
|
||||||
|
if provider_lower not in provider_map:
|
||||||
|
raise ValueError(f"不支持的提供商: {provider},支持的: {list(provider_map.keys())}")
|
||||||
|
|
||||||
|
return provider_map[provider_lower]
|
||||||
|
|
||||||
|
def _get_api_key(self, provider: LLMProvider) -> Optional[str]:
|
||||||
|
"""从环境变量获取API密钥"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
key_map = {
|
||||||
|
LLMProvider.DEEPSEEK: ["DEEPSEEK_API_KEY", "OPENAI_API_KEY"],
|
||||||
|
LLMProvider.QWEN: ["QWEN_API_KEY", "DASHSCOPE_API_KEY", "ALIBABA_API_KEY"],
|
||||||
|
LLMProvider.QWEN_VL: ["QWEN_API_KEY", "DASHSCOPE_API_KEY", "ALIBABA_API_KEY"]
|
||||||
|
}
|
||||||
|
|
||||||
|
for key_name in key_map.get(provider, []):
|
||||||
|
api_key = os.getenv(key_name)
|
||||||
|
if api_key:
|
||||||
|
return api_key
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _create_client(self, config: LLMConfig) -> BaseLLMClient:
|
||||||
|
"""创建具体客户端"""
|
||||||
|
client_map = {
|
||||||
|
LLMProvider.DEEPSEEK: DeepSeekClient,
|
||||||
|
LLMProvider.QWEN: QwenClient,
|
||||||
|
LLMProvider.QWEN_VL: QwenVLClient
|
||||||
|
}
|
||||||
|
|
||||||
|
client_class = client_map.get(config.provider)
|
||||||
|
if not client_class:
|
||||||
|
raise ValueError(f"不支持的提供商: {config.provider}")
|
||||||
|
|
||||||
|
return client_class(config)
|
||||||
|
|
||||||
|
def get_cached(self, provider: str, model: Optional[str] = None) -> Optional[BaseLLMClient]:
|
||||||
|
"""获取缓存的客户端"""
|
||||||
|
provider_enum = self._parse_provider(provider)
|
||||||
|
model = model or DEFAULT_MODELS.get(provider_enum)
|
||||||
|
cache_key = f"{provider}_{model}"
|
||||||
|
return LLMFactory._global_instances.get(cache_key)
|
||||||
|
|
||||||
|
def list_available_providers(self) -> Dict[str, list]:
|
||||||
|
"""列出可用的提供商和模型"""
|
||||||
|
return {
|
||||||
|
"deepseek": DeepSeekClient.SUPPORTED_MODELS,
|
||||||
|
"qwen": QwenClient.SUPPORTED_MODELS,
|
||||||
|
"qwen_vl": QwenVLClient.SUPPORTED_MODELS
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def preload_clients(cls, providers: list = None):
|
||||||
|
"""
|
||||||
|
预加载LLM客户端(应用启动时调用)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
providers: 要预加载的提供商列表,默认加载qwen和deepseek
|
||||||
|
"""
|
||||||
|
if providers is None:
|
||||||
|
providers = ["qwen", "deepseek"]
|
||||||
|
|
||||||
|
factory = cls()
|
||||||
|
for provider in providers:
|
||||||
|
try:
|
||||||
|
client = factory.create(provider)
|
||||||
|
logger.success(f"预加载LLM客户端成功: {provider}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"预加载LLM客户端失败: {provider} - {e}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_global_client(cls, provider: str, model: Optional[str] = None) -> Optional[BaseLLMClient]:
|
||||||
|
"""获取全局缓存的客户端"""
|
||||||
|
provider_lower = provider.lower()
|
||||||
|
# 处理模型名作为provider的情况(如 qwen3.5-flash)
|
||||||
|
if provider_lower.startswith("qwen"):
|
||||||
|
provider_lower = "qwen"
|
||||||
|
model = model or DEFAULT_MODELS.get(LLMProvider.QWEN if provider_lower == "qwen" else LLMProvider.DEEPSEEK)
|
||||||
|
cache_key = f"{provider_lower}_{model}"
|
||||||
|
return cls._global_instances.get(cache_key)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def cleanup(cls):
|
||||||
|
"""清理所有缓存的客户端"""
|
||||||
|
for cache_key, client in cls._global_instances.items():
|
||||||
|
try:
|
||||||
|
client.close()
|
||||||
|
logger.debug(f"关闭LLM客户端: {cache_key}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"关闭LLM客户端失败: {cache_key} - {e}")
|
||||||
|
cls._global_instances.clear()
|
||||||
|
logger.info("所有LLM客户端已清理")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_llm_factory() -> LLMFactory:
|
||||||
|
"""获取LLM工厂实例(缓存)"""
|
||||||
|
return LLMFactory()
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm_client(
|
||||||
|
provider: str = "qwen",
|
||||||
|
model: Optional[str] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> BaseLLMClient:
|
||||||
|
"""
|
||||||
|
便捷函数:获取LLM客户端(优先使用缓存)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: 提供商名称
|
||||||
|
model: 模型名称
|
||||||
|
**kwargs: 其他配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseLLMClient: LLM客户端实例
|
||||||
|
"""
|
||||||
|
factory = get_llm_factory()
|
||||||
|
|
||||||
|
# 先尝试获取缓存的实例
|
||||||
|
cached = factory.get_cached(provider, model)
|
||||||
|
if cached:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
return factory.create(provider, model=model, **kwargs)
|
||||||
392
backend/app/services/llm/qwen_client.py
Normal file
392
backend/app/services/llm/qwen_client.py
Normal file
@@ -0,0 +1,392 @@
|
|||||||
|
# src/services/llm/qwen_client.py
|
||||||
|
"""Qwen LLM客户端 - 支持OpenAI兼容API格式"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
from typing import List, Dict, Optional, Generator, AsyncGenerator
|
||||||
|
from loguru import logger
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from .base_client import BaseLLMClient, LLMResponse, LLMConfig, LLMProvider
|
||||||
|
|
||||||
|
|
||||||
|
class QwenClient(BaseLLMClient):
|
||||||
|
"""
|
||||||
|
Qwen API客户端(OpenAI兼容格式)
|
||||||
|
|
||||||
|
支持通过new-api等代理服务调用:
|
||||||
|
- qwen-turbo
|
||||||
|
- qwen-plus
|
||||||
|
- qwen-max
|
||||||
|
- qwen3.5-flash (推荐:快速响应)
|
||||||
|
- qwen3.5-plus
|
||||||
|
- qwen-long
|
||||||
|
- qwen2.5系列
|
||||||
|
"""
|
||||||
|
|
||||||
|
SUPPORTED_MODELS = [
|
||||||
|
"qwen-turbo",
|
||||||
|
"qwen-plus",
|
||||||
|
"qwen-max",
|
||||||
|
"qwen-max-longcontext",
|
||||||
|
"qwen-long",
|
||||||
|
"qwen3.5-flash",
|
||||||
|
"qwen3.5-plus",
|
||||||
|
"qwen3-plus",
|
||||||
|
"qwen2.5-72b-instruct",
|
||||||
|
"qwen2.5-32b-instruct",
|
||||||
|
"qwen2.5-14b-instruct",
|
||||||
|
"qwen2.5-7b-instruct"
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, config: LLMConfig):
|
||||||
|
if config.provider not in [LLMProvider.QWEN, LLMProvider.QWEN_VL]:
|
||||||
|
raise ValueError(f"配置provider应为Qwen,实际为{config.provider}")
|
||||||
|
super().__init__(config)
|
||||||
|
self._init_client()
|
||||||
|
|
||||||
|
def _init_client(self):
|
||||||
|
"""初始化HTTP客户端"""
|
||||||
|
# OpenAI兼容API格式
|
||||||
|
self._client = httpx.Client(
|
||||||
|
base_url=self.config.base_url,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {self.config.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
},
|
||||||
|
timeout=self.config.timeout
|
||||||
|
)
|
||||||
|
logger.info(f"Qwen客户端初始化完成: {self.config.base_url} - {self.config.model}")
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""对话补全(OpenAI兼容格式)"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# OpenAI兼容格式的请求体
|
||||||
|
payload = {
|
||||||
|
"model": self.config.model,
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": max_tokens or self.config.max_tokens,
|
||||||
|
"temperature": temperature or self.config.temperature,
|
||||||
|
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
|
||||||
|
# OpenAI兼容接口路径
|
||||||
|
response = self._client.post("/chat/completions", json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
latency_ms = int((time.time() - start_time) * 1000)
|
||||||
|
|
||||||
|
# OpenAI兼容格式的响应解析
|
||||||
|
choices = data.get("choices", [{}])
|
||||||
|
message = choices[0].get("message", {})
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=message.get("content", ""),
|
||||||
|
model=data.get("model", self.config.model),
|
||||||
|
usage=data.get("usage", {}),
|
||||||
|
finish_reason=choices[0].get("finish_reason", "stop"),
|
||||||
|
latency_ms=latency_ms
|
||||||
|
)
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"Qwen API错误: {e.response.status_code} - {e.response.text}")
|
||||||
|
return LLMResponse(
|
||||||
|
content="",
|
||||||
|
model=self.config.model,
|
||||||
|
error=f"API错误: {e.response.status_code} - {e.response.text[:200]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Qwen调用失败: {e}")
|
||||||
|
return LLMResponse(
|
||||||
|
content="",
|
||||||
|
model=self.config.model,
|
||||||
|
error=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
def stream_chat(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> Generator[str, None, None]:
|
||||||
|
"""
|
||||||
|
流式对话补全(SSE格式)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
str: 每次返回一个文本片段
|
||||||
|
|
||||||
|
使用示例:
|
||||||
|
for chunk in client.stream_chat(messages):
|
||||||
|
print(chunk, end="", flush=True)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# OpenAI兼容格式的请求体,启用流式输出
|
||||||
|
payload = {
|
||||||
|
"model": self.config.model,
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": max_tokens or self.config.max_tokens,
|
||||||
|
"temperature": temperature or self.config.temperature,
|
||||||
|
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||||
|
"stream": True # 启用流式输出
|
||||||
|
}
|
||||||
|
|
||||||
|
# 使用stream模式发送请求
|
||||||
|
with self._client.stream("POST", "/chat/completions", json=payload) as response:
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if line:
|
||||||
|
line = line.strip()
|
||||||
|
# SSE格式: data: {...}
|
||||||
|
if line.startswith("data: "):
|
||||||
|
data_str = line[6:] # 移除 "data: " 前缀
|
||||||
|
if data_str == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
data = json.loads(data_str)
|
||||||
|
choices = data.get("choices", [])
|
||||||
|
if not choices:
|
||||||
|
continue # 跳过空的choices
|
||||||
|
delta = choices[0].get("delta", {})
|
||||||
|
content = delta.get("content", "")
|
||||||
|
if content:
|
||||||
|
yield content
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"Qwen流式API错误: {e.response.status_code}")
|
||||||
|
yield f"[ERROR: API返回错误 {e.response.status_code}]"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Qwen流式调用失败: {e}")
|
||||||
|
yield f"[ERROR: {str(e)}]"
|
||||||
|
|
||||||
|
async def async_stream_chat(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""
|
||||||
|
异步流式对话补全(用于FastAPI SSE响应)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
str: 每次返回一个文本片段
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# 使用同步流式方法,包装为异步
|
||||||
|
for chunk in self.stream_chat(messages, max_tokens, temperature, **kwargs):
|
||||||
|
yield chunk
|
||||||
|
# 给async循环一个小延迟,让其他任务有机会执行
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
def get_available_models(self) -> List[str]:
|
||||||
|
"""获取可用模型列表"""
|
||||||
|
return self.SUPPORTED_MODELS
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""关闭客户端"""
|
||||||
|
if self._client:
|
||||||
|
self._client.close()
|
||||||
|
|
||||||
|
|
||||||
|
class QwenVLClient(BaseLLMClient):
|
||||||
|
"""
|
||||||
|
Qwen VL多模态客户端(OpenAI兼容格式)
|
||||||
|
|
||||||
|
支持模型:
|
||||||
|
- qwen-vl-plus
|
||||||
|
- qwen-vl-max
|
||||||
|
- qwen3-vl-plus
|
||||||
|
- qwen2-vl-7b-instruct
|
||||||
|
- qwen2-vl-72b-instruct
|
||||||
|
"""
|
||||||
|
|
||||||
|
SUPPORTED_MODELS = [
|
||||||
|
"qwen-vl-plus",
|
||||||
|
"qwen-vl-max",
|
||||||
|
"qwen3-vl-plus",
|
||||||
|
"qwen2-vl-7b-instruct",
|
||||||
|
"qwen2-vl-72b-instruct"
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, config: LLMConfig):
|
||||||
|
if config.provider != LLMProvider.QWEN_VL:
|
||||||
|
raise ValueError(f"配置provider应为QWEN_VL,实际为{config.provider}")
|
||||||
|
super().__init__(config)
|
||||||
|
self._init_client()
|
||||||
|
|
||||||
|
def _init_client(self):
|
||||||
|
"""初始化HTTP客户端"""
|
||||||
|
self._client = httpx.Client(
|
||||||
|
base_url=self.config.base_url,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {self.config.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
},
|
||||||
|
timeout=self.config.timeout
|
||||||
|
)
|
||||||
|
logger.info(f"QwenVL客户端初始化完成: {self.config.base_url} - {self.config.model}")
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""多模态对话补全(OpenAI兼容格式)
|
||||||
|
|
||||||
|
支持图片输入,消息格式:
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}},
|
||||||
|
{"type": "text", "text": "描述这张图片"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# OpenAI兼容格式的请求体
|
||||||
|
payload = {
|
||||||
|
"model": self.config.model,
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": max_tokens or self.config.max_tokens,
|
||||||
|
"temperature": temperature or self.config.temperature,
|
||||||
|
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
|
||||||
|
response = self._client.post("/chat/completions", json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
latency_ms = int((time.time() - start_time) * 1000)
|
||||||
|
|
||||||
|
choices = data.get("choices", [{}])
|
||||||
|
message = choices[0].get("message", {})
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=message.get("content", ""),
|
||||||
|
model=data.get("model", self.config.model),
|
||||||
|
usage=data.get("usage", {}),
|
||||||
|
finish_reason=choices[0].get("finish_reason", "stop"),
|
||||||
|
latency_ms=latency_ms
|
||||||
|
)
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"QwenVL API错误: {e.response.status_code} - {e.response.text}")
|
||||||
|
return LLMResponse(
|
||||||
|
content="",
|
||||||
|
model=self.config.model,
|
||||||
|
error=f"API错误: {e.response.status_code} - {e.response.text[:200]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"QwenVL调用失败: {e}")
|
||||||
|
return LLMResponse(
|
||||||
|
content="",
|
||||||
|
model=self.config.model,
|
||||||
|
error=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
def stream_chat(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> Generator[str, None, None]:
|
||||||
|
"""流式多模态对话补全"""
|
||||||
|
try:
|
||||||
|
payload = {
|
||||||
|
"model": self.config.model,
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": max_tokens or self.config.max_tokens,
|
||||||
|
"temperature": temperature or self.config.temperature,
|
||||||
|
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||||
|
"stream": True
|
||||||
|
}
|
||||||
|
|
||||||
|
with self._client.stream("POST", "/chat/completions", json=payload) as response:
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if line:
|
||||||
|
line = line.strip()
|
||||||
|
if line.startswith("data: "):
|
||||||
|
data_str = line[6:]
|
||||||
|
if data_str == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
data = json.loads(data_str)
|
||||||
|
choices = data.get("choices", [])
|
||||||
|
if not choices:
|
||||||
|
continue # 跳过空的choices
|
||||||
|
delta = choices[0].get("delta", {})
|
||||||
|
content = delta.get("content", "")
|
||||||
|
if content:
|
||||||
|
yield content
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"QwenVL流式调用失败: {e}")
|
||||||
|
yield f"[ERROR: {str(e)}]"
|
||||||
|
|
||||||
|
def get_available_models(self) -> List[str]:
|
||||||
|
"""获取可用模型列表"""
|
||||||
|
return self.SUPPORTED_MODELS
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""关闭客户端"""
|
||||||
|
if self._client:
|
||||||
|
self._client.close()
|
||||||
|
|
||||||
|
|
||||||
|
def create_qwen_client(
|
||||||
|
api_key: str,
|
||||||
|
model: str = "qwen3.5-flash",
|
||||||
|
base_url: str = "http://6.86.80.4:30080/v1",
|
||||||
|
**kwargs
|
||||||
|
) -> QwenClient:
|
||||||
|
"""便捷函数:创建Qwen客户端"""
|
||||||
|
config = LLMConfig(
|
||||||
|
provider=LLMProvider.QWEN,
|
||||||
|
model=model,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return QwenClient(config)
|
||||||
|
|
||||||
|
|
||||||
|
def create_qwen_vl_client(
|
||||||
|
api_key: str,
|
||||||
|
model: str = "qwen3-vl-plus",
|
||||||
|
base_url: str = "http://6.86.80.4:30080/v1",
|
||||||
|
**kwargs
|
||||||
|
) -> QwenVLClient:
|
||||||
|
"""便捷函数:创建QwenVL客户端"""
|
||||||
|
config = LLMConfig(
|
||||||
|
provider=LLMProvider.QWEN_VL,
|
||||||
|
model=model,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return QwenVLClient(config)
|
||||||
425
backend/app/services/mock_data.py
Normal file
425
backend/app/services/mock_data.py
Normal file
@@ -0,0 +1,425 @@
|
|||||||
|
"""
|
||||||
|
Mock数据服务 - 提供预设假数据供前后端对接测试
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, List, Any
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
# 预设法规文档列表
|
||||||
|
MOCK_DOCUMENTS: List[Dict[str, Any]] = [
|
||||||
|
{
|
||||||
|
"id": "doc-001",
|
||||||
|
"name": "道路交通安全法.pdf",
|
||||||
|
"chunks": 156,
|
||||||
|
"status": "indexed",
|
||||||
|
"created_at": datetime(2026, 5, 10, 10, 0, 0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "doc-002",
|
||||||
|
"name": "机动车登记规定.docx",
|
||||||
|
"chunks": 89,
|
||||||
|
"status": "indexed",
|
||||||
|
"created_at": datetime(2026, 5, 10, 11, 0, 0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "doc-003",
|
||||||
|
"name": "电动自行车规范.pdf",
|
||||||
|
"chunks": 42,
|
||||||
|
"status": "indexed",
|
||||||
|
"created_at": datetime(2026, 5, 10, 12, 0, 0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "doc-004",
|
||||||
|
"name": "GB 38031-2020 电动汽车安全要求.pdf",
|
||||||
|
"chunks": 128,
|
||||||
|
"status": "indexed",
|
||||||
|
"created_at": datetime(2026, 5, 10, 13, 0, 0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "doc-005",
|
||||||
|
"name": "C-NCAP管理规则(2021版).pdf",
|
||||||
|
"chunks": 95,
|
||||||
|
"status": "indexed",
|
||||||
|
"created_at": datetime(2026, 5, 10, 14, 0, 0),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# 预设快捷问题
|
||||||
|
MOCK_QUICK_QUESTIONS: List[Dict[str, str]] = [
|
||||||
|
{"id": "q1", "question": "电动自行车需要上牌照吗?", "category": "车辆登记"},
|
||||||
|
{"id": "q2", "question": "新能源汽车有哪些补贴政策?", "category": "新能源"},
|
||||||
|
{"id": "q3", "question": "车辆年检的规定是什么?", "category": "年检"},
|
||||||
|
{"id": "q4", "question": "驾驶证过期了怎么处理?", "category": "驾驶证"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# 预设检索结果
|
||||||
|
MOCK_RETRIEVAL_RESULTS: List[Dict[str, Any]] = [
|
||||||
|
{
|
||||||
|
"id": "chunk-001",
|
||||||
|
"score": 0.95,
|
||||||
|
"preview": "根据《道路交通安全法》第十八条规定,电动自行车经公安机关交通管理部门登记后,方可上道路行驶...",
|
||||||
|
"doc_name": "道路交通安全法",
|
||||||
|
"clause": "第十八条",
|
||||||
|
"content": "根据《道路交通安全法》第十八条规定,电动自行车经公安机关交通管理部门登记后,方可上道路行驶。电动自行车应当符合国家标准,最高设计车速不超过二十五公里每小时,整车质量不超过五十五千克。",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "chunk-002",
|
||||||
|
"score": 0.88,
|
||||||
|
"preview": "电动自行车需符合GB17761-2018国家标准,包括最高车速、整车质量、脚踏骑行能力等要求...",
|
||||||
|
"doc_name": "电动自行车规范",
|
||||||
|
"clause": "第4条",
|
||||||
|
"content": "电动自行车需符合GB17761-2018国家标准。主要技术要求包括:最高设计车速不超过25km/h,整车质量不超过55kg,具有脚踏骑行能力,蓄电池标称电压不超过48V,电动机额定连续输出功率不超过400W。",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "chunk-003",
|
||||||
|
"score": 0.82,
|
||||||
|
"preview": "机动车登记规定:初次申领机动车号牌、行驶证的,机动车所有人应当向住所地的车辆管理所申请注册登记...",
|
||||||
|
"doc_name": "机动车登记规定",
|
||||||
|
"clause": "第5条",
|
||||||
|
"content": "机动车登记规定:初次申领机动车号牌、行驶证的,机动车所有人应当向住所地的车辆管理所申请注册登记。申请注册登记的,应当提交机动车所有人的身份证明、购车发票等机动车来历证明、机动车整车出厂合格证明或者进口机动车进口凭证。",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "chunk-004",
|
||||||
|
"score": 0.75,
|
||||||
|
"preview": "驾驶电动自行车上道路行驶,应当佩戴安全头盔,遵守道路交通安全法律法规...",
|
||||||
|
"doc_name": "道路交通安全法",
|
||||||
|
"clause": "第76条",
|
||||||
|
"content": "驾驶电动自行车上道路行驶,应当佩戴安全头盔,遵守道路交通安全法律法规。电动自行车不得逆向行驶,不得在机动车道内行驶,最高车速不得超过规定的限速。",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "chunk-005",
|
||||||
|
"score": 0.68,
|
||||||
|
"preview": "电动汽车动力电池安全要求:电池系统发生热失控后,应在5分钟内不起火不爆炸...",
|
||||||
|
"doc_name": "GB 38031-2020",
|
||||||
|
"clause": "第7条",
|
||||||
|
"content": "电动汽车动力电池安全要求(GB 38031-2020):电池系统发生热失控后,应在5分钟内不起火不爆炸,为乘员预留逃生时间。电池包需通过针刺、过充、短路等安全测试。",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# 预设RAG问答答案模板(按关键词匹配)
|
||||||
|
MOCK_RAG_ANSWERS: Dict[str, Dict[str, Any]] = {
|
||||||
|
"电动自行车": {
|
||||||
|
"text": "根据《道路交通安全法》及相关规范,电动自行车上路需满足以下条件:\n\n1. 符合国家标准 GB17761-2018\n2. 经公安机关交通管理部门登记\n3. 最高设计车速不超过 25km/h\n4. 整车质量不超过 55kg\n5. 具有脚踏骑行能力\n6. 蓄电池标称电压不超过 48V\n\n行驶时还需佩戴安全头盔,不得逆向行驶或在机动车道内行驶。",
|
||||||
|
"retrieval_ids": ["chunk-001", "chunk-002", "chunk-004"],
|
||||||
|
},
|
||||||
|
"驾驶证": {
|
||||||
|
"text": "驾驶证申请流程如下:\n\n1. 到驾校报名并参加培训\n2. 通过科目一(理论考试)\n3. 通过科目二(场地驾驶技能考试)\n4. 通过科目三(道路驾驶技能考试)\n5. 通过科目四(安全文明驾驶常识考试)\n6. 领取驾驶证\n\n初次申领需到住所地车辆管理所申请注册登记。",
|
||||||
|
"retrieval_ids": ["chunk-003"],
|
||||||
|
},
|
||||||
|
"超速": {
|
||||||
|
"text": "超速处罚标准(根据《道路交通安全法》):\n\n- 超速10%以下:警告\n- 超速10%-20%:罚款50-200元\n- 超速20%-50%:罚款200-500元,记3-6分\n- 超速50%以上:罚款500-2000元,记12分,可吊销驾驶证\n\n机动车驾驶人违反道路交通安全法律、法规将处警告或二十元以上二百元以下罚款。",
|
||||||
|
"retrieval_ids": ["chunk-001"],
|
||||||
|
},
|
||||||
|
"年检": {
|
||||||
|
"text": "车辆年检规定:\n\n- 小型私家车:6年内免检(每2年申领标志),6-10年每2年检验,10年以上每年检验\n- 车辆需携带行驶证、交强险保单\n- 检验项目:灯光、制动、排放等\n\n机动车所有人的住所迁出车辆管理所管辖区域的,需在登记证书上签注变更事项。",
|
||||||
|
"retrieval_ids": ["chunk-003"],
|
||||||
|
},
|
||||||
|
"电池": {
|
||||||
|
"text": "电动汽车电池安全标准(GB 38031-2020):\n\n1. 热失控要求:电池系统发生热失控后,应在5分钟内不起火不爆炸,为乘员预留逃生时间\n2. 电池包需通过针刺、过充、短路等安全测试\n3. 充电系统应具备过充保护功能,当电池SOC达到100%时应自动停止充电\n4. 充电接口应符合GB/T 18487.1标准要求\n\n以上要求确保电动汽车的整车安全性。",
|
||||||
|
"retrieval_ids": ["chunk-005"],
|
||||||
|
},
|
||||||
|
"碰撞": {
|
||||||
|
"text": "正面碰撞测试要求(C-NCAP管理规则):\n\n1. 正面100%重叠刚性壁障碰撞试验\n2. 碰撞速度:50km/h\n3. 试验后要求:\n - 车门应能打开\n - 燃油系统无泄漏\n - 座椅及安全带功能正常\n\n此测试用于评估车辆在正面碰撞事故中对乘员的保护能力。",
|
||||||
|
"retrieval_ids": [],
|
||||||
|
},
|
||||||
|
"AEB": {
|
||||||
|
"text": "AEB(自动紧急制动系统)测试标准:\n\n1. 系统应在检测到前方障碍物时主动减速或停车\n2. 测试场景分为三种:\n - 目标车静止\n - 目标车移动\n - 目标车制动\n3. AEB功能是C-NCAP评分的重要加分项\n\n该系统对提升车辆主动安全性能具有重要意义。",
|
||||||
|
"retrieval_ids": [],
|
||||||
|
},
|
||||||
|
"高速公路": {
|
||||||
|
"text": "高速公路安全距离规定:\n\n1. 车速超过100km/h时,与同车道前车保持100米以上距离\n2. 车速低于100km/h时,距离可适当缩短\n3. 执行紧急任务的警车、消防车、救护车、工程救险车不受行驶速度限制\n\n保持安全距离是预防追尾事故的关键措施。",
|
||||||
|
"retrieval_ids": [],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# 预设合规分析结果
|
||||||
|
MOCK_COMPLIANCE_RESULT: Dict[str, Any] = {
|
||||||
|
"task_id": "task-001",
|
||||||
|
"dashboard": {
|
||||||
|
"score": 78,
|
||||||
|
"high_risk_count": 2,
|
||||||
|
"medium_risk_count": 1,
|
||||||
|
"low_risk_count": 0,
|
||||||
|
"need_fix_segments": 3,
|
||||||
|
"status": "warning",
|
||||||
|
"status_label": "需优化",
|
||||||
|
},
|
||||||
|
"segments": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"index": 1,
|
||||||
|
"intent": "车身结构设计",
|
||||||
|
"start_pos": 45,
|
||||||
|
"end_pos": 230,
|
||||||
|
"content": "车身采用高强度钢铝混合结构,A柱和B柱使用热成型钢板,厚度2.5mm。车顶结构设计满足GB 26112-2010抗压强度要求,正面碰撞能量吸收区域采用渐进式变形设计,确保碰撞时能量有效分散。",
|
||||||
|
"risk_level": "high",
|
||||||
|
"regulations": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"name": "GB 26112-2010",
|
||||||
|
"clause": "第4.2条",
|
||||||
|
"score": 0.95,
|
||||||
|
"match_keyword": "车顶抗压强度",
|
||||||
|
"category": "high",
|
||||||
|
"full_content": "车顶结构应能承受相当于车辆整备质量1.5倍的载荷,载荷分布应均匀,试验后车顶变形量不超过规定值。",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"name": "C-NCAP管理规则",
|
||||||
|
"clause": "第3.1条",
|
||||||
|
"score": 0.88,
|
||||||
|
"match_keyword": "正面碰撞",
|
||||||
|
"category": "high",
|
||||||
|
"full_content": "正面碰撞试验速度为50km/h,碰撞后车门应能打开,燃油系统无泄漏,座椅及安全带功能正常。",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"name": "GB 11551-2014",
|
||||||
|
"clause": "第5条",
|
||||||
|
"score": 0.72,
|
||||||
|
"match_keyword": "碰撞能量吸收",
|
||||||
|
"category": "medium",
|
||||||
|
"full_content": "车辆正面碰撞时应有效保护乘员,碰撞能量应通过车身结构合理分散。",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4,
|
||||||
|
"name": "机动车安全技术条件",
|
||||||
|
"clause": "第12条",
|
||||||
|
"score": 0.58,
|
||||||
|
"match_keyword": "A柱强度",
|
||||||
|
"category": "medium",
|
||||||
|
"full_content": "A柱应具备足够的抗变形能力,材料强度应符合相关标准要求。",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"index": 2,
|
||||||
|
"intent": "动力系统配置",
|
||||||
|
"start_pos": 298,
|
||||||
|
"end_pos": 425,
|
||||||
|
"content": "搭载永磁同步电机,最大功率150kW,峰值扭矩310Nm。电池组采用三元锂离子电池,容量75kWh,能量密度180Wh/kg。充电接口支持快充(30分钟充至80%)和慢充(8小时充满),符合GB/T 18487.1-2015标准。",
|
||||||
|
"risk_level": "medium",
|
||||||
|
"regulations": [
|
||||||
|
{
|
||||||
|
"id": 5,
|
||||||
|
"name": "GB/T 18487.1-2015",
|
||||||
|
"clause": "第6条",
|
||||||
|
"score": 0.94,
|
||||||
|
"match_keyword": "充电接口标准",
|
||||||
|
"category": "high",
|
||||||
|
"full_content": "电动汽车传导充电接口应符合GB/T 18487.1标准要求,充电系统应具备过充保护功能。",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6,
|
||||||
|
"name": "GB/T 31484-2015",
|
||||||
|
"clause": "第4条",
|
||||||
|
"score": 0.85,
|
||||||
|
"match_keyword": "电池能量密度",
|
||||||
|
"category": "high",
|
||||||
|
"full_content": "动力电池能量密度不低于120Wh/kg,电池系统需通过热失控测试。",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7,
|
||||||
|
"name": "新能源汽车生产企业准入",
|
||||||
|
"clause": "第8条",
|
||||||
|
"score": 0.65,
|
||||||
|
"match_keyword": "电机功率",
|
||||||
|
"category": "medium",
|
||||||
|
"full_content": "驱动电机应符合相关技术标准,功率参数应在规定范围内。",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8,
|
||||||
|
"name": "电动汽车安全要求",
|
||||||
|
"clause": "第7条",
|
||||||
|
"score": 0.45,
|
||||||
|
"match_keyword": "充电时间",
|
||||||
|
"category": "low",
|
||||||
|
"full_content": "充电系统应具备过充保护功能,当电池SOC达到100%时应自动停止充电。",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"index": 3,
|
||||||
|
"intent": "安全配置设计",
|
||||||
|
"start_pos": 570,
|
||||||
|
"end_pos": 725,
|
||||||
|
"content": "配备6个安全气囊(前排双气囊、侧气囊、侧气帘),采用预紧式安全带。ABS系统采用博世第9代ESP,具备碰撞预警功能(FCW)和自动紧急制动(AEB)。方向盘集成驾驶员疲劳监测摄像头。",
|
||||||
|
"risk_level": "low",
|
||||||
|
"regulations": [
|
||||||
|
{
|
||||||
|
"id": 9,
|
||||||
|
"name": "GB 27887-2011",
|
||||||
|
"clause": "第5条",
|
||||||
|
"score": 0.92,
|
||||||
|
"match_keyword": "安全气囊",
|
||||||
|
"category": "high",
|
||||||
|
"full_content": "乘用车应配备驾驶员和乘客安全气囊,气囊系统应符合相关技术标准。",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 10,
|
||||||
|
"name": "GB/T 26991-2011",
|
||||||
|
"clause": "第3条",
|
||||||
|
"score": 0.78,
|
||||||
|
"match_keyword": "ABS系统",
|
||||||
|
"category": "medium",
|
||||||
|
"full_content": "车辆应配备防抱死制动系统,系统性能应符合相关标准要求。",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 11,
|
||||||
|
"name": "C-NCAP管理规则",
|
||||||
|
"clause": "第4.2条",
|
||||||
|
"score": 0.71,
|
||||||
|
"match_keyword": "AEB自动制动",
|
||||||
|
"category": "medium",
|
||||||
|
"full_content": "主动安全配置评分包含AEB功能,AEB系统应能有效检测障碍物并主动减速。",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 12,
|
||||||
|
"name": "机动车运行安全技术条件",
|
||||||
|
"clause": "第15条",
|
||||||
|
"score": 0.38,
|
||||||
|
"match_keyword": "疲劳监测",
|
||||||
|
"category": "low",
|
||||||
|
"full_content": "建议配备驾驶员状态监测系统,及时发现驾驶员疲劳或分心状态。",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"priority_actions": [
|
||||||
|
{
|
||||||
|
"regulation": "GB 26112-2010 第4.2条",
|
||||||
|
"issue": "缺少车顶抗压强度测试数据",
|
||||||
|
"suggestion": "补充车顶抗压强度具体测试数据,确保满足1.5倍整备质量载荷要求",
|
||||||
|
"severity": "high",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"regulation": "GB/T 31484-2015 第4条",
|
||||||
|
"issue": "缺少电池热失控测试报告",
|
||||||
|
"suggestion": "补充电池热失控测试报告,验证5分钟内不起火不爆炸",
|
||||||
|
"severity": "high",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"regulation": "C-NCAP管理规则 第3.1条",
|
||||||
|
"issue": "缺少碰撞后车门开启性能数据",
|
||||||
|
"suggestion": "提供碰撞后车门开启性能测试数据",
|
||||||
|
"severity": "medium",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# 预设合规对话响应模板
|
||||||
|
MOCK_COMPLIANCE_CHAT_RESPONSES: Dict[str, Dict[str, str]] = {
|
||||||
|
"车身结构设计": {
|
||||||
|
"compliance": "根据当前分析,车身结构设计部分存在以下合规问题:\n\n1. GB 26112-2010要求车顶承受1.5倍整备质量载荷,目前设计声明满足要求但缺少测试数据\n2. C-NCAP正面碰撞后车门应能打开,需提供碰撞测试报告\n\n建议补充相关测试数据以提升合规评分。",
|
||||||
|
"interpretation": "GB 26112-2010 第4.2条具体要求解读:\n\n车顶抗压强度测试是车辆被动安全的重要指标。该标准要求车顶结构能够承受相当于车辆整备质量1.5倍的均匀分布载荷,试验后车顶变形量不得超过规定限值。\n\n热成型钢板(22MnB5材料)抗拉强度约1500-1650 MPa,理论上能满足要求,但需通过实际测试验证。",
|
||||||
|
"suggestion": "针对车身结构设计的修改建议:\n\n1. 补充车顶抗压强度测试报告\n2. 提供A柱材料认证证书\n3. 完善正面碰撞能量吸收设计说明\n4. 添加碰撞后车门开启性能数据\n\n这些补充材料可有效提升合规评分。",
|
||||||
|
},
|
||||||
|
"动力系统配置": {
|
||||||
|
"compliance": "动力系统配置整体合规性良好,主要检查点:\n\n1. 电池能量密度180Wh/kg超过最低要求120Wh/kg ✓\n2. 充电接口符合GB/T 18487.1标准 ✓\n3. 快充30分钟充至80%符合行业标准 ✓\n\n需补充电池热失控测试报告。",
|
||||||
|
"interpretation": "GB/T 31484-2015对动力电池的要求解读:\n\n1. 能量密度:不低于120Wh/kg(您的设计180Wh/kg满足要求)\n2. 循环寿命:不少于1000次循环后容量保持率≥80%\n3. 安全测试:需通过针刺、过充、短路等测试\n\n建议补充循环寿命测试数据。",
|
||||||
|
"suggestion": "动力系统配置改进建议:\n\n1. 补充电池热失控测试报告\n2. 提供循环寿命测试数据\n3. 添加充电系统过充保护功能说明\n4. 完善电池管理系统(BMS)技术文档",
|
||||||
|
},
|
||||||
|
"安全配置设计": {
|
||||||
|
"compliance": "安全配置设计合规性评估:\n\n1. 安全气囊配置满足GB 27887-2011要求 ✓\n2. ABS/ESP系统符合标准 ✓\n3. AEB功能是C-NCAP加分项 ✓\n\n驾驶员疲劳监测是建议配置,不强制要求。",
|
||||||
|
"interpretation": "C-NCAP主动安全评分规则解读:\n\nAEB(自动紧急制动)系统是C-NCAP评分的重要加分项,最高可获得额外加分。测试场景包括:\n- 目标车静止场景\n- 目标车移动场景\n- 目标车制动场景\n\n建议完善AEB系统测试数据以获取更高评分。",
|
||||||
|
"suggestion": "安全配置优化建议:\n\n1. 提供AEB系统测试数据\n2. 补充FCW预警功能测试报告\n3. 添加安全气囊展开时间数据\n4. 完善驾驶员疲劳监测系统说明(如有)",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# 预设系统统计数据
|
||||||
|
MOCK_SYSTEM_STATS: Dict[str, int] = {
|
||||||
|
"docs": 5,
|
||||||
|
"chunks": 510,
|
||||||
|
"vectors": 510,
|
||||||
|
"segments": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 预设系统配置
|
||||||
|
MOCK_SYSTEM_CONFIG: Dict[str, Any] = {
|
||||||
|
"llm": {
|
||||||
|
"model": "qwen-max",
|
||||||
|
},
|
||||||
|
"embedding": {
|
||||||
|
"model": "text-embedding-v3",
|
||||||
|
"dimension": 1536,
|
||||||
|
},
|
||||||
|
"milvus": {
|
||||||
|
"host": "localhost",
|
||||||
|
"port": 19530,
|
||||||
|
},
|
||||||
|
"retrieval": {
|
||||||
|
"vector_top_k": 10,
|
||||||
|
"final_top_k": 5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_mock_documents() -> List[Dict[str, Any]]:
|
||||||
|
"""获取预设法规文档列表"""
|
||||||
|
return MOCK_DOCUMENTS
|
||||||
|
|
||||||
|
|
||||||
|
def get_mock_quick_questions() -> List[Dict[str, str]]:
|
||||||
|
"""获取预设快捷问题"""
|
||||||
|
return MOCK_QUICK_QUESTIONS
|
||||||
|
|
||||||
|
|
||||||
|
def get_mock_retrieval(query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
||||||
|
"""根据查询关键词返回预设检索结果"""
|
||||||
|
results = []
|
||||||
|
for keyword, data in MOCK_RAG_ANSWERS.items():
|
||||||
|
if keyword in query:
|
||||||
|
for retrieval_id in data.get("retrieval_ids", []):
|
||||||
|
for item in MOCK_RETRIEVAL_RESULTS:
|
||||||
|
if item["id"] == retrieval_id:
|
||||||
|
results.append({
|
||||||
|
"id": item["id"],
|
||||||
|
"score": item["score"],
|
||||||
|
"preview": item["preview"],
|
||||||
|
"doc_name": item["doc_name"],
|
||||||
|
"clause": item["clause"],
|
||||||
|
})
|
||||||
|
break
|
||||||
|
if not results:
|
||||||
|
results = MOCK_RETRIEVAL_RESULTS[:top_k]
|
||||||
|
return results[:top_k]
|
||||||
|
|
||||||
|
|
||||||
|
def get_mock_rag_answer(query: str) -> str:
|
||||||
|
"""根据查询关键词返回预设答案"""
|
||||||
|
for keyword, data in MOCK_RAG_ANSWERS.items():
|
||||||
|
if keyword in query:
|
||||||
|
return data["text"]
|
||||||
|
return "抱歉,暂未找到与您问题直接相关的法规内容。请尝试更具体的问题,或联系交通管理部门获取详细信息。\n\n您可以尝试询问:电动自行车、驾驶证、超速处罚、年检、电池安全、碰撞测试、AEB系统、高速公路规则等话题。"
|
||||||
|
|
||||||
|
|
||||||
|
def get_mock_compliance_result(task_id: str) -> Dict[str, Any]:
|
||||||
|
"""获取预设合规分析结果"""
|
||||||
|
result = MOCK_COMPLIANCE_RESULT.copy()
|
||||||
|
result["task_id"] = task_id
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_mock_compliance_chat_response(intent: str, query: str) -> str:
|
||||||
|
"""获取预设合规对话响应"""
|
||||||
|
responses = MOCK_COMPLIANCE_CHAT_RESPONSES.get(intent, {})
|
||||||
|
if "合规" in query or "符合" in query:
|
||||||
|
return responses.get("compliance", "根据相关法规分析,该段落的合规性需进一步评估。")
|
||||||
|
elif "解读" in query or "什么" in query or "如何" in query:
|
||||||
|
return responses.get("interpretation", "法规要求详细解读如下...")
|
||||||
|
elif "修改" in query or "建议" in query or "完善" in query:
|
||||||
|
return responses.get("suggestion", "建议进行以下修改以提升合规性...")
|
||||||
|
return f"关于您的问题,{intent}部分涉及多条相关法规。您可以进一步询问合规性评估或修改建议。"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_task_id() -> str:
|
||||||
|
"""生成任务ID"""
|
||||||
|
return f"task-{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_doc_id() -> str:
|
||||||
|
"""生成文档ID"""
|
||||||
|
return f"doc-{uuid.uuid4().hex[:8]}"
|
||||||
7
backend/app/services/parser/__init__.py
Normal file
7
backend/app/services/parser/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# src/services/parser/__init__.py
|
||||||
|
"""文档解析服务"""
|
||||||
|
|
||||||
|
from .pdf_parser import PDFParser
|
||||||
|
from .docx_parser import DocxParser
|
||||||
|
|
||||||
|
__all__ = ["PDFParser", "DocxParser"]
|
||||||
287
backend/app/services/parser/docx_parser.py
Normal file
287
backend/app/services/parser/docx_parser.py
Normal file
@@ -0,0 +1,287 @@
|
|||||||
|
# src/services/parser/docx_parser.py
|
||||||
|
"""Word文档解析 - 使用python-docx"""
|
||||||
|
|
||||||
|
from docx import Document
|
||||||
|
from docx.enum.text import WD_ALIGN_PARAGRAPH
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from loguru import logger
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DocxParagraph:
|
||||||
|
"""段落内容"""
|
||||||
|
text: str
|
||||||
|
level: int = 0 # 标题级别,0表示正文
|
||||||
|
is_list: bool = False
|
||||||
|
list_number: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DocxTable:
|
||||||
|
"""表格内容"""
|
||||||
|
rows: List[List[str]]
|
||||||
|
markdown: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DocxDocumentContent:
|
||||||
|
"""Word文档完整内容"""
|
||||||
|
file_path: str
|
||||||
|
paragraphs: List[DocxParagraph]
|
||||||
|
tables: List[DocxTable]
|
||||||
|
metadata: Dict[str, str] = field(default_factory=dict)
|
||||||
|
markdown_text: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class DocxParser:
|
||||||
|
"""Word文档解析器 - 基于python-docx"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.document = None
|
||||||
|
|
||||||
|
def parse(self, file_path: str) -> DocxDocumentContent:
|
||||||
|
"""
|
||||||
|
解析Word文档
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Word文档路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DocxDocumentContent: 解析后的文档内容
|
||||||
|
"""
|
||||||
|
logger.info(f"开始解析Word文档: {file_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.document = Document(file_path)
|
||||||
|
doc_content = DocxDocumentContent(
|
||||||
|
file_path=file_path,
|
||||||
|
paragraphs=[],
|
||||||
|
tables=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 提取文档元数据
|
||||||
|
doc_content.metadata = self._extract_metadata()
|
||||||
|
|
||||||
|
# 提取段落
|
||||||
|
doc_content.paragraphs = self._extract_paragraphs()
|
||||||
|
|
||||||
|
# 提取表格
|
||||||
|
doc_content.tables = self._extract_tables()
|
||||||
|
|
||||||
|
# 生成Markdown格式文本
|
||||||
|
doc_content.markdown_text = self._generate_markdown(doc_content)
|
||||||
|
|
||||||
|
logger.success(f"Word文档解析完成,共{len(doc_content.paragraphs)}个段落")
|
||||||
|
|
||||||
|
return doc_content
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Word文档解析失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _extract_metadata(self) -> Dict[str, str]:
|
||||||
|
"""提取文档元数据"""
|
||||||
|
metadata = {}
|
||||||
|
try:
|
||||||
|
core_props = self.document.core_properties
|
||||||
|
metadata = {
|
||||||
|
"title": core_props.title or "",
|
||||||
|
"author": core_props.author or "",
|
||||||
|
"subject": core_props.subject or "",
|
||||||
|
"keywords": core_props.keywords or "",
|
||||||
|
"created": str(core_props.created) if core_props.created else "",
|
||||||
|
"modified": str(core_props.modified) if core_props.modified else "",
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"提取元数据失败: {e}")
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def _extract_paragraphs(self) -> List[DocxParagraph]:
|
||||||
|
"""提取所有段落"""
|
||||||
|
paragraphs = []
|
||||||
|
|
||||||
|
for para in self.document.paragraphs:
|
||||||
|
text = para.text.strip()
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 判断标题级别
|
||||||
|
level = self._get_paragraph_level(para)
|
||||||
|
|
||||||
|
# 判断是否是列表项
|
||||||
|
is_list, list_number = self._detect_list_item(para)
|
||||||
|
|
||||||
|
paragraph = DocxParagraph(
|
||||||
|
text=text,
|
||||||
|
level=level,
|
||||||
|
is_list=is_list,
|
||||||
|
list_number=list_number
|
||||||
|
)
|
||||||
|
paragraphs.append(paragraph)
|
||||||
|
|
||||||
|
return paragraphs
|
||||||
|
|
||||||
|
def _get_paragraph_level(self, para) -> int:
|
||||||
|
"""
|
||||||
|
判断段落标题级别
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 标题级别,0表示正文
|
||||||
|
"""
|
||||||
|
# 方法1:检查段落样式
|
||||||
|
style_name = para.style.name if para.style else ""
|
||||||
|
|
||||||
|
if "Heading" in style_name or "标题" in style_name:
|
||||||
|
# 从样式名称中提取级别
|
||||||
|
match = re.search(r'Heading\s*(\d)|标题\s*(\d)', style_name)
|
||||||
|
if match:
|
||||||
|
level = int(match.group(1) or match.group(2))
|
||||||
|
return level
|
||||||
|
|
||||||
|
# 方法2:检查段落格式(字号)
|
||||||
|
# 标题通常字号较大
|
||||||
|
if para.paragraph_format:
|
||||||
|
# 可以根据字号判断,这里简化处理
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 方法3:根据内容模式判断(法规文档特征)
|
||||||
|
text = para.text.strip()
|
||||||
|
|
||||||
|
# 第一章、第X章 -> 二级标题
|
||||||
|
if re.match(r'^第[一二三四五六七八九十百]+章\s', text):
|
||||||
|
return 2
|
||||||
|
# 第X节 -> 三级标题
|
||||||
|
elif re.match(r'^第[一二三四五六七八九十百]+节\s', text):
|
||||||
|
return 3
|
||||||
|
# 第X条 -> 四级标题
|
||||||
|
elif re.match(r'^第[一二三四五六七八九十百]+条\s', text):
|
||||||
|
return 4
|
||||||
|
|
||||||
|
return 0 # 正文
|
||||||
|
|
||||||
|
def _detect_list_item(self, para) -> tuple[bool, Optional[str]]:
|
||||||
|
"""检测是否是列表项"""
|
||||||
|
text = para.text.strip()
|
||||||
|
|
||||||
|
# 数字列表:1.、2.、(1)、[1]等
|
||||||
|
if re.match(r'^[\d]+[.、)\]]\s', text):
|
||||||
|
match = re.match(r'^([\d]+[.、)\]])\s', text)
|
||||||
|
return True, match.group(1) if match else None
|
||||||
|
|
||||||
|
# 中文数字列表:一、二、(一)等
|
||||||
|
if re.match(r'^[一二三四五六七八九十]+[、.)]\s', text):
|
||||||
|
match = re.match(r'^([一二三四五六七八九十]+[、.)])\s', text)
|
||||||
|
return True, match.group(1) if match else None
|
||||||
|
|
||||||
|
# 检查段落格式中的列表编号
|
||||||
|
if para.paragraph_format and hasattr(para.paragraph_format, 'left_indent'):
|
||||||
|
# 有缩进的可能是列表项
|
||||||
|
pass
|
||||||
|
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
def _extract_tables(self) -> List[DocxTable]:
|
||||||
|
"""提取所有表格"""
|
||||||
|
tables = []
|
||||||
|
|
||||||
|
for table in self.document.tables:
|
||||||
|
rows = []
|
||||||
|
for row in table.rows:
|
||||||
|
cells = []
|
||||||
|
for cell in row.cells:
|
||||||
|
cells.append(cell.text.strip())
|
||||||
|
rows.append(cells)
|
||||||
|
|
||||||
|
# 转换为Markdown表格
|
||||||
|
markdown = self._table_to_markdown(rows)
|
||||||
|
|
||||||
|
table_content = DocxTable(rows=rows, markdown=markdown)
|
||||||
|
tables.append(table_content)
|
||||||
|
|
||||||
|
return tables
|
||||||
|
|
||||||
|
def _table_to_markdown(self, rows: List[List[str]]) -> str:
|
||||||
|
"""将表格转换为Markdown格式"""
|
||||||
|
if not rows or len(rows) < 1:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
|
||||||
|
# 表头
|
||||||
|
if len(rows) >= 1:
|
||||||
|
header = rows[0]
|
||||||
|
lines.append("| " + " | ".join(cell for cell in header) + " |")
|
||||||
|
lines.append("| " + " | ".join("---" for _ in header) + " |")
|
||||||
|
|
||||||
|
# 数据行
|
||||||
|
for row in rows[1:]:
|
||||||
|
lines.append("| " + " | ".join(cell for cell in row) + " |")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def _generate_markdown(self, doc_content: DocxDocumentContent) -> str:
|
||||||
|
"""生成Markdown格式文本"""
|
||||||
|
lines = []
|
||||||
|
|
||||||
|
# 文档标题
|
||||||
|
title = doc_content.metadata.get("title", "")
|
||||||
|
if title:
|
||||||
|
lines.append(f"# {title}\n")
|
||||||
|
else:
|
||||||
|
# 从第一个段落获取标题(如果是标题样式)
|
||||||
|
for para in doc_content.paragraphs[:5]:
|
||||||
|
if para.level == 1:
|
||||||
|
lines.append(f"# {para.text}\n")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
lines.append(f"# {doc_content.file_path}\n")
|
||||||
|
|
||||||
|
# 元数据信息
|
||||||
|
lines.append("\n## 文档信息\n")
|
||||||
|
for key, value in doc_content.metadata.items():
|
||||||
|
if value:
|
||||||
|
lines.append(f"- **{key}**: {value}")
|
||||||
|
|
||||||
|
# 正文内容
|
||||||
|
lines.append("\n## 正文\n")
|
||||||
|
|
||||||
|
table_index = 0
|
||||||
|
for para in doc_content.paragraphs:
|
||||||
|
if para.level > 0:
|
||||||
|
# 标题
|
||||||
|
prefix = "#" * para.level
|
||||||
|
lines.append(f"\n{prefix} {para.text}\n")
|
||||||
|
elif para.is_list:
|
||||||
|
# 列表项
|
||||||
|
lines.append(f"- {para.text}")
|
||||||
|
else:
|
||||||
|
# 正文
|
||||||
|
lines.append(para.text)
|
||||||
|
|
||||||
|
# 添加表格
|
||||||
|
if doc_content.tables:
|
||||||
|
lines.append("\n## 表格\n")
|
||||||
|
for i, table in enumerate(doc_content.tables):
|
||||||
|
lines.append(f"\n### 表格 {i + 1}\n")
|
||||||
|
lines.append(table.markdown + "\n")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def parse_to_markdown(self, file_path: str) -> str:
|
||||||
|
"""直接解析并返回Markdown文本"""
|
||||||
|
doc_content = self.parse(file_path)
|
||||||
|
return doc_content.markdown_text
|
||||||
|
|
||||||
|
|
||||||
|
def parse_docx(file_path: str) -> DocxDocumentContent:
|
||||||
|
"""便捷函数:解析Word文档"""
|
||||||
|
parser = DocxParser()
|
||||||
|
return parser.parse(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_docx_to_markdown(file_path: str) -> str:
|
||||||
|
"""便捷函数:解析Word并返回Markdown"""
|
||||||
|
parser = DocxParser()
|
||||||
|
return parser.parse_to_markdown(file_path)
|
||||||
204
backend/app/services/parser/mineru_parser.py
Normal file
204
backend/app/services/parser/mineru_parser.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
# src/services/parser/mineru_parser.py
|
||||||
|
"""MinerU多模态PDF解析 - 版面感知解析"""
|
||||||
|
|
||||||
|
from typing import Optional, Dict
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from loguru import logger
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MinerUResult:
|
||||||
|
"""MinerU解析结果"""
|
||||||
|
file_path: str
|
||||||
|
markdown_text: str
|
||||||
|
metadata: Dict[str, str] = field(default_factory=dict)
|
||||||
|
success: bool = True
|
||||||
|
error_message: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class MinerUParser:
|
||||||
|
"""
|
||||||
|
MinerU多模态PDF解析器
|
||||||
|
|
||||||
|
MinerU (magic-pdf) 是一个开源的高质量PDF解析工具,
|
||||||
|
支持版面感知解析,能够识别文档中的标题、正文、表格、图片等元素,
|
||||||
|
并输出结构化的Markdown格式。
|
||||||
|
|
||||||
|
GitHub: https://github.com/opendatalab/MinerU
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.available = self._check_mineru_available()
|
||||||
|
|
||||||
|
def _check_mineru_available(self) -> bool:
|
||||||
|
"""检查MinerU是否可用"""
|
||||||
|
try:
|
||||||
|
from magic_pdf.pipe.UNIPipe import UNIPipe
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("MinerU (magic-pdf) 未安装,请运行: pip install magic-pdf[full]")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def parse(self, file_path: str, output_dir: Optional[str] = None) -> MinerUResult:
|
||||||
|
"""
|
||||||
|
使用MinerU解析PDF文档
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: PDF文件路径
|
||||||
|
output_dir: 输出目录(可选,用于保存解析产物)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MinerUResult: 解析结果
|
||||||
|
"""
|
||||||
|
logger.info(f"尝试使用MinerU解析: {file_path}")
|
||||||
|
|
||||||
|
if not self.available:
|
||||||
|
return MinerUResult(
|
||||||
|
file_path=file_path,
|
||||||
|
markdown_text="",
|
||||||
|
success=False,
|
||||||
|
error_message="MinerU未安装"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from magic_pdf.pipe.UNIPipe import UNIPipe
|
||||||
|
from magic_pdf.libs.MakeContentConfig import DropMode
|
||||||
|
|
||||||
|
# 设置输出目录
|
||||||
|
if output_dir is None:
|
||||||
|
output_dir = os.path.dirname(file_path)
|
||||||
|
|
||||||
|
# 创建解析管道
|
||||||
|
# OCR模式可以根据PDF类型选择
|
||||||
|
# auto: 自动判断是否需要OCR
|
||||||
|
# txt: 纯文本PDF(无OCR)
|
||||||
|
# ocr: 扫描件PDF(OCR)
|
||||||
|
pipe = UNIPipe(file_path, output_dir)
|
||||||
|
|
||||||
|
# 执行解析
|
||||||
|
# pipe_mk() 返回Markdown格式文本
|
||||||
|
markdown_content = pipe.pipe_mk()
|
||||||
|
|
||||||
|
logger.success(f"MinerU解析成功")
|
||||||
|
|
||||||
|
return MinerUResult(
|
||||||
|
file_path=file_path,
|
||||||
|
markdown_text=markdown_content,
|
||||||
|
metadata=self._extract_metadata(pipe),
|
||||||
|
success=True
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"MinerU解析失败: {e}")
|
||||||
|
return MinerUResult(
|
||||||
|
file_path=file_path,
|
||||||
|
markdown_text="",
|
||||||
|
success=False,
|
||||||
|
error_message=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extract_metadata(self, pipe) -> Dict[str, str]:
|
||||||
|
"""从解析管道提取元数据"""
|
||||||
|
metadata = {}
|
||||||
|
try:
|
||||||
|
# MinerU解析管道中可能包含的元数据信息
|
||||||
|
if hasattr(pipe, 'pdf_mid_data') and pipe.pdf_mid_data:
|
||||||
|
mid_data = pipe.pdf_mid_data
|
||||||
|
# 提取可能的元数据字段
|
||||||
|
metadata = {
|
||||||
|
"page_count": str(mid_data.get("page_count", "")),
|
||||||
|
"language": str(mid_data.get("language", "")),
|
||||||
|
"is_scanned": str(mid_data.get("is_scanned", "")),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"提取MinerU元数据失败: {e}")
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def parse_to_markdown(self, file_path: str) -> str:
|
||||||
|
"""直接解析并返回Markdown文本"""
|
||||||
|
result = self.parse(file_path)
|
||||||
|
return result.markdown_text if result.success else ""
|
||||||
|
|
||||||
|
|
||||||
|
class ParserOrchestrator:
|
||||||
|
"""
|
||||||
|
解析服务编排 - 按优先级选择解析器
|
||||||
|
|
||||||
|
解析策略:
|
||||||
|
1. 优先尝试MinerU(版面感知能力强)
|
||||||
|
2. MinerU失败时回退到基础PyMuPDF解析
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
from .pdf_parser import PDFParser
|
||||||
|
self.mineru_parser = MinerUParser()
|
||||||
|
self.pdf_parser = PDFParser()
|
||||||
|
self.mineru_available = self.mineru_parser.available
|
||||||
|
|
||||||
|
def parse_pdf(self, file_path: str, prefer_mineru: bool = True) -> str:
|
||||||
|
"""
|
||||||
|
解析PDF文档,按优先级选择解析器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: PDF文件路径
|
||||||
|
prefer_mineru: 是否优先使用MinerU
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Markdown格式文本
|
||||||
|
"""
|
||||||
|
markdown_text = ""
|
||||||
|
|
||||||
|
if prefer_mineru and self.mineru_available:
|
||||||
|
# 优先尝试MinerU
|
||||||
|
result = self.mineru_parser.parse(file_path)
|
||||||
|
if result.success:
|
||||||
|
markdown_text = result.markdown_text
|
||||||
|
logger.info("使用MinerU解析成功")
|
||||||
|
return markdown_text
|
||||||
|
else:
|
||||||
|
logger.warning(f"MinerU解析失败,回退到PyMuPDF: {result.error_message}")
|
||||||
|
|
||||||
|
# 回退到PyMuPDF基础解析
|
||||||
|
logger.info("使用PyMuPDF基础解析")
|
||||||
|
markdown_text = self.pdf_parser.parse_to_markdown(file_path)
|
||||||
|
|
||||||
|
return markdown_text
|
||||||
|
|
||||||
|
def parse_docx(self, file_path: str) -> str:
|
||||||
|
"""解析Word文档"""
|
||||||
|
from .docx_parser import DocxParser
|
||||||
|
docx_parser = DocxParser()
|
||||||
|
return docx_parser.parse_to_markdown(file_path)
|
||||||
|
|
||||||
|
def parse(self, file_path: str) -> str:
|
||||||
|
"""
|
||||||
|
根据文件类型选择解析器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: 文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Markdown格式文本
|
||||||
|
"""
|
||||||
|
ext = os.path.splitext(file_path)[1].lower()
|
||||||
|
|
||||||
|
if ext == ".pdf":
|
||||||
|
return self.parse_pdf(file_path)
|
||||||
|
elif ext in [".docx", ".doc"]:
|
||||||
|
return self.parse_docx(file_path)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的文件类型: {ext}")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_with_mineru(file_path: str) -> MinerUResult:
|
||||||
|
"""便捷函数:使用MinerU解析"""
|
||||||
|
parser = MinerUParser()
|
||||||
|
return parser.parse(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_pdf_smart(file_path: str) -> str:
|
||||||
|
"""便捷函数:智能解析PDF(自动选择最佳解析器)"""
|
||||||
|
orchestrator = ParserOrchestrator()
|
||||||
|
return orchestrator.parse_pdf(file_path)
|
||||||
268
backend/app/services/parser/pdf_parser.py
Normal file
268
backend/app/services/parser/pdf_parser.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
# src/services/parser/pdf_parser.py
|
||||||
|
"""PDF文档解析 - 使用PyMuPDF基础解析"""
|
||||||
|
|
||||||
|
import fitz # PyMuPDF
|
||||||
|
from typing import List, Dict, Optional, Tuple
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from loguru import logger
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PDFPageContent:
|
||||||
|
"""PDF页面内容"""
|
||||||
|
page_number: int
|
||||||
|
text: str
|
||||||
|
tables: List[str] = field(default_factory=list)
|
||||||
|
images: List[str] = field(default_factory=list) # 图片路径列表
|
||||||
|
blocks: List[Dict] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PDFDocumentContent:
|
||||||
|
"""PDF文档完整内容"""
|
||||||
|
file_path: str
|
||||||
|
total_pages: int
|
||||||
|
pages: List[PDFPageContent]
|
||||||
|
metadata: Dict[str, str] = field(default_factory=dict)
|
||||||
|
markdown_text: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class PDFParser:
|
||||||
|
"""PDF文档解析器 - 基于PyMuPDF"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.pdf = None
|
||||||
|
|
||||||
|
def parse(self, file_path: str, extract_tables: bool = True, extract_images: bool = False) -> PDFDocumentContent:
|
||||||
|
"""
|
||||||
|
解析PDF文档
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: PDF文件路径
|
||||||
|
extract_tables: 是否提取表格
|
||||||
|
extract_images: 是否提取图片
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PDFDocumentContent: 解析后的文档内容
|
||||||
|
"""
|
||||||
|
logger.info(f"开始解析PDF文档: {file_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.pdf = fitz.open(file_path)
|
||||||
|
doc_content = PDFDocumentContent(
|
||||||
|
file_path=file_path,
|
||||||
|
total_pages=self.pdf.page_count,
|
||||||
|
pages=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 提取文档元数据
|
||||||
|
doc_content.metadata = self._extract_metadata()
|
||||||
|
|
||||||
|
# 逐页解析
|
||||||
|
for page_num in range(self.pdf.page_count):
|
||||||
|
page = self.pdf[page_num]
|
||||||
|
page_content = self._parse_page(page, page_num + 1, extract_tables, extract_images)
|
||||||
|
doc_content.pages.append(page_content)
|
||||||
|
|
||||||
|
# 生成Markdown格式文本
|
||||||
|
doc_content.markdown_text = self._generate_markdown(doc_content)
|
||||||
|
|
||||||
|
self.pdf.close()
|
||||||
|
logger.success(f"PDF解析完成,共{doc_content.total_pages}页")
|
||||||
|
|
||||||
|
return doc_content
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"PDF解析失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _extract_metadata(self) -> Dict[str, str]:
|
||||||
|
"""提取PDF元数据"""
|
||||||
|
metadata = {}
|
||||||
|
try:
|
||||||
|
meta = self.pdf.metadata
|
||||||
|
metadata = {
|
||||||
|
"title": meta.get("title", ""),
|
||||||
|
"author": meta.get("author", ""),
|
||||||
|
"subject": meta.get("subject", ""),
|
||||||
|
"keywords": meta.get("keywords", ""),
|
||||||
|
"creator": meta.get("creator", ""),
|
||||||
|
"producer": meta.get("producer", ""),
|
||||||
|
"creation_date": meta.get("creationDate", ""),
|
||||||
|
"mod_date": meta.get("modDate", ""),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"提取元数据失败: {e}")
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def _parse_page(self, page: fitz.Page, page_num: int,
|
||||||
|
extract_tables: bool, extract_images: bool) -> PDFPageContent:
|
||||||
|
"""解析单页内容"""
|
||||||
|
page_content = PDFPageContent(page_number=page_num, text="")
|
||||||
|
|
||||||
|
# 提取文本块(保留结构)
|
||||||
|
blocks = page.get_text("dict", flags=fitz.TEXT_PRESERVE_WHITESPACE)["blocks"]
|
||||||
|
page_content.blocks = blocks
|
||||||
|
|
||||||
|
# 提取纯文本
|
||||||
|
text = page.get_text("text", flags=fitz.TEXT_PRESERVE_WHITESPACE)
|
||||||
|
page_content.text = text.strip()
|
||||||
|
|
||||||
|
# 提取表格(使用PyMuPDF的表格提取功能)
|
||||||
|
if extract_tables:
|
||||||
|
tables = self._extract_tables_from_page(page)
|
||||||
|
page_content.tables = tables
|
||||||
|
|
||||||
|
# 提取图片
|
||||||
|
if extract_images:
|
||||||
|
images = self._extract_images_from_page(page, page_num)
|
||||||
|
page_content.images = images
|
||||||
|
|
||||||
|
return page_content
|
||||||
|
|
||||||
|
def _extract_tables_from_page(self, page: fitz.Page) -> List[str]:
|
||||||
|
"""
|
||||||
|
从页面提取表格(基于文本块分析)
|
||||||
|
注意:PyMuPDF基础版表格提取能力有限,复杂表格建议使用MinerU
|
||||||
|
"""
|
||||||
|
tables = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 使用PyMuPDF的表格提取方法(2.4+版本)
|
||||||
|
# 对于更复杂的表格,需要在mineru_parser中使用更高级的方法
|
||||||
|
tabs = page.find_tables()
|
||||||
|
if tabs:
|
||||||
|
for tab in tabs:
|
||||||
|
table_text = tab.extract()
|
||||||
|
# 将表格转换为Markdown格式
|
||||||
|
markdown_table = self._table_to_markdown(table_text)
|
||||||
|
tables.append(markdown_table)
|
||||||
|
|
||||||
|
except AttributeError:
|
||||||
|
# 旧版本PyMuPDF没有表格提取功能
|
||||||
|
logger.warning("PyMuPDF版本不支持表格提取,请升级到2.4+版本")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"表格提取失败: {e}")
|
||||||
|
|
||||||
|
return tables
|
||||||
|
|
||||||
|
def _table_to_markdown(self, table_data: List[List[str]]) -> str:
|
||||||
|
"""将表格数据转换为Markdown格式"""
|
||||||
|
if not table_data or len(table_data) < 1:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
# 表头
|
||||||
|
if len(table_data) >= 1:
|
||||||
|
header = table_data[0]
|
||||||
|
lines.append("| " + " | ".join(str(cell).strip() for cell in header) + " |")
|
||||||
|
lines.append("| " + " | ".join("---" for _ in header) + " |")
|
||||||
|
|
||||||
|
# 数据行
|
||||||
|
for row in table_data[1:]:
|
||||||
|
lines.append("| " + " | ".join(str(cell).strip() for cell in row) + " |")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def _extract_images_from_page(self, page: fitz.Page, page_num: int) -> List[str]:
|
||||||
|
"""提取页面图片"""
|
||||||
|
images = []
|
||||||
|
# 图片提取功能(可选实现)
|
||||||
|
# 这里仅记录图片信息,实际图片需要额外保存
|
||||||
|
try:
|
||||||
|
image_list = page.get_images()
|
||||||
|
for img_index, img in enumerate(image_list):
|
||||||
|
xref = img[0]
|
||||||
|
images.append(f"image_p{page_num}_i{img_index}_xref{xref}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"图片提取失败: {e}")
|
||||||
|
return images
|
||||||
|
|
||||||
|
def _generate_markdown(self, doc_content: PDFDocumentContent) -> str:
|
||||||
|
"""生成Markdown格式文本"""
|
||||||
|
lines = []
|
||||||
|
|
||||||
|
# 文档标题
|
||||||
|
title = doc_content.metadata.get("title", "")
|
||||||
|
if title:
|
||||||
|
lines.append(f"# {title}\n")
|
||||||
|
else:
|
||||||
|
lines.append(f"# {doc_content.file_path}\n")
|
||||||
|
|
||||||
|
# 元数据信息
|
||||||
|
lines.append("\n## 文档信息\n")
|
||||||
|
for key, value in doc_content.metadata.items():
|
||||||
|
if value and key in ["author", "subject", "keywords", "creation_date"]:
|
||||||
|
lines.append(f"- **{key}**: {value}")
|
||||||
|
|
||||||
|
# 正文内容
|
||||||
|
lines.append("\n## 正文\n")
|
||||||
|
|
||||||
|
for page in doc_content.pages:
|
||||||
|
# 页码标记
|
||||||
|
lines.append(f"\n---\n**第 {page.page_number} 页**\n")
|
||||||
|
|
||||||
|
# 处理文本内容,识别标题结构
|
||||||
|
text = self._process_page_text(page.text, page.blocks)
|
||||||
|
lines.append(text)
|
||||||
|
|
||||||
|
# 添加表格
|
||||||
|
for table in page.tables:
|
||||||
|
lines.append("\n" + table + "\n")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def _process_page_text(self, text: str, blocks: List[Dict]) -> str:
|
||||||
|
"""处理页面文本,识别标题结构"""
|
||||||
|
# 基于字体大小识别标题
|
||||||
|
processed_text = text
|
||||||
|
|
||||||
|
# 尝试识别标题(基于字号)
|
||||||
|
# 法规文档通常有明确的层级结构:章、节、条
|
||||||
|
processed_text = self._detect_headers(text, blocks)
|
||||||
|
|
||||||
|
return processed_text
|
||||||
|
|
||||||
|
def _detect_headers(self, text: str, blocks: List[Dict]) -> str:
|
||||||
|
"""检测并标记标题(基于字号或内容模式)"""
|
||||||
|
lines = text.split("\n")
|
||||||
|
processed_lines = []
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 法规标题模式检测
|
||||||
|
# 第一章、第X章、第X节、第X条等
|
||||||
|
if re.match(r'^第[一二三四五六七八九十百]+章\s', line):
|
||||||
|
processed_lines.append(f"\n## {line}\n")
|
||||||
|
elif re.match(r'^第[一二三四五六七八九十百]+节\s', line):
|
||||||
|
processed_lines.append(f"\n### {line}\n")
|
||||||
|
elif re.match(r'^第[一二三四五六七八九十百]+条\s', line):
|
||||||
|
processed_lines.append(f"\n#### {line}\n")
|
||||||
|
elif re.match(r'^[一二三四五六七八九十]+\s*[、.]', line):
|
||||||
|
# 条款子项
|
||||||
|
processed_lines.append(f"- {line}")
|
||||||
|
else:
|
||||||
|
processed_lines.append(line)
|
||||||
|
|
||||||
|
return "\n".join(processed_lines)
|
||||||
|
|
||||||
|
def parse_to_markdown(self, file_path: str) -> str:
|
||||||
|
"""直接解析并返回Markdown文本"""
|
||||||
|
doc_content = self.parse(file_path)
|
||||||
|
return doc_content.markdown_text
|
||||||
|
|
||||||
|
|
||||||
|
def parse_pdf(file_path: str, **kwargs) -> PDFDocumentContent:
|
||||||
|
"""便捷函数:解析PDF文档"""
|
||||||
|
parser = PDFParser()
|
||||||
|
return parser.parse(file_path, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_pdf_to_markdown(file_path: str) -> str:
|
||||||
|
"""便捷函数:解析PDF并返回Markdown"""
|
||||||
|
parser = PDFParser()
|
||||||
|
return parser.parse_to_markdown(file_path)
|
||||||
12
backend/app/services/rag/__init__.py
Normal file
12
backend/app/services/rag/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
# src/services/rag/__init__.py
|
||||||
|
"""RAG服务模块"""
|
||||||
|
|
||||||
|
from .retriever import Retriever, retrieve_regulations
|
||||||
|
from .context_builder import ContextBuilder, build_rag_context
|
||||||
|
from .prompt_templates import PromptTemplates, get_prompt_template
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Retriever", "retrieve_regulations",
|
||||||
|
"ContextBuilder", "build_rag_context",
|
||||||
|
"PromptTemplates", "get_prompt_template"
|
||||||
|
]
|
||||||
230
backend/app/services/rag/context_builder.py
Normal file
230
backend/app/services/rag/context_builder.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
# src/services/rag/context_builder.py
|
||||||
|
"""RAG上下文构建服务 - 构建LLM输入上下文"""
|
||||||
|
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from .retriever import RetrievedDocument
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RAGContext:
|
||||||
|
"""RAG构建的上下文"""
|
||||||
|
system_prompt: str
|
||||||
|
context_text: str
|
||||||
|
user_query: str
|
||||||
|
total_tokens: int
|
||||||
|
sources: List[Dict]
|
||||||
|
truncated: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class ContextBuilder:
|
||||||
|
"""
|
||||||
|
RAG上下文构建器
|
||||||
|
|
||||||
|
功能:
|
||||||
|
- 格式化检索结果为上下文文本
|
||||||
|
- 控制上下文长度(token限制)
|
||||||
|
- 构建完整的LLM输入格式
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_context_tokens: int = None,
|
||||||
|
include_metadata: bool = True,
|
||||||
|
citation_format: str = "【条款{clause}】"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化上下文构建器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_context_tokens: 最大上下文token数
|
||||||
|
include_metadata: 是否包含元数据(文档名、条款号等)
|
||||||
|
citation_format: 引用格式模板
|
||||||
|
"""
|
||||||
|
self.max_context_tokens = max_context_tokens or settings.rag_max_context_tokens
|
||||||
|
self.include_metadata = include_metadata
|
||||||
|
self.citation_format = citation_format
|
||||||
|
|
||||||
|
logger.info(f"上下文构建器初始化: max_tokens={self.max_context_tokens}")
|
||||||
|
|
||||||
|
def build(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
documents: List[RetrievedDocument],
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
) -> RAGContext:
|
||||||
|
"""
|
||||||
|
构建RAG上下文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 用户查询
|
||||||
|
documents: 检索到的文档列表
|
||||||
|
system_prompt: 系统提示词(可选)
|
||||||
|
max_tokens: 最大token数(可选,覆盖默认值)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RAGContext: 构建的上下文对象
|
||||||
|
"""
|
||||||
|
max_tokens = max_tokens or self.max_context_tokens
|
||||||
|
|
||||||
|
# 格式化文档内容
|
||||||
|
context_text, sources, truncated = self._format_documents(
|
||||||
|
documents,
|
||||||
|
max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建系统提示词
|
||||||
|
system_prompt = system_prompt or self._default_system_prompt()
|
||||||
|
|
||||||
|
# 估算总token数
|
||||||
|
total_tokens = self._estimate_tokens(system_prompt + context_text + query)
|
||||||
|
|
||||||
|
logger.info(f"上下文构建完成: {len(documents)}条文档, {total_tokens}tokens, truncated={truncated}")
|
||||||
|
|
||||||
|
return RAGContext(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
context_text=context_text,
|
||||||
|
user_query=query,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
sources=sources,
|
||||||
|
truncated=truncated
|
||||||
|
)
|
||||||
|
|
||||||
|
def _format_documents(
|
||||||
|
self,
|
||||||
|
documents: List[RetrievedDocument],
|
||||||
|
max_tokens: int
|
||||||
|
) -> tuple:
|
||||||
|
"""
|
||||||
|
格式化文档内容
|
||||||
|
|
||||||
|
Args:
|
||||||
|
documents: 文档列表
|
||||||
|
max_tokens: 最大token数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(context_text, sources, truncated)
|
||||||
|
"""
|
||||||
|
context_parts = []
|
||||||
|
sources = []
|
||||||
|
current_tokens = 0
|
||||||
|
truncated = False
|
||||||
|
|
||||||
|
for i, doc in enumerate(documents):
|
||||||
|
# 格式化单个文档
|
||||||
|
formatted = self._format_single_doc(doc, i + 1)
|
||||||
|
|
||||||
|
# 估算token数
|
||||||
|
doc_tokens = self._estimate_tokens(formatted)
|
||||||
|
|
||||||
|
# 检查是否超出限制
|
||||||
|
if current_tokens + doc_tokens > max_tokens:
|
||||||
|
truncated = True
|
||||||
|
logger.warning(f"上下文截断: 已达到{max_tokens}tokens限制")
|
||||||
|
break
|
||||||
|
|
||||||
|
context_parts.append(formatted)
|
||||||
|
current_tokens += doc_tokens
|
||||||
|
|
||||||
|
# 记录来源
|
||||||
|
sources.append({
|
||||||
|
"index": i + 1,
|
||||||
|
"doc_id": doc.doc_id,
|
||||||
|
"doc_name": doc.doc_name,
|
||||||
|
"section_title": doc.section_title,
|
||||||
|
"clause_number": doc.clause_number,
|
||||||
|
"page_number": doc.page_number,
|
||||||
|
"score": doc.score
|
||||||
|
})
|
||||||
|
|
||||||
|
context_text = "\n\n".join(context_parts)
|
||||||
|
return context_text, sources, truncated
|
||||||
|
|
||||||
|
def _format_single_doc(
|
||||||
|
self,
|
||||||
|
doc: RetrievedDocument,
|
||||||
|
index: int
|
||||||
|
) -> str:
|
||||||
|
"""格式化单个文档"""
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
# 索引编号
|
||||||
|
parts.append(f"[{index}]")
|
||||||
|
|
||||||
|
# 元数据(可选)
|
||||||
|
if self.include_metadata:
|
||||||
|
meta_parts = []
|
||||||
|
|
||||||
|
if doc.doc_name:
|
||||||
|
meta_parts.append(f"文档: {doc.doc_name}")
|
||||||
|
|
||||||
|
if doc.section_title:
|
||||||
|
meta_parts.append(f"章节: {doc.section_title}")
|
||||||
|
|
||||||
|
if doc.clause_number:
|
||||||
|
clause_text = self.citation_format.format(clause=doc.clause_number)
|
||||||
|
meta_parts.append(clause_text)
|
||||||
|
|
||||||
|
if meta_parts:
|
||||||
|
parts.append(" | ".join(meta_parts))
|
||||||
|
|
||||||
|
# 内容
|
||||||
|
parts.append(doc.content)
|
||||||
|
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
def _default_system_prompt(self) -> str:
|
||||||
|
"""默认系统提示词"""
|
||||||
|
return """你是合规专家助手,基于提供的法规条款回答问题。
|
||||||
|
|
||||||
|
回答要求:
|
||||||
|
1. 直接回答问题,必须引用具体条款编号(如【条款5.2.1】)
|
||||||
|
2. 如引用的条款不完整,说明需要进一步查阅原文
|
||||||
|
3. 给出明确的合规建议和操作指导
|
||||||
|
4. 如果检索内容不足以回答问题,如实说明
|
||||||
|
|
||||||
|
回答格式:
|
||||||
|
- 先给出直接结论
|
||||||
|
- 然后引用支撑条款
|
||||||
|
- 最后给出合规建议"""
|
||||||
|
|
||||||
|
def _estimate_tokens(self, text: str) -> int:
|
||||||
|
"""估算文本token数"""
|
||||||
|
# 中文字符约1.5 token,英文约0.25 token
|
||||||
|
chinese_chars = sum(1 for c in text if '一' <= c <= '鿿')
|
||||||
|
other_chars = len(text) - chinese_chars
|
||||||
|
return int(chinese_chars * 1.5 + other_chars * 0.25)
|
||||||
|
|
||||||
|
def build_messages(
|
||||||
|
self,
|
||||||
|
context: RAGContext
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
构建LLM消息格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: RAG上下文对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: [{"role": "system/user/assistant", "content": "..."}]
|
||||||
|
"""
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": context.system_prompt},
|
||||||
|
{"role": "user", "content": f"参考以下法规条款回答问题。\n\n{context.context_text}\n\n问题:{context.user_query}"}
|
||||||
|
]
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def build_rag_context(
|
||||||
|
query: str,
|
||||||
|
documents: List[RetrievedDocument],
|
||||||
|
**kwargs
|
||||||
|
) -> RAGContext:
|
||||||
|
"""便捷函数:构建RAG上下文"""
|
||||||
|
builder = ContextBuilder()
|
||||||
|
return builder.build(query, documents, **kwargs)
|
||||||
296
backend/app/services/rag/prompt_templates.py
Normal file
296
backend/app/services/rag/prompt_templates.py
Normal file
@@ -0,0 +1,296 @@
|
|||||||
|
# src/services/rag/prompt_templates.py
|
||||||
|
"""RAG Prompt模板 - 合规问答专用Prompt"""
|
||||||
|
|
||||||
|
from typing import Dict, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PromptTemplate:
|
||||||
|
"""Prompt模板"""
|
||||||
|
name: str
|
||||||
|
system_prompt: str
|
||||||
|
user_template: str
|
||||||
|
description: str
|
||||||
|
|
||||||
|
|
||||||
|
class PromptTemplates:
|
||||||
|
"""
|
||||||
|
合规问答Prompt模板库
|
||||||
|
|
||||||
|
包含多种场景的Prompt模板:
|
||||||
|
- 合规问答(标准)
|
||||||
|
- 条款解读(详细解释)
|
||||||
|
- 合规检查(判断合规状态)
|
||||||
|
- 差异对比(新旧法规对比)
|
||||||
|
- 报告生成(合规报告)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 合规问答标准模板
|
||||||
|
COMPLIANCE_QA = PromptTemplate(
|
||||||
|
name="compliance_qa",
|
||||||
|
system_prompt="""你是合规专家助手,专门解答法规合规问题。
|
||||||
|
|
||||||
|
角色定位:
|
||||||
|
- 深谙国家法规标准(GB标准、行业标准)
|
||||||
|
- 熟悉车辆安全、数据安全、EHS等领域合规要求
|
||||||
|
- 提供专业、准确、可操作的合规建议
|
||||||
|
|
||||||
|
回答规范:
|
||||||
|
1. 必须引用具体条款编号(如【条款5.2.1】)
|
||||||
|
2. 优先引用高相关性条款(score ≥ 0.5)
|
||||||
|
3. 如条款内容不完整,明确提示需要查阅原文
|
||||||
|
4. 给出明确的合规结论和建议
|
||||||
|
5. 如检索内容不足以回答,如实说明
|
||||||
|
|
||||||
|
回答格式:
|
||||||
|
【结论】直接给出合规判断或答案
|
||||||
|
|
||||||
|
【条款依据】
|
||||||
|
- 【条款X.X.X】简要内容摘要(相关性: 高/中)
|
||||||
|
- ...
|
||||||
|
|
||||||
|
【合规建议】
|
||||||
|
1. 具体操作建议
|
||||||
|
2. 需要注意的风险点
|
||||||
|
3. 后续行动建议""",
|
||||||
|
user_template="""请根据以下法规条款回答问题。
|
||||||
|
|
||||||
|
【法规条款】
|
||||||
|
{context}
|
||||||
|
|
||||||
|
【用户问题】
|
||||||
|
{query}""",
|
||||||
|
description="标准合规问答模板"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 条款解读模板(详细解释)
|
||||||
|
CLAUSE_INTERPRETATION = PromptTemplate(
|
||||||
|
name="clause_interpretation",
|
||||||
|
system_prompt="""你是法规解读专家,负责详细解释法规条款的含义和应用。
|
||||||
|
|
||||||
|
解读要求:
|
||||||
|
1. 逐句解释条款原文的含义
|
||||||
|
2. 说明条款的目的和背景
|
||||||
|
3. 举例说明条款的实际应用场景
|
||||||
|
4. 解释常见的误解和注意事项
|
||||||
|
|
||||||
|
解读格式:
|
||||||
|
【条款原文】完整引用条款
|
||||||
|
|
||||||
|
【逐句解读】
|
||||||
|
- "原文句1":解读含义
|
||||||
|
- "原文句2":解读含义
|
||||||
|
...
|
||||||
|
|
||||||
|
【应用场景】
|
||||||
|
具体举例说明该条款在实际工作中如何应用
|
||||||
|
|
||||||
|
【注意事项】
|
||||||
|
常见误解、执行难点、合规风险点""",
|
||||||
|
user_template="""请解读以下法规条款:
|
||||||
|
|
||||||
|
条款编号:{clause_number}
|
||||||
|
条款内容:{content}
|
||||||
|
|
||||||
|
用户关注点:{query}""",
|
||||||
|
description="条款详细解读模板"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 合规检查模板(判断合规状态)
|
||||||
|
COMPLIANCE_CHECK = PromptTemplate(
|
||||||
|
name="compliance_check",
|
||||||
|
system_prompt="""你是合规检查专家,负责评估企业行为或产品的合规状态。
|
||||||
|
|
||||||
|
检查流程:
|
||||||
|
1. 理解企业行为/产品描述
|
||||||
|
2. 识别相关的法规条款
|
||||||
|
3. 逐条对照检查合规状态
|
||||||
|
4. 给出综合合规结论和整改建议
|
||||||
|
|
||||||
|
合规状态分类:
|
||||||
|
- ✅ 符合:完全满足法规要求
|
||||||
|
- ⚠️ 需评估:需要进一步核实或补充材料
|
||||||
|
- ❌ 不符合:明确违反法规要求
|
||||||
|
- ❓ 无适用条款:检索内容不足以判断
|
||||||
|
|
||||||
|
检查格式:
|
||||||
|
【合规检查报告】
|
||||||
|
|
||||||
|
一、检查对象
|
||||||
|
{描述企业行为/产品}
|
||||||
|
|
||||||
|
二、条款对照检查
|
||||||
|
| 条款编号 | 要求摘要 | 检查状态 | 说明 |
|
||||||
|
|--------|---------|---------|------|
|
||||||
|
| 【条款X.X.X】 | ... | ✅/⚠️/❌/❓ | ... |
|
||||||
|
|
||||||
|
三、综合结论
|
||||||
|
合规等级:A/B/C/D(完全合规/基本合规/部分合规/不合规)
|
||||||
|
|
||||||
|
四、整改建议(如需要)
|
||||||
|
1. ...
|
||||||
|
2. ...""",
|
||||||
|
user_template="""请对以下企业行为进行合规检查:
|
||||||
|
|
||||||
|
【行为/产品描述】
|
||||||
|
{query}
|
||||||
|
|
||||||
|
【相关法规条款】
|
||||||
|
{context}""",
|
||||||
|
description="合规检查评估模板"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 差异对比模板(新旧法规对比)
|
||||||
|
COMPARISON = PromptTemplate(
|
||||||
|
name="comparison",
|
||||||
|
system_prompt="""你是法规变更分析专家,负责对比新旧法规版本的差异。
|
||||||
|
|
||||||
|
对比任务:
|
||||||
|
1. 识别新旧版本的条款差异
|
||||||
|
2. 分类差异类型(新增/修改/删除)
|
||||||
|
3. 分析差异的影响范围
|
||||||
|
4. 给出企业应对建议
|
||||||
|
|
||||||
|
差异分类:
|
||||||
|
- 🆕 新增条款:原版本不存在
|
||||||
|
- 🔄 修改条款:内容有实质性变更
|
||||||
|
- ❌ 删除条款:原条款被移除
|
||||||
|
- ⚖️ 调整条款:仅格式/编号调整,实质内容不变
|
||||||
|
|
||||||
|
对比格式:
|
||||||
|
【法规变更对比分析】
|
||||||
|
|
||||||
|
一、变更概述
|
||||||
|
- 旧版本:{version_old}
|
||||||
|
- 新版本:{version_new}
|
||||||
|
- 变更条款数:{count}
|
||||||
|
|
||||||
|
二、差异明细
|
||||||
|
| 类型 | 条款编号 | 旧版本内容 | 新版本内容 | 变化要点 |
|
||||||
|
|-----|---------|-----------|-----------|---------|
|
||||||
|
| 🆕 | X.X.X | - | ... | 新增要求... |
|
||||||
|
|
||||||
|
三、影响分析
|
||||||
|
- 高影响:{条款列表}
|
||||||
|
- 中影响:{条款列表}
|
||||||
|
- 低影响:{条款列表}
|
||||||
|
|
||||||
|
四、应对建议
|
||||||
|
1. 立即整改项
|
||||||
|
2. 逐步调整项
|
||||||
|
3. 持续关注项""",
|
||||||
|
user_template="""请对比分析以下法规差异:
|
||||||
|
|
||||||
|
【用户问题】
|
||||||
|
{query}
|
||||||
|
|
||||||
|
【旧版本条款】
|
||||||
|
{context_old}
|
||||||
|
|
||||||
|
【新版本条款】
|
||||||
|
{context_new}""",
|
||||||
|
description="法规版本对比模板"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 报告生成模板
|
||||||
|
REPORT_GENERATION = PromptTemplate(
|
||||||
|
name="report_generation",
|
||||||
|
system_prompt="""你是合规报告撰写专家,负责生成结构化的合规分析报告。
|
||||||
|
|
||||||
|
报告要求:
|
||||||
|
1. 结构清晰、逻辑严谨
|
||||||
|
2. 数据准确、引用规范
|
||||||
|
3. 结论明确、建议可操作
|
||||||
|
4. 语言专业、表达简洁
|
||||||
|
|
||||||
|
报告结构:
|
||||||
|
1. 概述(背景、范围)
|
||||||
|
2. 分析内容(主体分析)
|
||||||
|
3. 发现问题(合规差距)
|
||||||
|
4. 整改建议(具体措施)
|
||||||
|
5. 附录(引用条款原文)""",
|
||||||
|
user_template="""请生成以下合规报告:
|
||||||
|
|
||||||
|
【报告主题】
|
||||||
|
{query}
|
||||||
|
|
||||||
|
【分析依据】
|
||||||
|
{context}
|
||||||
|
|
||||||
|
【报告要求】
|
||||||
|
{requirements}""",
|
||||||
|
description="合规报告生成模板"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 文档摘要生成模板
|
||||||
|
DOCUMENT_SUMMARY = PromptTemplate(
|
||||||
|
name="document_summary",
|
||||||
|
system_prompt="""你是法规文档摘要专家,负责生成法规文档的核心要点摘要。
|
||||||
|
|
||||||
|
摘要要求:
|
||||||
|
1. 精炼核心内容,不超过1024字
|
||||||
|
2. 突出关键合规要求和条款编号
|
||||||
|
3. 说明适用范围和生效条件
|
||||||
|
4. 列出重要定义和术语解释
|
||||||
|
|
||||||
|
摘要格式:
|
||||||
|
【法规名称】{doc_name}
|
||||||
|
|
||||||
|
【适用范围】{适用范围描述}
|
||||||
|
|
||||||
|
【核心条款摘要】
|
||||||
|
- 【条款X.X.X】{关键要求}(重要度:高)
|
||||||
|
- ...
|
||||||
|
|
||||||
|
【关键术语】
|
||||||
|
- 术语1:定义解释
|
||||||
|
- ...
|
||||||
|
|
||||||
|
【合规要点】
|
||||||
|
1. 必须满足的核心要求
|
||||||
|
2. 需要特别注意的条款""",
|
||||||
|
user_template="""请生成以下法规文档的摘要:
|
||||||
|
|
||||||
|
【文档名称】
|
||||||
|
{doc_name}
|
||||||
|
|
||||||
|
【文档内容】
|
||||||
|
{content}
|
||||||
|
|
||||||
|
请生成不超过1024字的摘要。""",
|
||||||
|
description="文档摘要生成模板"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_template(cls, name: str) -> Optional[PromptTemplate]:
|
||||||
|
"""获取指定模板"""
|
||||||
|
templates = {
|
||||||
|
"compliance_qa": cls.COMPLIANCE_QA,
|
||||||
|
"clause_interpretation": cls.CLAUSE_INTERPRETATION,
|
||||||
|
"compliance_check": cls.COMPLIANCE_CHECK,
|
||||||
|
"comparison": cls.COMPARISON,
|
||||||
|
"report": cls.REPORT_GENERATION,
|
||||||
|
"document_summary": cls.DOCUMENT_SUMMARY
|
||||||
|
}
|
||||||
|
return templates.get(name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_templates(cls) -> Dict[str, str]:
|
||||||
|
"""列出所有模板"""
|
||||||
|
return {
|
||||||
|
"compliance_qa": cls.COMPLIANCE_QA.description,
|
||||||
|
"clause_interpretation": cls.CLAUSE_INTERPRETATION.description,
|
||||||
|
"compliance_check": cls.COMPLIANCE_CHECK.description,
|
||||||
|
"comparison": cls.COMPARISON.description,
|
||||||
|
"report": cls.REPORT_GENERATION.description,
|
||||||
|
"document_summary": cls.DOCUMENT_SUMMARY.description
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt_template(name: str) -> PromptTemplate:
|
||||||
|
"""便捷函数:获取Prompt模板"""
|
||||||
|
template = PromptTemplates.get_template(name)
|
||||||
|
if not template:
|
||||||
|
raise ValueError(f"不存在的模板: {name}")
|
||||||
|
return template
|
||||||
193
backend/app/services/rag/retriever.py
Normal file
193
backend/app/services/rag/retriever.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
# src/services/rag/retriever.py
|
||||||
|
"""RAG检索服务 - 封装Milvus检索"""
|
||||||
|
|
||||||
|
from typing import List, Dict, Optional, Any
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from app.services.embedding.bge_m3_embedder import BGEM3Embedder
|
||||||
|
from app.services.storage.milvus_client import MilvusClient, SearchResult
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RetrievedDocument:
|
||||||
|
"""检索到的文档"""
|
||||||
|
content: str
|
||||||
|
doc_id: str # 文档ID,用于下载
|
||||||
|
doc_name: str
|
||||||
|
section_title: str
|
||||||
|
clause_number: str
|
||||||
|
page_number: int
|
||||||
|
score: float
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class Retriever:
|
||||||
|
"""
|
||||||
|
RAG检索器
|
||||||
|
|
||||||
|
功能:
|
||||||
|
- 向量检索(Dense + Sparse混合)
|
||||||
|
- 重排序(可选)
|
||||||
|
- 过滤和筛选
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
top_k: int = None,
|
||||||
|
rerank: bool = False,
|
||||||
|
min_score: float = 0.3
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化检索器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
top_k: 检索召回数量
|
||||||
|
rerank: 是否启用重排序
|
||||||
|
min_score: 最低相关性分数阈值
|
||||||
|
"""
|
||||||
|
self.top_k = top_k or settings.rag_top_k
|
||||||
|
self.rerank = rerank
|
||||||
|
self.min_score = min_score
|
||||||
|
|
||||||
|
# 嵌入模型(延迟加载)
|
||||||
|
self.embedder: Optional[BGEM3Embedder] = None
|
||||||
|
|
||||||
|
# Milvus客户端(延迟连接)
|
||||||
|
self.milvus: Optional[MilvusClient] = None
|
||||||
|
|
||||||
|
logger.info(f"检索器初始化: top_k={self.top_k}, rerank={self.rerank}")
|
||||||
|
|
||||||
|
def _init_embedder(self):
|
||||||
|
"""延迟初始化嵌入模型"""
|
||||||
|
if self.embedder is None:
|
||||||
|
logger.info("加载嵌入模型...")
|
||||||
|
self.embedder = BGEM3Embedder(model_name=settings.embedding_model)
|
||||||
|
|
||||||
|
def _init_milvus(self):
|
||||||
|
"""延迟初始化Milvus"""
|
||||||
|
if self.milvus is None:
|
||||||
|
logger.info("连接Milvus...")
|
||||||
|
self.milvus = MilvusClient()
|
||||||
|
self.milvus.connect()
|
||||||
|
self.milvus.create_collection(recreate=False)
|
||||||
|
self.milvus.load_collection()
|
||||||
|
|
||||||
|
def retrieve(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
filters: Optional[str] = None,
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
) -> List[RetrievedDocument]:
|
||||||
|
"""
|
||||||
|
检索相关文档
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 查询文本
|
||||||
|
filters: 过滤条件(如 "regulation_type=='车辆安全'")
|
||||||
|
top_k: 返回数量(可选,覆盖默认值)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[RetrievedDocument]: 检索结果列表
|
||||||
|
"""
|
||||||
|
logger.info(f"执行检索: {query}")
|
||||||
|
|
||||||
|
# 初始化组件
|
||||||
|
self._init_embedder()
|
||||||
|
self._init_milvus()
|
||||||
|
|
||||||
|
# 生成查询向量
|
||||||
|
query_embedding = self.embedder.embed_single(query)
|
||||||
|
|
||||||
|
# 执行混合检索
|
||||||
|
results = self.milvus.hybrid_search(
|
||||||
|
query_dense=query_embedding['dense'].tolist(),
|
||||||
|
query_sparse=query_embedding['sparse'],
|
||||||
|
top_k=top_k or self.top_k,
|
||||||
|
filters=filters
|
||||||
|
)
|
||||||
|
|
||||||
|
# 转换为RetrievedDocument格式
|
||||||
|
documents = []
|
||||||
|
for r in results:
|
||||||
|
if r.score >= self.min_score:
|
||||||
|
doc = RetrievedDocument(
|
||||||
|
content=r.content,
|
||||||
|
doc_id=r.metadata.get("doc_id", ""),
|
||||||
|
doc_name=r.metadata.get("doc_name", ""),
|
||||||
|
section_title=r.metadata.get("section_title", ""),
|
||||||
|
clause_number=r.metadata.get("clause_number", ""),
|
||||||
|
page_number=r.metadata.get("page_number", 0),
|
||||||
|
score=r.score,
|
||||||
|
metadata=r.metadata
|
||||||
|
)
|
||||||
|
documents.append(doc)
|
||||||
|
|
||||||
|
logger.success(f"检索完成,返回{len(documents)}条结果(阈值过滤后)")
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def retrieve_with_scores(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
filters: Optional[str] = None
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
检索并返回完整结果(包含分数)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 查询文本
|
||||||
|
filters: 过滤条件
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: 包含分数的检索结果
|
||||||
|
"""
|
||||||
|
documents = self.retrieve(query, filters)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"content": doc.content,
|
||||||
|
"doc_id": doc.doc_id,
|
||||||
|
"doc_name": doc.doc_name,
|
||||||
|
"section_title": doc.section_title,
|
||||||
|
"clause_number": doc.clause_number,
|
||||||
|
"page_number": doc.page_number,
|
||||||
|
"score": doc.score
|
||||||
|
}
|
||||||
|
for doc in documents
|
||||||
|
]
|
||||||
|
|
||||||
|
def search_by_doc_name(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
doc_name: str
|
||||||
|
) -> List[RetrievedDocument]:
|
||||||
|
"""按文档名称过滤检索"""
|
||||||
|
filters = f'doc_name=="{doc_name}"'
|
||||||
|
return self.retrieve(query, filters)
|
||||||
|
|
||||||
|
def search_by_regulation_type(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
regulation_type: str
|
||||||
|
) -> List[RetrievedDocument]:
|
||||||
|
"""按法规类型过滤检索"""
|
||||||
|
filters = f'regulation_type=="{regulation_type}"'
|
||||||
|
return self.retrieve(query, filters)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""关闭连接"""
|
||||||
|
if self.milvus:
|
||||||
|
self.milvus.disconnect()
|
||||||
|
logger.info("检索器已关闭")
|
||||||
|
|
||||||
|
|
||||||
|
def retrieve_regulations(
|
||||||
|
query: str,
|
||||||
|
top_k: int = 10,
|
||||||
|
filters: Optional[str] = None
|
||||||
|
) -> List[RetrievedDocument]:
|
||||||
|
"""便捷函数:检索法规"""
|
||||||
|
retriever = Retriever(top_k=top_k)
|
||||||
|
results = retriever.retrieve(query, filters)
|
||||||
|
retriever.close()
|
||||||
|
return results
|
||||||
7
backend/app/services/storage/__init__.py
Normal file
7
backend/app/services/storage/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# src/services/storage/__init__.py
|
||||||
|
"""存储服务"""
|
||||||
|
|
||||||
|
from .milvus_client import MilvusClient
|
||||||
|
from .minio_client import MinIOClient
|
||||||
|
|
||||||
|
__all__ = ["MilvusClient", "MinIOClient"]
|
||||||
485
backend/app/services/storage/milvus_client.py
Normal file
485
backend/app/services/storage/milvus_client.py
Normal file
@@ -0,0 +1,485 @@
|
|||||||
|
# src/services/storage/milvus_client.py
|
||||||
|
"""Milvus向量数据库客户端 - 存储与检索服务"""
|
||||||
|
|
||||||
|
from pymilvus import (
|
||||||
|
connections,
|
||||||
|
Collection,
|
||||||
|
FieldSchema,
|
||||||
|
CollectionSchema,
|
||||||
|
DataType,
|
||||||
|
utility
|
||||||
|
)
|
||||||
|
from typing import List, Dict, Optional, Any
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from loguru import logger
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ..embedding.text_chunker import TextChunk
|
||||||
|
from ..embedding.bge_m3_embedder import EmbeddingResult
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SearchResult:
|
||||||
|
"""检索结果"""
|
||||||
|
id: int
|
||||||
|
content: str
|
||||||
|
score: float
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MilvusDocument:
|
||||||
|
"""Milvus文档数据结构"""
|
||||||
|
doc_id: str
|
||||||
|
chunk_id: str
|
||||||
|
content: str
|
||||||
|
dense_vector: List[float]
|
||||||
|
sparse_vector: Dict[int, float]
|
||||||
|
doc_name: str
|
||||||
|
section_title: str
|
||||||
|
clause_number: str
|
||||||
|
page_number: int
|
||||||
|
regulation_type: str
|
||||||
|
version: str
|
||||||
|
create_time: int
|
||||||
|
|
||||||
|
|
||||||
|
class MilvusClient:
|
||||||
|
"""Milvus向量数据库客户端"""
|
||||||
|
|
||||||
|
COLLECTION_NAME = "regulations"
|
||||||
|
|
||||||
|
SCHEMA_FIELDS = [
|
||||||
|
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
|
||||||
|
FieldSchema(name="doc_id", dtype=DataType.VARCHAR, max_length=64),
|
||||||
|
FieldSchema(name="chunk_id", dtype=DataType.VARCHAR, max_length=128),
|
||||||
|
FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=8192),
|
||||||
|
FieldSchema(name="dense_vector", dtype=DataType.FLOAT_VECTOR, dim=1024),
|
||||||
|
FieldSchema(name="sparse_vector", dtype=DataType.SPARSE_FLOAT_VECTOR),
|
||||||
|
FieldSchema(name="doc_name", dtype=DataType.VARCHAR, max_length=256),
|
||||||
|
FieldSchema(name="section_title", dtype=DataType.VARCHAR, max_length=512),
|
||||||
|
FieldSchema(name="clause_number", dtype=DataType.VARCHAR, max_length=64),
|
||||||
|
FieldSchema(name="page_number", dtype=DataType.INT64),
|
||||||
|
FieldSchema(name="regulation_type", dtype=DataType.VARCHAR, max_length=32),
|
||||||
|
FieldSchema(name="version", dtype=DataType.VARCHAR, max_length=32),
|
||||||
|
FieldSchema(name="create_time", dtype=DataType.INT64),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: str = None,
|
||||||
|
port: int = None,
|
||||||
|
collection_name: str = None,
|
||||||
|
db_name: str = None
|
||||||
|
):
|
||||||
|
self.host = host or settings.milvus_host
|
||||||
|
self.port = port or settings.milvus_port
|
||||||
|
self.collection_name = collection_name or settings.milvus_collection
|
||||||
|
self.db_name = db_name or settings.milvus_db_name
|
||||||
|
|
||||||
|
self.collection: Optional[Collection] = None
|
||||||
|
self.connected = False
|
||||||
|
|
||||||
|
logger.info(f"Milvus客户端配置: {self.host}:{self.port}, Collection: {self.collection_name}")
|
||||||
|
|
||||||
|
def connect(self) -> bool:
|
||||||
|
"""连接到Milvus服务器"""
|
||||||
|
try:
|
||||||
|
connections.connect(
|
||||||
|
alias="default",
|
||||||
|
host=self.host,
|
||||||
|
port=self.port,
|
||||||
|
db_name=self.db_name
|
||||||
|
)
|
||||||
|
self.connected = True
|
||||||
|
logger.success(f"Milvus连接成功: {self.host}:{self.port}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Milvus连接失败: {e}")
|
||||||
|
self.connected = False
|
||||||
|
return False
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
"""断开连接"""
|
||||||
|
try:
|
||||||
|
connections.disconnect("default")
|
||||||
|
self.connected = False
|
||||||
|
logger.info("Milvus连接已断开")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"断开连接时出错: {e}")
|
||||||
|
|
||||||
|
def create_collection(self, recreate: bool = False) -> bool:
|
||||||
|
"""创建Collection"""
|
||||||
|
if not self.connected:
|
||||||
|
logger.warning("未连接到Milvus,请先调用connect()")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
if utility.has_collection(self.collection_name):
|
||||||
|
if recreate:
|
||||||
|
logger.info(f"删除已存在的Collection: {self.collection_name}")
|
||||||
|
utility.drop_collection(self.collection_name)
|
||||||
|
else:
|
||||||
|
logger.info(f"Collection已存在: {self.collection_name}")
|
||||||
|
self.collection = Collection(self.collection_name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
schema = CollectionSchema(
|
||||||
|
fields=self.SCHEMA_FIELDS,
|
||||||
|
description="法规文档向量存储",
|
||||||
|
enable_dynamic_field=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.collection = Collection(
|
||||||
|
name=self.collection_name,
|
||||||
|
schema=schema
|
||||||
|
)
|
||||||
|
|
||||||
|
self._create_indexes()
|
||||||
|
|
||||||
|
logger.success(f"Collection创建成功: {self.collection_name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Collection创建失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _create_indexes(self):
|
||||||
|
"""创建向量索引"""
|
||||||
|
if not self.collection:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
dense_index_params = {
|
||||||
|
"metric_type": "COSINE",
|
||||||
|
"index_type": "IVF_FLAT",
|
||||||
|
"params": {"nlist": 128}
|
||||||
|
}
|
||||||
|
self.collection.create_index(
|
||||||
|
field_name="dense_vector",
|
||||||
|
index_params=dense_index_params
|
||||||
|
)
|
||||||
|
|
||||||
|
sparse_index_params = {
|
||||||
|
"metric_type": "IP",
|
||||||
|
"index_type": "SPARSE_INVERTED_INDEX",
|
||||||
|
"params": {"drop_ratio_build": 0.2}
|
||||||
|
}
|
||||||
|
self.collection.create_index(
|
||||||
|
field_name="sparse_vector",
|
||||||
|
index_params=sparse_index_params
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.success("向量索引创建成功")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"创建索引时出错: {e}")
|
||||||
|
|
||||||
|
def load_collection(self):
|
||||||
|
"""加载Collection到内存"""
|
||||||
|
if self.collection:
|
||||||
|
self.collection.load()
|
||||||
|
logger.info(f"Collection已加载: {self.collection_name}")
|
||||||
|
|
||||||
|
def release_collection(self):
|
||||||
|
"""释放Collection内存"""
|
||||||
|
if self.collection:
|
||||||
|
self.collection.release()
|
||||||
|
logger.info(f"Collection已释放: {self.collection_name}")
|
||||||
|
|
||||||
|
def insert_chunks(
|
||||||
|
self,
|
||||||
|
chunks: List[TextChunk],
|
||||||
|
embeddings: EmbeddingResult
|
||||||
|
) -> List[int]:
|
||||||
|
"""插入文档分块和嵌入向量"""
|
||||||
|
if not self.collection:
|
||||||
|
logger.warning("Collection未初始化")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if len(chunks) != len(embeddings.texts):
|
||||||
|
logger.warning(f"Chunks数量与嵌入数量不匹配")
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.info(f"准备插入{len(chunks)}个文档分块")
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = []
|
||||||
|
current_time = int(time.time())
|
||||||
|
|
||||||
|
for chunk, dense_emb, sparse_emb in zip(
|
||||||
|
chunks,
|
||||||
|
embeddings.dense_embeddings,
|
||||||
|
embeddings.sparse_embeddings
|
||||||
|
):
|
||||||
|
row = {
|
||||||
|
"doc_id": chunk.metadata.doc_id,
|
||||||
|
"chunk_id": chunk.metadata.chunk_id,
|
||||||
|
"content": chunk.content,
|
||||||
|
"dense_vector": dense_emb.tolist(),
|
||||||
|
"sparse_vector": sparse_emb,
|
||||||
|
"doc_name": chunk.metadata.doc_name,
|
||||||
|
"section_title": chunk.metadata.section_title,
|
||||||
|
"clause_number": chunk.metadata.clause_number,
|
||||||
|
"page_number": chunk.metadata.page_number,
|
||||||
|
"regulation_type": chunk.metadata.regulation_type,
|
||||||
|
"version": chunk.metadata.version,
|
||||||
|
"create_time": current_time
|
||||||
|
}
|
||||||
|
data.append(row)
|
||||||
|
|
||||||
|
result = self.collection.insert(data)
|
||||||
|
self.collection.flush()
|
||||||
|
|
||||||
|
logger.success(f"插入完成,共{len(result.primary_keys)}条记录")
|
||||||
|
return result.primary_keys
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"插入数据失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def hybrid_search(
|
||||||
|
self,
|
||||||
|
query_dense: List[float],
|
||||||
|
query_sparse: Dict[int, float],
|
||||||
|
top_k: int = 10,
|
||||||
|
filters: Optional[str] = None
|
||||||
|
) -> List[SearchResult]:
|
||||||
|
"""混合检索:Dense + Sparse"""
|
||||||
|
if not self.collection:
|
||||||
|
logger.warning("Collection未初始化")
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.collection.load()
|
||||||
|
|
||||||
|
# 使用简单的Dense检索(兼容所有版本)
|
||||||
|
dense_results = self.dense_search(query_dense, top_k, filters)
|
||||||
|
|
||||||
|
# 可选:合并Sparse结果
|
||||||
|
if query_sparse:
|
||||||
|
sparse_results = self.sparse_search(query_sparse, top_k, filters)
|
||||||
|
merged = self._merge_results(dense_results, sparse_results, top_k)
|
||||||
|
logger.success(f"混合检索完成,返回{len(merged)}条结果")
|
||||||
|
return merged
|
||||||
|
|
||||||
|
return dense_results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"混合检索失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _merge_results(
|
||||||
|
self,
|
||||||
|
dense_results: List[SearchResult],
|
||||||
|
sparse_results: List[SearchResult],
|
||||||
|
top_k: int,
|
||||||
|
dense_weight: float = 0.6
|
||||||
|
) -> List[SearchResult]:
|
||||||
|
"""手动融合Dense和Sparse结果"""
|
||||||
|
sparse_weight = 1 - dense_weight
|
||||||
|
merged_dict = {}
|
||||||
|
|
||||||
|
for r in dense_results:
|
||||||
|
merged_dict[r.id] = {
|
||||||
|
"result": r,
|
||||||
|
"dense_score": r.score * dense_weight,
|
||||||
|
"sparse_score": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
for r in sparse_results:
|
||||||
|
if r.id in merged_dict:
|
||||||
|
merged_dict[r.id]["sparse_score"] = r.score * sparse_weight
|
||||||
|
else:
|
||||||
|
merged_dict[r.id] = {
|
||||||
|
"result": r,
|
||||||
|
"dense_score": 0,
|
||||||
|
"sparse_score": r.score * sparse_weight
|
||||||
|
}
|
||||||
|
|
||||||
|
final_results = []
|
||||||
|
for id_, data in merged_dict.items():
|
||||||
|
result = data["result"]
|
||||||
|
final_score = data["dense_score"] + data["sparse_score"]
|
||||||
|
final_results.append(SearchResult(
|
||||||
|
id=result.id,
|
||||||
|
content=result.content,
|
||||||
|
score=final_score,
|
||||||
|
metadata=result.metadata
|
||||||
|
))
|
||||||
|
|
||||||
|
final_results.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
return final_results[:top_k]
|
||||||
|
|
||||||
|
def dense_search(
|
||||||
|
self,
|
||||||
|
query_dense: List[float],
|
||||||
|
top_k: int = 10,
|
||||||
|
filters: Optional[str] = None
|
||||||
|
) -> List[SearchResult]:
|
||||||
|
"""纯Dense向量检索"""
|
||||||
|
if not self.collection:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.collection.load()
|
||||||
|
|
||||||
|
search_params = {
|
||||||
|
"metric_type": "COSINE",
|
||||||
|
"params": {"nprobe": 16}
|
||||||
|
}
|
||||||
|
|
||||||
|
results = self.collection.search(
|
||||||
|
data=[query_dense],
|
||||||
|
anns_field="dense_vector",
|
||||||
|
param=search_params,
|
||||||
|
limit=top_k,
|
||||||
|
filter=filters,
|
||||||
|
output_fields=[
|
||||||
|
"doc_id", "chunk_id", "content",
|
||||||
|
"doc_name", "section_title", "clause_number",
|
||||||
|
"page_number", "regulation_type", "version"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
search_results = []
|
||||||
|
for hits in results:
|
||||||
|
for hit in hits:
|
||||||
|
result = SearchResult(
|
||||||
|
id=hit.id,
|
||||||
|
content=hit.entity.get("content", ""),
|
||||||
|
score=hit.score,
|
||||||
|
metadata={
|
||||||
|
"doc_id": hit.entity.get("doc_id", ""),
|
||||||
|
"chunk_id": hit.entity.get("chunk_id", ""),
|
||||||
|
"doc_name": hit.entity.get("doc_name", ""),
|
||||||
|
"section_title": hit.entity.get("section_title", ""),
|
||||||
|
"clause_number": hit.entity.get("clause_number", ""),
|
||||||
|
"page_number": hit.entity.get("page_number", 0),
|
||||||
|
"regulation_type": hit.entity.get("regulation_type", ""),
|
||||||
|
"version": hit.entity.get("version", ""),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
search_results.append(result)
|
||||||
|
|
||||||
|
return search_results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Dense检索失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def sparse_search(
|
||||||
|
self,
|
||||||
|
query_sparse: Dict[int, float],
|
||||||
|
top_k: int = 10,
|
||||||
|
filters: Optional[str] = None
|
||||||
|
) -> List[SearchResult]:
|
||||||
|
"""纯Sparse向量检索"""
|
||||||
|
if not self.collection:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.collection.load()
|
||||||
|
|
||||||
|
search_params = {
|
||||||
|
"metric_type": "IP",
|
||||||
|
"params": {"drop_ratio_search": 0.2}
|
||||||
|
}
|
||||||
|
|
||||||
|
results = self.collection.search(
|
||||||
|
data=[query_sparse],
|
||||||
|
anns_field="sparse_vector",
|
||||||
|
param=search_params,
|
||||||
|
limit=top_k,
|
||||||
|
filter=filters,
|
||||||
|
output_fields=[
|
||||||
|
"doc_id", "chunk_id", "content",
|
||||||
|
"doc_name", "section_title", "clause_number",
|
||||||
|
"page_number", "regulation_type", "version"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
search_results = []
|
||||||
|
for hits in results:
|
||||||
|
for hit in hits:
|
||||||
|
result = SearchResult(
|
||||||
|
id=hit.id,
|
||||||
|
content=hit.entity.get("content", ""),
|
||||||
|
score=hit.score,
|
||||||
|
metadata={
|
||||||
|
"doc_id": hit.entity.get("doc_id", ""),
|
||||||
|
"chunk_id": hit.entity.get("chunk_id", ""),
|
||||||
|
"doc_name": hit.entity.get("doc_name", ""),
|
||||||
|
"section_title": hit.entity.get("section_title", ""),
|
||||||
|
"clause_number": hit.entity.get("clause_number", ""),
|
||||||
|
"page_number": hit.entity.get("page_number", 0),
|
||||||
|
"regulation_type": hit.entity.get("regulation_type", ""),
|
||||||
|
"version": hit.entity.get("version", ""),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
search_results.append(result)
|
||||||
|
|
||||||
|
return search_results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Sparse检索失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def delete_by_doc_id(self, doc_id: str) -> int:
|
||||||
|
"""根据doc_id删除记录"""
|
||||||
|
if not self.collection:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
expr = f'doc_id=="{doc_id}"'
|
||||||
|
result = self.collection.delete(expr)
|
||||||
|
logger.info(f"删除记录: doc_id={doc_id}, 数量={len(result.primary_keys)}")
|
||||||
|
return len(result.primary_keys)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除失败: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def get_collection_stats(self) -> Dict[str, Any]:
|
||||||
|
"""获取Collection统计信息"""
|
||||||
|
if not self.collection:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
stats = {
|
||||||
|
"name": self.collection_name,
|
||||||
|
"num_entities": self.collection.num_entities,
|
||||||
|
"description": self.collection.description,
|
||||||
|
}
|
||||||
|
return stats
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"获取统计信息失败: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def create_milvus_client() -> MilvusClient:
|
||||||
|
"""便捷函数:创建Milvus客户端"""
|
||||||
|
client = MilvusClient()
|
||||||
|
client.connect()
|
||||||
|
client.create_collection(recreate=False)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def insert_documents(
|
||||||
|
client: MilvusClient,
|
||||||
|
chunks: List[TextChunk],
|
||||||
|
embeddings: EmbeddingResult
|
||||||
|
) -> List[int]:
|
||||||
|
"""便捷函数:插入文档"""
|
||||||
|
return client.insert_chunks(chunks, embeddings)
|
||||||
|
|
||||||
|
|
||||||
|
def search_regulations(
|
||||||
|
client: MilvusClient,
|
||||||
|
query_dense: List[float],
|
||||||
|
query_sparse: Dict[int, float],
|
||||||
|
top_k: int = 10
|
||||||
|
) -> List[SearchResult]:
|
||||||
|
"""便捷函数:检索法规"""
|
||||||
|
return client.hybrid_search(query_dense, query_sparse, top_k)
|
||||||
352
backend/app/services/storage/minio_client.py
Normal file
352
backend/app/services/storage/minio_client.py
Normal file
@@ -0,0 +1,352 @@
|
|||||||
|
# src/services/storage/minio_client.py
|
||||||
|
"""MinIO对象存储客户端 - 文档文件存储"""
|
||||||
|
|
||||||
|
from minio import Minio
|
||||||
|
from minio.error import S3Error
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from loguru import logger
|
||||||
|
from io import BytesIO
|
||||||
|
import os
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class MinIOClient:
|
||||||
|
"""MinIO对象存储客户端"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
endpoint: str = None,
|
||||||
|
access_key: str = None,
|
||||||
|
secret_key: str = None,
|
||||||
|
bucket: str = None,
|
||||||
|
secure: bool = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化MinIO客户端
|
||||||
|
|
||||||
|
Args:
|
||||||
|
endpoint: MinIO服务地址
|
||||||
|
access_key: 访问密钥
|
||||||
|
secret_key: 秘密密钥
|
||||||
|
bucket: 存储桶名称
|
||||||
|
secure: 是否使用HTTPS
|
||||||
|
"""
|
||||||
|
self.endpoint = endpoint or settings.minio_endpoint
|
||||||
|
self.access_key = access_key or settings.minio_access_key
|
||||||
|
self.secret_key = secret_key or settings.minio_secret_key
|
||||||
|
self.bucket = bucket or settings.minio_bucket
|
||||||
|
self.secure = secure or settings.minio_secure
|
||||||
|
|
||||||
|
self.client: Optional[Minio] = None
|
||||||
|
self.connected = False
|
||||||
|
|
||||||
|
logger.info(f"MinIO客户端配置: {self.endpoint}, bucket={self.bucket}")
|
||||||
|
|
||||||
|
def connect(self) -> bool:
|
||||||
|
"""连接MinIO服务"""
|
||||||
|
try:
|
||||||
|
self.client = Minio(
|
||||||
|
self.endpoint,
|
||||||
|
access_key=self.access_key,
|
||||||
|
secret_key=self.secret_key,
|
||||||
|
secure=self.secure
|
||||||
|
)
|
||||||
|
self.connected = True
|
||||||
|
logger.success(f"MinIO连接成功: {self.endpoint}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"MinIO连接失败: {e}")
|
||||||
|
self.connected = False
|
||||||
|
return False
|
||||||
|
|
||||||
|
def ensure_bucket(self) -> bool:
|
||||||
|
"""确保存储桶存在"""
|
||||||
|
if not self.connected:
|
||||||
|
logger.warning("未连接MinIO,请先调用connect()")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not self.client.bucket_exists(self.bucket):
|
||||||
|
self.client.make_bucket(self.bucket)
|
||||||
|
logger.success(f"创建存储桶: {self.bucket}")
|
||||||
|
else:
|
||||||
|
logger.info(f"存储桶已存在: {self.bucket}")
|
||||||
|
return True
|
||||||
|
except S3Error as e:
|
||||||
|
logger.error(f"存储桶操作失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def upload_file(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
object_name: str,
|
||||||
|
metadata: Dict[str, Any] = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
上传本地文件到MinIO
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: 本地文件路径
|
||||||
|
object_name: MinIO对象名称
|
||||||
|
metadata: 元数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否成功
|
||||||
|
"""
|
||||||
|
if not self.connected:
|
||||||
|
self.connect()
|
||||||
|
self.ensure_bucket()
|
||||||
|
|
||||||
|
try:
|
||||||
|
file_size = os.stat(file_path).st_size
|
||||||
|
content_type = self._get_content_type(file_path)
|
||||||
|
|
||||||
|
with open(file_path, 'rb') as f:
|
||||||
|
self.client.put_object(
|
||||||
|
self.bucket,
|
||||||
|
object_name,
|
||||||
|
f,
|
||||||
|
file_size,
|
||||||
|
content_type=content_type,
|
||||||
|
metadata=metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.success(f"文件上传成功: {object_name}, 大小={file_size}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except S3Error as e:
|
||||||
|
logger.error(f"文件上传失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def upload_bytes(
|
||||||
|
self,
|
||||||
|
data: bytes,
|
||||||
|
object_name: str,
|
||||||
|
content_type: str = "application/octet-stream",
|
||||||
|
metadata: Dict[str, Any] = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
上传字节数据到MinIO
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 文件字节数据
|
||||||
|
object_name: MinIO对象名称
|
||||||
|
content_type: 内容类型
|
||||||
|
metadata: 元数据(注意:MinIO仅支持US-ASCII字符)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否成功
|
||||||
|
"""
|
||||||
|
if not self.connected:
|
||||||
|
self.connect()
|
||||||
|
self.ensure_bucket()
|
||||||
|
|
||||||
|
try:
|
||||||
|
data_stream = BytesIO(data)
|
||||||
|
|
||||||
|
# 处理metadata:仅保留ASCII安全字符
|
||||||
|
safe_metadata = None
|
||||||
|
if metadata:
|
||||||
|
safe_metadata = {}
|
||||||
|
for key, value in metadata.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
# 只保留ASCII字符或转换为安全格式
|
||||||
|
try:
|
||||||
|
value.encode('ascii')
|
||||||
|
safe_metadata[key] = value
|
||||||
|
except UnicodeEncodeError:
|
||||||
|
# 中文字符跳过或用占位符
|
||||||
|
safe_metadata[key] = ""
|
||||||
|
else:
|
||||||
|
safe_metadata[key] = str(value)
|
||||||
|
|
||||||
|
self.client.put_object(
|
||||||
|
self.bucket,
|
||||||
|
object_name,
|
||||||
|
data_stream,
|
||||||
|
len(data),
|
||||||
|
content_type=content_type,
|
||||||
|
metadata=safe_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.success(f"数据上传成功: {object_name}, 大小={len(data)}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except S3Error as e:
|
||||||
|
logger.error(f"数据上传失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def download_file(
|
||||||
|
self,
|
||||||
|
object_name: str,
|
||||||
|
file_path: str
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
从MinIO下载文件到本地
|
||||||
|
|
||||||
|
Args:
|
||||||
|
object_name: MinIO对象名称
|
||||||
|
file_path: 本地保存路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否成功
|
||||||
|
"""
|
||||||
|
if not self.connected:
|
||||||
|
self.connect()
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client.fget_object(
|
||||||
|
self.bucket,
|
||||||
|
object_name,
|
||||||
|
file_path
|
||||||
|
)
|
||||||
|
logger.success(f"文件下载成功: {object_name} -> {file_path}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except S3Error as e:
|
||||||
|
logger.error(f"文件下载失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_object_url(
|
||||||
|
self,
|
||||||
|
object_name: str,
|
||||||
|
expires: int = 3600
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
获取对象下载URL(临时URL)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
object_name: MinIO对象名称
|
||||||
|
expires: URL有效期(秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 下载URL
|
||||||
|
"""
|
||||||
|
if not self.connected:
|
||||||
|
self.connect()
|
||||||
|
|
||||||
|
try:
|
||||||
|
url = self.client.presigned_get_object(
|
||||||
|
self.bucket,
|
||||||
|
object_name,
|
||||||
|
expires=expires
|
||||||
|
)
|
||||||
|
return url
|
||||||
|
|
||||||
|
except S3Error as e:
|
||||||
|
logger.error(f"获取URL失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_object_data(self, object_name: str) -> Optional[bytes]:
|
||||||
|
"""
|
||||||
|
获取对象数据(字节)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
object_name: MinIO对象名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: 文件数据
|
||||||
|
"""
|
||||||
|
if not self.connected:
|
||||||
|
self.connect()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.client.get_object(self.bucket, object_name)
|
||||||
|
data = response.read()
|
||||||
|
response.close()
|
||||||
|
response.release_conn()
|
||||||
|
return data
|
||||||
|
|
||||||
|
except S3Error as e:
|
||||||
|
logger.error(f"获取对象数据失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def delete_object(self, object_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
删除对象
|
||||||
|
|
||||||
|
Args:
|
||||||
|
object_name: MinIO对象名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否成功
|
||||||
|
"""
|
||||||
|
if not self.connected:
|
||||||
|
self.connect()
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client.remove_object(self.bucket, object_name)
|
||||||
|
logger.info(f"对象删除成功: {object_name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except S3Error as e:
|
||||||
|
logger.error(f"对象删除失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def list_objects(self, prefix: str = "") -> list:
|
||||||
|
"""
|
||||||
|
列出存储桶中的对象
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: 对象名称前缀
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: 对象列表
|
||||||
|
"""
|
||||||
|
if not self.connected:
|
||||||
|
self.connect()
|
||||||
|
|
||||||
|
try:
|
||||||
|
objects = self.client.list_objects(self.bucket, prefix=prefix)
|
||||||
|
return [obj.object_name for obj in objects]
|
||||||
|
|
||||||
|
except S3Error as e:
|
||||||
|
logger.error(f"列出对象失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def object_exists(self, object_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查对象是否存在
|
||||||
|
|
||||||
|
Args:
|
||||||
|
object_name: MinIO对象名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否存在
|
||||||
|
"""
|
||||||
|
if not self.connected:
|
||||||
|
self.connect()
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client.stat_object(self.bucket, object_name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
except S3Error:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _get_content_type(self, file_path: str) -> str:
|
||||||
|
"""根据文件扩展名获取Content-Type"""
|
||||||
|
ext = os.path.splitext(file_path)[1].lower()
|
||||||
|
content_types = {
|
||||||
|
'.pdf': 'application/pdf',
|
||||||
|
'.doc': 'application/msword',
|
||||||
|
'.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||||
|
'.txt': 'text/plain',
|
||||||
|
'.json': 'application/json',
|
||||||
|
'.xml': 'application/xml',
|
||||||
|
}
|
||||||
|
return content_types.get(ext, 'application/octet-stream')
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""关闭连接(MinIO客户端无需显式关闭)"""
|
||||||
|
self.connected = False
|
||||||
|
logger.info("MinIO客户端已关闭")
|
||||||
|
|
||||||
|
|
||||||
|
def create_minio_client() -> MinIOClient:
|
||||||
|
"""便捷函数:创建MinIO客户端"""
|
||||||
|
client = MinIOClient()
|
||||||
|
client.connect()
|
||||||
|
client.ensure_bucket()
|
||||||
|
return client
|
||||||
4
backend/app/utils/__init__.py
Normal file
4
backend/app/utils/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from .chunking import TextChunker, chunker
|
||||||
|
from .logger import logger, setup_logging
|
||||||
|
|
||||||
|
__all__ = ["TextChunker", "chunker", "logger", "setup_logging"]
|
||||||
78
backend/app/utils/chunking.py
Normal file
78
backend/app/utils/chunking.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
import re
|
||||||
|
from typing import List
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
class TextChunker:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size: int = settings.chunk_size,
|
||||||
|
chunk_overlap: int = settings.chunk_overlap,
|
||||||
|
):
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self.chunk_overlap = chunk_overlap
|
||||||
|
|
||||||
|
def chunk_by_clause(self, text: str) -> List[dict]:
|
||||||
|
"""按条款边界分块(适用于法规文档)"""
|
||||||
|
clause_pattern = r"(第[一二三四五六七八九十百]+条)"
|
||||||
|
parts = re.split(clause_pattern, text)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
current_clause = None
|
||||||
|
current_text = ""
|
||||||
|
chunk_index = 0
|
||||||
|
|
||||||
|
for part in parts:
|
||||||
|
if re.match(clause_pattern, part):
|
||||||
|
if current_clause and current_text.strip():
|
||||||
|
chunks.append({
|
||||||
|
"clause_id": current_clause,
|
||||||
|
"content": current_text.strip(),
|
||||||
|
"chunk_index": chunk_index,
|
||||||
|
})
|
||||||
|
chunk_index += 1
|
||||||
|
current_clause = part
|
||||||
|
current_text = ""
|
||||||
|
else:
|
||||||
|
current_text += part
|
||||||
|
|
||||||
|
if current_clause and current_text.strip():
|
||||||
|
chunks.append({
|
||||||
|
"clause_id": current_clause,
|
||||||
|
"content": current_text.strip(),
|
||||||
|
"chunk_index": chunk_index,
|
||||||
|
})
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def chunk_by_size(self, text: str) -> List[dict]:
|
||||||
|
"""按固定大小分块"""
|
||||||
|
chunks = []
|
||||||
|
start = 0
|
||||||
|
chunk_index = 0
|
||||||
|
|
||||||
|
while start < len(text):
|
||||||
|
end = start + self.chunk_size
|
||||||
|
chunk_text = text[start:end]
|
||||||
|
|
||||||
|
if chunk_text.strip():
|
||||||
|
chunks.append({
|
||||||
|
"content": chunk_text.strip(),
|
||||||
|
"chunk_index": chunk_index,
|
||||||
|
"start_pos": start,
|
||||||
|
"end_pos": end,
|
||||||
|
})
|
||||||
|
chunk_index += 1
|
||||||
|
|
||||||
|
start = end - self.chunk_overlap
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def estimate_tokens(self, text: str) -> int:
|
||||||
|
"""估算token数量"""
|
||||||
|
chinese_chars = len(re.findall(r"[^\x00-\xff]", text))
|
||||||
|
english_chars = len(text) - chinese_chars
|
||||||
|
return int(chinese_chars / 1.5 + english_chars / 4)
|
||||||
|
|
||||||
|
|
||||||
|
chunker = TextChunker()
|
||||||
24
backend/app/utils/logger.py
Normal file
24
backend/app/utils/logger.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging() -> logging.Logger:
|
||||||
|
"""配置日志"""
|
||||||
|
logger = logging.getLogger("app")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
handler = logging.StreamHandler(sys.stdout)
|
||||||
|
handler.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = setup_logging()
|
||||||
2
backend/app/workers/__init__.py
Normal file
2
backend/app/workers/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# src/workers/__init__.py
|
||||||
|
"""异步任务Worker模块"""
|
||||||
12
backend/app/workflows/__init__.py
Normal file
12
backend/app/workflows/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from .rag_workflow import RagState, rag_workflow, run_rag_workflow, stream_rag_workflow
|
||||||
|
from .compliance_workflow import ComplianceState, compliance_workflow, run_compliance_workflow
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"RagState",
|
||||||
|
"rag_workflow",
|
||||||
|
"run_rag_workflow",
|
||||||
|
"stream_rag_workflow",
|
||||||
|
"ComplianceState",
|
||||||
|
"compliance_workflow",
|
||||||
|
"run_compliance_workflow",
|
||||||
|
]
|
||||||
175
backend/app/workflows/compliance_workflow.py
Normal file
175
backend/app/workflows/compliance_workflow.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
from typing import TypedDict, List
|
||||||
|
from langgraph.graph import StateGraph, END
|
||||||
|
|
||||||
|
|
||||||
|
class ComplianceState(TypedDict):
|
||||||
|
document_path: str
|
||||||
|
raw_text: str
|
||||||
|
segments: List[dict]
|
||||||
|
matched_regulations: List[dict]
|
||||||
|
risk_dashboard: dict
|
||||||
|
priority_actions: List[dict]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_document(state: ComplianceState) -> dict:
|
||||||
|
"""解析文档"""
|
||||||
|
from app.services import get_document_service
|
||||||
|
doc_service = get_document_service(
|
||||||
|
"/airegulation/demo-mao/backend/data/raw",
|
||||||
|
"/airegulation/demo-mao/backend/data/parsed",
|
||||||
|
)
|
||||||
|
text = doc_service.parse_document(state["document_path"])
|
||||||
|
return {"raw_text": text}
|
||||||
|
|
||||||
|
|
||||||
|
def segment_document(state: ComplianceState) -> dict:
|
||||||
|
"""AI语义分段"""
|
||||||
|
from app.services import llm_service
|
||||||
|
prompt = f"""请分析以下设计方案文档,按照设计意图将其分成若干语义段落。
|
||||||
|
|
||||||
|
文档内容:
|
||||||
|
{state['raw_text'][:3000]}
|
||||||
|
|
||||||
|
请输出JSON格式的分段结果,每个段落包含:
|
||||||
|
- intent: 段落意图/主题
|
||||||
|
- startPos: 在原文中的起始位置(大致)
|
||||||
|
- endPos: 在原文中的结束位置(大致)
|
||||||
|
- keywords: 关键词列表
|
||||||
|
|
||||||
|
输出格式:
|
||||||
|
[{{"intent": "...", "startPos": 0, "endPos": 100, "keywords": [...]}}]"""
|
||||||
|
|
||||||
|
# 简化处理:返回基本分段
|
||||||
|
segments = [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"intent": "整体设计概述",
|
||||||
|
"content": state["raw_text"][:500],
|
||||||
|
"keywords": ["设计", "方案"],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
return {"segments": segments}
|
||||||
|
|
||||||
|
|
||||||
|
def match_regulations(state: ComplianceState) -> dict:
|
||||||
|
"""法规匹配"""
|
||||||
|
from app.services import embedding_service, milvus_service
|
||||||
|
matched = []
|
||||||
|
|
||||||
|
for segment in state["segments"]:
|
||||||
|
keyword_text = " ".join(segment.get("keywords", []))
|
||||||
|
embedding = embedding_service.embed_single(keyword_text)
|
||||||
|
|
||||||
|
docs = milvus_service.search(embedding, top_k=5)
|
||||||
|
|
||||||
|
segment_regs = []
|
||||||
|
for doc in docs:
|
||||||
|
category = "high" if doc["score"] > 0.85 else ("medium" if doc["score"] > 0.6 else "low")
|
||||||
|
segment_regs.append({
|
||||||
|
"id": doc["id"],
|
||||||
|
"name": doc["doc_name"],
|
||||||
|
"clause": doc.get("clause_id"),
|
||||||
|
"score": doc["score"],
|
||||||
|
"match_keyword": keyword_text,
|
||||||
|
"category": category,
|
||||||
|
"full_content": doc["content"],
|
||||||
|
})
|
||||||
|
|
||||||
|
segment["regulations"] = segment_regs
|
||||||
|
matched.append(segment)
|
||||||
|
|
||||||
|
return {"matched_regulations": matched}
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_risk(state: ComplianceState) -> dict:
|
||||||
|
"""计算风险等级"""
|
||||||
|
segments = state["matched_regulations"]
|
||||||
|
|
||||||
|
high_count = 0
|
||||||
|
medium_count = 0
|
||||||
|
low_count = 0
|
||||||
|
need_fix = 0
|
||||||
|
total_score = 0
|
||||||
|
|
||||||
|
for segment in segments:
|
||||||
|
regs = segment.get("regulations", [])
|
||||||
|
high_regs = [r for r in regs if r["category"] == "high"]
|
||||||
|
|
||||||
|
if high_regs:
|
||||||
|
avg_score = sum(r["score"] for r in high_regs) / len(high_regs)
|
||||||
|
if avg_score < 0.9:
|
||||||
|
segment["risk_level"] = "high"
|
||||||
|
high_count += 1
|
||||||
|
need_fix += 1
|
||||||
|
elif avg_score < 0.92:
|
||||||
|
segment["risk_level"] = "medium"
|
||||||
|
medium_count += 1
|
||||||
|
else:
|
||||||
|
segment["risk_level"] = "low"
|
||||||
|
low_count += 1
|
||||||
|
else:
|
||||||
|
segment["risk_level"] = "low"
|
||||||
|
low_count += 1
|
||||||
|
|
||||||
|
total_score += avg_score if high_regs else 100
|
||||||
|
|
||||||
|
avg_score = total_score / len(segments) if segments else 100
|
||||||
|
|
||||||
|
status = "pass" if avg_score >= 90 else ("warning" if avg_score >= 70 else "fail")
|
||||||
|
status_label = "合规" if status == "pass" else ("需要修改" if status == "warning" else "高风险")
|
||||||
|
|
||||||
|
dashboard = {
|
||||||
|
"score": avg_score,
|
||||||
|
"high_risk_count": high_count,
|
||||||
|
"medium_risk_count": medium_count,
|
||||||
|
"low_risk_count": low_count,
|
||||||
|
"need_fix_segments": need_fix,
|
||||||
|
"status": status,
|
||||||
|
"status_label": status_label,
|
||||||
|
}
|
||||||
|
|
||||||
|
return {"risk_dashboard": dashboard, "segments": segments}
|
||||||
|
|
||||||
|
|
||||||
|
def generate_suggestions(state: ComplianceState) -> dict:
|
||||||
|
"""生成优先建议"""
|
||||||
|
actions = []
|
||||||
|
|
||||||
|
for segment in state["segments"]:
|
||||||
|
for reg in segment.get("regulations", []):
|
||||||
|
if reg["category"] == "high" and reg["score"] < 0.9:
|
||||||
|
actions.append({
|
||||||
|
"regulation": reg["name"],
|
||||||
|
"issue": reg["match_keyword"],
|
||||||
|
"suggestion": f"建议对照{reg['name']}第{reg.get('clause', '')}条进行修改",
|
||||||
|
"severity": "high",
|
||||||
|
})
|
||||||
|
|
||||||
|
return {"priority_actions": actions}
|
||||||
|
|
||||||
|
|
||||||
|
# 构建工作流图
|
||||||
|
compliance_graph = StateGraph(ComplianceState)
|
||||||
|
|
||||||
|
compliance_graph.add_node("parse", parse_document)
|
||||||
|
compliance_graph.add_node("segment", segment_document)
|
||||||
|
compliance_graph.add_node("match", match_regulations)
|
||||||
|
compliance_graph.add_node("score", calculate_risk)
|
||||||
|
compliance_graph.add_node("suggest", generate_suggestions)
|
||||||
|
|
||||||
|
compliance_graph.set_entry_point("parse")
|
||||||
|
compliance_graph.add_edge("parse", "segment")
|
||||||
|
compliance_graph.add_edge("segment", "match")
|
||||||
|
compliance_graph.add_edge("match", "score")
|
||||||
|
compliance_graph.add_edge("score", "suggest")
|
||||||
|
compliance_graph.add_edge("suggest", END)
|
||||||
|
|
||||||
|
compliance_workflow = compliance_graph.compile()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_compliance_workflow(document_path: str) -> ComplianceState:
|
||||||
|
"""运行合规分析工作流"""
|
||||||
|
initial_state: ComplianceState = {"document_path": document_path}
|
||||||
|
result = compliance_workflow.invoke(initial_state)
|
||||||
|
return result
|
||||||
114
backend/app/workflows/rag_workflow.py
Normal file
114
backend/app/workflows/rag_workflow.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
from typing import TypedDict, List
|
||||||
|
from langgraph.graph import StateGraph, END
|
||||||
|
|
||||||
|
|
||||||
|
class RagState(TypedDict):
|
||||||
|
query: str
|
||||||
|
query_embedding: List[float]
|
||||||
|
retrieved_docs: List[dict]
|
||||||
|
context: str
|
||||||
|
answer: str
|
||||||
|
sources: List[dict]
|
||||||
|
|
||||||
|
|
||||||
|
def embed_query(state: RagState) -> dict:
|
||||||
|
"""将查询转为向量"""
|
||||||
|
from app.services import embedding_service
|
||||||
|
embedding = embedding_service.embed_single(state["query"])
|
||||||
|
return {"query_embedding": embedding}
|
||||||
|
|
||||||
|
|
||||||
|
def retrieve_docs(state: RagState) -> dict:
|
||||||
|
"""向量检索"""
|
||||||
|
from app.services import milvus_service
|
||||||
|
from app.core.config import settings
|
||||||
|
docs = milvus_service.search(
|
||||||
|
state["query_embedding"],
|
||||||
|
top_k=settings.vector_top_k,
|
||||||
|
)
|
||||||
|
return {"retrieved_docs": docs[:settings.final_top_k]}
|
||||||
|
|
||||||
|
|
||||||
|
def build_context(state: RagState) -> dict:
|
||||||
|
"""构建上下文"""
|
||||||
|
context_parts = []
|
||||||
|
sources = []
|
||||||
|
|
||||||
|
for doc in state["retrieved_docs"]:
|
||||||
|
context_parts.append(f"【{doc['doc_name']} - {doc.get('clause_id', '')}】\n{doc['content']}")
|
||||||
|
sources.append({
|
||||||
|
"name": doc["doc_name"],
|
||||||
|
"clause": doc.get("clause_id"),
|
||||||
|
})
|
||||||
|
|
||||||
|
context = "\n\n".join(context_parts)
|
||||||
|
return {"context": context, "sources": sources}
|
||||||
|
|
||||||
|
|
||||||
|
def generate_answer(state: RagState) -> dict:
|
||||||
|
"""生成答案"""
|
||||||
|
from app.services import llm_service
|
||||||
|
prompt = f"""请根据以下法规内容回答用户问题,并在回答中标注引用的法规条款。
|
||||||
|
|
||||||
|
法规内容:
|
||||||
|
{state['context']}
|
||||||
|
|
||||||
|
用户问题:{state['query']}
|
||||||
|
|
||||||
|
请给出准确、简洁的回答,并引用相关法规条款。"""
|
||||||
|
|
||||||
|
answer = ""
|
||||||
|
for chunk in llm_service.generate_stream(prompt):
|
||||||
|
answer += chunk
|
||||||
|
|
||||||
|
return {"answer": answer}
|
||||||
|
|
||||||
|
|
||||||
|
# 构建工作流图
|
||||||
|
rag_graph = StateGraph(RagState)
|
||||||
|
|
||||||
|
rag_graph.add_node("embed", embed_query)
|
||||||
|
rag_graph.add_node("retrieve", retrieve_docs)
|
||||||
|
rag_graph.add_node("build_context", build_context)
|
||||||
|
rag_graph.add_node("generate", generate_answer)
|
||||||
|
|
||||||
|
rag_graph.set_entry_point("embed")
|
||||||
|
rag_graph.add_edge("embed", "retrieve")
|
||||||
|
rag_graph.add_edge("retrieve", "build_context")
|
||||||
|
rag_graph.add_edge("build_context", "generate")
|
||||||
|
rag_graph.add_edge("generate", END)
|
||||||
|
|
||||||
|
rag_workflow = rag_graph.compile()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_rag_workflow(query: str) -> RagState:
|
||||||
|
"""运行RAG工作流"""
|
||||||
|
initial_state: RagState = {"query": query}
|
||||||
|
result = rag_workflow.invoke(initial_state)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def stream_rag_workflow(query: str):
|
||||||
|
"""流式运行RAG工作流"""
|
||||||
|
from app.services import llm_service
|
||||||
|
|
||||||
|
# 先完成检索阶段
|
||||||
|
state: RagState = {"query": query}
|
||||||
|
state.update(embed_query(state))
|
||||||
|
state.update(retrieve_docs(state))
|
||||||
|
state.update(build_context(state))
|
||||||
|
|
||||||
|
# 流式生成阶段
|
||||||
|
prompt = f"""请根据以下法规内容回答用户问题,并在回答中标注引用的法规条款。
|
||||||
|
|
||||||
|
法规内容:
|
||||||
|
{state['context']}
|
||||||
|
|
||||||
|
用户问题:{state['query']}
|
||||||
|
|
||||||
|
请给出准确、简洁的回答,并引用相关法规条款。"""
|
||||||
|
|
||||||
|
for chunk in llm_service.generate_stream(prompt):
|
||||||
|
yield {"type": "chunk", "text": chunk}
|
||||||
|
|
||||||
|
yield {"type": "done", "sources": state["sources"]}
|
||||||
1
backend/data/raw/compliance_task-32e64724_test_doc.txt
Normal file
1
backend/data/raw/compliance_task-32e64724_test_doc.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
test content
|
||||||
2
backend/data/raw/doc-3b47abd7_requirement.txt
Normal file
2
backend/data/raw/doc-3b47abd7_requirement.txt
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
apache-flink==1.13.2
|
||||||
|
PyMySQL>=1.1.0
|
||||||
2
backend/data/raw/doc-9b01a78a_requirement.txt
Normal file
2
backend/data/raw/doc-9b01a78a_requirement.txt
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
apache-flink==1.13.2
|
||||||
|
PyMySQL>=1.1.0
|
||||||
19
backend/main.py
Normal file
19
backend/main.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
"""Convenience launcher for the migrated backend app."""
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
uvicorn.run(
|
||||||
|
"app.main:app",
|
||||||
|
host=settings.api_host,
|
||||||
|
port=settings.api_port,
|
||||||
|
reload=settings.debug,
|
||||||
|
log_level="info",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
35
backend/pyproject.toml
Normal file
35
backend/pyproject.toml
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
[project]
|
||||||
|
name = "ai-regulations-backend"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Migrated FastAPI backend for AI regulations demo"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = [
|
||||||
|
"fastapi>=0.110.0",
|
||||||
|
"uvicorn[standard]>=0.27.0",
|
||||||
|
"python-multipart>=0.0.9",
|
||||||
|
"pydantic>=2.0.0",
|
||||||
|
"pydantic-settings>=2.0.0",
|
||||||
|
"python-dotenv>=1.0.0",
|
||||||
|
"loguru>=0.7.0",
|
||||||
|
"httpx>=0.25.0",
|
||||||
|
"tiktoken>=0.5.0",
|
||||||
|
"tenacity>=8.2.0",
|
||||||
|
"pymilvus>=2.4.0",
|
||||||
|
"minio>=7.1.0",
|
||||||
|
"pymupdf>=1.24.0",
|
||||||
|
"python-docx>=1.1.0",
|
||||||
|
"FlagEmbedding>=1.2.0",
|
||||||
|
"sentence-transformers>=2.2.0",
|
||||||
|
"torch>=2.0.0",
|
||||||
|
"numpy>=1.24.0",
|
||||||
|
"langchain>=0.1.0",
|
||||||
|
"langchain-milvus>=0.1.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
backend = "main:main"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
29
backend/requirements.txt
Normal file
29
backend/requirements.txt
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
fastapi>=0.110.0
|
||||||
|
uvicorn[standard]>=0.27.0
|
||||||
|
python-multipart>=0.0.9
|
||||||
|
|
||||||
|
pydantic>=2.0.0
|
||||||
|
pydantic-settings>=2.0.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
loguru>=0.7.0
|
||||||
|
|
||||||
|
httpx>=0.25.0
|
||||||
|
tiktoken>=0.5.0
|
||||||
|
tenacity>=8.2.0
|
||||||
|
|
||||||
|
pymilvus>=2.4.0
|
||||||
|
minio>=7.1.0
|
||||||
|
|
||||||
|
pymupdf>=1.24.0
|
||||||
|
python-docx>=1.1.0
|
||||||
|
|
||||||
|
FlagEmbedding>=1.2.0
|
||||||
|
sentence-transformers>=2.2.0
|
||||||
|
torch>=2.0.0
|
||||||
|
numpy>=1.24.0
|
||||||
|
|
||||||
|
langchain>=0.1.0
|
||||||
|
langchain-milvus>=0.1.0
|
||||||
|
|
||||||
|
pytest>=7.4.0
|
||||||
|
pytest-asyncio>=0.21.0
|
||||||
0
backend/tests/__init__.py
Normal file
0
backend/tests/__init__.py
Normal file
8
backend/uv.lock
generated
Normal file
8
backend/uv.lock
generated
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
version = 1
|
||||||
|
revision = 3
|
||||||
|
requires-python = ">=3.9"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "backend"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = { virtual = "." }
|
||||||
24
frontend/.gitignore
vendored
Normal file
24
frontend/.gitignore
vendored
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
# Logs
|
||||||
|
logs
|
||||||
|
*.log
|
||||||
|
npm-debug.log*
|
||||||
|
yarn-debug.log*
|
||||||
|
yarn-error.log*
|
||||||
|
pnpm-debug.log*
|
||||||
|
lerna-debug.log*
|
||||||
|
|
||||||
|
node_modules
|
||||||
|
dist
|
||||||
|
dist-ssr
|
||||||
|
*.local
|
||||||
|
|
||||||
|
# Editor directories and files
|
||||||
|
.vscode/*
|
||||||
|
!.vscode/extensions.json
|
||||||
|
.idea
|
||||||
|
.DS_Store
|
||||||
|
*.suo
|
||||||
|
*.ntvs*
|
||||||
|
*.njsproj
|
||||||
|
*.sln
|
||||||
|
*.sw?
|
||||||
140
frontend/README.md
Normal file
140
frontend/README.md
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
# Regulation RAG - 法规合规智能分析系统
|
||||||
|
|
||||||
|
一个基于 RAG (Retrieval-Augmented Generation) 技术的法规合规智能分析与问答系统原型。支持文档上传、语义分段分析、法规匹配标注、风险评估及交互式合规问答。
|
||||||
|
|
||||||
|
## 功能特性
|
||||||
|
|
||||||
|
### 📋 合规分析 (Compliance)
|
||||||
|
- **文档上传与解析** - 支持 PDF、DOCX、TXT 格式文档上传
|
||||||
|
- **AI 语义分段** - 自动识别文档语义段落与设计意图
|
||||||
|
- **法规匹配标注** - 根据段落内容匹配相关法规条款,计算相关性得分
|
||||||
|
- **风险仪表盘** - 可折叠的风险评估面板,展示合规评分、高风险项、待修改段落
|
||||||
|
- **优先行动建议** - 基于风险等级生成修改建议列表
|
||||||
|
|
||||||
|
### 💬 RAG 对话 (Rag Chat)
|
||||||
|
- **法规问答** - 基于向量检索的法规问答系统
|
||||||
|
- **快捷问题** - 预设常用问题快速提问
|
||||||
|
- **检索片段展示** - 右侧面板实时显示引用的法规片段及相似度得分
|
||||||
|
- **答案重生成** - 支持重新生成上一个回答
|
||||||
|
|
||||||
|
### 📚 文档管理 (Docs)
|
||||||
|
- **文档上传** - 支持多种格式文档导入
|
||||||
|
- **索引状态** - 显示已索引文档列表及分块数量
|
||||||
|
- **处理流水线** - 展示 Load → Parse → Chunk → Embed → Store 全流程状态
|
||||||
|
|
||||||
|
### 📊 系统状态 (Status)
|
||||||
|
- **系统统计** - 文档数、分块数、向量维度、条款数
|
||||||
|
- **配置展示** - LLM 模型、Embedding 模型、向量数据库、检索策略等参数
|
||||||
|
|
||||||
|
## 技术栈
|
||||||
|
|
||||||
|
- **前端框架**: React 19 + TypeScript
|
||||||
|
- **构建工具**: Vite 8
|
||||||
|
- **样式方案**: TailwindCSS 4
|
||||||
|
- **状态管理**: React Context API
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
### 安装依赖
|
||||||
|
|
||||||
|
```bash
|
||||||
|
npm install
|
||||||
|
```
|
||||||
|
|
||||||
|
### 开发模式
|
||||||
|
|
||||||
|
```bash
|
||||||
|
npm run dev
|
||||||
|
```
|
||||||
|
|
||||||
|
启动本地开发服务器,默认访问 `http://localhost:5173`
|
||||||
|
|
||||||
|
### 构建生产版本
|
||||||
|
|
||||||
|
```bash
|
||||||
|
npm run build
|
||||||
|
```
|
||||||
|
|
||||||
|
### 预览生产版本
|
||||||
|
|
||||||
|
```bash
|
||||||
|
npm run preview
|
||||||
|
```
|
||||||
|
|
||||||
|
## 项目结构
|
||||||
|
|
||||||
|
```
|
||||||
|
src/
|
||||||
|
├── components/
|
||||||
|
│ ├── common/ # 通用组件 (Logo, Pattern, ThemeToggle)
|
||||||
|
│ ├── layout/ # 布局组件 (Header, Tabs, Content)
|
||||||
|
│ └── ui/ # UI 基础组件 (Badge, Button, Card, Input...)
|
||||||
|
├── contexts/ # React Context (AppContext, ThemeContext)
|
||||||
|
├── data/ # Mock 数据
|
||||||
|
├── pages/
|
||||||
|
│ ├── Compliance/ # 合规分析页面
|
||||||
|
│ ├── Docs/ # 文档管理页面
|
||||||
|
│ ├── RagChat/ # RAG 对话页面
|
||||||
|
│ └── Status/ # 系统状态页面
|
||||||
|
├── styles/ # 全局样式
|
||||||
|
├── types/ # TypeScript 类型定义
|
||||||
|
└── App.tsx # 应用入口
|
||||||
|
```
|
||||||
|
|
||||||
|
## 核心类型定义
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// 法规信息
|
||||||
|
interface Regulation {
|
||||||
|
id: number;
|
||||||
|
name: string; // 法规名称
|
||||||
|
clause: string; // 条款编号
|
||||||
|
score: number; // 相关性得分 (0-1)
|
||||||
|
matchKeyword: string; // 匹配关键词
|
||||||
|
category: 'high' | 'medium' | 'low'; // 相关性等级
|
||||||
|
fullContent: string; // 法规完整内容
|
||||||
|
}
|
||||||
|
|
||||||
|
// 语义段落
|
||||||
|
interface ComplianceChunk {
|
||||||
|
id: number;
|
||||||
|
index: number; // 段落序号
|
||||||
|
intent: string; // 段落意图
|
||||||
|
startPos: number; // 文档起始位置
|
||||||
|
endPos: number; // 文档结束位置
|
||||||
|
content: string; // 段落内容
|
||||||
|
regulations: Regulation[];
|
||||||
|
}
|
||||||
|
|
||||||
|
// 风险仪表盘数据
|
||||||
|
interface RiskDashboardData {
|
||||||
|
score: number; // 合规评分
|
||||||
|
highRiskCount: number; // 高风险项数量
|
||||||
|
mediumRiskCount: number;
|
||||||
|
lowRiskCount: number;
|
||||||
|
needFixSegments: number;// 待修改段落数
|
||||||
|
status: 'pass' | 'warning' | 'fail';
|
||||||
|
statusLabel: string; // 状态标签
|
||||||
|
segmentRisks: SegmentRisk[];
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 界面设计
|
||||||
|
|
||||||
|
- **主题**: 支持深色/浅色主题切换
|
||||||
|
- **配色**: T-Mobile 品牌色系 (E20074 主色调)
|
||||||
|
- **布局**: 响应式设计,固定 Header + Tab 导航
|
||||||
|
- **交互**: 卡片式布局,悬浮效果,进度条动画
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
|
||||||
|
本项目为原型演示系统,使用 Mock 数据模拟后端服务。生产环境需接入:
|
||||||
|
|
||||||
|
- 文档解析服务
|
||||||
|
- 向量数据库 (如 ChromaDB、Milvus)
|
||||||
|
- Embedding 模型 API
|
||||||
|
- LLM 服务 API
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT
|
||||||
22
frontend/eslint.config.js
Normal file
22
frontend/eslint.config.js
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import js from '@eslint/js'
|
||||||
|
import globals from 'globals'
|
||||||
|
import reactHooks from 'eslint-plugin-react-hooks'
|
||||||
|
import reactRefresh from 'eslint-plugin-react-refresh'
|
||||||
|
import tseslint from 'typescript-eslint'
|
||||||
|
import { defineConfig, globalIgnores } from 'eslint/config'
|
||||||
|
|
||||||
|
export default defineConfig([
|
||||||
|
globalIgnores(['dist']),
|
||||||
|
{
|
||||||
|
files: ['**/*.{ts,tsx}'],
|
||||||
|
extends: [
|
||||||
|
js.configs.recommended,
|
||||||
|
tseslint.configs.recommended,
|
||||||
|
reactHooks.configs.flat.recommended,
|
||||||
|
reactRefresh.configs.vite,
|
||||||
|
],
|
||||||
|
languageOptions: {
|
||||||
|
globals: globals.browser,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
])
|
||||||
1161
frontend/index.html
1161
frontend/index.html
File diff suppressed because it is too large
Load Diff
3150
frontend/package-lock.json
generated
Normal file
3150
frontend/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
34
frontend/package.json
Normal file
34
frontend/package.json
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
{
|
||||||
|
"name": "regulation-rag",
|
||||||
|
"private": true,
|
||||||
|
"version": "0.0.0",
|
||||||
|
"type": "module",
|
||||||
|
"scripts": {
|
||||||
|
"dev": "vite",
|
||||||
|
"build": "tsc -b && vite build",
|
||||||
|
"lint": "eslint .",
|
||||||
|
"preview": "vite preview"
|
||||||
|
},
|
||||||
|
"dependencies": {
|
||||||
|
"react": "^19.2.5",
|
||||||
|
"react-dom": "^19.2.5"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"@eslint/js": "^10.0.1",
|
||||||
|
"@tailwindcss/postcss": "^4.2.4",
|
||||||
|
"@types/node": "^24.12.2",
|
||||||
|
"@types/react": "^19.2.14",
|
||||||
|
"@types/react-dom": "^19.2.3",
|
||||||
|
"@vitejs/plugin-react": "^6.0.1",
|
||||||
|
"autoprefixer": "^10.5.0",
|
||||||
|
"eslint": "^10.2.1",
|
||||||
|
"eslint-plugin-react-hooks": "^7.1.1",
|
||||||
|
"eslint-plugin-react-refresh": "^0.5.2",
|
||||||
|
"globals": "^17.5.0",
|
||||||
|
"postcss": "^8.5.14",
|
||||||
|
"tailwindcss": "^4.2.4",
|
||||||
|
"typescript": "~6.0.2",
|
||||||
|
"typescript-eslint": "^8.58.2",
|
||||||
|
"vite": "^8.0.10"
|
||||||
|
}
|
||||||
|
}
|
||||||
6
frontend/postcss.config.js
Normal file
6
frontend/postcss.config.js
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
export default {
|
||||||
|
plugins: {
|
||||||
|
'@tailwindcss/postcss': {},
|
||||||
|
autoprefixer: {},
|
||||||
|
},
|
||||||
|
}
|
||||||
1
frontend/public/favicon.svg
Normal file
1
frontend/public/favicon.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 9.3 KiB |
24
frontend/public/icons.svg
Normal file
24
frontend/public/icons.svg
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<symbol id="bluesky-icon" viewBox="0 0 16 17">
|
||||||
|
<g clip-path="url(#bluesky-clip)"><path fill="#08060d" d="M7.75 7.735c-.693-1.348-2.58-3.86-4.334-5.097-1.68-1.187-2.32-.981-2.74-.79C.188 2.065.1 2.812.1 3.251s.241 3.602.398 4.13c.52 1.744 2.367 2.333 4.07 2.145-2.495.37-4.71 1.278-1.805 4.512 3.196 3.309 4.38-.71 4.987-2.746.608 2.036 1.307 5.91 4.93 2.746 2.72-2.746.747-4.143-1.747-4.512 1.702.189 3.55-.4 4.07-2.145.156-.528.397-3.691.397-4.13s-.088-1.186-.575-1.406c-.42-.19-1.06-.395-2.741.79-1.755 1.24-3.64 3.752-4.334 5.099"/></g>
|
||||||
|
<defs><clipPath id="bluesky-clip"><path fill="#fff" d="M.1.85h15.3v15.3H.1z"/></clipPath></defs>
|
||||||
|
</symbol>
|
||||||
|
<symbol id="discord-icon" viewBox="0 0 20 19">
|
||||||
|
<path fill="#08060d" d="M16.224 3.768a14.5 14.5 0 0 0-3.67-1.153c-.158.286-.343.67-.47.976a13.5 13.5 0 0 0-4.067 0c-.128-.306-.317-.69-.476-.976A14.4 14.4 0 0 0 3.868 3.77C1.546 7.28.916 10.703 1.231 14.077a14.7 14.7 0 0 0 4.5 2.306q.545-.748.965-1.587a9.5 9.5 0 0 1-1.518-.74q.191-.14.372-.293c2.927 1.369 6.107 1.369 8.999 0q.183.152.372.294-.723.437-1.52.74.418.838.963 1.588a14.6 14.6 0 0 0 4.504-2.308c.37-3.911-.63-7.302-2.644-10.309m-9.13 8.234c-.878 0-1.599-.82-1.599-1.82 0-.998.705-1.82 1.6-1.82.894 0 1.614.82 1.599 1.82.001 1-.705 1.82-1.6 1.82m5.91 0c-.878 0-1.599-.82-1.599-1.82 0-.998.705-1.82 1.6-1.82.893 0 1.614.82 1.599 1.82 0 1-.706 1.82-1.6 1.82"/>
|
||||||
|
</symbol>
|
||||||
|
<symbol id="documentation-icon" viewBox="0 0 21 20">
|
||||||
|
<path fill="none" stroke="#aa3bff" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.35" d="m15.5 13.333 1.533 1.322c.645.555.967.833.967 1.178s-.322.623-.967 1.179L15.5 18.333m-3.333-5-1.534 1.322c-.644.555-.966.833-.966 1.178s.322.623.966 1.179l1.534 1.321"/>
|
||||||
|
<path fill="none" stroke="#aa3bff" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.35" d="M17.167 10.836v-4.32c0-1.41 0-2.117-.224-2.68-.359-.906-1.118-1.621-2.08-1.96-.599-.21-1.349-.21-2.848-.21-2.623 0-3.935 0-4.983.369-1.684.591-3.013 1.842-3.641 3.428C3 6.449 3 7.684 3 10.154v2.122c0 2.558 0 3.838.706 4.726q.306.383.713.671c.76.536 1.79.64 3.581.66"/>
|
||||||
|
<path fill="none" stroke="#aa3bff" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.35" d="M3 10a2.78 2.78 0 0 1 2.778-2.778c.555 0 1.209.097 1.748-.047.48-.129.854-.503.982-.982.145-.54.048-1.194.048-1.749a2.78 2.78 0 0 1 2.777-2.777"/>
|
||||||
|
</symbol>
|
||||||
|
<symbol id="github-icon" viewBox="0 0 19 19">
|
||||||
|
<path fill="#08060d" fill-rule="evenodd" d="M9.356 1.85C5.05 1.85 1.57 5.356 1.57 9.694a7.84 7.84 0 0 0 5.324 7.44c.387.079.528-.168.528-.376 0-.182-.013-.805-.013-1.454-2.165.467-2.616-.935-2.616-.935-.349-.91-.864-1.143-.864-1.143-.71-.48.051-.48.051-.48.787.051 1.2.805 1.2.805.695 1.194 1.817.857 2.268.649.064-.507.27-.857.49-1.052-1.728-.182-3.545-.857-3.545-3.87 0-.857.31-1.558.8-2.104-.078-.195-.349-1 .077-2.078 0 0 .657-.208 2.14.805a7.5 7.5 0 0 1 1.946-.26c.657 0 1.328.092 1.946.26 1.483-1.013 2.14-.805 2.14-.805.426 1.078.155 1.883.078 2.078.502.546.799 1.247.799 2.104 0 3.013-1.818 3.675-3.558 3.87.284.247.528.714.528 1.454 0 1.052-.012 1.896-.012 2.156 0 .208.142.455.528.377a7.84 7.84 0 0 0 5.324-7.441c.013-4.338-3.48-7.844-7.773-7.844" clip-rule="evenodd"/>
|
||||||
|
</symbol>
|
||||||
|
<symbol id="social-icon" viewBox="0 0 20 20">
|
||||||
|
<path fill="none" stroke="#aa3bff" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.35" d="M12.5 6.667a4.167 4.167 0 1 0-8.334 0 4.167 4.167 0 0 0 8.334 0"/>
|
||||||
|
<path fill="none" stroke="#aa3bff" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.35" d="M2.5 16.667a5.833 5.833 0 0 1 8.75-5.053m3.837.474.513 1.035c.07.144.257.282.414.309l.93.155c.596.1.736.536.307.965l-.723.73a.64.64 0 0 0-.152.531l.207.903c.164.715-.213.991-.84.618l-.872-.52a.63.63 0 0 0-.577 0l-.872.52c-.624.373-1.003.094-.84-.618l.207-.903a.64.64 0 0 0-.152-.532l-.723-.729c-.426-.43-.289-.864.306-.964l.93-.156a.64.64 0 0 0 .412-.31l.513-1.034c.28-.562.735-.562 1.012 0"/>
|
||||||
|
</symbol>
|
||||||
|
<symbol id="x-icon" viewBox="0 0 19 19">
|
||||||
|
<path fill="#08060d" fill-rule="evenodd" d="M1.893 1.98c.052.072 1.245 1.769 2.653 3.77l2.892 4.114c.183.261.333.48.333.486s-.068.089-.152.183l-.522.593-.765.867-3.597 4.087c-.375.426-.734.834-.798.905a1 1 0 0 0-.118.148c0 .01.236.017.664.017h.663l.729-.83c.4-.457.796-.906.879-.999a692 692 0 0 0 1.794-2.038c.034-.037.301-.34.594-.675l.551-.624.345-.392a7 7 0 0 1 .34-.374c.006 0 .93 1.306 2.052 2.903l2.084 2.965.045.063h2.275c1.87 0 2.273-.003 2.266-.021-.008-.02-1.098-1.572-3.894-5.547-2.013-2.862-2.28-3.246-2.273-3.266.008-.019.282-.332 2.085-2.38l2-2.274 1.567-1.782c.022-.028-.016-.03-.65-.03h-.674l-.3.342a871 871 0 0 1-1.782 2.025c-.067.075-.405.458-.75.852a100 100 0 0 1-.803.91c-.148.172-.299.344-.99 1.127-.304.343-.32.358-.345.327-.015-.019-.904-1.282-1.976-2.808L6.365 1.85H1.8zm1.782.91 8.078 11.294c.772 1.08 1.413 1.973 1.425 1.984.016.017.241.02 1.05.017l1.03-.004-2.694-3.766L7.796 5.75 5.722 2.852l-1.039-.004-1.039-.004z" clip-rule="evenodd"/>
|
||||||
|
</symbol>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 4.9 KiB |
BIN
frontend/public/logo/t_mobile_logo_transparent.png
Normal file
BIN
frontend/public/logo/t_mobile_logo_transparent.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 15 KiB |
184
frontend/src/App.css
Normal file
184
frontend/src/App.css
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
.counter {
|
||||||
|
font-size: 16px;
|
||||||
|
padding: 5px 10px;
|
||||||
|
border-radius: 5px;
|
||||||
|
color: var(--accent);
|
||||||
|
background: var(--accent-bg);
|
||||||
|
border: 2px solid transparent;
|
||||||
|
transition: border-color 0.3s;
|
||||||
|
margin-bottom: 24px;
|
||||||
|
|
||||||
|
&:hover {
|
||||||
|
border-color: var(--accent-border);
|
||||||
|
}
|
||||||
|
&:focus-visible {
|
||||||
|
outline: 2px solid var(--accent);
|
||||||
|
outline-offset: 2px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.hero {
|
||||||
|
position: relative;
|
||||||
|
|
||||||
|
.base,
|
||||||
|
.framework,
|
||||||
|
.vite {
|
||||||
|
inset-inline: 0;
|
||||||
|
margin: 0 auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
.base {
|
||||||
|
width: 170px;
|
||||||
|
position: relative;
|
||||||
|
z-index: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.framework,
|
||||||
|
.vite {
|
||||||
|
position: absolute;
|
||||||
|
}
|
||||||
|
|
||||||
|
.framework {
|
||||||
|
z-index: 1;
|
||||||
|
top: 34px;
|
||||||
|
height: 28px;
|
||||||
|
transform: perspective(2000px) rotateZ(300deg) rotateX(44deg) rotateY(39deg)
|
||||||
|
scale(1.4);
|
||||||
|
}
|
||||||
|
|
||||||
|
.vite {
|
||||||
|
z-index: 0;
|
||||||
|
top: 107px;
|
||||||
|
height: 26px;
|
||||||
|
width: auto;
|
||||||
|
transform: perspective(2000px) rotateZ(300deg) rotateX(40deg) rotateY(39deg)
|
||||||
|
scale(0.8);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#center {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 25px;
|
||||||
|
place-content: center;
|
||||||
|
place-items: center;
|
||||||
|
flex-grow: 1;
|
||||||
|
|
||||||
|
@media (max-width: 1024px) {
|
||||||
|
padding: 32px 20px 24px;
|
||||||
|
gap: 18px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#next-steps {
|
||||||
|
display: flex;
|
||||||
|
border-top: 1px solid var(--border);
|
||||||
|
text-align: left;
|
||||||
|
|
||||||
|
& > div {
|
||||||
|
flex: 1 1 0;
|
||||||
|
padding: 32px;
|
||||||
|
@media (max-width: 1024px) {
|
||||||
|
padding: 24px 20px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.icon {
|
||||||
|
margin-bottom: 16px;
|
||||||
|
width: 22px;
|
||||||
|
height: 22px;
|
||||||
|
}
|
||||||
|
|
||||||
|
@media (max-width: 1024px) {
|
||||||
|
flex-direction: column;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#docs {
|
||||||
|
border-right: 1px solid var(--border);
|
||||||
|
|
||||||
|
@media (max-width: 1024px) {
|
||||||
|
border-right: none;
|
||||||
|
border-bottom: 1px solid var(--border);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#next-steps ul {
|
||||||
|
list-style: none;
|
||||||
|
padding: 0;
|
||||||
|
display: flex;
|
||||||
|
gap: 8px;
|
||||||
|
margin: 32px 0 0;
|
||||||
|
|
||||||
|
.logo {
|
||||||
|
height: 18px;
|
||||||
|
}
|
||||||
|
|
||||||
|
a {
|
||||||
|
color: var(--text-h);
|
||||||
|
font-size: 16px;
|
||||||
|
border-radius: 6px;
|
||||||
|
background: var(--social-bg);
|
||||||
|
display: flex;
|
||||||
|
padding: 6px 12px;
|
||||||
|
align-items: center;
|
||||||
|
gap: 8px;
|
||||||
|
text-decoration: none;
|
||||||
|
transition: box-shadow 0.3s;
|
||||||
|
|
||||||
|
&:hover {
|
||||||
|
box-shadow: var(--shadow);
|
||||||
|
}
|
||||||
|
.button-icon {
|
||||||
|
height: 18px;
|
||||||
|
width: 18px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@media (max-width: 1024px) {
|
||||||
|
margin-top: 20px;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
justify-content: center;
|
||||||
|
|
||||||
|
li {
|
||||||
|
flex: 1 1 calc(50% - 8px);
|
||||||
|
}
|
||||||
|
|
||||||
|
a {
|
||||||
|
width: 100%;
|
||||||
|
justify-content: center;
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#spacer {
|
||||||
|
height: 88px;
|
||||||
|
border-top: 1px solid var(--border);
|
||||||
|
@media (max-width: 1024px) {
|
||||||
|
height: 48px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.ticks {
|
||||||
|
position: relative;
|
||||||
|
width: 100%;
|
||||||
|
|
||||||
|
&::before,
|
||||||
|
&::after {
|
||||||
|
content: '';
|
||||||
|
position: absolute;
|
||||||
|
top: -4.5px;
|
||||||
|
border: 5px solid transparent;
|
||||||
|
}
|
||||||
|
|
||||||
|
&::before {
|
||||||
|
left: 0;
|
||||||
|
border-left-color: var(--border);
|
||||||
|
}
|
||||||
|
&::after {
|
||||||
|
right: 0;
|
||||||
|
border-right-color: var(--border);
|
||||||
|
}
|
||||||
|
}
|
||||||
48
frontend/src/App.tsx
Normal file
48
frontend/src/App.tsx
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import './styles/globals.css';
|
||||||
|
import { ThemeProvider, AppProvider, useApp, useTheme } from './contexts';
|
||||||
|
import { Header, Tabs } from './components/layout';
|
||||||
|
import { CompliancePage } from './pages/Compliance';
|
||||||
|
import { DocsPage } from './pages/Docs';
|
||||||
|
import { StatusPage } from './pages/Status';
|
||||||
|
import { RagChatPage } from './pages/RagChat';
|
||||||
|
|
||||||
|
const PageContent = () => {
|
||||||
|
const { activeTab } = useApp();
|
||||||
|
|
||||||
|
switch (activeTab) {
|
||||||
|
case 'docs':
|
||||||
|
return <DocsPage />;
|
||||||
|
case 'compliance':
|
||||||
|
return <CompliancePage />;
|
||||||
|
case 'status':
|
||||||
|
return <StatusPage />;
|
||||||
|
case 'rag':
|
||||||
|
return <RagChatPage />;
|
||||||
|
default:
|
||||||
|
return <CompliancePage />;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const AppContent = () => {
|
||||||
|
const { theme } = useTheme();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="h-full flex flex-col min-h-screen" style={{ backgroundColor: theme.bg }}>
|
||||||
|
<Header />
|
||||||
|
<Tabs />
|
||||||
|
<PageContent />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
function App() {
|
||||||
|
return (
|
||||||
|
<ThemeProvider>
|
||||||
|
<AppProvider>
|
||||||
|
<AppContent />
|
||||||
|
</AppProvider>
|
||||||
|
</ThemeProvider>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export default App;
|
||||||
43
frontend/src/api/compliance.ts
Normal file
43
frontend/src/api/compliance.ts
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import { streamSSE, type ComplianceResult, type SSEMessage } from './index';
|
||||||
|
|
||||||
|
// Upload and analyze a design document
|
||||||
|
export async function analyzeDocument(file: File): Promise<{ task_id: string; status: string }> {
|
||||||
|
const formData = new FormData();
|
||||||
|
formData.append('file', file);
|
||||||
|
|
||||||
|
const response = await fetch('/api/compliance/analyze', {
|
||||||
|
method: 'POST',
|
||||||
|
body: formData,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Upload failed: ${response.status}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
return response.json();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get analysis result
|
||||||
|
export async function getComplianceResult(taskId: string): Promise<ComplianceResult | { status: string; message: string }> {
|
||||||
|
const response = await fetch(`/api/compliance/result/${taskId}`);
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Get result failed: ${response.status}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
return response.json();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compliance chat with SSE streaming
|
||||||
|
export function complianceChat(
|
||||||
|
segmentId: number,
|
||||||
|
query: string,
|
||||||
|
onMessage: (data: SSEMessage) => void,
|
||||||
|
onError?: (error: Error) => void,
|
||||||
|
onComplete?: () => void
|
||||||
|
): void {
|
||||||
|
streamSSE(`/compliance/chat/${segmentId}`, { query }, onMessage, onError, onComplete);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Export types
|
||||||
|
export type { ComplianceResult, SSEMessage };
|
||||||
144
frontend/src/api/docs.ts
Normal file
144
frontend/src/api/docs.ts
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
import type { DocInfo, DocListResponse, DocUploadResponse } from './index';
|
||||||
|
|
||||||
|
const DOCS_API_BASE = '/api/v1';
|
||||||
|
|
||||||
|
interface BackendDocumentItem {
|
||||||
|
doc_id: string;
|
||||||
|
filename: string;
|
||||||
|
size: number;
|
||||||
|
object_name: string;
|
||||||
|
download_url: string;
|
||||||
|
last_modified?: string | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface BackendDocumentListResponse {
|
||||||
|
documents: BackendDocumentItem[];
|
||||||
|
total: number;
|
||||||
|
limit?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface BackendKnowledgeResult {
|
||||||
|
id: number;
|
||||||
|
content: string;
|
||||||
|
score: number;
|
||||||
|
metadata: Record<string, unknown>;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface BackendKnowledgeResponse {
|
||||||
|
query: string;
|
||||||
|
total: number;
|
||||||
|
results: BackendKnowledgeResult[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface RegulationSearchItem {
|
||||||
|
id: number;
|
||||||
|
file: string;
|
||||||
|
clause: string;
|
||||||
|
score: number;
|
||||||
|
content: string;
|
||||||
|
tags: string[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface RegulationSearchResponse {
|
||||||
|
query: string;
|
||||||
|
total: number;
|
||||||
|
results: RegulationSearchItem[];
|
||||||
|
}
|
||||||
|
|
||||||
|
function formatFileSize(bytes: number): string {
|
||||||
|
if (!bytes) return '0 B';
|
||||||
|
if (bytes >= 1024 * 1024) return `${(bytes / (1024 * 1024)).toFixed(1)}MB`;
|
||||||
|
if (bytes >= 1024) return `${(bytes / 1024).toFixed(1)}KB`;
|
||||||
|
return `${bytes}B`;
|
||||||
|
}
|
||||||
|
|
||||||
|
function mapDoc(item: BackendDocumentItem): DocInfo {
|
||||||
|
return {
|
||||||
|
id: item.doc_id,
|
||||||
|
name: item.filename,
|
||||||
|
chunks: 0,
|
||||||
|
status: 'indexed',
|
||||||
|
created_at: item.last_modified || undefined,
|
||||||
|
download_url: `${DOCS_API_BASE}/documents/download/${item.doc_id}`,
|
||||||
|
size_text: formatFileSize(item.size),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function uploadDocument(file: File): Promise<DocUploadResponse> {
|
||||||
|
const formData = new FormData();
|
||||||
|
formData.append('file', file);
|
||||||
|
formData.append('doc_name', file.name);
|
||||||
|
formData.append('generate_summary', 'true');
|
||||||
|
|
||||||
|
const response = await fetch(`${DOCS_API_BASE}/documents/upload`, {
|
||||||
|
method: 'POST',
|
||||||
|
body: formData,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Upload failed: ${response.status}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
return {
|
||||||
|
doc_id: data.doc_id,
|
||||||
|
filename: data.doc_name || file.name,
|
||||||
|
size: file.size,
|
||||||
|
status: data.status,
|
||||||
|
num_chunks: data.num_chunks,
|
||||||
|
summary: data.summary,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function getDocumentList(): Promise<DocListResponse> {
|
||||||
|
const response = await fetch(`${DOCS_API_BASE}/documents/management-list`);
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`List failed: ${response.status}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json() as BackendDocumentListResponse;
|
||||||
|
return {
|
||||||
|
docs: data.documents.map(mapDoc),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function searchRegulations(query: string, topK: number = 8): Promise<RegulationSearchResponse> {
|
||||||
|
const response = await fetch(`${DOCS_API_BASE}/knowledge/retrieval`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
Accept: 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify({ query, top_k: topK }),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Search failed: ${response.status}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json() as BackendKnowledgeResponse;
|
||||||
|
return {
|
||||||
|
query: data.query,
|
||||||
|
total: data.total,
|
||||||
|
results: data.results.map((item) => {
|
||||||
|
const metadata = item.metadata || {};
|
||||||
|
return {
|
||||||
|
id: item.id,
|
||||||
|
file: String(metadata.doc_name || metadata.filename || metadata.source || '法规知识库'),
|
||||||
|
clause: String(metadata.chunk_type || metadata.section || metadata.clause || '法规片段'),
|
||||||
|
score: item.score,
|
||||||
|
content: item.content,
|
||||||
|
tags: [
|
||||||
|
metadata.regulation_type ? String(metadata.regulation_type) : '',
|
||||||
|
metadata.version ? `v${String(metadata.version)}` : '',
|
||||||
|
].filter(Boolean),
|
||||||
|
};
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getDocumentDownloadUrl(docId: string): string {
|
||||||
|
return `${DOCS_API_BASE}/documents/download/${docId}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
export type { DocInfo, DocListResponse, DocUploadResponse };
|
||||||
213
frontend/src/api/index.ts
Normal file
213
frontend/src/api/index.ts
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
// API configuration - 使用相对路径,通过 Vite proxy 转发
|
||||||
|
const API_BASE_URL = '/api';
|
||||||
|
|
||||||
|
// Helper function for fetch requests
|
||||||
|
async function fetchAPI<T>(endpoint: string, options?: RequestInit): Promise<T> {
|
||||||
|
const response = await fetch(`${API_BASE_URL}${endpoint}`, {
|
||||||
|
...options,
|
||||||
|
headers: {
|
||||||
|
...options?.headers,
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`API Error: ${response.status} ${response.statusText}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
return response.json();
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSE helper for streaming responses
|
||||||
|
function createSSEConnection(endpoint: string, body: unknown): EventSource {
|
||||||
|
// For POST requests with SSE, we need to use fetch with ReadableStream
|
||||||
|
// since EventSource only supports GET requests
|
||||||
|
const url = `${API_BASE_URL}${endpoint}`;
|
||||||
|
|
||||||
|
return new EventSource(url); // This won't work for POST, we'll handle it differently
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSE streaming helper for POST requests
|
||||||
|
async function streamSSE(
|
||||||
|
endpoint: string,
|
||||||
|
body: unknown,
|
||||||
|
onMessage: (data: unknown) => void,
|
||||||
|
onError?: (error: Error) => void,
|
||||||
|
onComplete?: () => void
|
||||||
|
): Promise<void> {
|
||||||
|
const url = `${API_BASE_URL}${endpoint}`;
|
||||||
|
|
||||||
|
const response = await fetch(url, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'Accept': 'text/event-stream',
|
||||||
|
},
|
||||||
|
body: JSON.stringify(body),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
if (onError) {
|
||||||
|
onError(new Error(`HTTP error! status: ${response.status}`));
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const reader = response.body?.getReader();
|
||||||
|
if (!reader) {
|
||||||
|
if (onError) {
|
||||||
|
onError(new Error('No response body'));
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
let buffer = '';
|
||||||
|
|
||||||
|
try {
|
||||||
|
while (true) {
|
||||||
|
const { done, value } = await reader.read();
|
||||||
|
if (done) break;
|
||||||
|
|
||||||
|
buffer += decoder.decode(value, { stream: true });
|
||||||
|
|
||||||
|
// Process SSE events
|
||||||
|
const lines = buffer.split('\n');
|
||||||
|
buffer = '';
|
||||||
|
|
||||||
|
for (const line of lines) {
|
||||||
|
if (line.startsWith('data:')) {
|
||||||
|
const data = line.slice(5).trim();
|
||||||
|
if (data) {
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(data);
|
||||||
|
onMessage(parsed);
|
||||||
|
} catch {
|
||||||
|
// Handle non-JSON data
|
||||||
|
onMessage({ type: 'raw', text: data });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (onComplete) {
|
||||||
|
onComplete();
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
if (onError) {
|
||||||
|
onError(error instanceof Error ? error : new Error(String(error)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Export types
|
||||||
|
export interface DocInfo {
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
chunks: number;
|
||||||
|
status: string;
|
||||||
|
created_at?: string;
|
||||||
|
download_url?: string;
|
||||||
|
size_text?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface DocListResponse {
|
||||||
|
docs: DocInfo[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface DocUploadResponse {
|
||||||
|
doc_id: string;
|
||||||
|
filename: string;
|
||||||
|
size: number;
|
||||||
|
status: string;
|
||||||
|
num_chunks?: number;
|
||||||
|
summary?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface QuickQuestion {
|
||||||
|
id: string;
|
||||||
|
question: string;
|
||||||
|
category: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface QuickQuestionsResponse {
|
||||||
|
questions: QuickQuestion[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface RetrievedDoc {
|
||||||
|
id: string;
|
||||||
|
score: number;
|
||||||
|
preview: string;
|
||||||
|
doc_name: string;
|
||||||
|
clause: string;
|
||||||
|
doc_id?: string;
|
||||||
|
download_url?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface SSEMessage {
|
||||||
|
type: string;
|
||||||
|
text?: string;
|
||||||
|
docs?: RetrievedDoc[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface Regulation {
|
||||||
|
id: number;
|
||||||
|
name: string;
|
||||||
|
clause: string;
|
||||||
|
score: number;
|
||||||
|
match_keyword: string;
|
||||||
|
category: string;
|
||||||
|
full_content: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ComplianceSegment {
|
||||||
|
id: number;
|
||||||
|
index: number;
|
||||||
|
intent: string;
|
||||||
|
start_pos: number;
|
||||||
|
end_pos: number;
|
||||||
|
content: string;
|
||||||
|
risk_level: string;
|
||||||
|
regulations: Regulation[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface RiskDashboard {
|
||||||
|
score: number;
|
||||||
|
high_risk_count: number;
|
||||||
|
medium_risk_count: number;
|
||||||
|
low_risk_count: number;
|
||||||
|
need_fix_segments: number;
|
||||||
|
status: string;
|
||||||
|
status_label: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PriorityAction {
|
||||||
|
regulation: string;
|
||||||
|
issue: string;
|
||||||
|
suggestion: string;
|
||||||
|
severity: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ComplianceResult {
|
||||||
|
task_id: string;
|
||||||
|
dashboard: RiskDashboard;
|
||||||
|
segments: ComplianceSegment[];
|
||||||
|
priority_actions: PriorityAction[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface SystemStats {
|
||||||
|
docs: number;
|
||||||
|
chunks: number;
|
||||||
|
vectors: number;
|
||||||
|
segments: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface SystemConfig {
|
||||||
|
llm: { model: string };
|
||||||
|
embedding: { model: string; dimension: number };
|
||||||
|
milvus: { host: string; port: number };
|
||||||
|
retrieval: { vector_top_k: number; final_top_k: number };
|
||||||
|
}
|
||||||
|
|
||||||
|
export { fetchAPI, streamSSE, API_BASE_URL };
|
||||||
114
frontend/src/api/rag.ts
Normal file
114
frontend/src/api/rag.ts
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
import type { QuickQuestionsResponse, SSEMessage } from './index';
|
||||||
|
|
||||||
|
const AGENT_API_BASE = '/api/v1';
|
||||||
|
|
||||||
|
export async function getQuickQuestions(): Promise<QuickQuestionsResponse> {
|
||||||
|
return {
|
||||||
|
questions: [
|
||||||
|
{ id: '1', question: '请总结最新入库法规对电池安全的核心要求', category: '法规解读' },
|
||||||
|
{ id: '2', question: '我上传的制度文档与新能源法规有哪些潜在冲突?', category: '差距分析' },
|
||||||
|
{ id: '3', question: '请给出法规依据,并按条款列出整改建议', category: '整改建议' },
|
||||||
|
{ id: '4', question: '请解释 UN-ECE 与 GB 标准在网络安全方面的差异', category: '标准对比' },
|
||||||
|
],
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseSSEChunk(raw: string, onMessage: (data: SSEMessage) => void) {
|
||||||
|
const blocks = raw.split('\n\n');
|
||||||
|
for (const block of blocks) {
|
||||||
|
if (!block.trim()) continue;
|
||||||
|
|
||||||
|
let eventName = 'message';
|
||||||
|
const dataLines: string[] = [];
|
||||||
|
|
||||||
|
for (const line of block.split('\n')) {
|
||||||
|
if (line.startsWith('event:')) {
|
||||||
|
eventName = line.slice(6).trim();
|
||||||
|
} else if (line.startsWith('data:')) {
|
||||||
|
dataLines.push(line.slice(5).trim());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const joined = dataLines.join('\n');
|
||||||
|
if (!joined) continue;
|
||||||
|
|
||||||
|
if (eventName === 'sources') {
|
||||||
|
try {
|
||||||
|
const docs = JSON.parse(joined) as Array<Record<string, unknown>>;
|
||||||
|
onMessage({
|
||||||
|
type: 'retrieved',
|
||||||
|
docs: docs.map((doc, index) => ({
|
||||||
|
id: String(doc.doc_id || doc.index || index + 1),
|
||||||
|
score: Number(doc.score || 0),
|
||||||
|
preview: String(doc.content || doc.snippet || ''),
|
||||||
|
doc_name: String(doc.doc_name || doc.filename || `引用 ${index + 1}`),
|
||||||
|
clause: String(doc.clause_number || doc.section_title || '法规片段'),
|
||||||
|
doc_id: doc.doc_id ? String(doc.doc_id) : undefined,
|
||||||
|
download_url: doc.doc_id ? `${AGENT_API_BASE}/documents/download/${String(doc.doc_id)}` : undefined,
|
||||||
|
})),
|
||||||
|
});
|
||||||
|
} catch {
|
||||||
|
// Ignore malformed source payloads.
|
||||||
|
}
|
||||||
|
} else if (eventName === 'content') {
|
||||||
|
onMessage({ type: 'chunk', text: joined });
|
||||||
|
} else if (eventName === 'done') {
|
||||||
|
onMessage({ type: 'done', text: joined });
|
||||||
|
} else if (eventName === 'error') {
|
||||||
|
onMessage({ type: 'error', text: joined });
|
||||||
|
} else if (eventName === 'status') {
|
||||||
|
onMessage({ type: 'status', text: joined });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function ragChat(
|
||||||
|
query: string,
|
||||||
|
topK: number = 5,
|
||||||
|
onMessage: (data: SSEMessage) => void,
|
||||||
|
onError?: (error: Error) => void,
|
||||||
|
onComplete?: () => void
|
||||||
|
): Promise<void> {
|
||||||
|
try {
|
||||||
|
const response = await fetch(`${AGENT_API_BASE}/agent/chat/stream`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
Accept: 'text/event-stream',
|
||||||
|
},
|
||||||
|
body: JSON.stringify({ query, top_k: topK }),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok || !response.body) {
|
||||||
|
throw new Error(`HTTP error! status: ${response.status}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const reader = response.body.getReader();
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
let buffer = '';
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
const { done, value } = await reader.read();
|
||||||
|
if (done) break;
|
||||||
|
|
||||||
|
buffer += decoder.decode(value, { stream: true });
|
||||||
|
const parts = buffer.split('\n\n');
|
||||||
|
buffer = parts.pop() || '';
|
||||||
|
parseSSEChunk(parts.join('\n\n'), onMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (buffer.trim()) {
|
||||||
|
parseSSEChunk(buffer, onMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (onComplete) {
|
||||||
|
onComplete();
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
if (onError) {
|
||||||
|
onError(error instanceof Error ? error : new Error(String(error)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export type { QuickQuestionsResponse, SSEMessage };
|
||||||
19
frontend/src/api/status.ts
Normal file
19
frontend/src/api/status.ts
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import { fetchAPI, type SystemStats, type SystemConfig } from './index';
|
||||||
|
|
||||||
|
// Get system statistics
|
||||||
|
export async function getSystemStats(): Promise<SystemStats> {
|
||||||
|
return fetchAPI<SystemStats>('/status/stats');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get system configuration
|
||||||
|
export async function getSystemConfig(): Promise<SystemConfig> {
|
||||||
|
return fetchAPI<SystemConfig>('/status/config');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get Milvus health status
|
||||||
|
export async function getMilvusHealth(): Promise<{ connected: boolean; collections: string[] }> {
|
||||||
|
return fetchAPI('/status/milvus/health');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Export types
|
||||||
|
export type { SystemStats, SystemConfig };
|
||||||
BIN
frontend/src/assets/hero.png
Normal file
BIN
frontend/src/assets/hero.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 13 KiB |
1
frontend/src/assets/react.svg
Normal file
1
frontend/src/assets/react.svg
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="iconify iconify--logos" width="35.93" height="32" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 228"><path fill="#00D8FF" d="M210.483 73.824a171.49 171.49 0 0 0-8.24-2.597c.465-1.9.893-3.777 1.273-5.621c6.238-30.281 2.16-54.676-11.769-62.708c-13.355-7.7-35.196.329-57.254 19.526a171.23 171.23 0 0 0-6.375 5.848a155.866 155.866 0 0 0-4.241-3.917C100.759 3.829 77.587-4.822 63.673 3.233C50.33 10.957 46.379 33.89 51.995 62.588a170.974 170.974 0 0 0 1.892 8.48c-3.28.932-6.445 1.924-9.474 2.98C17.309 83.498 0 98.307 0 113.668c0 15.865 18.582 31.778 46.812 41.427a145.52 145.52 0 0 0 6.921 2.165a167.467 167.467 0 0 0-2.01 9.138c-5.354 28.2-1.173 50.591 12.134 58.266c13.744 7.926 36.812-.22 59.273-19.855a145.567 145.567 0 0 0 5.342-4.923a168.064 168.064 0 0 0 6.92 6.314c21.758 18.722 43.246 26.282 56.54 18.586c13.731-7.949 18.194-32.003 12.4-61.268a145.016 145.016 0 0 0-1.535-6.842c1.62-.48 3.21-.974 4.76-1.488c29.348-9.723 48.443-25.443 48.443-41.52c0-15.417-17.868-30.326-45.517-39.844Zm-6.365 70.984c-1.4.463-2.836.91-4.3 1.345c-3.24-10.257-7.612-21.163-12.963-32.432c5.106-11 9.31-21.767 12.459-31.957c2.619.758 5.16 1.557 7.61 2.4c23.69 8.156 38.14 20.213 38.14 29.504c0 9.896-15.606 22.743-40.946 31.14Zm-10.514 20.834c2.562 12.94 2.927 24.64 1.23 33.787c-1.524 8.219-4.59 13.698-8.382 15.893c-8.067 4.67-25.32-1.4-43.927-17.412a156.726 156.726 0 0 1-6.437-5.87c7.214-7.889 14.423-17.06 21.459-27.246c12.376-1.098 24.068-2.894 34.671-5.345a134.17 134.17 0 0 1 1.386 6.193ZM87.276 214.515c-7.882 2.783-14.16 2.863-17.955.675c-8.075-4.657-11.432-22.636-6.853-46.752a156.923 156.923 0 0 1 1.869-8.499c10.486 2.32 22.093 3.988 34.498 4.994c7.084 9.967 14.501 19.128 21.976 27.15a134.668 134.668 0 0 1-4.877 4.492c-9.933 8.682-19.886 14.842-28.658 17.94ZM50.35 144.747c-12.483-4.267-22.792-9.812-29.858-15.863c-6.35-5.437-9.555-10.836-9.555-15.216c0-9.322 13.897-21.212 37.076-29.293c2.813-.98 5.757-1.905 8.812-2.773c3.204 10.42 7.406 21.315 12.477 32.332c-5.137 11.18-9.399 22.249-12.634 32.792a134.718 134.718 0 0 1-6.318-1.979Zm12.378-84.26c-4.811-24.587-1.616-43.134 6.425-47.789c8.564-4.958 27.502 2.111 47.463 19.835a144.318 144.318 0 0 1 3.841 3.545c-7.438 7.987-14.787 17.08-21.808 26.988c-12.04 1.116-23.565 2.908-34.161 5.309a160.342 160.342 0 0 1-1.76-7.887Zm110.427 27.268a347.8 347.8 0 0 0-7.785-12.803c8.168 1.033 15.994 2.404 23.343 4.08c-2.206 7.072-4.956 14.465-8.193 22.045a381.151 381.151 0 0 0-7.365-13.322Zm-45.032-43.861c5.044 5.465 10.096 11.566 15.065 18.186a322.04 322.04 0 0 0-30.257-.006c4.974-6.559 10.069-12.652 15.192-18.18ZM82.802 87.83a323.167 323.167 0 0 0-7.227 13.238c-3.184-7.553-5.909-14.98-8.134-22.152c7.304-1.634 15.093-2.97 23.209-3.984a321.524 321.524 0 0 0-7.848 12.897Zm8.081 65.352c-8.385-.936-16.291-2.203-23.593-3.793c2.26-7.3 5.045-14.885 8.298-22.6a321.187 321.187 0 0 0 7.257 13.246c2.594 4.48 5.28 8.868 8.038 13.147Zm37.542 31.03c-5.184-5.592-10.354-11.779-15.403-18.433c4.902.192 9.899.29 14.978.29c5.218 0 10.376-.117 15.453-.343c-4.985 6.774-10.018 12.97-15.028 18.486Zm52.198-57.817c3.422 7.8 6.306 15.345 8.596 22.52c-7.422 1.694-15.436 3.058-23.88 4.071a382.417 382.417 0 0 0 7.859-13.026a347.403 347.403 0 0 0 7.425-13.565Zm-16.898 8.101a358.557 358.557 0 0 1-12.281 19.815a329.4 329.4 0 0 1-23.444.823c-7.967 0-15.716-.248-23.178-.732a310.202 310.202 0 0 1-12.513-19.846h.001a307.41 307.41 0 0 1-10.923-20.627a310.278 310.278 0 0 1 10.89-20.637l-.001.001a307.318 307.318 0 0 1 12.413-19.761c7.613-.576 15.42-.876 23.31-.876H128c7.926 0 15.743.303 23.354.883a329.357 329.357 0 0 1 12.335 19.695a358.489 358.489 0 0 1 11.036 20.54a329.472 329.472 0 0 1-11 20.722Zm22.56-122.124c8.572 4.944 11.906 24.881 6.52 51.026c-.344 1.668-.73 3.367-1.15 5.09c-10.622-2.452-22.155-4.275-34.23-5.408c-7.034-10.017-14.323-19.124-21.64-27.008a160.789 160.789 0 0 1 5.888-5.4c18.9-16.447 36.564-22.941 44.612-18.3ZM128 90.808c12.625 0 22.86 10.235 22.86 22.86s-10.235 22.86-22.86 22.86s-22.86-10.235-22.86-22.86s10.235-22.86 22.86-22.86Z"></path></svg>
|
||||||
|
After Width: | Height: | Size: 4.0 KiB |
1
frontend/src/assets/vite.svg
Normal file
1
frontend/src/assets/vite.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 8.5 KiB |
15
frontend/src/components/common/TLogo.tsx
Normal file
15
frontend/src/components/common/TLogo.tsx
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
import React from 'react';
|
||||||
|
|
||||||
|
interface TLogoProps {
|
||||||
|
size?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const TLogo: React.FC<TLogoProps> = ({ size = 40 }) => (
|
||||||
|
<img
|
||||||
|
src="/logo/t_mobile_logo_transparent.png"
|
||||||
|
alt="T-Systems"
|
||||||
|
width={size}
|
||||||
|
height={size}
|
||||||
|
style={{ objectFit: 'contain' }}
|
||||||
|
/>
|
||||||
|
);
|
||||||
30
frontend/src/components/common/TPattern.tsx
Normal file
30
frontend/src/components/common/TPattern.tsx
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
import React from 'react';
|
||||||
|
import { useTheme } from '../../contexts/ThemeContext';
|
||||||
|
|
||||||
|
export const TPattern: React.FC = () => {
|
||||||
|
const { theme, isDark } = useTheme();
|
||||||
|
const patternOpacity = isDark ? 0.03 : 0.04;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
style={{
|
||||||
|
position: 'absolute',
|
||||||
|
top: 0,
|
||||||
|
right: 0,
|
||||||
|
width: 300,
|
||||||
|
height: 300,
|
||||||
|
opacity: patternOpacity,
|
||||||
|
pointerEvents: 'none',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<svg width="300" height="300" viewBox="0 0 300 300">
|
||||||
|
<defs>
|
||||||
|
<pattern id="grid" width="30" height="30" patternUnits="userSpaceOnUse">
|
||||||
|
<path d="M 30 0 L 0 0 0 30" fill="none" stroke={theme.accent} strokeWidth="1"/>
|
||||||
|
</pattern>
|
||||||
|
</defs>
|
||||||
|
<rect width="300" height="300" fill="url(#grid)"/>
|
||||||
|
</svg>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
35
frontend/src/components/common/ThemeToggle.tsx
Normal file
35
frontend/src/components/common/ThemeToggle.tsx
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
import React from 'react';
|
||||||
|
import { useTheme } from '../../contexts/ThemeContext';
|
||||||
|
|
||||||
|
export const ThemeToggle: React.FC = () => {
|
||||||
|
const { isDark, toggleTheme, theme } = useTheme();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<button
|
||||||
|
onClick={toggleTheme}
|
||||||
|
style={{
|
||||||
|
width: 44,
|
||||||
|
height: 44,
|
||||||
|
borderRadius: 10,
|
||||||
|
background: isDark ? theme.bgHover : theme.bgCard,
|
||||||
|
border: `1px solid ${theme.border}`,
|
||||||
|
cursor: 'pointer',
|
||||||
|
display: 'flex',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
transition: 'all 0.3s ease',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{isDark ? (
|
||||||
|
<svg width="20" height="20" viewBox="0 0 24 24" fill="none">
|
||||||
|
<circle cx="12" cy="12" r="4" fill={theme.accent}/>
|
||||||
|
<path d="M12 2V4M12 20V22M4 12H2M22 12H20M6.34 6.34L4.93 4.93M19.07 19.07L17.66 17.66M6.34 17.66L4.93 19.07M19.07 4.93L17.66 6.34" stroke={theme.accent} strokeWidth="2" strokeLinecap="round"/>
|
||||||
|
</svg>
|
||||||
|
) : (
|
||||||
|
<svg width="20" height="20" viewBox="0 0 24 24" fill="none">
|
||||||
|
<path d="M21 12.79A9 9 0 1 1 11.21 3 7 7 0 0 0 21 12.79z" fill={theme.accent} stroke={theme.accent} strokeWidth="1"/>
|
||||||
|
</svg>
|
||||||
|
)}
|
||||||
|
</button>
|
||||||
|
);
|
||||||
|
};
|
||||||
3
frontend/src/components/common/index.ts
Normal file
3
frontend/src/components/common/index.ts
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
export { TLogo } from './TLogo';
|
||||||
|
export { ThemeToggle } from './ThemeToggle';
|
||||||
|
export { TPattern } from './TPattern';
|
||||||
27
frontend/src/components/layout/Content.tsx
Normal file
27
frontend/src/components/layout/Content.tsx
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import React from 'react';
|
||||||
|
import { useTheme } from '../../contexts/ThemeContext';
|
||||||
|
|
||||||
|
interface ContentProps {
|
||||||
|
children: React.ReactNode;
|
||||||
|
wide?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const Content: React.FC<ContentProps> = ({ children, wide = false }) => {
|
||||||
|
const { theme } = useTheme();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<main
|
||||||
|
style={{
|
||||||
|
flex: 1,
|
||||||
|
padding: '48px 56px',
|
||||||
|
maxWidth: wide ? 1400 : 1100,
|
||||||
|
margin: '0 auto',
|
||||||
|
width: '100%',
|
||||||
|
position: 'relative',
|
||||||
|
backgroundColor: theme.bg,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{children}
|
||||||
|
</main>
|
||||||
|
);
|
||||||
|
};
|
||||||
47
frontend/src/components/layout/Header.tsx
Normal file
47
frontend/src/components/layout/Header.tsx
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
import React from 'react';
|
||||||
|
import { useTheme } from '../../contexts/ThemeContext';
|
||||||
|
import { TLogo } from '../common/TLogo';
|
||||||
|
import { ThemeToggle } from '../common/ThemeToggle';
|
||||||
|
|
||||||
|
export const Header: React.FC = () => {
|
||||||
|
const { theme } = useTheme();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<header
|
||||||
|
className="h-[72px] flex items-center justify-between sticky top-0 z-[100]"
|
||||||
|
style={{
|
||||||
|
padding: '0 48px',
|
||||||
|
borderBottom: `1px solid ${theme.border}`,
|
||||||
|
backgroundColor: theme.bg,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<div className="flex items-center" style={{ gap: 20 }}>
|
||||||
|
<TLogo size={80} />
|
||||||
|
<div className="flex items-baseline" style={{ gap: 12 }}>
|
||||||
|
<span style={{ fontWeight: 700, fontSize: 20, letterSpacing: '-0.5px', color: theme.text }}>
|
||||||
|
T-Systems
|
||||||
|
</span>
|
||||||
|
<span style={{ fontWeight: 300, fontSize: 16, color: theme.text2 }}>
|
||||||
|
Regulation
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="flex items-center" style={{ gap: 16 }}>
|
||||||
|
<ThemeToggle />
|
||||||
|
<div
|
||||||
|
className="flex items-center rounded-lg"
|
||||||
|
style={{
|
||||||
|
padding: '8px 16px',
|
||||||
|
gap: 8,
|
||||||
|
backgroundColor: theme.bgHover,
|
||||||
|
borderRadius: 8,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<span className="mono" style={{ fontSize: 11, color: theme.text3 }}>v1.0.0</span>
|
||||||
|
<div style={{ width: 1, height: 12, background: theme.border }} />
|
||||||
|
<span className="mono" style={{ fontSize: 12, color: theme.green }}>● ONLINE</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</header>
|
||||||
|
);
|
||||||
|
};
|
||||||
47
frontend/src/components/layout/Tabs.tsx
Normal file
47
frontend/src/components/layout/Tabs.tsx
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
import React from 'react';
|
||||||
|
import { useTheme, useApp } from '../../contexts';
|
||||||
|
|
||||||
|
const tabs = [
|
||||||
|
{ id: 'docs', label: '文档管理' },
|
||||||
|
{ id: 'compliance', label: '合规分析' },
|
||||||
|
{ id: 'status', label: '系统状态' },
|
||||||
|
{ id: 'rag', label: '法规对话' },
|
||||||
|
];
|
||||||
|
|
||||||
|
export const Tabs: React.FC = () => {
|
||||||
|
const { theme } = useTheme();
|
||||||
|
const { activeTab, setActiveTab } = useApp();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<nav
|
||||||
|
className="h-[56px] flex items-center"
|
||||||
|
style={{
|
||||||
|
padding: '0 48px',
|
||||||
|
borderBottom: `1px solid ${theme.border}`,
|
||||||
|
backgroundColor: theme.bg,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{tabs.map((tab) => (
|
||||||
|
<button
|
||||||
|
key={tab.id}
|
||||||
|
onClick={() => setActiveTab(tab.id as any)}
|
||||||
|
style={{
|
||||||
|
height: 56,
|
||||||
|
padding: '0 32px',
|
||||||
|
fontSize: 15,
|
||||||
|
fontWeight: activeTab === tab.id ? 600 : 400,
|
||||||
|
color: activeTab === tab.id ? theme.accent : theme.text3,
|
||||||
|
background: 'transparent',
|
||||||
|
border: 'none',
|
||||||
|
borderBottom: activeTab === tab.id ? `3px solid ${theme.accent}` : '3px solid transparent',
|
||||||
|
marginBottom: -1,
|
||||||
|
cursor: 'pointer',
|
||||||
|
transition: 'all 0.2s ease',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{tab.label}
|
||||||
|
</button>
|
||||||
|
))}
|
||||||
|
</nav>
|
||||||
|
);
|
||||||
|
};
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user