first commit

This commit is contained in:
2026-04-28 11:29:33 +08:00
commit c2a398930d
44 changed files with 5723 additions and 0 deletions

48
.env Normal file
View File

@@ -0,0 +1,48 @@
# 环境变量配置 - 已有数据库服务
# 应用配置
APP_NAME=AI+合规智能中枢
APP_VERSION=0.1.0
DEBUG=false
# Milvus向量数据库配置已有
MILVUS_HOST=localhost
MILVUS_PORT=19530
MILVUS_COLLECTION=regulations
MILVUS_DB_NAME=default
# MinIO对象存储配置已有
MINIO_ENDPOINT=localhost:9000
MINIO_ACCESS_KEY=minioadmin
MINIO_SECRET_KEY=minioadmin
MINIO_BUCKET=compliance-docs
MINIO_SECURE=false
# Redis配置已有
REDIS_HOST=localhost
REDIS_PORT=6379
REDIS_PASSWORD=redis@123
REDIS_DB=0
# PostgreSQL配置已有
POSTGRES_HOST=localhost
POSTGRES_PORT=5432
POSTGRES_USER=postgresql
POSTGRES_PASSWORD=postgresql123456
POSTGRES_DB=compliance_db
# 嵌入模型配置
EMBEDDING_MODEL=BAAI/bge-m3
EMBEDDING_DIM=1024
EMBEDDING_MAX_LENGTH=8192
EMBEDDING_BATCH_SIZE=12
EMBEDDING_USE_FP16=true
# 文档处理配置
CHUNK_SIZE=512
CHUNK_OVERLAP=50
MAX_FILE_SIZE_MB=100
# API配置
API_HOST=0.0.0.0
API_PORT=8000

31
.env.example Normal file
View File

@@ -0,0 +1,31 @@
# .env.example - 环境变量配置示例
# Milvus向量数据库配置
MILVUS_HOST=localhost
MILVUS_PORT=19530
MILVUS_COLLECTION=regulations
# 嵌入模型配置
EMBEDDING_MODEL=BAAI/bge-m3
EMBEDDING_DIM=1024
# MinIO对象存储配置
MINIO_ENDPOINT=localhost:9000
MINIO_ACCESS_KEY=minioadmin
MINIO_SECRET_KEY=minioadmin123
MINIO_BUCKET=compliance-docs
# Redis配置
REDIS_HOST=localhost
REDIS_PORT=6379
# PostgreSQL配置
POSTGRES_HOST=localhost
POSTGRES_PORT=5432
POSTGRES_USER=compliance
POSTGRES_PASSWORD=compliance123
POSTGRES_DB=compliance_db
# 文档处理配置
CHUNK_SIZE=512
CHUNK_OVERLAP=50

143
README.md Normal file
View File

