commit c2a398930dcdcda5577e1fef3b570a9bf4389559 Author: wangwei Date: Tue Apr 28 11:29:33 2026 +0800 first commit diff --git a/.env b/.env new file mode 100644 index 0000000..a18007d --- /dev/null +++ b/.env @@ -0,0 +1,48 @@ +# 环境变量配置 - 已有数据库服务 + +# 应用配置 +APP_NAME=AI+合规智能中枢 +APP_VERSION=0.1.0 +DEBUG=false + +# Milvus向量数据库配置(已有) +MILVUS_HOST=localhost +MILVUS_PORT=19530 +MILVUS_COLLECTION=regulations +MILVUS_DB_NAME=default + +# MinIO对象存储配置(已有) +MINIO_ENDPOINT=localhost:9000 +MINIO_ACCESS_KEY=minioadmin +MINIO_SECRET_KEY=minioadmin +MINIO_BUCKET=compliance-docs +MINIO_SECURE=false + +# Redis配置(已有) +REDIS_HOST=localhost +REDIS_PORT=6379 +REDIS_PASSWORD=redis@123 +REDIS_DB=0 + +# PostgreSQL配置(已有) +POSTGRES_HOST=localhost +POSTGRES_PORT=5432 +POSTGRES_USER=postgresql +POSTGRES_PASSWORD=postgresql123456 +POSTGRES_DB=compliance_db + +# 嵌入模型配置 +EMBEDDING_MODEL=BAAI/bge-m3 +EMBEDDING_DIM=1024 +EMBEDDING_MAX_LENGTH=8192 +EMBEDDING_BATCH_SIZE=12 +EMBEDDING_USE_FP16=true + +# 文档处理配置 +CHUNK_SIZE=512 +CHUNK_OVERLAP=50 +MAX_FILE_SIZE_MB=100 + +# API配置 +API_HOST=0.0.0.0 +API_PORT=8000 \ No newline at end of file diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..ce9232c --- /dev/null +++ b/.env.example @@ -0,0 +1,31 @@ +# .env.example - 环境变量配置示例 + +# Milvus向量数据库配置 +MILVUS_HOST=localhost +MILVUS_PORT=19530 +MILVUS_COLLECTION=regulations + +# 嵌入模型配置 +EMBEDDING_MODEL=BAAI/bge-m3 +EMBEDDING_DIM=1024 + +# MinIO对象存储配置 +MINIO_ENDPOINT=localhost:9000 +MINIO_ACCESS_KEY=minioadmin +MINIO_SECRET_KEY=minioadmin123 +MINIO_BUCKET=compliance-docs + +# Redis配置 +REDIS_HOST=localhost +REDIS_PORT=6379 + +# PostgreSQL配置 +POSTGRES_HOST=localhost +POSTGRES_PORT=5432 +POSTGRES_USER=compliance +POSTGRES_PASSWORD=compliance123 +POSTGRES_DB=compliance_db + +# 文档处理配置 +CHUNK_SIZE=512 +CHUNK_OVERLAP=50 \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..daf5831 --- /dev/null +++ b/README.md @@ -0,0 +1,143 @@ +# AI+合规智能中枢 - 法律法规文档解析入库 + +面向车企与工厂的合规智能平台,实现法规文档的解析、分块、嵌入和向量存储。 + +## MVP功能 + +本次实现的核心功能(最小可用版本): + +- ✅ PDF/DOCX文档解析(MinerU + PyMuPDF) +- ✅ 智能分块(章节级+条款级双粒度切割) +- ✅ BGE-M3嵌入(Dense+Sparse双路向量) +- ✅ Milvus向量数据库存储与混合检索 +- ✅ FastAPI接口封装 + +## 项目结构 + +``` +Demo-glm/ +├── src/ +│ ├── api/ # FastAPI接口层 +│ │ ├── main.py # API入口 +│ │ ├── routes/ +│ │ │ ├── documents.py # 文档上传接口 +│ │ │ └── knowledge.py # 知识库检索接口 +│ │ └── models/ +│ │ └── document.py # Pydantic数据模型 +│ ├── services/ +│ │ ├── parser/ # 文档解析服务 +│ │ │ ├── pdf_parser.py # PDF解析(PyMuPDF) +│ │ │ ├── docx_parser.py # Word解析 +│ │ │ └── mineru_parser.py # MinerU多模态解析 +│ │ ├── embedding/ # 嵌入服务 +│ │ │ ├── text_chunker.py # 智能分块器 +│ │ │ └── bge_m3_embedder.py # BGE-M3嵌入 +│ │ ├── storage/ +│ │ │ └── milvus_client.py # Milvus客户端 +│ │ └── document_processor.py # 文档处理主流程 +│ └── config/ +│ │ ├── settings.py # 配置管理 +│ │ └── logging.py # 日志配置 +├── tests/ +│ ├── test_parser.py # 解析测试 +│ ├── test_embedding.py # 嵌入测试 +│ ├── test_milvus.py # Milvus测试 +│ └── verify_mvp.py # MVP验证脚本 +├── docker/ +│ └── docker-compose.yml # Milvus/MinIO部署 +├── requirements.txt +├── pyproject.toml +└── .env.example +``` + +## 快速开始 + +### 1. 安装依赖 + +```bash +pip install -r requirements.txt +``` + +### 2. 启动Milvus向量数据库 + +```bash +cd docker +docker-compose up -d +``` + +等待Milvus启动完成(约30秒): +```bash +docker-compose logs -f milvus +``` + +### 3. 运行验证脚本 + +```bash +python tests/verify_mvp.py +``` + +### 4. 启动API服务 + +```bash +uvicorn src.api.main:app --reload --port 8000 +``` + +访问API文档:http://localhost:8000/docs + +## API接口 + +### 上传文档 + +```bash +curl -X POST http://localhost:8000/api/v1/documents/upload \ + -F "file=@your_regulation.pdf" \ + -F "doc_name=GB 7258-2017" \ + -F "regulation_type=车辆安全" +``` + +### 检索法规 + +```bash +curl -X POST http://localhost:8000/api/v1/knowledge/search \ + -H "Content-Type: application/json" \ + -d '{"query": "机动车安全技术要求", "top_k": 10}' +``` + +## 技术栈 + +| 类别 | 技术 | +|------|------| +| 文档解析 | MinerU + PyMuPDF + python-docx | +| 分块策略 | 章节级+条款级双粒度切割 | +| 嵌入模型 | BGE-M3(1024维 Dense + Sparse) | +| 向量数据库 | Milvus 2.4(本地Docker部署) | +| 检索方式 | Dense+Sparse混合检索 + RRF融合 | +| API框架 | FastAPI | + +## 配置 + +创建 `.env` 文件(参考 `.env.example`): + +```env +# Milvus配置 +MILVUS_HOST=localhost +MILVUS_PORT=19530 + +# 嵌入模型配置 +EMBEDDING_MODEL=BAAI/bge-m3 +EMBEDDING_DIM=1024 + +# 分块配置 +CHUNK_SIZE=512 +``` + +## 后续迭代(不在本次MVP范围) + +- LLM摘要生成(DeepSeek/Qwen API) +- 文档上传UI界面 +- 混合检索问答功能 +- 法规变更监控与自动更新 + +## 许可证 + +MIT License \ No newline at end of file diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml new file mode 100644 index 0000000..1c6132b --- /dev/null +++ b/docker/docker-compose.yml @@ -0,0 +1,88 @@ +# AI+合规智能中枢 - 基础设施部署配置 +# Milvus向量数据库 + MinIO对象存储 + Redis缓存 + +services: + # Milvus向量数据库 (Standalone模式) + milvus: + image: milvusdb/milvus:v2.4-latest + container_name: milvus-standalone + ports: + - "19530:19530" # SDK连接端口 + - "9091:9091" # 健康检查端口 + environment: + ETCD_USE_EMBED: "true" + COMMON_LOG_LEVEL: "info" + volumes: + - milvus_data:/var/lib/milvus + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"] + interval: 30s + timeout: 10s + retries: 5 + restart: unless-stopped + command: ["milvus", "run", "standalone"] + + # MinIO对象存储 + minio: + image: minio/minio:latest + container_name: minio + ports: + - "9000:9000" # API端口 + - "9001:9001" # Console端口 + environment: + MINIO_ROOT_USER: minioadmin + MINIO_ROOT_PASSWORD: minioadmin123 + volumes: + - minio_data:/data + command: server /data --console-address ":9001" + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] + interval: 30s + timeout: 10s + retries: 5 + restart: unless-stopped + + # Redis缓存 + redis: + image: redis:7-alpine + container_name: redis + ports: + - "6379:6379" + volumes: + - redis_data:/data + command: redis-server --appendonly yes + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 30s + timeout: 10s + retries: 5 + restart: unless-stopped + + # PostgreSQL数据库 (可选) + postgres: + image: postgres:15-alpine + container_name: postgres + ports: + - "5432:5432" + environment: + POSTGRES_USER: compliance + POSTGRES_PASSWORD: compliance123 + POSTGRES_DB: compliance_db + volumes: + - postgres_data:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U compliance"] + interval: 30s + timeout: 10s + retries: 5 + restart: unless-stopped + +volumes: + milvus_data: + minio_data: + redis_data: + postgres_data: + +networks: + default: + name: compliance-network \ No newline at end of file diff --git a/download_model.sh b/download_model.sh new file mode 100644 index 0000000..23906dc --- /dev/null +++ b/download_model.sh @@ -0,0 +1,44 @@ +#!/bin/bash +# download_model.sh - 下载BGE-M3模型文件(可在本地Windows运行后上传到服务器) + +MODEL_DIR="bge-m3-model" +MODEL_URL="https://modelscope.cn/models/Xorbits/bge-m3/resolve/master" + +echo "========================================" +echo "下载 BGE-M3 模型文件" +echo "========================================" +echo "" +echo "目标目录: $MODEL_DIR" +echo "" + +# 创建目录 +mkdir -p "$MODEL_DIR" + +# 下载模型文件 +FILES=( + "config.json" + "model.safetensors" + "tokenizer.json" + "tokenizer_config.json" + "special_tokens_map.json" + "vocab.txt" + "sentencepiece.bpe.model" +) + +for file in "${FILES[@]}"; do + echo "下载: $file" + wget -c "$MODEL_URL/$file" -O "$MODEL_DIR/$file" || curl -L "$MODEL_URL/$file" -o "$MODEL_DIR/$file" +done + +echo "" +echo "========================================" +echo "下载完成!" +echo "========================================" +echo "" +echo "模型文件列表:" +ls -lh "$MODEL_DIR" +echo "" +echo "下一步:" +echo "1. 将 $MODEL_DIR 目录上传到服务器" +echo "2. 在服务器上放置到: ~/.cache/huggingface/hub/models--BAAI--bge-m3/snapshots/main/" +echo "" \ No newline at end of file diff --git a/frontend/index.html b/frontend/index.html new file mode 100644 index 0000000..b50ef1b --- /dev/null +++ b/frontend/index.html @@ -0,0 +1,1150 @@ + + + + + + AI+合规智能中枢 - 文档测试平台 + + + + + + +
+
+ + +
+ +
+
+
+ API +
+
+
+ MILVUS +
+
+
+ + +
+ +
+
文档上传测试
+ +
+
+ + + + +
+
拖拽文件到此处,或点击选择
+
支持上传法规文档进行智能解析
+
+ .PDF + .DOCX + .DOC +
+ +
+ +
+
PDF
+
+
-
+
-
+
+ +
+ +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
解析
+
+
+
+
分块
+
+
+
+
嵌入
+
+
+
+
入库
+
+
+ +
+ + +
+ +
+
+
+ + + +
+
处理成功
+
+
+
+
DOC ID
+
-
+
+
+
CHUNKS
+
-
+
+
+
STATUS
+
-
+
+
+
TIME
+
-
+
+
+
+
+ + +
+
法规检索测试
+ +
+ + +
+ +
+
+
+ + + + +
+
请输入关键词检索法规内容
+
+
+
+
+ + + + \ No newline at end of file diff --git a/pip.conf b/pip.conf new file mode 100644 index 0000000..d50bd4c --- /dev/null +++ b/pip.conf @@ -0,0 +1,6 @@ +[global] +index-url = https://mirrors.aliyun.com/pypi/simple +trusted-host = mirrors.aliyun.com + +[install] +trusted-host = mirrors.aliyun.com \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..f74ab11 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,47 @@ +[project] +name = "ai-compliance-hub" +version = "0.1.0" +description = "AI+合规智能中枢 - 法律法规文档解析入库功能" +authors = [ + {name = "T-systems AI Regulations Team"} +] +readme = "README.md" +requires-python = ">=3.10" +license = {text = "MIT"} + +dependencies = [ + "pymilvus>=2.4.0", + "fastapi>=0.100.0", + "uvicorn[standard]>=0.23.0", + "python-multipart>=0.0.6", + "langchain>=0.1.0", + "langchain-milvus>=0.1.0", + "pymupdf>=1.24.0", + "python-docx>=0.8.11", + "FlagEmbedding>=1.2.0", + "sentence-transformers>=2.2.0", + "pydantic>=2.0.0", + "pydantic-settings>=2.0.0", + "python-dotenv>=1.0.0", + "loguru>=0.7.0", + "tenacity>=8.2.0", + "httpx>=0.24.0", +] + +[project.optional-dependencies] +mineru = ["magic-pdf[full]>=0.6.0"] +queue = ["celery>=5.3.0", "redis>=4.5.0"] +storage = ["minio>=7.1.0"] +db = ["psycopg2-binary>=2.9.0"] +dev = ["pytest>=7.0.0", "pytest-asyncio>=0.21.0"] + +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +asyncio_mode = "auto" \ No newline at end of file diff --git a/quick_start.sh b/quick_start.sh new file mode 100644 index 0000000..a7ff18a --- /dev/null +++ b/quick_start.sh @@ -0,0 +1,246 @@ +#!/bin/bash +# quick_start.sh - 快速启动脚本 +# 适配Docker部署的数据库环境 + +set -e + +echo "========================================" +echo "AI+合规智能中枢 - 快速启动脚本" +echo "========================================" +echo "" + +# 颜色定义 +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +VENV_DIR=".venv" + +# 检查Python版本 +echo -e "${YELLOW}[1/8] 检查Python环境...${NC}" +if command -v python3 &> /dev/null; then + PYTHON_CMD=python3 +else + PYTHON_CMD=python +fi + +PYTHON_VERSION=$($PYTHON_CMD --version 2>&1 | awk '{print $2}') +echo "Python版本: $PYTHON_VERSION" + +if [[ ! "$PYTHON_VERSION" =~ ^3\.1[0-9] ]]; then + echo -e "${RED}需要Python 3.10+,当前版本: $PYTHON_VERSION${NC}" + exit 1 +fi +echo -e "${GREEN}Python环境检查通过${NC}" +echo "" + +# 创建虚拟环境 +echo -e "${YELLOW}[2/8] 创建虚拟环境...${NC}" +if [ -d "$VENV_DIR" ]; then + echo "虚拟环境已存在: $VENV_DIR" +else + echo "正在创建虚拟环境..." + $PYTHON_CMD -m venv $VENV_DIR + echo -e "${GREEN}虚拟环境创建成功${NC}" +fi + +# 激活虚拟环境 +source $VENV_DIR/bin/activate +echo "已激活虚拟环境: $VENV_DIR" +echo "" + +# 安装依赖(使用国内镜像源) +echo -e "${YELLOW}[3/8] 安装Python依赖...${NC}" + +# 配置pip使用清华镜像源(国内加速) +PIP_MIRROR="https://mirrors.aliyun.com/pypi/simple" + +echo "使用镜像源: $PIP_MIRROR" +pip config set global.index-url $PIP_MIRROR -q + +pip install --upgrade pip -q +pip install -r requirements.txt -q + +if [ $? -eq 0 ]; then + echo -e "${GREEN}依赖安装完成${NC}" +else + echo -e "${RED}依赖安装失败,请检查requirements.txt${NC}" + exit 1 +fi +echo "" + +# 检查Docker容器状态 +echo -e "${YELLOW}[4/8] 检查Docker容器状态...${NC}" +REQUIRED_CONTAINERS="milvus minio redis postgres" + +for container in $REQUIRED_CONTAINERS; do + if docker ps --format '{{.Names}}' | grep -q "^${container}$"; then + echo -e "${GREEN}✓ ${container} 容器运行正常${NC}" + elif docker ps -a --format '{{.Names}}' | grep -q "^${container}$"; then + echo -e "${YELLOW}⚠ ${container} 容器已停止,尝试启动...${NC}" + docker start ${container} + sleep 2 + if docker ps --format '{{.Names}}' | grep -q "^${container}$"; then + echo -e "${GREEN}✓ ${container} 已启动${NC}" + else + echo -e "${RED}✗ ${container} 启动失败${NC}" + fi + else + echo -e "${RED}✗ ${container} 容器不存在${NC}" + fi +done +echo "" + +# 检查PostgreSQL连接(通过Python) +echo -e "${YELLOW}[5/8] 检查PostgreSQL连接...${NC}" +python << 'EOF' +import sys +try: + import psycopg2 + conn = psycopg2.connect( + host="localhost", + port=5432, + user="postgresql", + password="postgresql123456", + database="postgres" + ) + cur = conn.cursor() + + # 检查compliance_db是否存在 + cur.execute("SELECT 1 FROM pg_database WHERE datname='compliance_db'") + exists = cur.fetchone() + + if not exists: + cur.execute("CREATE DATABASE compliance_db") + print("数据库 compliance_db 创建成功") + else: + print("数据库 compliance_db 已存在") + + conn.commit() + cur.close() + conn.close() + print("PostgreSQL连接成功") +except Exception as e: + print(f"PostgreSQL连接失败: {e}") + sys.exit(1) +EOF + +if [ $? -eq 0 ]; then + echo -e "${GREEN}PostgreSQL服务运行正常${NC}" +else + echo -e "${RED}PostgreSQL连接失败${NC}" + exit 1 +fi +echo "" + +# 检查Redis连接(通过Python) +echo -e "${YELLOW}[6/8] 检查Redis连接...${NC}" +python << 'EOF' +import sys +try: + import redis + r = redis.Redis( + host="localhost", + port=6379, + password="redis@123", + decode_responses=True + ) + result = r.ping() + if result: + print("Redis连接成功") + else: + sys.exit(1) +except Exception as e: + print(f"Redis连接失败: {e}") + sys.exit(1) +EOF + +if [ $? -eq 0 ]; then + echo -e "${GREEN}Redis服务运行正常${NC}" +else + echo -e "${RED}Redis连接失败${NC}" + exit 1 +fi +echo "" + +# 检查Milvus连接(通过Python) +echo -e "${YELLOW}[7/8] 检查Milvus连接...${NC}" +python << 'EOF' +import sys +try: + from pymilvus import connections, utility + connections.connect(host="localhost", port=19530) + print("Milvus连接成功") + connections.disconnect("default") +except Exception as e: + print(f"Milvus连接失败: {e}") + sys.exit(1) +EOF + +if [ $? -eq 0 ]; then + echo -e "${GREEN}Milvus服务运行正常${NC}" +else + echo -e "${RED}Milvus连接失败${NC}" + exit 1 +fi +echo "" + +# 检查MinIO连接(通过Python) +echo -e "${YELLOW}[8/8] 检查MinIO连接...${NC}" +python << 'EOF' +import sys +try: + from minio import Minio + client = Minio("localhost:9000", "minioadmin", "minioadmin", secure=False) + + # 检查bucket是否存在 + bucket = "compliance-docs" + if not client.bucket_exists(bucket): + client.make_bucket(bucket) + print(f"MinIO bucket '{bucket}' 创建成功") + else: + print(f"MinIO bucket '{bucket}' 已存在") + print("MinIO连接成功") +except Exception as e: + print(f"MinIO连接失败: {e}") + sys.exit(1) +EOF + +if [ $? -eq 0 ]; then + echo -e "${GREEN}MinIO服务运行正常${NC}" +else + echo -e "${RED}MinIO连接失败${NC}" + exit 1 +fi +echo "" + +# 预下载BGE-M3模型(可选) +echo -e "${YELLOW}[提示] BGE-M3嵌入模型...${NC}" +MODEL_CACHE=~/.cache/huggingface/hub/models--BAAI--bge-m3 + +if [ -d "$MODEL_CACHE" ]; then + echo -e "${GREEN}BGE-M3模型已存在${NC}" +else + echo -e "${YELLOW}模型未下载,首次运行时将自动下载(约2GB)${NC}" + echo "手动预下载: python -c \"from FlagEmbedding import BGEM3FlagModel; BGEM3FlagModel('BAAI/bge-m3')\"" +fi +echo "" + +# 输出启动命令 +echo "========================================" +echo -e "${GREEN}环境检查完成!${NC}" +echo "========================================" +echo "" +echo "虚拟环境: $VENV_DIR" +echo "" +echo "启动API服务:" +echo " ./start_api.sh" +echo "" +echo "后台启动:" +echo " ./start_api_background.sh" +echo "" +echo "API文档地址:" +echo " http://localhost:8000/docs" +echo " http://localhost:8000/health" +echo "" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..592d3f3 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,46 @@ +# AI+合规智能中枢 - 法律法规文档解析入库 +# MVP核心依赖包 + +# 向量数据库 +pymilvus>=2.4.0 + +# API框架 +fastapi>=0.100.0 +uvicorn[standard]>=0.23.0 +python-multipart>=0.0.6 + +# RAG框架 +langchain>=0.1.0 +langchain-milvus>=0.1.0 + +# PDF解析 +pymupdf>=1.24.0 # PyMuPDF + +# Word文档解析 +python-docx>=0.8.11 + +# MinerU多模态PDF解析(可选,需要额外配置) +# magic-pdf[full]>=0.6.0 + +# 嵌入模型 +FlagEmbedding>=1.2.0 +sentence-transformers>=2.2.0 + +# 任务队列(可选) +celery>=5.3.0 +redis>=4.5.0 + +# 对象存储(可选) +minio>=7.1.0 + +# 数据库 +psycopg2-binary>=2.9.0 +# mysql-connector-python>=8.0.0 + +# 工具库 +pydantic>=2.0.0 +pydantic-settings>=2.0.0 +python-dotenv>=1.0.0 +loguru>=0.7.0 +tenacity>=8.2.0 +httpx>=0.24.0 \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e3df257 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,4 @@ +# src/__init__.py +"""AI+合规智能中枢 - 法律法规文档解析入库功能""" + +__version__ = "0.1.0" \ No newline at end of file diff --git a/src/api/__init__.py b/src/api/__init__.py new file mode 100644 index 0000000..0f376a5 --- /dev/null +++ b/src/api/__init__.py @@ -0,0 +1,2 @@ +# src/api/__init__.py +"""API接口模块""" \ No newline at end of file diff --git a/src/api/main.py b/src/api/main.py new file mode 100644 index 0000000..3badb30 --- /dev/null +++ b/src/api/main.py @@ -0,0 +1,109 @@ +# src/api/main.py +"""FastAPI应用入口""" + +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from contextlib import asynccontextmanager +from loguru import logger +import sys +import os + +# 设置日志 +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) +from src.config.logging import setup_logging +from src.config.settings import settings +from src.api.routes import api_router +from src.api.models import ErrorResponse + +# 配置日志 +setup_logging(level="INFO" if not settings.debug else "DEBUG") + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """应用生命周期管理""" + logger.info(f"启动 {settings.app_name} v{settings.app_version}") + logger.info(f"调试模式: {settings.debug}") + + # 启动时的初始化操作 + # 如:预加载模型、建立数据库连接池等 + + yield + + # 关闭时的清理操作 + logger.info("应用关闭,执行清理...") + # 如:关闭数据库连接、清理缓存等 + + +# 创建FastAPI应用 +app = FastAPI( + title=settings.app_name, + description="AI+合规智能中枢 - 法律法规文档解析入库功能\n\n支持PDF/DOCX文档解析、智能分块、向量嵌入、Milvus存储", + version=settings.app_version, + lifespan=lifespan, + docs_url="/docs", + redoc_url="/redoc" +) + +# CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # 生产环境应配置具体域名 + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# 注册路由 +app.include_router(api_router, prefix="/api/v1") + + +# 全局异常处理 +@app.exception_handler(Exception) +async def global_exception_handler(request: Request, exc: Exception): + """全局异常处理""" + logger.error(f"未处理的异常: {exc}") + return JSONResponse( + status_code=500, + content=ErrorResponse( + error="InternalServerError", + message=str(exc) + ).model_dump() + ) + + +# 健康检查接口 +@app.get("/health", tags=["health"]) +async def health_check(): + """健康检查""" + return { + "status": "healthy", + "app": settings.app_name, + "version": settings.app_version + } + + +# 根路径 +@app.get("/", tags=["root"]) +async def root(): + """根路径""" + return { + "message": f"Welcome to {settings.app_name}", + "version": settings.app_version, + "docs": "/docs", + "health": "/health" + } + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run( + "main:app", + host=settings.api_host, + port=settings.api_port, + reload=settings.debug, + log_level="info" + ) \ No newline at end of file diff --git a/src/api/models/__init__.py b/src/api/models/__init__.py new file mode 100644 index 0000000..139463b --- /dev/null +++ b/src/api/models/__init__.py @@ -0,0 +1,22 @@ +# src/api/models/__init__.py +"""API数据模型""" + +from .document import ( + DocumentUploadRequest, + DocumentUploadResponse, + SearchRequest, + SearchResultItem, + SearchResponse, + DocumentStatusResponse, + ErrorResponse +) + +__all__ = [ + "DocumentUploadRequest", + "DocumentUploadResponse", + "SearchRequest", + "SearchResultItem", + "SearchResponse", + "DocumentStatusResponse", + "ErrorResponse" +] \ No newline at end of file diff --git a/src/api/models/document.py b/src/api/models/document.py new file mode 100644 index 0000000..ab5c64b --- /dev/null +++ b/src/api/models/document.py @@ -0,0 +1,61 @@ +# src/api/models/document.py +"""文档相关Pydantic数据模型""" + +from pydantic import BaseModel, Field +from typing import List, Optional, Dict, Any +from datetime import datetime + + +class DocumentUploadRequest(BaseModel): + """文档上传请求""" + doc_name: Optional[str] = Field(None, description="文档名称") + regulation_type: Optional[str] = Field(None, description="法规类型") + version: Optional[str] = Field(None, description="文档版本") + + +class DocumentUploadResponse(BaseModel): + """文档上传响应""" + doc_id: str = Field(..., description="文档ID") + doc_name: str = Field(..., description="文档名称") + status: str = Field(..., description="处理状态") + message: str = Field(..., description="状态消息") + num_chunks: Optional[int] = Field(None, description="分块数量") + timestamp: datetime = Field(default_factory=datetime.now, description="时间戳") + + +class SearchRequest(BaseModel): + """检索请求""" + query: str = Field(..., description="查询文本") + top_k: int = Field(default=10, description="返回结果数量") + filters: Optional[str] = Field(None, description="过滤条件") + + +class SearchResultItem(BaseModel): + """单个检索结果""" + id: int = Field(..., description="记录ID") + content: str = Field(..., description="内容") + score: float = Field(..., description="相似度分数") + metadata: Dict[str, Any] = Field(default_factory=dict, description="元数据") + + +class SearchResponse(BaseModel): + """检索响应""" + query: str = Field(..., description="查询文本") + total: int = Field(..., description="结果总数") + results: List[SearchResultItem] = Field(default_factory=list, description="结果列表") + timestamp: datetime = Field(default_factory=datetime.now, description="时间戳") + + +class DocumentStatusResponse(BaseModel): + """文档状态响应""" + doc_id: str = Field(..., description="文档ID") + status: str = Field(..., description="状态") + num_chunks: Optional[int] = Field(None, description="分块数量") + timestamp: datetime = Field(default_factory=datetime.now, description="时间戳") + + +class ErrorResponse(BaseModel): + """错误响应""" + error: str = Field(..., description="错误类型") + message: str = Field(..., description="错误消息") + timestamp: datetime = Field(default_factory=datetime.now, description="时间戳") \ No newline at end of file diff --git a/src/api/routes/__init__.py b/src/api/routes/__init__.py new file mode 100644 index 0000000..74c30c6 --- /dev/null +++ b/src/api/routes/__init__.py @@ -0,0 +1,15 @@ +# src/api/routes/__init__.py +"""API路由模块""" + +from fastapi import APIRouter +from .documents import router as documents_router +from .knowledge import router as knowledge_router + +# 主路由 +api_router = APIRouter() + +# 注册子路由 +api_router.include_router(documents_router) +api_router.include_router(knowledge_router) + +__all__ = ["api_router", "documents_router", "knowledge_router"] \ No newline at end of file diff --git a/src/api/routes/documents.py b/src/api/routes/documents.py new file mode 100644 index 0000000..3af9956 --- /dev/null +++ b/src/api/routes/documents.py @@ -0,0 +1,117 @@ +# src/api/routes/documents.py +"""文档上传与处理接口""" + +from fastapi import APIRouter, UploadFile, File, Form, HTTPException +from typing import Optional +import os +import uuid +import tempfile +from loguru import logger + +from ..models import DocumentUploadResponse, ErrorResponse +from src.services.document_processor import DocumentProcessor +from src.config.settings import settings + +router = APIRouter(prefix="/documents", tags=["documents"]) + + +@router.post("/upload", response_model=DocumentUploadResponse) +async def upload_document( + file: UploadFile = File(..., description="上传的文档文件"), + doc_name: Optional[str] = Form(None, description="文档名称"), + regulation_type: Optional[str] = Form(None, description="法规类型"), + version: Optional[str] = Form(None, description="文档版本") +): + """ + 上传文档并处理 + + 支持格式:PDF、DOCX、DOC + 处理流程:解析 → 分块 → 嵌入 → 入库 + """ + # 验证文件类型 + ext = os.path.splitext(file.filename)[1].lower() + if ext not in [".pdf", ".docx", ".doc"]: + raise HTTPException( + status_code=400, + detail=f"不支持的文件类型: {ext},仅支持PDF、DOCX、DOC" + ) + + # 验证文件大小 + if file.size and file.size > settings.max_file_size_mb * 1024 * 1024: + raise HTTPException( + status_code=400, + detail=f"文件过大,最大支持{settings.max_file_size_mb}MB" + ) + + # 生成文档ID + doc_id = str(uuid.uuid4())[:8] + + # 文档名称 + final_doc_name = doc_name or file.filename + + logger.info(f"接收到文件上传: {final_doc_name}, 类型: {ext}") + + try: + # 保存临时文件 + temp_dir = tempfile.gettempdir() + temp_path = os.path.join(temp_dir, f"{doc_id}_{file.filename}") + + with open(temp_path, "wb") as f: + content = await file.read() + f.write(content) + + logger.info(f"文件已保存到: {temp_path}") + + # 处理文档 + processor = DocumentProcessor() + result = processor.process( + file_path=temp_path, + doc_name=final_doc_name, + regulation_type=regulation_type or "", + version=version or "" + ) + processor.close() + + # 清理临时文件 + try: + os.remove(temp_path) + except: + pass + + if result.success: + return DocumentUploadResponse( + doc_id=result.doc_id, + doc_name=result.doc_name, + status="success", + message=result.message, + num_chunks=result.num_chunks + ) + else: + raise HTTPException( + status_code=500, + detail=result.message + ) + + except Exception as e: + logger.error(f"文档处理失败: {e}") + raise HTTPException( + status_code=500, + detail=f"文档处理失败: {str(e)}" + ) + + +@router.get("/status/{doc_id}", response_model=DocumentUploadResponse) +async def get_document_status(doc_id: str): + """ + 查询文档处理状态 + + Args: + doc_id: 文档ID + """ + # TODO: 实现状态查询(需要数据库支持) + return DocumentUploadResponse( + doc_id=doc_id, + doc_name="", + status="unknown", + message="状态查询功能待实现" + ) \ No newline at end of file diff --git a/src/api/routes/knowledge.py b/src/api/routes/knowledge.py new file mode 100644 index 0000000..f13b663 --- /dev/null +++ b/src/api/routes/knowledge.py @@ -0,0 +1,81 @@ +# src/api/routes/knowledge.py +"""知识库检索接口""" + +from fastapi import APIRouter, HTTPException +from loguru import logger + +from ..models import SearchRequest, SearchResponse, SearchResultItem, ErrorResponse +from src.services.document_processor import DocumentProcessor + +router = APIRouter(prefix="/knowledge", tags=["knowledge"]) + + +@router.post("/search", response_model=SearchResponse) +async def search_knowledge(request: SearchRequest): + """ + 检索法规知识库 + + 使用混合检索:Dense向量 + Sparse向量 + RRF融合 + + Args: + request: 检索请求参数 + """ + if not request.query or len(request.query.strip()) == 0: + raise HTTPException( + status_code=400, + detail="查询文本不能为空" + ) + + logger.info(f"收到检索请求: {request.query}") + + try: + # 执行检索 + processor = DocumentProcessor() + results = processor.search( + query=request.query, + top_k=request.top_k, + filters=request.filters + ) + processor.close() + + # 转换结果格式 + result_items = [] + for r in results: + item = SearchResultItem( + id=r.get("id", 0), + content=r.get("content", ""), + score=r.get("score", 0.0), + metadata=r.get("metadata", {}) + ) + result_items.append(item) + + return SearchResponse( + query=request.query, + total=len(result_items), + results=result_items + ) + + except Exception as e: + logger.error(f"检索失败: {e}") + raise HTTPException( + status_code=500, + detail=f"检索失败: {str(e)}" + ) + + +@router.post("/retrieval", response_model=SearchResponse) +async def knowledge_retrieval(request: SearchRequest): + """ + 知识检索接口(与架构文档对齐) + + 该接口实现完整的检索流程: + 1. 意图识别 + 2. BM25关键词检索 + 向量语义检索(双路召回) + 3. Cross-Encoder精排 + 4. 返回结果 + + Args: + request: 检索请求 + """ + # 当前版本使用混合检索,后续可添加精排步骤 + return await search_knowledge(request) \ No newline at end of file diff --git a/src/config/__init__.py b/src/config/__init__.py new file mode 100644 index 0000000..90d7a49 --- /dev/null +++ b/src/config/__init__.py @@ -0,0 +1,6 @@ +# src/config/__init__.py +"""配置模块""" + +from .settings import Settings, get_settings + +__all__ = ["Settings", "get_settings"] \ No newline at end of file diff --git a/src/config/logging.py b/src/config/logging.py new file mode 100644 index 0000000..a68820e --- /dev/null +++ b/src/config/logging.py @@ -0,0 +1,32 @@ +# src/config/logging.py +"""日志配置""" + +from loguru import logger +import sys + + +def setup_logging(level: str = "INFO"): + """设置日志配置""" + + # 移除默认handler + logger.remove() + + # 添加控制台输出 + logger.add( + sys.stdout, + level=level, + format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}", + colorize=True + ) + + # 添加文件输出 + logger.add( + "logs/app_{time:YYYY-MM-DD}.log", + level=level, + format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}", + rotation="00:00", + retention="7 days", + compression="zip" + ) + + return logger \ No newline at end of file diff --git a/src/config/settings.py b/src/config/settings.py new file mode 100644 index 0000000..0acc7ba --- /dev/null +++ b/src/config/settings.py @@ -0,0 +1,73 @@ +# src/config/settings.py +"""配置管理 - 环境变量和默认配置""" + +from pydantic_settings import BaseSettings +from pydantic import Field +from typing import Optional +from functools import lru_cache + + +class Settings(BaseSettings): + """应用配置""" + + # 应用基础配置 + app_name: str = Field(default="AI+合规智能中枢", description="应用名称") + app_version: str = Field(default="0.1.0", description="应用版本") + debug: bool = Field(default=False, description="调试模式") + + # Milvus向量数据库配置 + milvus_host: str = Field(default="localhost", description="Milvus服务地址") + milvus_port: int = Field(default=19530, description="Milvus服务端口") + milvus_collection: str = Field(default="regulations", description="法规向量集合名称") + milvus_db_name: str = Field(default="default", description="Milvus数据库名称") + + # 嵌入模型配置 + embedding_model: str = Field(default="BAAI/bge-m3", description="嵌入模型名称") + embedding_dim: int = Field(default=1024, description="嵌入向量维度") + embedding_max_length: int = Field(default=8192, description="最大嵌入长度") + embedding_batch_size: int = Field(default=12, description="嵌入批处理大小") + embedding_use_fp16: bool = Field(default=True, description="使用FP16加速") + + # MinIO对象存储配置 + minio_endpoint: str = Field(default="localhost:9000", description="MinIO服务地址") + minio_access_key: str = Field(default="minioadmin", description="MinIO访问密钥") + minio_secret_key: str = Field(default="minioadmin123", description="MinIO秘密密钥") + minio_bucket: str = Field(default="compliance-docs", description="文档存储桶名称") + minio_secure: bool = Field(default=False, description="是否使用HTTPS") + + # Redis配置 + redis_host: str = Field(default="localhost", description="Redis服务地址") + redis_port: int = Field(default=6379, description="Redis服务端口") + redis_password: str = Field(default="", description="Redis密码") + redis_db: int = Field(default=0, description="Redis数据库编号") + + # PostgreSQL配置 + postgres_host: str = Field(default="localhost", description="PostgreSQL服务地址") + postgres_port: int = Field(default=5432, description="PostgreSQL服务端口") + postgres_user: str = Field(default="compliance", description="PostgreSQL用户名") + postgres_password: str = Field(default="compliance123", description="PostgreSQL密码") + postgres_db: str = Field(default="compliance_db", description="PostgreSQL数据库名称") + + # 文档处理配置 + chunk_size: int = Field(default=512, description="分块大小(字符数)") + chunk_overlap: int = Field(default=50, description="分块重叠大小") + max_file_size_mb: int = Field(default=100, description="最大文件大小(MB)") + + # API配置 + api_host: str = Field(default="0.0.0.0", description="API服务地址") + api_port: int = Field(default=8000, description="API服务端口") + + class Config: + env_file = ".env" + env_file_encoding = "utf-8" + extra = "ignore" + + +@lru_cache +def get_settings() -> Settings: + """获取配置实例(缓存)""" + return Settings() + + +# 导出默认配置实例 +settings = get_settings() \ No newline at end of file diff --git a/src/services/__init__.py b/src/services/__init__.py new file mode 100644 index 0000000..ac6925a --- /dev/null +++ b/src/services/__init__.py @@ -0,0 +1,2 @@ +# src/services/__init__.py +"""业务服务模块""" \ No newline at end of file diff --git a/src/services/document_processor.py b/src/services/document_processor.py new file mode 100644 index 0000000..a220431 --- /dev/null +++ b/src/services/document_processor.py @@ -0,0 +1,348 @@ +# src/services/document_processor.py +"""文档处理主流程 - 解析→分块→嵌入→入库""" + +import os +from typing import List, Dict, Optional +from dataclasses import dataclass +from loguru import logger +import uuid + +from .parser.pdf_parser import PDFParser +from .parser.docx_parser import DocxParser +from .parser.mineru_parser import ParserOrchestrator +from .embedding.text_chunker import RegulationChunker, TextChunk +from .embedding.bge_m3_embedder import BGEM3Embedder, EmbeddingResult +from .storage.milvus_client import MilvusClient +from src.config.settings import settings + + +@dataclass +class ProcessingResult: + """文档处理结果""" + doc_id: str + doc_name: str + success: bool + num_chunks: int = 0 + message: str = "" + markdown_text: str = "" + + +class DocumentProcessor: + """ + 文档处理服务 - 完整处理流程 + + 流程: + 1. 文档解析(PDF/DOCX → Markdown) + 2. 智能分块(章节级+条款级) + 3. 向量嵌入(BGE-M3 Dense+Sparse) + 4. 存储入库(Milvus向量数据库) + """ + + def __init__( + self, + chunk_size: int = None, + embedding_model: str = None, + use_mineru: bool = True + ): + """ + 初始化文档处理器 + + Args: + chunk_size: 分块大小 + embedding_model: 嵌入模型名称 + use_mineru: 是否优先使用MinerU解析 + """ + self.chunk_size = chunk_size or settings.chunk_size + self.embedding_model = embedding_model or settings.embedding_model + self.use_mineru = use_mineru + + # 初始化各组件 + logger.info("初始化文档处理组件...") + + # 解析器 + self.parser = ParserOrchestrator() + + # 分块器 + self.chunker = RegulationChunker(chunk_size=self.chunk_size) + + # 嵌入模型(延迟加载,首次使用时初始化) + self.embedder: Optional[BGEM3Embedder] = None + + # Milvus客户端(延迟连接) + self.milvus: Optional[MilvusClient] = None + + logger.success("文档处理器初始化完成") + + def _init_embedder(self): + """延迟初始化嵌入模型""" + if self.embedder is None: + logger.info("加载嵌入模型...") + self.embedder = BGEM3Embedder(model_name=self.embedding_model) + + def _init_milvus(self): + """延迟初始化Milvus连接""" + if self.milvus is None: + logger.info("连接Milvus...") + self.milvus = MilvusClient() + self.milvus.connect() + self.milvus.create_collection(recreate=False) + self.milvus.load_collection() + + def process( + self, + file_path: str, + doc_name: Optional[str] = None, + regulation_type: str = "", + version: str = "" + ) -> ProcessingResult: + """ + 处理单个文档 + + Args: + file_path: 文档文件路径 + doc_name: 文档名称(可选,默认从文件名获取) + regulation_type: 法规类型 + version: 文档版本 + + Returns: + ProcessingResult: 处理结果 + """ + # 生成文档ID + doc_id = str(uuid.uuid4())[:8] + + # 获取文档名称 + if doc_name is None: + doc_name = os.path.basename(file_path) + + logger.info(f"开始处理文档: {doc_name} (ID: {doc_id})") + + try: + # 1. 文档解析 + logger.info("Step 1: 文档解析") + markdown_text = self._parse_document(file_path) + + if not markdown_text: + return ProcessingResult( + doc_id=doc_id, + doc_name=doc_name, + success=False, + message="文档解析失败,内容为空" + ) + + # 2. 智能分块 + logger.info("Step 2: 智能分块") + chunks = self._chunk_document( + markdown_text, + doc_id, + doc_name, + regulation_type, + version + ) + + if not chunks: + return ProcessingResult( + doc_id=doc_id, + doc_name=doc_name, + success=False, + message="分块失败,无有效内容" + ) + + # 3. 向量嵌入 + logger.info("Step 3: 向量嵌入") + embeddings = self._embed_chunks(chunks) + + if embeddings is None: + return ProcessingResult( + doc_id=doc_id, + doc_name=doc_name, + success=False, + message="向量嵌入失败" + ) + + # 4. 存储入库 + logger.info("Step 4: 存储入库") + inserted_ids = self._insert_to_milvus(chunks, embeddings) + + logger.success(f"文档处理完成: {doc_name}, 共{len(inserted_ids)}条记录") + + return ProcessingResult( + doc_id=doc_id, + doc_name=doc_name, + success=True, + num_chunks=len(inserted_ids), + message="处理成功", + markdown_text=markdown_text + ) + + except Exception as e: + logger.error(f"文档处理失败: {e}") + return ProcessingResult( + doc_id=doc_id, + doc_name=doc_name, + success=False, + message=f"处理失败: {str(e)}" + ) + + def _parse_document(self, file_path: str) -> str: + """解析文档""" + ext = os.path.splitext(file_path)[1].lower() + + try: + if ext == ".pdf": + # PDF文档解析(优先MinerU,回退PyMuPDF) + markdown_text = self.parser.parse_pdf(file_path, prefer_mineru=self.use_mineru) + elif ext in [".docx", ".doc"]: + # Word文档解析 + markdown_text = self.parser.parse_docx(file_path) + else: + logger.warning(f"不支持的文件类型: {ext}") + return "" + + logger.success(f"文档解析完成,内容长度: {len(markdown_text)}字符") + return markdown_text + + except Exception as e: + logger.error(f"文档解析失败: {e}") + return "" + + def _chunk_document( + self, + markdown_text: str, + doc_id: str, + doc_name: str, + regulation_type: str, + version: str + ) -> List[TextChunk]: + """分块文档""" + try: + chunks = self.chunker.chunk_document( + markdown_text, + doc_id=doc_id, + doc_name=doc_name, + regulation_type=regulation_type, + version=version + ) + logger.success(f"分块完成,共{len(chunks)}个chunk") + return chunks + + except Exception as e: + logger.error(f"分块失败: {e}") + return [] + + def _embed_chunks(self, chunks: List[TextChunk]) -> Optional[EmbeddingResult]: + """嵌入分块""" + try: + # 延迟初始化嵌入模型 + self._init_embedder() + + # 提取文本内容 + texts = [chunk.content for chunk in chunks] + + # 执行嵌入 + embeddings = self.embedder.embed(texts) + + logger.success(f"嵌入完成,向量数: {len(embeddings.dense_embeddings)}") + return embeddings + + except Exception as e: + logger.error(f"嵌入失败: {e}") + return None + + def _insert_to_milvus( + self, + chunks: List[TextChunk], + embeddings: EmbeddingResult + ) -> List[int]: + """插入Milvus""" + try: + # 延迟初始化Milvus + self._init_milvus() + + # 执行插入 + inserted_ids = self.milvus.insert_chunks(chunks, embeddings) + + logger.success(f"入库完成,共{len(inserted_ids)}条记录") + return inserted_ids + + except Exception as e: + logger.error(f"入库失败: {e}") + return [] + + def search( + self, + query: str, + top_k: int = 10, + filters: Optional[str] = None + ) -> List[Dict]: + """ + 检索法规内容 + + Args: + query: 查询文本 + top_k: 返回结果数量 + filters: 过滤条件 + + Returns: + List[Dict]: 检索结果 + """ + logger.info(f"执行检索: {query}") + + try: + # 延迟初始化 + self._init_embedder() + self._init_milvus() + + # 生成查询向量 + query_embedding = self.embedder.embed_single(query) + + # 执行混合检索 + results = self.milvus.hybrid_search( + query_dense=query_embedding['dense'].tolist(), + query_sparse=query_embedding['sparse'], + top_k=top_k, + filters=filters + ) + + # 转换为字典格式 + result_dicts = [] + for r in results: + result_dicts.append({ + "id": r.id, + "content": r.content, + "score": r.score, + "metadata": r.metadata + }) + + logger.success(f"检索完成,返回{len(result_dicts)}条结果") + return result_dicts + + except Exception as e: + logger.error(f"检索失败: {e}") + return [] + + def close(self): + """关闭连接""" + if self.milvus: + self.milvus.disconnect() + logger.info("文档处理器已关闭") + + +def process_document( + file_path: str, + doc_name: Optional[str] = None, + regulation_type: str = "", + version: str = "" +) -> ProcessingResult: + """便捷函数:处理单个文档""" + processor = DocumentProcessor() + result = processor.process(file_path, doc_name, regulation_type, version) + processor.close() + return result + + +def search_regulations(query: str, top_k: int = 10) -> List[Dict]: + """便捷函数:检索法规""" + processor = DocumentProcessor() + results = processor.search(query, top_k) + processor.close() + return results \ No newline at end of file diff --git a/src/services/embedding/__init__.py b/src/services/embedding/__init__.py new file mode 100644 index 0000000..56c0451 --- /dev/null +++ b/src/services/embedding/__init__.py @@ -0,0 +1,7 @@ +# src/services/embedding/__init__.py +"""嵌入和分块服务""" + +from .text_chunker import RegulationChunker +from .bge_m3_embedder import BGEM3Embedder + +__all__ = ["RegulationChunker", "BGEM3Embedder"] \ No newline at end of file diff --git a/src/services/embedding/bge_m3_embedder.py b/src/services/embedding/bge_m3_embedder.py new file mode 100644 index 0000000..73daa0e --- /dev/null +++ b/src/services/embedding/bge_m3_embedder.py @@ -0,0 +1,296 @@ +# src/services/embedding/bge_m3_embedder.py +"""BGE-M3嵌入服务 - Dense+Sparse双路向量生成""" + +import numpy as np +from typing import List, Dict, Optional, Union +from dataclasses import dataclass, field +from loguru import logger +import torch +import os + +# 设置HuggingFace镜像(国内网络) +if 'HF_ENDPOINT' not in os.environ: + os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' + +# 本地模型路径(按优先级检查) +LOCAL_MODEL_PATHS = [ + os.path.expanduser("~/.cache/modelscope/Xorbits/bge-m3"), # ModelScope下载路径 + os.path.expanduser("~/.cache/huggingface/hub/models--BAAI--bge-m3/snapshots/main"), # HuggingFace本地路径 +] + + +@dataclass +class EmbeddingResult: + """嵌入结果""" + dense_embeddings: np.ndarray # Dense向量(语义检索) + sparse_embeddings: List[Dict[int, float]] # Sparse向量(关键词匹配) + texts: List[str] + dim: int = 1024 + + +class BGEM3Embedder: + """ + BGE-M3多语言嵌入模型服务 + + BGE-M3是BAAI发布的多语言嵌入模型,支持: + - Dense向量:用于语义相似度检索 + - Sparse向量:用于关键词精确匹配(BM25风格) + - ColBERT向量:用于细粒度交互匹配(可选) + + 特点: + - 支持100+语言(中英双语优化) + - 8192 tokens超长上下文 + - Dense+Sparse双路检索能力 + + GitHub: https://github.com/FlagOpen/FlagEmbedding + """ + + def __init__( + self, + model_name: str = "BAAI/bge-m3", + use_fp16: bool = True, + device: Optional[str] = None, + batch_size: int = 12, + max_length: int = 8192, + local_model_path: Optional[str] = None + ): + """ + 初始化BGE-M3嵌入模型 + + Args: + model_name: 模型名称(如果使用本地路径,此参数会被忽略) + use_fp16: 是否使用FP16加速 + device: 设备类型(cuda/cpu),默认自动选择 + batch_size: 批处理大小 + max_length: 最大序列长度 + local_model_path: 本地模型路径(可选,优先使用) + """ + self.use_fp16 = use_fp16 + self.batch_size = batch_size + self.max_length = max_length + + # 确定模型路径(优先使用本地路径) + if local_model_path and os.path.exists(local_model_path): + self.model_path = local_model_path + self.model_name = "local" + logger.info(f"使用本地模型路径: {local_model_path}") + else: + # 检查多个可能的本地路径 + found_local = False + for path in LOCAL_MODEL_PATHS: + if os.path.exists(path) and os.path.exists(os.path.join(path, "config.json")): + self.model_path = path + self.model_name = "local" + logger.info(f"使用本地模型路径: {path}") + found_local = True + break + + if not found_local: + self.model_path = model_name + self.model_name = model_name + logger.info(f"使用远程模型: {model_name}") + + # 自动选择设备 + if device is None: + self.device = "cuda" if torch.cuda.is_available() else "cpu" + else: + self.device = device + + logger.info(f"初始化BGE-M3模型, 设备: {self.device}") + + self.model = None + self._load_model() + + def _load_model(self): + """加载嵌入模型""" + try: + from FlagEmbedding import BGEM3FlagModel + + self.model = BGEM3FlagModel( + self.model_path, + use_fp16=self.use_fp16, + device=self.device + ) + + logger.success(f"BGE-M3模型加载成功") + + except ImportError: + logger.warning("FlagEmbedding库未安装,请运行: pip install FlagEmbedding") + raise + except Exception as e: + logger.error(f"模型加载失败: {e}") + raise + + def embed( + self, + texts: List[str], + return_dense: bool = True, + return_sparse: bool = True, + return_colbert_vecs: bool = False + ) -> EmbeddingResult: + """ + 对文本列表生成嵌入向量 + + Args: + texts: 文本列表 + return_dense: 是否返回Dense向量 + return_sparse: 是否返回Sparse向量 + return_colbert_vecs: 是否返回ColBERT向量 + + Returns: + EmbeddingResult: 嵌入结果 + """ + if not texts: + logger.warning("输入文本列表为空") + return EmbeddingResult( + dense_embeddings=np.array([]), + sparse_embeddings=[], + texts=[], + dim=0 + ) + + logger.info(f"开始嵌入{len(texts)}个文本块") + + try: + # 执行嵌入 + embeddings = self.model.encode( + texts, + batch_size=self.batch_size, + max_length=self.max_length, + return_dense=return_dense, + return_sparse=return_sparse, + return_colbert_vecs=return_colbert_vecs + ) + + # 提取结果 + dense_embeddings = embeddings.get('dense_vecs', np.array([])) + sparse_embeddings = embeddings.get('lexical_weights', []) + + # 获取维度 + dim = dense_embeddings.shape[1] if len(dense_embeddings) > 0 else 1024 + + logger.success(f"嵌入完成,向量维度: {dim}") + + return EmbeddingResult( + dense_embeddings=dense_embeddings, + sparse_embeddings=sparse_embeddings, + texts=texts, + dim=dim + ) + + except Exception as e: + logger.error(f"嵌入失败: {e}") + raise + + def embed_single(self, text: str) -> Dict[str, Union[np.ndarray, Dict]]: + """ + 对单个文本生成嵌入向量 + + Args: + text: 输入文本 + + Returns: + Dict: 包含dense和sparse向量 + """ + result = self.embed([text]) + return { + 'dense': result.dense_embeddings[0], + 'sparse': result.sparse_embeddings[0] if result.sparse_embeddings else {}, + 'dim': result.dim + } + + def embed_dense(self, texts: List[str]) -> np.ndarray: + """只生成Dense向量""" + result = self.embed(texts, return_sparse=False, return_colbert_vecs=False) + return result.dense_embeddings + + def embed_sparse(self, texts: List[str]) -> List[Dict[int, float]]: + """只生成Sparse向量""" + result = self.embed(texts, return_dense=False, return_colbert_vecs=False) + return result.sparse_embeddings + + def embed_query(self, query: str) -> Dict: + """ + 对查询文本生成嵌入(用于检索) + + Args: + query: 查询文本 + + Returns: + Dict: 包含dense和sparse向量 + """ + return self.embed_single(query) + + def compute_similarity( + self, + query_embedding: np.ndarray, + doc_embeddings: np.ndarray, + metric: str = "cosine" + ) -> np.ndarray: + """ + 计算查询与文档的相似度 + + Args: + query_embedding: 查询向量 + doc_embeddings: 文档向量矩阵 + metric: 相似度度量(cosine/dot) + + Returns: + np.ndarray: 相似度分数数组 + """ + if metric == "cosine": + # 余弦相似度 + query_norm = np.linalg.norm(query_embedding) + doc_norms = np.linalg.norm(doc_embeddings, axis=1) + + similarities = np.dot(doc_embeddings, query_embedding) / (doc_norms * query_norm) + + elif metric == "dot": + # 点积相似度 + similarities = np.dot(doc_embeddings, query_embedding) + + else: + raise ValueError(f"不支持的相似度度量: {metric}") + + return similarities + + def sparse_similarity( + self, + query_sparse: Dict[int, float], + doc_sparse: Dict[int, float] + ) -> float: + """ + 计算Sparse向量的相似度(BM25风格) + + Args: + query_sparse: 查询的Sparse向量(词ID -> 权重) + doc_sparse: 文档的Sparse向量 + + Returns: + float: 相似度分数 + """ + # 计算交集词的点积 + common_keys = set(query_sparse.keys()) & set(doc_sparse.keys()) + + score = sum(query_sparse[k] * doc_sparse[k] for k in common_keys) + return score + + +def embed_texts( + texts: List[str], + model_name: str = "BAAI/bge-m3", + **kwargs +) -> EmbeddingResult: + """便捷函数:对文本列表生成嵌入""" + embedder = BGEM3Embedder(model_name=model_name, **kwargs) + return embedder.embed(texts) + + +def embed_single_text( + text: str, + model_name: str = "BAAI/bge-m3", + **kwargs +) -> Dict: + """便捷函数:对单个文本生成嵌入""" + embedder = BGEM3Embedder(model_name=model_name, **kwargs) + return embedder.embed_single(text) \ No newline at end of file diff --git a/src/services/embedding/text_chunker.py b/src/services/embedding/text_chunker.py new file mode 100644 index 0000000..0240cde --- /dev/null +++ b/src/services/embedding/text_chunker.py @@ -0,0 +1,449 @@ +# src/services/embedding/text_chunker.py +"""智能分块器 - 章节级+条款级双粒度切割""" + +import re +from typing import List, Dict, Optional, Tuple +from dataclasses import dataclass, field +from loguru import logger + + +@dataclass +class ChunkMetadata: + """分块元数据""" + doc_id: str = "" + doc_name: str = "" + chunk_id: str = "" + section_number: str = "" # 章节编号(如 "第一章") + section_title: str = "" # 章节标题 + clause_number: str = "" # 条款编号(如 "第一条") + page_number: int = 0 + start_position: int = 0 # 在原文中的起始位置 + end_position: int = 0 # 在原文中的结束位置 + regulation_type: str = "" # 法规类型 + version: str = "" + + +@dataclass +class TextChunk: + """文本分块""" + content: str + metadata: ChunkMetadata + token_count: int = 0 # 估算的token数量 + + +class RegulationChunker: + """ + 法规文档智能分块器 + + 实现章节级/条款级双粒度切割,适配国标GB文档结构: + - 国标文档通常有明确的层级结构:章 > 节 > 条 + - 每个条款应作为一个独立的语义单元 + - 保留条款完整性,避免跨条款截断 + """ + + # 法规标题模式 + CHAPTER_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+章\s+[^\n]+') + SECTION_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+节\s+[^\n]+') + CLAUSE_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+条\s') + + # 条款子项模式 + SUB_ITEM_PATTERN = re.compile(r'^[\((][一二三四五六七八九十]+[\))]\s') + NUMBER_ITEM_PATTERN = re.compile(r'^[\d]+[\.、]\s') + + def __init__( + self, + chunk_size: int = 512, + chunk_overlap: int = 50, + max_chunk_size: int = 2048, + min_chunk_size: int = 100 + ): + """ + 初始化分块器 + + Args: + chunk_size: 默认分块大小(字符数) + chunk_overlap: 分块重叠大小 + max_chunk_size: 最大分块大小(防止单个条款过长) + min_chunk_size: 最小分块大小(防止碎片化) + """ + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + self.max_chunk_size = max_chunk_size + self.min_chunk_size = min_chunk_size + + def chunk_document( + self, + markdown_text: str, + doc_id: str = "", + doc_name: str = "", + regulation_type: str = "", + version: str = "" + ) -> List[TextChunk]: + """ + 对法规文档进行智能分块 + + Args: + markdown_text: Markdown格式的文档内容 + doc_id: 文档ID + doc_name: 文档名称 + regulation_type: 法规类型 + version: 文档版本 + + Returns: + List[TextChunk]: 分块列表 + """ + logger.info(f"开始分块文档: {doc_name}") + + # 1. 按章节分割(一级分块) + sections = self._split_by_sections(markdown_text) + + # 2. 在每个章节内按条款分割(二级分块) + chunks = [] + global_position = 0 + + for section_num, section_title, section_content, section_start in sections: + # 在章节内按条款分割 + clause_chunks = self._split_by_clauses( + section_content, + section_num, + section_title, + section_start + global_position + ) + + for chunk_content, clause_num, clause_title, start_pos, end_pos in clause_chunks: + # 处理过长的条款(进一步细分) + if len(chunk_content) > self.max_chunk_size: + sub_chunks = self._split_long_clause( + chunk_content, + clause_num, + clause_title + ) + for sub_content, sub_start, sub_end in sub_chunks: + chunk = self._create_chunk( + sub_content, + doc_id, + doc_name, + section_num, + section_title, + clause_num, + sub_start + start_pos, + sub_end + start_pos, + regulation_type, + version + ) + chunks.append(chunk) + else: + chunk = self._create_chunk( + chunk_content, + doc_id, + doc_name, + section_num, + section_title, + clause_num, + start_pos, + end_pos, + regulation_type, + version + ) + chunks.append(chunk) + + logger.success(f"分块完成,共{len(chunks)}个chunk") + return chunks + + def _split_by_sections(self, markdown_text: str) -> List[Tuple[str, str, str, int]]: + """ + 按章节分割文档 + + Returns: + List of (section_number, section_title, section_content, start_position) + """ + sections = [] + lines = markdown_text.split('\n') + + current_section_num = "" + current_section_title = "" + current_section_content = [] + current_section_start = 0 + + for i, line in enumerate(lines): + # 检测章节标题 + chapter_match = self.CHAPTER_PATTERN.match(line.strip()) + section_match = self.SECTION_PATTERN.match(line.strip()) + + if chapter_match or section_match: + # 保存上一个章节 + if current_section_content: + content = '\n'.join(current_section_content) + sections.append(( + current_section_num, + current_section_title, + content, + current_section_start + )) + + # 开始新章节 + current_section_start = sum(len(l) + 1 for l in lines[:i]) + current_section_content = [] + + if chapter_match: + current_section_num = line.strip() + current_section_title = self._extract_title(line.strip()) + else: + current_section_num = line.strip() + current_section_title = self._extract_title(line.strip()) + + current_section_content.append(line) + + # 保存最后一个章节 + if current_section_content: + content = '\n'.join(current_section_content) + sections.append(( + current_section_num, + current_section_title, + content, + current_section_start + )) + + # 如果没有检测到章节,将整个文档作为一个大章节 + if not sections: + sections.append(( + "", + "全文", + markdown_text, + 0 + )) + + return sections + + def _split_by_clauses( + self, + section_content: str, + section_num: str, + section_title: str, + section_start: int + ) -> List[Tuple[str, str, str, int, int]]: + """ + 在章节内按条款分割 + + Returns: + List of (content, clause_number, clause_title, start_position, end_position) + """ + clauses = [] + lines = section_content.split('\n') + + current_clause_num = "" + current_clause_title = "" + current_clause_content = [] + current_clause_start = section_start + + for i, line in enumerate(lines): + # 检测条款标题 + clause_match = self.CLAUSE_PATTERN.match(line.strip()) + + if clause_match: + # 保存上一个条款 + if current_clause_content: + content = '\n'.join(current_clause_content) + end_pos = current_clause_start + len(content) + clauses.append(( + content, + current_clause_num, + current_clause_title, + current_clause_start, + end_pos + )) + + # 开始新条款 + current_clause_start = section_start + sum(len(l) + 1 for l in lines[:i]) + current_clause_content = [] + current_clause_num = self._extract_clause_number(line.strip()) + current_clause_title = line.strip() + + current_clause_content.append(line) + + # 保存最后一个条款 + if current_clause_content: + content = '\n'.join(current_clause_content) + end_pos = current_clause_start + len(content) + clauses.append(( + content, + current_clause_num, + current_clause_title, + current_clause_start, + end_pos + )) + + # 如果没有检测到条款,将整个章节作为一个条款 + if not clauses: + clauses.append(( + section_content, + "", + section_title, + section_start, + section_start + len(section_content) + )) + + return clauses + + def _split_long_clause( + self, + content: str, + clause_num: str, + clause_title: str + ) -> List[Tuple[str, int, int]]: + """ + 分割过长的条款内容 + + 按条款子项或段落分割,保持语义完整性 + """ + sub_chunks = [] + lines = content.split('\n') + + # 检测是否有子项结构 + has_sub_items = any( + self.SUB_ITEM_PATTERN.match(line.strip()) or + self.NUMBER_ITEM_PATTERN.match(line.strip()) + for line in lines + ) + + if has_sub_items: + # 按子项分割 + current_sub_content = [] + current_sub_start = 0 + + for i, line in enumerate(lines): + is_sub_item = ( + self.SUB_ITEM_PATTERN.match(line.strip()) or + self.NUMBER_ITEM_PATTERN.match(line.strip()) + ) + + if is_sub_item and current_sub_content: + sub_content = '\n'.join(current_sub_content) + sub_end = current_sub_start + len(sub_content) + if len(sub_content) >= self.min_chunk_size: + sub_chunks.append((sub_content, current_sub_start, sub_end)) + current_sub_content = [] + current_sub_start = sum(len(l) + 1 for l in lines[:i]) + + current_sub_content.append(line) + + # 保存最后一个子项 + if current_sub_content: + sub_content = '\n'.join(current_sub_content) + sub_end = current_sub_start + len(sub_content) + sub_chunks.append((sub_content, current_sub_start, sub_end)) + + else: + # 按段落分割(滑动窗口) + paragraphs = [] + current_para = [] + + for line in lines: + if line.strip(): + current_para.append(line) + else: + if current_para: + paragraphs.append('\n'.join(current_para)) + current_para = [] + + if current_para: + paragraphs.append('\n'.join(current_para)) + + # 合并段落直到达到chunk_size + current_chunk = [] + current_length = 0 + chunk_start = 0 + + for para in paragraphs: + if current_length + len(para) > self.chunk_size and current_chunk: + chunk_content = '\n'.join(current_chunk) + chunk_end = chunk_start + len(chunk_content) + sub_chunks.append((chunk_content, chunk_start, chunk_end)) + current_chunk = [] + current_length = 0 + chunk_start = chunk_end + + current_chunk.append(para) + current_length += len(para) + + # 保存最后一个chunk + if current_chunk: + chunk_content = '\n'.join(current_chunk) + chunk_end = chunk_start + len(chunk_content) + sub_chunks.append((chunk_content, chunk_start, chunk_end)) + + return sub_chunks + + def _extract_title(self, header_line: str) -> str: + """从标题行提取标题内容""" + # 移除"第X章"、"第X节"前缀 + title = re.sub(r'^第[一二三四五六七八九十百]+[章节]\s+', '', header_line) + return title.strip() + + def _extract_clause_number(self, clause_line: str) -> str: + """从条款行提取条款编号""" + match = self.CLAUSE_PATTERN.match(clause_line) + if match: + return match.group(0).strip() + return "" + + def _create_chunk( + self, + content: str, + doc_id: str, + doc_name: str, + section_num: str, + section_title: str, + clause_num: str, + start_pos: int, + end_pos: int, + regulation_type: str, + version: str + ) -> TextChunk: + """创建文本分块""" + # 清理内容 + content = content.strip() + + # 计算估算token数(中文约1.5字符/token) + token_count = int(len(content) * 0.7) # 简化估算 + + # 生成chunk_id + chunk_id = f"{doc_id}_{section_num}_{clause_num}_{start_pos}" + + metadata = ChunkMetadata( + doc_id=doc_id, + doc_name=doc_name, + chunk_id=chunk_id, + section_number=section_num, + section_title=section_title, + clause_number=clause_num, + start_position=start_pos, + end_position=end_pos, + regulation_type=regulation_type, + version=version + ) + + return TextChunk( + content=content, + metadata=metadata, + token_count=token_count + ) + + +def chunk_regulation_document( + markdown_text: str, + doc_id: str = "", + doc_name: str = "", + regulation_type: str = "", + version: str = "", + chunk_size: int = 512 +) -> List[TextChunk]: + """便捷函数:对法规文档进行分块""" + chunker = RegulationChunker(chunk_size=chunk_size) + return chunker.chunk_document( + markdown_text, + doc_id, + doc_name, + regulation_type, + version + ) \ No newline at end of file diff --git a/src/services/parser/__init__.py b/src/services/parser/__init__.py new file mode 100644 index 0000000..5704664 --- /dev/null +++ b/src/services/parser/__init__.py @@ -0,0 +1,7 @@ +# src/services/parser/__init__.py +"""文档解析服务""" + +from .pdf_parser import PDFParser +from .docx_parser import DocxParser + +__all__ = ["PDFParser", "DocxParser"] \ No newline at end of file diff --git a/src/services/parser/docx_parser.py b/src/services/parser/docx_parser.py new file mode 100644 index 0000000..405096f --- /dev/null +++ b/src/services/parser/docx_parser.py @@ -0,0 +1,287 @@ +# src/services/parser/docx_parser.py +"""Word文档解析 - 使用python-docx""" + +from docx import Document +from docx.enum.text import WD_ALIGN_PARAGRAPH +from typing import List, Dict, Optional +from dataclasses import dataclass, field +from loguru import logger +import re + + +@dataclass +class DocxParagraph: + """段落内容""" + text: str + level: int = 0 # 标题级别,0表示正文 + is_list: bool = False + list_number: Optional[str] = None + + +@dataclass +class DocxTable: + """表格内容""" + rows: List[List[str]] + markdown: str = "" + + +@dataclass +class DocxDocumentContent: + """Word文档完整内容""" + file_path: str + paragraphs: List[DocxParagraph] + tables: List[DocxTable] + metadata: Dict[str, str] = field(default_factory=dict) + markdown_text: str = "" + + +class DocxParser: + """Word文档解析器 - 基于python-docx""" + + def __init__(self): + self.document = None + + def parse(self, file_path: str) -> DocxDocumentContent: + """ + 解析Word文档 + + Args: + file_path: Word文档路径 + + Returns: + DocxDocumentContent: 解析后的文档内容 + """ + logger.info(f"开始解析Word文档: {file_path}") + + try: + self.document = Document(file_path) + doc_content = DocxDocumentContent( + file_path=file_path, + paragraphs=[], + tables=[] + ) + + # 提取文档元数据 + doc_content.metadata = self._extract_metadata() + + # 提取段落 + doc_content.paragraphs = self._extract_paragraphs() + + # 提取表格 + doc_content.tables = self._extract_tables() + + # 生成Markdown格式文本 + doc_content.markdown_text = self._generate_markdown(doc_content) + + logger.success(f"Word文档解析完成,共{len(doc_content.paragraphs)}个段落") + + return doc_content + + except Exception as e: + logger.error(f"Word文档解析失败: {e}") + raise + + def _extract_metadata(self) -> Dict[str, str]: + """提取文档元数据""" + metadata = {} + try: + core_props = self.document.core_properties + metadata = { + "title": core_props.title or "", + "author": core_props.author or "", + "subject": core_props.subject or "", + "keywords": core_props.keywords or "", + "created": str(core_props.created) if core_props.created else "", + "modified": str(core_props.modified) if core_props.modified else "", + } + except Exception as e: + logger.warning(f"提取元数据失败: {e}") + return metadata + + def _extract_paragraphs(self) -> List[DocxParagraph]: + """提取所有段落""" + paragraphs = [] + + for para in self.document.paragraphs: + text = para.text.strip() + if not text: + continue + + # 判断标题级别 + level = self._get_paragraph_level(para) + + # 判断是否是列表项 + is_list, list_number = self._detect_list_item(para) + + paragraph = DocxParagraph( + text=text, + level=level, + is_list=is_list, + list_number=list_number + ) + paragraphs.append(paragraph) + + return paragraphs + + def _get_paragraph_level(self, para) -> int: + """ + 判断段落标题级别 + + Returns: + int: 标题级别,0表示正文 + """ + # 方法1:检查段落样式 + style_name = para.style.name if para.style else "" + + if "Heading" in style_name or "标题" in style_name: + # 从样式名称中提取级别 + match = re.search(r'Heading\s*(\d)|标题\s*(\d)', style_name) + if match: + level = int(match.group(1) or match.group(2)) + return level + + # 方法2:检查段落格式(字号) + # 标题通常字号较大 + if para.paragraph_format: + # 可以根据字号判断,这里简化处理 + pass + + # 方法3:根据内容模式判断(法规文档特征) + text = para.text.strip() + + # 第一章、第X章 -> 二级标题 + if re.match(r'^第[一二三四五六七八九十百]+章\s', text): + return 2 + # 第X节 -> 三级标题 + elif re.match(r'^第[一二三四五六七八九十百]+节\s', text): + return 3 + # 第X条 -> 四级标题 + elif re.match(r'^第[一二三四五六七八九十百]+条\s', text): + return 4 + + return 0 # 正文 + + def _detect_list_item(self, para) -> tuple[bool, Optional[str]]: + """检测是否是列表项""" + text = para.text.strip() + + # 数字列表:1.、2.、(1)、[1]等 + if re.match(r'^[\d]+[.、)\]]\s', text): + match = re.match(r'^([\d]+[.、)\]])\s', text) + return True, match.group(1) if match else None + + # 中文数字列表:一、二、(一)等 + if re.match(r'^[一二三四五六七八九十]+[、.)]\s', text): + match = re.match(r'^([一二三四五六七八九十]+[、.)])\s', text) + return True, match.group(1) if match else None + + # 检查段落格式中的列表编号 + if para.paragraph_format and hasattr(para.paragraph_format, 'left_indent'): + # 有缩进的可能是列表项 + pass + + return False, None + + def _extract_tables(self) -> List[DocxTable]: + """提取所有表格""" + tables = [] + + for table in self.document.tables: + rows = [] + for row in table.rows: + cells = [] + for cell in row.cells: + cells.append(cell.text.strip()) + rows.append(cells) + + # 转换为Markdown表格 + markdown = self._table_to_markdown(rows) + + table_content = DocxTable(rows=rows, markdown=markdown) + tables.append(table_content) + + return tables + + def _table_to_markdown(self, rows: List[List[str]]) -> str: + """将表格转换为Markdown格式""" + if not rows or len(rows) < 1: + return "" + + lines = [] + + # 表头 + if len(rows) >= 1: + header = rows[0] + lines.append("| " + " | ".join(cell for cell in header) + " |") + lines.append("| " + " | ".join("---" for _ in header) + " |") + + # 数据行 + for row in rows[1:]: + lines.append("| " + " | ".join(cell for cell in row) + " |") + + return "\n".join(lines) + + def _generate_markdown(self, doc_content: DocxDocumentContent) -> str: + """生成Markdown格式文本""" + lines = [] + + # 文档标题 + title = doc_content.metadata.get("title", "") + if title: + lines.append(f"# {title}\n") + else: + # 从第一个段落获取标题(如果是标题样式) + for para in doc_content.paragraphs[:5]: + if para.level == 1: + lines.append(f"# {para.text}\n") + break + else: + lines.append(f"# {doc_content.file_path}\n") + + # 元数据信息 + lines.append("\n## 文档信息\n") + for key, value in doc_content.metadata.items(): + if value: + lines.append(f"- **{key}**: {value}") + + # 正文内容 + lines.append("\n## 正文\n") + + table_index = 0 + for para in doc_content.paragraphs: + if para.level > 0: + # 标题 + prefix = "#" * para.level + lines.append(f"\n{prefix} {para.text}\n") + elif para.is_list: + # 列表项 + lines.append(f"- {para.text}") + else: + # 正文 + lines.append(para.text) + + # 添加表格 + if doc_content.tables: + lines.append("\n## 表格\n") + for i, table in enumerate(doc_content.tables): + lines.append(f"\n### 表格 {i + 1}\n") + lines.append(table.markdown + "\n") + + return "\n".join(lines) + + def parse_to_markdown(self, file_path: str) -> str: + """直接解析并返回Markdown文本""" + doc_content = self.parse(file_path) + return doc_content.markdown_text + + +def parse_docx(file_path: str) -> DocxDocumentContent: + """便捷函数:解析Word文档""" + parser = DocxParser() + return parser.parse(file_path) + + +def parse_docx_to_markdown(file_path: str) -> str: + """便捷函数:解析Word并返回Markdown""" + parser = DocxParser() + return parser.parse_to_markdown(file_path) \ No newline at end of file diff --git a/src/services/parser/mineru_parser.py b/src/services/parser/mineru_parser.py new file mode 100644 index 0000000..e4baea2 --- /dev/null +++ b/src/services/parser/mineru_parser.py @@ -0,0 +1,204 @@ +# src/services/parser/mineru_parser.py +"""MinerU多模态PDF解析 - 版面感知解析""" + +from typing import Optional, Dict +from dataclasses import dataclass, field +from loguru import logger +import os + + +@dataclass +class MinerUResult: + """MinerU解析结果""" + file_path: str + markdown_text: str + metadata: Dict[str, str] = field(default_factory=dict) + success: bool = True + error_message: str = "" + + +class MinerUParser: + """ + MinerU多模态PDF解析器 + + MinerU (magic-pdf) 是一个开源的高质量PDF解析工具, + 支持版面感知解析,能够识别文档中的标题、正文、表格、图片等元素, + 并输出结构化的Markdown格式。 + + GitHub: https://github.com/opendatalab/MinerU + """ + + def __init__(self): + self.available = self._check_mineru_available() + + def _check_mineru_available(self) -> bool: + """检查MinerU是否可用""" + try: + from magic_pdf.pipe.UNIPipe import UNIPipe + return True + except ImportError: + logger.warning("MinerU (magic-pdf) 未安装,请运行: pip install magic-pdf[full]") + return False + + def parse(self, file_path: str, output_dir: Optional[str] = None) -> MinerUResult: + """ + 使用MinerU解析PDF文档 + + Args: + file_path: PDF文件路径 + output_dir: 输出目录(可选,用于保存解析产物) + + Returns: + MinerUResult: 解析结果 + """ + logger.info(f"尝试使用MinerU解析: {file_path}") + + if not self.available: + return MinerUResult( + file_path=file_path, + markdown_text="", + success=False, + error_message="MinerU未安装" + ) + + try: + from magic_pdf.pipe.UNIPipe import UNIPipe + from magic_pdf.libs.MakeContentConfig import DropMode + + # 设置输出目录 + if output_dir is None: + output_dir = os.path.dirname(file_path) + + # 创建解析管道 + # OCR模式可以根据PDF类型选择 + # auto: 自动判断是否需要OCR + # txt: 纯文本PDF(无OCR) + # ocr: 扫描件PDF(OCR) + pipe = UNIPipe(file_path, output_dir) + + # 执行解析 + # pipe_mk() 返回Markdown格式文本 + markdown_content = pipe.pipe_mk() + + logger.success(f"MinerU解析成功") + + return MinerUResult( + file_path=file_path, + markdown_text=markdown_content, + metadata=self._extract_metadata(pipe), + success=True + ) + + except Exception as e: + logger.error(f"MinerU解析失败: {e}") + return MinerUResult( + file_path=file_path, + markdown_text="", + success=False, + error_message=str(e) + ) + + def _extract_metadata(self, pipe) -> Dict[str, str]: + """从解析管道提取元数据""" + metadata = {} + try: + # MinerU解析管道中可能包含的元数据信息 + if hasattr(pipe, 'pdf_mid_data') and pipe.pdf_mid_data: + mid_data = pipe.pdf_mid_data + # 提取可能的元数据字段 + metadata = { + "page_count": str(mid_data.get("page_count", "")), + "language": str(mid_data.get("language", "")), + "is_scanned": str(mid_data.get("is_scanned", "")), + } + except Exception as e: + logger.warning(f"提取MinerU元数据失败: {e}") + + return metadata + + def parse_to_markdown(self, file_path: str) -> str: + """直接解析并返回Markdown文本""" + result = self.parse(file_path) + return result.markdown_text if result.success else "" + + +class ParserOrchestrator: + """ + 解析服务编排 - 按优先级选择解析器 + + 解析策略: + 1. 优先尝试MinerU(版面感知能力强) + 2. MinerU失败时回退到基础PyMuPDF解析 + """ + + def __init__(self): + from .pdf_parser import PDFParser + self.mineru_parser = MinerUParser() + self.pdf_parser = PDFParser() + self.mineru_available = self.mineru_parser.available + + def parse_pdf(self, file_path: str, prefer_mineru: bool = True) -> str: + """ + 解析PDF文档,按优先级选择解析器 + + Args: + file_path: PDF文件路径 + prefer_mineru: 是否优先使用MinerU + + Returns: + str: Markdown格式文本 + """ + markdown_text = "" + + if prefer_mineru and self.mineru_available: + # 优先尝试MinerU + result = self.mineru_parser.parse(file_path) + if result.success: + markdown_text = result.markdown_text + logger.info("使用MinerU解析成功") + return markdown_text + else: + logger.warning(f"MinerU解析失败,回退到PyMuPDF: {result.error_message}") + + # 回退到PyMuPDF基础解析 + logger.info("使用PyMuPDF基础解析") + markdown_text = self.pdf_parser.parse_to_markdown(file_path) + + return markdown_text + + def parse_docx(self, file_path: str) -> str: + """解析Word文档""" + from .docx_parser import DocxParser + docx_parser = DocxParser() + return docx_parser.parse_to_markdown(file_path) + + def parse(self, file_path: str) -> str: + """ + 根据文件类型选择解析器 + + Args: + file_path: 文件路径 + + Returns: + str: Markdown格式文本 + """ + ext = os.path.splitext(file_path)[1].lower() + + if ext == ".pdf": + return self.parse_pdf(file_path) + elif ext in [".docx", ".doc"]: + return self.parse_docx(file_path) + else: + raise ValueError(f"不支持的文件类型: {ext}") + + +def parse_with_mineru(file_path: str) -> MinerUResult: + """便捷函数:使用MinerU解析""" + parser = MinerUParser() + return parser.parse(file_path) + + +def parse_pdf_smart(file_path: str) -> str: + """便捷函数:智能解析PDF(自动选择最佳解析器)""" + orchestrator = ParserOrchestrator() + return orchestrator.parse_pdf(file_path) \ No newline at end of file diff --git a/src/services/parser/pdf_parser.py b/src/services/parser/pdf_parser.py new file mode 100644 index 0000000..63f51a7 --- /dev/null +++ b/src/services/parser/pdf_parser.py @@ -0,0 +1,268 @@ +# src/services/parser/pdf_parser.py +"""PDF文档解析 - 使用PyMuPDF基础解析""" + +import fitz # PyMuPDF +from typing import List, Dict, Optional, Tuple +from dataclasses import dataclass, field +from loguru import logger +import re + + +@dataclass +class PDFPageContent: + """PDF页面内容""" + page_number: int + text: str + tables: List[str] = field(default_factory=list) + images: List[str] = field(default_factory=list) # 图片路径列表 + blocks: List[Dict] = field(default_factory=list) + + +@dataclass +class PDFDocumentContent: + """PDF文档完整内容""" + file_path: str + total_pages: int + pages: List[PDFPageContent] + metadata: Dict[str, str] = field(default_factory=dict) + markdown_text: str = "" + + +class PDFParser: + """PDF文档解析器 - 基于PyMuPDF""" + + def __init__(self): + self.pdf = None + + def parse(self, file_path: str, extract_tables: bool = True, extract_images: bool = False) -> PDFDocumentContent: + """ + 解析PDF文档 + + Args: + file_path: PDF文件路径 + extract_tables: 是否提取表格 + extract_images: 是否提取图片 + + Returns: + PDFDocumentContent: 解析后的文档内容 + """ + logger.info(f"开始解析PDF文档: {file_path}") + + try: + self.pdf = fitz.open(file_path) + doc_content = PDFDocumentContent( + file_path=file_path, + total_pages=self.pdf.page_count, + pages=[] + ) + + # 提取文档元数据 + doc_content.metadata = self._extract_metadata() + + # 逐页解析 + for page_num in range(self.pdf.page_count): + page = self.pdf[page_num] + page_content = self._parse_page(page, page_num + 1, extract_tables, extract_images) + doc_content.pages.append(page_content) + + # 生成Markdown格式文本 + doc_content.markdown_text = self._generate_markdown(doc_content) + + self.pdf.close() + logger.success(f"PDF解析完成,共{doc_content.total_pages}页") + + return doc_content + + except Exception as e: + logger.error(f"PDF解析失败: {e}") + raise + + def _extract_metadata(self) -> Dict[str, str]: + """提取PDF元数据""" + metadata = {} + try: + meta = self.pdf.metadata + metadata = { + "title": meta.get("title", ""), + "author": meta.get("author", ""), + "subject": meta.get("subject", ""), + "keywords": meta.get("keywords", ""), + "creator": meta.get("creator", ""), + "producer": meta.get("producer", ""), + "creation_date": meta.get("creationDate", ""), + "mod_date": meta.get("modDate", ""), + } + except Exception as e: + logger.warning(f"提取元数据失败: {e}") + return metadata + + def _parse_page(self, page: fitz.Page, page_num: int, + extract_tables: bool, extract_images: bool) -> PDFPageContent: + """解析单页内容""" + page_content = PDFPageContent(page_number=page_num, text="") + + # 提取文本块(保留结构) + blocks = page.get_text("dict", flags=fitz.TEXT_PRESERVE_WHITESPACE)["blocks"] + page_content.blocks = blocks + + # 提取纯文本 + text = page.get_text("text", flags=fitz.TEXT_PRESERVE_WHITESPACE) + page_content.text = text.strip() + + # 提取表格(使用PyMuPDF的表格提取功能) + if extract_tables: + tables = self._extract_tables_from_page(page) + page_content.tables = tables + + # 提取图片 + if extract_images: + images = self._extract_images_from_page(page, page_num) + page_content.images = images + + return page_content + + def _extract_tables_from_page(self, page: fitz.Page) -> List[str]: + """ + 从页面提取表格(基于文本块分析) + 注意:PyMuPDF基础版表格提取能力有限,复杂表格建议使用MinerU + """ + tables = [] + + try: + # 使用PyMuPDF的表格提取方法(2.4+版本) + # 对于更复杂的表格,需要在mineru_parser中使用更高级的方法 + tabs = page.find_tables() + if tabs: + for tab in tabs: + table_text = tab.extract() + # 将表格转换为Markdown格式 + markdown_table = self._table_to_markdown(table_text) + tables.append(markdown_table) + + except AttributeError: + # 旧版本PyMuPDF没有表格提取功能 + logger.warning("PyMuPDF版本不支持表格提取,请升级到2.4+版本") + except Exception as e: + logger.warning(f"表格提取失败: {e}") + + return tables + + def _table_to_markdown(self, table_data: List[List[str]]) -> str: + """将表格数据转换为Markdown格式""" + if not table_data or len(table_data) < 1: + return "" + + lines = [] + # 表头 + if len(table_data) >= 1: + header = table_data[0] + lines.append("| " + " | ".join(str(cell).strip() for cell in header) + " |") + lines.append("| " + " | ".join("---" for _ in header) + " |") + + # 数据行 + for row in table_data[1:]: + lines.append("| " + " | ".join(str(cell).strip() for cell in row) + " |") + + return "\n".join(lines) + + def _extract_images_from_page(self, page: fitz.Page, page_num: int) -> List[str]: + """提取页面图片""" + images = [] + # 图片提取功能(可选实现) + # 这里仅记录图片信息,实际图片需要额外保存 + try: + image_list = page.get_images() + for img_index, img in enumerate(image_list): + xref = img[0] + images.append(f"image_p{page_num}_i{img_index}_xref{xref}") + except Exception as e: + logger.warning(f"图片提取失败: {e}") + return images + + def _generate_markdown(self, doc_content: PDFDocumentContent) -> str: + """生成Markdown格式文本""" + lines = [] + + # 文档标题 + title = doc_content.metadata.get("title", "") + if title: + lines.append(f"# {title}\n") + else: + lines.append(f"# {doc_content.file_path}\n") + + # 元数据信息 + lines.append("\n## 文档信息\n") + for key, value in doc_content.metadata.items(): + if value and key in ["author", "subject", "keywords", "creation_date"]: + lines.append(f"- **{key}**: {value}") + + # 正文内容 + lines.append("\n## 正文\n") + + for page in doc_content.pages: + # 页码标记 + lines.append(f"\n---\n**第 {page.page_number} 页**\n") + + # 处理文本内容,识别标题结构 + text = self._process_page_text(page.text, page.blocks) + lines.append(text) + + # 添加表格 + for table in page.tables: + lines.append("\n" + table + "\n") + + return "\n".join(lines) + + def _process_page_text(self, text: str, blocks: List[Dict]) -> str: + """处理页面文本,识别标题结构""" + # 基于字体大小识别标题 + processed_text = text + + # 尝试识别标题(基于字号) + # 法规文档通常有明确的层级结构:章、节、条 + processed_text = self._detect_headers(text, blocks) + + return processed_text + + def _detect_headers(self, text: str, blocks: List[Dict]) -> str: + """检测并标记标题(基于字号或内容模式)""" + lines = text.split("\n") + processed_lines = [] + + for line in lines: + line = line.strip() + if not line: + continue + + # 法规标题模式检测 + # 第一章、第X章、第X节、第X条等 + if re.match(r'^第[一二三四五六七八九十百]+章\s', line): + processed_lines.append(f"\n## {line}\n") + elif re.match(r'^第[一二三四五六七八九十百]+节\s', line): + processed_lines.append(f"\n### {line}\n") + elif re.match(r'^第[一二三四五六七八九十百]+条\s', line): + processed_lines.append(f"\n#### {line}\n") + elif re.match(r'^[一二三四五六七八九十]+\s*[、.]', line): + # 条款子项 + processed_lines.append(f"- {line}") + else: + processed_lines.append(line) + + return "\n".join(processed_lines) + + def parse_to_markdown(self, file_path: str) -> str: + """直接解析并返回Markdown文本""" + doc_content = self.parse(file_path) + return doc_content.markdown_text + + +def parse_pdf(file_path: str, **kwargs) -> PDFDocumentContent: + """便捷函数:解析PDF文档""" + parser = PDFParser() + return parser.parse(file_path, **kwargs) + + +def parse_pdf_to_markdown(file_path: str) -> str: + """便捷函数:解析PDF并返回Markdown""" + parser = PDFParser() + return parser.parse_to_markdown(file_path) \ No newline at end of file diff --git a/src/services/storage/__init__.py b/src/services/storage/__init__.py new file mode 100644 index 0000000..7227e84 --- /dev/null +++ b/src/services/storage/__init__.py @@ -0,0 +1,6 @@ +# src/services/storage/__init__.py +"""存储服务""" + +from .milvus_client import MilvusClient + +__all__ = ["MilvusClient"] \ No newline at end of file diff --git a/src/services/storage/milvus_client.py b/src/services/storage/milvus_client.py new file mode 100644 index 0000000..9b1751f --- /dev/null +++ b/src/services/storage/milvus_client.py @@ -0,0 +1,485 @@ +# src/services/storage/milvus_client.py +"""Milvus向量数据库客户端 - 存储与检索服务""" + +from pymilvus import ( + connections, + Collection, + FieldSchema, + CollectionSchema, + DataType, + utility +) +from typing import List, Dict, Optional, Any +from dataclasses import dataclass, field +from loguru import logger +import time +import numpy as np + +from ..embedding.text_chunker import TextChunk +from ..embedding.bge_m3_embedder import EmbeddingResult +from src.config.settings import settings + + +@dataclass +class SearchResult: + """检索结果""" + id: int + content: str + score: float + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class MilvusDocument: + """Milvus文档数据结构""" + doc_id: str + chunk_id: str + content: str + dense_vector: List[float] + sparse_vector: Dict[int, float] + doc_name: str + section_title: str + clause_number: str + page_number: int + regulation_type: str + version: str + create_time: int + + +class MilvusClient: + """Milvus向量数据库客户端""" + + COLLECTION_NAME = "regulations" + + SCHEMA_FIELDS = [ + FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), + FieldSchema(name="doc_id", dtype=DataType.VARCHAR, max_length=64), + FieldSchema(name="chunk_id", dtype=DataType.VARCHAR, max_length=128), + FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=8192), + FieldSchema(name="dense_vector", dtype=DataType.FLOAT_VECTOR, dim=1024), + FieldSchema(name="sparse_vector", dtype=DataType.SPARSE_FLOAT_VECTOR), + FieldSchema(name="doc_name", dtype=DataType.VARCHAR, max_length=256), + FieldSchema(name="section_title", dtype=DataType.VARCHAR, max_length=512), + FieldSchema(name="clause_number", dtype=DataType.VARCHAR, max_length=64), + FieldSchema(name="page_number", dtype=DataType.INT64), + FieldSchema(name="regulation_type", dtype=DataType.VARCHAR, max_length=32), + FieldSchema(name="version", dtype=DataType.VARCHAR, max_length=32), + FieldSchema(name="create_time", dtype=DataType.INT64), + ] + + def __init__( + self, + host: str = None, + port: int = None, + collection_name: str = None, + db_name: str = None + ): + self.host = host or settings.milvus_host + self.port = port or settings.milvus_port + self.collection_name = collection_name or settings.milvus_collection + self.db_name = db_name or settings.milvus_db_name + + self.collection: Optional[Collection] = None + self.connected = False + + logger.info(f"Milvus客户端配置: {self.host}:{self.port}, Collection: {self.collection_name}") + + def connect(self) -> bool: + """连接到Milvus服务器""" + try: + connections.connect( + alias="default", + host=self.host, + port=self.port, + db_name=self.db_name + ) + self.connected = True + logger.success(f"Milvus连接成功: {self.host}:{self.port}") + return True + except Exception as e: + logger.error(f"Milvus连接失败: {e}") + self.connected = False + return False + + def disconnect(self): + """断开连接""" + try: + connections.disconnect("default") + self.connected = False + logger.info("Milvus连接已断开") + except Exception as e: + logger.warning(f"断开连接时出错: {e}") + + def create_collection(self, recreate: bool = False) -> bool: + """创建Collection""" + if not self.connected: + logger.warning("未连接到Milvus,请先调用connect()") + return False + + try: + if utility.has_collection(self.collection_name): + if recreate: + logger.info(f"删除已存在的Collection: {self.collection_name}") + utility.drop_collection(self.collection_name) + else: + logger.info(f"Collection已存在: {self.collection_name}") + self.collection = Collection(self.collection_name) + return True + + schema = CollectionSchema( + fields=self.SCHEMA_FIELDS, + description="法规文档向量存储", + enable_dynamic_field=True + ) + + self.collection = Collection( + name=self.collection_name, + schema=schema + ) + + self._create_indexes() + + logger.success(f"Collection创建成功: {self.collection_name}") + return True + + except Exception as e: + logger.error(f"Collection创建失败: {e}") + return False + + def _create_indexes(self): + """创建向量索引""" + if not self.collection: + return + + try: + dense_index_params = { + "metric_type": "COSINE", + "index_type": "IVF_FLAT", + "params": {"nlist": 128} + } + self.collection.create_index( + field_name="dense_vector", + index_params=dense_index_params + ) + + sparse_index_params = { + "metric_type": "IP", + "index_type": "SPARSE_INVERTED_INDEX", + "params": {"drop_ratio_build": 0.2} + } + self.collection.create_index( + field_name="sparse_vector", + index_params=sparse_index_params + ) + + logger.success("向量索引创建成功") + + except Exception as e: + logger.warning(f"创建索引时出错: {e}") + + def load_collection(self): + """加载Collection到内存""" + if self.collection: + self.collection.load() + logger.info(f"Collection已加载: {self.collection_name}") + + def release_collection(self): + """释放Collection内存""" + if self.collection: + self.collection.release() + logger.info(f"Collection已释放: {self.collection_name}") + + def insert_chunks( + self, + chunks: List[TextChunk], + embeddings: EmbeddingResult + ) -> List[int]: + """插入文档分块和嵌入向量""" + if not self.collection: + logger.warning("Collection未初始化") + return [] + + if len(chunks) != len(embeddings.texts): + logger.warning(f"Chunks数量与嵌入数量不匹配") + return [] + + logger.info(f"准备插入{len(chunks)}个文档分块") + + try: + data = [] + current_time = int(time.time()) + + for chunk, dense_emb, sparse_emb in zip( + chunks, + embeddings.dense_embeddings, + embeddings.sparse_embeddings + ): + row = { + "doc_id": chunk.metadata.doc_id, + "chunk_id": chunk.metadata.chunk_id, + "content": chunk.content, + "dense_vector": dense_emb.tolist(), + "sparse_vector": sparse_emb, + "doc_name": chunk.metadata.doc_name, + "section_title": chunk.metadata.section_title, + "clause_number": chunk.metadata.clause_number, + "page_number": chunk.metadata.page_number, + "regulation_type": chunk.metadata.regulation_type, + "version": chunk.metadata.version, + "create_time": current_time + } + data.append(row) + + result = self.collection.insert(data) + self.collection.flush() + + logger.success(f"插入完成,共{len(result.primary_keys)}条记录") + return result.primary_keys + + except Exception as e: + logger.error(f"插入数据失败: {e}") + return [] + + def hybrid_search( + self, + query_dense: List[float], + query_sparse: Dict[int, float], + top_k: int = 10, + filters: Optional[str] = None + ) -> List[SearchResult]: + """混合检索:Dense + Sparse""" + if not self.collection: + logger.warning("Collection未初始化") + return [] + + try: + self.collection.load() + + # 使用简单的Dense检索(兼容所有版本) + dense_results = self.dense_search(query_dense, top_k, filters) + + # 可选:合并Sparse结果 + if query_sparse: + sparse_results = self.sparse_search(query_sparse, top_k, filters) + merged = self._merge_results(dense_results, sparse_results, top_k) + logger.success(f"混合检索完成,返回{len(merged)}条结果") + return merged + + return dense_results + + except Exception as e: + logger.error(f"混合检索失败: {e}") + return [] + + def _merge_results( + self, + dense_results: List[SearchResult], + sparse_results: List[SearchResult], + top_k: int, + dense_weight: float = 0.6 + ) -> List[SearchResult]: + """手动融合Dense和Sparse结果""" + sparse_weight = 1 - dense_weight + merged_dict = {} + + for r in dense_results: + merged_dict[r.id] = { + "result": r, + "dense_score": r.score * dense_weight, + "sparse_score": 0 + } + + for r in sparse_results: + if r.id in merged_dict: + merged_dict[r.id]["sparse_score"] = r.score * sparse_weight + else: + merged_dict[r.id] = { + "result": r, + "dense_score": 0, + "sparse_score": r.score * sparse_weight + } + + final_results = [] + for id_, data in merged_dict.items(): + result = data["result"] + final_score = data["dense_score"] + data["sparse_score"] + final_results.append(SearchResult( + id=result.id, + content=result.content, + score=final_score, + metadata=result.metadata + )) + + final_results.sort(key=lambda x: x.score, reverse=True) + return final_results[:top_k] + + def dense_search( + self, + query_dense: List[float], + top_k: int = 10, + filters: Optional[str] = None + ) -> List[SearchResult]: + """纯Dense向量检索""" + if not self.collection: + return [] + + try: + self.collection.load() + + search_params = { + "metric_type": "COSINE", + "params": {"nprobe": 16} + } + + results = self.collection.search( + data=[query_dense], + anns_field="dense_vector", + param=search_params, + limit=top_k, + filter=filters, + output_fields=[ + "doc_id", "chunk_id", "content", + "doc_name", "section_title", "clause_number", + "page_number", "regulation_type", "version" + ] + ) + + search_results = [] + for hits in results: + for hit in hits: + result = SearchResult( + id=hit.id, + content=hit.entity.get("content", ""), + score=hit.score, + metadata={ + "doc_id": hit.entity.get("doc_id", ""), + "chunk_id": hit.entity.get("chunk_id", ""), + "doc_name": hit.entity.get("doc_name", ""), + "section_title": hit.entity.get("section_title", ""), + "clause_number": hit.entity.get("clause_number", ""), + "page_number": hit.entity.get("page_number", 0), + "regulation_type": hit.entity.get("regulation_type", ""), + "version": hit.entity.get("version", ""), + } + ) + search_results.append(result) + + return search_results + + except Exception as e: + logger.error(f"Dense检索失败: {e}") + return [] + + def sparse_search( + self, + query_sparse: Dict[int, float], + top_k: int = 10, + filters: Optional[str] = None + ) -> List[SearchResult]: + """纯Sparse向量检索""" + if not self.collection: + return [] + + try: + self.collection.load() + + search_params = { + "metric_type": "IP", + "params": {"drop_ratio_search": 0.2} + } + + results = self.collection.search( + data=[query_sparse], + anns_field="sparse_vector", + param=search_params, + limit=top_k, + filter=filters, + output_fields=[ + "doc_id", "chunk_id", "content", + "doc_name", "section_title", "clause_number", + "page_number", "regulation_type", "version" + ] + ) + + search_results = [] + for hits in results: + for hit in hits: + result = SearchResult( + id=hit.id, + content=hit.entity.get("content", ""), + score=hit.score, + metadata={ + "doc_id": hit.entity.get("doc_id", ""), + "chunk_id": hit.entity.get("chunk_id", ""), + "doc_name": hit.entity.get("doc_name", ""), + "section_title": hit.entity.get("section_title", ""), + "clause_number": hit.entity.get("clause_number", ""), + "page_number": hit.entity.get("page_number", 0), + "regulation_type": hit.entity.get("regulation_type", ""), + "version": hit.entity.get("version", ""), + } + ) + search_results.append(result) + + return search_results + + except Exception as e: + logger.error(f"Sparse检索失败: {e}") + return [] + + def delete_by_doc_id(self, doc_id: str) -> int: + """根据doc_id删除记录""" + if not self.collection: + return 0 + + try: + expr = f'doc_id=="{doc_id}"' + result = self.collection.delete(expr) + logger.info(f"删除记录: doc_id={doc_id}, 数量={len(result.primary_keys)}") + return len(result.primary_keys) + except Exception as e: + logger.error(f"删除失败: {e}") + return 0 + + def get_collection_stats(self) -> Dict[str, Any]: + """获取Collection统计信息""" + if not self.collection: + return {} + + try: + stats = { + "name": self.collection_name, + "num_entities": self.collection.num_entities, + "description": self.collection.description, + } + return stats + except Exception as e: + logger.warning(f"获取统计信息失败: {e}") + return {} + + +def create_milvus_client() -> MilvusClient: + """便捷函数:创建Milvus客户端""" + client = MilvusClient() + client.connect() + client.create_collection(recreate=False) + return client + + +def insert_documents( + client: MilvusClient, + chunks: List[TextChunk], + embeddings: EmbeddingResult +) -> List[int]: + """便捷函数:插入文档""" + return client.insert_chunks(chunks, embeddings) + + +def search_regulations( + client: MilvusClient, + query_dense: List[float], + query_sparse: Dict[int, float], + top_k: int = 10 +) -> List[SearchResult]: + """便捷函数:检索法规""" + return client.hybrid_search(query_dense, query_sparse, top_k) \ No newline at end of file diff --git a/src/workers/__init__.py b/src/workers/__init__.py new file mode 100644 index 0000000..3c737a9 --- /dev/null +++ b/src/workers/__init__.py @@ -0,0 +1,2 @@ +# src/workers/__init__.py +"""异步任务Worker模块""" \ No newline at end of file diff --git a/start_api.sh b/start_api.sh new file mode 100644 index 0000000..a28bd0d --- /dev/null +++ b/start_api.sh @@ -0,0 +1,47 @@ +#!/bin/bash +# start_api.sh - 启动API服务(支持虚拟环境) + +set -e + +VENV_DIR=".venv" + +# 创建日志目录 +mkdir -p logs + +echo "========================================" +echo "启动 AI+合规智能中枢 API服务" +echo "========================================" +echo "" + +# 检查虚拟环境 +if [ ! -d "$VENV_DIR" ]; then + echo "错误: 虚拟环境不存在,请先运行 ./quick_start.sh" + exit 1 +fi + +# 激活虚拟环境 +source $VENV_DIR/bin/activate +echo "已激活虚拟环境: $VENV_DIR" +echo "" + +# 检查.env文件 +if [ ! -f ".env" ]; then + echo "警告: .env文件不存在,使用默认配置" +fi + +# 启动参数 +HOST=${API_HOST:-0.0.0.0} +PORT=${API_PORT:-8000} + +echo "API地址: http://$HOST:$PORT" +echo "API文档: http://$HOST:$PORT/docs" +echo "健康检查: http://$HOST:$PORT/health" +echo "" +echo "前端测试页面:" +echo " 直接打开: frontend/index.html" +echo " 或启动服务: ./start_frontend.sh" +echo "" +echo "正在启动..." +echo "" + +python -m uvicorn src.api.main:app --host $HOST --port $PORT --reload \ No newline at end of file diff --git a/start_api_background.sh b/start_api_background.sh new file mode 100644 index 0000000..71838f0 --- /dev/null +++ b/start_api_background.sh @@ -0,0 +1,77 @@ +#!/bin/bash +# start_api_background.sh - 后台启动API服务(生产环境,支持虚拟环境) + +set -e + +VENV_DIR=".venv" + +# 创建日志目录 +mkdir -p logs + +PID_FILE=logs/api.pid +LOG_FILE=logs/api.log + +echo "========================================" +echo "后台启动 AI+合规智能中枢 API服务" +echo "========================================" +echo "" + +# 检查虚拟环境 +if [ ! -d "$VENV_DIR" ]; then + echo "错误: 虚拟环境不存在,请先运行 ./quick_start.sh" + exit 1 +fi + +# 检查是否已运行 +if [ -f "$PID_FILE" ]; then + PID=$(cat $PID_FILE) + if ps -p $PID > /dev/null 2>&1; then + echo "服务已在运行 (PID: $PID)" + echo "如需重启,请先运行: ./stop_api.sh" + exit 1 + else + rm -f $PID_FILE + fi +fi + +# 启动参数 +HOST=${API_HOST:-0.0.0.0} +PORT=${API_PORT:-8000} + +echo "服务地址: http://$HOST:$PORT" +echo "日志文件: $LOG_FILE" +echo "" + +# 后台启动(使用虚拟环境) +echo "正在后台启动..." +nohup $VENV_DIR/bin/python -m uvicorn src.api.main:app --host $HOST --port $PORT > $LOG_FILE 2>&1 & +PID=$! + +# 保存PID +echo $PID > $PID_FILE + +# 等待服务启动 +sleep 3 + +# 检查是否启动成功 +if ps -p $PID > /dev/null 2>&1; then + echo "服务启动成功 (PID: $PID)" + echo "" + echo "API地址: http://$HOST:$PORT" + echo "API文档: http://$HOST:$PORT/docs" + echo "健康检查: http://$HOST:$PORT/health" + echo "" + echo "前端测试页面:" + echo " 直接打开: frontend/index.html" + echo " 或启动服务: ./start_frontend.sh" + echo "" + echo "查看日志:" + echo " tail -f $LOG_FILE" + echo "" + echo "停止服务:" + echo " ./stop_api.sh" +else + echo "服务启动失败,请查看日志: $LOG_FILE" + rm -f $PID_FILE + exit 1 +fi \ No newline at end of file diff --git a/start_frontend.sh b/start_frontend.sh new file mode 100644 index 0000000..481a04d --- /dev/null +++ b/start_frontend.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# start_frontend.sh - 启动前端测试页面静态服务 + +set -e + +FRONTEND_DIR="frontend" +PORT=${FRONTEND_PORT:-3000} + +echo "========================================" +echo "启动前端测试页面服务" +echo "========================================" +echo "" + +# 检查前端目录 +if [ ! -d "$FRONTEND_DIR" ]; then + echo "错误: 前端目录不存在: $FRONTEND_DIR" + exit 1 +fi + +# 检查index.html +if [ ! -f "$FRONTEND_DIR/index.html" ]; then + echo "错误: 前端页面不存在: $FRONTEND_DIR/index.html" + exit 1 +fi + +echo "前端地址: http://localhost:$PORT" +echo "" +echo "确保API服务已启动:" +echo " ./start_api.sh" +echo "" +echo "正在启动前端服务..." + +# 使用Python内置HTTP服务器 +cd $FRONTEND_DIR + +# 查找Python命令 +if command -v python3 &> /dev/null; then + PYTHON_CMD=python3 +elif command -v python &> /dev/null; then + PYTHON_CMD=python +else + echo "错误: 未找到Python" + exit 1 +fi + +$PYTHON_CMD -m http.server $PORT --bind 0.0.0.0 \ No newline at end of file diff --git a/stop_api.sh b/stop_api.sh new file mode 100644 index 0000000..5b22d30 --- /dev/null +++ b/stop_api.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# stop_api.sh - 停止API服务 + +PID_FILE=logs/api.pid + +echo "========================================" +echo "停止 AI+合规智能中枢 API服务" +echo "========================================" +echo "" + +if [ -f "$PID_FILE" ]; then + PID=$(cat $PID_FILE) + + if ps -p $PID > /dev/null 2>&1; then + echo "正在停止服务 (PID: $PID)..." + kill $PID + + # 等待进程结束 + sleep 2 + + if ps -p $PID > /dev/null 2>&1; then + echo "进程未响应,强制终止..." + kill -9 $PID + fi + + rm -f $PID_FILE + echo "服务已停止" + else + echo "进程已不存在,清理PID文件" + rm -f $PID_FILE + fi +else + echo "PID文件不存在,服务可能未运行" + + # 尝试查找并停止所有uvicorn进程 + UVICORN_PIDS=$(pgrep -f "uvicorn src.api.main") + if [ -n "$UVICORN_PIDS" ]; then + echo "发现运行中的uvicorn进程: $UVICORN_PIDS" + echo "是否停止这些进程? (y/n)" + read -r answer + if [ "$answer" = "y" ]; then + kill $UVICORN_PIDS + echo "进程已停止" + fi + fi +fi \ No newline at end of file diff --git a/test_api.sh b/test_api.sh new file mode 100644 index 0000000..3ebeb4d --- /dev/null +++ b/test_api.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# test_api.sh - API接口测试脚本 + +API_URL=${API_URL:-http://localhost:8000} + +echo "========================================" +echo "API接口测试" +echo "========================================" +echo "" + +# 1. 健康检查 +echo ">>> 测试: 健康检查 GET /health" +curl -s -X GET "$API_URL/health" +echo "" +echo "" + +# 2. 根路径 +echo ">>> 测试: 根路径 GET /" +curl -s -X GET "$API_URL/" +echo "" +echo "" + +# 3. 检索接口(无数据时返回空结果) +echo ">>> 测试: 检索接口 POST /api/v1/knowledge/search" +curl -s -X POST "$API_URL/api/v1/knowledge/search" \ + -H "Content-Type: application/json" \ + -d '{"query": "机动车安全标准", "top_k": 5}' +echo "" +echo "" + +echo "========================================" +echo "测试完成" +echo "========================================" +echo "" +echo "上传文档测试:" +echo " ./test_upload.sh your_file.pdf" +echo "" \ No newline at end of file diff --git a/test_upload.sh b/test_upload.sh new file mode 100644 index 0000000..7fe5521 --- /dev/null +++ b/test_upload.sh @@ -0,0 +1,76 @@ +#!/bin/bash +# test_upload.sh - 测试文档上传(支持虚拟环境) + +set -e + +API_URL=${API_URL:-http://localhost:8000} + +echo "========================================" +echo "测试文档上传功能" +echo "========================================" +echo "" + +echo "API地址: $API_URL" +echo "" + +# 检查健康状态 +echo "1. 检查服务健康状态..." +curl -s $API_URL/health | python -m json.tool 2>/dev/null || curl -s $API_URL/health +echo "" + +# 检查文件参数 +if [ -z "$1" ]; then + echo "使用方法: ./test_upload.sh <文件路径>" + echo "" + echo "示例:" + echo " ./test_upload.sh sample.pdf" + echo " ./test_upload.sh sample.docx" + exit 1 +fi + +FILE_PATH=$1 + +if [ ! -f "$FILE_PATH" ]; then + echo "错误: 文件不存在: $FILE_PATH" + exit 1 +fi + +FILE_NAME=$(basename "$FILE_PATH") +FILE_EXT="${FILE_NAME##*.}" + +echo "文件: $FILE_PATH" +echo "类型: $FILE_EXT" +echo "" + +# 上传文档 +echo "2. 上传文档..." +echo "" + +RESPONSE=$(curl -s -X POST "$API_URL/api/v1/documents/upload" \ + -F "file=@$FILE_PATH" \ + -F "doc_name=$FILE_NAME" \ + -F "regulation_type=测试法规") + +echo "$RESPONSE" | python -m json.tool 2>/dev/null || echo "$RESPONSE" +echo "" + +# 提取doc_id +DOC_ID=$(echo "$RESPONSE" | python -c "import sys,json; d=json.load(sys.stdin); print(d.get('doc_id',''))" 2>/dev/null) + +if [ -n "$DOC_ID" ]; then + echo "文档ID: $DOC_ID" + echo "" + + # 测试检索 + echo "3. 测试检索..." + echo "" + + SEARCH_QUERY="法规安全要求" + + curl -s -X POST "$API_URL/api/v1/knowledge/search" \ + -H "Content-Type: application/json" \ + -d "{\"query\": \"$SEARCH_QUERY\", \"top_k\": 5}" \ + | python -m json.tool 2>/dev/null || cat +else + echo "上传可能失败,请检查响应内容" +fi \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..c024ed8 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,2 @@ +# tests/__init__.py +"""测试模块""" \ No newline at end of file diff --git a/tests/test_embedding.py b/tests/test_embedding.py new file mode 100644 index 0000000..821a1bb --- /dev/null +++ b/tests/test_embedding.py @@ -0,0 +1,184 @@ +# tests/test_embedding.py +"""嵌入和分块测试""" + +import pytest +from loguru import logger +import sys +import os + +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +from src.services.embedding.text_chunker import RegulationChunker, TextChunk, ChunkMetadata +from src.services.embedding.bge_m3_embedder import BGEM3Embedder, EmbeddingResult + + +class TestRegulationChunker: + """法规分块器测试""" + + @pytest.fixture + def chunker(self): + """创建分块器实例""" + return RegulationChunker(chunk_size=512) + + @pytest.fixture + def sample_regulation(self): + """示例法规文档""" + return """ +# GB 7258-2017 机动车运行安全技术条件 + +第一章 范围 + +第一条 本标准规定了机动车运行安全技术条件。 + +第二条 本标准适用于在我国道路上行驶的所有机动车。 + +第二章 术语和定义 + +第三条 下列术语和定义适用于本标准。 + +(一)机动车:以动力装置驱动或者牵引,上道路行驶的供人员乘用或者用于运送物品以及进行工程专项作业的轮式车辆。 + +(二)整车:完整的机动车,包括所有必要的部件和系统。 + +第三章 技术要求 + +第四条 机动车应满足以下基本要求: + +1. 车辆应具有唯一的产品标识; +2. 车辆结构应安全可靠; +3. 车辆应配备必要的安全装置。 +""" + + def test_chunk_document(self, chunker, sample_regulation): + """测试文档分块""" + chunks = chunker.chunk_document( + sample_regulation, + doc_id="gb7258", + doc_name="GB 7258-2017", + regulation_type="车辆安全" + ) + + # 应该有多个分块 + assert len(chunks) > 3 + + # 每个分块应该有内容 + for chunk in chunks: + assert len(chunk.content) > 0 + assert chunk.metadata.doc_id == "gb7258" + + def test_section_detection(self, chunker, sample_regulation): + """测试章节检测""" + chunks = chunker.chunk_document( + sample_regulation, + doc_id="test", + doc_name="测试" + ) + + # 应该检测到章节 + section_numbers = [c.metadata.section_number for c in chunks] + assert any(s for s in section_numbers) # 至少有一个章节编号 + + def test_clause_detection(self, chunker, sample_regulation): + """测试条款检测""" + chunks = chunker.chunk_document( + sample_regulation, + doc_id="test", + doc_name="测试" + ) + + # 应该检测到条款 + clause_numbers = [c.metadata.clause_number for c in chunks] + assert any(c for c in clause_numbers) # 至少有一个条款编号 + + def test_long_clause_split(self, chunker): + """测试长条款分割""" + long_clause = """ +第一条 本条款内容很长,需要进行分割处理。 + +本条款包含以下多项内容: +1. 第一项内容,这是一个非常长的子项,包含了大量的文字描述,需要进行适当的处理。 +2. 第二项内容,这也是一个较长的子项,包含了相关的技术要求和规范说明。 +3. 第三项内容,继续描述相关要求和注意事项,确保文档的完整性和规范性。 +4. 第四项内容,补充说明其他相关事项,保证内容的全面性。 +""" + + chunks = chunker.chunk_document( + long_clause, + doc_id="test", + doc_name="测试" + ) + + # 长条款应该被分割成多个chunk + assert len(chunks) >= 1 + + +class TestBGEM3Embedder: + """BGE-M3嵌入模型测试""" + + @pytest.fixture + def embedder(self): + """创建嵌入模型实例""" + try: + return BGEM3Embedder() + except Exception as e: + pytest.skip(f"嵌入模型加载失败: {e}") + + def test_embed_single(self, embedder): + """测试单文本嵌入""" + text = "这是一条测试文本" + result = embedder.embed_single(text) + + # 应该包含dense和sparse向量 + assert 'dense' in result + assert 'sparse' in result + + # dense向量维度应该是1024 + assert len(result['dense']) == 1024 + + def test_embed_batch(self, embedder): + """测试批量嵌入""" + texts = [ + "第一条 本标准规定了机动车安全要求", + "第二条 机动车应符合技术条件", + "第三条 生产企业应建立管理体系" + ] + + result = embedder.embed(texts) + + # 应该返回正确数量的向量 + assert len(result.dense_embeddings) == 3 + + # 维度应该是1024 + assert result.dense_embeddings.shape[1] == 1024 + + def test_embed_empty_list(self, embedder): + """测试空列表嵌入""" + result = embedder.embed([]) + + # 应该返回空结果 + assert len(result.dense_embeddings) == 0 + + def test_similarity(self, embedder): + """测试相似度计算""" + import numpy as np + + texts = [ + "机动车安全标准要求", + "汽车安全技术规范", + "食品安全管理规定" # 不相关文本 + ] + + result = embedder.embed(texts) + + # 计算第一个文本与其他文本的相似度 + query = result.dense_embeddings[0] + docs = result.dense_embeddings[1:] + + similarities = embedder.compute_similarity(query, docs) + + # 相关文档的相似度应该更高 + assert similarities[0] > similarities[1] # 车辆安全 > 食品安全 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) \ No newline at end of file diff --git a/tests/test_milvus.py b/tests/test_milvus.py new file mode 100644 index 0000000..8d599b8 --- /dev/null +++ b/tests/test_milvus.py @@ -0,0 +1,136 @@ +# tests/test_milvus.py +"""Milvus集成测试""" + +import pytest +from loguru import logger +import sys +import os + +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +from src.services.storage.milvus_client import MilvusClient, SearchResult +from src.services.embedding.bge_m3_embedder import BGEM3Embedder +from src.config.settings import settings + + +class TestMilvusConnection: + """Milvus连接测试""" + + def test_connection(self): + """测试Milvus连接""" + client = MilvusClient() + + result = client.connect() + assert result == True + + client.disconnect() + + def test_create_collection(self): + """测试创建Collection""" + client = MilvusClient() + client.connect() + + result = client.create_collection(recreate=True) + assert result == True + + # 检查Collection是否存在 + stats = client.get_collection_stats() + assert stats["name"] == settings.milvus_collection + + client.disconnect() + + +class TestMilvusOperations: + """Milvus操作测试""" + + @pytest.fixture + def client(self): + """创建测试客户端""" + client = MilvusClient() + client.connect() + client.create_collection(recreate=True) + client.load_collection() + yield client + client.disconnect() + + def test_insert_and_search(self, client): + """测试插入和检索""" + from src.services.embedding.text_chunker import TextChunk, ChunkMetadata + + # 创建测试数据 + chunks = [ + TextChunk( + content="第一条 为保障机动车安全技术性能,预防和减少机动车交通事故,保护人身安全,制定本标准。", + metadata=ChunkMetadata( + doc_id="test_doc", + doc_name="测试文档", + chunk_id="test_chunk_1", + clause_number="第一条", + regulation_type="车辆安全" + ) + ), + TextChunk( + content="第二条 本标准适用于在我国道路上行驶的所有机动车。", + metadata=ChunkMetadata( + doc_id="test_doc", + doc_name="测试文档", + chunk_id="test_chunk_2", + clause_number="第二条", + regulation_type="车辆安全" + ) + ) + ] + + # 生成嵌入 + embedder = BGEM3Embedder() + embeddings = embedder.embed([c.content for c in chunks]) + + # 插入数据 + inserted_ids = client.insert_chunks(chunks, embeddings) + assert len(inserted_ids) == 2 + + # 执行检索 + query = "机动车安全标准" + query_embedding = embedder.embed_single(query) + + results = client.hybrid_search( + query_dense=query_embedding['dense'].tolist(), + query_sparse=query_embedding['sparse'], + top_k=2 + ) + + assert len(results) > 0 + assert "机动车" in results[0].content or "安全" in results[0].content + + +class TestEmbedding: + """嵌入模型测试""" + + def test_embed_single_text(self): + """测试单文本嵌入""" + embedder = BGEM3Embedder() + + result = embedder.embed_single("这是一条测试文本") + + assert 'dense' in result + assert 'sparse' in result + assert len(result['dense']) == 1024 # BGE-M3默认维度 + + def test_embed_batch(self): + """测试批量嵌入""" + embedder = BGEM3Embedder() + + texts = [ + "第一条 本标准规定了机动车安全要求", + "第二条 机动车应符合以下技术条件", + "第三条 生产企业应建立质量管理体系" + ] + + result = embedder.embed(texts) + + assert len(result.dense_embeddings) == 3 + assert result.dense_embeddings.shape[1] == 1024 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) \ No newline at end of file diff --git a/tests/test_parser.py b/tests/test_parser.py new file mode 100644 index 0000000..83cc250 --- /dev/null +++ b/tests/test_parser.py @@ -0,0 +1,118 @@ +# tests/test_parser.py +"""文档解析测试""" + +import pytest +from loguru import logger +import sys +import os + +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +from src.services.parser.pdf_parser import PDFParser, parse_pdf_to_markdown +from src.services.parser.docx_parser import DocxParser, parse_docx_to_markdown +from src.services.parser.mineru_parser import MinerUParser, ParserOrchestrator + + +class TestPDFParser: + """PDF解析测试""" + + def test_parser_initialization(self): + """测试PDF解析器初始化""" + parser = PDFParser() + assert parser is not None + + def test_parse_sample_pdf(self): + """测试解析示例PDF(如果有)""" + # 如果有示例PDF文件,可以在此测试 + sample_pdf = os.path.join(os.path.dirname(__file__), "sample.pdf") + + if os.path.exists(sample_pdf): + parser = PDFParser() + result = parser.parse(sample_pdf) + + assert result.total_pages > 0 + assert len(result.pages) > 0 + assert len(result.markdown_text) > 0 + + +class TestDocxParser: + """Word文档解析测试""" + + def test_parser_initialization(self): + """测试Word解析器初始化""" + parser = DocxParser() + assert parser is not None + + def test_parse_sample_docx(self): + """测试解析示例DOCX""" + sample_docx = os.path.join(os.path.dirname(__file__), "sample.docx") + + if os.path.exists(sample_docx): + parser = DocxParser() + result = parser.parse(sample_docx) + + assert len(result.paragraphs) > 0 + assert len(result.markdown_text) > 0 + + +class TestChunker: + """分块器测试""" + + def test_chunker_initialization(self): + """测试分块器初始化""" + from src.services.embedding.text_chunker import RegulationChunker + + chunker = RegulationChunker(chunk_size=512) + assert chunker is not None + + def test_chunk_sample_text(self): + """测试分块示例文本""" + from src.services.embedding.text_chunker import RegulationChunker + + sample_text = """ +# 测试法规文档 + +第一章 总则 + +第一条 为规范某项行为,制定本规定。 + +第二条 本规定适用于相关主体。 + +第二章 具体要求 + +第三条 相关主体应当遵守以下要求: + +(一)建立管理制度; +(二)配备专业人员; +(三)定期进行检查。 +""" + + chunker = RegulationChunker(chunk_size=256) + chunks = chunker.chunk_document( + sample_text, + doc_id="test", + doc_name="测试法规" + ) + + assert len(chunks) > 0 + + # 验证分块包含章节信息 + has_section = any(c.metadata.section_number for c in chunks) + assert has_section + + +class TestFullPipeline: + """完整流程测试""" + + def test_pipeline_without_files(self): + """测试流程初始化(无文件)""" + from src.services.document_processor import DocumentProcessor + + processor = DocumentProcessor() + assert processor is not None + + processor.close() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) \ No newline at end of file diff --git a/tests/verify_mvp.py b/tests/verify_mvp.py new file mode 100644 index 0000000..bed0ca4 --- /dev/null +++ b/tests/verify_mvp.py @@ -0,0 +1,222 @@ +""" +MVP功能验证脚本 + +用于验证完整的文档处理流程: +1. PDF/DOCX解析 +2. 智能分块 +3. 向量嵌入 +4. Milvus入库 +5. 混合检索 + +使用方法: +1. 首先启动Milvus: docker-compose up -d +2. 运行此脚本: python verify_mvp.py +""" + +import os +import sys +import time + +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +from loguru import logger +from src.config.logging import setup_logging +from src.services.document_processor import DocumentProcessor, ProcessingResult +from src.services.storage.milvus_client import MilvusClient +from src.config.settings import settings + +# 设置日志 +setup_logging(level="INFO") + + +def verify_milvus_connection(): + """验证Milvus连接""" + logger.info("=" * 50) + logger.info("Step 1: 验证Milvus连接") + logger.info("=" * 50) + + client = MilvusClient() + + try: + result = client.connect() + if result: + logger.success("Milvus连接成功") + + # 创建Collection + client.create_collection(recreate=True) + stats = client.get_collection_stats() + logger.info(f"Collection信息: {stats}") + + client.disconnect() + return True + else: + logger.error("Milvus连接失败,请检查docker-compose是否启动") + return False + + except Exception as e: + logger.error(f"Milvus连接异常: {e}") + logger.info("请先启动Milvus: cd docker && docker-compose up -d") + return False + + +def verify_embedding_model(): + """验证嵌入模型""" + logger.info("=" * 50) + logger.info("Step 2: 验证BGE-M3嵌入模型") + logger.info("=" * 50) + + try: + from src.services.embedding.bge_m3_embedder import BGEM3Embedder + + embedder = BGEM3Embedder() + logger.success("嵌入模型加载成功") + + # 测试嵌入 + test_text = "这是一条测试文本,用于验证嵌入模型功能" + result = embedder.embed_single(test_text) + + logger.info(f"Dense向量维度: {len(result['dense'])}") + logger.info(f"Sparse向量词数: {len(result['sparse'])}") + + return True + + except Exception as e: + logger.error(f"嵌入模型验证失败: {e}") + logger.info("请确保已安装FlagEmbedding: pip install FlagEmbedding") + return False + + +def verify_sample_document(): + """验证示例文档处理""" + logger.info("=" * 50) + logger.info("Step 3: 验证文档处理流程") + logger.info("=" * 50) + + # 使用内置的示例文本(无需外部文件) + sample_text = """ +# GB 7258-2017 机动车运行安全技术条件 + +第一章 范围 + +第一条 本标准规定了机动车运行安全技术条件,适用于在我国道路上行驶的所有机动车。 + +第二条 本标准包括整车、发动机、传动系、行驶系、制动系、照明与信号装置等技术要求。 + +第二章 术语和定义 + +第三条 下列术语和定义适用于本标准: + +(一)机动车:以动力装置驱动或者牵引,上道路行驶的供人员乘用或者用于运送物品的轮式车辆。 + +(二)整车产品:完整的机动车产品,包括所有必要的部件和系统。 + +第三章 整车技术要求 + +第四条 机动车整车应满足以下基本技术要求: + +1. 车辆外廓尺寸应符合规定限值; +2. 车辆应具有唯一的产品标识; +3. 车辆结构应安全可靠,各部件连接牢固。 + +第五条 车辆应配备必要的安全装置,包括: +- 制动系统 +- 照明与信号装置 +- 安全带 +- 灭火器 +""" + + try: + from src.services.embedding.text_chunker import RegulationChunker + from src.services.embedding.bge_m3_embedder import BGEM3Embedder + from src.services.storage.milvus_client import MilvusClient + + # 1. 分块 + logger.info("测试分块...") + chunker = RegulationChunker(chunk_size=256) + chunks = chunker.chunk_document( + sample_text, + doc_id="gb7258_test", + doc_name="GB 7258-2017 测试", + regulation_type="车辆安全" + ) + logger.success(f"分块完成,共{len(chunks)}个chunk") + + # 2. 嵌入 + logger.info("测试嵌入...") + embedder = BGEM3Embedder() + embeddings = embedder.embed([c.content for c in chunks]) + logger.success(f"嵌入完成,向量数: {len(embeddings.dense_embeddings)}") + + # 3. 入库 + logger.info("测试入库...") + client = MilvusClient() + client.connect() + client.create_collection(recreate=False) + client.load_collection() + + inserted_ids = client.insert_chunks(chunks, embeddings) + logger.success(f"入库完成,共{len(inserted_ids)}条记录") + + # 4. 检索 + logger.info("测试检索...") + query = "机动车安全技术要求" + query_emb = embedder.embed_single(query) + + results = client.hybrid_search( + query_dense=query_emb['dense'].tolist(), + query_sparse=query_emb['sparse'], + top_k=3 + ) + logger.success(f"检索完成,返回{len(results)}条结果") + + for i, r in enumerate(results): + logger.info(f"结果{i+1}: 分数={r.score:.4f}, 内容={r.content[:50]}...") + + client.disconnect() + return True + + except Exception as e: + logger.error(f"文档处理验证失败: {e}") + return False + + +def main(): + """主验证流程""" + logger.info("\n" + "=" * 60) + logger.info("AI+合规智能中枢 MVP功能验证") + logger.info("=" * 60) + + results = [] + + # 1. Milvus连接验证 + results.append(("Milvus连接", verify_milvus_connection())) + + # 2. 嵌入模型验证 + results.append(("嵌入模型", verify_embedding_model())) + + # 3. 文档处理验证 + results.append(("文档处理", verify_sample_document())) + + # 输出结果汇总 + logger.info("\n" + "=" * 60) + logger.info("验证结果汇总") + logger.info("=" * 60) + + all_passed = True + for name, passed in results: + status = "✅ 通过" if passed else "❌ 失败" + logger.info(f"{name}: {status}") + if not passed: + all_passed = False + + if all_passed: + logger.success("\n🎉 所有验证通过!MVP功能正常") + else: + logger.warning("\n⚠️ 部分验证失败,请检查配置和环境") + + return all_passed + + +if __name__ == "__main__": + success = main() + sys.exit(0 if success else 1) \ No newline at end of file