@@ -0,0 +1,143 @@
# AI+合规智能中枢 - 法律法规文档解析入库
面向车企与工厂的合规智能平台,实现法规文档的解析、分块、嵌入和向量存储。
## MVP功能
本次实现的核心功能(最小可用版本):
- ✅ PDF/DOCX文档解析MinerU + PyMuPDF
- ✅ 智能分块(章节级+条款级双粒度切割)
- ✅ BGE-M3嵌入Dense+Sparse双路向量
- ✅ Milvus向量数据库存储与混合检索
- ✅ FastAPI接口封装
## 项目结构
```
Demo-glm/
├── src/
│ ├── api/ # FastAPI接口层
│ │ ├── main.py # API入口
│ │ ├── routes/
│ │ │ ├── documents.py # 文档上传接口
│ │ │ └── knowledge.py # 知识库检索接口
│ │ └── models/
│ │ └── document.py # Pydantic数据模型
│ ├── services/
│ │ ├── parser/ # 文档解析服务
│ │ │ ├── pdf_parser.py # PDF解析PyMuPDF
│ │ │ ├── docx_parser.py # Word解析
│ │ │ └── mineru_parser.py # MinerU多模态解析
│ │ ├── embedding/ # 嵌入服务
│ │ │ ├── text_chunker.py # 智能分块器
│ │ │ └── bge_m3_embedder.py # BGE-M3嵌入
│ │ ├── storage/
│ │ │ └── milvus_client.py # Milvus客户端
│ │ └── document_processor.py # 文档处理主流程
│ └── config/
│ │ ├── settings.py # 配置管理
│ │ └── logging.py # 日志配置
├── tests/
│ ├── test_parser.py # 解析测试
│ ├── test_embedding.py # 嵌入测试
│ ├── test_milvus.py # Milvus测试
│ └── verify_mvp.py # MVP验证脚本
├── docker/
│ └── docker-compose.yml # Milvus/MinIO部署
├── requirements.txt
├── pyproject.toml
└── .env.example
```
## 快速开始
### 1. 安装依赖
```bash
pip install -r requirements.txt
```
### 2. 启动Milvus向量数据库
```bash
cd docker
docker-compose up -d
```
等待Milvus启动完成约30秒
```bash
docker-compose logs -f milvus
```
### 3. 运行验证脚本
```bash
python tests/verify_mvp.py
```
### 4. 启动API服务
```bash
uvicorn src.api.main:app --reload --port 8000
```
访问API文档http://localhost:8000/docs
## API接口
### 上传文档
```bash
curl -X POST http://localhost:8000/api/v1/documents/upload \
-F "file=@your_regulation.pdf" \
-F "doc_name=GB 7258-2017" \
-F "regulation_type=车辆安全"
```
### 检索法规
```bash
curl -X POST http://localhost:8000/api/v1/knowledge/search \
-H "Content-Type: application/json" \
-d '{"query": "机动车安全技术要求", "top_k": 10}'
```
## 技术栈
| 类别 | 技术 |
|------|------|
| 文档解析 | MinerU + PyMuPDF + python-docx |
| 分块策略 | 章节级+条款级双粒度切割 |
| 嵌入模型 | BGE-M31024维 Dense + Sparse |
| 向量数据库 | Milvus 2.4本地Docker部署 |
| 检索方式 | Dense+Sparse混合检索 + RRF融合 |
| API框架 | FastAPI |
## 配置
创建 `.env` 文件(参考 `.env.example`
```env
# Milvus配置
MILVUS_HOST=localhost
MILVUS_PORT=19530
# 嵌入模型配置
EMBEDDING_MODEL=BAAI/bge-m3
EMBEDDING_DIM=1024
# 分块配置
CHUNK_SIZE=512
```
## 后续迭代不在本次MVP范围
- LLM摘要生成DeepSeek/Qwen API
- 文档上传UI界面
- 混合检索问答功能
- 法规变更监控与自动更新
## 许可证
MIT License

88
docker/docker-compose.yml Normal file
View File

@@ -0,0 +1,88 @@
# AI+合规智能中枢 - 基础设施部署配置
# Milvus向量数据库 + MinIO对象存储 + Redis缓存
services:
# Milvus向量数据库 (Standalone模式)
milvus:
image: milvusdb/milvus:v2.4-latest
container_name: milvus-standalone
ports:
- "19530:19530" # SDK连接端口
- "9091:9091" # 健康检查端口
environment:
ETCD_USE_EMBED: "true"
COMMON_LOG_LEVEL: "info"
volumes:
- milvus_data:/var/lib/milvus
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"]
interval: 30s
timeout: 10s
retries: 5
restart: unless-stopped
command: ["milvus", "run", "standalone"]
# MinIO对象存储
minio:
image: minio/minio:latest
container_name: minio
ports:
- "9000:9000" # API端口
- "9001:9001" # Console端口
environment:
MINIO_ROOT_USER: minioadmin
MINIO_ROOT_PASSWORD: minioadmin123
volumes:
- minio_data:/data
command: server /data --console-address ":9001"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
interval: 30s
timeout: 10s
retries: 5
restart: unless-stopped
# Redis缓存
redis:
image: redis:7-alpine
container_name: redis
ports:
- "6379:6379"
volumes:
- redis_data:/data
command: redis-server --appendonly yes
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 30s
timeout: 10s
retries: 5
restart: unless-stopped
# PostgreSQL数据库 (可选)
postgres:
image: postgres:15-alpine
container_name: postgres
ports:
- "5432:5432"
environment:
POSTGRES_USER: compliance
POSTGRES_PASSWORD: compliance123
POSTGRES_DB: compliance_db
volumes:
- postgres_data:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U compliance"]
interval: 30s
timeout: 10s
retries: 5
restart: unless-stopped
volumes:
milvus_data:
minio_data:
redis_data:
postgres_data:
networks:
default:
name: compliance-network

44
download_model.sh Normal file
View File

@@ -0,0 +1,44 @@
#!/bin/bash
# download_model.sh - 下载BGE-M3模型文件可在本地Windows运行后上传到服务器
MODEL_DIR="bge-m3-model"
MODEL_URL="https://modelscope.cn/models/Xorbits/bge-m3/resolve/master"
echo "========================================"
echo "下载 BGE-M3 模型文件"
echo "========================================"
echo ""
echo "目标目录: $MODEL_DIR"
echo ""
# 创建目录
mkdir -p "$MODEL_DIR"
# 下载模型文件
FILES=(
"config.json"
"model.safetensors"
"tokenizer.json"
"tokenizer_config.json"
"special_tokens_map.json"
"vocab.txt"
"sentencepiece.bpe.model"
)
for file in "${FILES[@]}"; do
echo "下载: $file"
wget -c "$MODEL_URL/$file" -O "$MODEL_DIR/$file" || curl -L "$MODEL_URL/$file" -o "$MODEL_DIR/$file"
done
echo ""
echo "========================================"
echo "下载完成!"
echo "========================================"
echo ""
echo "模型文件列表:"
ls -lh "$MODEL_DIR"
echo ""
echo "下一步:"
echo "1. 将 $MODEL_DIR 目录上传到服务器"
echo "2. 在服务器上放置到: ~/.cache/huggingface/hub/models--BAAI--bge-m3/snapshots/main/"
echo ""

1150
frontend/index.html Normal file

File diff suppressed because it is too large Load Diff

6
pip.conf Normal file
View File

@@ -0,0 +1,6 @@
[global]
index-url = https://mirrors.aliyun.com/pypi/simple
trusted-host = mirrors.aliyun.com
[install]
trusted-host = mirrors.aliyun.com

47
pyproject.toml Normal file
View File

@@ -0,0 +1,47 @@
[project]
name = "ai-compliance-hub"
version = "0.1.0"
description = "AI+合规智能中枢 - 法律法规文档解析入库功能"
authors = [
{name = "T-systems AI Regulations Team"}
]
readme = "README.md"
requires-python = ">=3.10"
license = {text = "MIT"}
dependencies = [
"pymilvus>=2.4.0",
"fastapi>=0.100.0",
"uvicorn[standard]>=0.23.0",
"python-multipart>=0.0.6",
"langchain>=0.1.0",
"langchain-milvus>=0.1.0",
"pymupdf>=1.24.0",
"python-docx>=0.8.11",
"FlagEmbedding>=1.2.0",
"sentence-transformers>=2.2.0",
"pydantic>=2.0.0",
"pydantic-settings>=2.0.0",
"python-dotenv>=1.0.0",
"loguru>=0.7.0",
"tenacity>=8.2.0",
"httpx>=0.24.0",
]
[project.optional-dependencies]
mineru = ["magic-pdf[full]>=0.6.0"]
queue = ["celery>=5.3.0", "redis>=4.5.0"]
storage = ["minio>=7.1.0"]
db = ["psycopg2-binary>=2.9.0"]
dev = ["pytest>=7.0.0", "pytest-asyncio>=0.21.0"]
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
where = ["src"]
[tool.pytest.ini_options]
testpaths = ["tests"]
asyncio_mode = "auto"

246
quick_start.sh Normal file
View File

@@ -0,0 +1,246 @@
#!/bin/bash
# quick_start.sh - 快速启动脚本
# 适配Docker部署的数据库环境
set -e
echo "========================================"
echo "AI+合规智能中枢 - 快速启动脚本"
echo "========================================"
echo ""
# 颜色定义
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
VENV_DIR=".venv"
# 检查Python版本
echo -e "${YELLOW}[1/8] 检查Python环境...${NC}"
if command -v python3 &> /dev/null; then
PYTHON_CMD=python3
else
PYTHON_CMD=python
fi
PYTHON_VERSION=$($PYTHON_CMD --version 2>&1 | awk '{print $2}')
echo "Python版本: $PYTHON_VERSION"
if [[ ! "$PYTHON_VERSION" =~ ^3\.1[0-9] ]]; then
echo -e "${RED}需要Python 3.10+,当前版本: $PYTHON_VERSION${NC}"
exit 1
fi
echo -e "${GREEN}Python环境检查通过${NC}"
echo ""
# 创建虚拟环境
echo -e "${YELLOW}[2/8] 创建虚拟环境...${NC}"
if [ -d "$VENV_DIR" ]; then
echo "虚拟环境已存在: $VENV_DIR"
else
echo "正在创建虚拟环境..."
$PYTHON_CMD -m venv $VENV_DIR
echo -e "${GREEN}虚拟环境创建成功${NC}"
fi
# 激活虚拟环境
source $VENV_DIR/bin/activate
echo "已激活虚拟环境: $VENV_DIR"
echo ""
# 安装依赖(使用国内镜像源)
echo -e "${YELLOW}[3/8] 安装Python依赖...${NC}"
# 配置pip使用清华镜像源国内加速
PIP_MIRROR="https://mirrors.aliyun.com/pypi/simple"
echo "使用镜像源: $PIP_MIRROR"
pip config set global.index-url $PIP_MIRROR -q
pip install --upgrade pip -q
pip install -r requirements.txt -q
if [ $? -eq 0 ]; then
echo -e "${GREEN}依赖安装完成${NC}"
else
echo -e "${RED}依赖安装失败请检查requirements.txt${NC}"
exit 1
fi
echo ""
# 检查Docker容器状态
echo -e "${YELLOW}[4/8] 检查Docker容器状态...${NC}"
REQUIRED_CONTAINERS="milvus minio redis postgres"
for container in $REQUIRED_CONTAINERS; do
if docker ps --format '{{.Names}}' | grep -q "^${container}$"; then
echo -e "${GREEN}${container} 容器运行正常${NC}"
elif docker ps -a --format '{{.Names}}' | grep -q "^${container}$"; then
echo -e "${YELLOW}${container} 容器已停止,尝试启动...${NC}"
docker start ${container}
sleep 2
if docker ps --format '{{.Names}}' | grep -q "^${container}$"; then
echo -e "${GREEN}${container} 已启动${NC}"
else
echo -e "${RED}${container} 启动失败${NC}"
fi
else
echo -e "${RED}${container} 容器不存在${NC}"
fi
done
echo ""
# 检查PostgreSQL连接通过Python
echo -e "${YELLOW}[5/8] 检查PostgreSQL连接...${NC}"
python << 'EOF'
import sys
try:
import psycopg2
conn = psycopg2.connect(
host="localhost",
port=5432,
user="postgresql",
password="postgresql123456",
database="postgres"
)
cur = conn.cursor()
# 检查compliance_db是否存在
cur.execute("SELECT 1 FROM pg_database WHERE datname='compliance_db'")
exists = cur.fetchone()
if not exists:
cur.execute("CREATE DATABASE compliance_db")
print("数据库 compliance_db 创建成功")
else:
print("数据库 compliance_db 已存在")
conn.commit()
cur.close()
conn.close()
print("PostgreSQL连接成功")
except Exception as e:
print(f"PostgreSQL连接失败: {e}")
sys.exit(1)
EOF
if [ $? -eq 0 ]; then
echo -e "${GREEN}PostgreSQL服务运行正常${NC}"
else
echo -e "${RED}PostgreSQL连接失败${NC}"
exit 1
fi
echo ""
# 检查Redis连接通过Python
echo -e "${YELLOW}[6/8] 检查Redis连接...${NC}"
python << 'EOF'
import sys
try:
import redis
r = redis.Redis(
host="localhost",
port=6379,
password="redis@123",
decode_responses=True
)
result = r.ping()
if result:
print("Redis连接成功")
else:
sys.exit(1)
except Exception as e:
print(f"Redis连接失败: {e}")
sys.exit(1)
EOF
if [ $? -eq 0 ]; then
echo -e "${GREEN}Redis服务运行正常${NC}"
else
echo -e "${RED}Redis连接失败${NC}"
exit 1
fi
echo ""
# 检查Milvus连接通过Python
echo -e "${YELLOW}[7/8] 检查Milvus连接...${NC}"
python << 'EOF'
import sys
try:
from pymilvus import connections, utility
connections.connect(host="localhost", port=19530)
print("Milvus连接成功")
connections.disconnect("default")
except Exception as e:
print(f"Milvus连接失败: {e}")
sys.exit(1)
EOF
if [ $? -eq 0 ]; then
echo -e "${GREEN}Milvus服务运行正常${NC}"
else
echo -e "${RED}Milvus连接失败${NC}"
exit 1
fi
echo ""
# 检查MinIO连接通过Python
echo -e "${YELLOW}[8/8] 检查MinIO连接...${NC}"
python << 'EOF'
import sys
try:
from minio import Minio
client = Minio("localhost:9000", "minioadmin", "minioadmin", secure=False)
# 检查bucket是否存在
bucket = "compliance-docs"
if not client.bucket_exists(bucket):
client.make_bucket(bucket)
print(f"MinIO bucket '{bucket}' 创建成功")
else:
print(f"MinIO bucket '{bucket}' 已存在")
print("MinIO连接成功")
except Exception as e:
print(f"MinIO连接失败: {e}")
sys.exit(1)
EOF
if [ $? -eq 0 ]; then
echo -e "${GREEN}MinIO服务运行正常${NC}"
else
echo -e "${RED}MinIO连接失败${NC}"
exit 1
fi
echo ""
# 预下载BGE-M3模型可选
echo -e "${YELLOW}[提示] BGE-M3嵌入模型...${NC}"
MODEL_CACHE=~/.cache/huggingface/hub/models--BAAI--bge-m3
if [ -d "$MODEL_CACHE" ]; then
echo -e "${GREEN}BGE-M3模型已存在${NC}"
else
echo -e "${YELLOW}模型未下载首次运行时将自动下载约2GB${NC}"
echo "手动预下载: python -c \"from FlagEmbedding import BGEM3FlagModel; BGEM3FlagModel('BAAI/bge-m3')\""
fi
echo ""
# 输出启动命令
echo "========================================"
echo -e "${GREEN}环境检查完成!${NC}"
echo "========================================"
echo ""
echo "虚拟环境: $VENV_DIR"
echo ""
echo "启动API服务:"
echo " ./start_api.sh"
echo ""
echo "后台启动:"
echo " ./start_api_background.sh"
echo ""
echo "API文档地址:"
echo " http://localhost:8000/docs"
echo " http://localhost:8000/health"
echo ""

46
requirements.txt Normal file
View File

@@ -0,0 +1,46 @@
# AI+合规智能中枢 - 法律法规文档解析入库
# MVP核心依赖包
# 向量数据库
pymilvus>=2.4.0
# API框架
fastapi>=0.100.0
uvicorn[standard]>=0.23.0
python-multipart>=0.0.6
# RAG框架
langchain>=0.1.0
langchain-milvus>=0.1.0
# PDF解析
pymupdf>=1.24.0 # PyMuPDF
# Word文档解析
python-docx>=0.8.11
# MinerU多模态PDF解析可选需要额外配置
# magic-pdf[full]>=0.6.0
# 嵌入模型
FlagEmbedding>=1.2.0
sentence-transformers>=2.2.0
# 任务队列(可选)
celery>=5.3.0
redis>=4.5.0
# 对象存储(可选)
minio>=7.1.0
# 数据库
psycopg2-binary>=2.9.0
# mysql-connector-python>=8.0.0
# 工具库
pydantic>=2.0.0
pydantic-settings>=2.0.0
python-dotenv>=1.0.0
loguru>=0.7.0
tenacity>=8.2.0
httpx>=0.24.0

4
src/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
# src/__init__.py
"""AI+合规智能中枢 - 法律法规文档解析入库功能"""
__version__ = "0.1.0"

2
src/api/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
# src/api/__init__.py
"""API接口模块"""

109
src/api/main.py Normal file
View File

@@ -0,0 +1,109 @@
# src/api/main.py
"""FastAPI应用入口"""
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from contextlib import asynccontextmanager
from loguru import logger
import sys
import os
# 设置日志
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from src.config.logging import setup_logging
from src.config.settings import settings
from src.api.routes import api_router
from src.api.models import ErrorResponse
# 配置日志
setup_logging(level="INFO" if not settings.debug else "DEBUG")
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
logger.info(f"启动 {settings.app_name} v{settings.app_version}")
logger.info(f"调试模式: {settings.debug}")
# 启动时的初始化操作
# 如:预加载模型、建立数据库连接池等
yield
# 关闭时的清理操作
logger.info("应用关闭,执行清理...")
# 如:关闭数据库连接、清理缓存等
# 创建FastAPI应用
app = FastAPI(
title=settings.app_name,
description="AI+合规智能中枢 - 法律法规文档解析入库功能\n\n支持PDF/DOCX文档解析、智能分块、向量嵌入、Milvus存储",
version=settings.app_version,
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc"
)
# CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 生产环境应配置具体域名
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 注册路由
app.include_router(api_router, prefix="/api/v1")
# 全局异常处理
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
"""全局异常处理"""
logger.error(f"未处理的异常: {exc}")
return JSONResponse(
status_code=500,
content=ErrorResponse(
error="InternalServerError",
message=str(exc)
).model_dump()
)
# 健康检查接口
@app.get("/health", tags=["health"])
async def health_check():
"""健康检查"""
return {
"status": "healthy",
"app": settings.app_name,
"version": settings.app_version
}
# 根路径
@app.get("/", tags=["root"])
async def root():
"""根路径"""
return {
"message": f"Welcome to {settings.app_name}",
"version": settings.app_version,
"docs": "/docs",
"health": "/health"
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host=settings.api_host,
port=settings.api_port,
reload=settings.debug,
log_level="info"
)

View File

@@ -0,0 +1,22 @@
# src/api/models/__init__.py
"""API数据模型"""
from .document import (
DocumentUploadRequest,
DocumentUploadResponse,
SearchRequest,
SearchResultItem,
SearchResponse,
DocumentStatusResponse,
ErrorResponse
)
__all__ = [
"DocumentUploadRequest",
"DocumentUploadResponse",
"SearchRequest",
"SearchResultItem",
"SearchResponse",
"DocumentStatusResponse",
"ErrorResponse"
]

View File

@@ -0,0 +1,61 @@
# src/api/models/document.py
"""文档相关Pydantic数据模型"""
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
from datetime import datetime
class DocumentUploadRequest(BaseModel):
"""文档上传请求"""
doc_name: Optional[str] = Field(None, description="文档名称")
regulation_type: Optional[str] = Field(None, description="法规类型")
version: Optional[str] = Field(None, description="文档版本")
class DocumentUploadResponse(BaseModel):
"""文档上传响应"""
doc_id: str = Field(..., description="文档ID")
doc_name: str = Field(..., description="文档名称")
status: str = Field(..., description="处理状态")
message: str = Field(..., description="状态消息")
num_chunks: Optional[int] = Field(None, description="分块数量")
timestamp: datetime = Field(default_factory=datetime.now, description="时间戳")
class SearchRequest(BaseModel):
"""检索请求"""
query: str = Field(..., description="查询文本")
top_k: int = Field(default=10, description="返回结果数量")
filters: Optional[str] = Field(None, description="过滤条件")
class SearchResultItem(BaseModel):
"""单个检索结果"""
id: int = Field(..., description="记录ID")
content: str = Field(..., description="内容")
score: float = Field(..., description="相似度分数")
metadata: Dict[str, Any] = Field(default_factory=dict, description="元数据")
class SearchResponse(BaseModel):
"""检索响应"""
query: str = Field(..., description="查询文本")
total: int = Field(..., description="结果总数")
results: List[SearchResultItem] = Field(default_factory=list, description="结果列表")
timestamp: datetime = Field(default_factory=datetime.now, description="时间戳")
class DocumentStatusResponse(BaseModel):
"""文档状态响应"""
doc_id: str = Field(..., description="文档ID")
status: str = Field(..., description="状态")
num_chunks: Optional[int] = Field(None, description="分块数量")
timestamp: datetime = Field(default_factory=datetime.now, description="时间戳")
class ErrorResponse(BaseModel):
"""错误响应"""
error: str = Field(..., description="错误类型")
message: str = Field(..., description="错误消息")
timestamp: datetime = Field(default_factory=datetime.now, description="时间戳")

View File

@@ -0,0 +1,15 @@
# src/api/routes/__init__.py
"""API路由模块"""
from fastapi import APIRouter
from .documents import router as documents_router
from .knowledge import router as knowledge_router
# 主路由
api_router = APIRouter()
# 注册子路由
api_router.include_router(documents_router)
api_router.include_router(knowledge_router)
__all__ = ["api_router", "documents_router", "knowledge_router"]

117
src/api/routes/documents.py Normal file
View File

@@ -0,0 +1,117 @@
# src/api/routes/documents.py
"""文档上传与处理接口"""
from fastapi import APIRouter, UploadFile, File, Form, HTTPException
from typing import Optional
import os
import uuid
import tempfile
from loguru import logger
from ..models import DocumentUploadResponse, ErrorResponse
from src.services.document_processor import DocumentProcessor
from src.config.settings import settings
router = APIRouter(prefix="/documents", tags=["documents"])
@router.post("/upload", response_model=DocumentUploadResponse)
async def upload_document(
file: UploadFile = File(..., description="上传的文档文件"),
doc_name: Optional[str] = Form(None, description="文档名称"),
regulation_type: Optional[str] = Form(None, description="法规类型"),
version: Optional[str] = Form(None, description="文档版本")
):
"""
上传文档并处理
支持格式PDF、DOCX、DOC
处理流程:解析 → 分块 → 嵌入 → 入库
"""
# 验证文件类型
ext = os.path.splitext(file.filename)[1].lower()
if ext not in [".pdf", ".docx", ".doc"]:
raise HTTPException(
status_code=400,
detail=f"不支持的文件类型: {ext}仅支持PDF、DOCX、DOC"
)
# 验证文件大小
if file.size and file.size > settings.max_file_size_mb * 1024 * 1024:
raise HTTPException(
status_code=400,
detail=f"文件过大,最大支持{settings.max_file_size_mb}MB"
)
# 生成文档ID
doc_id = str(uuid.uuid4())[:8]
# 文档名称
final_doc_name = doc_name or file.filename
logger.info(f"接收到文件上传: {final_doc_name}, 类型: {ext}")
try:
# 保存临时文件
temp_dir = tempfile.gettempdir()
temp_path = os.path.join(temp_dir, f"{doc_id}_{file.filename}")
with open(temp_path, "wb") as f:
content = await file.read()
f.write(content)
logger.info(f"文件已保存到: {temp_path}")
# 处理文档
processor = DocumentProcessor()
result = processor.process(
file_path=temp_path,
doc_name=final_doc_name,
regulation_type=regulation_type or "",
version=version or ""
)
processor.close()
# 清理临时文件
try:
os.remove(temp_path)
except:
pass
if result.success:
return DocumentUploadResponse(
doc_id=result.doc_id,
doc_name=result.doc_name,
status="success",
message=result.message,
num_chunks=result.num_chunks
)
else:
raise HTTPException(
status_code=500,
detail=result.message
)
except Exception as e:
logger.error(f"文档处理失败: {e}")
raise HTTPException(
status_code=500,
detail=f"文档处理失败: {str(e)}"
)
@router.get("/status/{doc_id}", response_model=DocumentUploadResponse)
async def get_document_status(doc_id: str):
"""
查询文档处理状态
Args:
doc_id: 文档ID
"""
# TODO: 实现状态查询(需要数据库支持)
return DocumentUploadResponse(
doc_id=doc_id,
doc_name="",
status="unknown",
message="状态查询功能待实现"
)

View File

@@ -0,0 +1,81 @@
# src/api/routes/knowledge.py
"""知识库检索接口"""
from fastapi import APIRouter, HTTPException
from loguru import logger
from ..models import SearchRequest, SearchResponse, SearchResultItem, ErrorResponse
from src.services.document_processor import DocumentProcessor
router = APIRouter(prefix="/knowledge", tags=["knowledge"])
@router.post("/search", response_model=SearchResponse)
async def search_knowledge(request: SearchRequest):
"""
检索法规知识库
使用混合检索Dense向量 + Sparse向量 + RRF融合
Args:
request: 检索请求参数
"""
if not request.query or len(request.query.strip()) == 0:
raise HTTPException(
status_code=400,
detail="查询文本不能为空"
)
logger.info(f"收到检索请求: {request.query}")
try:
# 执行检索
processor = DocumentProcessor()
results = processor.search(
query=request.query,
top_k=request.top_k,
filters=request.filters
)
processor.close()
# 转换结果格式
result_items = []
for r in results:
item = SearchResultItem(
id=r.get("id", 0),
content=r.get("content", ""),
score=r.get("score", 0.0),
metadata=r.get("metadata", {})
)
result_items.append(item)
return SearchResponse(
query=request.query,
total=len(result_items),
results=result_items
)
except Exception as e:
logger.error(f"检索失败: {e}")
raise HTTPException(
status_code=500,
detail=f"检索失败: {str(e)}"
)
@router.post("/retrieval", response_model=SearchResponse)
async def knowledge_retrieval(request: SearchRequest):
"""
知识检索接口(与架构文档对齐)
该接口实现完整的检索流程:
1. 意图识别
2. BM25关键词检索 + 向量语义检索(双路召回)
3. Cross-Encoder精排
4. 返回结果
Args:
request: 检索请求
"""
# 当前版本使用混合检索,后续可添加精排步骤
return await search_knowledge(request)

6
src/config/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
# src/config/__init__.py
"""配置模块"""
from .settings import Settings, get_settings
__all__ = ["Settings", "get_settings"]

32
src/config/logging.py Normal file
View File

@@ -0,0 +1,32 @@
# src/config/logging.py
"""日志配置"""
from loguru import logger
import sys
def setup_logging(level: str = "INFO"):
"""设置日志配置"""
# 移除默认handler
logger.remove()
# 添加控制台输出
logger.add(
sys.stdout,
level=level,
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
colorize=True
)
# 添加文件输出
logger.add(
"logs/app_{time:YYYY-MM-DD}.log",
level=level,
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}",
rotation="00:00",
retention="7 days",
compression="zip"
)
return logger

73
src/config/settings.py Normal file
View File

@@ -0,0 +1,73 @@
# src/config/settings.py
"""配置管理 - 环境变量和默认配置"""
from pydantic_settings import BaseSettings
from pydantic import Field
from typing import Optional
from functools import lru_cache
class Settings(BaseSettings):
"""应用配置"""
# 应用基础配置
app_name: str = Field(default="AI+合规智能中枢", description="应用名称")
app_version: str = Field(default="0.1.0", description="应用版本")
debug: bool = Field(default=False, description="调试模式")
# Milvus向量数据库配置
milvus_host: str = Field(default="localhost", description="Milvus服务地址")
milvus_port: int = Field(default=19530, description="Milvus服务端口")
milvus_collection: str = Field(default="regulations", description="法规向量集合名称")
milvus_db_name: str = Field(default="default", description="Milvus数据库名称")
# 嵌入模型配置
embedding_model: str = Field(default="BAAI/bge-m3", description="嵌入模型名称")
embedding_dim: int = Field(default=1024, description="嵌入向量维度")
embedding_max_length: int = Field(default=8192, description="最大嵌入长度")
embedding_batch_size: int = Field(default=12, description="嵌入批处理大小")
embedding_use_fp16: bool = Field(default=True, description="使用FP16加速")
# MinIO对象存储配置
minio_endpoint: str = Field(default="localhost:9000", description="MinIO服务地址")
minio_access_key: str = Field(default="minioadmin", description="MinIO访问密钥")
minio_secret_key: str = Field(default="minioadmin123", description="MinIO秘密密钥")
minio_bucket: str = Field(default="compliance-docs", description="文档存储桶名称")
minio_secure: bool = Field(default=False, description="是否使用HTTPS")
# Redis配置
redis_host: str = Field(default="localhost", description="Redis服务地址")
redis_port: int = Field(default=6379, description="Redis服务端口")
redis_password: str = Field(default="", description="Redis密码")
redis_db: int = Field(default=0, description="Redis数据库编号")
# PostgreSQL配置
postgres_host: str = Field(default="localhost", description="PostgreSQL服务地址")
postgres_port: int = Field(default=5432, description="PostgreSQL服务端口")
postgres_user: str = Field(default="compliance", description="PostgreSQL用户名")
postgres_password: str = Field(default="compliance123", description="PostgreSQL密码")
postgres_db: str = Field(default="compliance_db", description="PostgreSQL数据库名称")
# 文档处理配置
chunk_size: int = Field(default=512, description="分块大小(字符数)")
chunk_overlap: int = Field(default=50, description="分块重叠大小")
max_file_size_mb: int = Field(default=100, description="最大文件大小(MB)")
# API配置
api_host: str = Field(default="0.0.0.0", description="API服务地址")
api_port: int = Field(default=8000, description="API服务端口")
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
extra = "ignore"
@lru_cache
def get_settings() -> Settings:
"""获取配置实例(缓存)"""
return Settings()
# 导出默认配置实例
settings = get_settings()

2
src/services/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
# src/services/__init__.py
"""业务服务模块"""

View File

@@ -0,0 +1,348 @@
# src/services/document_processor.py
"""文档处理主流程 - 解析→分块→嵌入→入库"""
import os
from typing import List, Dict, Optional
from dataclasses import dataclass
from loguru import logger
import uuid
from .parser.pdf_parser import PDFParser
from .parser.docx_parser import DocxParser
from .parser.mineru_parser import ParserOrchestrator
from .embedding.text_chunker import RegulationChunker, TextChunk
from .embedding.bge_m3_embedder import BGEM3Embedder, EmbeddingResult
from .storage.milvus_client import MilvusClient
from src.config.settings import settings
@dataclass
class ProcessingResult:
"""文档处理结果"""
doc_id: str
doc_name: str
success: bool
num_chunks: int = 0
message: str = ""
markdown_text: str = ""
class DocumentProcessor:
"""
文档处理服务 - 完整处理流程
流程:
1. 文档解析PDF/DOCX → Markdown
2. 智能分块(章节级+条款级)
3. 向量嵌入BGE-M3 Dense+Sparse
4. 存储入库Milvus向量数据库
"""
def __init__(
self,
chunk_size: int = None,
embedding_model: str = None,
use_mineru: bool = True
):
"""
初始化文档处理器
Args:
chunk_size: 分块大小
embedding_model: 嵌入模型名称
use_mineru: 是否优先使用MinerU解析
"""
self.chunk_size = chunk_size or settings.chunk_size
self.embedding_model = embedding_model or settings.embedding_model
self.use_mineru = use_mineru
# 初始化各组件
logger.info("初始化文档处理组件...")
# 解析器
self.parser = ParserOrchestrator()
# 分块器
self.chunker = RegulationChunker(chunk_size=self.chunk_size)
# 嵌入模型(延迟加载,首次使用时初始化)
self.embedder: Optional[BGEM3Embedder] = None
# Milvus客户端延迟连接
self.milvus: Optional[MilvusClient] = None
logger.success("文档处理器初始化完成")
def _init_embedder(self):
"""延迟初始化嵌入模型"""
if self.embedder is None:
logger.info("加载嵌入模型...")
self.embedder = BGEM3Embedder(model_name=self.embedding_model)
def _init_milvus(self):
"""延迟初始化Milvus连接"""
if self.milvus is None:
logger.info("连接Milvus...")
self.milvus = MilvusClient()
self.milvus.connect()
self.milvus.create_collection(recreate=False)
self.milvus.load_collection()
def process(
self,
file_path: str,
doc_name: Optional[str] = None,
regulation_type: str = "",
version: str = ""
) -> ProcessingResult:
"""
处理单个文档
Args:
file_path: 文档文件路径
doc_name: 文档名称(可选,默认从文件名获取)
regulation_type: 法规类型
version: 文档版本
Returns:
ProcessingResult: 处理结果
"""
# 生成文档ID
doc_id = str(uuid.uuid4())[:8]
# 获取文档名称
if doc_name is None:
doc_name = os.path.basename(file_path)
logger.info(f"开始处理文档: {doc_name} (ID: {doc_id})")
try:
# 1. 文档解析
logger.info("Step 1: 文档解析")
markdown_text = self._parse_document(file_path)
if not markdown_text:
return ProcessingResult(
doc_id=doc_id,
doc_name=doc_name,
success=False,
message="文档解析失败,内容为空"
)
# 2. 智能分块
logger.info("Step 2: 智能分块")
chunks = self._chunk_document(
markdown_text,
doc_id,
doc_name,
regulation_type,
version
)
if not chunks:
return ProcessingResult(
doc_id=doc_id,
doc_name=doc_name,
success=False,
message="分块失败,无有效内容"
)
# 3. 向量嵌入
logger.info("Step 3: 向量嵌入")
embeddings = self._embed_chunks(chunks)
if embeddings is None:
return ProcessingResult(
doc_id=doc_id,
doc_name=doc_name,
success=False,
message="向量嵌入失败"
)
# 4. 存储入库
logger.info("Step 4: 存储入库")
inserted_ids = self._insert_to_milvus(chunks, embeddings)
logger.success(f"文档处理完成: {doc_name}, 共{len(inserted_ids)}条记录")
return ProcessingResult(
doc_id=doc_id,
doc_name=doc_name,
success=True,
num_chunks=len(inserted_ids),
message="处理成功",
markdown_text=markdown_text
)
except Exception as e:
logger.error(f"文档处理失败: {e}")
return ProcessingResult(
doc_id=doc_id,
doc_name=doc_name,
success=False,
message=f"处理失败: {str(e)}"
)
def _parse_document(self, file_path: str) -> str:
"""解析文档"""
ext = os.path.splitext(file_path)[1].lower()
try:
if ext == ".pdf":
# PDF文档解析优先MinerU回退PyMuPDF
markdown_text = self.parser.parse_pdf(file_path, prefer_mineru=self.use_mineru)
elif ext in [".docx", ".doc"]:
# Word文档解析
markdown_text = self.parser.parse_docx(file_path)
else:
logger.warning(f"不支持的文件类型: {ext}")
return ""
logger.success(f"文档解析完成,内容长度: {len(markdown_text)}字符")
return markdown_text
except Exception as e:
logger.error(f"文档解析失败: {e}")
return ""
def _chunk_document(
self,
markdown_text: str,
doc_id: str,
doc_name: str,
regulation_type: str,
version: str
) -> List[TextChunk]:
"""分块文档"""
try:
chunks = self.chunker.chunk_document(
markdown_text,
doc_id=doc_id,
doc_name=doc_name,
regulation_type=regulation_type,
version=version
)
logger.success(f"分块完成,共{len(chunks)}个chunk")
return chunks
except Exception as e:
logger.error(f"分块失败: {e}")
return []
def _embed_chunks(self, chunks: List[TextChunk]) -> Optional[EmbeddingResult]:
"""嵌入分块"""
try:
# 延迟初始化嵌入模型
self._init_embedder()
# 提取文本内容
texts = [chunk.content for chunk in chunks]
# 执行嵌入
embeddings = self.embedder.embed(texts)
logger.success(f"嵌入完成,向量数: {len(embeddings.dense_embeddings)}")
return embeddings
except Exception as e:
logger.error(f"嵌入失败: {e}")
return None
def _insert_to_milvus(
self,
chunks: List[TextChunk],
embeddings: EmbeddingResult
) -> List[int]:
"""插入Milvus"""
try:
# 延迟初始化Milvus
self._init_milvus()
# 执行插入
inserted_ids = self.milvus.insert_chunks(chunks, embeddings)
logger.success(f"入库完成,共{len(inserted_ids)}条记录")
return inserted_ids
except Exception as e:
logger.error(f"入库失败: {e}")
return []
def search(
self,
query: str,
top_k: int = 10,
filters: Optional[str] = None
) -> List[Dict]:
"""
检索法规内容
Args:
query: 查询文本
top_k: 返回结果数量
filters: 过滤条件
Returns:
List[Dict]: 检索结果
"""
logger.info(f"执行检索: {query}")
try:
# 延迟初始化
self._init_embedder()
self._init_milvus()
# 生成查询向量
query_embedding = self.embedder.embed_single(query)
# 执行混合检索
results = self.milvus.hybrid_search(
query_dense=query_embedding['dense'].tolist(),
query_sparse=query_embedding['sparse'],
top_k=top_k,
filters=filters
)
# 转换为字典格式
result_dicts = []
for r in results:
result_dicts.append({
"id": r.id,
"content": r.content,
"score": r.score,
"metadata": r.metadata
})
logger.success(f"检索完成,返回{len(result_dicts)}条结果")
return result_dicts
except Exception as e:
logger.error(f"检索失败: {e}")
return []
def close(self):
"""关闭连接"""
if self.milvus:
self.milvus.disconnect()
logger.info("文档处理器已关闭")
def process_document(
file_path: str,
doc_name: Optional[str] = None,
regulation_type: str = "",
version: str = ""
) -> ProcessingResult:
"""便捷函数:处理单个文档"""
processor = DocumentProcessor()
result = processor.process(file_path, doc_name, regulation_type, version)
processor.close()
return result
def search_regulations(query: str, top_k: int = 10) -> List[Dict]:
"""便捷函数:检索法规"""
processor = DocumentProcessor()
results = processor.search(query, top_k)
processor.close()
return results

View File

@@ -0,0 +1,7 @@
# src/services/embedding/__init__.py
"""嵌入和分块服务"""
from .text_chunker import RegulationChunker
from .bge_m3_embedder import BGEM3Embedder
__all__ = ["RegulationChunker", "BGEM3Embedder"]

View File

@@ -0,0 +1,296 @@
# src/services/embedding/bge_m3_embedder.py
"""BGE-M3嵌入服务 - Dense+Sparse双路向量生成"""
import numpy as np
from typing import List, Dict, Optional, Union
from dataclasses import dataclass, field
from loguru import logger
import torch
import os
# 设置HuggingFace镜像国内网络
if 'HF_ENDPOINT' not in os.environ:
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
# 本地模型路径(按优先级检查)
LOCAL_MODEL_PATHS = [
os.path.expanduser("~/.cache/modelscope/Xorbits/bge-m3"), # ModelScope下载路径
os.path.expanduser("~/.cache/huggingface/hub/models--BAAI--bge-m3/snapshots/main"), # HuggingFace本地路径
]
@dataclass
class EmbeddingResult:
"""嵌入结果"""
dense_embeddings: np.ndarray # Dense向量语义检索
sparse_embeddings: List[Dict[int, float]] # Sparse向量关键词匹配
texts: List[str]
dim: int = 1024
class BGEM3Embedder:
"""
BGE-M3多语言嵌入模型服务
BGE-M3是BAAI发布的多语言嵌入模型支持
- Dense向量用于语义相似度检索
- Sparse向量用于关键词精确匹配BM25风格
- ColBERT向量用于细粒度交互匹配可选
特点:
- 支持100+语言(中英双语优化)
- 8192 tokens超长上下文
- Dense+Sparse双路检索能力
GitHub: https://github.com/FlagOpen/FlagEmbedding
"""
def __init__(
self,
model_name: str = "BAAI/bge-m3",
use_fp16: bool = True,
device: Optional[str] = None,
batch_size: int = 12,
max_length: int = 8192,
local_model_path: Optional[str] = None
):
"""
初始化BGE-M3嵌入模型
Args:
model_name: 模型名称(如果使用本地路径,此参数会被忽略)
use_fp16: 是否使用FP16加速
device: 设备类型cuda/cpu默认自动选择
batch_size: 批处理大小
max_length: 最大序列长度
local_model_path: 本地模型路径(可选,优先使用)
"""
self.use_fp16 = use_fp16
self.batch_size = batch_size
self.max_length = max_length
# 确定模型路径(优先使用本地路径)
if local_model_path and os.path.exists(local_model_path):
self.model_path = local_model_path
self.model_name = "local"
logger.info(f"使用本地模型路径: {local_model_path}")
else:
# 检查多个可能的本地路径
found_local = False
for path in LOCAL_MODEL_PATHS:
if os.path.exists(path) and os.path.exists(os.path.join(path, "config.json")):
self.model_path = path
self.model_name = "local"
logger.info(f"使用本地模型路径: {path}")
found_local = True
break
if not found_local:
self.model_path = model_name
self.model_name = model_name
logger.info(f"使用远程模型: {model_name}")
# 自动选择设备
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
logger.info(f"初始化BGE-M3模型, 设备: {self.device}")
self.model = None
self._load_model()
def _load_model(self):
"""加载嵌入模型"""
try:
from FlagEmbedding import BGEM3FlagModel
self.model = BGEM3FlagModel(
self.model_path,
use_fp16=self.use_fp16,
device=self.device
)
logger.success(f"BGE-M3模型加载成功")
except ImportError:
logger.warning("FlagEmbedding库未安装请运行: pip install FlagEmbedding")
raise
except Exception as e:
logger.error(f"模型加载失败: {e}")
raise
def embed(
self,
texts: List[str],
return_dense: bool = True,
return_sparse: bool = True,
return_colbert_vecs: bool = False
) -> EmbeddingResult:
"""
对文本列表生成嵌入向量
Args:
texts: 文本列表
return_dense: 是否返回Dense向量
return_sparse: 是否返回Sparse向量
return_colbert_vecs: 是否返回ColBERT向量
Returns:
EmbeddingResult: 嵌入结果
"""
if not texts:
logger.warning("输入文本列表为空")
return EmbeddingResult(
dense_embeddings=np.array([]),
sparse_embeddings=[],
texts=[],
dim=0
)
logger.info(f"开始嵌入{len(texts)}个文本块")
try:
# 执行嵌入
embeddings = self.model.encode(
texts,
batch_size=self.batch_size,
max_length=self.max_length,
return_dense=return_dense,
return_sparse=return_sparse,
return_colbert_vecs=return_colbert_vecs
)
# 提取结果
dense_embeddings = embeddings.get('dense_vecs', np.array([]))
sparse_embeddings = embeddings.get('lexical_weights', [])
# 获取维度
dim = dense_embeddings.shape[1] if len(dense_embeddings) > 0 else 1024
logger.success(f"嵌入完成,向量维度: {dim}")
return EmbeddingResult(
dense_embeddings=dense_embeddings,
sparse_embeddings=sparse_embeddings,
texts=texts,
dim=dim
)
except Exception as e:
logger.error(f"嵌入失败: {e}")
raise
def embed_single(self, text: str) -> Dict[str, Union[np.ndarray, Dict]]:
"""
对单个文本生成嵌入向量
Args:
text: 输入文本
Returns:
Dict: 包含dense和sparse向量
"""
result = self.embed([text])
return {
'dense': result.dense_embeddings[0],
'sparse': result.sparse_embeddings[0] if result.sparse_embeddings else {},
'dim': result.dim
}
def embed_dense(self, texts: List[str]) -> np.ndarray:
"""只生成Dense向量"""
result = self.embed(texts, return_sparse=False, return_colbert_vecs=False)
return result.dense_embeddings
def embed_sparse(self, texts: List[str]) -> List[Dict[int, float]]:
"""只生成Sparse向量"""
result = self.embed(texts, return_dense=False, return_colbert_vecs=False)
return result.sparse_embeddings
def embed_query(self, query: str) -> Dict:
"""
对查询文本生成嵌入(用于检索)
Args:
query: 查询文本
Returns:
Dict: 包含dense和sparse向量
"""
return self.embed_single(query)
def compute_similarity(
self,
query_embedding: np.ndarray,
doc_embeddings: np.ndarray,
metric: str = "cosine"
) -> np.ndarray:
"""
计算查询与文档的相似度
Args:
query_embedding: 查询向量
doc_embeddings: 文档向量矩阵
metric: 相似度度量cosine/dot
Returns:
np.ndarray: 相似度分数数组
"""
if metric == "cosine":
# 余弦相似度
query_norm = np.linalg.norm(query_embedding)
doc_norms = np.linalg.norm(doc_embeddings, axis=1)
similarities = np.dot(doc_embeddings, query_embedding) / (doc_norms * query_norm)
elif metric == "dot":
# 点积相似度
similarities = np.dot(doc_embeddings, query_embedding)
else:
raise ValueError(f"不支持的相似度度量: {metric}")
return similarities
def sparse_similarity(
self,
query_sparse: Dict[int, float],
doc_sparse: Dict[int, float]
) -> float:
"""
计算Sparse向量的相似度BM25风格
Args:
query_sparse: 查询的Sparse向量词ID -> 权重)
doc_sparse: 文档的Sparse向量
Returns:
float: 相似度分数
"""
# 计算交集词的点积
common_keys = set(query_sparse.keys()) & set(doc_sparse.keys())
score = sum(query_sparse[k] * doc_sparse[k] for k in common_keys)
return score
def embed_texts(
texts: List[str],
model_name: str = "BAAI/bge-m3",
**kwargs
) -> EmbeddingResult:
"""便捷函数:对文本列表生成嵌入"""
embedder = BGEM3Embedder(model_name=model_name, **kwargs)
return embedder.embed(texts)
def embed_single_text(
text: str,
model_name: str = "BAAI/bge-m3",
**kwargs
) -> Dict:
"""便捷函数:对单个文本生成嵌入"""
embedder = BGEM3Embedder(model_name=model_name, **kwargs)
return embedder.embed_single(text)

View File

@@ -0,0 +1,449 @@
# src/services/embedding/text_chunker.py
"""智能分块器 - 章节级+条款级双粒度切割"""
import re
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass, field
from loguru import logger
@dataclass
class ChunkMetadata:
"""分块元数据"""
doc_id: str = ""
doc_name: str = ""
chunk_id: str = ""
section_number: str = "" # 章节编号(如 "第一章"
section_title: str = "" # 章节标题
clause_number: str = "" # 条款编号(如 "第一条"
page_number: int = 0
start_position: int = 0 # 在原文中的起始位置
end_position: int = 0 # 在原文中的结束位置
regulation_type: str = "" # 法规类型
version: str = ""
@dataclass
class TextChunk:
"""文本分块"""
content: str
metadata: ChunkMetadata
token_count: int = 0 # 估算的token数量
class RegulationChunker:
"""
法规文档智能分块器
实现章节级/条款级双粒度切割适配国标GB文档结构
- 国标文档通常有明确的层级结构:章 > 节 > 条
- 每个条款应作为一个独立的语义单元
- 保留条款完整性,避免跨条款截断
"""
# 法规标题模式
CHAPTER_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+章\s+[^\n]+')
SECTION_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+节\s+[^\n]+')
CLAUSE_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+条\s')
# 条款子项模式
SUB_ITEM_PATTERN = re.compile(r'^[\(][一二三四五六七八九十]+[\)]\s')
NUMBER_ITEM_PATTERN = re.compile(r'^[\d]+[\.、]\s')
def __init__(
self,
chunk_size: int = 512,
chunk_overlap: int = 50,
max_chunk_size: int = 2048,
min_chunk_size: int = 100
):
"""
初始化分块器
Args:
chunk_size: 默认分块大小(字符数)
chunk_overlap: 分块重叠大小
max_chunk_size: 最大分块大小(防止单个条款过长)
min_chunk_size: 最小分块大小(防止碎片化)
"""
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.max_chunk_size = max_chunk_size
self.min_chunk_size = min_chunk_size
def chunk_document(
self,
markdown_text: str,
doc_id: str = "",
doc_name: str = "",
regulation_type: str = "",
version: str = ""
) -> List[TextChunk]:
"""
对法规文档进行智能分块
Args:
markdown_text: Markdown格式的文档内容
doc_id: 文档ID
doc_name: 文档名称
regulation_type: 法规类型
version: 文档版本
Returns:
List[TextChunk]: 分块列表
"""
logger.info(f"开始分块文档: {doc_name}")
# 1. 按章节分割(一级分块)
sections = self._split_by_sections(markdown_text)
# 2. 在每个章节内按条款分割(二级分块)
chunks = []
global_position = 0
for section_num, section_title, section_content, section_start in sections:
# 在章节内按条款分割
clause_chunks = self._split_by_clauses(
section_content,
section_num,
section_title,
section_start + global_position
)
for chunk_content, clause_num, clause_title, start_pos, end_pos in clause_chunks:
# 处理过长的条款(进一步细分)
if len(chunk_content) > self.max_chunk_size:
sub_chunks = self._split_long_clause(
chunk_content,
clause_num,
clause_title
)
for sub_content, sub_start, sub_end in sub_chunks:
chunk = self._create_chunk(
sub_content,
doc_id,
doc_name,
section_num,
section_title,
clause_num,
sub_start + start_pos,
sub_end + start_pos,
regulation_type,
version
)
chunks.append(chunk)
else:
chunk = self._create_chunk(
chunk_content,
doc_id,
doc_name,
section_num,
section_title,
clause_num,
start_pos,
end_pos,
regulation_type,
version
)
chunks.append(chunk)
logger.success(f"分块完成,共{len(chunks)}个chunk")
return chunks
def _split_by_sections(self, markdown_text: str) -> List[Tuple[str, str, str, int]]:
"""
按章节分割文档
Returns:
List of (section_number, section_title, section_content, start_position)
"""
sections = []
lines = markdown_text.split('\n')
current_section_num = ""
current_section_title = ""
current_section_content = []
current_section_start = 0
for i, line in enumerate(lines):
# 检测章节标题
chapter_match = self.CHAPTER_PATTERN.match(line.strip())
section_match = self.SECTION_PATTERN.match(line.strip())
if chapter_match or section_match:
# 保存上一个章节
if current_section_content:
content = '\n'.join(current_section_content)
sections.append((
current_section_num,
current_section_title,
content,
current_section_start
))
# 开始新章节
current_section_start = sum(len(l) + 1 for l in lines[:i])
current_section_content = []
if chapter_match:
current_section_num = line.strip()
current_section_title = self._extract_title(line.strip())
else:
current_section_num = line.strip()
current_section_title = self._extract_title(line.strip())
current_section_content.append(line)
# 保存最后一个章节
if current_section_content:
content = '\n'.join(current_section_content)
sections.append((
current_section_num,
current_section_title,
content,
current_section_start
))
# 如果没有检测到章节,将整个文档作为一个大章节
if not sections:
sections.append((
"",
"全文",
markdown_text,
0
))
return sections
def _split_by_clauses(
self,
section_content: str,
section_num: str,
section_title: str,
section_start: int
) -> List[Tuple[str, str, str, int, int]]:
"""
在章节内按条款分割
Returns:
List of (content, clause_number, clause_title, start_position, end_position)
"""
clauses = []
lines = section_content.split('\n')
current_clause_num = ""
current_clause_title = ""
current_clause_content = []
current_clause_start = section_start
for i, line in enumerate(lines):
# 检测条款标题
clause_match = self.CLAUSE_PATTERN.match(line.strip())
if clause_match:
# 保存上一个条款
if current_clause_content:
content = '\n'.join(current_clause_content)
end_pos = current_clause_start + len(content)
clauses.append((
content,
current_clause_num,
current_clause_title,
current_clause_start,
end_pos
))
# 开始新条款
current_clause_start = section_start + sum(len(l) + 1 for l in lines[:i])
current_clause_content = []
current_clause_num = self._extract_clause_number(line.strip())
current_clause_title = line.strip()
current_clause_content.append(line)
# 保存最后一个条款
if current_clause_content:
content = '\n'.join(current_clause_content)
end_pos = current_clause_start + len(content)
clauses.append((
content,
current_clause_num,
current_clause_title,
current_clause_start,
end_pos
))
# 如果没有检测到条款,将整个章节作为一个条款
if not clauses:
clauses.append((
section_content,
"",
section_title,
section_start,
section_start + len(section_content)
))
return clauses
def _split_long_clause(
self,
content: str,
clause_num: str,
clause_title: str
) -> List[Tuple[str, int, int]]:
"""
分割过长的条款内容
按条款子项或段落分割,保持语义完整性
"""
sub_chunks = []
lines = content.split('\n')
# 检测是否有子项结构
has_sub_items = any(
self.SUB_ITEM_PATTERN.match(line.strip()) or
self.NUMBER_ITEM_PATTERN.match(line.strip())
for line in lines
)
if has_sub_items:
# 按子项分割
current_sub_content = []
current_sub_start = 0
for i, line in enumerate(lines):
is_sub_item = (
self.SUB_ITEM_PATTERN.match(line.strip()) or
self.NUMBER_ITEM_PATTERN.match(line.strip())
)
if is_sub_item and current_sub_content:
sub_content = '\n'.join(current_sub_content)
sub_end = current_sub_start + len(sub_content)
if len(sub_content) >= self.min_chunk_size:
sub_chunks.append((sub_content, current_sub_start, sub_end))
current_sub_content = []
current_sub_start = sum(len(l) + 1 for l in lines[:i])
current_sub_content.append(line)
# 保存最后一个子项
if current_sub_content:
sub_content = '\n'.join(current_sub_content)
sub_end = current_sub_start + len(sub_content)
sub_chunks.append((sub_content, current_sub_start, sub_end))
else:
# 按段落分割(滑动窗口)
paragraphs = []
current_para = []
for line in lines:
if line.strip():
current_para.append(line)
else:
if current_para:
paragraphs.append('\n'.join(current_para))
current_para = []
if current_para:
paragraphs.append('\n'.join(current_para))
# 合并段落直到达到chunk_size
current_chunk = []
current_length = 0
chunk_start = 0
for para in paragraphs:
if current_length + len(para) > self.chunk_size and current_chunk:
chunk_content = '\n'.join(current_chunk)
chunk_end = chunk_start + len(chunk_content)
sub_chunks.append((chunk_content, chunk_start, chunk_end))
current_chunk = []
current_length = 0
chunk_start = chunk_end
current_chunk.append(para)
current_length += len(para)
# 保存最后一个chunk
if current_chunk:
chunk_content = '\n'.join(current_chunk)
chunk_end = chunk_start + len(chunk_content)
sub_chunks.append((chunk_content, chunk_start, chunk_end))
return sub_chunks
def _extract_title(self, header_line: str) -> str:
"""从标题行提取标题内容"""
# 移除"第X章"、"第X节"前缀
title = re.sub(r'^第[一二三四五六七八九十百]+[章节]\s+', '', header_line)
return title.strip()
def _extract_clause_number(self, clause_line: str) -> str:
"""从条款行提取条款编号"""
match = self.CLAUSE_PATTERN.match(clause_line)
if match:
return match.group(0).strip()
return ""
def _create_chunk(
self,
content: str,
doc_id: str,
doc_name: str,
section_num: str,
section_title: str,
clause_num: str,
start_pos: int,
end_pos: int,
regulation_type: str,
version: str
) -> TextChunk:
"""创建文本分块"""
# 清理内容
content = content.strip()
# 计算估算token数中文约1.5字符/token
token_count = int(len(content) * 0.7) # 简化估算
# 生成chunk_id
chunk_id = f"{doc_id}_{section_num}_{clause_num}_{start_pos}"
metadata = ChunkMetadata(
doc_id=doc_id,
doc_name=doc_name,
chunk_id=chunk_id,
section_number=section_num,
section_title=section_title,
clause_number=clause_num,
start_position=start_pos,
end_position=end_pos,
regulation_type=regulation_type,
version=version
)
return TextChunk(
content=content,
metadata=metadata,
token_count=token_count
)
def chunk_regulation_document(
markdown_text: str,
doc_id: str = "",
doc_name: str = "",
regulation_type: str = "",
version: str = "",
chunk_size: int = 512
) -> List[TextChunk]:
"""便捷函数:对法规文档进行分块"""
chunker = RegulationChunker(chunk_size=chunk_size)
return chunker.chunk_document(
markdown_text,
doc_id,
doc_name,
regulation_type,
version
)

View File

@@ -0,0 +1,7 @@
# src/services/parser/__init__.py
"""文档解析服务"""
from .pdf_parser import PDFParser
from .docx_parser import DocxParser
__all__ = ["PDFParser", "DocxParser"]

View File

@@ -0,0 +1,287 @@
# src/services/parser/docx_parser.py
"""Word文档解析 - 使用python-docx"""
from docx import Document
from docx.enum.text import WD_ALIGN_PARAGRAPH
from typing import List, Dict, Optional
from dataclasses import dataclass, field
from loguru import logger
import re
@dataclass
class DocxParagraph:
"""段落内容"""
text: str
level: int = 0 # 标题级别0表示正文
is_list: bool = False
list_number: Optional[str] = None
@dataclass
class DocxTable:
"""表格内容"""
rows: List[List[str]]
markdown: str = ""
@dataclass
class DocxDocumentContent:
"""Word文档完整内容"""
file_path: str
paragraphs: List[DocxParagraph]
tables: List[DocxTable]
metadata: Dict[str, str] = field(default_factory=dict)
markdown_text: str = ""
class DocxParser:
"""Word文档解析器 - 基于python-docx"""
def __init__(self):
self.document = None
def parse(self, file_path: str) -> DocxDocumentContent:
"""
解析Word文档
Args:
file_path: Word文档路径
Returns:
DocxDocumentContent: 解析后的文档内容
"""
logger.info(f"开始解析Word文档: {file_path}")
try:
self.document = Document(file_path)
doc_content = DocxDocumentContent(
file_path=file_path,
paragraphs=[],
tables=[]
)
# 提取文档元数据
doc_content.metadata = self._extract_metadata()
# 提取段落
doc_content.paragraphs = self._extract_paragraphs()
# 提取表格
doc_content.tables = self._extract_tables()
# 生成Markdown格式文本
doc_content.markdown_text = self._generate_markdown(doc_content)
logger.success(f"Word文档解析完成{len(doc_content.paragraphs)}个段落")
return doc_content
except Exception as e:
logger.error(f"Word文档解析失败: {e}")
raise
def _extract_metadata(self) -> Dict[str, str]:
"""提取文档元数据"""
metadata = {}
try:
core_props = self.document.core_properties
metadata = {
"title": core_props.title or "",
"author": core_props.author or "",
"subject": core_props.subject or "",
"keywords": core_props.keywords or "",
"created": str(core_props.created) if core_props.created else "",
"modified": str(core_props.modified) if core_props.modified else "",
}
except Exception as e:
logger.warning(f"提取元数据失败: {e}")
return metadata
def _extract_paragraphs(self) -> List[DocxParagraph]:
"""提取所有段落"""
paragraphs = []
for para in self.document.paragraphs:
text = para.text.strip()
if not text:
continue
# 判断标题级别
level = self._get_paragraph_level(para)
# 判断是否是列表项
is_list, list_number = self._detect_list_item(para)
paragraph = DocxParagraph(
text=text,
level=level,
is_list=is_list,
list_number=list_number
)
paragraphs.append(paragraph)
return paragraphs
def _get_paragraph_level(self, para) -> int:
"""
判断段落标题级别
Returns:
int: 标题级别0表示正文
"""
# 方法1检查段落样式
style_name = para.style.name if para.style else ""
if "Heading" in style_name or "标题" in style_name:
# 从样式名称中提取级别
match = re.search(r'Heading\s*(\d)|标题\s*(\d)', style_name)
if match:
level = int(match.group(1) or match.group(2))
return level
# 方法2检查段落格式字号
# 标题通常字号较大
if para.paragraph_format:
# 可以根据字号判断,这里简化处理
pass
# 方法3根据内容模式判断法规文档特征
text = para.text.strip()
# 第一章、第X章 -> 二级标题
if re.match(r'^第[一二三四五六七八九十百]+章\s', text):
return 2
# 第X节 -> 三级标题
elif re.match(r'^第[一二三四五六七八九十百]+节\s', text):
return 3
# 第X条 -> 四级标题
elif re.match(r'^第[一二三四五六七八九十百]+条\s', text):
return 4
return 0 # 正文
def _detect_list_item(self, para) -> tuple[bool, Optional[str]]:
"""检测是否是列表项"""
text = para.text.strip()
# 数字列表1.、2.、1、[1]等
if re.match(r'^[\d]+[.、)\]]\s', text):
match = re.match(r'^([\d]+[.、)\]])\s', text)
return True, match.group(1) if match else None
# 中文数字列表:一、二、(一)等
if re.match(r'^[一二三四五六七八九十]+[、.)]\s', text):
match = re.match(r'^([一二三四五六七八九十]+[、.)])\s', text)
return True, match.group(1) if match else None
# 检查段落格式中的列表编号
if para.paragraph_format and hasattr(para.paragraph_format, 'left_indent'):
# 有缩进的可能是列表项
pass
return False, None
def _extract_tables(self) -> List[DocxTable]:
"""提取所有表格"""
tables = []
for table in self.document.tables:
rows = []
for row in table.rows:
cells = []
for cell in row.cells:
cells.append(cell.text.strip())
rows.append(cells)
# 转换为Markdown表格
markdown = self._table_to_markdown(rows)
table_content = DocxTable(rows=rows, markdown=markdown)
tables.append(table_content)
return tables
def _table_to_markdown(self, rows: List[List[str]]) -> str:
"""将表格转换为Markdown格式"""
if not rows or len(rows) < 1:
return ""
lines = []
# 表头
if len(rows) >= 1:
header = rows[0]
lines.append("| " + " | ".join(cell for cell in header) + " |")
lines.append("| " + " | ".join("---" for _ in header) + " |")
# 数据行
for row in rows[1:]:
lines.append("| " + " | ".join(cell for cell in row) + " |")
return "\n".join(lines)
def _generate_markdown(self, doc_content: DocxDocumentContent) -> str:
"""生成Markdown格式文本"""
lines = []
# 文档标题
title = doc_content.metadata.get("title", "")
if title:
lines.append(f"# {title}\n")
else:
# 从第一个段落获取标题(如果是标题样式)
for para in doc_content.paragraphs[:5]:
if para.level == 1:
lines.append(f"# {para.text}\n")
break
else:
lines.append(f"# {doc_content.file_path}\n")
# 元数据信息
lines.append("\n## 文档信息\n")
for key, value in doc_content.metadata.items():
if value:
lines.append(f"- **{key}**: {value}")
# 正文内容
lines.append("\n## 正文\n")
table_index = 0
for para in doc_content.paragraphs:
if para.level > 0:
# 标题
prefix = "#" * para.level
lines.append(f"\n{prefix} {para.text}\n")
elif para.is_list:
# 列表项
lines.append(f"- {para.text}")
else:
# 正文
lines.append(para.text)
# 添加表格
if doc_content.tables:
lines.append("\n## 表格\n")
for i, table in enumerate(doc_content.tables):
lines.append(f"\n### 表格 {i + 1}\n")
lines.append(table.markdown + "\n")
return "\n".join(lines)
def parse_to_markdown(self, file_path: str) -> str:
"""直接解析并返回Markdown文本"""
doc_content = self.parse(file_path)
return doc_content.markdown_text
def parse_docx(file_path: str) -> DocxDocumentContent:
"""便捷函数解析Word文档"""
parser = DocxParser()
return parser.parse(file_path)
def parse_docx_to_markdown(file_path: str) -> str:
"""便捷函数解析Word并返回Markdown"""
parser = DocxParser()
return parser.parse_to_markdown(file_path)

View File

@@ -0,0 +1,204 @@
# src/services/parser/mineru_parser.py
"""MinerU多模态PDF解析 - 版面感知解析"""
from typing import Optional, Dict
from dataclasses import dataclass, field
from loguru import logger
import os
@dataclass
class MinerUResult:
"""MinerU解析结果"""
file_path: str
markdown_text: str
metadata: Dict[str, str] = field(default_factory=dict)
success: bool = True
error_message: str = ""
class MinerUParser:
"""
MinerU多模态PDF解析器
MinerU (magic-pdf) 是一个开源的高质量PDF解析工具
支持版面感知解析,能够识别文档中的标题、正文、表格、图片等元素,
并输出结构化的Markdown格式。
GitHub: https://github.com/opendatalab/MinerU
"""
def __init__(self):
self.available = self._check_mineru_available()
def _check_mineru_available(self) -> bool:
"""检查MinerU是否可用"""
try:
from magic_pdf.pipe.UNIPipe import UNIPipe
return True
except ImportError:
logger.warning("MinerU (magic-pdf) 未安装,请运行: pip install magic-pdf[full]")
return False
def parse(self, file_path: str, output_dir: Optional[str] = None) -> MinerUResult:
"""
使用MinerU解析PDF文档
Args:
file_path: PDF文件路径
output_dir: 输出目录(可选,用于保存解析产物)
Returns:
MinerUResult: 解析结果
"""
logger.info(f"尝试使用MinerU解析: {file_path}")
if not self.available:
return MinerUResult(
file_path=file_path,
markdown_text="",
success=False,
error_message="MinerU未安装"
)
try:
from magic_pdf.pipe.UNIPipe import UNIPipe
from magic_pdf.libs.MakeContentConfig import DropMode
# 设置输出目录
if output_dir is None:
output_dir = os.path.dirname(file_path)
# 创建解析管道
# OCR模式可以根据PDF类型选择
# auto: 自动判断是否需要OCR
# txt: 纯文本PDF无OCR
# ocr: 扫描件PDFOCR
pipe = UNIPipe(file_path, output_dir)
# 执行解析
# pipe_mk() 返回Markdown格式文本
markdown_content = pipe.pipe_mk()
logger.success(f"MinerU解析成功")
return MinerUResult(
file_path=file_path,
markdown_text=markdown_content,
metadata=self._extract_metadata(pipe),
success=True
)
except Exception as e:
logger.error(f"MinerU解析失败: {e}")
return MinerUResult(
file_path=file_path,
markdown_text="",
success=False,
error_message=str(e)
)
def _extract_metadata(self, pipe) -> Dict[str, str]:
"""从解析管道提取元数据"""
metadata = {}
try:
# MinerU解析管道中可能包含的元数据信息
if hasattr(pipe, 'pdf_mid_data') and pipe.pdf_mid_data:
mid_data = pipe.pdf_mid_data
# 提取可能的元数据字段
metadata = {
"page_count": str(mid_data.get("page_count", "")),
"language": str(mid_data.get("language", "")),
"is_scanned": str(mid_data.get("is_scanned", "")),
}
except Exception as e:
logger.warning(f"提取MinerU元数据失败: {e}")
return metadata
def parse_to_markdown(self, file_path: str) -> str:
"""直接解析并返回Markdown文本"""
result = self.parse(file_path)
return result.markdown_text if result.success else ""
class ParserOrchestrator:
"""
解析服务编排 - 按优先级选择解析器
解析策略:
1. 优先尝试MinerU版面感知能力强
2. MinerU失败时回退到基础PyMuPDF解析
"""
def __init__(self):
from .pdf_parser import PDFParser
self.mineru_parser = MinerUParser()
self.pdf_parser = PDFParser()
self.mineru_available = self.mineru_parser.available
def parse_pdf(self, file_path: str, prefer_mineru: bool = True) -> str:
"""
解析PDF文档按优先级选择解析器
Args:
file_path: PDF文件路径
prefer_mineru: 是否优先使用MinerU
Returns:
str: Markdown格式文本
"""
markdown_text = ""
if prefer_mineru and self.mineru_available:
# 优先尝试MinerU
result = self.mineru_parser.parse(file_path)
if result.success:
markdown_text = result.markdown_text
logger.info("使用MinerU解析成功")
return markdown_text
else:
logger.warning(f"MinerU解析失败回退到PyMuPDF: {result.error_message}")
# 回退到PyMuPDF基础解析
logger.info("使用PyMuPDF基础解析")
markdown_text = self.pdf_parser.parse_to_markdown(file_path)
return markdown_text
def parse_docx(self, file_path: str) -> str:
"""解析Word文档"""
from .docx_parser import DocxParser
docx_parser = DocxParser()
return docx_parser.parse_to_markdown(file_path)
def parse(self, file_path: str) -> str:
"""
根据文件类型选择解析器
Args:
file_path: 文件路径
Returns:
str: Markdown格式文本
"""
ext = os.path.splitext(file_path)[1].lower()
if ext == ".pdf":
return self.parse_pdf(file_path)
elif ext in [".docx", ".doc"]:
return self.parse_docx(file_path)
else:
raise ValueError(f"不支持的文件类型: {ext}")
def parse_with_mineru(file_path: str) -> MinerUResult:
"""便捷函数使用MinerU解析"""
parser = MinerUParser()
return parser.parse(file_path)
def parse_pdf_smart(file_path: str) -> str:
"""便捷函数智能解析PDF自动选择最佳解析器"""
orchestrator = ParserOrchestrator()
return orchestrator.parse_pdf(file_path)

View File

@@ -0,0 +1,268 @@
# src/services/parser/pdf_parser.py
"""PDF文档解析 - 使用PyMuPDF基础解析"""
import fitz # PyMuPDF
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass, field
from loguru import logger
import re
@dataclass
class PDFPageContent:
"""PDF页面内容"""
page_number: int
text: str
tables: List[str] = field(default_factory=list)
images: List[str] = field(default_factory=list) # 图片路径列表
blocks: List[Dict] = field(default_factory=list)
@dataclass
class PDFDocumentContent:
"""PDF文档完整内容"""
file_path: str
total_pages: int
pages: List[PDFPageContent]
metadata: Dict[str, str] = field(default_factory=dict)
markdown_text: str = ""
class PDFParser:
"""PDF文档解析器 - 基于PyMuPDF"""
def __init__(self):
self.pdf = None
def parse(self, file_path: str, extract_tables: bool = True, extract_images: bool = False) -> PDFDocumentContent:
"""
解析PDF文档
Args:
file_path: PDF文件路径
extract_tables: 是否提取表格
extract_images: 是否提取图片
Returns:
PDFDocumentContent: 解析后的文档内容
"""
logger.info(f"开始解析PDF文档: {file_path}")
try:
self.pdf = fitz.open(file_path)
doc_content = PDFDocumentContent(
file_path=file_path,
total_pages=self.pdf.page_count,
pages=[]
)
# 提取文档元数据
doc_content.metadata = self._extract_metadata()
# 逐页解析
for page_num in range(self.pdf.page_count):
page = self.pdf[page_num]
page_content = self._parse_page(page, page_num + 1, extract_tables, extract_images)
doc_content.pages.append(page_content)
# 生成Markdown格式文本
doc_content.markdown_text = self._generate_markdown(doc_content)
self.pdf.close()
logger.success(f"PDF解析完成{doc_content.total_pages}")
return doc_content
except Exception as e:
logger.error(f"PDF解析失败: {e}")
raise
def _extract_metadata(self) -> Dict[str, str]:
"""提取PDF元数据"""
metadata = {}
try:
meta = self.pdf.metadata
metadata = {
"title": meta.get("title", ""),
"author": meta.get("author", ""),
"subject": meta.get("subject", ""),
"keywords": meta.get("keywords", ""),
"creator": meta.get("creator", ""),
"producer": meta.get("producer", ""),
"creation_date": meta.get("creationDate", ""),
"mod_date": meta.get("modDate", ""),
}
except Exception as e:
logger.warning(f"提取元数据失败: {e}")
return metadata
def _parse_page(self, page: fitz.Page, page_num: int,
extract_tables: bool, extract_images: bool) -> PDFPageContent:
"""解析单页内容"""
page_content = PDFPageContent(page_number=page_num, text="")
# 提取文本块(保留结构)
blocks = page.get_text("dict", flags=fitz.TEXT_PRESERVE_WHITESPACE)["blocks"]
page_content.blocks = blocks
# 提取纯文本
text = page.get_text("text", flags=fitz.TEXT_PRESERVE_WHITESPACE)
page_content.text = text.strip()
# 提取表格使用PyMuPDF的表格提取功能
if extract_tables:
tables = self._extract_tables_from_page(page)
page_content.tables = tables
# 提取图片
if extract_images:
images = self._extract_images_from_page(page, page_num)
page_content.images = images
return page_content
def _extract_tables_from_page(self, page: fitz.Page) -> List[str]:
"""
从页面提取表格(基于文本块分析)
注意PyMuPDF基础版表格提取能力有限复杂表格建议使用MinerU
"""
tables = []
try:
# 使用PyMuPDF的表格提取方法2.4+版本)
# 对于更复杂的表格需要在mineru_parser中使用更高级的方法
tabs = page.find_tables()
if tabs:
for tab in tabs:
table_text = tab.extract()
# 将表格转换为Markdown格式
markdown_table = self._table_to_markdown(table_text)
tables.append(markdown_table)
except AttributeError:
# 旧版本PyMuPDF没有表格提取功能
logger.warning("PyMuPDF版本不支持表格提取请升级到2.4+版本")
except Exception as e:
logger.warning(f"表格提取失败: {e}")
return tables
def _table_to_markdown(self, table_data: List[List[str]]) -> str:
"""将表格数据转换为Markdown格式"""
if not table_data or len(table_data) < 1:
return ""
lines = []
# 表头
if len(table_data) >= 1:
header = table_data[0]
lines.append("| " + " | ".join(str(cell).strip() for cell in header) + " |")
lines.append("| " + " | ".join("---" for _ in header) + " |")
# 数据行
for row in table_data[1:]:
lines.append("| " + " | ".join(str(cell).strip() for cell in row) + " |")
return "\n".join(lines)
def _extract_images_from_page(self, page: fitz.Page, page_num: int) -> List[str]:
"""提取页面图片"""
images = []
# 图片提取功能(可选实现)
# 这里仅记录图片信息,实际图片需要额外保存
try:
image_list = page.get_images()
for img_index, img in enumerate(image_list):
xref = img[0]
images.append(f"image_p{page_num}_i{img_index}_xref{xref}")
except Exception as e:
logger.warning(f"图片提取失败: {e}")
return images
def _generate_markdown(self, doc_content: PDFDocumentContent) -> str:
"""生成Markdown格式文本"""
lines = []
# 文档标题
title = doc_content.metadata.get("title", "")
if title:
lines.append(f"# {title}\n")
else:
lines.append(f"# {doc_content.file_path}\n")
# 元数据信息
lines.append("\n## 文档信息\n")
for key, value in doc_content.metadata.items():
if value and key in ["author", "subject", "keywords", "creation_date"]:
lines.append(f"- **{key}**: {value}")
# 正文内容
lines.append("\n## 正文\n")
for page in doc_content.pages:
# 页码标记
lines.append(f"\n---\n**第 {page.page_number} 页**\n")
# 处理文本内容,识别标题结构
text = self._process_page_text(page.text, page.blocks)
lines.append(text)
# 添加表格
for table in page.tables:
lines.append("\n" + table + "\n")
return "\n".join(lines)
def _process_page_text(self, text: str, blocks: List[Dict]) -> str:
"""处理页面文本,识别标题结构"""
# 基于字体大小识别标题
processed_text = text
# 尝试识别标题(基于字号)
# 法规文档通常有明确的层级结构:章、节、条
processed_text = self._detect_headers(text, blocks)
return processed_text
def _detect_headers(self, text: str, blocks: List[Dict]) -> str:
"""检测并标记标题(基于字号或内容模式)"""
lines = text.split("\n")
processed_lines = []
for line in lines:
line = line.strip()
if not line:
continue
# 法规标题模式检测
# 第一章、第X章、第X节、第X条等
if re.match(r'^第[一二三四五六七八九十百]+章\s', line):
processed_lines.append(f"\n## {line}\n")
elif re.match(r'^第[一二三四五六七八九十百]+节\s', line):
processed_lines.append(f"\n### {line}\n")
elif re.match(r'^第[一二三四五六七八九十百]+条\s', line):
processed_lines.append(f"\n#### {line}\n")
elif re.match(r'^[一二三四五六七八九十]+\s*[、.]', line):
# 条款子项
processed_lines.append(f"- {line}")
else:
processed_lines.append(line)
return "\n".join(processed_lines)
def parse_to_markdown(self, file_path: str) -> str:
"""直接解析并返回Markdown文本"""
doc_content = self.parse(file_path)
return doc_content.markdown_text
def parse_pdf(file_path: str, **kwargs) -> PDFDocumentContent:
"""便捷函数解析PDF文档"""
parser = PDFParser()
return parser.parse(file_path, **kwargs)
def parse_pdf_to_markdown(file_path: str) -> str:
"""便捷函数解析PDF并返回Markdown"""
parser = PDFParser()
return parser.parse_to_markdown(file_path)

View File

@@ -0,0 +1,6 @@
# src/services/storage/__init__.py
"""存储服务"""
from .milvus_client import MilvusClient
__all__ = ["MilvusClient"]

View File

@@ -0,0 +1,485 @@
# src/services/storage/milvus_client.py
"""Milvus向量数据库客户端 - 存储与检索服务"""
from pymilvus import (
connections,
Collection,
FieldSchema,
CollectionSchema,
DataType,
utility
)
from typing import List, Dict, Optional, Any
from dataclasses import dataclass, field
from loguru import logger
import time
import numpy as np
from ..embedding.text_chunker import TextChunk
from ..embedding.bge_m3_embedder import EmbeddingResult
from src.config.settings import settings
@dataclass
class SearchResult:
"""检索结果"""
id: int
content: str
score: float
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class MilvusDocument:
"""Milvus文档数据结构"""
doc_id: str
chunk_id: str
content: str
dense_vector: List[float]
sparse_vector: Dict[int, float]
doc_name: str
section_title: str
clause_number: str
page_number: int
regulation_type: str
version: str
create_time: int
class MilvusClient:
"""Milvus向量数据库客户端"""
COLLECTION_NAME = "regulations"
SCHEMA_FIELDS = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="doc_id", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="chunk_id", dtype=DataType.VARCHAR, max_length=128),
FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=8192),
FieldSchema(name="dense_vector", dtype=DataType.FLOAT_VECTOR, dim=1024),
FieldSchema(name="sparse_vector", dtype=DataType.SPARSE_FLOAT_VECTOR),
FieldSchema(name="doc_name", dtype=DataType.VARCHAR, max_length=256),
FieldSchema(name="section_title", dtype=DataType.VARCHAR, max_length=512),
FieldSchema(name="clause_number", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="page_number", dtype=DataType.INT64),
FieldSchema(name="regulation_type", dtype=DataType.VARCHAR, max_length=32),
FieldSchema(name="version", dtype=DataType.VARCHAR, max_length=32),
FieldSchema(name="create_time", dtype=DataType.INT64),
]
def __init__(
self,
host: str = None,
port: int = None,
collection_name: str = None,
db_name: str = None
):
self.host = host or settings.milvus_host
self.port = port or settings.milvus_port
self.collection_name = collection_name or settings.milvus_collection
self.db_name = db_name or settings.milvus_db_name
self.collection: Optional[Collection] = None
self.connected = False
logger.info(f"Milvus客户端配置: {self.host}:{self.port}, Collection: {self.collection_name}")
def connect(self) -> bool:
"""连接到Milvus服务器"""
try:
connections.connect(
alias="default",
host=self.host,
port=self.port,
db_name=self.db_name
)
self.connected = True
logger.success(f"Milvus连接成功: {self.host}:{self.port}")
return True
except Exception as e:
logger.error(f"Milvus连接失败: {e}")
self.connected = False
return False
def disconnect(self):
"""断开连接"""
try:
connections.disconnect("default")
self.connected = False
logger.info("Milvus连接已断开")
except Exception as e:
logger.warning(f"断开连接时出错: {e}")
def create_collection(self, recreate: bool = False) -> bool:
"""创建Collection"""
if not self.connected:
logger.warning("未连接到Milvus请先调用connect()")
return False
try:
if utility.has_collection(self.collection_name):
if recreate:
logger.info(f"删除已存在的Collection: {self.collection_name}")
utility.drop_collection(self.collection_name)
else:
logger.info(f"Collection已存在: {self.collection_name}")
self.collection = Collection(self.collection_name)
return True
schema = CollectionSchema(
fields=self.SCHEMA_FIELDS,
description="法规文档向量存储",
enable_dynamic_field=True
)
self.collection = Collection(
name=self.collection_name,
schema=schema
)
self._create_indexes()
logger.success(f"Collection创建成功: {self.collection_name}")
return True
except Exception as e:
logger.error(f"Collection创建失败: {e}")
return False
def _create_indexes(self):
"""创建向量索引"""
if not self.collection:
return
try:
dense_index_params = {
"metric_type": "COSINE",
"index_type": "IVF_FLAT",
"params": {"nlist": 128}
}
self.collection.create_index(
field_name="dense_vector",
index_params=dense_index_params
)
sparse_index_params = {
"metric_type": "IP",
"index_type": "SPARSE_INVERTED_INDEX",
"params": {"drop_ratio_build": 0.2}
}
self.collection.create_index(
field_name="sparse_vector",
index_params=sparse_index_params
)
logger.success("向量索引创建成功")
except Exception as e:
logger.warning(f"创建索引时出错: {e}")
def load_collection(self):
"""加载Collection到内存"""
if self.collection:
self.collection.load()
logger.info(f"Collection已加载: {self.collection_name}")
def release_collection(self):
"""释放Collection内存"""
if self.collection:
self.collection.release()
logger.info(f"Collection已释放: {self.collection_name}")
def insert_chunks(
self,
chunks: List[TextChunk],
embeddings: EmbeddingResult
) -> List[int]:
"""插入文档分块和嵌入向量"""
if not self.collection:
logger.warning("Collection未初始化")
return []
if len(chunks) != len(embeddings.texts):
logger.warning(f"Chunks数量与嵌入数量不匹配")
return []
logger.info(f"准备插入{len(chunks)}个文档分块")
try:
data = []
current_time = int(time.time())
for chunk, dense_emb, sparse_emb in zip(
chunks,
embeddings.dense_embeddings,
embeddings.sparse_embeddings
):
row = {
"doc_id": chunk.metadata.doc_id,
"chunk_id": chunk.metadata.chunk_id,
"content": chunk.content,
"dense_vector": dense_emb.tolist(),
"sparse_vector": sparse_emb,
"doc_name": chunk.metadata.doc_name,
"section_title": chunk.metadata.section_title,
"clause_number": chunk.metadata.clause_number,
"page_number": chunk.metadata.page_number,
"regulation_type": chunk.metadata.regulation_type,
"version": chunk.metadata.version,
"create_time": current_time
}
data.append(row)
result = self.collection.insert(data)
self.collection.flush()
logger.success(f"插入完成,共{len(result.primary_keys)}条记录")
return result.primary_keys
except Exception as e:
logger.error(f"插入数据失败: {e}")
return []
def hybrid_search(
self,
query_dense: List[float],
query_sparse: Dict[int, float],
top_k: int = 10,
filters: Optional[str] = None
) -> List[SearchResult]:
"""混合检索Dense + Sparse"""
if not self.collection:
logger.warning("Collection未初始化")
return []
try:
self.collection.load()
# 使用简单的Dense检索兼容所有版本
dense_results = self.dense_search(query_dense, top_k, filters)
# 可选合并Sparse结果
if query_sparse:
sparse_results = self.sparse_search(query_sparse, top_k, filters)
merged = self._merge_results(dense_results, sparse_results, top_k)
logger.success(f"混合检索完成,返回{len(merged)}条结果")
return merged
return dense_results
except Exception as e:
logger.error(f"混合检索失败: {e}")
return []
def _merge_results(
self,
dense_results: List[SearchResult],
sparse_results: List[SearchResult],
top_k: int,
dense_weight: float = 0.6
) -> List[SearchResult]:
"""手动融合Dense和Sparse结果"""
sparse_weight = 1 - dense_weight
merged_dict = {}
for r in dense_results:
merged_dict[r.id] = {
"result": r,
"dense_score": r.score * dense_weight,
"sparse_score": 0
}
for r in sparse_results:
if r.id in merged_dict:
merged_dict[r.id]["sparse_score"] = r.score * sparse_weight
else:
merged_dict[r.id] = {
"result": r,
"dense_score": 0,
"sparse_score": r.score * sparse_weight
}
final_results = []
for id_, data in merged_dict.items():
result = data["result"]
final_score = data["dense_score"] + data["sparse_score"]
final_results.append(SearchResult(
id=result.id,
content=result.content,
score=final_score,
metadata=result.metadata
))
final_results.sort(key=lambda x: x.score, reverse=True)
return final_results[:top_k]
def dense_search(
self,
query_dense: List[float],
top_k: int = 10,
filters: Optional[str] = None
) -> List[SearchResult]:
"""纯Dense向量检索"""
if not self.collection:
return []
try:
self.collection.load()
search_params = {
"metric_type": "COSINE",
"params": {"nprobe": 16}
}
results = self.collection.search(
data=[query_dense],
anns_field="dense_vector",
param=search_params,
limit=top_k,
filter=filters,
output_fields=[
"doc_id", "chunk_id", "content",
"doc_name", "section_title", "clause_number",
"page_number", "regulation_type", "version"
]
)
search_results = []
for hits in results:
for hit in hits:
result = SearchResult(
id=hit.id,
content=hit.entity.get("content", ""),
score=hit.score,
metadata={
"doc_id": hit.entity.get("doc_id", ""),
"chunk_id": hit.entity.get("chunk_id", ""),
"doc_name": hit.entity.get("doc_name", ""),
"section_title": hit.entity.get("section_title", ""),
"clause_number": hit.entity.get("clause_number", ""),
"page_number": hit.entity.get("page_number", 0),
"regulation_type": hit.entity.get("regulation_type", ""),
"version": hit.entity.get("version", ""),
}
)
search_results.append(result)
return search_results
except Exception as e:
logger.error(f"Dense检索失败: {e}")
return []
def sparse_search(
self,
query_sparse: Dict[int, float],
top_k: int = 10,
filters: Optional[str] = None
) -> List[SearchResult]:
"""纯Sparse向量检索"""
if not self.collection:
return []
try:
self.collection.load()
search_params = {
"metric_type": "IP",
"params": {"drop_ratio_search": 0.2}
}
results = self.collection.search(
data=[query_sparse],
anns_field="sparse_vector",
param=search_params,
limit=top_k,
filter=filters,
output_fields=[
"doc_id", "chunk_id", "content",
"doc_name", "section_title", "clause_number",
"page_number", "regulation_type", "version"
]
)
search_results = []
for hits in results:
for hit in hits:
result = SearchResult(
id=hit.id,
content=hit.entity.get("content", ""),
score=hit.score,
metadata={
"doc_id": hit.entity.get("doc_id", ""),
"chunk_id": hit.entity.get("chunk_id", ""),
"doc_name": hit.entity.get("doc_name", ""),
"section_title": hit.entity.get("section_title", ""),
"clause_number": hit.entity.get("clause_number", ""),
"page_number": hit.entity.get("page_number", 0),
"regulation_type": hit.entity.get("regulation_type", ""),
"version": hit.entity.get("version", ""),
}
)
search_results.append(result)
return search_results
except Exception as e:
logger.error(f"Sparse检索失败: {e}")
return []
def delete_by_doc_id(self, doc_id: str) -> int:
"""根据doc_id删除记录"""
if not self.collection:
return 0
try:
expr = f'doc_id=="{doc_id}"'
result = self.collection.delete(expr)
logger.info(f"删除记录: doc_id={doc_id}, 数量={len(result.primary_keys)}")
return len(result.primary_keys)
except Exception as e:
logger.error(f"删除失败: {e}")
return 0
def get_collection_stats(self) -> Dict[str, Any]:
"""获取Collection统计信息"""
if not self.collection:
return {}
try:
stats = {
"name": self.collection_name,
"num_entities": self.collection.num_entities,
"description": self.collection.description,
}
return stats
except Exception as e:
logger.warning(f"获取统计信息失败: {e}")
return {}
def create_milvus_client() -> MilvusClient:
"""便捷函数创建Milvus客户端"""
client = MilvusClient()
client.connect()
client.create_collection(recreate=False)
return client
def insert_documents(
client: MilvusClient,
chunks: List[TextChunk],
embeddings: EmbeddingResult
) -> List[int]:
"""便捷函数:插入文档"""
return client.insert_chunks(chunks, embeddings)
def search_regulations(
client: MilvusClient,
query_dense: List[float],
query_sparse: Dict[int, float],
top_k: int = 10
) -> List[SearchResult]:
"""便捷函数:检索法规"""
return client.hybrid_search(query_dense, query_sparse, top_k)

2
src/workers/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
# src/workers/__init__.py
"""异步任务Worker模块"""

47
start_api.sh Normal file
View File

@@ -0,0 +1,47 @@
#!/bin/bash
# start_api.sh - 启动API服务支持虚拟环境
set -e
VENV_DIR=".venv"
# 创建日志目录
mkdir -p logs
echo "========================================"
echo "启动 AI+合规智能中枢 API服务"
echo "========================================"
echo ""
# 检查虚拟环境
if [ ! -d "$VENV_DIR" ]; then
echo "错误: 虚拟环境不存在,请先运行 ./quick_start.sh"
exit 1
fi
# 激活虚拟环境
source $VENV_DIR/bin/activate
echo "已激活虚拟环境: $VENV_DIR"
echo ""
# 检查.env文件
if [ ! -f ".env" ]; then
echo "警告: .env文件不存在使用默认配置"
fi
# 启动参数
HOST=${API_HOST:-0.0.0.0}
PORT=${API_PORT:-8000}
echo "API地址: http://$HOST:$PORT"
echo "API文档: http://$HOST:$PORT/docs"
echo "健康检查: http://$HOST:$PORT/health"
echo ""
echo "前端测试页面:"
echo " 直接打开: frontend/index.html"
echo " 或启动服务: ./start_frontend.sh"
echo ""
echo "正在启动..."
echo ""
python -m uvicorn src.api.main:app --host $HOST --port $PORT --reload

77
start_api_background.sh Normal file
View File

@@ -0,0 +1,77 @@
#!/bin/bash
# start_api_background.sh - 后台启动API服务生产环境支持虚拟环境
set -e
VENV_DIR=".venv"
# 创建日志目录
mkdir -p logs
PID_FILE=logs/api.pid
LOG_FILE=logs/api.log
echo "========================================"
echo "后台启动 AI+合规智能中枢 API服务"
echo "========================================"
echo ""
# 检查虚拟环境
if [ ! -d "$VENV_DIR" ]; then
echo "错误: 虚拟环境不存在,请先运行 ./quick_start.sh"
exit 1
fi
# 检查是否已运行
if [ -f "$PID_FILE" ]; then
PID=$(cat $PID_FILE)
if ps -p $PID > /dev/null 2>&1; then
echo "服务已在运行 (PID: $PID)"
echo "如需重启,请先运行: ./stop_api.sh"
exit 1
else
rm -f $PID_FILE
fi
fi
# 启动参数
HOST=${API_HOST:-0.0.0.0}
PORT=${API_PORT:-8000}
echo "服务地址: http://$HOST:$PORT"
echo "日志文件: $LOG_FILE"
echo ""
# 后台启动(使用虚拟环境)
echo "正在后台启动..."
nohup $VENV_DIR/bin/python -m uvicorn src.api.main:app --host $HOST --port $PORT > $LOG_FILE 2>&1 &
PID=$!
# 保存PID
echo $PID > $PID_FILE
# 等待服务启动
sleep 3
# 检查是否启动成功
if ps -p $PID > /dev/null 2>&1; then
echo "服务启动成功 (PID: $PID)"
echo ""
echo "API地址: http://$HOST:$PORT"
echo "API文档: http://$HOST:$PORT/docs"
echo "健康检查: http://$HOST:$PORT/health"
echo ""
echo "前端测试页面:"
echo " 直接打开: frontend/index.html"
echo " 或启动服务: ./start_frontend.sh"
echo ""
echo "查看日志:"
echo " tail -f $LOG_FILE"
echo ""
echo "停止服务:"
echo " ./stop_api.sh"
else
echo "服务启动失败,请查看日志: $LOG_FILE"
rm -f $PID_FILE
exit 1
fi

46
start_frontend.sh Normal file
View File

@@ -0,0 +1,46 @@
#!/bin/bash
# start_frontend.sh - 启动前端测试页面静态服务
set -e
FRONTEND_DIR="frontend"
PORT=${FRONTEND_PORT:-3000}
echo "========================================"
echo "启动前端测试页面服务"
echo "========================================"
echo ""
# 检查前端目录
if [ ! -d "$FRONTEND_DIR" ]; then
echo "错误: 前端目录不存在: $FRONTEND_DIR"
exit 1
fi
# 检查index.html
if [ ! -f "$FRONTEND_DIR/index.html" ]; then
echo "错误: 前端页面不存在: $FRONTEND_DIR/index.html"
exit 1
fi
echo "前端地址: http://localhost:$PORT"
echo ""
echo "确保API服务已启动:"
echo " ./start_api.sh"
echo ""
echo "正在启动前端服务..."
# 使用Python内置HTTP服务器
cd $FRONTEND_DIR
# 查找Python命令
if command -v python3 &> /dev/null; then
PYTHON_CMD=python3
elif command -v python &> /dev/null; then
PYTHON_CMD=python
else
echo "错误: 未找到Python"
exit 1
fi
$PYTHON_CMD -m http.server $PORT --bind 0.0.0.0

46
stop_api.sh Normal file
View File

@@ -0,0 +1,46 @@
#!/bin/bash
# stop_api.sh - 停止API服务
PID_FILE=logs/api.pid
echo "========================================"
echo "停止 AI+合规智能中枢 API服务"
echo "========================================"
echo ""
if [ -f "$PID_FILE" ]; then
PID=$(cat $PID_FILE)
if ps -p $PID > /dev/null 2>&1; then
echo "正在停止服务 (PID: $PID)..."
kill $PID
# 等待进程结束
sleep 2
if ps -p $PID > /dev/null 2>&1; then
echo "进程未响应,强制终止..."
kill -9 $PID
fi
rm -f $PID_FILE
echo "服务已停止"
else
echo "进程已不存在清理PID文件"
rm -f $PID_FILE
fi
else
echo "PID文件不存在服务可能未运行"
# 尝试查找并停止所有uvicorn进程
UVICORN_PIDS=$(pgrep -f "uvicorn src.api.main")
if [ -n "$UVICORN_PIDS" ]; then
echo "发现运行中的uvicorn进程: $UVICORN_PIDS"
echo "是否停止这些进程? (y/n)"
read -r answer
if [ "$answer" = "y" ]; then
kill $UVICORN_PIDS
echo "进程已停止"
fi
fi
fi

37
test_api.sh Normal file
View File

@@ -0,0 +1,37 @@
#!/bin/bash
# test_api.sh - API接口测试脚本
API_URL=${API_URL:-http://localhost:8000}
echo "========================================"
echo "API接口测试"
echo "========================================"
echo ""
# 1. 健康检查
echo ">>> 测试: 健康检查 GET /health"
curl -s -X GET "$API_URL/health"
echo ""
echo ""
# 2. 根路径
echo ">>> 测试: 根路径 GET /"
curl -s -X GET "$API_URL/"
echo ""
echo ""
# 3. 检索接口(无数据时返回空结果)
echo ">>> 测试: 检索接口 POST /api/v1/knowledge/search"
curl -s -X POST "$API_URL/api/v1/knowledge/search" \
-H "Content-Type: application/json" \
-d '{"query": "机动车安全标准", "top_k": 5}'
echo ""
echo ""
echo "========================================"
echo "测试完成"
echo "========================================"
echo ""
echo "上传文档测试:"
echo " ./test_upload.sh your_file.pdf"
echo ""

76
test_upload.sh Normal file
View File

@@ -0,0 +1,76 @@
#!/bin/bash
# test_upload.sh - 测试文档上传(支持虚拟环境)
set -e
API_URL=${API_URL:-http://localhost:8000}
echo "========================================"
echo "测试文档上传功能"
echo "========================================"
echo ""
echo "API地址: $API_URL"
echo ""
# 检查健康状态
echo "1. 检查服务健康状态..."
curl -s $API_URL/health | python -m json.tool 2>/dev/null || curl -s $API_URL/health
echo ""
# 检查文件参数
if [ -z "$1" ]; then
echo "使用方法: ./test_upload.sh <文件路径>"
echo ""
echo "示例:"
echo " ./test_upload.sh sample.pdf"
echo " ./test_upload.sh sample.docx"
exit 1
fi
FILE_PATH=$1
if [ ! -f "$FILE_PATH" ]; then
echo "错误: 文件不存在: $FILE_PATH"
exit 1
fi
FILE_NAME=$(basename "$FILE_PATH")
FILE_EXT="${FILE_NAME##*.}"
echo "文件: $FILE_PATH"
echo "类型: $FILE_EXT"
echo ""
# 上传文档
echo "2. 上传文档..."
echo ""
RESPONSE=$(curl -s -X POST "$API_URL/api/v1/documents/upload" \
-F "file=@$FILE_PATH" \
-F "doc_name=$FILE_NAME" \
-F "regulation_type=测试法规")
echo "$RESPONSE" | python -m json.tool 2>/dev/null || echo "$RESPONSE"
echo ""
# 提取doc_id
DOC_ID=$(echo "$RESPONSE" | python -c "import sys,json; d=json.load(sys.stdin); print(d.get('doc_id',''))" 2>/dev/null)
if [ -n "$DOC_ID" ]; then
echo "文档ID: $DOC_ID"
echo ""
# 测试检索
echo "3. 测试检索..."
echo ""
SEARCH_QUERY="法规安全要求"
curl -s -X POST "$API_URL/api/v1/knowledge/search" \
-H "Content-Type: application/json" \
-d "{\"query\": \"$SEARCH_QUERY\", \"top_k\": 5}" \
| python -m json.tool 2>/dev/null || cat
else
echo "上传可能失败,请检查响应内容"
fi

2
tests/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
# tests/__init__.py
"""测试模块"""

184
tests/test_embedding.py Normal file
View File

@@ -0,0 +1,184 @@
# tests/test_embedding.py
"""嵌入和分块测试"""
import pytest
from loguru import logger
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from src.services.embedding.text_chunker import RegulationChunker, TextChunk, ChunkMetadata
from src.services.embedding.bge_m3_embedder import BGEM3Embedder, EmbeddingResult
class TestRegulationChunker:
"""法规分块器测试"""
@pytest.fixture
def chunker(self):
"""创建分块器实例"""
return RegulationChunker(chunk_size=512)
@pytest.fixture
def sample_regulation(self):
"""示例法规文档"""
return """
# GB 7258-2017 机动车运行安全技术条件
第一章 范围
第一条 本标准规定了机动车运行安全技术条件。
第二条 本标准适用于在我国道路上行驶的所有机动车。
第二章 术语和定义
第三条 下列术语和定义适用于本标准。
(一)机动车:以动力装置驱动或者牵引,上道路行驶的供人员乘用或者用于运送物品以及进行工程专项作业的轮式车辆。
(二)整车:完整的机动车,包括所有必要的部件和系统。
第三章 技术要求
第四条 机动车应满足以下基本要求:
1. 车辆应具有唯一的产品标识;
2. 车辆结构应安全可靠;
3. 车辆应配备必要的安全装置。
"""
def test_chunk_document(self, chunker, sample_regulation):
"""测试文档分块"""
chunks = chunker.chunk_document(
sample_regulation,
doc_id="gb7258",
doc_name="GB 7258-2017",
regulation_type="车辆安全"
)
# 应该有多个分块
assert len(chunks) > 3
# 每个分块应该有内容
for chunk in chunks:
assert len(chunk.content) > 0
assert chunk.metadata.doc_id == "gb7258"
def test_section_detection(self, chunker, sample_regulation):
"""测试章节检测"""
chunks = chunker.chunk_document(
sample_regulation,
doc_id="test",
doc_name="测试"
)
# 应该检测到章节
section_numbers = [c.metadata.section_number for c in chunks]
assert any(s for s in section_numbers) # 至少有一个章节编号
def test_clause_detection(self, chunker, sample_regulation):
"""测试条款检测"""
chunks = chunker.chunk_document(
sample_regulation,
doc_id="test",
doc_name="测试"
)
# 应该检测到条款
clause_numbers = [c.metadata.clause_number for c in chunks]
assert any(c for c in clause_numbers) # 至少有一个条款编号
def test_long_clause_split(self, chunker):
"""测试长条款分割"""
long_clause = """
第一条 本条款内容很长,需要进行分割处理。
本条款包含以下多项内容:
1. 第一项内容,这是一个非常长的子项,包含了大量的文字描述,需要进行适当的处理。
2. 第二项内容,这也是一个较长的子项,包含了相关的技术要求和规范说明。
3. 第三项内容,继续描述相关要求和注意事项,确保文档的完整性和规范性。
4. 第四项内容,补充说明其他相关事项,保证内容的全面性。
"""
chunks = chunker.chunk_document(
long_clause,
doc_id="test",
doc_name="测试"
)
# 长条款应该被分割成多个chunk
assert len(chunks) >= 1
class TestBGEM3Embedder:
"""BGE-M3嵌入模型测试"""
@pytest.fixture
def embedder(self):
"""创建嵌入模型实例"""
try:
return BGEM3Embedder()
except Exception as e:
pytest.skip(f"嵌入模型加载失败: {e}")
def test_embed_single(self, embedder):
"""测试单文本嵌入"""
text = "这是一条测试文本"
result = embedder.embed_single(text)
# 应该包含dense和sparse向量
assert 'dense' in result
assert 'sparse' in result
# dense向量维度应该是1024
assert len(result['dense']) == 1024
def test_embed_batch(self, embedder):
"""测试批量嵌入"""
texts = [
"第一条 本标准规定了机动车安全要求",
"第二条 机动车应符合技术条件",
"第三条 生产企业应建立管理体系"
]
result = embedder.embed(texts)
# 应该返回正确数量的向量
assert len(result.dense_embeddings) == 3
# 维度应该是1024
assert result.dense_embeddings.shape[1] == 1024
def test_embed_empty_list(self, embedder):
"""测试空列表嵌入"""
result = embedder.embed([])
# 应该返回空结果
assert len(result.dense_embeddings) == 0
def test_similarity(self, embedder):
"""测试相似度计算"""
import numpy as np
texts = [
"机动车安全标准要求",
"汽车安全技术规范",
"食品安全管理规定" # 不相关文本
]
result = embedder.embed(texts)
# 计算第一个文本与其他文本的相似度
query = result.dense_embeddings[0]
docs = result.dense_embeddings[1:]
similarities = embedder.compute_similarity(query, docs)
# 相关文档的相似度应该更高
assert similarities[0] > similarities[1] # 车辆安全 > 食品安全
if __name__ == "__main__":
pytest.main([__file__, "-v"])

136
tests/test_milvus.py Normal file
View File

@@ -0,0 +1,136 @@
# tests/test_milvus.py
"""Milvus集成测试"""
import pytest
from loguru import logger
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from src.services.storage.milvus_client import MilvusClient, SearchResult
from src.services.embedding.bge_m3_embedder import BGEM3Embedder
from src.config.settings import settings
class TestMilvusConnection:
"""Milvus连接测试"""
def test_connection(self):
"""测试Milvus连接"""
client = MilvusClient()
result = client.connect()
assert result == True
client.disconnect()
def test_create_collection(self):
"""测试创建Collection"""
client = MilvusClient()
client.connect()
result = client.create_collection(recreate=True)
assert result == True
# 检查Collection是否存在
stats = client.get_collection_stats()
assert stats["name"] == settings.milvus_collection
client.disconnect()
class TestMilvusOperations:
"""Milvus操作测试"""
@pytest.fixture
def client(self):
"""创建测试客户端"""
client = MilvusClient()
client.connect()
client.create_collection(recreate=True)
client.load_collection()
yield client
client.disconnect()
def test_insert_and_search(self, client):
"""测试插入和检索"""
from src.services.embedding.text_chunker import TextChunk, ChunkMetadata
# 创建测试数据
chunks = [
TextChunk(
content="第一条 为保障机动车安全技术性能,预防和减少机动车交通事故,保护人身安全,制定本标准。",
metadata=ChunkMetadata(
doc_id="test_doc",
doc_name="测试文档",
chunk_id="test_chunk_1",
clause_number="第一条",
regulation_type="车辆安全"
)
),
TextChunk(
content="第二条 本标准适用于在我国道路上行驶的所有机动车。",
metadata=ChunkMetadata(
doc_id="test_doc",
doc_name="测试文档",
chunk_id="test_chunk_2",
clause_number="第二条",
regulation_type="车辆安全"
)
)
]
# 生成嵌入
embedder = BGEM3Embedder()
embeddings = embedder.embed([c.content for c in chunks])
# 插入数据
inserted_ids = client.insert_chunks(chunks, embeddings)
assert len(inserted_ids) == 2
# 执行检索
query = "机动车安全标准"
query_embedding = embedder.embed_single(query)
results = client.hybrid_search(
query_dense=query_embedding['dense'].tolist(),
query_sparse=query_embedding['sparse'],
top_k=2
)
assert len(results) > 0
assert "机动车" in results[0].content or "安全" in results[0].content
class TestEmbedding:
"""嵌入模型测试"""
def test_embed_single_text(self):
"""测试单文本嵌入"""
embedder = BGEM3Embedder()
result = embedder.embed_single("这是一条测试文本")
assert 'dense' in result
assert 'sparse' in result
assert len(result['dense']) == 1024 # BGE-M3默认维度
def test_embed_batch(self):
"""测试批量嵌入"""
embedder = BGEM3Embedder()
texts = [
"第一条 本标准规定了机动车安全要求",
"第二条 机动车应符合以下技术条件",
"第三条 生产企业应建立质量管理体系"
]
result = embedder.embed(texts)
assert len(result.dense_embeddings) == 3
assert result.dense_embeddings.shape[1] == 1024
if __name__ == "__main__":
pytest.main([__file__, "-v"])

118
tests/test_parser.py Normal file
View File

@@ -0,0 +1,118 @@
# tests/test_parser.py
"""文档解析测试"""
import pytest
from loguru import logger
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from src.services.parser.pdf_parser import PDFParser, parse_pdf_to_markdown
from src.services.parser.docx_parser import DocxParser, parse_docx_to_markdown
from src.services.parser.mineru_parser import MinerUParser, ParserOrchestrator
class TestPDFParser:
"""PDF解析测试"""
def test_parser_initialization(self):
"""测试PDF解析器初始化"""
parser = PDFParser()
assert parser is not None
def test_parse_sample_pdf(self):
"""测试解析示例PDF如果有"""
# 如果有示例PDF文件可以在此测试
sample_pdf = os.path.join(os.path.dirname(__file__), "sample.pdf")
if os.path.exists(sample_pdf):
parser = PDFParser()
result = parser.parse(sample_pdf)
assert result.total_pages > 0
assert len(result.pages) > 0
assert len(result.markdown_text) > 0
class TestDocxParser:
"""Word文档解析测试"""
def test_parser_initialization(self):
"""测试Word解析器初始化"""
parser = DocxParser()
assert parser is not None
def test_parse_sample_docx(self):
"""测试解析示例DOCX"""
sample_docx = os.path.join(os.path.dirname(__file__), "sample.docx")
if os.path.exists(sample_docx):
parser = DocxParser()
result = parser.parse(sample_docx)
assert len(result.paragraphs) > 0
assert len(result.markdown_text) > 0
class TestChunker:
"""分块器测试"""
def test_chunker_initialization(self):
"""测试分块器初始化"""
from src.services.embedding.text_chunker import RegulationChunker
chunker = RegulationChunker(chunk_size=512)
assert chunker is not None
def test_chunk_sample_text(self):
"""测试分块示例文本"""
from src.services.embedding.text_chunker import RegulationChunker
sample_text = """
# 测试法规文档
第一章 总则
第一条 为规范某项行为,制定本规定。
第二条 本规定适用于相关主体。
第二章 具体要求
第三条 相关主体应当遵守以下要求:
(一)建立管理制度;
(二)配备专业人员;
(三)定期进行检查。
"""
chunker = RegulationChunker(chunk_size=256)
chunks = chunker.chunk_document(
sample_text,
doc_id="test",
doc_name="测试法规"
)
assert len(chunks) > 0
# 验证分块包含章节信息
has_section = any(c.metadata.section_number for c in chunks)
assert has_section
class TestFullPipeline:
"""完整流程测试"""
def test_pipeline_without_files(self):
"""测试流程初始化(无文件)"""
from src.services.document_processor import DocumentProcessor
processor = DocumentProcessor()
assert processor is not None
processor.close()
if __name__ == "__main__":
pytest.main([__file__, "-v"])

222
tests/verify_mvp.py Normal file
View File

@@ -0,0 +1,222 @@
"""
MVP功能验证脚本
用于验证完整的文档处理流程:
1. PDF/DOCX解析
2. 智能分块
3. 向量嵌入
4. Milvus入库
5. 混合检索
使用方法:
1. 首先启动Milvus: docker-compose up -d
2. 运行此脚本: python verify_mvp.py
"""
import os
import sys
import time
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from loguru import logger
from src.config.logging import setup_logging
from src.services.document_processor import DocumentProcessor, ProcessingResult
from src.services.storage.milvus_client import MilvusClient
from src.config.settings import settings
# 设置日志
setup_logging(level="INFO")
def verify_milvus_connection():
"""验证Milvus连接"""
logger.info("=" * 50)
logger.info("Step 1: 验证Milvus连接")
logger.info("=" * 50)
client = MilvusClient()
try:
result = client.connect()
if result:
logger.success("Milvus连接成功")
# 创建Collection
client.create_collection(recreate=True)
stats = client.get_collection_stats()
logger.info(f"Collection信息: {stats}")
client.disconnect()
return True
else:
logger.error("Milvus连接失败请检查docker-compose是否启动")
return False
except Exception as e:
logger.error(f"Milvus连接异常: {e}")
logger.info("请先启动Milvus: cd docker && docker-compose up -d")
return False
def verify_embedding_model():
"""验证嵌入模型"""
logger.info("=" * 50)
logger.info("Step 2: 验证BGE-M3嵌入模型")
logger.info("=" * 50)
try:
from src.services.embedding.bge_m3_embedder import BGEM3Embedder
embedder = BGEM3Embedder()
logger.success("嵌入模型加载成功")
# 测试嵌入
test_text = "这是一条测试文本,用于验证嵌入模型功能"
result = embedder.embed_single(test_text)
logger.info(f"Dense向量维度: {len(result['dense'])}")
logger.info(f"Sparse向量词数: {len(result['sparse'])}")
return True
except Exception as e:
logger.error(f"嵌入模型验证失败: {e}")
logger.info("请确保已安装FlagEmbedding: pip install FlagEmbedding")
return False
def verify_sample_document():
"""验证示例文档处理"""
logger.info("=" * 50)
logger.info("Step 3: 验证文档处理流程")
logger.info("=" * 50)
# 使用内置的示例文本(无需外部文件)
sample_text = """
# GB 7258-2017 机动车运行安全技术条件
第一章 范围
第一条 本标准规定了机动车运行安全技术条件,适用于在我国道路上行驶的所有机动车。
第二条 本标准包括整车、发动机、传动系、行驶系、制动系、照明与信号装置等技术要求。
第二章 术语和定义
第三条 下列术语和定义适用于本标准:
(一)机动车:以动力装置驱动或者牵引,上道路行驶的供人员乘用或者用于运送物品的轮式车辆。
(二)整车产品:完整的机动车产品,包括所有必要的部件和系统。
第三章 整车技术要求
第四条 机动车整车应满足以下基本技术要求:
1. 车辆外廓尺寸应符合规定限值;
2. 车辆应具有唯一的产品标识;
3. 车辆结构应安全可靠,各部件连接牢固。
第五条 车辆应配备必要的安全装置,包括:
- 制动系统
- 照明与信号装置
- 安全带
- 灭火器
"""
try:
from src.services.embedding.text_chunker import RegulationChunker
from src.services.embedding.bge_m3_embedder import BGEM3Embedder
from src.services.storage.milvus_client import MilvusClient
# 1. 分块
logger.info("测试分块...")
chunker = RegulationChunker(chunk_size=256)
chunks = chunker.chunk_document(
sample_text,
doc_id="gb7258_test",
doc_name="GB 7258-2017 测试",
regulation_type="车辆安全"
)
logger.success(f"分块完成,共{len(chunks)}个chunk")
# 2. 嵌入
logger.info("测试嵌入...")
embedder = BGEM3Embedder()
embeddings = embedder.embed([c.content for c in chunks])
logger.success(f"嵌入完成,向量数: {len(embeddings.dense_embeddings)}")
# 3. 入库
logger.info("测试入库...")
client = MilvusClient()
client.connect()
client.create_collection(recreate=False)
client.load_collection()
inserted_ids = client.insert_chunks(chunks, embeddings)
logger.success(f"入库完成,共{len(inserted_ids)}条记录")
# 4. 检索
logger.info("测试检索...")
query = "机动车安全技术要求"
query_emb = embedder.embed_single(query)
results = client.hybrid_search(
query_dense=query_emb['dense'].tolist(),
query_sparse=query_emb['sparse'],
top_k=3
)
logger.success(f"检索完成,返回{len(results)}条结果")
for i, r in enumerate(results):
logger.info(f"结果{i+1}: 分数={r.score:.4f}, 内容={r.content[:50]}...")
client.disconnect()
return True
except Exception as e:
logger.error(f"文档处理验证失败: {e}")
return False
def main():
"""主验证流程"""
logger.info("\n" + "=" * 60)
logger.info("AI+合规智能中枢 MVP功能验证")
logger.info("=" * 60)
results = []
# 1. Milvus连接验证
results.append(("Milvus连接", verify_milvus_connection()))
# 2. 嵌入模型验证
results.append(("嵌入模型", verify_embedding_model()))
# 3. 文档处理验证
results.append(("文档处理", verify_sample_document()))
# 输出结果汇总
logger.info("\n" + "=" * 60)
logger.info("验证结果汇总")
logger.info("=" * 60)
all_passed = True
for name, passed in results:
status = "✅ 通过" if passed else "❌ 失败"
logger.info(f"{name}: {status}")
if not passed:
all_passed = False
if all_passed:
logger.success("\n🎉 所有验证通过MVP功能正常")
else:
logger.warning("\n⚠️ 部分验证失败,请检查配置和环境")
return all_passed
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)