Fix SSE route dependency and align architecture docs
This commit is contained in:
56
backend/.env
56
backend/.env
@@ -1,56 +0,0 @@
|
||||
APP_NAME=AI+合规智能中枢
|
||||
APP_VERSION=0.1.0
|
||||
DEBUG=false
|
||||
|
||||
MILVUS_HOST=localhost
|
||||
MILVUS_PORT=19530
|
||||
MILVUS_COLLECTION=regulations
|
||||
MILVUS_DB_NAME=default
|
||||
|
||||
MINIO_ENDPOINT=localhost:9000
|
||||
MINIO_ACCESS_KEY=minioadmin
|
||||
MINIO_SECRET_KEY=minioadmin
|
||||
MINIO_BUCKET=compliance-docs
|
||||
MINIO_SECURE=false
|
||||
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_PASSWORD=redis@123
|
||||
REDIS_DB=0
|
||||
|
||||
POSTGRES_HOST=localhost
|
||||
POSTGRES_PORT=5432
|
||||
POSTGRES_USER=postgresql
|
||||
POSTGRES_PASSWORD=postgresql123456
|
||||
POSTGRES_DB=compliance_db
|
||||
|
||||
EMBEDDING_MODEL=BAAI/bge-m3
|
||||
EMBEDDING_DIM=1024
|
||||
EMBEDDING_MAX_LENGTH=8192
|
||||
EMBEDDING_BATCH_SIZE=12
|
||||
EMBEDDING_USE_FP16=true
|
||||
|
||||
CHUNK_SIZE=512
|
||||
CHUNK_OVERLAP=50
|
||||
MAX_FILE_SIZE_MB=100
|
||||
|
||||
API_HOST=0.0.0.0
|
||||
API_PORT=8000
|
||||
|
||||
LLM_PROVIDER=deepseek
|
||||
LLM_MODEL=deepseek-v4-flash
|
||||
LLM_MAX_TOKENS=4096
|
||||
LLM_TEMPERATURE=0.7
|
||||
|
||||
QWEN_API_KEY=sk-fVr9KmDZNC4pGDBQj0EUWz9bDmFzNxjYC9EzZpe2bVDsxtz8
|
||||
QWEN_BASE_URL=http://6.86.80.4:30080/v1
|
||||
QWEN_MODEL=qwen3.5-plus
|
||||
QWEN_VL_MODEL=qwen3-vl-plus
|
||||
|
||||
DEEPSEEK_API_KEY=sk-fVr9KmDZNC4pGDBQj0EUWz9bDmFzNxjYC9EzZpe2bVDsxtz8
|
||||
DEEPSEEK_BASE_URL=http://6.86.80.4:30080/v1
|
||||
DEEPSEEK_MODEL=deepseek-v4-flash
|
||||
|
||||
RAG_TOP_K=10
|
||||
RAG_MAX_CONTEXT_TOKENS=4000
|
||||
RAG_SUMMARY_MAX_TOKENS=1024
|
||||
@@ -1,56 +0,0 @@
|
||||
APP_NAME=AI+合规智能中枢
|
||||
APP_VERSION=0.1.0
|
||||
DEBUG=false
|
||||
|
||||
MILVUS_HOST=localhost
|
||||
MILVUS_PORT=19530
|
||||
MILVUS_COLLECTION=regulations
|
||||
MILVUS_DB_NAME=default
|
||||
|
||||
EMBEDDING_MODEL=BAAI/bge-m3
|
||||
EMBEDDING_DIM=1024
|
||||
EMBEDDING_MAX_LENGTH=8192
|
||||
EMBEDDING_BATCH_SIZE=12
|
||||
EMBEDDING_USE_FP16=true
|
||||
|
||||
MINIO_ENDPOINT=localhost:9000
|
||||
MINIO_ACCESS_KEY=minioadmin
|
||||
MINIO_SECRET_KEY=minioadmin123
|
||||
MINIO_BUCKET=compliance-docs
|
||||
MINIO_SECURE=false
|
||||
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_PASSWORD=
|
||||
REDIS_DB=0
|
||||
|
||||
POSTGRES_HOST=localhost
|
||||
POSTGRES_PORT=5432
|
||||
POSTGRES_USER=compliance
|
||||
POSTGRES_PASSWORD=compliance123
|
||||
POSTGRES_DB=compliance_db
|
||||
|
||||
CHUNK_SIZE=512
|
||||
CHUNK_OVERLAP=50
|
||||
MAX_FILE_SIZE_MB=100
|
||||
|
||||
API_HOST=0.0.0.0
|
||||
API_PORT=8000
|
||||
|
||||
LLM_PROVIDER=deepseek
|
||||
LLM_MODEL=deepseek-v4-flash
|
||||
LLM_MAX_TOKENS=4096
|
||||
LLM_TEMPERATURE=0.7
|
||||
|
||||
QWEN_API_KEY=your_api_key_here
|
||||
DEEPSEEK_API_KEY=your_api_key_here
|
||||
QWEN_BASE_URL=http://6.86.80.4:30080/v1
|
||||
DEEPSEEK_BASE_URL=http://6.86.80.4:30080/v1
|
||||
|
||||
QWEN_MODEL=qwen3.5-plus
|
||||
QWEN_VL_MODEL=qwen3-vl-plus
|
||||
DEEPSEEK_MODEL=deepseek-v4-flash
|
||||
|
||||
RAG_TOP_K=10
|
||||
RAG_MAX_CONTEXT_TOKENS=4000
|
||||
RAG_SUMMARY_MAX_TOKENS=1024
|
||||
@@ -1,3 +1,14 @@
|
||||
from .main import app
|
||||
"""Initialize the app package."""
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
__all__ = ["app"]
|
||||
|
||||
__all__ = ["app"]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
"""Handle getattr for this module."""
|
||||
if name == "app":
|
||||
from .main import app
|
||||
|
||||
return app
|
||||
raise AttributeError(name)
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
阿里云文档智能 API 解析 PDF,输出三层结构 chunks
|
||||
- structure_nodes: 目录树结构
|
||||
- semantic_blocks: 语义块(章节文本、表格、图片)
|
||||
- vector_chunks: 检索块(带 overlap 切分)
|
||||
"""
|
||||
"""Handle Aliyun parsing support for parse pdf."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
@@ -19,16 +15,16 @@ from alibabacloud_tea_openapi import models as open_api_models
|
||||
from alibabacloud_docmind_api20220711 import models as docmind_models
|
||||
from alibabacloud_tea_util import models as util_models
|
||||
|
||||
# ===================== 阿里云配置 =====================
|
||||
ALIBABA_ACCESS_KEY_ID = "LTAI5t6fWvAsvZkoF9WTbtys"
|
||||
ALIBABA_ACCESS_KEY_SECRET = "WX4oaE4FLYRa5L85TMQkqRPHeTJAF0"
|
||||
ALIBABA_ENDPOINT = "docmind-api.cn-hangzhou.aliyuncs.com"
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
ALIBABA_ACCESS_KEY_ID = os.getenv("ALIBABA_ACCESS_KEY_ID", "")
|
||||
ALIBABA_ACCESS_KEY_SECRET = os.getenv("ALIBABA_ACCESS_KEY_SECRET", "")
|
||||
ALIBABA_ENDPOINT = os.getenv("ALIBABA_ENDPOINT", "docmind-api.cn-hangzhou.aliyuncs.com")
|
||||
|
||||
# ===================== 切分参数 =====================
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
MAX_CHARS = 600
|
||||
OVERLAP_CHARS = 80
|
||||
|
||||
# ===================== 布局类型常量 =====================
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
TOC_TITLES = {"目次", "目录"}
|
||||
TITLE_SUBTYPES = {"doc_title", "para_title"}
|
||||
TEXT_SUBTYPES = {"para", "none"}
|
||||
@@ -36,8 +32,11 @@ FIGURE_TYPES = {"figure", "figure_name", "figure_note"}
|
||||
FIGURE_SUBTYPES = {"picture", "pic_title", "pic_caption"}
|
||||
|
||||
|
||||
# ===================== 阿里云 API 客户端 =====================
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
def init_client() -> DocmindClient:
|
||||
"""Handle init client."""
|
||||
if not ALIBABA_ACCESS_KEY_ID or not ALIBABA_ACCESS_KEY_SECRET:
|
||||
raise ValueError("缺少阿里云文档解析凭据,请设置 ALIBABA_ACCESS_KEY_ID 和 ALIBABA_ACCESS_KEY_SECRET")
|
||||
config = open_api_models.Config(
|
||||
access_key_id=ALIBABA_ACCESS_KEY_ID,
|
||||
access_key_secret=ALIBABA_ACCESS_KEY_SECRET,
|
||||
@@ -47,7 +46,7 @@ def init_client() -> DocmindClient:
|
||||
|
||||
|
||||
def submit_job(client: DocmindClient, file_path: str) -> str:
|
||||
"""提交文档解析任务"""
|
||||
"""Submit job."""
|
||||
file_name = Path(file_path).name
|
||||
request = docmind_models.SubmitDocParserJobAdvanceRequest(
|
||||
file_url_object=open(file_path, "rb"),
|
||||
@@ -62,14 +61,14 @@ def submit_job(client: DocmindClient, file_path: str) -> str:
|
||||
|
||||
|
||||
def query_status(client: DocmindClient, task_id: str) -> Dict:
|
||||
"""查询任务状态"""
|
||||
"""Handle query status."""
|
||||
request = docmind_models.QueryDocParserStatusRequest(id=task_id)
|
||||
response = client.query_doc_parser_status(request)
|
||||
return response.body.data.to_map() if response.body.data else None
|
||||
|
||||
|
||||
def wait_for_completion(client: DocmindClient, task_id: str, poll_interval: int = 5) -> bool:
|
||||
"""等待任务完成"""
|
||||
"""Wait for for completion."""
|
||||
while True:
|
||||
status_data = query_status(client, task_id)
|
||||
if not status_data:
|
||||
@@ -85,7 +84,7 @@ def wait_for_completion(client: DocmindClient, task_id: str, poll_interval: int
|
||||
|
||||
|
||||
def get_result(client: DocmindClient, task_id: str, layout_num: int = 0, layout_step_size: int = 50) -> Dict:
|
||||
"""获取解析结果"""
|
||||
"""Return result."""
|
||||
request = docmind_models.GetDocParserResultRequest(
|
||||
id=task_id,
|
||||
layout_step_size=layout_step_size,
|
||||
@@ -96,7 +95,7 @@ def get_result(client: DocmindClient, task_id: str, layout_num: int = 0, layout_
|
||||
|
||||
|
||||
def collect_all_results(client: DocmindClient, task_id: str, layout_step_size: int = 50) -> List[Dict]:
|
||||
"""收集所有解析结果"""
|
||||
"""Collect all results."""
|
||||
all_layouts = []
|
||||
layout_num = 0
|
||||
while True:
|
||||
@@ -113,8 +112,9 @@ def collect_all_results(client: DocmindClient, task_id: str, layout_step_size: i
|
||||
return all_layouts
|
||||
|
||||
|
||||
# ===================== 文本处理 =====================
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
def normalize_text(text: str) -> str:
|
||||
"""Normalize text."""
|
||||
text = text.replace("\r", "\n")
|
||||
text = text.replace(" ", " ")
|
||||
text = re.sub(r"\n+", "\n", text)
|
||||
@@ -123,34 +123,41 @@ def normalize_text(text: str) -> str:
|
||||
|
||||
|
||||
def get_page(layout: Dict) -> int:
|
||||
"""Return page."""
|
||||
return layout.get("pageNum", layout.get("pageNumber", 0))
|
||||
|
||||
|
||||
def get_text(layout: Dict) -> str:
|
||||
"""Return text."""
|
||||
text = normalize_text(layout.get("text", ""))
|
||||
if text:
|
||||
return text
|
||||
return normalize_text(layout.get("markdownContent", ""))
|
||||
|
||||
|
||||
# ===================== 布局类型判断 =====================
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
def is_title(layout: Dict) -> bool:
|
||||
"""Return whether title."""
|
||||
return layout.get("type") == "title" or layout.get("subType") in TITLE_SUBTYPES
|
||||
|
||||
|
||||
def is_text(layout: Dict) -> bool:
|
||||
"""Return whether text."""
|
||||
return layout.get("type") == "text" and layout.get("subType", "none") in TEXT_SUBTYPES
|
||||
|
||||
|
||||
def is_figure(layout: Dict) -> bool:
|
||||
"""Return whether figure."""
|
||||
return layout.get("type") in FIGURE_TYPES or layout.get("subType") in FIGURE_SUBTYPES
|
||||
|
||||
|
||||
def is_table(layout: Dict) -> bool:
|
||||
"""Return whether table."""
|
||||
return layout.get("type") == "table"
|
||||
|
||||
|
||||
def is_toc_layout(layout: Dict) -> bool:
|
||||
"""Return whether toc layout."""
|
||||
text = get_text(layout)
|
||||
if text in TOC_TITLES:
|
||||
return True
|
||||
@@ -160,6 +167,7 @@ def is_toc_layout(layout: Dict) -> bool:
|
||||
|
||||
|
||||
def extract_table_text(layout: Dict) -> str:
|
||||
"""Extract table text."""
|
||||
rows = []
|
||||
for cell in layout.get("cells", []):
|
||||
texts = []
|
||||
@@ -172,8 +180,9 @@ def extract_table_text(layout: Dict) -> str:
|
||||
return "\n".join(rows).strip()
|
||||
|
||||
|
||||
# ===================== 结构层:目录树 =====================
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
def build_structure_nodes(layouts: List[Dict]) -> List[Dict]:
|
||||
"""Build structure nodes."""
|
||||
nodes = []
|
||||
for layout in layouts:
|
||||
if not is_title(layout):
|
||||
@@ -195,8 +204,9 @@ def build_structure_nodes(layouts: List[Dict]) -> List[Dict]:
|
||||
return nodes
|
||||
|
||||
|
||||
# ===================== 语义层:章节内容 =====================
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
def update_section_path(section_stack: List[Dict], layout: Dict) -> List[Dict]:
|
||||
"""Update section path."""
|
||||
level = layout.get("level", 0)
|
||||
title = get_text(layout)
|
||||
while section_stack and section_stack[-1]["level"] >= level:
|
||||
@@ -213,10 +223,12 @@ def update_section_path(section_stack: List[Dict], layout: Dict) -> List[Dict]:
|
||||
|
||||
|
||||
def section_path_titles(section_stack: List[Dict]) -> List[str]:
|
||||
"""Handle section path titles."""
|
||||
return [item["title"] for item in section_stack]
|
||||
|
||||
|
||||
def flush_text_block(blocks: List[Dict], semantic_blocks: List[Dict], block_id: int) -> int:
|
||||
"""Handle flush text block."""
|
||||
if not blocks:
|
||||
return block_id
|
||||
|
||||
@@ -242,6 +254,7 @@ def flush_text_block(blocks: List[Dict], semantic_blocks: List[Dict], block_id:
|
||||
|
||||
|
||||
def build_semantic_blocks(layouts: List[Dict]) -> List[Dict]:
|
||||
"""Build semantic blocks."""
|
||||
semantic_blocks = []
|
||||
section_stack = []
|
||||
pending_text_blocks = []
|
||||
@@ -327,8 +340,9 @@ def build_semantic_blocks(layouts: List[Dict]) -> List[Dict]:
|
||||
return semantic_blocks
|
||||
|
||||
|
||||
# ===================== 检索层:向量 chunks =====================
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
def split_text_with_overlap(text: str, max_chars: int, overlap_chars: int) -> List[str]:
|
||||
"""Handle split text with overlap."""
|
||||
text = text.strip()
|
||||
if len(text) <= max_chars:
|
||||
return [text] if text else []
|
||||
@@ -351,6 +365,7 @@ def build_vector_chunks(
|
||||
max_chars: int,
|
||||
overlap_chars: int,
|
||||
) -> List[Dict]:
|
||||
"""Build vector chunks."""
|
||||
vector_chunks = []
|
||||
chunk_index = 1
|
||||
|
||||
@@ -385,7 +400,31 @@ def build_vector_chunks(
|
||||
return vector_chunks
|
||||
|
||||
|
||||
# ===================== 主转换函数 =====================
|
||||
def parse_pdf_to_structured_chunks(
|
||||
pdf_path: str,
|
||||
*,
|
||||
doc_id: str,
|
||||
doc_title: str,
|
||||
max_chars: int = MAX_CHARS,
|
||||
overlap_chars: int = OVERLAP_CHARS,
|
||||
poll_interval: int = 5,
|
||||
) -> Dict:
|
||||
"""Parse pdf to structured chunks."""
|
||||
client = init_client()
|
||||
task_id = submit_job(client, pdf_path)
|
||||
if not wait_for_completion(client, task_id, poll_interval):
|
||||
raise RuntimeError("阿里云文档解析任务失败")
|
||||
layouts = collect_all_results(client, task_id)
|
||||
return convert_layouts(
|
||||
layouts,
|
||||
doc_id=doc_id,
|
||||
doc_title=doc_title,
|
||||
max_chars=max_chars,
|
||||
overlap_chars=overlap_chars,
|
||||
)
|
||||
|
||||
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
def convert_layouts(
|
||||
layouts: List[Dict],
|
||||
doc_id: str,
|
||||
@@ -393,6 +432,7 @@ def convert_layouts(
|
||||
max_chars: int,
|
||||
overlap_chars: int,
|
||||
) -> Dict:
|
||||
"""Handle convert layouts."""
|
||||
structure_nodes = build_structure_nodes(layouts)
|
||||
semantic_blocks = build_semantic_blocks(layouts)
|
||||
vector_chunks = build_vector_chunks(
|
||||
@@ -411,8 +451,9 @@ def convert_layouts(
|
||||
}
|
||||
|
||||
|
||||
# ===================== CLI 入口 =====================
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
def main() -> None:
|
||||
"""Run the module entrypoint."""
|
||||
parser = argparse.ArgumentParser(description="阿里云文档智能解析 PDF,输出三层结构 chunks")
|
||||
parser.add_argument("pdf_path", help="PDF 文件路径")
|
||||
parser.add_argument("--out", default="vector_chunks.json", help="输出 JSON 文件路径")
|
||||
@@ -428,30 +469,30 @@ def main() -> None:
|
||||
if not pdf_path.exists():
|
||||
raise FileNotFoundError(f"PDF 文件不存在: {pdf_path}")
|
||||
|
||||
# 1. 提交阿里云任务
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
client = init_client()
|
||||
print(f"提交任务: {pdf_path}")
|
||||
task_id = submit_job(client, str(pdf_path))
|
||||
print(f"任务 ID: {task_id}")
|
||||
|
||||
# 2. 等待完成
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
print("等待任务完成...")
|
||||
if not wait_for_completion(client, task_id, args.poll_interval):
|
||||
print("任务失败,退出")
|
||||
return
|
||||
|
||||
# 3. 获取 layouts
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
print("获取解析结果...")
|
||||
layouts = collect_all_results(client, task_id)
|
||||
print(f"获取到 {len(layouts)} 个布局块")
|
||||
|
||||
# 4. 输出原始 layouts(可选)
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
if args.layouts_output:
|
||||
layouts_path = Path(args.layouts_output).expanduser().resolve()
|
||||
layouts_path.write_text(json.dumps(layouts, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(f"原始 layouts 已写入: {layouts_path}")
|
||||
|
||||
# 5. 转换为三层结构
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
print("转换为三层结构...")
|
||||
data = convert_layouts(
|
||||
layouts,
|
||||
@@ -461,7 +502,7 @@ def main() -> None:
|
||||
overlap_chars=args.overlap_chars,
|
||||
)
|
||||
|
||||
# 6. 输出结果
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
output_path = Path(args.out).expanduser().resolve()
|
||||
output_path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
|
||||
@@ -472,4 +513,4 @@ def main() -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
将 vector_chunks.json 向量化并上传到 Milvus 和 PostgreSQL
|
||||
使用中转站的 OpenAI 兼容 API
|
||||
"""
|
||||
"""Handle Aliyun parsing support for upload to milvus."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
@@ -23,18 +20,18 @@ from pymilvus import (
|
||||
)
|
||||
from openai import OpenAI
|
||||
|
||||
# ===================== 配置 =====================
|
||||
# 中转站配置
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
RELAY_BASE_URL = "http://6.86.80.4:30080/v1"
|
||||
RELAY_API_KEY = "sk-5HeY7gfSIlyZMacfuXOf5cphpymsNqufEu1ou4U3avbULcyY"
|
||||
EMBEDDING_MODEL = "text-embedding-v3" # 中转站支持的 embedding 模型
|
||||
EMBEDDING_MODEL = "text-embedding-v3" # Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
|
||||
# Milvus 配置
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
MILVUS_HOST = "localhost"
|
||||
MILVUS_PORT = "19530"
|
||||
COLLECTION_NAME = "regulation_chunks"
|
||||
|
||||
# PostgreSQL 配置
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
PG_HOST = "6.86.80.10"
|
||||
PG_PORT = 5432
|
||||
PG_USER = "postgresql"
|
||||
@@ -44,12 +41,12 @@ PG_DATABASE = "postgres"
|
||||
|
||||
# ===================== Embedding =====================
|
||||
def get_openai_client(api_key: str, base_url: str) -> OpenAI:
|
||||
"""创建 OpenAI 客户端连接到中转站"""
|
||||
"""Return openai client."""
|
||||
return OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
|
||||
def get_embeddings_batch(client: OpenAI, texts: List[str], batch_size: int = 10) -> List[List[float]]:
|
||||
"""批量获取文本向量"""
|
||||
"""Return embeddings batch."""
|
||||
all_embeddings = []
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
@@ -69,12 +66,13 @@ def get_embeddings_batch(client: OpenAI, texts: List[str], batch_size: int = 10)
|
||||
|
||||
# ===================== Milvus =====================
|
||||
def init_milvus(host: str, port: str):
|
||||
"""Handle init milvus."""
|
||||
connections.connect("default", host=host, port=port)
|
||||
print(f"已连接 Milvus: {host}:{port}")
|
||||
|
||||
|
||||
def create_collection(name: str, dim: int) -> Collection:
|
||||
"""创建或获取 collection"""
|
||||
"""Create collection."""
|
||||
if utility.has_collection(name):
|
||||
print(f"Collection '{name}' 已存在,删除重建")
|
||||
utility.drop_collection(name)
|
||||
@@ -90,14 +88,14 @@ def create_collection(name: str, dim: int) -> Collection:
|
||||
FieldSchema(name="page_end", dtype=DataType.INT64),
|
||||
FieldSchema(name="section_title", dtype=DataType.VARCHAR, max_length=512),
|
||||
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=2048),
|
||||
FieldSchema(name="source_ids", dtype=DataType.VARCHAR, max_length=4096), # JSON 字符串
|
||||
FieldSchema(name="source_ids", dtype=DataType.VARCHAR, max_length=4096), # Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dim),
|
||||
]
|
||||
|
||||
schema = CollectionSchema(fields, description="法规文档检索 chunks")
|
||||
collection = Collection(name, schema)
|
||||
|
||||
# 创建向量索引(IVF_FLAT,适合中小规模)
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
index_params = {
|
||||
"metric_type": "COSINE",
|
||||
"index_type": "IVF_FLAT",
|
||||
@@ -110,7 +108,7 @@ def create_collection(name: str, dim: int) -> Collection:
|
||||
|
||||
|
||||
def insert_chunks(collection: Collection, chunks: List[Dict], embeddings: List[List[float]]):
|
||||
"""插入 chunks 到 Milvus"""
|
||||
"""Handle insert chunks."""
|
||||
data = [
|
||||
[c["chunk_id"] for c in chunks],
|
||||
[c["doc_id"] for c in chunks],
|
||||
@@ -122,7 +120,7 @@ def insert_chunks(collection: Collection, chunks: List[Dict], embeddings: List[L
|
||||
[c["page_end"] for c in chunks],
|
||||
[c["section_title"] for c in chunks],
|
||||
[c["text"] for c in chunks],
|
||||
[json.dumps(c.get("source_ids", [])) for c in chunks], # JSON 字符串
|
||||
[json.dumps(c.get("source_ids", [])) for c in chunks], # Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
embeddings,
|
||||
]
|
||||
|
||||
@@ -132,14 +130,14 @@ def insert_chunks(collection: Collection, chunks: List[Dict], embeddings: List[L
|
||||
|
||||
|
||||
def load_collection(collection: Collection):
|
||||
"""加载 collection 到内存(搜索前必须)"""
|
||||
"""Load collection."""
|
||||
collection.load()
|
||||
print(f"Collection 已加载到内存")
|
||||
|
||||
|
||||
# ===================== PostgreSQL =====================
|
||||
def get_pg_connection(host: str, port: int, user: str, password: str, database: str):
|
||||
"""获取 PostgreSQL 连接"""
|
||||
"""Return pg connection."""
|
||||
conn = psycopg2.connect(
|
||||
host=host,
|
||||
port=port,
|
||||
@@ -152,18 +150,18 @@ def get_pg_connection(host: str, port: int, user: str, password: str, database:
|
||||
|
||||
|
||||
def insert_chunks_to_pg(conn, chunks: List[Dict], doc_data: Dict):
|
||||
"""插入 chunks 和相关数据到 PostgreSQL"""
|
||||
"""Handle insert chunks to pg."""
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
# 1. 插入文档
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
cursor.execute("""
|
||||
INSERT INTO documents (doc_id, title, standard_number, upload_time)
|
||||
VALUES (%s, %s, %s, NOW())
|
||||
ON CONFLICT (doc_id) DO UPDATE SET title = EXCLUDED.title, updated_at = NOW()
|
||||
""", (doc_data["doc_id"], doc_data["doc_title"], doc_data.get("standard_number")))
|
||||
|
||||
# 2. 插入语义块
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
semantic_blocks = doc_data.get("semantic_blocks", [])
|
||||
if semantic_blocks:
|
||||
block_rows = [
|
||||
@@ -192,7 +190,7 @@ def insert_chunks_to_pg(conn, chunks: List[Dict], doc_data: Dict):
|
||||
)
|
||||
print(f"已插入 {len(semantic_blocks)} 个语义块")
|
||||
|
||||
# 3. 插入向量块元数据
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
chunk_rows = [
|
||||
(
|
||||
doc_data["doc_id"],
|
||||
@@ -230,9 +228,9 @@ def insert_chunks_to_pg(conn, chunks: List[Dict], doc_data: Dict):
|
||||
cursor.close()
|
||||
|
||||
|
||||
# ===================== 主流程 =====================
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
def load_data(file_path: Path) -> Dict:
|
||||
"""加载 vector_chunks.json,返回完整数据"""
|
||||
"""Load data."""
|
||||
data = json.loads(file_path.read_text(encoding="utf-8"))
|
||||
return data
|
||||
|
||||
@@ -251,7 +249,8 @@ def upload_to_milvus_and_pg(
|
||||
pg_password: str,
|
||||
pg_database: str,
|
||||
):
|
||||
# 1. 加载完整数据
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
"""Handle upload to milvus and pg."""
|
||||
chunks_path = Path(chunks_file).expanduser().resolve()
|
||||
if not chunks_path.exists():
|
||||
raise FileNotFoundError(f"文件不存在: {chunks_path}")
|
||||
@@ -262,29 +261,29 @@ def upload_to_milvus_and_pg(
|
||||
raise ValueError("vector_chunks 为空")
|
||||
print(f"加载 {len(chunks)} 个 chunks")
|
||||
|
||||
# 2. 初始化连接
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
client = get_openai_client(api_key, base_url)
|
||||
init_milvus(milvus_host, milvus_port)
|
||||
pg_conn = get_pg_connection(pg_host, pg_port, pg_user, pg_password, pg_database)
|
||||
|
||||
# 3. 获取 embeddings
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
texts = [c["embedding_text"] for c in chunks]
|
||||
embeddings = get_embeddings_batch(client, texts, batch_size)
|
||||
print(f"生成 {len(embeddings)} 个向量")
|
||||
|
||||
# 4. 获取 embedding 维度
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
embedding_dim = len(embeddings[0])
|
||||
print(f"Embedding 维度: {embedding_dim}")
|
||||
|
||||
# 5. 创建 collection 并插入 Milvus
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
collection = create_collection(collection_name, embedding_dim)
|
||||
insert_chunks(collection, chunks, embeddings)
|
||||
load_collection(collection)
|
||||
|
||||
# 6. 插入 PostgreSQL
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
insert_chunks_to_pg(pg_conn, chunks, data)
|
||||
|
||||
# 7. 关闭连接
|
||||
# Keep parser integration steps explicit so external workflow behavior stays traceable.
|
||||
pg_conn.close()
|
||||
|
||||
print("上传完成!")
|
||||
@@ -292,6 +291,7 @@ def upload_to_milvus_and_pg(
|
||||
|
||||
# ===================== CLI =====================
|
||||
def main():
|
||||
"""Run the module entrypoint."""
|
||||
parser = argparse.ArgumentParser(description="将 vector_chunks 向量化并上传到 Milvus 和 PostgreSQL")
|
||||
parser.add_argument("chunks_file", help="vector_chunks.json 文件路径")
|
||||
parser.add_argument("--api-key", default=RELAY_API_KEY, help="中转站 API Key")
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
"""API接口模块"""
|
||||
"""Initialize the app.api package."""
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@ from app.api.routes import api_router
|
||||
from app.config.logging import setup_logging
|
||||
from app.config.settings import settings
|
||||
from app.services.llm.llm_factory import LLMFactory
|
||||
# Keep module behavior explicit so the backend flow stays easy to audit.
|
||||
|
||||
|
||||
setup_logging(level="INFO" if not settings.debug else "DEBUG")
|
||||
|
||||
|
||||
@@ -1,5 +1,14 @@
|
||||
"""API数据模型"""
|
||||
"""Initialize the app.api.models package."""
|
||||
|
||||
from .agent import (
|
||||
AskRequest,
|
||||
AskResponse,
|
||||
ChatRequest,
|
||||
ChatResponse,
|
||||
FeedbackRequest,
|
||||
SessionInfo,
|
||||
TemplateListResponse,
|
||||
)
|
||||
from .document import (
|
||||
DocumentUploadRequest,
|
||||
DocumentUploadResponse,
|
||||
@@ -9,8 +18,17 @@ from .document import (
|
||||
DocumentStatusResponse,
|
||||
ErrorResponse
|
||||
)
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AskRequest",
|
||||
"AskResponse",
|
||||
"ChatRequest",
|
||||
"ChatResponse",
|
||||
"FeedbackRequest",
|
||||
"SessionInfo",
|
||||
"TemplateListResponse",
|
||||
"DocumentUploadRequest",
|
||||
"DocumentUploadResponse",
|
||||
"SearchRequest",
|
||||
|
||||
79
backend/app/api/models/agent.py
Normal file
79
backend/app/api/models/agent.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Define API models for agent endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Group agent transport schemas together so route modules stay focused on HTTP flow.
|
||||
|
||||
|
||||
class AskRequest(BaseModel):
|
||||
"""Define the Ask Request API model."""
|
||||
|
||||
query: str = Field(..., min_length=1, max_length=2000)
|
||||
filters: Optional[str] = None
|
||||
provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
top_k: Optional[int] = Field(default=None, ge=1, le=20)
|
||||
prompt_template: Optional[str] = None
|
||||
|
||||
|
||||
class AskResponse(BaseModel):
|
||||
"""Define the Ask Response API model."""
|
||||
|
||||
answer: str
|
||||
sources: List[Dict] = Field(default_factory=list)
|
||||
model: str = ""
|
||||
latency_ms: int = 0
|
||||
retrieved_count: int = 0
|
||||
context_tokens: int = 0
|
||||
truncated: bool = False
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
"""Define the Chat Request API model."""
|
||||
|
||||
query: str = Field(..., min_length=1, max_length=2000)
|
||||
session_id: Optional[str] = None
|
||||
filters: Optional[str] = None
|
||||
provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
top_k: Optional[int] = Field(default=None, ge=1, le=20)
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""Define the Chat Response API model."""
|
||||
|
||||
session_id: str
|
||||
answer: str
|
||||
sources: List[Dict] = Field(default_factory=list)
|
||||
model: str = ""
|
||||
latency_ms: int = 0
|
||||
message_count: int = 0
|
||||
|
||||
|
||||
class SessionInfo(BaseModel):
|
||||
"""Define the Session Info API model."""
|
||||
|
||||
session_id: str
|
||||
message_count: int
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
|
||||
class FeedbackRequest(BaseModel):
|
||||
"""Define the Feedback Request API model."""
|
||||
|
||||
session_id: str
|
||||
message_index: int
|
||||
rating: int = Field(..., ge=1, le=5)
|
||||
comment: Optional[str] = None
|
||||
|
||||
|
||||
class TemplateListResponse(BaseModel):
|
||||
"""Define the Template List Response API model."""
|
||||
|
||||
templates: Dict[str, str]
|
||||
@@ -1,19 +1,21 @@
|
||||
"""文档相关Pydantic数据模型"""
|
||||
"""Define API models for document."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
# Group related schema definitions so validation rules stay consistent.
|
||||
|
||||
|
||||
|
||||
class DocumentUploadRequest(BaseModel):
|
||||
"""文档上传请求"""
|
||||
"""Define the Document Upload Request API model."""
|
||||
doc_name: Optional[str] = Field(None, description="文档名称")
|
||||
regulation_type: Optional[str] = Field(None, description="法规类型")
|
||||
version: Optional[str] = Field(None, description="文档版本")
|
||||
|
||||
|
||||
class DocumentUploadResponse(BaseModel):
|
||||
"""文档上传响应"""
|
||||
"""Define the Document Upload Response API model."""
|
||||
doc_id: str = Field(..., description="文档ID")
|
||||
doc_name: str = Field(..., description="文档名称")
|
||||
status: str = Field(..., description="处理状态")
|
||||
@@ -25,14 +27,14 @@ class DocumentUploadResponse(BaseModel):
|
||||
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
"""检索请求"""
|
||||
"""Define the Search Request API model."""
|
||||
query: str = Field(..., description="查询文本")
|
||||
top_k: int = Field(default=10, description="返回结果数量")
|
||||
filters: Optional[str] = Field(None, description="过滤条件")
|
||||
|
||||
|
||||
class SearchResultItem(BaseModel):
|
||||
"""单个检索结果"""
|
||||
"""Define the Search Result Item API model."""
|
||||
id: int = Field(..., description="记录ID")
|
||||
content: str = Field(..., description="内容")
|
||||
score: float = Field(..., description="相似度分数")
|
||||
@@ -40,7 +42,7 @@ class SearchResultItem(BaseModel):
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
"""检索响应"""
|
||||
"""Define the Search Response API model."""
|
||||
query: str = Field(..., description="查询文本")
|
||||
total: int = Field(..., description="结果总数")
|
||||
results: List[SearchResultItem] = Field(default_factory=list, description="结果列表")
|
||||
@@ -48,7 +50,7 @@ class SearchResponse(BaseModel):
|
||||
|
||||
|
||||
class DocumentStatusResponse(BaseModel):
|
||||
"""文档状态响应"""
|
||||
"""Define the Document Status Response API model."""
|
||||
doc_id: str = Field(..., description="文档ID")
|
||||
status: str = Field(..., description="状态")
|
||||
num_chunks: Optional[int] = Field(None, description="分块数量")
|
||||
@@ -56,7 +58,7 @@ class DocumentStatusResponse(BaseModel):
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""错误响应"""
|
||||
"""Define the Error Response API model."""
|
||||
error: str = Field(..., description="错误类型")
|
||||
message: str = Field(..., description="错误消息")
|
||||
timestamp: datetime = Field(default_factory=datetime.now, description="时间戳")
|
||||
|
||||
@@ -1,16 +1,29 @@
|
||||
"""API路由模块"""
|
||||
"""Initialize the app.api.routes package."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from .compliance import router as compliance_router
|
||||
from .documents import router as documents_router
|
||||
from .knowledge import router as knowledge_router
|
||||
from .agent import router as agent_router
|
||||
from .status import router as status_router
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
# 主路由
|
||||
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
api_router = APIRouter()
|
||||
|
||||
# 注册子路由
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
api_router.include_router(documents_router)
|
||||
api_router.include_router(knowledge_router)
|
||||
api_router.include_router(agent_router)
|
||||
api_router.include_router(compliance_router)
|
||||
api_router.include_router(status_router)
|
||||
|
||||
__all__ = ["api_router", "documents_router", "knowledge_router", "agent_router"]
|
||||
__all__ = [
|
||||
"api_router",
|
||||
"documents_router",
|
||||
"knowledge_router",
|
||||
"agent_router",
|
||||
"compliance_router",
|
||||
"status_router",
|
||||
]
|
||||
|
||||
@@ -1,186 +1,83 @@
|
||||
"""Agent API接口 - 问答对话接口"""
|
||||
"""Define API routes for agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional, AsyncGenerator
|
||||
from loguru import logger
|
||||
import json
|
||||
import asyncio
|
||||
import json
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from app.services.agent.qa_agent import QAAgent, AgentConfig
|
||||
from app.services.agent.session_manager import SessionManager
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from dataclasses import asdict
|
||||
|
||||
from app.api.models import (
|
||||
AskRequest,
|
||||
AskResponse,
|
||||
ChatRequest,
|
||||
ChatResponse,
|
||||
FeedbackRequest,
|
||||
SessionInfo,
|
||||
)
|
||||
from app.config.settings import settings
|
||||
from app.shared.bootstrap import get_agent_conversation_service, get_conversation_store
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
|
||||
|
||||
router = APIRouter(prefix="/agent", tags=["agent"])
|
||||
|
||||
# 会话管理器(全局实例)
|
||||
session_manager = SessionManager()
|
||||
|
||||
|
||||
# ===== Pydantic Models =====
|
||||
|
||||
class AskRequest(BaseModel):
|
||||
"""单次问答请求"""
|
||||
query: str = Field(..., description="用户问题", min_length=1, max_length=2000)
|
||||
filters: Optional[str] = Field(None, description="检索过滤条件")
|
||||
provider: Optional[str] = Field(None, description="LLM提供商 (qwen/deepseek)")
|
||||
model: Optional[str] = Field(None, description="LLM模型名称")
|
||||
top_k: Optional[int] = Field(None, description="检索数量", ge=1, le=20)
|
||||
prompt_template: Optional[str] = Field(None, description="Prompt模板名称")
|
||||
|
||||
|
||||
class AskResponse(BaseModel):
|
||||
"""问答响应"""
|
||||
answer: str
|
||||
sources: List[Dict] = []
|
||||
model: str = ""
|
||||
latency_ms: int = 0
|
||||
retrieved_count: int = 0
|
||||
context_tokens: int = 0
|
||||
truncated: bool = False
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
"""多轮对话请求"""
|
||||
query: str = Field(..., description="用户问题", min_length=1, max_length=2000)
|
||||
session_id: Optional[str] = Field(None, description="会话ID(首次对话可不传)")
|
||||
filters: Optional[str] = Field(None, description="检索过滤条件")
|
||||
provider: Optional[str] = Field(None, description="LLM提供商")
|
||||
model: Optional[str] = Field(None, description="LLM模型名称")
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""多轮对话响应"""
|
||||
session_id: str
|
||||
answer: str
|
||||
sources: List[Dict] = []
|
||||
model: str = ""
|
||||
latency_ms: int = 0
|
||||
message_count: int = 0
|
||||
|
||||
|
||||
class SessionInfo(BaseModel):
|
||||
"""会话信息"""
|
||||
session_id: str
|
||||
message_count: int
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
|
||||
class FeedbackRequest(BaseModel):
|
||||
"""反馈请求"""
|
||||
session_id: str
|
||||
message_index: int
|
||||
rating: int = Field(..., ge=1, le=5, description="评分 1-5")
|
||||
comment: Optional[str] = Field(None, description="反馈内容")
|
||||
|
||||
|
||||
class TemplateListResponse(BaseModel):
|
||||
"""模板列表响应"""
|
||||
templates: Dict[str, str]
|
||||
|
||||
|
||||
# ===== API Endpoints =====
|
||||
|
||||
@router.post("/ask", response_model=AskResponse)
|
||||
async def ask_question(request: AskRequest):
|
||||
"""
|
||||
单次问答接口
|
||||
|
||||
不保存会话历史,适合单次查询场景。
|
||||
"""
|
||||
logger.info(f"收到问答请求: {request.query}")
|
||||
|
||||
"""Handle ask question."""
|
||||
try:
|
||||
# 构建Agent配置
|
||||
config = AgentConfig(
|
||||
llm_provider=request.provider or settings.llm_provider,
|
||||
llm_model=request.model or settings.llm_model,
|
||||
top_k=request.top_k or settings.rag_top_k
|
||||
)
|
||||
|
||||
# 创建Agent并执行问答
|
||||
agent = QAAgent(config)
|
||||
response = agent.ask(
|
||||
_, result = get_agent_conversation_service().ask(
|
||||
query=request.query,
|
||||
filters=request.filters,
|
||||
prompt_template=request.prompt_template
|
||||
provider=request.provider or settings.llm_provider,
|
||||
model=request.model or settings.llm_model,
|
||||
top_k=request.top_k or settings.rag_top_k,
|
||||
prompt_template=request.prompt_template,
|
||||
)
|
||||
agent.close()
|
||||
|
||||
return AskResponse(
|
||||
answer=response.answer,
|
||||
sources=response.sources,
|
||||
model=response.model,
|
||||
latency_ms=response.latency_ms,
|
||||
retrieved_count=response.retrieved_count,
|
||||
context_tokens=response.context_tokens,
|
||||
truncated=response.truncated,
|
||||
error=response.error
|
||||
answer=result.answer,
|
||||
sources=[asdict(source) for source in result.sources],
|
||||
model=result.model,
|
||||
latency_ms=result.latency_ms,
|
||||
retrieved_count=result.retrieved_count,
|
||||
context_tokens=result.context_tokens,
|
||||
truncated=result.truncated,
|
||||
error=result.error,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"问答失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/chat", response_model=ChatResponse)
|
||||
async def chat_with_session(request: ChatRequest):
|
||||
"""
|
||||
多轮对话接口
|
||||
|
||||
支持会话历史记录,适合连续对话场景。
|
||||
"""
|
||||
logger.info(f"收到对话请求: session={request.session_id}, query={request.query}")
|
||||
|
||||
"""Handle chat with session."""
|
||||
try:
|
||||
# 获取或创建会话
|
||||
if request.session_id:
|
||||
session = session_manager.get_session(request.session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="会话不存在或已过期")
|
||||
else:
|
||||
session = session_manager.create_session()
|
||||
|
||||
# 添加用户消息
|
||||
session.add_user_message(request.query)
|
||||
|
||||
# 执行问答
|
||||
config = AgentConfig(
|
||||
llm_provider=request.provider or settings.llm_provider,
|
||||
llm_model=request.model or settings.llm_model
|
||||
)
|
||||
|
||||
agent = QAAgent(config)
|
||||
response = agent.ask(
|
||||
session_id, result = get_agent_conversation_service().chat(
|
||||
query=request.query,
|
||||
filters=request.filters
|
||||
session_id=request.session_id,
|
||||
filters=request.filters,
|
||||
provider=request.provider or settings.llm_provider,
|
||||
model=request.model or settings.llm_model,
|
||||
top_k=request.top_k or settings.rag_top_k,
|
||||
)
|
||||
agent.close()
|
||||
|
||||
# 添加助手消息
|
||||
session.add_assistant_message(
|
||||
response.answer,
|
||||
response.sources
|
||||
)
|
||||
|
||||
session = get_conversation_store().get_session(session_id)
|
||||
return ChatResponse(
|
||||
session_id=session.session_id,
|
||||
answer=response.answer,
|
||||
sources=response.sources,
|
||||
model=response.model,
|
||||
latency_ms=response.latency_ms,
|
||||
message_count=session.message_count
|
||||
session_id=session_id,
|
||||
answer=result.answer,
|
||||
sources=[asdict(source) for source in result.sources],
|
||||
model=result.model,
|
||||
latency_ms=result.latency_ms,
|
||||
message_count=len(session.messages) if session else 0,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"对话失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc))
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.get("/chat/stream")
|
||||
@@ -189,260 +86,93 @@ async def chat_stream_get(
|
||||
session_id: Optional[str] = None,
|
||||
filters: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
model: Optional[str] = None
|
||||
model: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
流式对话接口(SSE)- GET版本
|
||||
|
||||
EventSource只能发送GET请求,因此提供此接口。
|
||||
query参数通过URL传递。
|
||||
|
||||
SSE事件格式:
|
||||
- event: session - 会话ID
|
||||
- event: status - 状态更新(检索中、生成中)
|
||||
- event: sources - 引用来源
|
||||
- event: content - 回答内容片段
|
||||
- event: done - 完成,包含统计信息
|
||||
- event: error - 错误信息
|
||||
"""
|
||||
logger.info(f"收到GET流式对话请求: session={session_id}, query={query}")
|
||||
|
||||
"""Handle chat stream get."""
|
||||
async def generate_sse() -> AsyncGenerator[str, None]:
|
||||
"""生成SSE事件流"""
|
||||
"""Handle generate sse."""
|
||||
try:
|
||||
# 获取或创建会话
|
||||
if session_id:
|
||||
session = session_manager.get_session(session_id)
|
||||
if not session:
|
||||
yield f"event: error\ndata: 会话不存在或已过期\n\n"
|
||||
return
|
||||
else:
|
||||
session = session_manager.create_session()
|
||||
|
||||
# 发送session_id
|
||||
yield f"event: session\ndata: {json.dumps({'session_id': session.session_id})}\n\n"
|
||||
|
||||
# 添加用户消息
|
||||
session.add_user_message(query)
|
||||
|
||||
# 创建Agent
|
||||
config = AgentConfig(
|
||||
llm_provider=provider or settings.llm_provider,
|
||||
llm_model=model or settings.llm_model
|
||||
)
|
||||
|
||||
agent = QAAgent(config)
|
||||
|
||||
# 执行流式问答
|
||||
full_answer = ""
|
||||
sources = []
|
||||
done_data = {}
|
||||
|
||||
for event_data in agent.ask_stream(
|
||||
session_id_, event_stream = get_agent_conversation_service().stream_chat(
|
||||
query=query,
|
||||
filters=filters
|
||||
):
|
||||
session_id=session_id,
|
||||
filters=filters,
|
||||
provider=provider or settings.llm_provider,
|
||||
model=model or settings.llm_model,
|
||||
top_k=settings.rag_top_k,
|
||||
)
|
||||
yield f"event: session\ndata: {json.dumps({'session_id': session_id_})}\n\n"
|
||||
for event_data in event_stream:
|
||||
event_type = event_data.get("event", "content")
|
||||
data = event_data.get("data", "")
|
||||
|
||||
# 收集完整回答和来源
|
||||
if event_type == "content":
|
||||
full_answer += str(data)
|
||||
elif event_type == "sources":
|
||||
sources = data
|
||||
elif event_type == "done":
|
||||
done_data = data
|
||||
|
||||
# 发送SSE事件
|
||||
if isinstance(data, (dict, list)):
|
||||
yield f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
||||
yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
yield f"event: {event_type}\ndata: {data}\n\n"
|
||||
|
||||
# 小延迟让其他任务有机会执行
|
||||
await asyncio.sleep(0)
|
||||
|
||||
agent.close()
|
||||
|
||||
# 保存到会话历史
|
||||
session.add_assistant_message(full_answer, sources)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式对话失败: {e}")
|
||||
yield f"event: error\ndata: {str(e)}\n\n"
|
||||
except Exception as exc:
|
||||
yield f"event: error\ndata: {str(exc)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate_sse(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no" # 禁用nginx缓冲
|
||||
}
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chat/stream")
|
||||
async def chat_stream(request: ChatRequest):
|
||||
"""
|
||||
流式对话接口(SSE)
|
||||
|
||||
返回Server-Sent Events格式的流式响应,用户可实时看到思考过程和回答生成。
|
||||
|
||||
SSE事件格式:
|
||||
- event: status - 状态更新(检索中、生成中)
|
||||
- event: sources - 引用来源
|
||||
- event: content - 回答内容片段
|
||||
- event: done - 完成,包含统计信息
|
||||
- event: error - 错误信息
|
||||
"""
|
||||
logger.info(f"收到流式对话请求: session={request.session_id}, query={request.query}")
|
||||
|
||||
async def generate_sse() -> AsyncGenerator[str, None]:
|
||||
"""生成SSE事件流"""
|
||||
try:
|
||||
# 获取或创建会话
|
||||
if request.session_id:
|
||||
session = session_manager.get_session(request.session_id)
|
||||
if not session:
|
||||
yield f"event: error\ndata: 会话不存在或已过期\n\n"
|
||||
return
|
||||
else:
|
||||
session = session_manager.create_session()
|
||||
|
||||
# 发送session_id
|
||||
yield f"event: session\ndata: {json.dumps({'session_id': session.session_id})}\n\n"
|
||||
|
||||
# 添加用户消息
|
||||
session.add_user_message(request.query)
|
||||
|
||||
# 创建Agent
|
||||
config = AgentConfig(
|
||||
llm_provider=request.provider or settings.llm_provider,
|
||||
llm_model=request.model or settings.llm_model
|
||||
)
|
||||
|
||||
agent = QAAgent(config)
|
||||
|
||||
# 执行流式问答
|
||||
full_answer = ""
|
||||
sources = []
|
||||
done_data = {}
|
||||
|
||||
for event_data in agent.ask_stream(
|
||||
query=request.query,
|
||||
filters=request.filters
|
||||
):
|
||||
event_type = event_data.get("event", "content")
|
||||
data = event_data.get("data", "")
|
||||
|
||||
# 收集完整回答和来源
|
||||
if event_type == "content":
|
||||
full_answer += str(data)
|
||||
elif event_type == "sources":
|
||||
sources = data
|
||||
elif event_type == "done":
|
||||
done_data = data
|
||||
|
||||
# 发送SSE事件
|
||||
if isinstance(data, (dict, list)):
|
||||
yield f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
||||
else:
|
||||
yield f"event: {event_type}\ndata: {data}\n\n"
|
||||
|
||||
# 小延迟让其他任务有机会执行
|
||||
await asyncio.sleep(0)
|
||||
|
||||
agent.close()
|
||||
|
||||
# 保存到会话历史
|
||||
session.add_assistant_message(full_answer, sources)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式对话失败: {e}")
|
||||
yield f"event: error\ndata: {str(e)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate_sse(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no" # 禁用nginx缓冲
|
||||
}
|
||||
"""Handle chat stream."""
|
||||
return await chat_stream_get(
|
||||
query=request.query,
|
||||
session_id=request.session_id,
|
||||
filters=request.filters,
|
||||
provider=request.provider,
|
||||
model=request.model,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/session/{session_id}", response_model=SessionInfo)
|
||||
async def get_session_info(session_id: str):
|
||||
"""获取会话信息"""
|
||||
session = session_manager.get_session(session_id)
|
||||
"""Return session info."""
|
||||
session = get_conversation_store().get_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="会话不存在或已过期")
|
||||
|
||||
return SessionInfo(
|
||||
session_id=session.session_id,
|
||||
message_count=session.message_count,
|
||||
message_count=len(session.messages),
|
||||
created_at=session.created_at,
|
||||
updated_at=session.updated_at
|
||||
updated_at=session.updated_at,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/session/{session_id}/history")
|
||||
async def get_session_history(session_id: str, max_turns: int = 5):
|
||||
"""获取会话历史"""
|
||||
session = session_manager.get_session(session_id)
|
||||
"""Return session history."""
|
||||
session = get_conversation_store().get_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="会话不存在或已过期")
|
||||
|
||||
history = session.get_history(max_turns)
|
||||
history = [{"role": msg.role, "content": msg.content} for msg in session.messages[-(max_turns * 2):]]
|
||||
return {"session_id": session_id, "history": history}
|
||||
|
||||
|
||||
@router.delete("/session/{session_id}")
|
||||
async def delete_session(session_id: str):
|
||||
"""删除会话"""
|
||||
success = session_manager.delete_session(session_id)
|
||||
if not success:
|
||||
"""Delete session."""
|
||||
if not get_conversation_store().delete_session(session_id):
|
||||
raise HTTPException(status_code=404, detail="会话不存在")
|
||||
|
||||
return {"message": "会话已删除", "session_id": session_id}
|
||||
|
||||
|
||||
@router.get("/sessions", response_model=List[SessionInfo])
|
||||
async def list_sessions():
|
||||
"""列出所有活跃会话"""
|
||||
sessions = session_manager.list_sessions()
|
||||
return [SessionInfo(**s) for s in sessions]
|
||||
"""List sessions."""
|
||||
return [SessionInfo(**item) for item in get_conversation_store().list_sessions()]
|
||||
|
||||
|
||||
@router.post("/feedback")
|
||||
async def submit_feedback(request: FeedbackRequest):
|
||||
"""提交问答反馈"""
|
||||
session = session_manager.get_session(request.session_id)
|
||||
"""Submit feedback."""
|
||||
session = get_conversation_store().get_session(request.session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="会话不存在")
|
||||
|
||||
# 记录反馈(实际应用中可存储到数据库)
|
||||
logger.info(f"收到反馈: session={request.session_id}, rating={request.rating}, comment={request.comment}")
|
||||
|
||||
return {"message": "反馈已记录", "rating": request.rating}
|
||||
|
||||
|
||||
@router.get("/templates", response_model=TemplateListResponse)
|
||||
async def list_prompt_templates():
|
||||
"""列出可用的Prompt模板"""
|
||||
from app.services.rag.prompt_templates import PromptTemplates
|
||||
|
||||
templates = PromptTemplates.list_templates()
|
||||
return TemplateListResponse(templates=templates)
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
async def list_available_models():
|
||||
"""列出可用的LLM模型"""
|
||||
from app.services.llm import LLMFactory
|
||||
|
||||
factory = LLMFactory()
|
||||
models = factory.list_available_providers()
|
||||
return {"models": models}
|
||||
return {"message": "反馈已提交", "session_id": request.session_id, "message_index": request.message_index}
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
import uuid
|
||||
import os
|
||||
import json
|
||||
"""Define API routes for compliance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import APIRouter, File, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.schemas.compliance import (
|
||||
AnalyzeResponse,
|
||||
ComplianceChatRequest,
|
||||
@@ -13,38 +19,42 @@ from app.services.mock_data import (
|
||||
get_mock_compliance_result,
|
||||
get_mock_compliance_chat_response,
|
||||
)
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
|
||||
|
||||
router = APIRouter(prefix="/compliance", tags=["合规分析"])
|
||||
|
||||
# 临时存储分析任务
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
tasks_store: dict[str, dict] = {}
|
||||
|
||||
# Store uploaded compliance files inside the local backend data directory.
|
||||
RAW_DATA_DIR = Path(__file__).resolve().parents[3] / "data" / "raw"
|
||||
|
||||
|
||||
@router.post("/analyze", response_model=AnalyzeResponse)
|
||||
async def analyze_document(file: UploadFile = File(...)):
|
||||
"""上传设计方案进行分析"""
|
||||
# 生成任务ID
|
||||
"""Handle analyze document."""
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
task_id = generate_task_id()
|
||||
|
||||
# 保存文件
|
||||
raw_dir = "/airegulation/demo-mao/backend/data/raw"
|
||||
os.makedirs(raw_dir, exist_ok=True)
|
||||
file_path = os.path.join(raw_dir, f"compliance_{task_id}_{file.filename}")
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
RAW_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
file_path = RAW_DATA_DIR / f"compliance_{task_id}_{file.filename}"
|
||||
|
||||
content = await file.read()
|
||||
with open(file_path, "wb") as f:
|
||||
with file_path.open("wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# 记录任务
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
tasks_store[task_id] = {
|
||||
"task_id": task_id,
|
||||
"file_path": file_path,
|
||||
"file_path": str(file_path),
|
||||
"status": "processing",
|
||||
"result": None,
|
||||
}
|
||||
|
||||
# 模拟异步处理完成(立即返回结果)
|
||||
# 实际应用中这应该是后台任务
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
tasks_store[task_id]["status"] = "completed"
|
||||
tasks_store[task_id]["result"] = get_mock_compliance_result(task_id)
|
||||
|
||||
@@ -53,9 +63,9 @@ async def analyze_document(file: UploadFile = File(...)):
|
||||
|
||||
@router.get("/result/{task_id}")
|
||||
async def get_result(task_id: str):
|
||||
"""获取分析结果"""
|
||||
"""Return result."""
|
||||
if task_id not in tasks_store:
|
||||
# 如果任务ID不存在,返回默认mock结果
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
return get_mock_compliance_result(task_id)
|
||||
|
||||
task = tasks_store[task_id]
|
||||
@@ -68,8 +78,8 @@ async def get_result(task_id: str):
|
||||
|
||||
@router.post("/chat/{segment_id}")
|
||||
async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
|
||||
"""针对段落进行合规对话"""
|
||||
# 根据segment_id获取对应的intent
|
||||
"""Handle compliance chat."""
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
intent_map = {
|
||||
1: "车身结构设计",
|
||||
2: "动力系统配置",
|
||||
@@ -77,11 +87,12 @@ async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
|
||||
}
|
||||
intent = intent_map.get(segment_id, "车身结构设计")
|
||||
|
||||
async def generate():
|
||||
# 获取预设响应
|
||||
async def generate() -> AsyncGenerator[str, None]:
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
"""Handle generate."""
|
||||
response = get_mock_compliance_chat_response(intent, request.query)
|
||||
|
||||
# 流式输出响应
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
sentences = response.split("\n\n")
|
||||
for sentence in sentences:
|
||||
if sentence.strip():
|
||||
@@ -89,8 +100,15 @@ async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
|
||||
for chunk in chunks:
|
||||
if chunk.strip():
|
||||
await asyncio.sleep(0.05)
|
||||
yield {"event": "message", "data": json.dumps({"type": "chunk", "text": chunk + "\n"})}
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'chunk', 'text': chunk + chr(10)}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
|
||||
yield {"event": "message", "data": json.dumps({"type": "done"})}
|
||||
yield f"event: message\ndata: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
return EventSourceResponse(generate())
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Define API routes for docs."""
|
||||
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException
|
||||
import os
|
||||
import uuid
|
||||
@@ -10,30 +12,32 @@ from app.schemas.doc import (
|
||||
EmbedResponse,
|
||||
)
|
||||
from app.services.mock_data import get_mock_documents, generate_doc_id
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
|
||||
|
||||
router = APIRouter(prefix="/docs", tags=["文档管理"])
|
||||
|
||||
# 临时存储文档信息(包含预设的mock文档)
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
documents_store: dict[str, dict] = {}
|
||||
|
||||
# 初始化时加载mock文档
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
for doc in get_mock_documents():
|
||||
documents_store[doc["id"]] = doc
|
||||
|
||||
|
||||
@router.post("/upload", response_model=DocumentUploadResponse)
|
||||
async def upload_document(file: UploadFile = File(...)):
|
||||
"""上传法规文档"""
|
||||
# 检查文件格式
|
||||
"""Handle upload document."""
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
allowed_ext = [".pdf", ".docx", ".doc", ".txt"]
|
||||
ext = os.path.splitext(file.filename)[1].lower()
|
||||
if ext not in allowed_ext:
|
||||
raise HTTPException(400, f"Unsupported file format: {ext}")
|
||||
|
||||
# 生成文档ID
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
doc_id = generate_doc_id()
|
||||
|
||||
# 保存文件
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
raw_dir = "/airegulation/demo-mao/backend/data/raw"
|
||||
os.makedirs(raw_dir, exist_ok=True)
|
||||
file_path = os.path.join(raw_dir, f"{doc_id}_{file.filename}")
|
||||
@@ -42,7 +46,7 @@ async def upload_document(file: UploadFile = File(...)):
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# 记录文档信息
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
documents_store[doc_id] = {
|
||||
"id": doc_id,
|
||||
"name": file.filename,
|
||||
@@ -62,7 +66,7 @@ async def upload_document(file: UploadFile = File(...)):
|
||||
|
||||
@router.get("/list", response_model=DocumentListResponse)
|
||||
async def list_documents():
|
||||
"""获取已索引文档列表"""
|
||||
"""List documents."""
|
||||
docs = [
|
||||
DocumentInfo(
|
||||
id=d["id"],
|
||||
@@ -78,14 +82,14 @@ async def list_documents():
|
||||
|
||||
@router.post("/parse/{doc_id}", response_model=ParseResponse)
|
||||
async def parse_document(doc_id: str):
|
||||
"""解析文档并分块"""
|
||||
"""Parse document."""
|
||||
if doc_id not in documents_store:
|
||||
raise HTTPException(404, "Document not found")
|
||||
|
||||
doc = documents_store[doc_id]
|
||||
# 模拟解析逻辑
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
doc["status"] = "parsed"
|
||||
# 根据文件大小计算chunks数量
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
file_size = doc.get("size", 100000)
|
||||
doc["chunks"] = max(20, file_size // 8000)
|
||||
|
||||
@@ -94,12 +98,12 @@ async def parse_document(doc_id: str):
|
||||
|
||||
@router.post("/embed/{doc_id}", response_model=EmbedResponse)
|
||||
async def embed_document(doc_id: str):
|
||||
"""嵌入并存入向量库"""
|
||||
"""Embed document."""
|
||||
if doc_id not in documents_store:
|
||||
raise HTTPException(404, "Document not found")
|
||||
|
||||
doc = documents_store[doc_id]
|
||||
# 模拟嵌入逻辑
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
doc["status"] = "indexed"
|
||||
|
||||
return EmbedResponse(doc_id=doc_id, vectors=doc["chunks"])
|
||||
@@ -107,7 +111,7 @@ async def embed_document(doc_id: str):
|
||||
|
||||
@router.delete("/delete/{doc_id}")
|
||||
async def delete_document(doc_id: str):
|
||||
"""删除文档"""
|
||||
"""Delete document."""
|
||||
if doc_id not in documents_store:
|
||||
raise HTTPException(404, "Document not found")
|
||||
|
||||
|
||||
@@ -1,290 +1,140 @@
|
||||
"""文档上传与处理接口"""
|
||||
"""Define API routes for documents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, UploadFile, File, Form, HTTPException
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from typing import Optional
|
||||
import os
|
||||
import uuid
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from io import BytesIO
|
||||
from urllib.parse import quote
|
||||
|
||||
from ..models import DocumentUploadResponse, ErrorResponse
|
||||
from app.services.document_processor import DocumentProcessor
|
||||
from app.services.storage.minio_client import MinIOClient
|
||||
from app.config.settings import settings
|
||||
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
|
||||
from app.api.models import DocumentUploadResponse
|
||||
from app.application.documents import DocumentProcessResult
|
||||
from app.shared.bootstrap import get_document_command_service, get_document_query_service
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
|
||||
|
||||
router = APIRouter(prefix="/documents", tags=["documents"])
|
||||
|
||||
# MinIO客户端(用于文档存储)
|
||||
minio_client: Optional[MinIOClient] = None
|
||||
|
||||
|
||||
def get_minio_client() -> MinIOClient:
|
||||
"""获取MinIO客户端实例"""
|
||||
global minio_client
|
||||
if minio_client is None:
|
||||
minio_client = MinIOClient()
|
||||
minio_client.connect()
|
||||
minio_client.ensure_bucket()
|
||||
return minio_client
|
||||
|
||||
|
||||
def _build_document_records(limit: Optional[int] = None):
|
||||
"""构建文档列表记录,支持按最近更新时间倒序截断。"""
|
||||
minio = get_minio_client()
|
||||
|
||||
document_records = []
|
||||
objects = minio.client.list_objects(minio.bucket, recursive=True)
|
||||
for obj in objects:
|
||||
parts = obj.object_name.split("/", 1)
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
|
||||
doc_id, filename = parts
|
||||
last_modified = getattr(obj, "last_modified", None)
|
||||
document_records.append({
|
||||
"doc_id": doc_id,
|
||||
"filename": filename,
|
||||
"size": getattr(obj, "size", 0) or 0,
|
||||
"object_name": obj.object_name,
|
||||
"download_url": f"/api/v1/documents/download/{doc_id}",
|
||||
"last_modified": last_modified.isoformat() if last_modified else None,
|
||||
"_sort_key": last_modified.timestamp() if last_modified else 0,
|
||||
})
|
||||
|
||||
document_records.sort(key=lambda item: item["_sort_key"], reverse=True)
|
||||
if limit is not None:
|
||||
document_records = document_records[:limit]
|
||||
|
||||
for item in document_records:
|
||||
item.pop("_sort_key", None)
|
||||
|
||||
return document_records
|
||||
def _document_response(result: DocumentProcessResult) -> DocumentUploadResponse:
|
||||
"""Handle document response for this module."""
|
||||
return DocumentUploadResponse(
|
||||
doc_id=result.doc_id,
|
||||
doc_name=result.doc_name,
|
||||
status=result.status,
|
||||
message=result.message,
|
||||
num_chunks=result.num_chunks,
|
||||
summary=result.summary,
|
||||
summary_latency_ms=result.summary_latency_ms,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/upload", response_model=DocumentUploadResponse)
|
||||
async def upload_document(
|
||||
file: UploadFile = File(..., description="上传的文档文件"),
|
||||
doc_name: Optional[str] = Form(None, description="文档名称"),
|
||||
regulation_type: Optional[str] = Form(None, description="法规类型"),
|
||||
version: Optional[str] = Form(None, description="文档版本"),
|
||||
generate_summary: bool = Form(False, description="是否生成摘要(默认不生成,可节省约60秒)")
|
||||
doc_name: str | None = Form(None, description="文档名称"),
|
||||
regulation_type: str | None = Form(None, description="法规类型"),
|
||||
version: str | None = Form(None, description="文档版本"),
|
||||
generate_summary: bool = Form(False, description="是否生成摘要"),
|
||||
):
|
||||
"""
|
||||
上传文档并处理
|
||||
|
||||
支持格式:PDF、DOCX、DOC
|
||||
处理流程:解析 → 分块 → 嵌入 → 入库(摘要可选)
|
||||
文件存储:MinIO对象存储
|
||||
|
||||
参数说明:
|
||||
- generate_summary: 是否生成LLM摘要,默认False。勾选后处理时间增加约60秒。
|
||||
"""
|
||||
# 验证文件类型
|
||||
ext = os.path.splitext(file.filename)[1].lower()
|
||||
if ext not in [".pdf", ".docx", ".doc"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的文件类型: {ext},仅支持PDF、DOCX、DOC"
|
||||
)
|
||||
|
||||
# 验证文件大小
|
||||
if file.size and file.size > settings.max_file_size_mb * 1024 * 1024:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"文件过大,最大支持{settings.max_file_size_mb}MB"
|
||||
)
|
||||
|
||||
# 生成文档ID
|
||||
doc_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# 文档名称
|
||||
final_doc_name = doc_name or file.filename
|
||||
|
||||
# MinIO对象名称
|
||||
object_name = f"{doc_id}/{file.filename}"
|
||||
|
||||
logger.info(f"接收到文件上传: {final_doc_name}, 类型: {ext}, doc_id={doc_id}")
|
||||
"""Handle upload document."""
|
||||
content = await file.read()
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="文件名不能为空")
|
||||
if not content:
|
||||
raise HTTPException(status_code=400, detail="上传文件为空")
|
||||
|
||||
try:
|
||||
# 读取文件内容
|
||||
content = await file.read()
|
||||
|
||||
# 保存临时文件用于处理
|
||||
temp_dir = tempfile.gettempdir()
|
||||
temp_path = os.path.join(temp_dir, f"{doc_id}_{file.filename}")
|
||||
|
||||
with open(temp_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
logger.info(f"临时文件已保存到: {temp_path}")
|
||||
|
||||
# 上传到MinIO
|
||||
minio = get_minio_client()
|
||||
upload_success = minio.upload_bytes(
|
||||
data=content,
|
||||
object_name=object_name,
|
||||
content_type=minio._get_content_type(file.filename),
|
||||
metadata={
|
||||
"doc_id": doc_id # 仅传递ASCII安全的metadata
|
||||
}
|
||||
)
|
||||
|
||||
if upload_success:
|
||||
logger.success(f"文件已上传到MinIO: {object_name}")
|
||||
else:
|
||||
logger.warning(f"MinIO上传失败,仅使用本地临时文件")
|
||||
|
||||
# 处理文档(传入相同的doc_id,保持一致性)
|
||||
processor = DocumentProcessor(generate_summary=generate_summary)
|
||||
result = processor.process(
|
||||
file_path=temp_path,
|
||||
doc_id=doc_id, # 使用相同的doc_id
|
||||
doc_name=final_doc_name,
|
||||
result = get_document_command_service().upload_and_process(
|
||||
file_name=file.filename,
|
||||
content=content,
|
||||
content_type=file.content_type or "application/octet-stream",
|
||||
doc_name=doc_name,
|
||||
regulation_type=regulation_type or "",
|
||||
version=version or ""
|
||||
)
|
||||
processor.close()
|
||||
|
||||
# 清理临时文件
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
if result.success:
|
||||
return DocumentUploadResponse(
|
||||
doc_id=result.doc_id,
|
||||
doc_name=result.doc_name,
|
||||
status="success",
|
||||
message=result.message,
|
||||
num_chunks=result.num_chunks,
|
||||
summary=result.summary,
|
||||
summary_latency_ms=result.summary_latency_ms
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=result.message
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"文档处理失败: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"文档处理失败: {str(e)}"
|
||||
version=version or "",
|
||||
generate_summary=generate_summary,
|
||||
)
|
||||
if result.status == "failed":
|
||||
raise HTTPException(status_code=500, detail=result.message)
|
||||
return _document_response(result)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("文档上传失败")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.get("/status/{doc_id}", response_model=DocumentUploadResponse)
|
||||
async def get_document_status(doc_id: str):
|
||||
"""
|
||||
查询文档处理状态
|
||||
|
||||
Args:
|
||||
doc_id: 文档ID
|
||||
"""
|
||||
# TODO: 实现状态查询(需要数据库支持)
|
||||
"""Return document status."""
|
||||
document = get_document_query_service().get(doc_id)
|
||||
if not document:
|
||||
raise HTTPException(status_code=404, detail="文档不存在")
|
||||
return DocumentUploadResponse(
|
||||
doc_id=doc_id,
|
||||
doc_name="",
|
||||
status="unknown",
|
||||
message="状态查询功能待实现"
|
||||
doc_id=document.doc_id,
|
||||
doc_name=document.doc_name,
|
||||
status=document.status.value,
|
||||
message=document.error_message or "查询成功",
|
||||
num_chunks=document.chunk_count,
|
||||
summary=document.summary,
|
||||
summary_latency_ms=document.summary_latency_ms,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/download/{doc_id}")
|
||||
async def download_document(doc_id: str):
|
||||
"""
|
||||
下载文档(从MinIO获取)
|
||||
|
||||
Args:
|
||||
doc_id: 文档ID
|
||||
|
||||
Returns:
|
||||
文件下载响应
|
||||
"""
|
||||
logger.info(f"请求下载文档: doc_id={doc_id}")
|
||||
|
||||
"""Handle download document."""
|
||||
try:
|
||||
minio = get_minio_client()
|
||||
|
||||
# 查找该doc_id下的文件(MinIO对象名称格式: {doc_id}/{filename})
|
||||
objects = minio.list_objects(prefix=f"{doc_id}/")
|
||||
|
||||
if not objects:
|
||||
logger.warning(f"MinIO中未找到文档: doc_id={doc_id}")
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"文档不存在: doc_id={doc_id}"
|
||||
)
|
||||
|
||||
# 获取第一个匹配的对象
|
||||
object_name = objects[0]
|
||||
logger.info(f"找到MinIO对象: {object_name}")
|
||||
|
||||
# 获取文件数据
|
||||
file_data = minio.get_object_data(object_name)
|
||||
if file_data is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"获取文档数据失败"
|
||||
)
|
||||
|
||||
# 解析原始文件名
|
||||
original_name = object_name.split("/", 1)[1] if "/" in object_name else object_name
|
||||
|
||||
# 获取Content-Type
|
||||
content_type = minio._get_content_type(original_name)
|
||||
|
||||
logger.success(f"文档下载成功: {original_name}, 大小={len(file_data)}")
|
||||
|
||||
# 返回文件流(URL编码文件名以支持中文)
|
||||
encoded_name = quote(original_name)
|
||||
document, file_data = get_document_query_service().download(doc_id)
|
||||
encoded_name = quote(document.file_name)
|
||||
return StreamingResponse(
|
||||
BytesIO(file_data),
|
||||
media_type=content_type,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_name}"
|
||||
}
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"文档下载失败: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"文档下载失败: {str(e)}"
|
||||
media_type=document.content_type or "application/octet-stream",
|
||||
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_name}"},
|
||||
)
|
||||
except FileNotFoundError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc))
|
||||
except Exception as exc:
|
||||
logger.exception("文档下载失败")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.get("/list")
|
||||
async def list_documents():
|
||||
"""
|
||||
列出所有已上传的文档(从MinIO获取)
|
||||
"""
|
||||
try:
|
||||
documents = _build_document_records()
|
||||
return {"documents": documents, "total": len(documents)}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"列出文档失败: {e}")
|
||||
return {"documents": [], "total": 0, "error": str(e)}
|
||||
"""List documents."""
|
||||
documents = get_document_query_service().list_documents()
|
||||
return {
|
||||
"documents": [
|
||||
{
|
||||
"doc_id": item.doc_id,
|
||||
"doc_name": item.doc_name,
|
||||
"status": item.status.value,
|
||||
"chunk_count": item.chunk_count,
|
||||
"updated_at": item.updated_at.isoformat(),
|
||||
}
|
||||
for item in documents
|
||||
],
|
||||
"total": len(documents),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/management-list")
|
||||
async def get_document_management_list():
|
||||
"""
|
||||
文档管理清单接口:仅返回最近的10条文档。
|
||||
"""
|
||||
try:
|
||||
documents = _build_document_records(limit=10)
|
||||
return {"documents": documents, "total": len(documents), "limit": 10}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取文档管理清单失败: {e}")
|
||||
return {"documents": [], "total": 0, "limit": 10, "error": str(e)}
|
||||
"""Return document management list."""
|
||||
documents = get_document_query_service().list_documents(limit=10)
|
||||
return {
|
||||
"documents": [
|
||||
{
|
||||
"doc_id": item.doc_id,
|
||||
"doc_name": item.doc_name,
|
||||
"status": item.status.value,
|
||||
"chunk_count": item.chunk_count,
|
||||
"updated_at": item.updated_at.isoformat(),
|
||||
}
|
||||
for item in documents
|
||||
],
|
||||
"total": len(documents),
|
||||
"limit": 10,
|
||||
}
|
||||
|
||||
@@ -1,80 +1,51 @@
|
||||
"""知识库检索接口"""
|
||||
"""Define API routes for knowledge."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from ..models import SearchRequest, SearchResponse, SearchResultItem, ErrorResponse
|
||||
from app.services.document_processor import DocumentProcessor
|
||||
from app.api.models import SearchResponse, SearchResultItem, SearchRequest
|
||||
from app.shared.bootstrap import get_retrieval_service
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
|
||||
|
||||
router = APIRouter(prefix="/knowledge", tags=["knowledge"])
|
||||
|
||||
|
||||
@router.post("/search", response_model=SearchResponse)
|
||||
async def search_knowledge(request: SearchRequest):
|
||||
"""
|
||||
检索法规知识库
|
||||
"""Search knowledge."""
|
||||
if not request.query or not request.query.strip():
|
||||
raise HTTPException(status_code=400, detail="查询文本不能为空")
|
||||
|
||||
使用混合检索:Dense向量 + Sparse向量 + RRF融合
|
||||
|
||||
Args:
|
||||
request: 检索请求参数
|
||||
"""
|
||||
if not request.query or len(request.query.strip()) == 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="查询文本不能为空"
|
||||
)
|
||||
|
||||
logger.info(f"收到检索请求: {request.query}")
|
||||
|
||||
try:
|
||||
# 执行检索
|
||||
processor = DocumentProcessor()
|
||||
results = processor.search(
|
||||
query=request.query,
|
||||
top_k=request.top_k,
|
||||
filters=request.filters
|
||||
)
|
||||
processor.close()
|
||||
|
||||
# 转换结果格式
|
||||
result_items = []
|
||||
for r in results:
|
||||
item = SearchResultItem(
|
||||
id=r.get("id", 0),
|
||||
content=r.get("content", ""),
|
||||
score=r.get("score", 0.0),
|
||||
metadata=r.get("metadata", {})
|
||||
results = get_retrieval_service().retrieve(
|
||||
query=request.query,
|
||||
top_k=request.top_k,
|
||||
filters=request.filters,
|
||||
)
|
||||
return SearchResponse(
|
||||
query=request.query,
|
||||
total=len(results),
|
||||
results=[
|
||||
SearchResultItem(
|
||||
id=index + 1,
|
||||
content=item.content,
|
||||
score=item.score,
|
||||
metadata={
|
||||
"doc_id": item.doc_id,
|
||||
"doc_name": item.doc_name,
|
||||
"chunk_id": item.chunk_id,
|
||||
"section_title": item.section_title,
|
||||
"page_number": item.page_number,
|
||||
**item.metadata,
|
||||
},
|
||||
)
|
||||
result_items.append(item)
|
||||
|
||||
return SearchResponse(
|
||||
query=request.query,
|
||||
total=len(result_items),
|
||||
results=result_items
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检索失败: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"检索失败: {str(e)}"
|
||||
)
|
||||
for index, item in enumerate(results)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/retrieval", response_model=SearchResponse)
|
||||
async def knowledge_retrieval(request: SearchRequest):
|
||||
"""
|
||||
知识检索接口(与架构文档对齐)
|
||||
|
||||
该接口实现完整的检索流程:
|
||||
1. 意图识别
|
||||
2. BM25关键词检索 + 向量语义检索(双路召回)
|
||||
3. Cross-Encoder精排
|
||||
4. 返回结果
|
||||
|
||||
Args:
|
||||
request: 检索请求
|
||||
"""
|
||||
# 当前版本使用混合检索,后续可添加精排步骤
|
||||
"""Handle knowledge retrieval."""
|
||||
return await search_knowledge(request)
|
||||
|
||||
@@ -1,29 +1,39 @@
|
||||
"""Define API routes for rag."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion
|
||||
from app.services.mock_data import (
|
||||
get_mock_quick_questions,
|
||||
get_mock_retrieval,
|
||||
get_mock_rag_answer,
|
||||
)
|
||||
import json
|
||||
import asyncio
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
|
||||
|
||||
router = APIRouter(prefix="/rag", tags=["RAG问答"])
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
async def rag_chat(request: RagChatRequest):
|
||||
"""SSE流式问答"""
|
||||
"""Handle rag chat."""
|
||||
|
||||
async def generate():
|
||||
# 发送检索开始事件
|
||||
yield {"event": "message", "data": json.dumps({"type": "retrieving"})}
|
||||
async def generate() -> AsyncGenerator[str, None]:
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
"""Handle generate."""
|
||||
yield f"event: message\ndata: {json.dumps({'type': 'retrieving'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 模拟检索延迟
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# 执行检索
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
docs = get_mock_retrieval(request.query, top_k=request.top_k)
|
||||
|
||||
retrieved_data = [
|
||||
@@ -36,39 +46,49 @@ async def rag_chat(request: RagChatRequest):
|
||||
}
|
||||
for d in docs
|
||||
]
|
||||
yield {"event": "message", "data": json.dumps({"type": "retrieved", "docs": retrieved_data})}
|
||||
yield f"event: message\ndata: {json.dumps({'type': 'retrieved', 'docs': retrieved_data}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 发送生成开始事件
|
||||
yield {"event": "message", "data": json.dumps({"type": "generating", "text": "正在生成答案..."})}
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
yield (
|
||||
f"event: message\ndata: "
|
||||
f"{json.dumps({'type': 'generating', 'text': '正在生成答案...'}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
|
||||
# 模拟生成延迟
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# 获取预设答案
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
answer = get_mock_rag_answer(request.query)
|
||||
|
||||
# 流式输出答案(按句子分割)
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
sentences = answer.split("\n\n")
|
||||
for sentence in sentences:
|
||||
if sentence.strip():
|
||||
# 进一步分割长句子
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
chunks = sentence.split("\n")
|
||||
for chunk in chunks:
|
||||
if chunk.strip():
|
||||
await asyncio.sleep(0.05) # 模拟生成延迟
|
||||
yield {"event": "message", "data": json.dumps({"type": "chunk", "text": chunk + "\n"})}
|
||||
await asyncio.sleep(0.05) # Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'chunk', 'text': chunk + chr(10)}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
|
||||
# 发送完成事件
|
||||
yield {"event": "message", "data": json.dumps({"type": "done"})}
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
yield f"event: message\ndata: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
return EventSourceResponse(generate())
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/quick-questions", response_model=QuickQuestionsResponse)
|
||||
async def get_quick_questions():
|
||||
"""获取预设快捷问题"""
|
||||
"""Return quick questions."""
|
||||
questions = [
|
||||
QuickQuestion(id=q["id"], question=q["question"], category=q["category"])
|
||||
for q in get_mock_quick_questions()
|
||||
]
|
||||
return QuickQuestionsResponse(questions=questions)
|
||||
return QuickQuestionsResponse(questions=questions)
|
||||
|
||||
@@ -1,28 +1,44 @@
|
||||
"""Define API routes for status."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from app.core.config import settings
|
||||
from app.services.mock_data import MOCK_SYSTEM_STATS, MOCK_SYSTEM_CONFIG
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.shared.bootstrap import get_document_query_service, get_vector_index
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
|
||||
|
||||
router = APIRouter(prefix="/status", tags=["系统状态"])
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_stats():
|
||||
"""获取系统统计"""
|
||||
# 返回预设统计数据
|
||||
return MOCK_SYSTEM_STATS
|
||||
"""Return stats."""
|
||||
documents = get_document_query_service().list_documents()
|
||||
indexed = sum(1 for item in documents if item.status.value == "indexed")
|
||||
failed = sum(1 for item in documents if item.status.value == "failed")
|
||||
return {
|
||||
"documents_total": len(documents),
|
||||
"documents_indexed": indexed,
|
||||
"documents_failed": failed,
|
||||
"chunks_total": sum(item.chunk_count for item in documents),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def get_config():
|
||||
"""获取当前配置"""
|
||||
return MOCK_SYSTEM_CONFIG
|
||||
"""Return config."""
|
||||
return {
|
||||
"embedding_model": settings.embedding_model,
|
||||
"embedding_dim": settings.embedding_dim,
|
||||
"embedding_base_url": settings.embedding_base_url,
|
||||
"milvus_collection": settings.milvus_collection,
|
||||
"llm_provider": settings.llm_provider,
|
||||
"llm_model": settings.llm_model,
|
||||
"document_metadata_path": settings.document_metadata_path,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/milvus/health")
|
||||
async def milvus_health():
|
||||
"""Milvus健康检查"""
|
||||
# 模拟连接状态(假数据模式下始终返回连接成功)
|
||||
return {
|
||||
"connected": True,
|
||||
"collections": ["vehicle_regulations"],
|
||||
}
|
||||
"""Handle milvus health."""
|
||||
return get_vector_index().health()
|
||||
|
||||
5
backend/app/application/__init__.py
Normal file
5
backend/app/application/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Initialize the app.application package."""
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = []
|
||||
7
backend/app/application/agent/__init__.py
Normal file
7
backend/app/application/agent/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Initialize the app.application.agent package."""
|
||||
|
||||
from .services import AgentConversationService
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = ["AgentConversationService"]
|
||||
145
backend/app/application/agent/services.py
Normal file
145
backend/app/application/agent/services.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Implement application-layer logic for services."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generator
|
||||
|
||||
from app.domain.conversation import AnswerGenerator, AnswerResult, ConversationStore
|
||||
from app.domain.retrieval import RetrievedChunk
|
||||
|
||||
from app.application.knowledge import KnowledgeRetrievalService
|
||||
# Keep orchestration logic centralized so use-case flow stays easy to trace.
|
||||
|
||||
|
||||
|
||||
class AgentConversationService:
|
||||
"""Provide the Agent Conversation Service service."""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
retrieval_service: KnowledgeRetrievalService,
|
||||
answer_generator: AnswerGenerator,
|
||||
conversation_store: ConversationStore,
|
||||
) -> None:
|
||||
"""Initialize the Agent Conversation Service instance."""
|
||||
self.retrieval_service = retrieval_service
|
||||
self.answer_generator = answer_generator
|
||||
self.conversation_store = conversation_store
|
||||
|
||||
def ask(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
filters: str | None = None,
|
||||
provider: str | None = None,
|
||||
model: str | None = None,
|
||||
top_k: int = 5,
|
||||
prompt_template: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> tuple[str | None, AnswerResult]:
|
||||
"""Handle ask for the Agent Conversation Service instance."""
|
||||
history = None
|
||||
active_session_id = None
|
||||
if session_id:
|
||||
session = self.conversation_store.get_session(session_id)
|
||||
if not session:
|
||||
raise ValueError("会话不存在或已过期")
|
||||
history = [{"role": msg.role, "content": msg.content} for msg in session.messages[-10:]]
|
||||
active_session_id = session.session_id
|
||||
self.conversation_store.save_message(session_id, role="user", content=query)
|
||||
retrieved = self.retrieval_service.retrieve(query=query, top_k=top_k, filters=filters)
|
||||
result = self.answer_generator.generate(
|
||||
query=query,
|
||||
retrieved_chunks=retrieved,
|
||||
history=history,
|
||||
provider=provider,
|
||||
model=model,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
if active_session_id:
|
||||
self.conversation_store.save_message(
|
||||
active_session_id,
|
||||
role="assistant",
|
||||
content=result.answer,
|
||||
sources=[source.__dict__ for source in result.sources],
|
||||
)
|
||||
return active_session_id, result
|
||||
|
||||
def chat(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
session_id: str | None = None,
|
||||
filters: str | None = None,
|
||||
provider: str | None = None,
|
||||
model: str | None = None,
|
||||
top_k: int = 5,
|
||||
) -> tuple[str, AnswerResult]:
|
||||
"""Handle chat for the Agent Conversation Service instance."""
|
||||
session = self.conversation_store.get_session(session_id) if session_id else None
|
||||
if session is None:
|
||||
session = self.conversation_store.create_session()
|
||||
self.conversation_store.save_message(session.session_id, role="user", content=query)
|
||||
history = [{"role": msg.role, "content": msg.content} for msg in session.messages[-10:]]
|
||||
retrieved = self.retrieval_service.retrieve(query=query, top_k=top_k, filters=filters)
|
||||
result = self.answer_generator.generate(
|
||||
query=query,
|
||||
retrieved_chunks=retrieved,
|
||||
history=history,
|
||||
provider=provider,
|
||||
model=model,
|
||||
)
|
||||
self.conversation_store.save_message(
|
||||
session.session_id,
|
||||
role="assistant",
|
||||
content=result.answer,
|
||||
sources=[source.__dict__ for source in result.sources],
|
||||
)
|
||||
return session.session_id, result
|
||||
|
||||
def stream_chat(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
session_id: str | None = None,
|
||||
filters: str | None = None,
|
||||
provider: str | None = None,
|
||||
model: str | None = None,
|
||||
top_k: int = 5,
|
||||
prompt_template: str | None = None,
|
||||
) -> tuple[str, Generator[dict, None, None]]:
|
||||
"""Stream chat for the Agent Conversation Service instance."""
|
||||
session = self.conversation_store.get_session(session_id) if session_id else None
|
||||
if session is None:
|
||||
session = self.conversation_store.create_session()
|
||||
self.conversation_store.save_message(session.session_id, role="user", content=query)
|
||||
history = [{"role": msg.role, "content": msg.content} for msg in session.messages[-10:]]
|
||||
retrieved = self.retrieval_service.retrieve(query=query, top_k=top_k, filters=filters)
|
||||
|
||||
def event_stream() -> Generator[dict, None, None]:
|
||||
"""Handle event stream for the Agent Conversation Service instance."""
|
||||
yield {"event": "status", "data": f"找到{len(retrieved)}条相关法规,正在生成回答..."}
|
||||
answer_parts: list[str] = []
|
||||
sources_payload: list[dict] = []
|
||||
for event in self.answer_generator.stream_generate(
|
||||
query=query,
|
||||
retrieved_chunks=retrieved,
|
||||
history=history,
|
||||
provider=provider,
|
||||
model=model,
|
||||
prompt_template=prompt_template,
|
||||
):
|
||||
if event.get("event") == "sources":
|
||||
sources_payload = event.get("data", [])
|
||||
if event.get("event") == "content":
|
||||
answer_parts.append(str(event.get("data", "")))
|
||||
yield event
|
||||
full_answer = "".join(answer_parts)
|
||||
self.conversation_store.save_message(
|
||||
session.session_id,
|
||||
role="assistant",
|
||||
content=full_answer,
|
||||
sources=sources_payload,
|
||||
)
|
||||
|
||||
return session.session_id, event_stream()
|
||||
7
backend/app/application/documents/__init__.py
Normal file
7
backend/app/application/documents/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Initialize the app.application.documents package."""
|
||||
|
||||
from .services import DocumentCommandService, DocumentProcessResult, DocumentQueryService
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = ["DocumentCommandService", "DocumentProcessResult", "DocumentQueryService"]
|
||||
186
backend/app/application/documents/services.py
Normal file
186
backend/app/application/documents/services.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""Implement application-layer logic for services."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from app.domain.documents import (
|
||||
ChunkBuilder,
|
||||
Document,
|
||||
DocumentBinaryStore,
|
||||
DocumentParser,
|
||||
DocumentRepository,
|
||||
DocumentStatus,
|
||||
)
|
||||
from app.domain.retrieval import EmbeddingProvider, VectorIndex
|
||||
# Keep orchestration logic centralized so use-case flow stays easy to trace.
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentProcessResult:
|
||||
"""Represent document process result data."""
|
||||
doc_id: str
|
||||
doc_name: str
|
||||
status: str
|
||||
message: str
|
||||
num_chunks: int = 0
|
||||
summary: str = ""
|
||||
summary_latency_ms: int = 0
|
||||
|
||||
|
||||
class DocumentCommandService:
|
||||
"""Provide the Document Command Service service."""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
document_repository: DocumentRepository,
|
||||
binary_store: DocumentBinaryStore,
|
||||
parser: DocumentParser,
|
||||
chunk_builder: ChunkBuilder,
|
||||
embedding_provider: EmbeddingProvider,
|
||||
vector_index: VectorIndex,
|
||||
) -> None:
|
||||
"""Initialize the Document Command Service instance."""
|
||||
self.document_repository = document_repository
|
||||
self.binary_store = binary_store
|
||||
self.parser = parser
|
||||
self.chunk_builder = chunk_builder
|
||||
self.embedding_provider = embedding_provider
|
||||
self.vector_index = vector_index
|
||||
|
||||
def upload_and_process(
|
||||
self,
|
||||
*,
|
||||
doc_id: str | None = None,
|
||||
file_name: str,
|
||||
content: bytes,
|
||||
content_type: str,
|
||||
doc_name: str | None,
|
||||
regulation_type: str,
|
||||
version: str,
|
||||
generate_summary: bool,
|
||||
) -> DocumentProcessResult:
|
||||
"""Handle upload and process for the Document Command Service instance."""
|
||||
doc_id = doc_id or str(uuid.uuid4())[:8]
|
||||
final_doc_name = doc_name or file_name
|
||||
object_name = f"{doc_id}/{file_name}"
|
||||
|
||||
document = Document(
|
||||
doc_id=doc_id,
|
||||
doc_name=final_doc_name,
|
||||
file_name=file_name,
|
||||
object_name=object_name,
|
||||
content_type=content_type,
|
||||
size_bytes=len(content),
|
||||
regulation_type=regulation_type,
|
||||
version=version,
|
||||
metadata={"generate_summary": generate_summary},
|
||||
)
|
||||
self.document_repository.create(document)
|
||||
|
||||
temp_path = ""
|
||||
try:
|
||||
self.binary_store.save(
|
||||
object_name=object_name,
|
||||
data=content,
|
||||
content_type=content_type,
|
||||
metadata={"doc_id": doc_id},
|
||||
)
|
||||
self.document_repository.update_status(doc_id, DocumentStatus.STORED)
|
||||
|
||||
suffix = os.path.splitext(file_name)[1]
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
||||
temp_file.write(content)
|
||||
temp_path = temp_file.name
|
||||
|
||||
parsed_document = self.parser.parse(
|
||||
file_path=temp_path,
|
||||
doc_id=doc_id,
|
||||
doc_name=final_doc_name,
|
||||
)
|
||||
self.document_repository.update_status(
|
||||
doc_id,
|
||||
DocumentStatus.PARSED,
|
||||
parser_name=parsed_document.parser_name,
|
||||
metadata={"structure_nodes": len(parsed_document.structure_nodes)},
|
||||
)
|
||||
|
||||
chunks = self.chunk_builder.build(
|
||||
parsed_document=parsed_document,
|
||||
regulation_type=regulation_type,
|
||||
version=version,
|
||||
)
|
||||
if not chunks:
|
||||
raise ValueError("解析完成但没有生成可入库的 chunks")
|
||||
|
||||
vectors = self.embedding_provider.embed_texts([chunk.embedding_text for chunk in chunks])
|
||||
inserted = self.vector_index.upsert(chunks, vectors)
|
||||
if inserted != len(chunks):
|
||||
logger.warning("Milvus upsert count mismatched: inserted={}, chunks={}", inserted, len(chunks))
|
||||
|
||||
self.document_repository.update_status(
|
||||
doc_id,
|
||||
DocumentStatus.INDEXED,
|
||||
chunk_count=len(chunks),
|
||||
summary="",
|
||||
summary_latency_ms=0,
|
||||
index_name=self.vector_index.health().get("collection_name", ""),
|
||||
)
|
||||
stored = self.document_repository.get(doc_id)
|
||||
return DocumentProcessResult(
|
||||
doc_id=doc_id,
|
||||
doc_name=final_doc_name,
|
||||
status=(stored.status.value if stored else DocumentStatus.INDEXED.value),
|
||||
message="处理成功",
|
||||
num_chunks=len(chunks),
|
||||
summary=stored.summary if stored else "",
|
||||
summary_latency_ms=stored.summary_latency_ms if stored else 0,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("文档处理失败: doc_id={}", doc_id)
|
||||
self.document_repository.update_status(
|
||||
doc_id,
|
||||
DocumentStatus.FAILED,
|
||||
error_message=str(exc),
|
||||
)
|
||||
return DocumentProcessResult(
|
||||
doc_id=doc_id,
|
||||
doc_name=final_doc_name,
|
||||
status=DocumentStatus.FAILED.value,
|
||||
message=f"文档处理失败: {exc}",
|
||||
)
|
||||
finally:
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except OSError:
|
||||
logger.warning("临时文件清理失败: {}", temp_path)
|
||||
|
||||
|
||||
class DocumentQueryService:
|
||||
"""Provide the Document Query Service service."""
|
||||
def __init__(self, *, document_repository: DocumentRepository, binary_store: DocumentBinaryStore) -> None:
|
||||
"""Initialize the Document Query Service instance."""
|
||||
self.document_repository = document_repository
|
||||
self.binary_store = binary_store
|
||||
|
||||
def get(self, doc_id: str) -> Document | None:
|
||||
"""Handle get for the Document Query Service instance."""
|
||||
return self.document_repository.get(doc_id)
|
||||
|
||||
def list_documents(self, limit: int | None = None) -> list[Document]:
|
||||
"""List documents for the Document Query Service instance."""
|
||||
return self.document_repository.list(limit=limit)
|
||||
|
||||
def download(self, doc_id: str) -> tuple[Document, bytes]:
|
||||
"""Handle download for the Document Query Service instance."""
|
||||
document = self.document_repository.get(doc_id)
|
||||
if not document:
|
||||
raise FileNotFoundError(f"文档不存在: {doc_id}")
|
||||
return document, self.binary_store.read(document.object_name)
|
||||
7
backend/app/application/knowledge/__init__.py
Normal file
7
backend/app/application/knowledge/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Initialize the app.application.knowledge package."""
|
||||
|
||||
from .services import KnowledgeRetrievalService
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = ["KnowledgeRetrievalService"]
|
||||
19
backend/app/application/knowledge/services.py
Normal file
19
backend/app/application/knowledge/services.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Implement application-layer logic for services."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.domain.retrieval import RetrievalQuery, Retriever, RetrievedChunk
|
||||
# Keep orchestration logic centralized so use-case flow stays easy to trace.
|
||||
|
||||
|
||||
|
||||
class KnowledgeRetrievalService:
|
||||
"""Provide the Knowledge Retrieval Service service."""
|
||||
def __init__(self, *, retriever: Retriever) -> None:
|
||||
"""Initialize the Knowledge Retrieval Service instance."""
|
||||
self.retriever = retriever
|
||||
|
||||
def retrieve(self, *, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
|
||||
"""Handle retrieve for the Knowledge Retrieval Service instance."""
|
||||
retrieval_query = RetrievalQuery(query=query, top_k=top_k, filters=filters)
|
||||
return self.retriever.retrieve(retrieval_query)
|
||||
@@ -1,5 +1,7 @@
|
||||
"""配置模块"""
|
||||
"""Initialize the app.config package."""
|
||||
|
||||
from .settings import Settings, get_settings, settings
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = ["Settings", "get_settings", "settings"]
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
"""日志配置"""
|
||||
"""Configure backend settings for logging."""
|
||||
|
||||
from loguru import logger
|
||||
import sys
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
|
||||
|
||||
|
||||
def setup_logging(level: str = "INFO"):
|
||||
"""设置日志配置"""
|
||||
"""Handle setup logging."""
|
||||
|
||||
# 移除默认handler
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
logger.remove()
|
||||
|
||||
# 添加控制台输出
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
level=level,
|
||||
@@ -18,7 +20,7 @@ def setup_logging(level: str = "INFO"):
|
||||
colorize=True
|
||||
)
|
||||
|
||||
# 添加文件输出
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
logger.add(
|
||||
"logs/app_{time:YYYY-MM-DD}.log",
|
||||
level=level,
|
||||
|
||||
@@ -1,94 +1,119 @@
|
||||
"""配置管理 - 环境变量和默认配置"""
|
||||
"""Configure backend settings for settings."""
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from pydantic import Field
|
||||
from typing import Optional
|
||||
from functools import lru_cache
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
|
||||
|
||||
ROOT_DIR = Path(__file__).resolve().parents[3]
|
||||
ROOT_ENV_FILES = (
|
||||
ROOT_DIR / ".env",
|
||||
ROOT_DIR / ".env.development",
|
||||
)
|
||||
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""应用配置"""
|
||||
"""Define configuration for settings."""
|
||||
|
||||
# 应用基础配置
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=tuple(str(env_file) for env_file in ROOT_ENV_FILES),
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
app_name: str = Field(default="AI Regulations Demo", description="Application name")
|
||||
app_version: str = Field(default="0.1.0", description="应用版本")
|
||||
debug: bool = Field(default=False, description="调试模式")
|
||||
|
||||
# Milvus向量数据库配置
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
milvus_host: str = Field(default="localhost", description="Milvus服务地址")
|
||||
milvus_port: int = Field(default=19530, description="Milvus服务端口")
|
||||
milvus_collection: str = Field(default="regulations", description="法规向量集合名称")
|
||||
milvus_collection: str = Field(default="regulations_dense_1536", description="法规向量集合名称")
|
||||
milvus_db_name: str = Field(default="default", description="Milvus数据库名称")
|
||||
|
||||
# 嵌入模型配置
|
||||
embedding_model: str = Field(default="BAAI/bge-m3", description="嵌入模型名称")
|
||||
embedding_dim: int = Field(default=1024, description="嵌入向量维度")
|
||||
embedding_max_length: int = Field(default=8192, description="最大嵌入长度")
|
||||
embedding_batch_size: int = Field(default=12, description="嵌入批处理大小")
|
||||
embedding_use_fp16: bool = Field(default=True, description="使用FP16加速")
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
embedding_model: str = Field(default="text-embedding-v3", description="嵌入模型名称")
|
||||
embedding_dim: int = Field(default=1536, description="嵌入向量维度")
|
||||
embedding_api_key: str = Field(default="", description="Embedding API密钥")
|
||||
embedding_base_url: str = Field(default="http://6.86.80.4:30080/v1", description="Embedding API地址")
|
||||
embedding_timeout_seconds: int = Field(default=120, description="Embedding API超时时间(秒)")
|
||||
alibaba_access_key_id: str = Field(default="", description="阿里云文档解析 Access Key ID")
|
||||
alibaba_access_key_secret: str = Field(default="", description="阿里云文档解析 Access Key Secret")
|
||||
alibaba_endpoint: str = Field(default="docmind-api.cn-hangzhou.aliyuncs.com", description="阿里云文档解析 endpoint")
|
||||
|
||||
# MinIO对象存储配置
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
minio_endpoint: str = Field(default="localhost:9000", description="MinIO服务地址")
|
||||
minio_access_key: str = Field(default="minioadmin", description="MinIO访问密钥")
|
||||
minio_secret_key: str = Field(default="minioadmin123", description="MinIO秘密密钥")
|
||||
minio_bucket: str = Field(default="upload-files", description="文档存储桶名称")
|
||||
minio_secure: bool = Field(default=False, description="是否使用HTTPS")
|
||||
|
||||
# Redis配置
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
redis_host: str = Field(default="localhost", description="Redis服务地址")
|
||||
redis_port: int = Field(default=6379, description="Redis服务端口")
|
||||
redis_password: str = Field(default="", description="Redis密码")
|
||||
redis_db: int = Field(default=0, description="Redis数据库编号")
|
||||
|
||||
# PostgreSQL配置
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
postgres_host: str = Field(default="localhost", description="PostgreSQL服务地址")
|
||||
postgres_port: int = Field(default=5432, description="PostgreSQL服务端口")
|
||||
postgres_user: str = Field(default="compliance", description="PostgreSQL用户名")
|
||||
postgres_password: str = Field(default="compliance123", description="PostgreSQL密码")
|
||||
postgres_db: str = Field(default="compliance_db", description="PostgreSQL数据库名称")
|
||||
|
||||
# 文档处理配置
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
chunk_size: int = Field(default=512, description="分块大小(字符数)")
|
||||
chunk_overlap: int = Field(default=50, description="分块重叠大小")
|
||||
max_file_size_mb: int = Field(default=100, description="最大文件大小(MB)")
|
||||
document_metadata_path: str = Field(default="backend/data/documents.json", description="文档元数据存储路径")
|
||||
parser_backend: str = Field(default="local", description="解析后端(local/aliyun)")
|
||||
chunk_backend: str = Field(default="local", description="分块后端(local/aliyun)")
|
||||
|
||||
# API配置
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
api_host: str = Field(default="0.0.0.0", description="API服务地址")
|
||||
api_port: int = Field(default=8000, description="API服务端口")
|
||||
|
||||
# LLM配置
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
llm_provider: str = Field(default="deepseek", description="LLM提供商 (deepseek/qwen/qwen_vl)")
|
||||
llm_model: str = Field(default="deepseek-v4-flash", description="LLM模型名称")
|
||||
llm_max_tokens: int = Field(default=4096, description="LLM最大输出token数")
|
||||
llm_temperature: float = Field(default=0.7, description="LLM温度参数")
|
||||
|
||||
# DeepSeek配置
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
deepseek_api_key: str = Field(default="", description="DeepSeek API密钥")
|
||||
deepseek_base_url: str = Field(default="http://6.86.80.4:30080/v1", description="DeepSeek API地址")
|
||||
deepseek_model: str = Field(default="deepseek-v4-flash", description="DeepSeek模型")
|
||||
|
||||
# Qwen配置(通过统一代理API)
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
qwen_api_key: str = Field(default="", description="Qwen API密钥")
|
||||
qwen_base_url: str = Field(default="http://6.86.80.4:30080/v1", description="Qwen API地址")
|
||||
qwen_model: str = Field(default="qwen3.5-flash", description="Qwen文本模型")
|
||||
qwen_vl_model: str = Field(default="qwen3-vl-plus", description="Qwen视觉模型")
|
||||
|
||||
# RAG配置
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
rag_top_k: int = Field(default=5, description="检索召回数量")
|
||||
rag_max_context_tokens: int = Field(default=2000, description="RAG最大上下文token数")
|
||||
rag_summary_max_tokens: int = Field(default=10240, description="文档摘要最大token数")
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
extra = "ignore"
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
milvus_index_type: str = Field(default="IVF_FLAT", description="Milvus索引类型")
|
||||
milvus_nlist: int = Field(default=128, description="Milvus nlist参数")
|
||||
milvus_nprobe: int = Field(default=16, description="Milvus nprobe参数")
|
||||
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
session_max_sessions: int = Field(default=100, description="最大会话数量")
|
||||
session_timeout_minutes: int = Field(default=30, description="会话超时时间(分钟)")
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
"""获取配置实例(缓存)"""
|
||||
"""Return settings."""
|
||||
return Settings()
|
||||
|
||||
|
||||
# 导出默认配置实例
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
settings = get_settings()
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
"""Initialize the app.core package."""
|
||||
|
||||
from .config import settings, Settings
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = ["settings", "Settings"]
|
||||
@@ -1,41 +1,54 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import Optional
|
||||
"""Legacy-compatible config used by older utility modules."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
# Keep legacy settings aligned with the root-level env loading rules.
|
||||
|
||||
|
||||
ROOT_DIR = Path(__file__).resolve().parents[3]
|
||||
ROOT_ENV_FILES = tuple(str(path) for path in (ROOT_DIR / ".env", ROOT_DIR / ".env.development"))
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# DashScope API
|
||||
"""Define configuration for settings."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=ROOT_ENV_FILES,
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
dashscope_api_key: str = ""
|
||||
|
||||
# Milvus
|
||||
milvus_host: str = "localhost"
|
||||
milvus_port: int = 19530
|
||||
milvus_collection: str = "regulations_dense_1536"
|
||||
|
||||
# LLM配置
|
||||
# LLM / embedding defaults aligned with the migrated backend path.
|
||||
llm_model: str = "qwen-max"
|
||||
embedding_model: str = "text-embedding-v3"
|
||||
embedding_dim: int = 1536
|
||||
|
||||
# 检索配置
|
||||
# Legacy workflow compatibility only.
|
||||
vector_top_k: int = 10
|
||||
bm25_top_k: int = 10
|
||||
final_top_k: int = 5
|
||||
|
||||
# 分块配置
|
||||
# Legacy local chunking compatibility only; main ingest now uses Aliyun vector_chunks.
|
||||
chunk_size: int = 800
|
||||
chunk_overlap: int = 50
|
||||
|
||||
# 服务配置
|
||||
# Service config.
|
||||
api_host: str = "0.0.0.0"
|
||||
api_port: int = 8000
|
||||
|
||||
# Collection名称
|
||||
regulations_collection: str = "vehicle_regulations"
|
||||
# Legacy aliases retained for old utility modules.
|
||||
regulations_collection: str = "regulations_dense_1536"
|
||||
compliance_collection: str = "compliance_cache"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
case_sensitive = False
|
||||
|
||||
|
||||
settings = Settings()
|
||||
# Preserve the legacy module API while keeping env resolution centralized at the repo root.
|
||||
settings = Settings()
|
||||
|
||||
5
backend/app/domain/__init__.py
Normal file
5
backend/app/domain/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Initialize the app.domain package."""
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = []
|
||||
15
backend/app/domain/conversation/__init__.py
Normal file
15
backend/app/domain/conversation/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Initialize the app.domain.conversation package."""
|
||||
|
||||
from .models import AnswerResult, AnswerSource, ConversationMessage, ConversationSession
|
||||
from .ports import AnswerGenerator, ConversationStore
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AnswerGenerator",
|
||||
"AnswerResult",
|
||||
"AnswerSource",
|
||||
"ConversationMessage",
|
||||
"ConversationSession",
|
||||
"ConversationStore",
|
||||
]
|
||||
53
backend/app/domain/conversation/models.py
Normal file
53
backend/app/domain/conversation/models.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Define domain models for conversation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
# Keep module behavior explicit so the backend flow stays easy to audit.
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnswerSource:
|
||||
"""Represent answer source data."""
|
||||
doc_id: str
|
||||
doc_name: str
|
||||
chunk_id: str
|
||||
section_title: str
|
||||
page_number: int
|
||||
score: float
|
||||
content: str
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationMessage:
|
||||
"""Represent conversation message data."""
|
||||
role: str
|
||||
content: str
|
||||
timestamp: int
|
||||
sources: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationSession:
|
||||
"""Represent conversation session data."""
|
||||
session_id: str
|
||||
messages: list[ConversationMessage] = field(default_factory=list)
|
||||
created_at: int = 0
|
||||
updated_at: int = 0
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnswerResult:
|
||||
"""Represent answer result data."""
|
||||
answer: str
|
||||
sources: list[AnswerSource] = field(default_factory=list)
|
||||
model: str = ""
|
||||
latency_ms: int = 0
|
||||
retrieved_count: int = 0
|
||||
context_tokens: int = 0
|
||||
truncated: bool = False
|
||||
error: str | None = None
|
||||
78
backend/app/domain/conversation/ports.py
Normal file
78
backend/app/domain/conversation/ports.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Define domain ports for conversation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generator
|
||||
|
||||
from app.domain.retrieval.models import RetrievedChunk
|
||||
|
||||
from .models import AnswerResult, ConversationSession
|
||||
# Keep domain contracts explicit so adapters can swap implementations cleanly.
|
||||
|
||||
|
||||
|
||||
class AnswerGenerator(ABC):
|
||||
"""Represent the Answer Generator type."""
|
||||
@abstractmethod
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
retrieved_chunks: list[RetrievedChunk],
|
||||
history: list[dict[str, str]] | None = None,
|
||||
provider: str | None = None,
|
||||
model: str | None = None,
|
||||
prompt_template: str | None = None,
|
||||
) -> AnswerResult:
|
||||
"""Handle generate for the Answer Generator instance."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stream_generate(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
retrieved_chunks: list[RetrievedChunk],
|
||||
history: list[dict[str, str]] | None = None,
|
||||
provider: str | None = None,
|
||||
model: str | None = None,
|
||||
prompt_template: str | None = None,
|
||||
) -> Generator[dict, None, AnswerResult]:
|
||||
"""Stream generate for the Answer Generator instance."""
|
||||
pass
|
||||
|
||||
|
||||
class ConversationStore(ABC):
|
||||
"""Provide the Conversation Store store implementation."""
|
||||
@abstractmethod
|
||||
def create_session(self, metadata: dict | None = None) -> ConversationSession:
|
||||
"""Create session for the Conversation Store instance."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_session(self, session_id: str) -> ConversationSession | None:
|
||||
"""Return session for the Conversation Store instance."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_message(
|
||||
self,
|
||||
session_id: str,
|
||||
*,
|
||||
role: str,
|
||||
content: str,
|
||||
sources: list[dict] | None = None,
|
||||
) -> ConversationSession | None:
|
||||
"""Save message for the Conversation Store instance."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_session(self, session_id: str) -> bool:
|
||||
"""Delete session for the Conversation Store instance."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_sessions(self) -> list[dict]:
|
||||
"""List sessions for the Conversation Store instance."""
|
||||
pass
|
||||
17
backend/app/domain/documents/__init__.py
Normal file
17
backend/app/domain/documents/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Initialize the app.domain.documents package."""
|
||||
|
||||
from .models import Chunk, Document, DocumentStatus, ParsedDocument
|
||||
from .ports import ChunkBuilder, DocumentBinaryStore, DocumentParser, DocumentRepository
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Chunk",
|
||||
"Document",
|
||||
"DocumentStatus",
|
||||
"ParsedDocument",
|
||||
"ChunkBuilder",
|
||||
"DocumentBinaryStore",
|
||||
"DocumentParser",
|
||||
"DocumentRepository",
|
||||
]
|
||||
77
backend/app/domain/documents/models.py
Normal file
77
backend/app/domain/documents/models.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Define domain models for documents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
# Keep module behavior explicit so the backend flow stays easy to audit.
|
||||
|
||||
|
||||
def utcnow() -> datetime:
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
|
||||
class DocumentStatus(str, Enum):
|
||||
"""Define the Document Status enumeration."""
|
||||
PENDING = "pending"
|
||||
STORED = "stored"
|
||||
PARSED = "parsed"
|
||||
INDEXED = "indexed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Document:
|
||||
"""Represent the Document type."""
|
||||
doc_id: str
|
||||
doc_name: str
|
||||
file_name: str
|
||||
object_name: str
|
||||
content_type: str
|
||||
size_bytes: int
|
||||
status: DocumentStatus = DocumentStatus.PENDING
|
||||
regulation_type: str = ""
|
||||
version: str = ""
|
||||
summary: str = ""
|
||||
summary_latency_ms: int = 0
|
||||
chunk_count: int = 0
|
||||
parser_name: str = ""
|
||||
index_name: str = ""
|
||||
error_message: str = ""
|
||||
created_at: datetime = field(default_factory=utcnow)
|
||||
updated_at: datetime = field(default_factory=utcnow)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedDocument:
|
||||
"""Represent the Parsed Document type."""
|
||||
doc_id: str
|
||||
doc_name: str
|
||||
structure_nodes: list[dict[str, Any]]
|
||||
semantic_blocks: list[dict[str, Any]]
|
||||
vector_chunks: list[dict[str, Any]]
|
||||
parser_name: str
|
||||
raw_text: str = ""
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chunk:
|
||||
"""Represent the Chunk type."""
|
||||
chunk_id: str
|
||||
doc_id: str
|
||||
doc_name: str
|
||||
content: str
|
||||
embedding_text: str
|
||||
section_title: str = ""
|
||||
section_path: list[str] = field(default_factory=list)
|
||||
page_number: int = 0
|
||||
regulation_type: str = ""
|
||||
version: str = ""
|
||||
semantic_id: str = ""
|
||||
block_type: str = ""
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
96
backend/app/domain/documents/ports.py
Normal file
96
backend/app/domain/documents/ports.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""Define domain ports for documents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from .models import Chunk, Document, DocumentStatus, ParsedDocument
|
||||
# Keep domain contracts explicit so adapters can swap implementations cleanly.
|
||||
|
||||
|
||||
|
||||
class DocumentRepository(ABC):
|
||||
"""Provide the Document Repository repository implementation."""
|
||||
@abstractmethod
|
||||
def create(self, document: Document) -> Document:
|
||||
"""Handle create for the Document Repository instance."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, document: Document) -> Document:
|
||||
"""Handle update for the Document Repository instance."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, doc_id: str) -> Document | None:
|
||||
"""Handle get for the Document Repository instance."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list(self, limit: int | None = None) -> list[Document]:
|
||||
"""Handle list for the Document Repository instance."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_status(
|
||||
self,
|
||||
doc_id: str,
|
||||
status: DocumentStatus,
|
||||
*,
|
||||
error_message: str = "",
|
||||
chunk_count: int | None = None,
|
||||
summary: str | None = None,
|
||||
summary_latency_ms: int | None = None,
|
||||
parser_name: str | None = None,
|
||||
index_name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> Document | None:
|
||||
"""Update status for the Document Repository instance."""
|
||||
pass
|
||||
|
||||
|
||||
class DocumentBinaryStore(ABC):
|
||||
"""Provide the Document Binary Store store implementation."""
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
*,
|
||||
object_name: str,
|
||||
data: bytes,
|
||||
content_type: str,
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Handle save for the Document Binary Store instance."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def read(self, object_name: str) -> bytes:
|
||||
"""Handle read for the Document Binary Store instance."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, object_name: str) -> None:
|
||||
"""Handle delete for the Document Binary Store instance."""
|
||||
pass
|
||||
|
||||
|
||||
class DocumentParser(ABC):
|
||||
"""Provide the Document Parser parser."""
|
||||
@abstractmethod
|
||||
def parse(self, *, file_path: str, doc_id: str, doc_name: str) -> ParsedDocument:
|
||||
"""Handle parse for the Document Parser instance."""
|
||||
pass
|
||||
|
||||
|
||||
class ChunkBuilder(ABC):
|
||||
"""Provide the Chunk Builder builder."""
|
||||
@abstractmethod
|
||||
def build(
|
||||
self,
|
||||
*,
|
||||
parsed_document: ParsedDocument,
|
||||
regulation_type: str,
|
||||
version: str,
|
||||
) -> list[Chunk]:
|
||||
"""Handle build for the Chunk Builder instance."""
|
||||
pass
|
||||
8
backend/app/domain/retrieval/__init__.py
Normal file
8
backend/app/domain/retrieval/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Initialize the app.domain.retrieval package."""
|
||||
|
||||
from .models import RetrievalQuery, RetrievedChunk
|
||||
from .ports import EmbeddingProvider, Retriever, VectorIndex
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = ["RetrievalQuery", "RetrievedChunk", "EmbeddingProvider", "Retriever", "VectorIndex"]
|
||||
29
backend/app/domain/retrieval/models.py
Normal file
29
backend/app/domain/retrieval/models.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Define domain models for retrieval."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
# Keep module behavior explicit so the backend flow stays easy to audit.
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalQuery:
|
||||
"""Represent the Retrieval Query type."""
|
||||
query: str
|
||||
top_k: int
|
||||
filters: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievedChunk:
|
||||
"""Represent the Retrieved Chunk type."""
|
||||
chunk_id: str
|
||||
doc_id: str
|
||||
doc_name: str
|
||||
content: str
|
||||
score: float
|
||||
section_title: str = ""
|
||||
page_number: int = 0
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
60
backend/app/domain/retrieval/ports.py
Normal file
60
backend/app/domain/retrieval/ports.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Define domain ports for retrieval."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from app.domain.documents.models import Chunk
|
||||
|
||||
from .models import RetrievalQuery, RetrievedChunk
|
||||
# Keep domain contracts explicit so adapters can swap implementations cleanly.
|
||||
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
"""Provide the Embedding Provider provider."""
|
||||
@abstractmethod
|
||||
def embed_texts(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed texts for the Embedding Provider instance."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed query for the Embedding Provider instance."""
|
||||
pass
|
||||
|
||||
|
||||
class VectorIndex(ABC):
|
||||
"""Provide the Vector Index index implementation."""
|
||||
@abstractmethod
|
||||
def upsert(self, chunks: list[Chunk], vectors: list[list[float]]) -> int:
|
||||
"""Handle upsert for the Vector Index instance."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_document(self, doc_id: str) -> int:
|
||||
"""Delete by document for the Vector Index instance."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, query_vector: list[float], top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
|
||||
"""Handle search for the Vector Index instance."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def health(self) -> dict:
|
||||
"""Handle health for the Vector Index instance."""
|
||||
pass
|
||||
|
||||
|
||||
class Retriever(ABC):
|
||||
"""Provide the Retriever retriever."""
|
||||
@abstractmethod
|
||||
def retrieve(self, query: RetrievalQuery) -> list[RetrievedChunk]:
|
||||
"""Handle retrieve for the Retriever instance."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
|
||||
"""Handle search for the Retriever instance."""
|
||||
pass
|
||||
5
backend/app/infrastructure/__init__.py
Normal file
5
backend/app/infrastructure/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Initialize the app.infrastructure package."""
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = []
|
||||
5
backend/app/infrastructure/embedding/__init__.py
Normal file
5
backend/app/infrastructure/embedding/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Initialize the app.infrastructure.embedding package."""
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = []
|
||||
@@ -0,0 +1,59 @@
|
||||
"""Implement infrastructure support for openai compatible embedding provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.domain.retrieval import EmbeddingProvider
|
||||
# Keep adapter behavior explicit so integration details remain easy to audit.
|
||||
|
||||
|
||||
|
||||
class OpenAICompatibleEmbeddingProvider(EmbeddingProvider):
|
||||
"""Provide the Open A I Compatible Embedding Provider provider."""
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the Open A I Compatible Embedding Provider instance."""
|
||||
self.base_url = settings.embedding_base_url.rstrip("/")
|
||||
self.api_key = (
|
||||
settings.embedding_api_key
|
||||
or os.getenv("OPENAI_API_KEY", "")
|
||||
or os.getenv("QWEN_API_KEY", "")
|
||||
or os.getenv("DEEPSEEK_API_KEY", "")
|
||||
)
|
||||
self.model = settings.embedding_model
|
||||
self.timeout = settings.embedding_timeout_seconds
|
||||
self.dimension = settings.embedding_dim
|
||||
|
||||
def _request(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Handle request for this module for the Open A I Compatible Embedding Provider instance."""
|
||||
if not self.api_key:
|
||||
raise ValueError("缺少 EMBEDDING_API_KEY / OPENAI_API_KEY")
|
||||
response = httpx.post(
|
||||
f"{self.base_url}/embeddings",
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"model": self.model, "input": texts},
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
vectors = [item["embedding"] for item in sorted(data.get("data", []), key=lambda item: item["index"])]
|
||||
if any(len(vector) != self.dimension for vector in vectors):
|
||||
raise ValueError(f"embedding 维度不匹配,期望 {self.dimension}")
|
||||
return vectors
|
||||
|
||||
def embed_texts(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed texts for the Open A I Compatible Embedding Provider instance."""
|
||||
if not texts:
|
||||
return []
|
||||
return self._request(texts)
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed query for the Open A I Compatible Embedding Provider instance."""
|
||||
vectors = self._request([text])
|
||||
return vectors[0]
|
||||
@@ -0,0 +1,144 @@
|
||||
"""Implement infrastructure support for openai compatible answer generator."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Generator
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.domain.conversation import AnswerGenerator, AnswerResult, AnswerSource
|
||||
from app.domain.retrieval import RetrievedChunk
|
||||
from app.services.llm.llm_factory import get_llm_client
|
||||
# Keep adapter behavior explicit so integration details remain easy to audit.
|
||||
|
||||
|
||||
|
||||
PROMPT_TEMPLATES = {
|
||||
"default": "你是法规知识问答助手。请仅依据提供的上下文回答;如果上下文不足,明确说明。",
|
||||
"compliance_qa": "你是法规合规问答助手。优先引用给定法规原文,回答要准确、克制,并注明依据来源。",
|
||||
}
|
||||
|
||||
|
||||
class OpenAICompatibleAnswerGenerator(AnswerGenerator):
|
||||
"""Represent the Open A I Compatible Answer Generator type."""
|
||||
def _build_messages(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
retrieved_chunks: list[RetrievedChunk],
|
||||
history: list[dict[str, str]] | None,
|
||||
prompt_template: str | None,
|
||||
) -> tuple[list[dict[str, str]], int]:
|
||||
"""Handle build messages for this module for the Open A I Compatible Answer Generator instance."""
|
||||
system_prompt = PROMPT_TEMPLATES.get(prompt_template or "compliance_qa", PROMPT_TEMPLATES["default"])
|
||||
context_blocks = []
|
||||
context_tokens = 0
|
||||
for idx, chunk in enumerate(retrieved_chunks, start=1):
|
||||
block = (
|
||||
f"[{idx}] 文档: {chunk.doc_name}\n"
|
||||
f"章节: {chunk.section_title or '未标注'}\n"
|
||||
f"页码: {chunk.page_number}\n"
|
||||
f"内容: {chunk.content}"
|
||||
)
|
||||
context_tokens += len(block)
|
||||
context_blocks.append(block)
|
||||
context = "\n\n".join(context_blocks)[: settings.rag_max_context_tokens * 4]
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
for item in history or []:
|
||||
messages.append({"role": item["role"], "content": item["content"]})
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"问题:{query}\n\n参考上下文:\n{context}\n\n请在回答后给出简要引用编号。",
|
||||
}
|
||||
)
|
||||
return messages, min(context_tokens, settings.rag_max_context_tokens)
|
||||
|
||||
def _sources(self, chunks: list[RetrievedChunk]) -> list[AnswerSource]:
|
||||
"""Handle sources for this module for the Open A I Compatible Answer Generator instance."""
|
||||
return [
|
||||
AnswerSource(
|
||||
doc_id=chunk.doc_id,
|
||||
doc_name=chunk.doc_name,
|
||||
chunk_id=chunk.chunk_id,
|
||||
section_title=chunk.section_title,
|
||||
page_number=chunk.page_number,
|
||||
score=chunk.score,
|
||||
content=chunk.content,
|
||||
metadata=chunk.metadata,
|
||||
)
|
||||
for chunk in chunks
|
||||
]
|
||||
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
retrieved_chunks: list[RetrievedChunk],
|
||||
history: list[dict[str, str]] | None = None,
|
||||
provider: str | None = None,
|
||||
model: str | None = None,
|
||||
prompt_template: str | None = None,
|
||||
) -> AnswerResult:
|
||||
"""Handle generate for the Open A I Compatible Answer Generator instance."""
|
||||
start = time.time()
|
||||
messages, context_tokens = self._build_messages(
|
||||
query=query,
|
||||
retrieved_chunks=retrieved_chunks,
|
||||
history=history,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
client = get_llm_client(provider=provider or settings.llm_provider, model=model or settings.llm_model)
|
||||
response = client.chat(messages)
|
||||
latency_ms = int((time.time() - start) * 1000)
|
||||
return AnswerResult(
|
||||
answer=response.content if response.is_success else "",
|
||||
sources=self._sources(retrieved_chunks),
|
||||
model=response.model or (model or settings.llm_model),
|
||||
latency_ms=latency_ms,
|
||||
retrieved_count=len(retrieved_chunks),
|
||||
context_tokens=context_tokens,
|
||||
truncated=False,
|
||||
error=response.error,
|
||||
)
|
||||
|
||||
def stream_generate(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
retrieved_chunks: list[RetrievedChunk],
|
||||
history: list[dict[str, str]] | None = None,
|
||||
provider: str | None = None,
|
||||
model: str | None = None,
|
||||
prompt_template: str | None = None,
|
||||
) -> Generator[dict, None, AnswerResult]:
|
||||
"""Stream generate for the Open A I Compatible Answer Generator instance."""
|
||||
start = time.time()
|
||||
messages, context_tokens = self._build_messages(
|
||||
query=query,
|
||||
retrieved_chunks=retrieved_chunks,
|
||||
history=history,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
sources = [source.__dict__ for source in self._sources(retrieved_chunks)]
|
||||
yield {"event": "sources", "data": sources}
|
||||
client = get_llm_client(provider=provider or settings.llm_provider, model=model or settings.llm_model)
|
||||
answer_parts: list[str] = []
|
||||
if hasattr(client, "stream_chat"):
|
||||
for chunk in client.stream_chat(messages):
|
||||
answer_parts.append(chunk)
|
||||
yield {"event": "content", "data": chunk}
|
||||
else:
|
||||
response = client.chat(messages)
|
||||
answer_parts.append(response.content)
|
||||
yield {"event": "content", "data": response.content}
|
||||
full_answer = "".join(answer_parts)
|
||||
yield {
|
||||
"event": "done",
|
||||
"data": {
|
||||
"latency_ms": int((time.time() - start) * 1000),
|
||||
"retrieved_count": len(retrieved_chunks),
|
||||
"context_tokens": context_tokens,
|
||||
"model": model or settings.llm_model,
|
||||
},
|
||||
}
|
||||
5
backend/app/infrastructure/parser/__init__.py
Normal file
5
backend/app/infrastructure/parser/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Initialize the app.infrastructure.parser package."""
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = []
|
||||
55
backend/app/infrastructure/parser/aliyun_document_parser.py
Normal file
55
backend/app/infrastructure/parser/aliyun_document_parser.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Implement infrastructure support for aliyun document parser."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.aliyun_parser.parse_pdf import (
|
||||
MAX_CHARS,
|
||||
OVERLAP_CHARS,
|
||||
build_semantic_blocks,
|
||||
build_structure_nodes,
|
||||
build_vector_chunks,
|
||||
collect_all_results,
|
||||
init_client,
|
||||
submit_job,
|
||||
wait_for_completion,
|
||||
)
|
||||
from app.domain.documents import DocumentParser, ParsedDocument
|
||||
# Keep adapter behavior explicit so integration details remain easy to audit.
|
||||
|
||||
|
||||
|
||||
class AliyunDocumentParser(DocumentParser):
|
||||
"""Provide the Aliyun Document Parser parser."""
|
||||
parser_name = "aliyun_docmind"
|
||||
|
||||
def parse(self, *, file_path: str, doc_id: str, doc_name: str) -> ParsedDocument:
|
||||
"""Handle parse for the Aliyun Document Parser instance."""
|
||||
client = init_client()
|
||||
task_id = submit_job(client, file_path)
|
||||
if not wait_for_completion(client, task_id):
|
||||
raise RuntimeError("阿里云文档解析任务失败")
|
||||
layouts = collect_all_results(client, task_id)
|
||||
structure_nodes = build_structure_nodes(layouts)
|
||||
semantic_blocks = build_semantic_blocks(layouts)
|
||||
vector_chunks = build_vector_chunks(
|
||||
semantic_blocks,
|
||||
doc_id=doc_id,
|
||||
doc_title=doc_name,
|
||||
max_chars=MAX_CHARS,
|
||||
overlap_chars=OVERLAP_CHARS,
|
||||
)
|
||||
raw_text = "\n\n".join(
|
||||
block.get("text", "")
|
||||
for block in semantic_blocks
|
||||
if block.get("text")
|
||||
)
|
||||
return ParsedDocument(
|
||||
doc_id=doc_id,
|
||||
doc_name=doc_name,
|
||||
structure_nodes=structure_nodes,
|
||||
semantic_blocks=semantic_blocks,
|
||||
vector_chunks=vector_chunks,
|
||||
parser_name=self.parser_name,
|
||||
raw_text=raw_text,
|
||||
metadata={"task_id": task_id, "layout_count": len(layouts)},
|
||||
)
|
||||
66
backend/app/infrastructure/parser/local_chunk_builder.py
Normal file
66
backend/app/infrastructure/parser/local_chunk_builder.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Local chunk builder adapter for the migrated backend architecture."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.domain.documents import Chunk, ChunkBuilder, ParsedDocument
|
||||
from app.services.embedding.text_chunker import RegulationChunker
|
||||
|
||||
|
||||
class LocalRegulationChunkBuilder(ChunkBuilder):
|
||||
"""Adapt the existing markdown chunker to the new chunk builder port."""
|
||||
|
||||
def __init__(self, *, chunk_size: int = 512, chunk_overlap: int = 50) -> None:
|
||||
self.chunker = RegulationChunker(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
*,
|
||||
parsed_document: ParsedDocument,
|
||||
regulation_type: str,
|
||||
version: str,
|
||||
) -> list[Chunk]:
|
||||
markdown_text = parsed_document.raw_text.strip()
|
||||
if not markdown_text:
|
||||
return []
|
||||
|
||||
legacy_chunks = self.chunker.chunk_document(
|
||||
markdown_text,
|
||||
doc_id=parsed_document.doc_id,
|
||||
doc_name=parsed_document.doc_name,
|
||||
regulation_type=regulation_type,
|
||||
version=version,
|
||||
)
|
||||
|
||||
chunks: list[Chunk] = []
|
||||
for item in legacy_chunks:
|
||||
metadata = {
|
||||
"section_number": item.metadata.section_number,
|
||||
"section_title": item.metadata.section_title,
|
||||
"clause_number": item.metadata.clause_number,
|
||||
"start_position": item.metadata.start_position,
|
||||
"end_position": item.metadata.end_position,
|
||||
"token_count": item.token_count,
|
||||
"source": "local_chunker",
|
||||
}
|
||||
section_path = [value for value in [item.metadata.section_number, item.metadata.section_title] if value]
|
||||
chunks.append(
|
||||
Chunk(
|
||||
chunk_id=item.metadata.chunk_id,
|
||||
doc_id=parsed_document.doc_id,
|
||||
doc_name=parsed_document.doc_name,
|
||||
content=item.content,
|
||||
embedding_text=item.content,
|
||||
section_title=item.metadata.section_title or item.metadata.section_number,
|
||||
section_path=section_path,
|
||||
page_number=item.metadata.page_number,
|
||||
regulation_type=regulation_type,
|
||||
version=version,
|
||||
semantic_id=item.metadata.clause_number,
|
||||
block_type="local_markdown_chunk",
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
return chunks
|
||||
38
backend/app/infrastructure/parser/local_document_parser.py
Normal file
38
backend/app/infrastructure/parser/local_document_parser.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Local parser adapter for the migrated backend architecture."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from app.domain.documents import DocumentParser, ParsedDocument
|
||||
from app.services.parser.docx_parser import parse_docx_to_markdown
|
||||
from app.services.parser.pdf_parser import parse_pdf_to_markdown
|
||||
|
||||
|
||||
class LocalDocumentParser(DocumentParser):
|
||||
"""Adapt the existing local PDF/DOCX parsers to the new parser port."""
|
||||
|
||||
parser_name = "local_markdown_parser"
|
||||
|
||||
def parse(self, *, file_path: str, doc_id: str, doc_name: str) -> ParsedDocument:
|
||||
suffix = Path(file_path).suffix.lower()
|
||||
if suffix == ".pdf":
|
||||
markdown_text = parse_pdf_to_markdown(file_path)
|
||||
elif suffix in {".docx", ".doc"}:
|
||||
markdown_text = parse_docx_to_markdown(file_path)
|
||||
else:
|
||||
raise ValueError(f"不支持的文件类型: {suffix}")
|
||||
|
||||
if not markdown_text.strip():
|
||||
raise ValueError("本地解析完成但未提取到有效文本")
|
||||
|
||||
return ParsedDocument(
|
||||
doc_id=doc_id,
|
||||
doc_name=doc_name,
|
||||
structure_nodes=[],
|
||||
semantic_blocks=[],
|
||||
vector_chunks=[],
|
||||
parser_name=self.parser_name,
|
||||
raw_text=markdown_text,
|
||||
metadata={"source": "local_parser", "file_suffix": suffix},
|
||||
)
|
||||
48
backend/app/infrastructure/parser/vector_chunk_builder.py
Normal file
48
backend/app/infrastructure/parser/vector_chunk_builder.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Implement infrastructure support for vector chunk builder."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.domain.documents import Chunk, ChunkBuilder, ParsedDocument
|
||||
# Keep adapter behavior explicit so integration details remain easy to audit.
|
||||
|
||||
|
||||
|
||||
class AliyunVectorChunkBuilder(ChunkBuilder):
|
||||
"""Provide the Aliyun Vector Chunk Builder builder."""
|
||||
def build(
|
||||
self,
|
||||
*,
|
||||
parsed_document: ParsedDocument,
|
||||
regulation_type: str,
|
||||
version: str,
|
||||
) -> list[Chunk]:
|
||||
"""Handle build for the Aliyun Vector Chunk Builder instance."""
|
||||
chunks: list[Chunk] = []
|
||||
for index, item in enumerate(parsed_document.vector_chunks):
|
||||
content = item.get("content") or item.get("text") or ""
|
||||
embedding_text = item.get("embedding_text") or content
|
||||
if not embedding_text.strip():
|
||||
continue
|
||||
section_path = item.get("section_path") or []
|
||||
section_title = item.get("section_title") or (section_path[-1] if section_path else "")
|
||||
page_number = item.get("page_start") or item.get("page") or 0
|
||||
chunk_id = item.get("chunk_id") or f"{parsed_document.doc_id}-chunk-{index}"
|
||||
metadata = {k: v for k, v in item.items() if k not in {"content", "embedding_text"}}
|
||||
chunks.append(
|
||||
Chunk(
|
||||
chunk_id=str(chunk_id),
|
||||
doc_id=parsed_document.doc_id,
|
||||
doc_name=parsed_document.doc_name,
|
||||
content=content,
|
||||
embedding_text=embedding_text,
|
||||
section_title=section_title,
|
||||
section_path=section_path,
|
||||
page_number=int(page_number or 0),
|
||||
regulation_type=regulation_type,
|
||||
version=version,
|
||||
semantic_id=item.get("semantic_id", ""),
|
||||
block_type=item.get("block_type", ""),
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
return chunks
|
||||
5
backend/app/infrastructure/session/__init__.py
Normal file
5
backend/app/infrastructure/session/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Initialize the app.infrastructure.session package."""
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = []
|
||||
@@ -0,0 +1,95 @@
|
||||
"""Implement infrastructure support for in memory conversation store."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from app.domain.conversation import ConversationMessage, ConversationSession, ConversationStore
|
||||
# Keep adapter behavior explicit so integration details remain easy to audit.
|
||||
|
||||
|
||||
|
||||
class InMemoryConversationStore(ConversationStore):
|
||||
"""Provide the In Memory Conversation Store store implementation."""
|
||||
def __init__(self, *, max_sessions: int = 100, timeout_minutes: int = 30) -> None:
|
||||
"""Initialize the In Memory Conversation Store instance."""
|
||||
self.max_sessions = max_sessions
|
||||
self.timeout_seconds = timeout_minutes * 60
|
||||
self.sessions: dict[str, ConversationSession] = {}
|
||||
|
||||
def _now(self) -> int:
|
||||
"""Handle now for this module for the In Memory Conversation Store instance."""
|
||||
return int(time.time())
|
||||
|
||||
def _cleanup_expired(self) -> None:
|
||||
"""Handle cleanup expired for this module for the In Memory Conversation Store instance."""
|
||||
now = self._now()
|
||||
expired = [
|
||||
session_id
|
||||
for session_id, session in self.sessions.items()
|
||||
if (now - session.updated_at) > self.timeout_seconds
|
||||
]
|
||||
for session_id in expired:
|
||||
self.sessions.pop(session_id, None)
|
||||
|
||||
def create_session(self, metadata: dict | None = None) -> ConversationSession:
|
||||
"""Create session for the In Memory Conversation Store instance."""
|
||||
self._cleanup_expired()
|
||||
if len(self.sessions) >= self.max_sessions:
|
||||
oldest = min(self.sessions.values(), key=lambda item: item.updated_at)
|
||||
self.sessions.pop(oldest.session_id, None)
|
||||
session_id = str(uuid.uuid4())[:8]
|
||||
session = ConversationSession(
|
||||
session_id=session_id,
|
||||
created_at=self._now(),
|
||||
updated_at=self._now(),
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self.sessions[session_id] = session
|
||||
return session
|
||||
|
||||
def get_session(self, session_id: str) -> ConversationSession | None:
|
||||
"""Return session for the In Memory Conversation Store instance."""
|
||||
self._cleanup_expired()
|
||||
return self.sessions.get(session_id)
|
||||
|
||||
def save_message(
|
||||
self,
|
||||
session_id: str,
|
||||
*,
|
||||
role: str,
|
||||
content: str,
|
||||
sources: list[dict] | None = None,
|
||||
) -> ConversationSession | None:
|
||||
"""Save message for the In Memory Conversation Store instance."""
|
||||
session = self.get_session(session_id)
|
||||
if not session:
|
||||
return None
|
||||
session.messages.append(
|
||||
ConversationMessage(
|
||||
role=role,
|
||||
content=content,
|
||||
timestamp=self._now(),
|
||||
sources=sources or [],
|
||||
)
|
||||
)
|
||||
session.updated_at = self._now()
|
||||
return session
|
||||
|
||||
def delete_session(self, session_id: str) -> bool:
|
||||
"""Delete session for the In Memory Conversation Store instance."""
|
||||
return self.sessions.pop(session_id, None) is not None
|
||||
|
||||
def list_sessions(self) -> list[dict]:
|
||||
"""List sessions for the In Memory Conversation Store instance."""
|
||||
self._cleanup_expired()
|
||||
return [
|
||||
{
|
||||
"session_id": session.session_id,
|
||||
"message_count": len(session.messages),
|
||||
"created_at": session.created_at,
|
||||
"updated_at": session.updated_at,
|
||||
}
|
||||
for session in self.sessions.values()
|
||||
]
|
||||
5
backend/app/infrastructure/storage/__init__.py
Normal file
5
backend/app/infrastructure/storage/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Initialize the app.infrastructure.storage package."""
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = []
|
||||
109
backend/app/infrastructure/storage/json_document_repository.py
Normal file
109
backend/app/infrastructure/storage/json_document_repository.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""Implement infrastructure support for json document repository."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from app.domain.documents import Document, DocumentRepository, DocumentStatus
|
||||
# Keep adapter behavior explicit so integration details remain easy to audit.
|
||||
|
||||
|
||||
|
||||
class JsonDocumentRepository(DocumentRepository):
|
||||
"""Provide the Json Document Repository repository implementation."""
|
||||
def __init__(self, file_path: str) -> None:
|
||||
"""Initialize the Json Document Repository instance."""
|
||||
self.file_path = Path(file_path)
|
||||
self.file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if not self.file_path.exists():
|
||||
self.file_path.write_text("{}", encoding="utf-8")
|
||||
|
||||
def _load(self) -> dict[str, dict]:
|
||||
"""Handle load for this module for the Json Document Repository instance."""
|
||||
return json.loads(self.file_path.read_text(encoding="utf-8") or "{}")
|
||||
|
||||
def _save(self, payload: dict[str, dict]) -> None:
|
||||
"""Handle save for this module for the Json Document Repository instance."""
|
||||
self.file_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
|
||||
def _serialize(self, document: Document) -> dict:
|
||||
"""Handle serialize for this module for the Json Document Repository instance."""
|
||||
payload = document.__dict__.copy()
|
||||
payload["status"] = document.status.value
|
||||
payload["created_at"] = document.created_at.isoformat()
|
||||
payload["updated_at"] = document.updated_at.isoformat()
|
||||
return payload
|
||||
|
||||
def _deserialize(self, payload: dict) -> Document:
|
||||
"""Handle deserialize for this module for the Json Document Repository instance."""
|
||||
return Document(
|
||||
**{
|
||||
**payload,
|
||||
"status": DocumentStatus(payload["status"]),
|
||||
"created_at": datetime.fromisoformat(payload["created_at"]),
|
||||
"updated_at": datetime.fromisoformat(payload["updated_at"]),
|
||||
}
|
||||
)
|
||||
|
||||
def create(self, document: Document) -> Document:
|
||||
"""Handle create for the Json Document Repository instance."""
|
||||
payload = self._load()
|
||||
payload[document.doc_id] = self._serialize(document)
|
||||
self._save(payload)
|
||||
return document
|
||||
|
||||
def update(self, document: Document) -> Document:
|
||||
"""Handle update for the Json Document Repository instance."""
|
||||
document.updated_at = datetime.now(UTC)
|
||||
payload = self._load()
|
||||
payload[document.doc_id] = self._serialize(document)
|
||||
self._save(payload)
|
||||
return document
|
||||
|
||||
def get(self, doc_id: str) -> Document | None:
|
||||
"""Handle get for the Json Document Repository instance."""
|
||||
payload = self._load()
|
||||
item = payload.get(doc_id)
|
||||
return self._deserialize(item) if item else None
|
||||
|
||||
def list(self, limit: int | None = None) -> list[Document]:
|
||||
"""Handle list for the Json Document Repository instance."""
|
||||
payload = self._load()
|
||||
documents = [self._deserialize(item) for item in payload.values()]
|
||||
documents.sort(key=lambda item: item.updated_at, reverse=True)
|
||||
return documents[:limit] if limit is not None else documents
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
doc_id: str,
|
||||
status: DocumentStatus,
|
||||
*,
|
||||
error_message: str = "",
|
||||
chunk_count: int | None = None,
|
||||
summary: str | None = None,
|
||||
summary_latency_ms: int | None = None,
|
||||
parser_name: str | None = None,
|
||||
index_name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> Document | None:
|
||||
"""Update status for the Json Document Repository instance."""
|
||||
document = self.get(doc_id)
|
||||
if not document:
|
||||
return None
|
||||
document.status = status
|
||||
document.error_message = error_message
|
||||
if chunk_count is not None:
|
||||
document.chunk_count = chunk_count
|
||||
if summary is not None:
|
||||
document.summary = summary
|
||||
if summary_latency_ms is not None:
|
||||
document.summary_latency_ms = summary_latency_ms
|
||||
if parser_name is not None:
|
||||
document.parser_name = parser_name
|
||||
if index_name is not None:
|
||||
document.index_name = index_name
|
||||
if metadata:
|
||||
document.metadata.update(metadata)
|
||||
return self.update(document)
|
||||
47
backend/app/infrastructure/storage/minio_binary_store.py
Normal file
47
backend/app/infrastructure/storage/minio_binary_store.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Implement infrastructure support for minio binary store."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.domain.documents import DocumentBinaryStore
|
||||
from app.services.storage.minio_client import MinIOClient
|
||||
# Keep adapter behavior explicit so integration details remain easy to audit.
|
||||
|
||||
|
||||
|
||||
class MinioDocumentBinaryStore(DocumentBinaryStore):
|
||||
"""Provide the Minio Document Binary Store store implementation."""
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the Minio Document Binary Store instance."""
|
||||
self.client = MinIOClient()
|
||||
self.client.connect()
|
||||
self.client.ensure_bucket()
|
||||
|
||||
def save(
|
||||
self,
|
||||
*,
|
||||
object_name: str,
|
||||
data: bytes,
|
||||
content_type: str,
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Handle save for the Minio Document Binary Store instance."""
|
||||
success = self.client.upload_bytes(
|
||||
data=data,
|
||||
object_name=object_name,
|
||||
content_type=content_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
if not success:
|
||||
raise RuntimeError("MinIO 保存失败")
|
||||
|
||||
def read(self, object_name: str) -> bytes:
|
||||
"""Handle read for the Minio Document Binary Store instance."""
|
||||
data = self.client.get_object_data(object_name)
|
||||
if data is None:
|
||||
raise FileNotFoundError(f"对象不存在: {object_name}")
|
||||
return data
|
||||
|
||||
def delete(self, object_name: str) -> None:
|
||||
"""Handle delete for the Minio Document Binary Store instance."""
|
||||
if not self.client.delete_object(object_name):
|
||||
raise FileNotFoundError(f"对象删除失败: {object_name}")
|
||||
5
backend/app/infrastructure/vectorstore/__init__.py
Normal file
5
backend/app/infrastructure/vectorstore/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Initialize the app.infrastructure.vectorstore package."""
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = []
|
||||
24
backend/app/infrastructure/vectorstore/dense_retriever.py
Normal file
24
backend/app/infrastructure/vectorstore/dense_retriever.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Implement infrastructure support for dense retriever."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.domain.retrieval import EmbeddingProvider, RetrievalQuery, Retriever, RetrievedChunk, VectorIndex
|
||||
# Keep adapter behavior explicit so integration details remain easy to audit.
|
||||
|
||||
|
||||
|
||||
class DenseRetriever(Retriever):
|
||||
"""Provide the Dense Retriever retriever."""
|
||||
def __init__(self, *, embedding_provider: EmbeddingProvider, vector_index: VectorIndex) -> None:
|
||||
"""Initialize the Dense Retriever instance."""
|
||||
self.embedding_provider = embedding_provider
|
||||
self.vector_index = vector_index
|
||||
|
||||
def retrieve(self, query: RetrievalQuery) -> list[RetrievedChunk]:
|
||||
"""Handle retrieve for the Dense Retriever instance."""
|
||||
query_vector = self.embedding_provider.embed_query(query.query)
|
||||
return self.vector_index.search(query_vector, query.top_k, query.filters)
|
||||
|
||||
def search(self, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
|
||||
"""Handle search for the Dense Retriever instance."""
|
||||
return self.retrieve(RetrievalQuery(query=query, top_k=top_k, filters=filters))
|
||||
154
backend/app/infrastructure/vectorstore/milvus_vector_index.py
Normal file
154
backend/app/infrastructure/vectorstore/milvus_vector_index.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""Implement infrastructure support for milvus vector index."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
|
||||
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections, utility
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.domain.documents import Chunk
|
||||
from app.domain.retrieval import RetrievedChunk, VectorIndex
|
||||
# Keep adapter behavior explicit so integration details remain easy to audit.
|
||||
|
||||
|
||||
|
||||
class MilvusVectorIndex(VectorIndex):
|
||||
"""Provide the Milvus Vector Index index implementation."""
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the Milvus Vector Index instance."""
|
||||
self.collection_name = settings.milvus_collection
|
||||
self.db_name = settings.milvus_db_name
|
||||
connections.connect(
|
||||
alias="default",
|
||||
host=settings.milvus_host,
|
||||
port=settings.milvus_port,
|
||||
db_name=self.db_name,
|
||||
)
|
||||
self.collection = self._ensure_collection()
|
||||
|
||||
def _ensure_collection(self) -> Collection:
|
||||
"""Handle ensure collection for this module for the Milvus Vector Index instance."""
|
||||
if utility.has_collection(self.collection_name):
|
||||
collection = Collection(self.collection_name)
|
||||
collection.load()
|
||||
return collection
|
||||
schema = CollectionSchema(
|
||||
fields=[
|
||||
FieldSchema(name="id", dtype=DataType.VARCHAR, max_length=128, is_primary=True, auto_id=False),
|
||||
FieldSchema(name="doc_id", dtype=DataType.VARCHAR, max_length=64),
|
||||
FieldSchema(name="doc_name", dtype=DataType.VARCHAR, max_length=256),
|
||||
FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=settings.embedding_dim),
|
||||
FieldSchema(name="section_title", dtype=DataType.VARCHAR, max_length=512),
|
||||
FieldSchema(name="section_path", dtype=DataType.VARCHAR, max_length=4096),
|
||||
FieldSchema(name="page_number", dtype=DataType.INT64),
|
||||
FieldSchema(name="regulation_type", dtype=DataType.VARCHAR, max_length=128),
|
||||
FieldSchema(name="version", dtype=DataType.VARCHAR, max_length=64),
|
||||
FieldSchema(name="semantic_id", dtype=DataType.VARCHAR, max_length=128),
|
||||
FieldSchema(name="block_type", dtype=DataType.VARCHAR, max_length=64),
|
||||
FieldSchema(name="metadata_json", dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name="created_at", dtype=DataType.INT64),
|
||||
],
|
||||
description="Dense-only regulations index",
|
||||
enable_dynamic_field=False,
|
||||
)
|
||||
collection = Collection(name=self.collection_name, schema=schema)
|
||||
collection.create_index(
|
||||
field_name="embedding",
|
||||
index_params={
|
||||
"metric_type": "COSINE",
|
||||
"index_type": settings.milvus_index_type,
|
||||
"params": {"nlist": settings.milvus_nlist},
|
||||
},
|
||||
)
|
||||
collection.load()
|
||||
return collection
|
||||
|
||||
def upsert(self, chunks: list[Chunk], vectors: list[list[float]]) -> int:
|
||||
"""Handle upsert for the Milvus Vector Index instance."""
|
||||
if len(chunks) != len(vectors):
|
||||
raise ValueError("chunks 与 vectors 数量不一致")
|
||||
data = []
|
||||
now = int(time.time())
|
||||
for chunk, vector in zip(chunks, vectors):
|
||||
data.append(
|
||||
{
|
||||
"id": chunk.chunk_id,
|
||||
"doc_id": chunk.doc_id,
|
||||
"doc_name": chunk.doc_name,
|
||||
"content": chunk.content[:65535],
|
||||
"embedding": vector,
|
||||
"section_title": chunk.section_title[:512],
|
||||
"section_path": json.dumps(chunk.section_path, ensure_ascii=False)[:4096],
|
||||
"page_number": chunk.page_number,
|
||||
"regulation_type": chunk.regulation_type[:128],
|
||||
"version": chunk.version[:64],
|
||||
"semantic_id": chunk.semantic_id[:128],
|
||||
"block_type": chunk.block_type[:64],
|
||||
"metadata_json": json.dumps(chunk.metadata, ensure_ascii=False)[:65535],
|
||||
"created_at": now,
|
||||
}
|
||||
)
|
||||
self.collection.insert(data)
|
||||
self.collection.flush()
|
||||
return len(data)
|
||||
|
||||
def delete_by_document(self, doc_id: str) -> int:
|
||||
"""Delete by document for the Milvus Vector Index instance."""
|
||||
result = self.collection.delete(f'doc_id == "{doc_id}"')
|
||||
return len(result.primary_keys)
|
||||
|
||||
def search(self, query_vector: list[float], top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
|
||||
"""Handle search for the Milvus Vector Index instance."""
|
||||
results = self.collection.search(
|
||||
data=[query_vector],
|
||||
anns_field="embedding",
|
||||
param={"metric_type": "COSINE", "params": {"nprobe": settings.milvus_nprobe}},
|
||||
limit=top_k,
|
||||
filter=filters,
|
||||
output_fields=[
|
||||
"doc_id",
|
||||
"doc_name",
|
||||
"content",
|
||||
"section_title",
|
||||
"page_number",
|
||||
"regulation_type",
|
||||
"version",
|
||||
"semantic_id",
|
||||
"block_type",
|
||||
"metadata_json",
|
||||
],
|
||||
)
|
||||
payload: list[RetrievedChunk] = []
|
||||
for hits in results:
|
||||
for hit in hits:
|
||||
metadata = {}
|
||||
raw_metadata = hit.entity.get("metadata_json", "")
|
||||
if raw_metadata:
|
||||
try:
|
||||
metadata = json.loads(raw_metadata)
|
||||
except json.JSONDecodeError:
|
||||
metadata = {"raw_metadata": raw_metadata}
|
||||
payload.append(
|
||||
RetrievedChunk(
|
||||
chunk_id=str(hit.id),
|
||||
doc_id=hit.entity.get("doc_id", ""),
|
||||
doc_name=hit.entity.get("doc_name", ""),
|
||||
content=hit.entity.get("content", ""),
|
||||
score=float(hit.score),
|
||||
section_title=hit.entity.get("section_title", ""),
|
||||
page_number=int(hit.entity.get("page_number", 0) or 0),
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
return payload
|
||||
|
||||
def health(self) -> dict:
|
||||
"""Handle health for the Milvus Vector Index instance."""
|
||||
return {
|
||||
"connected": True,
|
||||
"collection_name": self.collection_name,
|
||||
"num_entities": self.collection.num_entities if self.collection else 0,
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Backend application entrypoint."""
|
||||
|
||||
from app.api.main import app
|
||||
# Keep module behavior explicit so the backend flow stays easy to audit.
|
||||
|
||||
|
||||
__all__ = ["app"]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Initialize the app.schemas package."""
|
||||
|
||||
from .doc import (
|
||||
DocumentUploadResponse,
|
||||
DocumentInfo,
|
||||
@@ -24,6 +26,8 @@ from .compliance import (
|
||||
ComplianceChatRequest,
|
||||
AnalyzeResponse,
|
||||
)
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DocumentUploadResponse",
|
||||
|
||||
@@ -1,21 +1,28 @@
|
||||
"""Define schema models for compliance."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
# Group related schema definitions so validation rules stay consistent.
|
||||
|
||||
|
||||
|
||||
class RiskLevel(str, Enum):
|
||||
"""Define the Risk Level enumeration."""
|
||||
high = "high"
|
||||
medium = "medium"
|
||||
low = "low"
|
||||
|
||||
|
||||
class ComplianceStatus(str, Enum):
|
||||
"""Define the Compliance Status enumeration."""
|
||||
pass_status = "pass"
|
||||
warning = "warning"
|
||||
fail = "fail"
|
||||
|
||||
|
||||
class Regulation(BaseModel):
|
||||
"""Define the Regulation API model."""
|
||||
id: int
|
||||
name: str
|
||||
clause: Optional[str] = None
|
||||
@@ -26,6 +33,7 @@ class Regulation(BaseModel):
|
||||
|
||||
|
||||
class ComplianceSegment(BaseModel):
|
||||
"""Define the Compliance Segment API model."""
|
||||
id: int
|
||||
index: int
|
||||
intent: str
|
||||
@@ -37,6 +45,7 @@ class ComplianceSegment(BaseModel):
|
||||
|
||||
|
||||
class RiskDashboard(BaseModel):
|
||||
"""Define the Risk Dashboard API model."""
|
||||
score: float
|
||||
high_risk_count: int
|
||||
medium_risk_count: int
|
||||
@@ -47,6 +56,7 @@ class RiskDashboard(BaseModel):
|
||||
|
||||
|
||||
class PriorityAction(BaseModel):
|
||||
"""Define the Priority Action API model."""
|
||||
regulation: str
|
||||
issue: str
|
||||
suggestion: str
|
||||
@@ -54,6 +64,7 @@ class PriorityAction(BaseModel):
|
||||
|
||||
|
||||
class ComplianceResult(BaseModel):
|
||||
"""Define the Compliance Result API model."""
|
||||
task_id: str
|
||||
dashboard: RiskDashboard
|
||||
segments: list[ComplianceSegment]
|
||||
@@ -61,9 +72,11 @@ class ComplianceResult(BaseModel):
|
||||
|
||||
|
||||
class ComplianceChatRequest(BaseModel):
|
||||
"""Define the Compliance Chat Request API model."""
|
||||
query: str
|
||||
|
||||
|
||||
class AnalyzeResponse(BaseModel):
|
||||
"""Define the Analyze Response API model."""
|
||||
task_id: str
|
||||
status: str = "processing"
|
||||
@@ -1,9 +1,14 @@
|
||||
"""Define schema models for doc."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
# Group related schema definitions so validation rules stay consistent.
|
||||
|
||||
|
||||
|
||||
class DocumentUploadResponse(BaseModel):
|
||||
"""Define the Document Upload Response API model."""
|
||||
doc_id: str
|
||||
filename: str
|
||||
size: int
|
||||
@@ -11,6 +16,7 @@ class DocumentUploadResponse(BaseModel):
|
||||
|
||||
|
||||
class DocumentInfo(BaseModel):
|
||||
"""Define the Document Info API model."""
|
||||
id: str
|
||||
name: str
|
||||
chunks: int
|
||||
@@ -19,10 +25,12 @@ class DocumentInfo(BaseModel):
|
||||
|
||||
|
||||
class DocumentListResponse(BaseModel):
|
||||
"""Define the Document List Response API model."""
|
||||
docs: list[DocumentInfo]
|
||||
|
||||
|
||||
class ChunkInfo(BaseModel):
|
||||
"""Define the Chunk Info API model."""
|
||||
chunk_id: str
|
||||
doc_name: str
|
||||
clause_id: Optional[str] = None
|
||||
@@ -33,12 +41,14 @@ class ChunkInfo(BaseModel):
|
||||
|
||||
|
||||
class ParseResponse(BaseModel):
|
||||
"""Define the Parse Response API model."""
|
||||
doc_id: str
|
||||
chunks: int
|
||||
status: str = "parsed"
|
||||
|
||||
|
||||
class EmbedResponse(BaseModel):
|
||||
"""Define the Embed Response API model."""
|
||||
doc_id: str
|
||||
vectors: int
|
||||
status: str = "embedded"
|
||||
@@ -1,13 +1,19 @@
|
||||
"""Define schema models for rag."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
# Group related schema definitions so validation rules stay consistent.
|
||||
|
||||
|
||||
|
||||
class RagChatRequest(BaseModel):
|
||||
"""Define the Rag Chat Request API model."""
|
||||
query: str
|
||||
top_k: int = 5
|
||||
|
||||
|
||||
class RetrievedDoc(BaseModel):
|
||||
"""Define the Retrieved Doc API model."""
|
||||
id: str
|
||||
doc_name: str
|
||||
clause_id: Optional[str] = None
|
||||
@@ -17,15 +23,18 @@ class RetrievedDoc(BaseModel):
|
||||
|
||||
|
||||
class SourceInfo(BaseModel):
|
||||
"""Define the Source Info API model."""
|
||||
name: str
|
||||
clause: Optional[str] = None
|
||||
|
||||
|
||||
class QuickQuestion(BaseModel):
|
||||
"""Define the Quick Question API model."""
|
||||
id: str
|
||||
question: str
|
||||
category: str
|
||||
|
||||
|
||||
class QuickQuestionsResponse(BaseModel):
|
||||
"""Define the Quick Questions Response API model."""
|
||||
questions: list[QuickQuestion]
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Backend service package."""
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__: list[str] = []
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Agent服务模块"""
|
||||
"""Initialize the app.services.agent package."""
|
||||
|
||||
from .qa_agent import QAAgent, ask_compliance_question
|
||||
from .session_manager import SessionManager, ChatSession
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = ["QAAgent", "ask_compliance_question", "SessionManager", "ChatSession"]
|
||||
|
||||
@@ -1,21 +1,19 @@
|
||||
"""RAG问答Agent - 合规智能问答核心实现"""
|
||||
"""Provide service-layer logic for qa agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import List, Dict, Optional, Any, Generator
|
||||
from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
from typing import Dict, Generator, List, Optional
|
||||
|
||||
from app.services.llm import get_llm_client, BaseLLMClient, LLMResponse
|
||||
from app.services.llm.llm_factory import LLMFactory
|
||||
from app.services.rag.retriever import Retriever, RetrievedDocument
|
||||
from app.services.rag.context_builder import ContextBuilder, RAGContext
|
||||
from app.services.rag.prompt_templates import get_prompt_template, PromptTemplate
|
||||
from app.config.settings import settings
|
||||
from app.shared.bootstrap import get_agent_conversation_service
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResponse:
|
||||
"""Agent响应结果"""
|
||||
"""Represent the Agent Response type."""
|
||||
answer: str
|
||||
sources: List[Dict] = field(default_factory=list)
|
||||
model: str = ""
|
||||
@@ -27,385 +25,73 @@ class AgentResponse:
|
||||
|
||||
@property
|
||||
def is_success(self) -> bool:
|
||||
"""Return whether success for the Agent Response instance."""
|
||||
return self.error is None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentConfig:
|
||||
"""Agent配置"""
|
||||
llm_provider: str = "deepseek"
|
||||
llm_model: str = "deepseek-v4-flash"
|
||||
top_k: int = 5
|
||||
min_score: float = 0.3
|
||||
max_context_tokens: int = 2000
|
||||
temperature: float = 0.7
|
||||
"""Define configuration for agent config."""
|
||||
llm_provider: str = settings.llm_provider
|
||||
llm_model: str = settings.llm_model
|
||||
top_k: int = settings.rag_top_k
|
||||
min_score: float = 0.0
|
||||
max_context_tokens: int = settings.rag_max_context_tokens
|
||||
temperature: float = settings.llm_temperature
|
||||
prompt_template: str = "compliance_qa"
|
||||
include_metadata: bool = True
|
||||
|
||||
|
||||
class QAAgent:
|
||||
"""
|
||||
合规问答Agent
|
||||
|
||||
核心流程:
|
||||
1. 接收用户问题
|
||||
2. Milvus混合检索相关法规条款
|
||||
3. 构建RAG上下文
|
||||
4. 调用LLM生成回答
|
||||
5. 返回答案和引用来源
|
||||
|
||||
使用示例:
|
||||
agent = QAAgent()
|
||||
response = agent.ask("机动车安全技术检验有哪些要求?")
|
||||
print(response.answer)
|
||||
for source in response.sources:
|
||||
print(f"引用: {source['doc_name']} - {source['clause_number']}")
|
||||
"""
|
||||
|
||||
"""Represent the Q A Agent type."""
|
||||
def __init__(self, config: Optional[AgentConfig] = None):
|
||||
"""
|
||||
初始化问答Agent
|
||||
|
||||
Args:
|
||||
config: Agent配置(可选,使用默认配置)
|
||||
"""
|
||||
self.config = config or AgentConfig(
|
||||
llm_provider=settings.llm_provider,
|
||||
llm_model=settings.llm_model,
|
||||
top_k=settings.rag_top_k,
|
||||
max_context_tokens=settings.rag_max_context_tokens
|
||||
)
|
||||
|
||||
# 初始化组件(延迟加载)
|
||||
self.llm: Optional[BaseLLMClient] = None
|
||||
self.retriever: Optional[Retriever] = None
|
||||
self.context_builder: Optional[ContextBuilder] = None
|
||||
|
||||
logger.info(f"问答Agent初始化: provider={self.config.llm_provider}, model={self.config.llm_model}")
|
||||
|
||||
def _init_llm(self):
|
||||
"""延迟初始化LLM客户端(优先使用全局缓存)"""
|
||||
if self.llm is None:
|
||||
# 尝试先获取全局缓存的客户端
|
||||
cached = LLMFactory.get_global_client(self.config.llm_provider, self.config.llm_model)
|
||||
if cached:
|
||||
self.llm = cached
|
||||
logger.debug(f"使用全局缓存的LLM客户端: {self.config.llm_provider} - {self.config.llm_model}")
|
||||
else:
|
||||
logger.info("创建新的LLM客户端...")
|
||||
self.llm = get_llm_client(
|
||||
provider=self.config.llm_provider,
|
||||
model=self.config.llm_model,
|
||||
temperature=self.config.temperature
|
||||
)
|
||||
|
||||
def _init_retriever(self):
|
||||
"""延迟初始化检索器"""
|
||||
if self.retriever is None:
|
||||
logger.info("初始化检索器...")
|
||||
self.retriever = Retriever(
|
||||
top_k=self.config.top_k,
|
||||
min_score=self.config.min_score
|
||||
)
|
||||
|
||||
def _init_context_builder(self):
|
||||
"""延迟初始化上下文构建器"""
|
||||
if self.context_builder is None:
|
||||
logger.info("初始化上下文构建器...")
|
||||
self.context_builder = ContextBuilder(
|
||||
max_context_tokens=self.config.max_context_tokens,
|
||||
include_metadata=self.config.include_metadata
|
||||
)
|
||||
"""Initialize the Q A Agent instance."""
|
||||
self.config = config or AgentConfig()
|
||||
|
||||
def ask(
|
||||
self,
|
||||
query: str,
|
||||
filters: Optional[str] = None,
|
||||
prompt_template: Optional[str] = None
|
||||
prompt_template: Optional[str] = None,
|
||||
) -> AgentResponse:
|
||||
"""
|
||||
回答用户问题
|
||||
|
||||
Args:
|
||||
query: 用户问题
|
||||
filters: 检索过滤条件(如 "regulation_type=='车辆安全'")
|
||||
prompt_template: Prompt模板名称(可选,覆盖默认配置)
|
||||
|
||||
Returns:
|
||||
AgentResponse: 包含答案和引用来源的响应对象
|
||||
"""
|
||||
start_time = time.time()
|
||||
logger.info(f"收到问题: {query}")
|
||||
|
||||
try:
|
||||
# Step 1: 检索相关法规
|
||||
self._init_retriever()
|
||||
documents = self.retriever.retrieve(query, filters)
|
||||
retrieved_count = len(documents)
|
||||
|
||||
if retrieved_count == 0:
|
||||
return AgentResponse(
|
||||
answer="抱歉,未找到与您的问题相关的法规条款。请尝试用不同的关键词重新提问,或提供更具体的法规名称。",
|
||||
retrieved_count=0,
|
||||
error="no_retrieved_documents"
|
||||
)
|
||||
|
||||
# Step 2: 构建RAG上下文
|
||||
self._init_context_builder()
|
||||
template_name = prompt_template or self.config.prompt_template
|
||||
template = get_prompt_template(template_name)
|
||||
|
||||
context = self.context_builder.build(
|
||||
query=query,
|
||||
documents=documents,
|
||||
system_prompt=template.system_prompt
|
||||
)
|
||||
|
||||
# Step 3: 构建LLM输入消息
|
||||
messages = self._build_messages(template, context)
|
||||
|
||||
# Step 4: 调用LLM生成回答
|
||||
self._init_llm()
|
||||
llm_response = self.llm.chat(
|
||||
messages=messages,
|
||||
temperature=self.config.temperature
|
||||
)
|
||||
|
||||
if not llm_response.is_success:
|
||||
return AgentResponse(
|
||||
answer="",
|
||||
retrieved_count=retrieved_count,
|
||||
error=llm_response.error
|
||||
)
|
||||
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Step 5: 返回结果
|
||||
logger.success(f"问答完成: {latency_ms}ms, {retrieved_count}条引用")
|
||||
|
||||
return AgentResponse(
|
||||
answer=llm_response.content,
|
||||
sources=context.sources,
|
||||
model=llm_response.model,
|
||||
latency_ms=latency_ms,
|
||||
retrieved_count=retrieved_count,
|
||||
context_tokens=context.total_tokens,
|
||||
truncated=context.truncated
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"问答失败: {e}")
|
||||
return AgentResponse(
|
||||
answer="",
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def ask_with_context(
|
||||
self,
|
||||
query: str,
|
||||
documents: List[RetrievedDocument],
|
||||
prompt_template: Optional[str] = None
|
||||
) -> AgentResponse:
|
||||
"""
|
||||
使用提供的文档回答问题(不执行检索)
|
||||
|
||||
Args:
|
||||
query: 用户问题
|
||||
documents: 已检索的文档列表
|
||||
prompt_template: Prompt模板名称
|
||||
|
||||
Returns:
|
||||
AgentResponse: 响应结果
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
self._init_context_builder()
|
||||
self._init_llm()
|
||||
|
||||
template_name = prompt_template or self.config.prompt_template
|
||||
template = get_prompt_template(template_name)
|
||||
|
||||
context = self.context_builder.build(
|
||||
query=query,
|
||||
documents=documents,
|
||||
system_prompt=template.system_prompt
|
||||
)
|
||||
|
||||
messages = self._build_messages(template, context)
|
||||
|
||||
llm_response = self.llm.chat(messages)
|
||||
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
return AgentResponse(
|
||||
answer=llm_response.content,
|
||||
sources=context.sources,
|
||||
model=llm_response.model,
|
||||
latency_ms=latency_ms,
|
||||
retrieved_count=len(documents),
|
||||
context_tokens=context.total_tokens,
|
||||
truncated=context.truncated
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"问答失败: {e}")
|
||||
return AgentResponse(answer="", error=str(e))
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
template: PromptTemplate,
|
||||
context: RAGContext
|
||||
) -> List[Dict[str, str]]:
|
||||
"""构建LLM输入消息"""
|
||||
user_content = template.user_template.format(
|
||||
context=context.context_text,
|
||||
query=context.user_query
|
||||
"""Handle ask for the Q A Agent instance."""
|
||||
_, result = get_agent_conversation_service().ask(
|
||||
query=query,
|
||||
filters=filters,
|
||||
provider=self.config.llm_provider,
|
||||
model=self.config.llm_model,
|
||||
top_k=self.config.top_k,
|
||||
prompt_template=prompt_template or self.config.prompt_template,
|
||||
)
|
||||
return AgentResponse(
|
||||
answer=result.answer,
|
||||
sources=[source.__dict__ for source in result.sources],
|
||||
model=result.model,
|
||||
latency_ms=result.latency_ms,
|
||||
retrieved_count=result.retrieved_count,
|
||||
context_tokens=result.context_tokens,
|
||||
truncated=result.truncated,
|
||||
error=result.error,
|
||||
)
|
||||
|
||||
return [
|
||||
{"role": "system", "content": template.system_prompt},
|
||||
{"role": "user", "content": user_content}
|
||||
]
|
||||
|
||||
def ask_stream(
|
||||
self,
|
||||
query: str,
|
||||
filters: Optional[str] = None,
|
||||
prompt_template: Optional[str] = None
|
||||
) -> Generator[Dict[str, Any], None, None]:
|
||||
"""
|
||||
流式回答用户问题(SSE模式)
|
||||
|
||||
返回事件类型:
|
||||
- {"event": "status", "data": "正在检索..."} - 状态更新
|
||||
- {"event": "sources", "data": [...]} - 引用来源
|
||||
- {"event": "content", "data": "文本片段"} - 回答内容
|
||||
- {"event": "done", "data": {"latency_ms": ..., "model": ...}} - 完成
|
||||
|
||||
Args:
|
||||
query: 用户问题
|
||||
filters: 检索过滤条件
|
||||
prompt_template: Prompt模板名称
|
||||
|
||||
Yields:
|
||||
Dict: SSE事件数据
|
||||
"""
|
||||
start_time = time.time()
|
||||
logger.info(f"收到流式问题: {query}")
|
||||
|
||||
try:
|
||||
# Step 1: 检索相关法规
|
||||
yield {"event": "status", "data": "正在检索相关法规..."}
|
||||
self._init_retriever()
|
||||
documents = self.retriever.retrieve(query, filters)
|
||||
retrieved_count = len(documents)
|
||||
|
||||
if retrieved_count == 0:
|
||||
yield {"event": "status", "data": "未找到相关法规"}
|
||||
yield {"event": "content", "data": "抱歉,未找到与您的问题相关的法规条款。请尝试用不同的关键词重新提问。"}
|
||||
yield {"event": "done", "data": {"latency_ms": 0, "retrieved_count": 0}}
|
||||
return
|
||||
|
||||
# Step 2: 发送检索结果
|
||||
yield {"event": "status", "data": f"找到{retrieved_count}条相关法规,正在生成回答..."}
|
||||
sources = [
|
||||
{
|
||||
"doc_name": doc.doc_name,
|
||||
"doc_id": doc.doc_id,
|
||||
"clause_number": doc.clause_number,
|
||||
"score": doc.score
|
||||
}
|
||||
for doc in documents[:5] # 只返回前5条引用
|
||||
]
|
||||
yield {"event": "sources", "data": sources}
|
||||
|
||||
# Step 3: 构建RAG上下文
|
||||
self._init_context_builder()
|
||||
template_name = prompt_template or self.config.prompt_template
|
||||
template = get_prompt_template(template_name)
|
||||
context = self.context_builder.build(
|
||||
query=query,
|
||||
documents=documents,
|
||||
system_prompt=template.system_prompt
|
||||
)
|
||||
|
||||
# Step 4: 构建LLM输入消息
|
||||
messages = self._build_messages(template, context)
|
||||
|
||||
# Step 5: 流式调用LLM生成回答
|
||||
self._init_llm()
|
||||
full_answer = ""
|
||||
|
||||
# 检查LLM是否支持流式输出
|
||||
if hasattr(self.llm, 'stream_chat'):
|
||||
yield {"event": "status", "data": "思考中..."}
|
||||
for chunk in self.llm.stream_chat(
|
||||
messages=messages,
|
||||
temperature=self.config.temperature
|
||||
):
|
||||
full_answer += chunk
|
||||
yield {"event": "content", "data": chunk}
|
||||
else:
|
||||
# 如果不支持流式,回退到普通调用
|
||||
yield {"event": "status", "data": "生成回答中..."}
|
||||
llm_response = self.llm.chat(
|
||||
messages=messages,
|
||||
temperature=self.config.temperature
|
||||
)
|
||||
if llm_response.is_success:
|
||||
full_answer = llm_response.content
|
||||
yield {"event": "content", "data": full_answer}
|
||||
|
||||
# Step 6: 发送完成事件
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
logger.success(f"流式问答完成: {latency_ms}ms, {retrieved_count}条引用")
|
||||
|
||||
yield {
|
||||
"event": "done",
|
||||
"data": {
|
||||
"latency_ms": latency_ms,
|
||||
"model": self.config.llm_model,
|
||||
"retrieved_count": retrieved_count,
|
||||
"context_tokens": context.total_tokens
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式问答失败: {e}")
|
||||
yield {"event": "error", "data": str(e)}
|
||||
def ask_stream(self, query: str, filters: Optional[str] = None) -> Generator[dict, None, None]:
|
||||
"""Handle ask stream for the Q A Agent instance."""
|
||||
_, stream = get_agent_conversation_service().stream_chat(
|
||||
query=query,
|
||||
filters=filters,
|
||||
provider=self.config.llm_provider,
|
||||
model=self.config.llm_model,
|
||||
top_k=self.config.top_k,
|
||||
prompt_template=self.config.prompt_template,
|
||||
)
|
||||
for event in stream:
|
||||
yield event
|
||||
|
||||
def close(self):
|
||||
"""关闭Agent资源(不关闭LLM客户端,因为它全局缓存)"""
|
||||
if self.retriever:
|
||||
self.retriever.close()
|
||||
logger.info("问答Agent已关闭")
|
||||
"""Release the resources held by this component."""
|
||||
return None
|
||||
|
||||
|
||||
def ask_compliance_question(
|
||||
query: str,
|
||||
provider: str = "deepseek",
|
||||
model: str = "deepseek-v4-flash",
|
||||
top_k: int = 10
|
||||
) -> AgentResponse:
|
||||
"""
|
||||
便捷函数:问答合规问题
|
||||
|
||||
Args:
|
||||
query: 用户问题
|
||||
provider: LLM提供商
|
||||
model: LLM模型
|
||||
top_k: 检索数量
|
||||
|
||||
Returns:
|
||||
AgentResponse: 响应结果
|
||||
"""
|
||||
config = AgentConfig(
|
||||
llm_provider=provider,
|
||||
llm_model=model,
|
||||
top_k=top_k
|
||||
)
|
||||
agent = QAAgent(config)
|
||||
response = agent.ask(query)
|
||||
agent.close()
|
||||
return response
|
||||
def ask_compliance_question(query: str, top_k: int = 5) -> AgentResponse:
|
||||
"""Handle ask compliance question."""
|
||||
return QAAgent(AgentConfig(top_k=top_k)).ask(query)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""多轮对话会话管理"""
|
||||
"""Provide service-layer logic for session manager."""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
@@ -9,7 +9,7 @@ from loguru import logger
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
"""对话消息"""
|
||||
"""Represent the Chat Message type."""
|
||||
role: str # "user" / "assistant" / "system"
|
||||
content: str
|
||||
timestamp: int
|
||||
@@ -19,7 +19,7 @@ class ChatMessage:
|
||||
|
||||
@dataclass
|
||||
class ChatSession:
|
||||
"""对话会话"""
|
||||
"""Represent the Chat Session type."""
|
||||
session_id: str
|
||||
messages: List[ChatMessage] = field(default_factory=list)
|
||||
created_at: int = field(default_factory=lambda: int(time.time()))
|
||||
@@ -27,7 +27,7 @@ class ChatSession:
|
||||
metadata: Dict = field(default_factory=dict)
|
||||
|
||||
def add_user_message(self, content: str) -> ChatMessage:
|
||||
"""添加用户消息"""
|
||||
"""Handle add user message for the Chat Session instance."""
|
||||
message = ChatMessage(
|
||||
role="user",
|
||||
content=content,
|
||||
@@ -42,7 +42,7 @@ class ChatSession:
|
||||
content: str,
|
||||
sources: List[Dict] = None
|
||||
) -> ChatMessage:
|
||||
"""添加助手消息"""
|
||||
"""Handle add assistant message for the Chat Session instance."""
|
||||
message = ChatMessage(
|
||||
role="assistant",
|
||||
content=content,
|
||||
@@ -54,9 +54,9 @@ class ChatSession:
|
||||
return message
|
||||
|
||||
def get_history(self, max_turns: int = 5) -> List[Dict[str, str]]:
|
||||
"""获取历史对话(用于LLM上下文)"""
|
||||
"""Return history for the Chat Session instance."""
|
||||
history = []
|
||||
# 获取最近N轮对话(每轮包含user + assistant)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
recent_messages = self.messages[-(max_turns * 2):]
|
||||
|
||||
for msg in recent_messages:
|
||||
@@ -68,81 +68,47 @@ class ChatSession:
|
||||
return history
|
||||
|
||||
def clear_history(self):
|
||||
"""清空对话历史"""
|
||||
"""Handle clear history for the Chat Session instance."""
|
||||
self.messages = []
|
||||
self.updated_at = int(time.time())
|
||||
logger.info(f"会话历史已清空: {self.session_id}")
|
||||
|
||||
@property
|
||||
def message_count(self) -> int:
|
||||
"""消息数量"""
|
||||
"""Handle message count for the Chat Session instance."""
|
||||
return len(self.messages)
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""是否为空会话"""
|
||||
"""Return whether empty for the Chat Session instance."""
|
||||
return len(self.messages) == 0
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""
|
||||
会话管理器
|
||||
|
||||
功能:
|
||||
- 创建/获取/删除会话
|
||||
- 会话超时清理
|
||||
- 会话历史记录管理
|
||||
|
||||
使用示例:
|
||||
manager = SessionManager()
|
||||
|
||||
# 创建会话
|
||||
session = manager.create_session()
|
||||
|
||||
# 添加消息
|
||||
session.add_user_message("什么是机动车安全技术检验?")
|
||||
session.add_assistant_message("根据GB 7258...", sources=[...])
|
||||
|
||||
# 获取历史(用于LLM多轮对话)
|
||||
history = session.get_history(max_turns=3)
|
||||
"""
|
||||
"""Represent the Session Manager type."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_sessions: int = 100,
|
||||
session_timeout_minutes: int = 30
|
||||
):
|
||||
"""
|
||||
初始化会话管理器
|
||||
|
||||
Args:
|
||||
max_sessions: 最大会话数量
|
||||
session_timeout_minutes: 会话超时时间(分钟)
|
||||
"""
|
||||
"""Initialize the Session Manager instance."""
|
||||
self.max_sessions = max_sessions
|
||||
self.session_timeout = session_timeout_minutes * 60
|
||||
|
||||
# 会话存储(内存)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
self._sessions: Dict[str, ChatSession] = {}
|
||||
|
||||
logger.info(f"会话管理器初始化: max_sessions={max_sessions}, timeout={session_timeout_minutes}min")
|
||||
|
||||
def create_session(self, metadata: Dict = None) -> ChatSession:
|
||||
"""
|
||||
创建新会话
|
||||
|
||||
Args:
|
||||
metadata: 会话元数据(可选)
|
||||
|
||||
Returns:
|
||||
ChatSession: 新创建的会话
|
||||
"""
|
||||
# 检查会话数量限制
|
||||
"""Create session for the Session Manager instance."""
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if len(self._sessions) >= self.max_sessions:
|
||||
# 清理过期会话
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
self._cleanup_expired_sessions()
|
||||
|
||||
# 如果仍然超出限制,删除最老的会话
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if len(self._sessions) >= self.max_sessions:
|
||||
oldest_id = min(
|
||||
self._sessions.keys(),
|
||||
@@ -163,19 +129,11 @@ class SessionManager:
|
||||
return session
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[ChatSession]:
|
||||
"""
|
||||
获取会话
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
ChatSession: 会话对象(如不存在返回None)
|
||||
"""
|
||||
"""Return session for the Session Manager instance."""
|
||||
session = self._sessions.get(session_id)
|
||||
|
||||
if session:
|
||||
# 检查是否过期
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if self._is_session_expired(session):
|
||||
self.delete_session(session_id)
|
||||
logger.info(f"会话已过期,已删除: {session_id}")
|
||||
@@ -184,15 +142,7 @@ class SessionManager:
|
||||
return session
|
||||
|
||||
def delete_session(self, session_id: str) -> bool:
|
||||
"""
|
||||
删除会话
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
bool: 是否成功删除
|
||||
"""
|
||||
"""Delete session for the Session Manager instance."""
|
||||
if session_id in self._sessions:
|
||||
del self._sessions[session_id]
|
||||
logger.info(f"删除会话: {session_id}")
|
||||
@@ -200,12 +150,7 @@ class SessionManager:
|
||||
return False
|
||||
|
||||
def list_sessions(self) -> List[Dict]:
|
||||
"""
|
||||
列出所有会话
|
||||
|
||||
Returns:
|
||||
List[Dict]: 会话列表摘要
|
||||
"""
|
||||
"""List sessions for the Session Manager instance."""
|
||||
return [
|
||||
{
|
||||
"session_id": s.session_id,
|
||||
@@ -217,12 +162,12 @@ class SessionManager:
|
||||
]
|
||||
|
||||
def _is_session_expired(self, session: ChatSession) -> bool:
|
||||
"""检查会话是否过期"""
|
||||
"""Handle is session expired for this module for the Session Manager instance."""
|
||||
current_time = int(time.time())
|
||||
return (current_time - session.updated_at) > self.session_timeout
|
||||
|
||||
def _cleanup_expired_sessions(self) -> int:
|
||||
"""清理过期会话"""
|
||||
"""Handle cleanup expired sessions for this module for the Session Manager instance."""
|
||||
expired_ids = [
|
||||
sid for sid, session in self._sessions.items()
|
||||
if self._is_session_expired(session)
|
||||
@@ -237,10 +182,10 @@ class SessionManager:
|
||||
return len(expired_ids)
|
||||
|
||||
def get_session_count(self) -> int:
|
||||
"""获取当前会话数量"""
|
||||
"""Return session count for the Session Manager instance."""
|
||||
return len(self._sessions)
|
||||
|
||||
def clear_all_sessions(self):
|
||||
"""清空所有会话"""
|
||||
"""Handle clear all sessions for the Session Manager instance."""
|
||||
self._sessions.clear()
|
||||
logger.info("所有会话已清空")
|
||||
|
||||
@@ -1,24 +1,19 @@
|
||||
"""文档处理主流程 - 解析→摘要→分块→嵌入→入库"""
|
||||
"""Provide service-layer logic for document processor."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import List, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
from loguru import logger
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from app.shared.bootstrap import get_document_command_service, get_retrieval_service
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
from .parser.pdf_parser import PDFParser
|
||||
from .parser.docx_parser import DocxParser
|
||||
from .parser.mineru_parser import ParserOrchestrator
|
||||
from .embedding.text_chunker import RegulationChunker, TextChunk
|
||||
from .embedding.bge_m3_embedder import BGEM3Embedder, EmbeddingResult
|
||||
from .storage.milvus_client import MilvusClient
|
||||
from .llm.document_summarizer import DocumentSummarizer, DocumentSummary
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessingResult:
|
||||
"""文档处理结果"""
|
||||
"""Represent the Processing Result type."""
|
||||
doc_id: str
|
||||
doc_name: str
|
||||
success: bool
|
||||
@@ -30,87 +25,10 @@ class ProcessingResult:
|
||||
|
||||
|
||||
class DocumentProcessor:
|
||||
"""
|
||||
文档处理服务 - 完整处理流程
|
||||
|
||||
流程:
|
||||
1. 文档解析(PDF/DOCX → Markdown)
|
||||
2. 智能分块(章节级+条款级)
|
||||
3. LLM摘要生成(可选)
|
||||
4. 向量嵌入(BGE-M3 Dense+Sparse)
|
||||
5. 存储入库(Milvus向量数据库)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = None,
|
||||
embedding_model: str = None,
|
||||
use_mineru: bool = True,
|
||||
generate_summary: bool = False, # 默认不生成摘要,节省约60秒
|
||||
llm_provider: str = None,
|
||||
llm_model: str = None
|
||||
):
|
||||
"""
|
||||
初始化文档处理器
|
||||
|
||||
Args:
|
||||
chunk_size: 分块大小
|
||||
embedding_model: 嵌入模型名称
|
||||
use_mineru: 是否优先使用MinerU解析
|
||||
generate_summary: 是否生成文档摘要(默认False,可节省约60秒处理时间)
|
||||
llm_provider: LLM提供商
|
||||
llm_model: LLM模型名称
|
||||
"""
|
||||
self.chunk_size = chunk_size or settings.chunk_size
|
||||
self.embedding_model = embedding_model or settings.embedding_model
|
||||
self.use_mineru = use_mineru
|
||||
"""Represent the Document Processor type."""
|
||||
def __init__(self, *args, generate_summary: bool = False, **kwargs):
|
||||
"""Initialize the Document Processor instance."""
|
||||
self.generate_summary = generate_summary
|
||||
self.llm_provider = llm_provider or settings.llm_provider
|
||||
self.llm_model = llm_model or settings.llm_model
|
||||
|
||||
# 初始化各组件
|
||||
logger.info("初始化文档处理组件...")
|
||||
|
||||
# 解析器
|
||||
self.parser = ParserOrchestrator()
|
||||
|
||||
# 分块器
|
||||
self.chunker = RegulationChunker(chunk_size=self.chunk_size)
|
||||
|
||||
# 嵌入模型(延迟加载)
|
||||
self.embedder: Optional[BGEM3Embedder] = None
|
||||
|
||||
# Milvus客户端(延迟连接)
|
||||
self.milvus: Optional[MilvusClient] = None
|
||||
|
||||
# 摘要生成器(延迟加载)
|
||||
self.summarizer: Optional[DocumentSummarizer] = None
|
||||
|
||||
logger.success("文档处理器初始化完成")
|
||||
|
||||
def _init_embedder(self):
|
||||
"""延迟初始化嵌入模型"""
|
||||
if self.embedder is None:
|
||||
logger.info("加载嵌入模型...")
|
||||
self.embedder = BGEM3Embedder(model_name=self.embedding_model)
|
||||
|
||||
def _init_milvus(self):
|
||||
"""延迟初始化Milvus连接"""
|
||||
if self.milvus is None:
|
||||
logger.info("连接Milvus...")
|
||||
self.milvus = MilvusClient()
|
||||
self.milvus.connect()
|
||||
self.milvus.create_collection(recreate=False)
|
||||
self.milvus.load_collection()
|
||||
|
||||
def _init_summarizer(self):
|
||||
"""延迟初始化摘要生成器"""
|
||||
if self.summarizer is None:
|
||||
logger.info("初始化摘要生成器...")
|
||||
self.summarizer = DocumentSummarizer(
|
||||
provider=self.llm_provider,
|
||||
model=self.llm_model
|
||||
)
|
||||
|
||||
def process(
|
||||
self,
|
||||
@@ -118,286 +36,51 @@ class DocumentProcessor:
|
||||
doc_id: Optional[str] = None,
|
||||
doc_name: Optional[str] = None,
|
||||
regulation_type: str = "",
|
||||
version: str = ""
|
||||
version: str = "",
|
||||
) -> ProcessingResult:
|
||||
"""
|
||||
处理单个文档
|
||||
"""Handle process for the Document Processor instance."""
|
||||
path = Path(file_path)
|
||||
content = path.read_bytes()
|
||||
result = get_document_command_service().upload_and_process(
|
||||
doc_id=doc_id,
|
||||
file_name=path.name,
|
||||
content=content,
|
||||
content_type="application/octet-stream",
|
||||
doc_name=doc_name or path.name,
|
||||
regulation_type=regulation_type,
|
||||
version=version,
|
||||
generate_summary=self.generate_summary,
|
||||
)
|
||||
return ProcessingResult(
|
||||
doc_id=result.doc_id,
|
||||
doc_name=result.doc_name,
|
||||
success=result.status != "failed",
|
||||
num_chunks=result.num_chunks,
|
||||
message=result.message,
|
||||
summary=result.summary,
|
||||
summary_latency_ms=result.summary_latency_ms,
|
||||
)
|
||||
|
||||
Args:
|
||||
file_path: 文档文件路径
|
||||
doc_id: 文档ID(可选,默认自动生成)
|
||||
doc_name: 文档名称(可选,默认从文件名获取)
|
||||
regulation_type: 法规类型
|
||||
version: 文档版本
|
||||
|
||||
Returns:
|
||||
ProcessingResult: 处理结果
|
||||
"""
|
||||
# 生成或使用传入的文档ID
|
||||
if doc_id is None:
|
||||
doc_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# 获取文档名称
|
||||
if doc_name is None:
|
||||
doc_name = os.path.basename(file_path)
|
||||
|
||||
logger.info(f"开始处理文档: {doc_name} (ID: {doc_id})")
|
||||
|
||||
# 初始化结果变量
|
||||
summary = ""
|
||||
summary_latency_ms = 0
|
||||
|
||||
try:
|
||||
# 1. 文档解析
|
||||
logger.info("Step 1: 文档解析")
|
||||
markdown_text = self._parse_document(file_path)
|
||||
|
||||
if not markdown_text:
|
||||
return ProcessingResult(
|
||||
doc_id=doc_id,
|
||||
doc_name=doc_name,
|
||||
success=False,
|
||||
message="文档解析失败,内容为空"
|
||||
)
|
||||
|
||||
# 2. LLM摘要生成(可选)
|
||||
if self.generate_summary:
|
||||
logger.info("Step 2: LLM摘要生成")
|
||||
self._init_summarizer()
|
||||
summary_result = self.summarizer.summarize(
|
||||
doc_name,
|
||||
markdown_text,
|
||||
regulation_type
|
||||
)
|
||||
if summary_result.is_success:
|
||||
summary = summary_result.summary
|
||||
summary_latency_ms = summary_result.latency_ms
|
||||
logger.success(f"摘要生成完成: {summary_latency_ms}ms")
|
||||
else:
|
||||
logger.warning(f"摘要生成失败: {summary_result.error}")
|
||||
else:
|
||||
logger.info("Step 2: 跳过摘要生成(未勾选,节省约60秒)")
|
||||
|
||||
# 3. 智能分块
|
||||
logger.info("Step 3: 智能分块")
|
||||
chunks = self._chunk_document(
|
||||
markdown_text,
|
||||
doc_id,
|
||||
doc_name,
|
||||
regulation_type,
|
||||
version
|
||||
)
|
||||
|
||||
if not chunks:
|
||||
return ProcessingResult(
|
||||
doc_id=doc_id,
|
||||
doc_name=doc_name,
|
||||
success=False,
|
||||
message="分块失败,无有效内容",
|
||||
markdown_text=markdown_text,
|
||||
summary=summary
|
||||
)
|
||||
|
||||
# 4. 向量嵌入
|
||||
logger.info("Step 4: 向量嵌入")
|
||||
embeddings = self._embed_chunks(chunks)
|
||||
|
||||
if embeddings is None:
|
||||
return ProcessingResult(
|
||||
doc_id=doc_id,
|
||||
doc_name=doc_name,
|
||||
success=False,
|
||||
message="向量嵌入失败",
|
||||
markdown_text=markdown_text,
|
||||
summary=summary
|
||||
)
|
||||
|
||||
# 5. 存储入库
|
||||
logger.info("Step 5: 存储入库")
|
||||
inserted_ids = self._insert_to_milvus(chunks, embeddings)
|
||||
|
||||
logger.success(f"文档处理完成: {doc_name}, 共{len(inserted_ids)}条记录")
|
||||
|
||||
return ProcessingResult(
|
||||
doc_id=doc_id,
|
||||
doc_name=doc_name,
|
||||
success=True,
|
||||
num_chunks=len(inserted_ids),
|
||||
message="处理成功",
|
||||
markdown_text=markdown_text,
|
||||
summary=summary,
|
||||
summary_latency_ms=summary_latency_ms
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"文档处理失败: {e}")
|
||||
return ProcessingResult(
|
||||
doc_id=doc_id,
|
||||
doc_name=doc_name,
|
||||
success=False,
|
||||
message=f"处理失败: {str(e)}"
|
||||
)
|
||||
|
||||
def _parse_document(self, file_path: str) -> str:
|
||||
"""解析文档"""
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
|
||||
try:
|
||||
if ext == ".pdf":
|
||||
# PDF文档解析(优先MinerU,回退PyMuPDF)
|
||||
markdown_text = self.parser.parse_pdf(file_path, prefer_mineru=self.use_mineru)
|
||||
elif ext in [".docx", ".doc"]:
|
||||
# Word文档解析
|
||||
markdown_text = self.parser.parse_docx(file_path)
|
||||
else:
|
||||
logger.warning(f"不支持的文件类型: {ext}")
|
||||
return ""
|
||||
|
||||
logger.success(f"文档解析完成,内容长度: {len(markdown_text)}字符")
|
||||
return markdown_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"文档解析失败: {e}")
|
||||
return ""
|
||||
|
||||
def _chunk_document(
|
||||
self,
|
||||
markdown_text: str,
|
||||
doc_id: str,
|
||||
doc_name: str,
|
||||
regulation_type: str,
|
||||
version: str
|
||||
) -> List[TextChunk]:
|
||||
"""分块文档"""
|
||||
try:
|
||||
chunks = self.chunker.chunk_document(
|
||||
markdown_text,
|
||||
doc_id=doc_id,
|
||||
doc_name=doc_name,
|
||||
regulation_type=regulation_type,
|
||||
version=version
|
||||
)
|
||||
logger.success(f"分块完成,共{len(chunks)}个chunk")
|
||||
return chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"分块失败: {e}")
|
||||
return []
|
||||
|
||||
def _embed_chunks(self, chunks: List[TextChunk]) -> Optional[EmbeddingResult]:
|
||||
"""嵌入分块"""
|
||||
try:
|
||||
# 延迟初始化嵌入模型
|
||||
self._init_embedder()
|
||||
|
||||
# 提取文本内容
|
||||
texts = [chunk.content for chunk in chunks]
|
||||
|
||||
# 执行嵌入
|
||||
embeddings = self.embedder.embed(texts)
|
||||
|
||||
logger.success(f"嵌入完成,向量数: {len(embeddings.dense_embeddings)}")
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"嵌入失败: {e}")
|
||||
return None
|
||||
|
||||
def _insert_to_milvus(
|
||||
self,
|
||||
chunks: List[TextChunk],
|
||||
embeddings: EmbeddingResult
|
||||
) -> List[int]:
|
||||
"""插入Milvus"""
|
||||
try:
|
||||
# 延迟初始化Milvus
|
||||
self._init_milvus()
|
||||
|
||||
# 执行插入
|
||||
inserted_ids = self.milvus.insert_chunks(chunks, embeddings)
|
||||
|
||||
logger.success(f"入库完成,共{len(inserted_ids)}条记录")
|
||||
return inserted_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"入库失败: {e}")
|
||||
return []
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 10,
|
||||
filters: Optional[str] = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
检索法规内容
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
top_k: 返回结果数量
|
||||
filters: 过滤条件
|
||||
|
||||
Returns:
|
||||
List[Dict]: 检索结果
|
||||
"""
|
||||
logger.info(f"执行检索: {query}")
|
||||
|
||||
try:
|
||||
# 延迟初始化
|
||||
self._init_embedder()
|
||||
self._init_milvus()
|
||||
|
||||
# 生成查询向量
|
||||
query_embedding = self.embedder.embed_single(query)
|
||||
|
||||
# 执行混合检索
|
||||
results = self.milvus.hybrid_search(
|
||||
query_dense=query_embedding['dense'].tolist(),
|
||||
query_sparse=query_embedding['sparse'],
|
||||
top_k=top_k,
|
||||
filters=filters
|
||||
)
|
||||
|
||||
# 转换为字典格式
|
||||
result_dicts = []
|
||||
for r in results:
|
||||
result_dicts.append({
|
||||
"id": r.id,
|
||||
"content": r.content,
|
||||
"score": r.score,
|
||||
"metadata": r.metadata
|
||||
})
|
||||
|
||||
logger.success(f"检索完成,返回{len(result_dicts)}条结果")
|
||||
return result_dicts
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检索失败: {e}")
|
||||
return []
|
||||
def search(self, query: str, top_k: int = 10, filters: str | None = None) -> list[dict]:
|
||||
"""Handle search for the Document Processor instance."""
|
||||
results = get_retrieval_service().retrieve(query=query, top_k=top_k, filters=filters)
|
||||
return [
|
||||
{
|
||||
"id": item.chunk_id,
|
||||
"content": item.content,
|
||||
"score": item.score,
|
||||
"metadata": {
|
||||
"doc_id": item.doc_id,
|
||||
"doc_name": item.doc_name,
|
||||
"chunk_id": item.chunk_id,
|
||||
"section_title": item.section_title,
|
||||
"page_number": item.page_number,
|
||||
**item.metadata,
|
||||
},
|
||||
}
|
||||
for item in results
|
||||
]
|
||||
|
||||
def close(self):
|
||||
"""关闭连接"""
|
||||
if self.milvus:
|
||||
self.milvus.disconnect()
|
||||
logger.info("文档处理器已关闭")
|
||||
|
||||
|
||||
def process_document(
|
||||
file_path: str,
|
||||
doc_name: Optional[str] = None,
|
||||
regulation_type: str = "",
|
||||
version: str = ""
|
||||
) -> ProcessingResult:
|
||||
"""便捷函数:处理单个文档"""
|
||||
processor = DocumentProcessor()
|
||||
result = processor.process(file_path, doc_name, regulation_type, version)
|
||||
processor.close()
|
||||
return result
|
||||
|
||||
|
||||
def search_regulations(query: str, top_k: int = 10) -> List[Dict]:
|
||||
"""便捷函数:检索法规"""
|
||||
processor = DocumentProcessor()
|
||||
results = processor.search(query, top_k)
|
||||
processor.close()
|
||||
return results
|
||||
"""Release the resources held by this component."""
|
||||
return None
|
||||
|
||||
@@ -1,6 +1,18 @@
|
||||
"""嵌入和分块服务"""
|
||||
"""Initialize the app.services.embedding package."""
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
from .text_chunker import RegulationChunker
|
||||
from .bge_m3_embedder import BGEM3Embedder
|
||||
|
||||
__all__ = ["RegulationChunker", "BGEM3Embedder"]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
"""Handle getattr for this module."""
|
||||
if name == "RegulationChunker":
|
||||
from .text_chunker import RegulationChunker
|
||||
|
||||
return RegulationChunker
|
||||
if name == "BGEM3Embedder":
|
||||
from .bge_m3_embedder import BGEM3Embedder
|
||||
|
||||
return BGEM3Embedder
|
||||
raise AttributeError(name)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""BGE-M3嵌入服务 - Dense+Sparse双路向量生成"""
|
||||
"""Provide service-layer logic for bge m3 embedder."""
|
||||
|
||||
import numpy as np
|
||||
from typing import List, Dict, Optional, Union
|
||||
@@ -6,43 +6,31 @@ from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
import torch
|
||||
import os
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
# 设置HuggingFace镜像(国内网络)
|
||||
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if 'HF_ENDPOINT' not in os.environ:
|
||||
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
||||
|
||||
# 本地模型路径(按优先级检查)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
LOCAL_MODEL_PATHS = [
|
||||
os.path.expanduser("~/.cache/modelscope/Xorbits/bge-m3"), # ModelScope下载路径
|
||||
os.path.expanduser("~/.cache/huggingface/hub/models--BAAI--bge-m3/snapshots/main"), # HuggingFace本地路径
|
||||
os.path.expanduser("~/.cache/modelscope/Xorbits/bge-m3"), # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
os.path.expanduser("~/.cache/huggingface/hub/models--BAAI--bge-m3/snapshots/main"), # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingResult:
|
||||
"""嵌入结果"""
|
||||
dense_embeddings: np.ndarray # Dense向量(语义检索)
|
||||
sparse_embeddings: List[Dict[int, float]] # Sparse向量(关键词匹配)
|
||||
"""Represent the Embedding Result type."""
|
||||
dense_embeddings: np.ndarray # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
sparse_embeddings: List[Dict[int, float]] # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
texts: List[str]
|
||||
dim: int = 1024
|
||||
|
||||
|
||||
class BGEM3Embedder:
|
||||
"""
|
||||
BGE-M3多语言嵌入模型服务
|
||||
|
||||
BGE-M3是BAAI发布的多语言嵌入模型,支持:
|
||||
- Dense向量:用于语义相似度检索
|
||||
- Sparse向量:用于关键词精确匹配(BM25风格)
|
||||
- ColBERT向量:用于细粒度交互匹配(可选)
|
||||
|
||||
特点:
|
||||
- 支持100+语言(中英双语优化)
|
||||
- 8192 tokens超长上下文
|
||||
- Dense+Sparse双路检索能力
|
||||
|
||||
GitHub: https://github.com/FlagOpen/FlagEmbedding
|
||||
"""
|
||||
"""Represent the B G E M3 Embedder type."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -53,28 +41,18 @@ class BGEM3Embedder:
|
||||
max_length: int = 8192,
|
||||
local_model_path: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
初始化BGE-M3嵌入模型
|
||||
|
||||
Args:
|
||||
model_name: 模型名称(如果使用本地路径,此参数会被忽略)
|
||||
use_fp16: 是否使用FP16加速
|
||||
device: 设备类型(cuda/cpu),默认自动选择
|
||||
batch_size: 批处理大小
|
||||
max_length: 最大序列长度
|
||||
local_model_path: 本地模型路径(可选,优先使用)
|
||||
"""
|
||||
"""Initialize the B G E M3 Embedder instance."""
|
||||
self.use_fp16 = use_fp16
|
||||
self.batch_size = batch_size
|
||||
self.max_length = max_length
|
||||
|
||||
# 确定模型路径(优先使用本地路径)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if local_model_path and os.path.exists(local_model_path):
|
||||
self.model_path = local_model_path
|
||||
self.model_name = "local"
|
||||
logger.info(f"使用本地模型路径: {local_model_path}")
|
||||
else:
|
||||
# 检查多个可能的本地路径
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
found_local = False
|
||||
for path in LOCAL_MODEL_PATHS:
|
||||
if os.path.exists(path) and os.path.exists(os.path.join(path, "config.json")):
|
||||
@@ -89,7 +67,7 @@ class BGEM3Embedder:
|
||||
self.model_name = model_name
|
||||
logger.info(f"使用远程模型: {model_name}")
|
||||
|
||||
# 自动选择设备
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if device is None:
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
else:
|
||||
@@ -101,7 +79,7 @@ class BGEM3Embedder:
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self):
|
||||
"""加载嵌入模型"""
|
||||
"""Handle load model for this module for the B G E M3 Embedder instance."""
|
||||
try:
|
||||
from FlagEmbedding import BGEM3FlagModel
|
||||
|
||||
@@ -127,18 +105,7 @@ class BGEM3Embedder:
|
||||
return_sparse: bool = True,
|
||||
return_colbert_vecs: bool = False
|
||||
) -> EmbeddingResult:
|
||||
"""
|
||||
对文本列表生成嵌入向量
|
||||
|
||||
Args:
|
||||
texts: 文本列表
|
||||
return_dense: 是否返回Dense向量
|
||||
return_sparse: 是否返回Sparse向量
|
||||
return_colbert_vecs: 是否返回ColBERT向量
|
||||
|
||||
Returns:
|
||||
EmbeddingResult: 嵌入结果
|
||||
"""
|
||||
"""Handle embed for the B G E M3 Embedder instance."""
|
||||
if not texts:
|
||||
logger.warning("输入文本列表为空")
|
||||
return EmbeddingResult(
|
||||
@@ -151,7 +118,7 @@ class BGEM3Embedder:
|
||||
logger.info(f"开始嵌入{len(texts)}个文本块")
|
||||
|
||||
try:
|
||||
# 执行嵌入
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
embeddings = self.model.encode(
|
||||
texts,
|
||||
batch_size=self.batch_size,
|
||||
@@ -161,11 +128,11 @@ class BGEM3Embedder:
|
||||
return_colbert_vecs=return_colbert_vecs
|
||||
)
|
||||
|
||||
# 提取结果
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
dense_embeddings = embeddings.get('dense_vecs', np.array([]))
|
||||
sparse_embeddings = embeddings.get('lexical_weights', [])
|
||||
|
||||
# 获取维度
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
dim = dense_embeddings.shape[1] if len(dense_embeddings) > 0 else 1024
|
||||
|
||||
logger.success(f"嵌入完成,向量维度: {dim}")
|
||||
@@ -182,15 +149,7 @@ class BGEM3Embedder:
|
||||
raise
|
||||
|
||||
def embed_single(self, text: str) -> Dict[str, Union[np.ndarray, Dict]]:
|
||||
"""
|
||||
对单个文本生成嵌入向量
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
|
||||
Returns:
|
||||
Dict: 包含dense和sparse向量
|
||||
"""
|
||||
"""Embed single for the B G E M3 Embedder instance."""
|
||||
result = self.embed([text])
|
||||
return {
|
||||
'dense': result.dense_embeddings[0],
|
||||
@@ -199,25 +158,17 @@ class BGEM3Embedder:
|
||||
}
|
||||
|
||||
def embed_dense(self, texts: List[str]) -> np.ndarray:
|
||||
"""只生成Dense向量"""
|
||||
"""Embed dense for the B G E M3 Embedder instance."""
|
||||
result = self.embed(texts, return_sparse=False, return_colbert_vecs=False)
|
||||
return result.dense_embeddings
|
||||
|
||||
def embed_sparse(self, texts: List[str]) -> List[Dict[int, float]]:
|
||||
"""只生成Sparse向量"""
|
||||
"""Embed sparse for the B G E M3 Embedder instance."""
|
||||
result = self.embed(texts, return_dense=False, return_colbert_vecs=False)
|
||||
return result.sparse_embeddings
|
||||
|
||||
def embed_query(self, query: str) -> Dict:
|
||||
"""
|
||||
对查询文本生成嵌入(用于检索)
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
|
||||
Returns:
|
||||
Dict: 包含dense和sparse向量
|
||||
"""
|
||||
"""Embed query for the B G E M3 Embedder instance."""
|
||||
return self.embed_single(query)
|
||||
|
||||
def compute_similarity(
|
||||
@@ -226,26 +177,16 @@ class BGEM3Embedder:
|
||||
doc_embeddings: np.ndarray,
|
||||
metric: str = "cosine"
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
计算查询与文档的相似度
|
||||
|
||||
Args:
|
||||
query_embedding: 查询向量
|
||||
doc_embeddings: 文档向量矩阵
|
||||
metric: 相似度度量(cosine/dot)
|
||||
|
||||
Returns:
|
||||
np.ndarray: 相似度分数数组
|
||||
"""
|
||||
"""Handle compute similarity for the B G E M3 Embedder instance."""
|
||||
if metric == "cosine":
|
||||
# 余弦相似度
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
query_norm = np.linalg.norm(query_embedding)
|
||||
doc_norms = np.linalg.norm(doc_embeddings, axis=1)
|
||||
|
||||
similarities = np.dot(doc_embeddings, query_embedding) / (doc_norms * query_norm)
|
||||
|
||||
elif metric == "dot":
|
||||
# 点积相似度
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
similarities = np.dot(doc_embeddings, query_embedding)
|
||||
|
||||
else:
|
||||
@@ -258,17 +199,8 @@ class BGEM3Embedder:
|
||||
query_sparse: Dict[int, float],
|
||||
doc_sparse: Dict[int, float]
|
||||
) -> float:
|
||||
"""
|
||||
计算Sparse向量的相似度(BM25风格)
|
||||
|
||||
Args:
|
||||
query_sparse: 查询的Sparse向量(词ID -> 权重)
|
||||
doc_sparse: 文档的Sparse向量
|
||||
|
||||
Returns:
|
||||
float: 相似度分数
|
||||
"""
|
||||
# 计算交集词的点积
|
||||
"""Handle sparse similarity for the B G E M3 Embedder instance."""
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
common_keys = set(query_sparse.keys()) & set(doc_sparse.keys())
|
||||
|
||||
score = sum(query_sparse[k] * doc_sparse[k] for k in common_keys)
|
||||
@@ -280,7 +212,7 @@ def embed_texts(
|
||||
model_name: str = "BAAI/bge-m3",
|
||||
**kwargs
|
||||
) -> EmbeddingResult:
|
||||
"""便捷函数:对文本列表生成嵌入"""
|
||||
"""Embed texts."""
|
||||
embedder = BGEM3Embedder(model_name=model_name, **kwargs)
|
||||
return embedder.embed(texts)
|
||||
|
||||
@@ -290,6 +222,6 @@ def embed_single_text(
|
||||
model_name: str = "BAAI/bge-m3",
|
||||
**kwargs
|
||||
) -> Dict:
|
||||
"""便捷函数:对单个文本生成嵌入"""
|
||||
"""Embed single text."""
|
||||
embedder = BGEM3Embedder(model_name=model_name, **kwargs)
|
||||
return embedder.embed_single(text)
|
||||
|
||||
@@ -1,51 +1,46 @@
|
||||
"""智能分块器 - 章节级+条款级双粒度切割"""
|
||||
"""Provide service-layer logic for text chunker."""
|
||||
|
||||
import re
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkMetadata:
|
||||
"""分块元数据"""
|
||||
"""Represent the Chunk Metadata type."""
|
||||
doc_id: str = ""
|
||||
doc_name: str = ""
|
||||
chunk_id: str = ""
|
||||
section_number: str = "" # 章节编号(如 "第一章")
|
||||
section_title: str = "" # 章节标题
|
||||
clause_number: str = "" # 条款编号(如 "第一条")
|
||||
section_number: str = "" # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
section_title: str = "" # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
clause_number: str = "" # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
page_number: int = 0
|
||||
start_position: int = 0 # 在原文中的起始位置
|
||||
end_position: int = 0 # 在原文中的结束位置
|
||||
regulation_type: str = "" # 法规类型
|
||||
start_position: int = 0 # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
end_position: int = 0 # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
regulation_type: str = "" # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
version: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextChunk:
|
||||
"""文本分块"""
|
||||
"""Represent the Text Chunk type."""
|
||||
content: str
|
||||
metadata: ChunkMetadata
|
||||
token_count: int = 0 # 估算的token数量
|
||||
token_count: int = 0 # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
|
||||
class RegulationChunker:
|
||||
"""
|
||||
法规文档智能分块器
|
||||
"""Represent the Regulation Chunker type."""
|
||||
|
||||
实现章节级/条款级双粒度切割,适配国标GB文档结构:
|
||||
- 国标文档通常有明确的层级结构:章 > 节 > 条
|
||||
- 每个条款应作为一个独立的语义单元
|
||||
- 保留条款完整性,避免跨条款截断
|
||||
"""
|
||||
|
||||
# 法规标题模式
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
CHAPTER_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+章\s+[^\n]+')
|
||||
SECTION_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+节\s+[^\n]+')
|
||||
CLAUSE_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+条\s')
|
||||
|
||||
# 条款子项模式
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
SUB_ITEM_PATTERN = re.compile(r'^[\((][一二三四五六七八九十]+[\))]\s')
|
||||
NUMBER_ITEM_PATTERN = re.compile(r'^[\d]+[\.、]\s')
|
||||
|
||||
@@ -56,15 +51,7 @@ class RegulationChunker:
|
||||
max_chunk_size: int = 2048,
|
||||
min_chunk_size: int = 100
|
||||
):
|
||||
"""
|
||||
初始化分块器
|
||||
|
||||
Args:
|
||||
chunk_size: 默认分块大小(字符数)
|
||||
chunk_overlap: 分块重叠大小
|
||||
max_chunk_size: 最大分块大小(防止单个条款过长)
|
||||
min_chunk_size: 最小分块大小(防止碎片化)
|
||||
"""
|
||||
"""Initialize the Regulation Chunker instance."""
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
self.max_chunk_size = max_chunk_size
|
||||
@@ -78,30 +65,18 @@ class RegulationChunker:
|
||||
regulation_type: str = "",
|
||||
version: str = ""
|
||||
) -> List[TextChunk]:
|
||||
"""
|
||||
对法规文档进行智能分块
|
||||
|
||||
Args:
|
||||
markdown_text: Markdown格式的文档内容
|
||||
doc_id: 文档ID
|
||||
doc_name: 文档名称
|
||||
regulation_type: 法规类型
|
||||
version: 文档版本
|
||||
|
||||
Returns:
|
||||
List[TextChunk]: 分块列表
|
||||
"""
|
||||
"""Handle chunk document for the Regulation Chunker instance."""
|
||||
logger.info(f"开始分块文档: {doc_name}")
|
||||
|
||||
# 1. 按章节分割(一级分块)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
sections = self._split_by_sections(markdown_text)
|
||||
|
||||
# 2. 在每个章节内按条款分割(二级分块)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
chunks = []
|
||||
global_position = 0
|
||||
|
||||
for section_num, section_title, section_content, section_start in sections:
|
||||
# 在章节内按条款分割
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
clause_chunks = self._split_by_clauses(
|
||||
section_content,
|
||||
section_num,
|
||||
@@ -110,7 +85,7 @@ class RegulationChunker:
|
||||
)
|
||||
|
||||
for chunk_content, clause_num, clause_title, start_pos, end_pos in clause_chunks:
|
||||
# 处理过长的条款(进一步细分)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if len(chunk_content) > self.max_chunk_size:
|
||||
sub_chunks = self._split_long_clause(
|
||||
chunk_content,
|
||||
@@ -150,12 +125,7 @@ class RegulationChunker:
|
||||
return chunks
|
||||
|
||||
def _split_by_sections(self, markdown_text: str) -> List[Tuple[str, str, str, int]]:
|
||||
"""
|
||||
按章节分割文档
|
||||
|
||||
Returns:
|
||||
List of (section_number, section_title, section_content, start_position)
|
||||
"""
|
||||
"""Handle split by sections for this module for the Regulation Chunker instance."""
|
||||
sections = []
|
||||
lines = markdown_text.split('\n')
|
||||
|
||||
@@ -165,12 +135,12 @@ class RegulationChunker:
|
||||
current_section_start = 0
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
# 检测章节标题
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
chapter_match = self.CHAPTER_PATTERN.match(line.strip())
|
||||
section_match = self.SECTION_PATTERN.match(line.strip())
|
||||
|
||||
if chapter_match or section_match:
|
||||
# 保存上一个章节
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if current_section_content:
|
||||
content = '\n'.join(current_section_content)
|
||||
sections.append((
|
||||
@@ -180,7 +150,7 @@ class RegulationChunker:
|
||||
current_section_start
|
||||
))
|
||||
|
||||
# 开始新章节
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
current_section_start = sum(len(l) + 1 for l in lines[:i])
|
||||
current_section_content = []
|
||||
|
||||
@@ -193,7 +163,7 @@ class RegulationChunker:
|
||||
|
||||
current_section_content.append(line)
|
||||
|
||||
# 保存最后一个章节
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if current_section_content:
|
||||
content = '\n'.join(current_section_content)
|
||||
sections.append((
|
||||
@@ -203,7 +173,7 @@ class RegulationChunker:
|
||||
current_section_start
|
||||
))
|
||||
|
||||
# 如果没有检测到章节,将整个文档作为一个大章节
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if not sections:
|
||||
sections.append((
|
||||
"",
|
||||
@@ -221,12 +191,7 @@ class RegulationChunker:
|
||||
section_title: str,
|
||||
section_start: int
|
||||
) -> List[Tuple[str, str, str, int, int]]:
|
||||
"""
|
||||
在章节内按条款分割
|
||||
|
||||
Returns:
|
||||
List of (content, clause_number, clause_title, start_position, end_position)
|
||||
"""
|
||||
"""Handle split by clauses for this module for the Regulation Chunker instance."""
|
||||
clauses = []
|
||||
lines = section_content.split('\n')
|
||||
|
||||
@@ -236,11 +201,11 @@ class RegulationChunker:
|
||||
current_clause_start = section_start
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
# 检测条款标题
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
clause_match = self.CLAUSE_PATTERN.match(line.strip())
|
||||
|
||||
if clause_match:
|
||||
# 保存上一个条款
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if current_clause_content:
|
||||
content = '\n'.join(current_clause_content)
|
||||
end_pos = current_clause_start + len(content)
|
||||
@@ -252,7 +217,7 @@ class RegulationChunker:
|
||||
end_pos
|
||||
))
|
||||
|
||||
# 开始新条款
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
current_clause_start = section_start + sum(len(l) + 1 for l in lines[:i])
|
||||
current_clause_content = []
|
||||
current_clause_num = self._extract_clause_number(line.strip())
|
||||
@@ -260,7 +225,7 @@ class RegulationChunker:
|
||||
|
||||
current_clause_content.append(line)
|
||||
|
||||
# 保存最后一个条款
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if current_clause_content:
|
||||
content = '\n'.join(current_clause_content)
|
||||
end_pos = current_clause_start + len(content)
|
||||
@@ -272,7 +237,7 @@ class RegulationChunker:
|
||||
end_pos
|
||||
))
|
||||
|
||||
# 如果没有检测到条款,将整个章节作为一个条款
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if not clauses:
|
||||
clauses.append((
|
||||
section_content,
|
||||
@@ -290,15 +255,11 @@ class RegulationChunker:
|
||||
clause_num: str,
|
||||
clause_title: str
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
"""
|
||||
分割过长的条款内容
|
||||
|
||||
按条款子项或段落分割,保持语义完整性
|
||||
"""
|
||||
"""Handle split long clause for this module for the Regulation Chunker instance."""
|
||||
sub_chunks = []
|
||||
lines = content.split('\n')
|
||||
|
||||
# 检测是否有子项结构
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
has_sub_items = any(
|
||||
self.SUB_ITEM_PATTERN.match(line.strip()) or
|
||||
self.NUMBER_ITEM_PATTERN.match(line.strip())
|
||||
@@ -306,7 +267,7 @@ class RegulationChunker:
|
||||
)
|
||||
|
||||
if has_sub_items:
|
||||
# 按子项分割
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
current_sub_content = []
|
||||
current_sub_start = 0
|
||||
|
||||
@@ -326,14 +287,14 @@ class RegulationChunker:
|
||||
|
||||
current_sub_content.append(line)
|
||||
|
||||
# 保存最后一个子项
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if current_sub_content:
|
||||
sub_content = '\n'.join(current_sub_content)
|
||||
sub_end = current_sub_start + len(sub_content)
|
||||
sub_chunks.append((sub_content, current_sub_start, sub_end))
|
||||
|
||||
else:
|
||||
# 按段落分割(滑动窗口)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
paragraphs = []
|
||||
current_para = []
|
||||
|
||||
@@ -348,7 +309,7 @@ class RegulationChunker:
|
||||
if current_para:
|
||||
paragraphs.append('\n'.join(current_para))
|
||||
|
||||
# 合并段落直到达到chunk_size
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
chunk_start = 0
|
||||
@@ -365,7 +326,7 @@ class RegulationChunker:
|
||||
current_chunk.append(para)
|
||||
current_length += len(para)
|
||||
|
||||
# 保存最后一个chunk
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if current_chunk:
|
||||
chunk_content = '\n'.join(current_chunk)
|
||||
chunk_end = chunk_start + len(chunk_content)
|
||||
@@ -374,13 +335,13 @@ class RegulationChunker:
|
||||
return sub_chunks
|
||||
|
||||
def _extract_title(self, header_line: str) -> str:
|
||||
"""从标题行提取标题内容"""
|
||||
# 移除"第X章"、"第X节"前缀
|
||||
"""Handle extract title for this module for the Regulation Chunker instance."""
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
title = re.sub(r'^第[一二三四五六七八九十百]+[章节]\s+', '', header_line)
|
||||
return title.strip()
|
||||
|
||||
def _extract_clause_number(self, clause_line: str) -> str:
|
||||
"""从条款行提取条款编号"""
|
||||
"""Handle extract clause number for this module for the Regulation Chunker instance."""
|
||||
match = self.CLAUSE_PATTERN.match(clause_line)
|
||||
if match:
|
||||
return match.group(0).strip()
|
||||
@@ -399,14 +360,14 @@ class RegulationChunker:
|
||||
regulation_type: str,
|
||||
version: str
|
||||
) -> TextChunk:
|
||||
"""创建文本分块"""
|
||||
# 清理内容
|
||||
"""Handle create chunk for this module for the Regulation Chunker instance."""
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
content = content.strip()
|
||||
|
||||
# 计算估算token数(中文约1.5字符/token)
|
||||
token_count = int(len(content) * 0.7) # 简化估算
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
token_count = int(len(content) * 0.7) # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
# 生成chunk_id
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
chunk_id = f"{doc_id}_{section_num}_{clause_num}_{start_pos}"
|
||||
|
||||
metadata = ChunkMetadata(
|
||||
@@ -437,7 +398,7 @@ def chunk_regulation_document(
|
||||
version: str = "",
|
||||
chunk_size: int = 512
|
||||
) -> List[TextChunk]:
|
||||
"""便捷函数:对法规文档进行分块"""
|
||||
"""Handle chunk regulation document."""
|
||||
chunker = RegulationChunker(chunk_size=chunk_size)
|
||||
return chunker.chunk_document(
|
||||
markdown_text,
|
||||
|
||||
@@ -1,14 +1,36 @@
|
||||
"""LLM服务模块"""
|
||||
"""Initialize the app.services.llm package."""
|
||||
|
||||
from .llm_factory import LLMFactory, get_llm_client
|
||||
from .base_client import BaseLLMClient, LLMResponse, LLMConfig, LLMProvider
|
||||
from .base_client import BaseLLMClient, LLMConfig, LLMProvider, LLMResponse
|
||||
from .deepseek_client import DeepSeekClient
|
||||
from .llm_factory import LLMFactory, get_llm_client
|
||||
from .qwen_client import QwenClient, QwenVLClient
|
||||
from .document_summarizer import DocumentSummarizer, summarize_document, DocumentSummary
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LLMFactory", "get_llm_client",
|
||||
"BaseLLMClient", "LLMResponse", "LLMConfig", "LLMProvider",
|
||||
"DeepSeekClient", "QwenClient", "QwenVLClient",
|
||||
"DocumentSummarizer", "summarize_document", "DocumentSummary"
|
||||
"LLMFactory",
|
||||
"get_llm_client",
|
||||
"BaseLLMClient",
|
||||
"LLMResponse",
|
||||
"LLMConfig",
|
||||
"LLMProvider",
|
||||
"DeepSeekClient",
|
||||
"QwenClient",
|
||||
"QwenVLClient",
|
||||
"DocumentSummarizer",
|
||||
"summarize_document",
|
||||
"DocumentSummary",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
"""Handle getattr for this module."""
|
||||
if name in {"DocumentSummarizer", "summarize_document", "DocumentSummary"}:
|
||||
from .document_summarizer import DocumentSummarizer, DocumentSummary, summarize_document
|
||||
|
||||
return {
|
||||
"DocumentSummarizer": DocumentSummarizer,
|
||||
"summarize_document": summarize_document,
|
||||
"DocumentSummary": DocumentSummary,
|
||||
}[name]
|
||||
raise AttributeError(name)
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
"""LLM客户端基类 - 统一接口定义"""
|
||||
"""Provide service-layer logic for base client."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Optional, Any
|
||||
from enum import Enum
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
|
||||
|
||||
|
||||
class LLMProvider(Enum):
|
||||
"""LLM提供商"""
|
||||
"""Define the L L M Provider enumeration."""
|
||||
DEEPSEEK = "deepseek"
|
||||
QWEN = "qwen"
|
||||
QWEN_VL = "qwen_vl"
|
||||
@@ -15,7 +17,7 @@ class LLMProvider(Enum):
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""LLM响应结果"""
|
||||
"""Represent the L L M Response type."""
|
||||
content: str
|
||||
model: str
|
||||
usage: Dict[str, int] = field(default_factory=dict)
|
||||
@@ -25,12 +27,13 @@ class LLMResponse:
|
||||
|
||||
@property
|
||||
def is_success(self) -> bool:
|
||||
"""Return whether success for the L L M Response instance."""
|
||||
return self.error is None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
"""LLM配置"""
|
||||
"""Define configuration for l l m config."""
|
||||
provider: LLMProvider
|
||||
model: str
|
||||
api_key: str
|
||||
@@ -38,19 +41,20 @@ class LLMConfig:
|
||||
max_tokens: int = 4096
|
||||
temperature: float = 0.7
|
||||
top_p: float = 0.9
|
||||
timeout: int = 300 # 默认超时300秒(摘要/Skills生成可能需要较长时间)
|
||||
timeout: int = 300 # Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
|
||||
|
||||
class BaseLLMClient(ABC):
|
||||
"""LLM客户端基类"""
|
||||
"""Represent the Base L L M Client type."""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
"""Initialize the Base L L M Client instance."""
|
||||
self.config = config
|
||||
self._client = None
|
||||
|
||||
@abstractmethod
|
||||
def _init_client(self):
|
||||
"""初始化客户端"""
|
||||
"""Handle init client for this module for the Base L L M Client instance."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -61,18 +65,7 @@ class BaseLLMClient(ABC):
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
对话补全
|
||||
|
||||
Args:
|
||||
messages: 对话消息列表 [{"role": "user/assistant/system", "content": "..."}]
|
||||
max_tokens: 最大输出token数
|
||||
temperature: 温度参数
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
LLMResponse: 响应结果
|
||||
"""
|
||||
"""Handle chat for the Base L L M Client instance."""
|
||||
pass
|
||||
|
||||
def complete(
|
||||
@@ -83,18 +76,7 @@ class BaseLLMClient(ABC):
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
单轮补全(便捷方法)
|
||||
|
||||
Args:
|
||||
prompt: 用户输入
|
||||
system_prompt: 系统提示词
|
||||
max_tokens: 最大输出token数
|
||||
temperature: 温度参数
|
||||
|
||||
Returns:
|
||||
LLMResponse: 响应结果
|
||||
"""
|
||||
"""Handle complete for the Base L L M Client instance."""
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
@@ -104,12 +86,12 @@ class BaseLLMClient(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_available_models(self) -> List[str]:
|
||||
"""获取可用模型列表"""
|
||||
"""Return available models for the Base L L M Client instance."""
|
||||
pass
|
||||
|
||||
def estimate_tokens(self, text: str) -> int:
|
||||
"""估算文本token数(粗略估计)"""
|
||||
# 中文字符约1.5 token,英文约0.25 token
|
||||
"""Handle estimate tokens for the Base L L M Client instance."""
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
chinese_chars = sum(1 for c in text if '一' <= c <= '鿿')
|
||||
other_chars = len(text) - chinese_chars
|
||||
return int(chinese_chars * 1.5 + other_chars * 0.25)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""DeepSeek LLM客户端 - OpenAI兼容API"""
|
||||
"""Provide service-layer logic for deepseek client."""
|
||||
|
||||
import time
|
||||
from typing import List, Dict, Optional
|
||||
@@ -6,20 +6,12 @@ from loguru import logger
|
||||
import httpx
|
||||
|
||||
from .base_client import BaseLLMClient, LLMResponse, LLMConfig, LLMProvider
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
|
||||
|
||||
|
||||
class DeepSeekClient(BaseLLMClient):
|
||||
"""
|
||||
DeepSeek API客户端(OpenAI兼容格式)
|
||||
|
||||
支持模型:
|
||||
- deepseek-chat
|
||||
- deepseek-coder
|
||||
- deepseek-reasoner
|
||||
- deepseek-v3
|
||||
- deepseek-v3.2
|
||||
- deepseek-v4-flash
|
||||
"""
|
||||
"""Represent the Deep Seek Client type."""
|
||||
|
||||
SUPPORTED_MODELS = [
|
||||
"deepseek-chat",
|
||||
@@ -31,13 +23,14 @@ class DeepSeekClient(BaseLLMClient):
|
||||
]
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
"""Initialize the Deep Seek Client instance."""
|
||||
if config.provider != LLMProvider.DEEPSEEK:
|
||||
raise ValueError(f"配置provider应为DEEPSEEK,实际为{config.provider}")
|
||||
super().__init__(config)
|
||||
self._init_client()
|
||||
|
||||
def _init_client(self):
|
||||
"""初始化HTTP客户端"""
|
||||
"""Handle init client for this module for the Deep Seek Client instance."""
|
||||
self._client = httpx.Client(
|
||||
base_url=self.config.base_url,
|
||||
headers={
|
||||
@@ -55,7 +48,7 @@ class DeepSeekClient(BaseLLMClient):
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
"""对话补全"""
|
||||
"""Handle chat for the Deep Seek Client instance."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
@@ -103,11 +96,11 @@ class DeepSeekClient(BaseLLMClient):
|
||||
)
|
||||
|
||||
def get_available_models(self) -> List[str]:
|
||||
"""获取可用模型列表"""
|
||||
"""Return available models for the Deep Seek Client instance."""
|
||||
return self.SUPPORTED_MODELS
|
||||
|
||||
def close(self):
|
||||
"""关闭客户端"""
|
||||
"""Release the resources held by this component."""
|
||||
if self._client:
|
||||
self._client.close()
|
||||
|
||||
@@ -118,7 +111,7 @@ def create_deepseek_client(
|
||||
base_url: str = "http://6.86.80.4:30080/v1",
|
||||
**kwargs
|
||||
) -> DeepSeekClient:
|
||||
"""便捷函数:创建DeepSeek客户端"""
|
||||
"""Create deepseek client."""
|
||||
config = LLMConfig(
|
||||
provider=LLMProvider.DEEPSEEK,
|
||||
model=model,
|
||||
|
||||
@@ -1,17 +1,20 @@
|
||||
"""文档摘要生成服务 - LLM生成法规文档摘要"""
|
||||
"""Provide service-layer logic for document summarizer."""
|
||||
|
||||
from typing import Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
from loguru import logger
|
||||
|
||||
from app.services.llm import get_llm_client, BaseLLMClient
|
||||
from app.services.llm.base_client import BaseLLMClient
|
||||
from app.services.llm.llm_factory import get_llm_client
|
||||
from app.services.rag.prompt_templates import get_prompt_template
|
||||
from app.config.settings import settings
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentSummary:
|
||||
"""文档摘要结果"""
|
||||
"""Represent the Document Summary type."""
|
||||
doc_name: str
|
||||
summary: str
|
||||
applicable_scope: str
|
||||
@@ -24,24 +27,12 @@ class DocumentSummary:
|
||||
|
||||
@property
|
||||
def is_success(self) -> bool:
|
||||
"""Return whether success for the Document Summary instance."""
|
||||
return self.error is None
|
||||
|
||||
|
||||
class DocumentSummarizer:
|
||||
"""
|
||||
文档摘要生成器
|
||||
|
||||
功能:
|
||||
- 生成法规文档的核心要点摘要
|
||||
- 提取适用范围
|
||||
- 突出关键条款
|
||||
- 列出合规要点
|
||||
|
||||
使用示例:
|
||||
summarizer = DocumentSummarizer()
|
||||
result = summarizer.summarize("GB 7258-2017", markdown_content)
|
||||
print(result.summary)
|
||||
"""
|
||||
"""Represent the Document Summarizer type."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -49,25 +40,18 @@ class DocumentSummarizer:
|
||||
model: str = None,
|
||||
max_tokens: int = None
|
||||
):
|
||||
"""
|
||||
初始化摘要生成器
|
||||
|
||||
Args:
|
||||
provider: LLM提供商
|
||||
model: LLM模型名称
|
||||
max_tokens: 最大输出token数
|
||||
"""
|
||||
"""Initialize the Document Summarizer instance."""
|
||||
self.provider = provider or settings.llm_provider
|
||||
self.model = model or settings.llm_model
|
||||
self.max_tokens = max_tokens or settings.rag_summary_max_tokens
|
||||
|
||||
# LLM客户端(延迟加载)
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
self.llm: Optional[BaseLLMClient] = None
|
||||
|
||||
logger.info(f"摘要生成器初始化: provider={self.provider}, model={self.model}")
|
||||
|
||||
def _init_llm(self):
|
||||
"""延迟初始化LLM"""
|
||||
"""Handle init llm for this module for the Document Summarizer instance."""
|
||||
if self.llm is None:
|
||||
self.llm = get_llm_client(
|
||||
provider=self.provider,
|
||||
@@ -81,18 +65,7 @@ class DocumentSummarizer:
|
||||
regulation_type: str = "",
|
||||
max_tokens: Optional[int] = None
|
||||
) -> DocumentSummary:
|
||||
"""
|
||||
生成文档摘要
|
||||
|
||||
Args:
|
||||
doc_name: 文档名称
|
||||
content: 文档内容(Markdown格式)
|
||||
regulation_type: 法规类型
|
||||
max_tokens: 最大输出token数
|
||||
|
||||
Returns:
|
||||
DocumentSummary: 摘要结果
|
||||
"""
|
||||
"""Handle summarize for the Document Summarizer instance."""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
@@ -101,23 +74,23 @@ class DocumentSummarizer:
|
||||
try:
|
||||
self._init_llm()
|
||||
|
||||
# 使用摘要模板
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
template = get_prompt_template("document_summary")
|
||||
|
||||
# 构建用户消息
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
user_content = template.user_template.format(
|
||||
doc_name=doc_name,
|
||||
content=content[:8000] # 截取前8000字符(避免超出token限制)
|
||||
content=content[:8000] # Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
)
|
||||
|
||||
# 调用LLM
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
response = self.llm.chat(
|
||||
messages=[
|
||||
{"role": "system", "content": template.system_prompt},
|
||||
{"role": "user", "content": user_content}
|
||||
],
|
||||
max_tokens=max_tokens or self.max_tokens,
|
||||
temperature=0.3 # 低温度保证摘要准确性
|
||||
temperature=0.3 # Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
)
|
||||
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
@@ -135,7 +108,7 @@ class DocumentSummarizer:
|
||||
error=response.error
|
||||
)
|
||||
|
||||
# 解析摘要结构
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
summary_data = self._parse_summary(response.content)
|
||||
|
||||
logger.success(f"摘要生成完成: {doc_name}, {latency_ms}ms")
|
||||
@@ -166,7 +139,7 @@ class DocumentSummarizer:
|
||||
)
|
||||
|
||||
def _parse_summary(self, content: str) -> Dict:
|
||||
"""解析摘要内容(提取结构化信息)"""
|
||||
"""Handle parse summary for this module for the Document Summarizer instance."""
|
||||
result = {
|
||||
"summary": content,
|
||||
"applicable_scope": "",
|
||||
@@ -175,26 +148,26 @@ class DocumentSummarizer:
|
||||
"compliance_points": []
|
||||
}
|
||||
|
||||
# 简单解析(提取关键信息)
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
lines = content.split("\n")
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
|
||||
# 提取适用范围
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
if "适用范围" in line or "适用对象" in line:
|
||||
result["applicable_scope"] = line.split(":")[-1].strip() if ":" in line else line.split(":")[-1].strip()
|
||||
|
||||
# 提取关键条款
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
if line.startswith("- 【条款") or line.startswith("【条款"):
|
||||
result["key_clauses"].append(line)
|
||||
|
||||
# 提取关键术语
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
if "关键术语" in line or "术语定义" in line:
|
||||
# 继续读取后续几行
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
pass
|
||||
|
||||
# 提取合规要点
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
if "合规要点" in line or "必须满足" in line:
|
||||
pass
|
||||
|
||||
@@ -204,15 +177,7 @@ class DocumentSummarizer:
|
||||
self,
|
||||
documents: list
|
||||
) -> list:
|
||||
"""
|
||||
批量生成摘要
|
||||
|
||||
Args:
|
||||
documents: 文档列表 [{"doc_name": str, "content": str}, ...]
|
||||
|
||||
Returns:
|
||||
list: 摘要结果列表
|
||||
"""
|
||||
"""Handle batch summarize for the Document Summarizer instance."""
|
||||
results = []
|
||||
for doc in documents:
|
||||
result = self.summarize(doc["doc_name"], doc["content"])
|
||||
@@ -225,6 +190,6 @@ def summarize_document(
|
||||
content: str,
|
||||
**kwargs
|
||||
) -> DocumentSummary:
|
||||
"""便捷函数:生成文档摘要"""
|
||||
"""Handle summarize document."""
|
||||
summarizer = DocumentSummarizer(**kwargs)
|
||||
return summarizer.summarize(doc_name, content)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""LLM工厂 - 统一创建和管理LLM客户端"""
|
||||
"""Provide service-layer logic for llm factory."""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from loguru import logger
|
||||
@@ -7,16 +7,18 @@ from functools import lru_cache
|
||||
from .base_client import BaseLLMClient, LLMConfig, LLMProvider, LLMResponse
|
||||
from .deepseek_client import DeepSeekClient
|
||||
from .qwen_client import QwenClient, QwenVLClient
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
|
||||
|
||||
# 默认模型映射
|
||||
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
DEFAULT_MODELS = {
|
||||
LLMProvider.DEEPSEEK: "deepseek-v4-flash",
|
||||
LLMProvider.QWEN: "qwen3.5-flash",
|
||||
LLMProvider.QWEN_VL: "qwen3-vl-plus"
|
||||
}
|
||||
|
||||
# API基础URL(使用统一代理服务)
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
DEFAULT_BASE_URLS = {
|
||||
LLMProvider.DEEPSEEK: "http://6.86.80.4:30080/v1",
|
||||
LLMProvider.QWEN: "http://6.86.80.4:30080/v1",
|
||||
@@ -25,31 +27,13 @@ DEFAULT_BASE_URLS = {
|
||||
|
||||
|
||||
class LLMFactory:
|
||||
"""
|
||||
LLM客户端工厂(支持全局缓存)
|
||||
"""Represent the L L M Factory type."""
|
||||
|
||||
支持的提供商和模型:
|
||||
- DeepSeek: deepseek-chat (DeepSeek-V3), deepseek-coder
|
||||
- Qwen: qwen-turbo, qwen-plus, qwen-max, qwen-long
|
||||
- QwenVL: qwen-vl-plus, qwen-vl-max (多模态)
|
||||
|
||||
使用示例:
|
||||
factory = LLMFactory()
|
||||
|
||||
# 使用默认配置
|
||||
client = factory.create("deepseek")
|
||||
|
||||
# 自定义配置
|
||||
client = factory.create("qwen", model="qwen-max", temperature=0.5)
|
||||
|
||||
# 调用LLM
|
||||
response = client.complete("你好,介绍一下自己")
|
||||
"""
|
||||
|
||||
# 全局客户端缓存(类级别,跨实例共享)
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
_global_instances: Dict[str, BaseLLMClient] = {}
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the L L M Factory instance."""
|
||||
self._config_cache: Dict[str, Any] = {}
|
||||
|
||||
def create(
|
||||
@@ -62,24 +46,10 @@ class LLMFactory:
|
||||
temperature: float = 0.7,
|
||||
**kwargs
|
||||
) -> BaseLLMClient:
|
||||
"""
|
||||
创建LLM客户端
|
||||
|
||||
Args:
|
||||
provider: 提供商名称 ("deepseek", "qwen", "qwen_vl")
|
||||
api_key: API密钥(如未提供,从环境变量获取)
|
||||
model: 模型名称(如未提供,使用默认模型)
|
||||
base_url: API基础URL
|
||||
max_tokens: 最大输出token数
|
||||
temperature: 温度参数
|
||||
**kwargs: 其他配置参数
|
||||
|
||||
Returns:
|
||||
BaseLLMClient: LLM客户端实例
|
||||
"""
|
||||
"""Handle create for the L L M Factory instance."""
|
||||
provider_enum = self._parse_provider(provider)
|
||||
|
||||
# 获取配置
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
api_key = api_key or self._get_api_key(provider_enum)
|
||||
model = model or DEFAULT_MODELS.get(provider_enum)
|
||||
base_url = base_url or DEFAULT_BASE_URLS.get(provider_enum)
|
||||
@@ -87,7 +57,7 @@ class LLMFactory:
|
||||
if not api_key:
|
||||
raise ValueError(f"缺少API密钥,请设置环境变量或传入api_key参数")
|
||||
|
||||
# 检查全局缓存
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
cache_key = f"{provider}_{model}"
|
||||
if cache_key in LLMFactory._global_instances:
|
||||
logger.debug(f"使用缓存的LLM客户端: {cache_key}")
|
||||
@@ -103,17 +73,17 @@ class LLMFactory:
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# 创建客户端
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
client = self._create_client(config)
|
||||
|
||||
# 缓存到全局实例
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
LLMFactory._global_instances[cache_key] = client
|
||||
|
||||
logger.info(f"LLM客户端创建成功并缓存: {provider} - {model}")
|
||||
return client
|
||||
|
||||
def _parse_provider(self, provider: str) -> LLMProvider:
|
||||
"""解析提供商名称"""
|
||||
"""Handle parse provider for this module for the L L M Factory instance."""
|
||||
provider_map = {
|
||||
"deepseek": LLMProvider.DEEPSEEK,
|
||||
"deepseek-v3": LLMProvider.DEEPSEEK,
|
||||
@@ -137,7 +107,7 @@ class LLMFactory:
|
||||
return provider_map[provider_lower]
|
||||
|
||||
def _get_api_key(self, provider: LLMProvider) -> Optional[str]:
|
||||
"""从环境变量获取API密钥"""
|
||||
"""Handle get api key for this module for the L L M Factory instance."""
|
||||
import os
|
||||
|
||||
key_map = {
|
||||
@@ -154,7 +124,7 @@ class LLMFactory:
|
||||
return None
|
||||
|
||||
def _create_client(self, config: LLMConfig) -> BaseLLMClient:
|
||||
"""创建具体客户端"""
|
||||
"""Handle create client for this module for the L L M Factory instance."""
|
||||
client_map = {
|
||||
LLMProvider.DEEPSEEK: DeepSeekClient,
|
||||
LLMProvider.QWEN: QwenClient,
|
||||
@@ -168,14 +138,14 @@ class LLMFactory:
|
||||
return client_class(config)
|
||||
|
||||
def get_cached(self, provider: str, model: Optional[str] = None) -> Optional[BaseLLMClient]:
|
||||
"""获取缓存的客户端"""
|
||||
"""Return cached for the L L M Factory instance."""
|
||||
provider_enum = self._parse_provider(provider)
|
||||
model = model or DEFAULT_MODELS.get(provider_enum)
|
||||
cache_key = f"{provider}_{model}"
|
||||
return LLMFactory._global_instances.get(cache_key)
|
||||
|
||||
def list_available_providers(self) -> Dict[str, list]:
|
||||
"""列出可用的提供商和模型"""
|
||||
"""List available providers for the L L M Factory instance."""
|
||||
return {
|
||||
"deepseek": DeepSeekClient.SUPPORTED_MODELS,
|
||||
"qwen": QwenClient.SUPPORTED_MODELS,
|
||||
@@ -184,12 +154,7 @@ class LLMFactory:
|
||||
|
||||
@classmethod
|
||||
def preload_clients(cls, providers: list = None):
|
||||
"""
|
||||
预加载LLM客户端(应用启动时调用)
|
||||
|
||||
Args:
|
||||
providers: 要预加载的提供商列表,默认加载qwen和deepseek
|
||||
"""
|
||||
"""Handle preload clients for the L L M Factory instance."""
|
||||
if providers is None:
|
||||
providers = ["qwen", "deepseek"]
|
||||
|
||||
@@ -203,9 +168,9 @@ class LLMFactory:
|
||||
|
||||
@classmethod
|
||||
def get_global_client(cls, provider: str, model: Optional[str] = None) -> Optional[BaseLLMClient]:
|
||||
"""获取全局缓存的客户端"""
|
||||
"""Return global client for the L L M Factory instance."""
|
||||
provider_lower = provider.lower()
|
||||
# 处理模型名作为provider的情况(如 qwen3.5-flash)
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
if provider_lower.startswith("qwen"):
|
||||
provider_lower = "qwen"
|
||||
model = model or DEFAULT_MODELS.get(LLMProvider.QWEN if provider_lower == "qwen" else LLMProvider.DEEPSEEK)
|
||||
@@ -214,7 +179,7 @@ class LLMFactory:
|
||||
|
||||
@classmethod
|
||||
def cleanup(cls):
|
||||
"""清理所有缓存的客户端"""
|
||||
"""Handle cleanup for the L L M Factory instance."""
|
||||
for cache_key, client in cls._global_instances.items():
|
||||
try:
|
||||
client.close()
|
||||
@@ -227,7 +192,7 @@ class LLMFactory:
|
||||
|
||||
@lru_cache
|
||||
def get_llm_factory() -> LLMFactory:
|
||||
"""获取LLM工厂实例(缓存)"""
|
||||
"""Return llm factory."""
|
||||
return LLMFactory()
|
||||
|
||||
|
||||
@@ -236,20 +201,10 @@ def get_llm_client(
|
||||
model: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> BaseLLMClient:
|
||||
"""
|
||||
便捷函数:获取LLM客户端(优先使用缓存)
|
||||
|
||||
Args:
|
||||
provider: 提供商名称
|
||||
model: 模型名称
|
||||
**kwargs: 其他配置
|
||||
|
||||
Returns:
|
||||
BaseLLMClient: LLM客户端实例
|
||||
"""
|
||||
"""Return llm client."""
|
||||
factory = get_llm_factory()
|
||||
|
||||
# 先尝试获取缓存的实例
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
cached = factory.get_cached(provider, model)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Qwen LLM客户端 - 支持OpenAI兼容API格式"""
|
||||
"""Provide service-layer logic for qwen client."""
|
||||
|
||||
import time
|
||||
import json
|
||||
@@ -7,21 +7,12 @@ from loguru import logger
|
||||
import httpx
|
||||
|
||||
from .base_client import BaseLLMClient, LLMResponse, LLMConfig, LLMProvider
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
|
||||
|
||||
|
||||
class QwenClient(BaseLLMClient):
|
||||
"""
|
||||
Qwen API客户端(OpenAI兼容格式)
|
||||
|
||||
支持通过new-api等代理服务调用:
|
||||
- qwen-turbo
|
||||
- qwen-plus
|
||||
- qwen-max
|
||||
- qwen3.5-flash (推荐:快速响应)
|
||||
- qwen3.5-plus
|
||||
- qwen-long
|
||||
- qwen2.5系列
|
||||
"""
|
||||
"""Represent the Qwen Client type."""
|
||||
|
||||
SUPPORTED_MODELS = [
|
||||
"qwen-turbo",
|
||||
@@ -39,14 +30,15 @@ class QwenClient(BaseLLMClient):
|
||||
]
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
"""Initialize the Qwen Client instance."""
|
||||
if config.provider not in [LLMProvider.QWEN, LLMProvider.QWEN_VL]:
|
||||
raise ValueError(f"配置provider应为Qwen,实际为{config.provider}")
|
||||
super().__init__(config)
|
||||
self._init_client()
|
||||
|
||||
def _init_client(self):
|
||||
"""初始化HTTP客户端"""
|
||||
# OpenAI兼容API格式
|
||||
"""Handle init client for this module for the Qwen Client instance."""
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
self._client = httpx.Client(
|
||||
base_url=self.config.base_url,
|
||||
headers={
|
||||
@@ -64,11 +56,11 @@ class QwenClient(BaseLLMClient):
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
"""对话补全(OpenAI兼容格式)"""
|
||||
"""Handle chat for the Qwen Client instance."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# OpenAI兼容格式的请求体
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
payload = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
@@ -78,7 +70,7 @@ class QwenClient(BaseLLMClient):
|
||||
"stream": False
|
||||
}
|
||||
|
||||
# OpenAI兼容接口路径
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
response = self._client.post("/chat/completions", json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -86,7 +78,7 @@ class QwenClient(BaseLLMClient):
|
||||
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# OpenAI兼容格式的响应解析
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
choices = data.get("choices", [{}])
|
||||
message = choices[0].get("message", {})
|
||||
|
||||
@@ -121,42 +113,33 @@ class QwenClient(BaseLLMClient):
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
流式对话补全(SSE格式)
|
||||
|
||||
Yields:
|
||||
str: 每次返回一个文本片段
|
||||
|
||||
使用示例:
|
||||
for chunk in client.stream_chat(messages):
|
||||
print(chunk, end="", flush=True)
|
||||
"""
|
||||
"""Stream chat for the Qwen Client instance."""
|
||||
try:
|
||||
# OpenAI兼容格式的请求体,启用流式输出
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
payload = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens or self.config.max_tokens,
|
||||
"temperature": temperature or self.config.temperature,
|
||||
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||
"stream": True # 启用流式输出
|
||||
"stream": True # Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
}
|
||||
|
||||
# 使用stream模式发送请求
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
with self._client.stream("POST", "/chat/completions", json=payload) as response:
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line = line.strip()
|
||||
# SSE格式: data: {...}
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:] # 移除 "data: " 前缀
|
||||
data_str = line[6:] # Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
continue # 跳过空的choices
|
||||
continue # Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
@@ -179,41 +162,27 @@ class QwenClient(BaseLLMClient):
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
异步流式对话补全(用于FastAPI SSE响应)
|
||||
|
||||
Yields:
|
||||
str: 每次返回一个文本片段
|
||||
"""
|
||||
"""Handle async stream chat for the Qwen Client instance."""
|
||||
import asyncio
|
||||
|
||||
# 使用同步流式方法,包装为异步
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
for chunk in self.stream_chat(messages, max_tokens, temperature, **kwargs):
|
||||
yield chunk
|
||||
# 给async循环一个小延迟,让其他任务有机会执行
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
await asyncio.sleep(0)
|
||||
|
||||
def get_available_models(self) -> List[str]:
|
||||
"""获取可用模型列表"""
|
||||
"""Return available models for the Qwen Client instance."""
|
||||
return self.SUPPORTED_MODELS
|
||||
|
||||
def close(self):
|
||||
"""关闭客户端"""
|
||||
"""Release the resources held by this component."""
|
||||
if self._client:
|
||||
self._client.close()
|
||||
|
||||
|
||||
class QwenVLClient(BaseLLMClient):
|
||||
"""
|
||||
Qwen VL多模态客户端(OpenAI兼容格式)
|
||||
|
||||
支持模型:
|
||||
- qwen-vl-plus
|
||||
- qwen-vl-max
|
||||
- qwen3-vl-plus
|
||||
- qwen2-vl-7b-instruct
|
||||
- qwen2-vl-72b-instruct
|
||||
"""
|
||||
"""Represent the Qwen V L Client type."""
|
||||
|
||||
SUPPORTED_MODELS = [
|
||||
"qwen-vl-plus",
|
||||
@@ -224,13 +193,14 @@ class QwenVLClient(BaseLLMClient):
|
||||
]
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
"""Initialize the Qwen V L Client instance."""
|
||||
if config.provider != LLMProvider.QWEN_VL:
|
||||
raise ValueError(f"配置provider应为QWEN_VL,实际为{config.provider}")
|
||||
super().__init__(config)
|
||||
self._init_client()
|
||||
|
||||
def _init_client(self):
|
||||
"""初始化HTTP客户端"""
|
||||
"""Handle init client for this module for the Qwen V L Client instance."""
|
||||
self._client = httpx.Client(
|
||||
base_url=self.config.base_url,
|
||||
headers={
|
||||
@@ -248,21 +218,11 @@ class QwenVLClient(BaseLLMClient):
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
"""多模态对话补全(OpenAI兼容格式)
|
||||
|
||||
支持图片输入,消息格式:
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}},
|
||||
{"type": "text", "text": "描述这张图片"}
|
||||
]
|
||||
}
|
||||
"""
|
||||
"""Handle chat for the Qwen V L Client instance."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# OpenAI兼容格式的请求体
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
payload = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
@@ -312,7 +272,7 @@ class QwenVLClient(BaseLLMClient):
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> Generator[str, None, None]:
|
||||
"""流式多模态对话补全"""
|
||||
"""Stream chat for the Qwen V L Client instance."""
|
||||
try:
|
||||
payload = {
|
||||
"model": self.config.model,
|
||||
@@ -335,7 +295,7 @@ class QwenVLClient(BaseLLMClient):
|
||||
data = json.loads(data_str)
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
continue # 跳过空的choices
|
||||
continue # Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
@@ -348,11 +308,11 @@ class QwenVLClient(BaseLLMClient):
|
||||
yield f"[ERROR: {str(e)}]"
|
||||
|
||||
def get_available_models(self) -> List[str]:
|
||||
"""获取可用模型列表"""
|
||||
"""Return available models for the Qwen V L Client instance."""
|
||||
return self.SUPPORTED_MODELS
|
||||
|
||||
def close(self):
|
||||
"""关闭客户端"""
|
||||
"""Release the resources held by this component."""
|
||||
if self._client:
|
||||
self._client.close()
|
||||
|
||||
@@ -363,7 +323,7 @@ def create_qwen_client(
|
||||
base_url: str = "http://6.86.80.4:30080/v1",
|
||||
**kwargs
|
||||
) -> QwenClient:
|
||||
"""便捷函数:创建Qwen客户端"""
|
||||
"""Create qwen client."""
|
||||
config = LLMConfig(
|
||||
provider=LLMProvider.QWEN,
|
||||
model=model,
|
||||
@@ -380,7 +340,7 @@ def create_qwen_vl_client(
|
||||
base_url: str = "http://6.86.80.4:30080/v1",
|
||||
**kwargs
|
||||
) -> QwenVLClient:
|
||||
"""便捷函数:创建QwenVL客户端"""
|
||||
"""Create qwen vl client."""
|
||||
config = LLMConfig(
|
||||
provider=LLMProvider.QWEN_VL,
|
||||
model=model,
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"""
|
||||
Mock数据服务 - 提供预设假数据供前后端对接测试
|
||||
"""
|
||||
"""Provide service-layer logic for mock data."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Any
|
||||
import uuid
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
# 预设法规文档列表
|
||||
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
MOCK_DOCUMENTS: List[Dict[str, Any]] = [
|
||||
{
|
||||
"id": "doc-001",
|
||||
@@ -45,7 +45,7 @@ MOCK_DOCUMENTS: List[Dict[str, Any]] = [
|
||||
},
|
||||
]
|
||||
|
||||
# 预设快捷问题
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
MOCK_QUICK_QUESTIONS: List[Dict[str, str]] = [
|
||||
{"id": "q1", "question": "电动自行车需要上牌照吗?", "category": "车辆登记"},
|
||||
{"id": "q2", "question": "新能源汽车有哪些补贴政策?", "category": "新能源"},
|
||||
@@ -53,7 +53,7 @@ MOCK_QUICK_QUESTIONS: List[Dict[str, str]] = [
|
||||
{"id": "q4", "question": "驾驶证过期了怎么处理?", "category": "驾驶证"},
|
||||
]
|
||||
|
||||
# 预设检索结果
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
MOCK_RETRIEVAL_RESULTS: List[Dict[str, Any]] = [
|
||||
{
|
||||
"id": "chunk-001",
|
||||
@@ -97,7 +97,7 @@ MOCK_RETRIEVAL_RESULTS: List[Dict[str, Any]] = [
|
||||
},
|
||||
]
|
||||
|
||||
# 预设RAG问答答案模板(按关键词匹配)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
MOCK_RAG_ANSWERS: Dict[str, Dict[str, Any]] = {
|
||||
"电动自行车": {
|
||||
"text": "根据《道路交通安全法》及相关规范,电动自行车上路需满足以下条件:\n\n1. 符合国家标准 GB17761-2018\n2. 经公安机关交通管理部门登记\n3. 最高设计车速不超过 25km/h\n4. 整车质量不超过 55kg\n5. 具有脚踏骑行能力\n6. 蓄电池标称电压不超过 48V\n\n行驶时还需佩戴安全头盔,不得逆向行驶或在机动车道内行驶。",
|
||||
@@ -133,7 +133,7 @@ MOCK_RAG_ANSWERS: Dict[str, Dict[str, Any]] = {
|
||||
},
|
||||
}
|
||||
|
||||
# 预设合规分析结果
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
MOCK_COMPLIANCE_RESULT: Dict[str, Any] = {
|
||||
"task_id": "task-001",
|
||||
"dashboard": {
|
||||
@@ -310,7 +310,7 @@ MOCK_COMPLIANCE_RESULT: Dict[str, Any] = {
|
||||
],
|
||||
}
|
||||
|
||||
# 预设合规对话响应模板
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
MOCK_COMPLIANCE_CHAT_RESPONSES: Dict[str, Dict[str, str]] = {
|
||||
"车身结构设计": {
|
||||
"compliance": "根据当前分析,车身结构设计部分存在以下合规问题:\n\n1. GB 26112-2010要求车顶承受1.5倍整备质量载荷,目前设计声明满足要求但缺少测试数据\n2. C-NCAP正面碰撞后车门应能打开,需提供碰撞测试报告\n\n建议补充相关测试数据以提升合规评分。",
|
||||
@@ -329,7 +329,7 @@ MOCK_COMPLIANCE_CHAT_RESPONSES: Dict[str, Dict[str, str]] = {
|
||||
},
|
||||
}
|
||||
|
||||
# 预设系统统计数据
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
MOCK_SYSTEM_STATS: Dict[str, int] = {
|
||||
"docs": 5,
|
||||
"chunks": 510,
|
||||
@@ -337,7 +337,7 @@ MOCK_SYSTEM_STATS: Dict[str, int] = {
|
||||
"segments": 0,
|
||||
}
|
||||
|
||||
# 预设系统配置
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
MOCK_SYSTEM_CONFIG: Dict[str, Any] = {
|
||||
"llm": {
|
||||
"model": "qwen-max",
|
||||
@@ -358,17 +358,17 @@ MOCK_SYSTEM_CONFIG: Dict[str, Any] = {
|
||||
|
||||
|
||||
def get_mock_documents() -> List[Dict[str, Any]]:
|
||||
"""获取预设法规文档列表"""
|
||||
"""Return mock documents."""
|
||||
return MOCK_DOCUMENTS
|
||||
|
||||
|
||||
def get_mock_quick_questions() -> List[Dict[str, str]]:
|
||||
"""获取预设快捷问题"""
|
||||
"""Return mock quick questions."""
|
||||
return MOCK_QUICK_QUESTIONS
|
||||
|
||||
|
||||
def get_mock_retrieval(query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
"""根据查询关键词返回预设检索结果"""
|
||||
"""Return mock retrieval."""
|
||||
results = []
|
||||
for keyword, data in MOCK_RAG_ANSWERS.items():
|
||||
if keyword in query:
|
||||
@@ -389,7 +389,7 @@ def get_mock_retrieval(query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
|
||||
|
||||
def get_mock_rag_answer(query: str) -> str:
|
||||
"""根据查询关键词返回预设答案"""
|
||||
"""Return mock rag answer."""
|
||||
for keyword, data in MOCK_RAG_ANSWERS.items():
|
||||
if keyword in query:
|
||||
return data["text"]
|
||||
@@ -397,14 +397,14 @@ def get_mock_rag_answer(query: str) -> str:
|
||||
|
||||
|
||||
def get_mock_compliance_result(task_id: str) -> Dict[str, Any]:
|
||||
"""获取预设合规分析结果"""
|
||||
"""Return mock compliance result."""
|
||||
result = MOCK_COMPLIANCE_RESULT.copy()
|
||||
result["task_id"] = task_id
|
||||
return result
|
||||
|
||||
|
||||
def get_mock_compliance_chat_response(intent: str, query: str) -> str:
|
||||
"""获取预设合规对话响应"""
|
||||
"""Return mock compliance chat response."""
|
||||
responses = MOCK_COMPLIANCE_CHAT_RESPONSES.get(intent, {})
|
||||
if "合规" in query or "符合" in query:
|
||||
return responses.get("compliance", "根据相关法规分析,该段落的合规性需进一步评估。")
|
||||
@@ -416,10 +416,10 @@ def get_mock_compliance_chat_response(intent: str, query: str) -> str:
|
||||
|
||||
|
||||
def generate_task_id() -> str:
|
||||
"""生成任务ID"""
|
||||
"""Handle generate task id."""
|
||||
return f"task-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
def generate_doc_id() -> str:
|
||||
"""生成文档ID"""
|
||||
"""Handle generate doc id."""
|
||||
return f"doc-{uuid.uuid4().hex[:8]}"
|
||||
@@ -1,6 +1,8 @@
|
||||
"""文档解析服务"""
|
||||
"""Initialize the app.services.parser package."""
|
||||
|
||||
from .pdf_parser import PDFParser
|
||||
from .docx_parser import DocxParser
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = ["PDFParser", "DocxParser"]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Word文档解析 - 使用python-docx"""
|
||||
"""Provide service-layer logic for docx parser."""
|
||||
|
||||
from docx import Document
|
||||
from docx.enum.text import WD_ALIGN_PARAGRAPH
|
||||
@@ -6,27 +6,29 @@ from typing import List, Dict, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
import re
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocxParagraph:
|
||||
"""段落内容"""
|
||||
"""Represent the Docx Paragraph type."""
|
||||
text: str
|
||||
level: int = 0 # 标题级别,0表示正文
|
||||
level: int = 0 # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
is_list: bool = False
|
||||
list_number: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocxTable:
|
||||
"""表格内容"""
|
||||
"""Represent the Docx Table type."""
|
||||
rows: List[List[str]]
|
||||
markdown: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocxDocumentContent:
|
||||
"""Word文档完整内容"""
|
||||
"""Represent the Docx Document Content type."""
|
||||
file_path: str
|
||||
paragraphs: List[DocxParagraph]
|
||||
tables: List[DocxTable]
|
||||
@@ -35,21 +37,14 @@ class DocxDocumentContent:
|
||||
|
||||
|
||||
class DocxParser:
|
||||
"""Word文档解析器 - 基于python-docx"""
|
||||
"""Provide the Docx Parser parser."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Docx Parser instance."""
|
||||
self.document = None
|
||||
|
||||
def parse(self, file_path: str) -> DocxDocumentContent:
|
||||
"""
|
||||
解析Word文档
|
||||
|
||||
Args:
|
||||
file_path: Word文档路径
|
||||
|
||||
Returns:
|
||||
DocxDocumentContent: 解析后的文档内容
|
||||
"""
|
||||
"""Handle parse for the Docx Parser instance."""
|
||||
logger.info(f"开始解析Word文档: {file_path}")
|
||||
|
||||
try:
|
||||
@@ -60,16 +55,16 @@ class DocxParser:
|
||||
tables=[]
|
||||
)
|
||||
|
||||
# 提取文档元数据
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
doc_content.metadata = self._extract_metadata()
|
||||
|
||||
# 提取段落
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
doc_content.paragraphs = self._extract_paragraphs()
|
||||
|
||||
# 提取表格
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
doc_content.tables = self._extract_tables()
|
||||
|
||||
# 生成Markdown格式文本
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
doc_content.markdown_text = self._generate_markdown(doc_content)
|
||||
|
||||
logger.success(f"Word文档解析完成,共{len(doc_content.paragraphs)}个段落")
|
||||
@@ -81,7 +76,7 @@ class DocxParser:
|
||||
raise
|
||||
|
||||
def _extract_metadata(self) -> Dict[str, str]:
|
||||
"""提取文档元数据"""
|
||||
"""Handle extract metadata for this module for the Docx Parser instance."""
|
||||
metadata = {}
|
||||
try:
|
||||
core_props = self.document.core_properties
|
||||
@@ -98,7 +93,7 @@ class DocxParser:
|
||||
return metadata
|
||||
|
||||
def _extract_paragraphs(self) -> List[DocxParagraph]:
|
||||
"""提取所有段落"""
|
||||
"""Handle extract paragraphs for this module for the Docx Parser instance."""
|
||||
paragraphs = []
|
||||
|
||||
for para in self.document.paragraphs:
|
||||
@@ -106,10 +101,10 @@ class DocxParser:
|
||||
if not text:
|
||||
continue
|
||||
|
||||
# 判断标题级别
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
level = self._get_paragraph_level(para)
|
||||
|
||||
# 判断是否是列表项
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
is_list, list_number = self._detect_list_item(para)
|
||||
|
||||
paragraph = DocxParagraph(
|
||||
@@ -123,66 +118,61 @@ class DocxParser:
|
||||
return paragraphs
|
||||
|
||||
def _get_paragraph_level(self, para) -> int:
|
||||
"""
|
||||
判断段落标题级别
|
||||
|
||||
Returns:
|
||||
int: 标题级别,0表示正文
|
||||
"""
|
||||
# 方法1:检查段落样式
|
||||
"""Handle get paragraph level for this module for the Docx Parser instance."""
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
style_name = para.style.name if para.style else ""
|
||||
|
||||
if "Heading" in style_name or "标题" in style_name:
|
||||
# 从样式名称中提取级别
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
match = re.search(r'Heading\s*(\d)|标题\s*(\d)', style_name)
|
||||
if match:
|
||||
level = int(match.group(1) or match.group(2))
|
||||
return level
|
||||
|
||||
# 方法2:检查段落格式(字号)
|
||||
# 标题通常字号较大
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if para.paragraph_format:
|
||||
# 可以根据字号判断,这里简化处理
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
pass
|
||||
|
||||
# 方法3:根据内容模式判断(法规文档特征)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
text = para.text.strip()
|
||||
|
||||
# 第一章、第X章 -> 二级标题
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if re.match(r'^第[一二三四五六七八九十百]+章\s', text):
|
||||
return 2
|
||||
# 第X节 -> 三级标题
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
elif re.match(r'^第[一二三四五六七八九十百]+节\s', text):
|
||||
return 3
|
||||
# 第X条 -> 四级标题
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
elif re.match(r'^第[一二三四五六七八九十百]+条\s', text):
|
||||
return 4
|
||||
|
||||
return 0 # 正文
|
||||
return 0 # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
def _detect_list_item(self, para) -> tuple[bool, Optional[str]]:
|
||||
"""检测是否是列表项"""
|
||||
"""Handle detect list item for this module for the Docx Parser instance."""
|
||||
text = para.text.strip()
|
||||
|
||||
# 数字列表:1.、2.、(1)、[1]等
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if re.match(r'^[\d]+[.、)\]]\s', text):
|
||||
match = re.match(r'^([\d]+[.、)\]])\s', text)
|
||||
return True, match.group(1) if match else None
|
||||
|
||||
# 中文数字列表:一、二、(一)等
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if re.match(r'^[一二三四五六七八九十]+[、.)]\s', text):
|
||||
match = re.match(r'^([一二三四五六七八九十]+[、.)])\s', text)
|
||||
return True, match.group(1) if match else None
|
||||
|
||||
# 检查段落格式中的列表编号
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if para.paragraph_format and hasattr(para.paragraph_format, 'left_indent'):
|
||||
# 有缩进的可能是列表项
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
pass
|
||||
|
||||
return False, None
|
||||
|
||||
def _extract_tables(self) -> List[DocxTable]:
|
||||
"""提取所有表格"""
|
||||
"""Handle extract tables for this module for the Docx Parser instance."""
|
||||
tables = []
|
||||
|
||||
for table in self.document.tables:
|
||||
@@ -193,7 +183,7 @@ class DocxParser:
|
||||
cells.append(cell.text.strip())
|
||||
rows.append(cells)
|
||||
|
||||
# 转换为Markdown表格
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
markdown = self._table_to_markdown(rows)
|
||||
|
||||
table_content = DocxTable(rows=rows, markdown=markdown)
|
||||
@@ -202,34 +192,34 @@ class DocxParser:
|
||||
return tables
|
||||
|
||||
def _table_to_markdown(self, rows: List[List[str]]) -> str:
|
||||
"""将表格转换为Markdown格式"""
|
||||
"""Handle table to markdown for this module for the Docx Parser instance."""
|
||||
if not rows or len(rows) < 1:
|
||||
return ""
|
||||
|
||||
lines = []
|
||||
|
||||
# 表头
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if len(rows) >= 1:
|
||||
header = rows[0]
|
||||
lines.append("| " + " | ".join(cell for cell in header) + " |")
|
||||
lines.append("| " + " | ".join("---" for _ in header) + " |")
|
||||
|
||||
# 数据行
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
for row in rows[1:]:
|
||||
lines.append("| " + " | ".join(cell for cell in row) + " |")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _generate_markdown(self, doc_content: DocxDocumentContent) -> str:
|
||||
"""生成Markdown格式文本"""
|
||||
"""Handle generate markdown for this module for the Docx Parser instance."""
|
||||
lines = []
|
||||
|
||||
# 文档标题
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
title = doc_content.metadata.get("title", "")
|
||||
if title:
|
||||
lines.append(f"# {title}\n")
|
||||
else:
|
||||
# 从第一个段落获取标题(如果是标题样式)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
for para in doc_content.paragraphs[:5]:
|
||||
if para.level == 1:
|
||||
lines.append(f"# {para.text}\n")
|
||||
@@ -237,29 +227,29 @@ class DocxParser:
|
||||
else:
|
||||
lines.append(f"# {doc_content.file_path}\n")
|
||||
|
||||
# 元数据信息
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
lines.append("\n## 文档信息\n")
|
||||
for key, value in doc_content.metadata.items():
|
||||
if value:
|
||||
lines.append(f"- **{key}**: {value}")
|
||||
|
||||
# 正文内容
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
lines.append("\n## 正文\n")
|
||||
|
||||
table_index = 0
|
||||
for para in doc_content.paragraphs:
|
||||
if para.level > 0:
|
||||
# 标题
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
prefix = "#" * para.level
|
||||
lines.append(f"\n{prefix} {para.text}\n")
|
||||
elif para.is_list:
|
||||
# 列表项
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
lines.append(f"- {para.text}")
|
||||
else:
|
||||
# 正文
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
lines.append(para.text)
|
||||
|
||||
# 添加表格
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if doc_content.tables:
|
||||
lines.append("\n## 表格\n")
|
||||
for i, table in enumerate(doc_content.tables):
|
||||
@@ -269,18 +259,18 @@ class DocxParser:
|
||||
return "\n".join(lines)
|
||||
|
||||
def parse_to_markdown(self, file_path: str) -> str:
|
||||
"""直接解析并返回Markdown文本"""
|
||||
"""Parse to markdown for the Docx Parser instance."""
|
||||
doc_content = self.parse(file_path)
|
||||
return doc_content.markdown_text
|
||||
|
||||
|
||||
def parse_docx(file_path: str) -> DocxDocumentContent:
|
||||
"""便捷函数:解析Word文档"""
|
||||
"""Parse docx."""
|
||||
parser = DocxParser()
|
||||
return parser.parse(file_path)
|
||||
|
||||
|
||||
def parse_docx_to_markdown(file_path: str) -> str:
|
||||
"""便捷函数:解析Word并返回Markdown"""
|
||||
"""Parse docx to markdown."""
|
||||
parser = DocxParser()
|
||||
return parser.parse_to_markdown(file_path)
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
"""MinerU多模态PDF解析 - 版面感知解析"""
|
||||
"""Provide service-layer logic for mineru parser."""
|
||||
|
||||
from typing import Optional, Dict
|
||||
from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
import os
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class MinerUResult:
|
||||
"""MinerU解析结果"""
|
||||
"""Represent the Miner U Result type."""
|
||||
file_path: str
|
||||
markdown_text: str
|
||||
metadata: Dict[str, str] = field(default_factory=dict)
|
||||
@@ -17,21 +19,14 @@ class MinerUResult:
|
||||
|
||||
|
||||
class MinerUParser:
|
||||
"""
|
||||
MinerU多模态PDF解析器
|
||||
|
||||
MinerU (magic-pdf) 是一个开源的高质量PDF解析工具,
|
||||
支持版面感知解析,能够识别文档中的标题、正文、表格、图片等元素,
|
||||
并输出结构化的Markdown格式。
|
||||
|
||||
GitHub: https://github.com/opendatalab/MinerU
|
||||
"""
|
||||
"""Provide the Miner U Parser parser."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Miner U Parser instance."""
|
||||
self.available = self._check_mineru_available()
|
||||
|
||||
def _check_mineru_available(self) -> bool:
|
||||
"""检查MinerU是否可用"""
|
||||
"""Handle check mineru available for this module for the Miner U Parser instance."""
|
||||
try:
|
||||
from magic_pdf.pipe.UNIPipe import UNIPipe
|
||||
return True
|
||||
@@ -40,16 +35,7 @@ class MinerUParser:
|
||||
return False
|
||||
|
||||
def parse(self, file_path: str, output_dir: Optional[str] = None) -> MinerUResult:
|
||||
"""
|
||||
使用MinerU解析PDF文档
|
||||
|
||||
Args:
|
||||
file_path: PDF文件路径
|
||||
output_dir: 输出目录(可选,用于保存解析产物)
|
||||
|
||||
Returns:
|
||||
MinerUResult: 解析结果
|
||||
"""
|
||||
"""Handle parse for the Miner U Parser instance."""
|
||||
logger.info(f"尝试使用MinerU解析: {file_path}")
|
||||
|
||||
if not self.available:
|
||||
@@ -64,19 +50,19 @@ class MinerUParser:
|
||||
from magic_pdf.pipe.UNIPipe import UNIPipe
|
||||
from magic_pdf.libs.MakeContentConfig import DropMode
|
||||
|
||||
# 设置输出目录
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if output_dir is None:
|
||||
output_dir = os.path.dirname(file_path)
|
||||
|
||||
# 创建解析管道
|
||||
# OCR模式可以根据PDF类型选择
|
||||
# auto: 自动判断是否需要OCR
|
||||
# txt: 纯文本PDF(无OCR)
|
||||
# ocr: 扫描件PDF(OCR)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
pipe = UNIPipe(file_path, output_dir)
|
||||
|
||||
# 执行解析
|
||||
# pipe_mk() 返回Markdown格式文本
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
markdown_content = pipe.pipe_mk()
|
||||
|
||||
logger.success(f"MinerU解析成功")
|
||||
@@ -98,13 +84,13 @@ class MinerUParser:
|
||||
)
|
||||
|
||||
def _extract_metadata(self, pipe) -> Dict[str, str]:
|
||||
"""从解析管道提取元数据"""
|
||||
"""Handle extract metadata for this module for the Miner U Parser instance."""
|
||||
metadata = {}
|
||||
try:
|
||||
# MinerU解析管道中可能包含的元数据信息
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if hasattr(pipe, 'pdf_mid_data') and pipe.pdf_mid_data:
|
||||
mid_data = pipe.pdf_mid_data
|
||||
# 提取可能的元数据字段
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
metadata = {
|
||||
"page_count": str(mid_data.get("page_count", "")),
|
||||
"language": str(mid_data.get("language", "")),
|
||||
@@ -116,41 +102,27 @@ class MinerUParser:
|
||||
return metadata
|
||||
|
||||
def parse_to_markdown(self, file_path: str) -> str:
|
||||
"""直接解析并返回Markdown文本"""
|
||||
"""Parse to markdown for the Miner U Parser instance."""
|
||||
result = self.parse(file_path)
|
||||
return result.markdown_text if result.success else ""
|
||||
|
||||
|
||||
class ParserOrchestrator:
|
||||
"""
|
||||
解析服务编排 - 按优先级选择解析器
|
||||
|
||||
解析策略:
|
||||
1. 优先尝试MinerU(版面感知能力强)
|
||||
2. MinerU失败时回退到基础PyMuPDF解析
|
||||
"""
|
||||
"""Represent the Parser Orchestrator type."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Parser Orchestrator instance."""
|
||||
from .pdf_parser import PDFParser
|
||||
self.mineru_parser = MinerUParser()
|
||||
self.pdf_parser = PDFParser()
|
||||
self.mineru_available = self.mineru_parser.available
|
||||
|
||||
def parse_pdf(self, file_path: str, prefer_mineru: bool = True) -> str:
|
||||
"""
|
||||
解析PDF文档,按优先级选择解析器
|
||||
|
||||
Args:
|
||||
file_path: PDF文件路径
|
||||
prefer_mineru: 是否优先使用MinerU
|
||||
|
||||
Returns:
|
||||
str: Markdown格式文本
|
||||
"""
|
||||
"""Parse pdf for the Parser Orchestrator instance."""
|
||||
markdown_text = ""
|
||||
|
||||
if prefer_mineru and self.mineru_available:
|
||||
# 优先尝试MinerU
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
result = self.mineru_parser.parse(file_path)
|
||||
if result.success:
|
||||
markdown_text = result.markdown_text
|
||||
@@ -159,28 +131,20 @@ class ParserOrchestrator:
|
||||
else:
|
||||
logger.warning(f"MinerU解析失败,回退到PyMuPDF: {result.error_message}")
|
||||
|
||||
# 回退到PyMuPDF基础解析
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
logger.info("使用PyMuPDF基础解析")
|
||||
markdown_text = self.pdf_parser.parse_to_markdown(file_path)
|
||||
|
||||
return markdown_text
|
||||
|
||||
def parse_docx(self, file_path: str) -> str:
|
||||
"""解析Word文档"""
|
||||
"""Parse docx for the Parser Orchestrator instance."""
|
||||
from .docx_parser import DocxParser
|
||||
docx_parser = DocxParser()
|
||||
return docx_parser.parse_to_markdown(file_path)
|
||||
|
||||
def parse(self, file_path: str) -> str:
|
||||
"""
|
||||
根据文件类型选择解析器
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
str: Markdown格式文本
|
||||
"""
|
||||
"""Handle parse for the Parser Orchestrator instance."""
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
|
||||
if ext == ".pdf":
|
||||
@@ -192,12 +156,12 @@ class ParserOrchestrator:
|
||||
|
||||
|
||||
def parse_with_mineru(file_path: str) -> MinerUResult:
|
||||
"""便捷函数:使用MinerU解析"""
|
||||
"""Parse with mineru."""
|
||||
parser = MinerUParser()
|
||||
return parser.parse(file_path)
|
||||
|
||||
|
||||
def parse_pdf_smart(file_path: str) -> str:
|
||||
"""便捷函数:智能解析PDF(自动选择最佳解析器)"""
|
||||
"""Parse pdf smart."""
|
||||
orchestrator = ParserOrchestrator()
|
||||
return orchestrator.parse_pdf(file_path)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""PDF文档解析 - 使用PyMuPDF基础解析"""
|
||||
"""Provide service-layer logic for pdf parser."""
|
||||
|
||||
import fitz # PyMuPDF
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
@@ -9,17 +9,17 @@ import re
|
||||
|
||||
@dataclass
|
||||
class PDFPageContent:
|
||||
"""PDF页面内容"""
|
||||
"""Represent the P D F Page Content type."""
|
||||
page_number: int
|
||||
text: str
|
||||
tables: List[str] = field(default_factory=list)
|
||||
images: List[str] = field(default_factory=list) # 图片路径列表
|
||||
images: List[str] = field(default_factory=list) # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
blocks: List[Dict] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PDFDocumentContent:
|
||||
"""PDF文档完整内容"""
|
||||
"""Represent the P D F Document Content type."""
|
||||
file_path: str
|
||||
total_pages: int
|
||||
pages: List[PDFPageContent]
|
||||
@@ -28,23 +28,14 @@ class PDFDocumentContent:
|
||||
|
||||
|
||||
class PDFParser:
|
||||
"""PDF文档解析器 - 基于PyMuPDF"""
|
||||
"""Provide the P D F Parser parser."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the P D F Parser instance."""
|
||||
self.pdf = None
|
||||
|
||||
def parse(self, file_path: str, extract_tables: bool = True, extract_images: bool = False) -> PDFDocumentContent:
|
||||
"""
|
||||
解析PDF文档
|
||||
|
||||
Args:
|
||||
file_path: PDF文件路径
|
||||
extract_tables: 是否提取表格
|
||||
extract_images: 是否提取图片
|
||||
|
||||
Returns:
|
||||
PDFDocumentContent: 解析后的文档内容
|
||||
"""
|
||||
"""Handle parse for the P D F Parser instance."""
|
||||
logger.info(f"开始解析PDF文档: {file_path}")
|
||||
|
||||
try:
|
||||
@@ -55,16 +46,16 @@ class PDFParser:
|
||||
pages=[]
|
||||
)
|
||||
|
||||
# 提取文档元数据
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
doc_content.metadata = self._extract_metadata()
|
||||
|
||||
# 逐页解析
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
for page_num in range(self.pdf.page_count):
|
||||
page = self.pdf[page_num]
|
||||
page_content = self._parse_page(page, page_num + 1, extract_tables, extract_images)
|
||||
doc_content.pages.append(page_content)
|
||||
|
||||
# 生成Markdown格式文本
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
doc_content.markdown_text = self._generate_markdown(doc_content)
|
||||
|
||||
self.pdf.close()
|
||||
@@ -77,7 +68,7 @@ class PDFParser:
|
||||
raise
|
||||
|
||||
def _extract_metadata(self) -> Dict[str, str]:
|
||||
"""提取PDF元数据"""
|
||||
"""Handle extract metadata for this module for the P D F Parser instance."""
|
||||
metadata = {}
|
||||
try:
|
||||
meta = self.pdf.metadata
|
||||
@@ -97,23 +88,23 @@ class PDFParser:
|
||||
|
||||
def _parse_page(self, page: fitz.Page, page_num: int,
|
||||
extract_tables: bool, extract_images: bool) -> PDFPageContent:
|
||||
"""解析单页内容"""
|
||||
"""Handle parse page for this module for the P D F Parser instance."""
|
||||
page_content = PDFPageContent(page_number=page_num, text="")
|
||||
|
||||
# 提取文本块(保留结构)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
blocks = page.get_text("dict", flags=fitz.TEXT_PRESERVE_WHITESPACE)["blocks"]
|
||||
page_content.blocks = blocks
|
||||
|
||||
# 提取纯文本
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
text = page.get_text("text", flags=fitz.TEXT_PRESERVE_WHITESPACE)
|
||||
page_content.text = text.strip()
|
||||
|
||||
# 提取表格(使用PyMuPDF的表格提取功能)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if extract_tables:
|
||||
tables = self._extract_tables_from_page(page)
|
||||
page_content.tables = tables
|
||||
|
||||
# 提取图片
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if extract_images:
|
||||
images = self._extract_images_from_page(page, page_num)
|
||||
page_content.images = images
|
||||
@@ -121,25 +112,22 @@ class PDFParser:
|
||||
return page_content
|
||||
|
||||
def _extract_tables_from_page(self, page: fitz.Page) -> List[str]:
|
||||
"""
|
||||
从页面提取表格(基于文本块分析)
|
||||
注意:PyMuPDF基础版表格提取能力有限,复杂表格建议使用MinerU
|
||||
"""
|
||||
"""Handle extract tables from page for this module for the P D F Parser instance."""
|
||||
tables = []
|
||||
|
||||
try:
|
||||
# 使用PyMuPDF的表格提取方法(2.4+版本)
|
||||
# 对于更复杂的表格,需要在mineru_parser中使用更高级的方法
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
tabs = page.find_tables()
|
||||
if tabs:
|
||||
for tab in tabs:
|
||||
table_text = tab.extract()
|
||||
# 将表格转换为Markdown格式
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
markdown_table = self._table_to_markdown(table_text)
|
||||
tables.append(markdown_table)
|
||||
|
||||
except AttributeError:
|
||||
# 旧版本PyMuPDF没有表格提取功能
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
logger.warning("PyMuPDF版本不支持表格提取,请升级到2.4+版本")
|
||||
except Exception as e:
|
||||
logger.warning(f"表格提取失败: {e}")
|
||||
@@ -147,28 +135,28 @@ class PDFParser:
|
||||
return tables
|
||||
|
||||
def _table_to_markdown(self, table_data: List[List[str]]) -> str:
|
||||
"""将表格数据转换为Markdown格式"""
|
||||
"""Handle table to markdown for this module for the P D F Parser instance."""
|
||||
if not table_data or len(table_data) < 1:
|
||||
return ""
|
||||
|
||||
lines = []
|
||||
# 表头
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if len(table_data) >= 1:
|
||||
header = table_data[0]
|
||||
lines.append("| " + " | ".join(str(cell).strip() for cell in header) + " |")
|
||||
lines.append("| " + " | ".join("---" for _ in header) + " |")
|
||||
|
||||
# 数据行
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
for row in table_data[1:]:
|
||||
lines.append("| " + " | ".join(str(cell).strip() for cell in row) + " |")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _extract_images_from_page(self, page: fitz.Page, page_num: int) -> List[str]:
|
||||
"""提取页面图片"""
|
||||
"""Handle extract images from page for this module for the P D F Parser instance."""
|
||||
images = []
|
||||
# 图片提取功能(可选实现)
|
||||
# 这里仅记录图片信息,实际图片需要额外保存
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
try:
|
||||
image_list = page.get_images()
|
||||
for img_index, img in enumerate(image_list):
|
||||
@@ -179,52 +167,52 @@ class PDFParser:
|
||||
return images
|
||||
|
||||
def _generate_markdown(self, doc_content: PDFDocumentContent) -> str:
|
||||
"""生成Markdown格式文本"""
|
||||
"""Handle generate markdown for this module for the P D F Parser instance."""
|
||||
lines = []
|
||||
|
||||
# 文档标题
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
title = doc_content.metadata.get("title", "")
|
||||
if title:
|
||||
lines.append(f"# {title}\n")
|
||||
else:
|
||||
lines.append(f"# {doc_content.file_path}\n")
|
||||
|
||||
# 元数据信息
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
lines.append("\n## 文档信息\n")
|
||||
for key, value in doc_content.metadata.items():
|
||||
if value and key in ["author", "subject", "keywords", "creation_date"]:
|
||||
lines.append(f"- **{key}**: {value}")
|
||||
|
||||
# 正文内容
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
lines.append("\n## 正文\n")
|
||||
|
||||
for page in doc_content.pages:
|
||||
# 页码标记
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
lines.append(f"\n---\n**第 {page.page_number} 页**\n")
|
||||
|
||||
# 处理文本内容,识别标题结构
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
text = self._process_page_text(page.text, page.blocks)
|
||||
lines.append(text)
|
||||
|
||||
# 添加表格
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
for table in page.tables:
|
||||
lines.append("\n" + table + "\n")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _process_page_text(self, text: str, blocks: List[Dict]) -> str:
|
||||
"""处理页面文本,识别标题结构"""
|
||||
# 基于字体大小识别标题
|
||||
"""Handle process page text for this module for the P D F Parser instance."""
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
processed_text = text
|
||||
|
||||
# 尝试识别标题(基于字号)
|
||||
# 法规文档通常有明确的层级结构:章、节、条
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
processed_text = self._detect_headers(text, blocks)
|
||||
|
||||
return processed_text
|
||||
|
||||
def _detect_headers(self, text: str, blocks: List[Dict]) -> str:
|
||||
"""检测并标记标题(基于字号或内容模式)"""
|
||||
"""Handle detect headers for this module for the P D F Parser instance."""
|
||||
lines = text.split("\n")
|
||||
processed_lines = []
|
||||
|
||||
@@ -233,8 +221,8 @@ class PDFParser:
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# 法规标题模式检测
|
||||
# 第一章、第X章、第X节、第X条等
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if re.match(r'^第[一二三四五六七八九十百]+章\s', line):
|
||||
processed_lines.append(f"\n## {line}\n")
|
||||
elif re.match(r'^第[一二三四五六七八九十百]+节\s', line):
|
||||
@@ -242,7 +230,7 @@ class PDFParser:
|
||||
elif re.match(r'^第[一二三四五六七八九十百]+条\s', line):
|
||||
processed_lines.append(f"\n#### {line}\n")
|
||||
elif re.match(r'^[一二三四五六七八九十]+\s*[、.]', line):
|
||||
# 条款子项
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
processed_lines.append(f"- {line}")
|
||||
else:
|
||||
processed_lines.append(line)
|
||||
@@ -250,18 +238,18 @@ class PDFParser:
|
||||
return "\n".join(processed_lines)
|
||||
|
||||
def parse_to_markdown(self, file_path: str) -> str:
|
||||
"""直接解析并返回Markdown文本"""
|
||||
"""Parse to markdown for the P D F Parser instance."""
|
||||
doc_content = self.parse(file_path)
|
||||
return doc_content.markdown_text
|
||||
|
||||
|
||||
def parse_pdf(file_path: str, **kwargs) -> PDFDocumentContent:
|
||||
"""便捷函数:解析PDF文档"""
|
||||
"""Parse pdf."""
|
||||
parser = PDFParser()
|
||||
return parser.parse(file_path, **kwargs)
|
||||
|
||||
|
||||
def parse_pdf_to_markdown(file_path: str) -> str:
|
||||
"""便捷函数:解析PDF并返回Markdown"""
|
||||
"""Parse pdf to markdown."""
|
||||
parser = PDFParser()
|
||||
return parser.parse_to_markdown(file_path)
|
||||
|
||||
@@ -1,11 +1,29 @@
|
||||
"""RAG服务模块"""
|
||||
"""Initialize the app.services.rag package."""
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
from .retriever import Retriever, retrieve_regulations
|
||||
from .context_builder import ContextBuilder, build_rag_context
|
||||
from .prompt_templates import PromptTemplates, get_prompt_template
|
||||
|
||||
__all__ = [
|
||||
"Retriever", "retrieve_regulations",
|
||||
"ContextBuilder", "build_rag_context",
|
||||
"PromptTemplates", "get_prompt_template"
|
||||
"Retriever",
|
||||
"retrieve_regulations",
|
||||
"ContextBuilder",
|
||||
"build_rag_context",
|
||||
"PromptTemplates",
|
||||
"get_prompt_template",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
"""Handle getattr for this module."""
|
||||
if name in {"Retriever", "retrieve_regulations"}:
|
||||
from .retriever import Retriever, retrieve_regulations
|
||||
|
||||
return {"Retriever": Retriever, "retrieve_regulations": retrieve_regulations}[name]
|
||||
if name in {"ContextBuilder", "build_rag_context"}:
|
||||
from .context_builder import ContextBuilder, build_rag_context
|
||||
|
||||
return {"ContextBuilder": ContextBuilder, "build_rag_context": build_rag_context}[name]
|
||||
if name in {"PromptTemplates", "get_prompt_template"}:
|
||||
from .prompt_templates import PromptTemplates, get_prompt_template
|
||||
|
||||
return {"PromptTemplates": PromptTemplates, "get_prompt_template": get_prompt_template}[name]
|
||||
raise AttributeError(name)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""RAG上下文构建服务 - 构建LLM输入上下文"""
|
||||
"""Provide service-layer logic for context builder."""
|
||||
|
||||
from typing import List, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
@@ -6,11 +6,13 @@ from loguru import logger
|
||||
|
||||
from .retriever import RetrievedDocument
|
||||
from app.config.settings import settings
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class RAGContext:
|
||||
"""RAG构建的上下文"""
|
||||
"""Represent the R A G Context type."""
|
||||
system_prompt: str
|
||||
context_text: str
|
||||
user_query: str
|
||||
@@ -20,14 +22,7 @@ class RAGContext:
|
||||
|
||||
|
||||
class ContextBuilder:
|
||||
"""
|
||||
RAG上下文构建器
|
||||
|
||||
功能:
|
||||
- 格式化检索结果为上下文文本
|
||||
- 控制上下文长度(token限制)
|
||||
- 构建完整的LLM输入格式
|
||||
"""
|
||||
"""Provide the Context Builder builder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -35,14 +30,7 @@ class ContextBuilder:
|
||||
include_metadata: bool = True,
|
||||
citation_format: str = "【条款{clause}】"
|
||||
):
|
||||
"""
|
||||
初始化上下文构建器
|
||||
|
||||
Args:
|
||||
max_context_tokens: 最大上下文token数
|
||||
include_metadata: 是否包含元数据(文档名、条款号等)
|
||||
citation_format: 引用格式模板
|
||||
"""
|
||||
"""Initialize the Context Builder instance."""
|
||||
self.max_context_tokens = max_context_tokens or settings.rag_max_context_tokens
|
||||
self.include_metadata = include_metadata
|
||||
self.citation_format = citation_format
|
||||
@@ -56,30 +44,19 @@ class ContextBuilder:
|
||||
system_prompt: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> RAGContext:
|
||||
"""
|
||||
构建RAG上下文
|
||||
|
||||
Args:
|
||||
query: 用户查询
|
||||
documents: 检索到的文档列表
|
||||
system_prompt: 系统提示词(可选)
|
||||
max_tokens: 最大token数(可选,覆盖默认值)
|
||||
|
||||
Returns:
|
||||
RAGContext: 构建的上下文对象
|
||||
"""
|
||||
"""Handle build for the Context Builder instance."""
|
||||
max_tokens = max_tokens or self.max_context_tokens
|
||||
|
||||
# 格式化文档内容
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
context_text, sources, truncated = self._format_documents(
|
||||
documents,
|
||||
max_tokens
|
||||
)
|
||||
|
||||
# 构建系统提示词
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
system_prompt = system_prompt or self._default_system_prompt()
|
||||
|
||||
# 估算总token数
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
total_tokens = self._estimate_tokens(system_prompt + context_text + query)
|
||||
|
||||
logger.info(f"上下文构建完成: {len(documents)}条文档, {total_tokens}tokens, truncated={truncated}")
|
||||
@@ -98,29 +75,20 @@ class ContextBuilder:
|
||||
documents: List[RetrievedDocument],
|
||||
max_tokens: int
|
||||
) -> tuple:
|
||||
"""
|
||||
格式化文档内容
|
||||
|
||||
Args:
|
||||
documents: 文档列表
|
||||
max_tokens: 最大token数
|
||||
|
||||
Returns:
|
||||
(context_text, sources, truncated)
|
||||
"""
|
||||
"""Handle format documents for this module for the Context Builder instance."""
|
||||
context_parts = []
|
||||
sources = []
|
||||
current_tokens = 0
|
||||
truncated = False
|
||||
|
||||
for i, doc in enumerate(documents):
|
||||
# 格式化单个文档
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
formatted = self._format_single_doc(doc, i + 1)
|
||||
|
||||
# 估算token数
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
doc_tokens = self._estimate_tokens(formatted)
|
||||
|
||||
# 检查是否超出限制
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if current_tokens + doc_tokens > max_tokens:
|
||||
truncated = True
|
||||
logger.warning(f"上下文截断: 已达到{max_tokens}tokens限制")
|
||||
@@ -129,7 +97,7 @@ class ContextBuilder:
|
||||
context_parts.append(formatted)
|
||||
current_tokens += doc_tokens
|
||||
|
||||
# 记录来源
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
sources.append({
|
||||
"index": i + 1,
|
||||
"doc_id": doc.doc_id,
|
||||
@@ -148,13 +116,13 @@ class ContextBuilder:
|
||||
doc: RetrievedDocument,
|
||||
index: int
|
||||
) -> str:
|
||||
"""格式化单个文档"""
|
||||
"""Handle format single doc for this module for the Context Builder instance."""
|
||||
parts = []
|
||||
|
||||
# 索引编号
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
parts.append(f"[{index}]")
|
||||
|
||||
# 元数据(可选)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if self.include_metadata:
|
||||
meta_parts = []
|
||||
|
||||
@@ -171,13 +139,13 @@ class ContextBuilder:
|
||||
if meta_parts:
|
||||
parts.append(" | ".join(meta_parts))
|
||||
|
||||
# 内容
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
parts.append(doc.content)
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def _default_system_prompt(self) -> str:
|
||||
"""默认系统提示词"""
|
||||
"""Handle default system prompt for this module for the Context Builder instance."""
|
||||
return """你是合规专家助手,基于提供的法规条款回答问题。
|
||||
|
||||
回答要求:
|
||||
@@ -192,8 +160,8 @@ class ContextBuilder:
|
||||
- 最后给出合规建议"""
|
||||
|
||||
def _estimate_tokens(self, text: str) -> int:
|
||||
"""估算文本token数"""
|
||||
# 中文字符约1.5 token,英文约0.25 token
|
||||
"""Handle estimate tokens for this module for the Context Builder instance."""
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
chinese_chars = sum(1 for c in text if '一' <= c <= '鿿')
|
||||
other_chars = len(text) - chinese_chars
|
||||
return int(chinese_chars * 1.5 + other_chars * 0.25)
|
||||
@@ -202,15 +170,7 @@ class ContextBuilder:
|
||||
self,
|
||||
context: RAGContext
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
构建LLM消息格式
|
||||
|
||||
Args:
|
||||
context: RAG上下文对象
|
||||
|
||||
Returns:
|
||||
List[Dict]: [{"role": "system/user/assistant", "content": "..."}]
|
||||
"""
|
||||
"""Build messages for the Context Builder instance."""
|
||||
messages = [
|
||||
{"role": "system", "content": context.system_prompt},
|
||||
{"role": "user", "content": f"参考以下法规条款回答问题。\n\n{context.context_text}\n\n问题:{context.user_query}"}
|
||||
@@ -224,6 +184,6 @@ def build_rag_context(
|
||||
documents: List[RetrievedDocument],
|
||||
**kwargs
|
||||
) -> RAGContext:
|
||||
"""便捷函数:构建RAG上下文"""
|
||||
"""Build rag context."""
|
||||
builder = ContextBuilder()
|
||||
return builder.build(query, documents, **kwargs)
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
"""RAG Prompt模板 - 合规问答专用Prompt"""
|
||||
"""Provide service-layer logic for prompt templates."""
|
||||
|
||||
from typing import Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptTemplate:
|
||||
"""Prompt模板"""
|
||||
"""Represent the Prompt Template type."""
|
||||
name: str
|
||||
system_prompt: str
|
||||
user_template: str
|
||||
@@ -14,18 +16,9 @@ class PromptTemplate:
|
||||
|
||||
|
||||
class PromptTemplates:
|
||||
"""
|
||||
合规问答Prompt模板库
|
||||
"""Represent the Prompt Templates type."""
|
||||
|
||||
包含多种场景的Prompt模板:
|
||||
- 合规问答(标准)
|
||||
- 条款解读(详细解释)
|
||||
- 合规检查(判断合规状态)
|
||||
- 差异对比(新旧法规对比)
|
||||
- 报告生成(合规报告)
|
||||
"""
|
||||
|
||||
# 合规问答标准模板
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
COMPLIANCE_QA = PromptTemplate(
|
||||
name="compliance_qa",
|
||||
system_prompt="""你是合规专家助手,专门解答法规合规问题。
|
||||
@@ -63,7 +56,7 @@ class PromptTemplates:
|
||||
description="标准合规问答模板"
|
||||
)
|
||||
|
||||
# 条款解读模板(详细解释)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
CLAUSE_INTERPRETATION = PromptTemplate(
|
||||
name="clause_interpretation",
|
||||
system_prompt="""你是法规解读专家,负责详细解释法规条款的含义和应用。
|
||||
@@ -96,7 +89,7 @@ class PromptTemplates:
|
||||
description="条款详细解读模板"
|
||||
)
|
||||
|
||||
# 合规检查模板(判断合规状态)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
COMPLIANCE_CHECK = PromptTemplate(
|
||||
name="compliance_check",
|
||||
system_prompt="""你是合规检查专家,负责评估企业行为或产品的合规状态。
|
||||
@@ -140,7 +133,7 @@ class PromptTemplates:
|
||||
description="合规检查评估模板"
|
||||
)
|
||||
|
||||
# 差异对比模板(新旧法规对比)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
COMPARISON = PromptTemplate(
|
||||
name="comparison",
|
||||
system_prompt="""你是法规变更分析专家,负责对比新旧法规版本的差异。
|
||||
@@ -192,7 +185,7 @@ class PromptTemplates:
|
||||
description="法规版本对比模板"
|
||||
)
|
||||
|
||||
# 报告生成模板
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
REPORT_GENERATION = PromptTemplate(
|
||||
name="report_generation",
|
||||
system_prompt="""你是合规报告撰写专家,负责生成结构化的合规分析报告。
|
||||
@@ -222,7 +215,7 @@ class PromptTemplates:
|
||||
description="合规报告生成模板"
|
||||
)
|
||||
|
||||
# 文档摘要生成模板
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
DOCUMENT_SUMMARY = PromptTemplate(
|
||||
name="document_summary",
|
||||
system_prompt="""你是法规文档摘要专家,负责生成法规文档的核心要点摘要。
|
||||
@@ -263,7 +256,7 @@ class PromptTemplates:
|
||||
|
||||
@classmethod
|
||||
def get_template(cls, name: str) -> Optional[PromptTemplate]:
|
||||
"""获取指定模板"""
|
||||
"""Return template for the Prompt Templates instance."""
|
||||
templates = {
|
||||
"compliance_qa": cls.COMPLIANCE_QA,
|
||||
"clause_interpretation": cls.CLAUSE_INTERPRETATION,
|
||||
@@ -276,7 +269,7 @@ class PromptTemplates:
|
||||
|
||||
@classmethod
|
||||
def list_templates(cls) -> Dict[str, str]:
|
||||
"""列出所有模板"""
|
||||
"""List templates for the Prompt Templates instance."""
|
||||
return {
|
||||
"compliance_qa": cls.COMPLIANCE_QA.description,
|
||||
"clause_interpretation": cls.CLAUSE_INTERPRETATION.description,
|
||||
@@ -288,7 +281,7 @@ class PromptTemplates:
|
||||
|
||||
|
||||
def get_prompt_template(name: str) -> PromptTemplate:
|
||||
"""便捷函数:获取Prompt模板"""
|
||||
"""Return prompt template."""
|
||||
template = PromptTemplates.get_template(name)
|
||||
if not template:
|
||||
raise ValueError(f"不存在的模板: {name}")
|
||||
|
||||
@@ -1,192 +1,82 @@
|
||||
"""RAG检索服务 - 封装Milvus检索"""
|
||||
"""Provide service-layer logic for retriever."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Dict, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
from typing import Any, Optional
|
||||
|
||||
from app.shared.bootstrap import get_retrieval_service
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
from app.services.embedding.bge_m3_embedder import BGEM3Embedder
|
||||
from app.services.storage.milvus_client import MilvusClient, SearchResult
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievedDocument:
|
||||
"""检索到的文档"""
|
||||
"""Represent the Retrieved Document type."""
|
||||
content: str
|
||||
doc_id: str # 文档ID,用于下载
|
||||
doc_id: str
|
||||
doc_name: str
|
||||
section_title: str
|
||||
clause_number: str
|
||||
page_number: int
|
||||
score: float
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class Retriever:
|
||||
"""
|
||||
RAG检索器
|
||||
|
||||
功能:
|
||||
- 向量检索(Dense + Sparse混合)
|
||||
- 重排序(可选)
|
||||
- 过滤和筛选
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int = None,
|
||||
rerank: bool = False,
|
||||
min_score: float = 0.3
|
||||
):
|
||||
"""
|
||||
初始化检索器
|
||||
|
||||
Args:
|
||||
top_k: 检索召回数量
|
||||
rerank: 是否启用重排序
|
||||
min_score: 最低相关性分数阈值
|
||||
"""
|
||||
self.top_k = top_k or settings.rag_top_k
|
||||
"""Provide the Retriever retriever."""
|
||||
def __init__(self, top_k: int = 5, rerank: bool = False, min_score: float = 0.0):
|
||||
"""Initialize the Retriever instance."""
|
||||
self.top_k = top_k
|
||||
self.rerank = rerank
|
||||
self.min_score = min_score
|
||||
|
||||
# 嵌入模型(延迟加载)
|
||||
self.embedder: Optional[BGEM3Embedder] = None
|
||||
|
||||
# Milvus客户端(延迟连接)
|
||||
self.milvus: Optional[MilvusClient] = None
|
||||
|
||||
logger.info(f"检索器初始化: top_k={self.top_k}, rerank={self.rerank}")
|
||||
|
||||
def _init_embedder(self):
|
||||
"""延迟初始化嵌入模型"""
|
||||
if self.embedder is None:
|
||||
logger.info("加载嵌入模型...")
|
||||
self.embedder = BGEM3Embedder(model_name=settings.embedding_model)
|
||||
|
||||
def _init_milvus(self):
|
||||
"""延迟初始化Milvus"""
|
||||
if self.milvus is None:
|
||||
logger.info("连接Milvus...")
|
||||
self.milvus = MilvusClient()
|
||||
self.milvus.connect()
|
||||
self.milvus.create_collection(recreate=False)
|
||||
self.milvus.load_collection()
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
filters: Optional[str] = None,
|
||||
top_k: Optional[int] = None
|
||||
) -> List[RetrievedDocument]:
|
||||
"""
|
||||
检索相关文档
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
filters: 过滤条件(如 "regulation_type=='车辆安全'")
|
||||
top_k: 返回数量(可选,覆盖默认值)
|
||||
|
||||
Returns:
|
||||
List[RetrievedDocument]: 检索结果列表
|
||||
"""
|
||||
logger.info(f"执行检索: {query}")
|
||||
|
||||
# 初始化组件
|
||||
self._init_embedder()
|
||||
self._init_milvus()
|
||||
|
||||
# 生成查询向量
|
||||
query_embedding = self.embedder.embed_single(query)
|
||||
|
||||
# 执行混合检索
|
||||
results = self.milvus.hybrid_search(
|
||||
query_dense=query_embedding['dense'].tolist(),
|
||||
query_sparse=query_embedding['sparse'],
|
||||
top_k=top_k or self.top_k,
|
||||
filters=filters
|
||||
)
|
||||
|
||||
# 转换为RetrievedDocument格式
|
||||
documents = []
|
||||
for r in results:
|
||||
if r.score >= self.min_score:
|
||||
doc = RetrievedDocument(
|
||||
content=r.content,
|
||||
doc_id=r.metadata.get("doc_id", ""),
|
||||
doc_name=r.metadata.get("doc_name", ""),
|
||||
section_title=r.metadata.get("section_title", ""),
|
||||
clause_number=r.metadata.get("clause_number", ""),
|
||||
page_number=r.metadata.get("page_number", 0),
|
||||
score=r.score,
|
||||
metadata=r.metadata
|
||||
)
|
||||
documents.append(doc)
|
||||
|
||||
logger.success(f"检索完成,返回{len(documents)}条结果(阈值过滤后)")
|
||||
return documents
|
||||
|
||||
def retrieve_with_scores(
|
||||
self,
|
||||
query: str,
|
||||
filters: Optional[str] = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
检索并返回完整结果(包含分数)
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
filters: 过滤条件
|
||||
|
||||
Returns:
|
||||
List[Dict]: 包含分数的检索结果
|
||||
"""
|
||||
documents = self.retrieve(query, filters)
|
||||
def retrieve(self, query: str, filters: Optional[str] = None, top_k: Optional[int] = None) -> list[RetrievedDocument]:
|
||||
"""Handle retrieve for the Retriever instance."""
|
||||
results = get_retrieval_service().retrieve(query=query, top_k=top_k or self.top_k, filters=filters)
|
||||
return [
|
||||
{
|
||||
"content": doc.content,
|
||||
"doc_id": doc.doc_id,
|
||||
"doc_name": doc.doc_name,
|
||||
"section_title": doc.section_title,
|
||||
"clause_number": doc.clause_number,
|
||||
"page_number": doc.page_number,
|
||||
"score": doc.score
|
||||
}
|
||||
for doc in documents
|
||||
RetrievedDocument(
|
||||
content=item.content,
|
||||
doc_id=item.doc_id,
|
||||
doc_name=item.doc_name,
|
||||
section_title=item.section_title,
|
||||
clause_number=item.metadata.get("clause_number", ""),
|
||||
page_number=item.page_number,
|
||||
score=item.score,
|
||||
metadata=item.metadata,
|
||||
)
|
||||
for item in results
|
||||
if item.score >= self.min_score
|
||||
]
|
||||
|
||||
def search_by_doc_name(
|
||||
self,
|
||||
query: str,
|
||||
doc_name: str
|
||||
) -> List[RetrievedDocument]:
|
||||
"""按文档名称过滤检索"""
|
||||
filters = f'doc_name=="{doc_name}"'
|
||||
return self.retrieve(query, filters)
|
||||
def retrieve_with_scores(self, query: str, filters: Optional[str] = None) -> list[dict]:
|
||||
"""Handle retrieve with scores for the Retriever instance."""
|
||||
return [
|
||||
{
|
||||
"content": item.content,
|
||||
"doc_id": item.doc_id,
|
||||
"doc_name": item.doc_name,
|
||||
"section_title": item.section_title,
|
||||
"clause_number": item.clause_number,
|
||||
"page_number": item.page_number,
|
||||
"score": item.score,
|
||||
}
|
||||
for item in self.retrieve(query, filters)
|
||||
]
|
||||
|
||||
def search_by_regulation_type(
|
||||
self,
|
||||
query: str,
|
||||
regulation_type: str
|
||||
) -> List[RetrievedDocument]:
|
||||
"""按法规类型过滤检索"""
|
||||
filters = f'regulation_type=="{regulation_type}"'
|
||||
return self.retrieve(query, filters)
|
||||
def search_by_doc_name(self, query: str, doc_name: str) -> list[RetrievedDocument]:
|
||||
"""Search by doc name for the Retriever instance."""
|
||||
return self.retrieve(query, filters=f'doc_name == "{doc_name}"')
|
||||
|
||||
def search_by_regulation_type(self, query: str, regulation_type: str) -> list[RetrievedDocument]:
|
||||
"""Search by regulation type for the Retriever instance."""
|
||||
return self.retrieve(query, filters=f'regulation_type == "{regulation_type}"')
|
||||
|
||||
def close(self):
|
||||
"""关闭连接"""
|
||||
if self.milvus:
|
||||
self.milvus.disconnect()
|
||||
logger.info("检索器已关闭")
|
||||
"""Release the resources held by this component."""
|
||||
return None
|
||||
|
||||
|
||||
def retrieve_regulations(
|
||||
query: str,
|
||||
top_k: int = 10,
|
||||
filters: Optional[str] = None
|
||||
) -> List[RetrievedDocument]:
|
||||
"""便捷函数:检索法规"""
|
||||
retriever = Retriever(top_k=top_k)
|
||||
results = retriever.retrieve(query, filters)
|
||||
retriever.close()
|
||||
return results
|
||||
def retrieve_regulations(query: str, top_k: int = 10, filters: Optional[str] = None) -> list[RetrievedDocument]:
|
||||
"""Handle retrieve regulations."""
|
||||
return Retriever(top_k=top_k).retrieve(query, filters)
|
||||
|
||||
@@ -1,6 +1,18 @@
|
||||
"""存储服务"""
|
||||
"""Initialize the app.services.storage package."""
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
from .milvus_client import MilvusClient
|
||||
from .minio_client import MinIOClient
|
||||
|
||||
__all__ = ["MilvusClient", "MinIOClient"]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
"""Handle getattr for this module."""
|
||||
if name == "MilvusClient":
|
||||
from .milvus_client import MilvusClient
|
||||
|
||||
return MilvusClient
|
||||
if name == "MinIOClient":
|
||||
from .minio_client import MinIOClient
|
||||
|
||||
return MinIOClient
|
||||
raise AttributeError(name)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Milvus向量数据库客户端 - 存储与检索服务"""
|
||||
"""Provide service-layer logic for milvus client."""
|
||||
|
||||
from pymilvus import (
|
||||
connections,
|
||||
@@ -17,11 +17,13 @@ import numpy as np
|
||||
from ..embedding.text_chunker import TextChunk
|
||||
from ..embedding.bge_m3_embedder import EmbeddingResult
|
||||
from app.config.settings import settings
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""检索结果"""
|
||||
"""Represent the Search Result type."""
|
||||
id: int
|
||||
content: str
|
||||
score: float
|
||||
@@ -30,7 +32,7 @@ class SearchResult:
|
||||
|
||||
@dataclass
|
||||
class MilvusDocument:
|
||||
"""Milvus文档数据结构"""
|
||||
"""Represent the Milvus Document type."""
|
||||
doc_id: str
|
||||
chunk_id: str
|
||||
content: str
|
||||
@@ -46,7 +48,7 @@ class MilvusDocument:
|
||||
|
||||
|
||||
class MilvusClient:
|
||||
"""Milvus向量数据库客户端"""
|
||||
"""Represent the Milvus Client type."""
|
||||
|
||||
COLLECTION_NAME = "regulations"
|
||||
|
||||
@@ -73,6 +75,7 @@ class MilvusClient:
|
||||
collection_name: str = None,
|
||||
db_name: str = None
|
||||
):
|
||||
"""Initialize the Milvus Client instance."""
|
||||
self.host = host or settings.milvus_host
|
||||
self.port = port or settings.milvus_port
|
||||
self.collection_name = collection_name or settings.milvus_collection
|
||||
@@ -84,7 +87,7 @@ class MilvusClient:
|
||||
logger.info(f"Milvus客户端配置: {self.host}:{self.port}, Collection: {self.collection_name}")
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""连接到Milvus服务器"""
|
||||
"""Handle connect for the Milvus Client instance."""
|
||||
try:
|
||||
connections.connect(
|
||||
alias="default",
|
||||
@@ -101,7 +104,7 @@ class MilvusClient:
|
||||
return False
|
||||
|
||||
def disconnect(self):
|
||||
"""断开连接"""
|
||||
"""Handle disconnect for the Milvus Client instance."""
|
||||
try:
|
||||
connections.disconnect("default")
|
||||
self.connected = False
|
||||
@@ -110,7 +113,7 @@ class MilvusClient:
|
||||
logger.warning(f"断开连接时出错: {e}")
|
||||
|
||||
def create_collection(self, recreate: bool = False) -> bool:
|
||||
"""创建Collection"""
|
||||
"""Create collection for the Milvus Client instance."""
|
||||
if not self.connected:
|
||||
logger.warning("未连接到Milvus,请先调用connect()")
|
||||
return False
|
||||
@@ -146,7 +149,7 @@ class MilvusClient:
|
||||
return False
|
||||
|
||||
def _create_indexes(self):
|
||||
"""创建向量索引"""
|
||||
"""Handle create indexes for this module for the Milvus Client instance."""
|
||||
if not self.collection:
|
||||
return
|
||||
|
||||
@@ -177,13 +180,13 @@ class MilvusClient:
|
||||
logger.warning(f"创建索引时出错: {e}")
|
||||
|
||||
def load_collection(self):
|
||||
"""加载Collection到内存"""
|
||||
"""Load collection for the Milvus Client instance."""
|
||||
if self.collection:
|
||||
self.collection.load()
|
||||
logger.info(f"Collection已加载: {self.collection_name}")
|
||||
|
||||
def release_collection(self):
|
||||
"""释放Collection内存"""
|
||||
"""Handle release collection for the Milvus Client instance."""
|
||||
if self.collection:
|
||||
self.collection.release()
|
||||
logger.info(f"Collection已释放: {self.collection_name}")
|
||||
@@ -193,7 +196,7 @@ class MilvusClient:
|
||||
chunks: List[TextChunk],
|
||||
embeddings: EmbeddingResult
|
||||
) -> List[int]:
|
||||
"""插入文档分块和嵌入向量"""
|
||||
"""Handle insert chunks for the Milvus Client instance."""
|
||||
if not self.collection:
|
||||
logger.warning("Collection未初始化")
|
||||
return []
|
||||
@@ -246,7 +249,7 @@ class MilvusClient:
|
||||
top_k: int = 10,
|
||||
filters: Optional[str] = None
|
||||
) -> List[SearchResult]:
|
||||
"""混合检索:Dense + Sparse"""
|
||||
"""Handle hybrid search for the Milvus Client instance."""
|
||||
if not self.collection:
|
||||
logger.warning("Collection未初始化")
|
||||
return []
|
||||
@@ -254,10 +257,10 @@ class MilvusClient:
|
||||
try:
|
||||
self.collection.load()
|
||||
|
||||
# 使用简单的Dense检索(兼容所有版本)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
dense_results = self.dense_search(query_dense, top_k, filters)
|
||||
|
||||
# 可选:合并Sparse结果
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if query_sparse:
|
||||
sparse_results = self.sparse_search(query_sparse, top_k, filters)
|
||||
merged = self._merge_results(dense_results, sparse_results, top_k)
|
||||
@@ -277,7 +280,7 @@ class MilvusClient:
|
||||
top_k: int,
|
||||
dense_weight: float = 0.6
|
||||
) -> List[SearchResult]:
|
||||
"""手动融合Dense和Sparse结果"""
|
||||
"""Handle merge results for this module for the Milvus Client instance."""
|
||||
sparse_weight = 1 - dense_weight
|
||||
merged_dict = {}
|
||||
|
||||
@@ -318,7 +321,7 @@ class MilvusClient:
|
||||
top_k: int = 10,
|
||||
filters: Optional[str] = None
|
||||
) -> List[SearchResult]:
|
||||
"""纯Dense向量检索"""
|
||||
"""Handle dense search for the Milvus Client instance."""
|
||||
if not self.collection:
|
||||
return []
|
||||
|
||||
@@ -375,7 +378,7 @@ class MilvusClient:
|
||||
top_k: int = 10,
|
||||
filters: Optional[str] = None
|
||||
) -> List[SearchResult]:
|
||||
"""纯Sparse向量检索"""
|
||||
"""Handle sparse search for the Milvus Client instance."""
|
||||
if not self.collection:
|
||||
return []
|
||||
|
||||
@@ -427,7 +430,7 @@ class MilvusClient:
|
||||
return []
|
||||
|
||||
def delete_by_doc_id(self, doc_id: str) -> int:
|
||||
"""根据doc_id删除记录"""
|
||||
"""Delete by doc id for the Milvus Client instance."""
|
||||
if not self.collection:
|
||||
return 0
|
||||
|
||||
@@ -441,7 +444,7 @@ class MilvusClient:
|
||||
return 0
|
||||
|
||||
def get_collection_stats(self) -> Dict[str, Any]:
|
||||
"""获取Collection统计信息"""
|
||||
"""Return collection stats for the Milvus Client instance."""
|
||||
if not self.collection:
|
||||
return {}
|
||||
|
||||
@@ -458,7 +461,7 @@ class MilvusClient:
|
||||
|
||||
|
||||
def create_milvus_client() -> MilvusClient:
|
||||
"""便捷函数:创建Milvus客户端"""
|
||||
"""Create milvus client."""
|
||||
client = MilvusClient()
|
||||
client.connect()
|
||||
client.create_collection(recreate=False)
|
||||
@@ -470,7 +473,7 @@ def insert_documents(
|
||||
chunks: List[TextChunk],
|
||||
embeddings: EmbeddingResult
|
||||
) -> List[int]:
|
||||
"""便捷函数:插入文档"""
|
||||
"""Handle insert documents."""
|
||||
return client.insert_chunks(chunks, embeddings)
|
||||
|
||||
|
||||
@@ -480,5 +483,5 @@ def search_regulations(
|
||||
query_sparse: Dict[int, float],
|
||||
top_k: int = 10
|
||||
) -> List[SearchResult]:
|
||||
"""便捷函数:检索法规"""
|
||||
"""Search regulations."""
|
||||
return client.hybrid_search(query_dense, query_sparse, top_k)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""MinIO对象存储客户端 - 文档文件存储"""
|
||||
"""Provide service-layer logic for minio client."""
|
||||
|
||||
from minio import Minio
|
||||
from minio.error import S3Error
|
||||
@@ -8,10 +8,12 @@ from io import BytesIO
|
||||
import os
|
||||
|
||||
from app.config.settings import settings
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
|
||||
|
||||
class MinIOClient:
|
||||
"""MinIO对象存储客户端"""
|
||||
"""Represent the Min I O Client type."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -21,16 +23,7 @@ class MinIOClient:
|
||||
bucket: str = None,
|
||||
secure: bool = None
|
||||
):
|
||||
"""
|
||||
初始化MinIO客户端
|
||||
|
||||
Args:
|
||||
endpoint: MinIO服务地址
|
||||
access_key: 访问密钥
|
||||
secret_key: 秘密密钥
|
||||
bucket: 存储桶名称
|
||||
secure: 是否使用HTTPS
|
||||
"""
|
||||
"""Initialize the Min I O Client instance."""
|
||||
self.endpoint = endpoint or settings.minio_endpoint
|
||||
self.access_key = access_key or settings.minio_access_key
|
||||
self.secret_key = secret_key or settings.minio_secret_key
|
||||
@@ -43,7 +36,7 @@ class MinIOClient:
|
||||
logger.info(f"MinIO客户端配置: {self.endpoint}, bucket={self.bucket}")
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""连接MinIO服务"""
|
||||
"""Handle connect for the Min I O Client instance."""
|
||||
try:
|
||||
self.client = Minio(
|
||||
self.endpoint,
|
||||
@@ -60,7 +53,7 @@ class MinIOClient:
|
||||
return False
|
||||
|
||||
def ensure_bucket(self) -> bool:
|
||||
"""确保存储桶存在"""
|
||||
"""Handle ensure bucket for the Min I O Client instance."""
|
||||
if not self.connected:
|
||||
logger.warning("未连接MinIO,请先调用connect()")
|
||||
return False
|
||||
@@ -82,17 +75,7 @@ class MinIOClient:
|
||||
object_name: str,
|
||||
metadata: Dict[str, Any] = None
|
||||
) -> bool:
|
||||
"""
|
||||
上传本地文件到MinIO
|
||||
|
||||
Args:
|
||||
file_path: 本地文件路径
|
||||
object_name: MinIO对象名称
|
||||
metadata: 元数据
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
"""Handle upload file for the Min I O Client instance."""
|
||||
if not self.connected:
|
||||
self.connect()
|
||||
self.ensure_bucket()
|
||||
@@ -125,18 +108,7 @@ class MinIOClient:
|
||||
content_type: str = "application/octet-stream",
|
||||
metadata: Dict[str, Any] = None
|
||||
) -> bool:
|
||||
"""
|
||||
上传字节数据到MinIO
|
||||
|
||||
Args:
|
||||
data: 文件字节数据
|
||||
object_name: MinIO对象名称
|
||||
content_type: 内容类型
|
||||
metadata: 元数据(注意:MinIO仅支持US-ASCII字符)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
"""Handle upload bytes for the Min I O Client instance."""
|
||||
if not self.connected:
|
||||
self.connect()
|
||||
self.ensure_bucket()
|
||||
@@ -144,18 +116,18 @@ class MinIOClient:
|
||||
try:
|
||||
data_stream = BytesIO(data)
|
||||
|
||||
# 处理metadata:仅保留ASCII安全字符
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
safe_metadata = None
|
||||
if metadata:
|
||||
safe_metadata = {}
|
||||
for key, value in metadata.items():
|
||||
if isinstance(value, str):
|
||||
# 只保留ASCII字符或转换为安全格式
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
try:
|
||||
value.encode('ascii')
|
||||
safe_metadata[key] = value
|
||||
except UnicodeEncodeError:
|
||||
# 中文字符跳过或用占位符
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
safe_metadata[key] = ""
|
||||
else:
|
||||
safe_metadata[key] = str(value)
|
||||
@@ -181,16 +153,7 @@ class MinIOClient:
|
||||
object_name: str,
|
||||
file_path: str
|
||||
) -> bool:
|
||||
"""
|
||||
从MinIO下载文件到本地
|
||||
|
||||
Args:
|
||||
object_name: MinIO对象名称
|
||||
file_path: 本地保存路径
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
"""Handle download file for the Min I O Client instance."""
|
||||
if not self.connected:
|
||||
self.connect()
|
||||
|
||||
@@ -212,16 +175,7 @@ class MinIOClient:
|
||||
object_name: str,
|
||||
expires: int = 3600
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
获取对象下载URL(临时URL)
|
||||
|
||||
Args:
|
||||
object_name: MinIO对象名称
|
||||
expires: URL有效期(秒)
|
||||
|
||||
Returns:
|
||||
str: 下载URL
|
||||
"""
|
||||
"""Return object url for the Min I O Client instance."""
|
||||
if not self.connected:
|
||||
self.connect()
|
||||
|
||||
@@ -238,15 +192,7 @@ class MinIOClient:
|
||||
return None
|
||||
|
||||
def get_object_data(self, object_name: str) -> Optional[bytes]:
|
||||
"""
|
||||
获取对象数据(字节)
|
||||
|
||||
Args:
|
||||
object_name: MinIO对象名称
|
||||
|
||||
Returns:
|
||||
bytes: 文件数据
|
||||
"""
|
||||
"""Return object data for the Min I O Client instance."""
|
||||
if not self.connected:
|
||||
self.connect()
|
||||
|
||||
@@ -262,15 +208,7 @@ class MinIOClient:
|
||||
return None
|
||||
|
||||
def delete_object(self, object_name: str) -> bool:
|
||||
"""
|
||||
删除对象
|
||||
|
||||
Args:
|
||||
object_name: MinIO对象名称
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
"""Delete object for the Min I O Client instance."""
|
||||
if not self.connected:
|
||||
self.connect()
|
||||
|
||||
@@ -284,15 +222,7 @@ class MinIOClient:
|
||||
return False
|
||||
|
||||
def list_objects(self, prefix: str = "") -> list:
|
||||
"""
|
||||
列出存储桶中的对象
|
||||
|
||||
Args:
|
||||
prefix: 对象名称前缀
|
||||
|
||||
Returns:
|
||||
list: 对象列表
|
||||
"""
|
||||
"""List objects for the Min I O Client instance."""
|
||||
if not self.connected:
|
||||
self.connect()
|
||||
|
||||
@@ -305,15 +235,7 @@ class MinIOClient:
|
||||
return []
|
||||
|
||||
def object_exists(self, object_name: str) -> bool:
|
||||
"""
|
||||
检查对象是否存在
|
||||
|
||||
Args:
|
||||
object_name: MinIO对象名称
|
||||
|
||||
Returns:
|
||||
bool: 是否存在
|
||||
"""
|
||||
"""Handle object exists for the Min I O Client instance."""
|
||||
if not self.connected:
|
||||
self.connect()
|
||||
|
||||
@@ -325,7 +247,7 @@ class MinIOClient:
|
||||
return False
|
||||
|
||||
def _get_content_type(self, file_path: str) -> str:
|
||||
"""根据文件扩展名获取Content-Type"""
|
||||
"""Handle get content type for this module for the Min I O Client instance."""
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
content_types = {
|
||||
'.pdf': 'application/pdf',
|
||||
@@ -338,13 +260,13 @@ class MinIOClient:
|
||||
return content_types.get(ext, 'application/octet-stream')
|
||||
|
||||
def close(self):
|
||||
"""关闭连接(MinIO客户端无需显式关闭)"""
|
||||
"""Release the resources held by this component."""
|
||||
self.connected = False
|
||||
logger.info("MinIO客户端已关闭")
|
||||
|
||||
|
||||
def create_minio_client() -> MinIOClient:
|
||||
"""便捷函数:创建MinIO客户端"""
|
||||
"""Create minio client."""
|
||||
client = MinIOClient()
|
||||
client.connect()
|
||||
client.ensure_bucket()
|
||||
|
||||
5
backend/app/shared/__init__.py
Normal file
5
backend/app/shared/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Initialize the app.shared package."""
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = []
|
||||
117
backend/app/shared/bootstrap.py
Normal file
117
backend/app/shared/bootstrap.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""Share backend wiring for bootstrap."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
from app.application.agent import AgentConversationService
|
||||
from app.application.documents import DocumentCommandService, DocumentQueryService
|
||||
from app.application.knowledge import KnowledgeRetrievalService
|
||||
from app.config.settings import settings
|
||||
from app.infrastructure.embedding.openai_compatible_embedding_provider import OpenAICompatibleEmbeddingProvider
|
||||
from app.infrastructure.llm.openai_compatible_answer_generator import OpenAICompatibleAnswerGenerator
|
||||
from app.infrastructure.parser.aliyun_document_parser import AliyunDocumentParser
|
||||
from app.infrastructure.parser.local_chunk_builder import LocalRegulationChunkBuilder
|
||||
from app.infrastructure.parser.local_document_parser import LocalDocumentParser
|
||||
from app.infrastructure.parser.vector_chunk_builder import AliyunVectorChunkBuilder
|
||||
from app.infrastructure.session.in_memory_conversation_store import InMemoryConversationStore
|
||||
from app.infrastructure.storage.json_document_repository import JsonDocumentRepository
|
||||
from app.infrastructure.storage.minio_binary_store import MinioDocumentBinaryStore
|
||||
from app.infrastructure.vectorstore.dense_retriever import DenseRetriever
|
||||
from app.infrastructure.vectorstore.milvus_vector_index import MilvusVectorIndex
|
||||
# Keep shared wiring centralized so dependency construction remains consistent.
|
||||
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_document_repository() -> JsonDocumentRepository:
|
||||
"""Return document repository."""
|
||||
return JsonDocumentRepository(settings.document_metadata_path)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_binary_store() -> MinioDocumentBinaryStore:
|
||||
"""Return binary store."""
|
||||
return MinioDocumentBinaryStore()
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_parser():
|
||||
"""Return parser."""
|
||||
if settings.parser_backend == "aliyun":
|
||||
return AliyunDocumentParser()
|
||||
return LocalDocumentParser()
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_chunk_builder():
|
||||
"""Return chunk builder."""
|
||||
if settings.chunk_backend == "aliyun":
|
||||
return AliyunVectorChunkBuilder()
|
||||
return LocalRegulationChunkBuilder(
|
||||
chunk_size=settings.chunk_size,
|
||||
chunk_overlap=settings.chunk_overlap,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_embedding_provider() -> OpenAICompatibleEmbeddingProvider:
|
||||
"""Return embedding provider."""
|
||||
return OpenAICompatibleEmbeddingProvider()
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_vector_index() -> MilvusVectorIndex:
|
||||
"""Return vector index."""
|
||||
return MilvusVectorIndex()
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_retrieval_service() -> KnowledgeRetrievalService:
|
||||
"""Return retrieval service."""
|
||||
retriever = DenseRetriever(
|
||||
embedding_provider=get_embedding_provider(),
|
||||
vector_index=get_vector_index(),
|
||||
)
|
||||
return KnowledgeRetrievalService(retriever=retriever)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_document_command_service() -> DocumentCommandService:
|
||||
"""Return document command service."""
|
||||
return DocumentCommandService(
|
||||
document_repository=get_document_repository(),
|
||||
binary_store=get_binary_store(),
|
||||
parser=get_parser(),
|
||||
chunk_builder=get_chunk_builder(),
|
||||
embedding_provider=get_embedding_provider(),
|
||||
vector_index=get_vector_index(),
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_document_query_service() -> DocumentQueryService:
|
||||
"""Return document query service."""
|
||||
return DocumentQueryService(
|
||||
document_repository=get_document_repository(),
|
||||
binary_store=get_binary_store(),
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_conversation_store() -> InMemoryConversationStore:
|
||||
"""Return conversation store."""
|
||||
return InMemoryConversationStore(
|
||||
max_sessions=settings.session_max_sessions,
|
||||
timeout_minutes=settings.session_timeout_minutes,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_agent_conversation_service() -> AgentConversationService:
|
||||
"""Return agent conversation service."""
|
||||
return AgentConversationService(
|
||||
retrieval_service=get_retrieval_service(),
|
||||
answer_generator=OpenAICompatibleAnswerGenerator(),
|
||||
conversation_store=get_conversation_store(),
|
||||
)
|
||||
@@ -1,4 +1,8 @@
|
||||
"""Initialize the app.utils package."""
|
||||
|
||||
from .chunking import TextChunker, chunker
|
||||
from .logger import logger, setup_logging
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = ["TextChunker", "chunker", "logger", "setup_logging"]
|
||||
@@ -1,19 +1,25 @@
|
||||
"""Provide utility helpers for chunking."""
|
||||
|
||||
import re
|
||||
from typing import List
|
||||
from app.core.config import settings
|
||||
# Keep module behavior explicit so the backend flow stays easy to audit.
|
||||
|
||||
|
||||
|
||||
class TextChunker:
|
||||
"""Represent the Text Chunker type."""
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = settings.chunk_size,
|
||||
chunk_overlap: int = settings.chunk_overlap,
|
||||
):
|
||||
"""Initialize the Text Chunker instance."""
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
|
||||
def chunk_by_clause(self, text: str) -> List[dict]:
|
||||
"""按条款边界分块(适用于法规文档)"""
|
||||
"""Handle chunk by clause for the Text Chunker instance."""
|
||||
clause_pattern = r"(第[一二三四五六七八九十百]+条)"
|
||||
parts = re.split(clause_pattern, text)
|
||||
|
||||
@@ -46,7 +52,7 @@ class TextChunker:
|
||||
return chunks
|
||||
|
||||
def chunk_by_size(self, text: str) -> List[dict]:
|
||||
"""按固定大小分块"""
|
||||
"""Handle chunk by size for the Text Chunker instance."""
|
||||
chunks = []
|
||||
start = 0
|
||||
chunk_index = 0
|
||||
@@ -69,7 +75,7 @@ class TextChunker:
|
||||
return chunks
|
||||
|
||||
def estimate_tokens(self, text: str) -> int:
|
||||
"""估算token数量"""
|
||||
"""Handle estimate tokens for the Text Chunker instance."""
|
||||
chinese_chars = len(re.findall(r"[^\x00-\xff]", text))
|
||||
english_chars = len(text) - chinese_chars
|
||||
return int(chinese_chars / 1.5 + english_chars / 4)
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
"""Provide utility helpers for logger."""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
# Keep module behavior explicit so the backend flow stays easy to audit.
|
||||
|
||||
|
||||
|
||||
def setup_logging() -> logging.Logger:
|
||||
"""配置日志"""
|
||||
"""Handle setup logging."""
|
||||
logger = logging.getLogger("app")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
"""异步任务Worker模块"""
|
||||
"""Initialize the app.workers package."""
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
"""Initialize the app.workflows package."""
|
||||
|
||||
from .rag_workflow import RagState, rag_workflow, run_rag_workflow, stream_rag_workflow
|
||||
from .compliance_workflow import ComplianceState, compliance_workflow, run_compliance_workflow
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RagState",
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
"""Define workflow state for compliance workflow."""
|
||||
|
||||
from typing import TypedDict, List
|
||||
from langgraph.graph import StateGraph, END
|
||||
# Keep workflow state definitions compact so transitions stay easy to audit.
|
||||
|
||||
|
||||
|
||||
class ComplianceState(TypedDict):
|
||||
"""Track workflow state for compliance state."""
|
||||
document_path: str
|
||||
raw_text: str
|
||||
segments: List[dict]
|
||||
@@ -12,7 +17,7 @@ class ComplianceState(TypedDict):
|
||||
|
||||
|
||||
def parse_document(state: ComplianceState) -> dict:
|
||||
"""解析文档"""
|
||||
"""Parse document."""
|
||||
from app.services import get_document_service
|
||||
doc_service = get_document_service(
|
||||
"/airegulation/demo-mao/backend/data/raw",
|
||||
@@ -23,7 +28,7 @@ def parse_document(state: ComplianceState) -> dict:
|
||||
|
||||
|
||||
def segment_document(state: ComplianceState) -> dict:
|
||||
"""AI语义分段"""
|
||||
"""Handle segment document."""
|
||||
from app.services import llm_service
|
||||
prompt = f"""请分析以下设计方案文档,按照设计意图将其分成若干语义段落。
|
||||
|
||||
@@ -39,7 +44,7 @@ def segment_document(state: ComplianceState) -> dict:
|
||||
输出格式:
|
||||
[{{"intent": "...", "startPos": 0, "endPos": 100, "keywords": [...]}}]"""
|
||||
|
||||
# 简化处理:返回基本分段
|
||||
# Keep workflow state definitions compact so transitions stay easy to audit.
|
||||
segments = [
|
||||
{
|
||||
"id": 1,
|
||||
@@ -53,7 +58,7 @@ def segment_document(state: ComplianceState) -> dict:
|
||||
|
||||
|
||||
def match_regulations(state: ComplianceState) -> dict:
|
||||
"""法规匹配"""
|
||||
"""Handle match regulations."""
|
||||
from app.services import embedding_service, milvus_service
|
||||
matched = []
|
||||
|
||||
@@ -83,7 +88,7 @@ def match_regulations(state: ComplianceState) -> dict:
|
||||
|
||||
|
||||
def calculate_risk(state: ComplianceState) -> dict:
|
||||
"""计算风险等级"""
|
||||
"""Handle calculate risk."""
|
||||
segments = state["matched_regulations"]
|
||||
|
||||
high_count = 0
|
||||
@@ -133,7 +138,7 @@ def calculate_risk(state: ComplianceState) -> dict:
|
||||
|
||||
|
||||
def generate_suggestions(state: ComplianceState) -> dict:
|
||||
"""生成优先建议"""
|
||||
"""Handle generate suggestions."""
|
||||
actions = []
|
||||
|
||||
for segment in state["segments"]:
|
||||
@@ -149,7 +154,7 @@ def generate_suggestions(state: ComplianceState) -> dict:
|
||||
return {"priority_actions": actions}
|
||||
|
||||
|
||||
# 构建工作流图
|
||||
# Keep workflow state definitions compact so transitions stay easy to audit.
|
||||
compliance_graph = StateGraph(ComplianceState)
|
||||
|
||||
compliance_graph.add_node("parse", parse_document)
|
||||
@@ -169,7 +174,7 @@ compliance_workflow = compliance_graph.compile()
|
||||
|
||||
|
||||
async def run_compliance_workflow(document_path: str) -> ComplianceState:
|
||||
"""运行合规分析工作流"""
|
||||
"""Handle run compliance workflow."""
|
||||
initial_state: ComplianceState = {"document_path": document_path}
|
||||
result = compliance_workflow.invoke(initial_state)
|
||||
return result
|
||||
@@ -1,8 +1,13 @@
|
||||
"""Define workflow state for rag workflow."""
|
||||
|
||||
from typing import TypedDict, List
|
||||
from langgraph.graph import StateGraph, END
|
||||
# Keep workflow state definitions compact so transitions stay easy to audit.
|
||||
|
||||
|
||||
|
||||
class RagState(TypedDict):
|
||||
"""Track workflow state for rag state."""
|
||||
query: str
|
||||
query_embedding: List[float]
|
||||
retrieved_docs: List[dict]
|
||||
@@ -12,14 +17,14 @@ class RagState(TypedDict):
|
||||
|
||||
|
||||
def embed_query(state: RagState) -> dict:
|
||||
"""将查询转为向量"""
|
||||
"""Embed query."""
|
||||
from app.services import embedding_service
|
||||
embedding = embedding_service.embed_single(state["query"])
|
||||
return {"query_embedding": embedding}
|
||||
|
||||
|
||||
def retrieve_docs(state: RagState) -> dict:
|
||||
"""向量检索"""
|
||||
"""Handle retrieve docs."""
|
||||
from app.services import milvus_service
|
||||
from app.core.config import settings
|
||||
docs = milvus_service.search(
|
||||
@@ -30,7 +35,7 @@ def retrieve_docs(state: RagState) -> dict:
|
||||
|
||||
|
||||
def build_context(state: RagState) -> dict:
|
||||
"""构建上下文"""
|
||||
"""Build context."""
|
||||
context_parts = []
|
||||
sources = []
|
||||
|
||||
@@ -46,7 +51,7 @@ def build_context(state: RagState) -> dict:
|
||||
|
||||
|
||||
def generate_answer(state: RagState) -> dict:
|
||||
"""生成答案"""
|
||||
"""Handle generate answer."""
|
||||
from app.services import llm_service
|
||||
prompt = f"""请根据以下法规内容回答用户问题,并在回答中标注引用的法规条款。
|
||||
|
||||
@@ -64,7 +69,7 @@ def generate_answer(state: RagState) -> dict:
|
||||
return {"answer": answer}
|
||||
|
||||
|
||||
# 构建工作流图
|
||||
# Keep workflow state definitions compact so transitions stay easy to audit.
|
||||
rag_graph = StateGraph(RagState)
|
||||
|
||||
rag_graph.add_node("embed", embed_query)
|
||||
@@ -82,23 +87,23 @@ rag_workflow = rag_graph.compile()
|
||||
|
||||
|
||||
async def run_rag_workflow(query: str) -> RagState:
|
||||
"""运行RAG工作流"""
|
||||
"""Handle run rag workflow."""
|
||||
initial_state: RagState = {"query": query}
|
||||
result = rag_workflow.invoke(initial_state)
|
||||
return result
|
||||
|
||||
|
||||
def stream_rag_workflow(query: str):
|
||||
"""流式运行RAG工作流"""
|
||||
"""Stream rag workflow."""
|
||||
from app.services import llm_service
|
||||
|
||||
# 先完成检索阶段
|
||||
# Keep workflow state definitions compact so transitions stay easy to audit.
|
||||
state: RagState = {"query": query}
|
||||
state.update(embed_query(state))
|
||||
state.update(retrieve_docs(state))
|
||||
state.update(build_context(state))
|
||||
|
||||
# 流式生成阶段
|
||||
# Keep workflow state definitions compact so transitions stay easy to audit.
|
||||
prompt = f"""请根据以下法规内容回答用户问题,并在回答中标注引用的法规条款。
|
||||
|
||||
法规内容:
|
||||
|
||||
1
backend/backend/data/documents.json
Normal file
1
backend/backend/data/documents.json
Normal file
@@ -0,0 +1 @@
|
||||
{}
|
||||
25
backend/data/documents.json
Normal file
25
backend/data/documents.json
Normal file
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"69280841": {
|
||||
"doc_id": "69280841",
|
||||
"doc_name": "TCT算法接口.pdf",
|
||||
"file_name": "TCT算法接口.pdf",
|
||||
"object_name": "69280841/TCT算法接口.pdf",
|
||||
"content_type": "application/pdf",
|
||||
"size_bytes": 165557,
|
||||
"status": "failed",
|
||||
"regulation_type": "",
|
||||
"version": "",
|
||||
"summary": "",
|
||||
"summary_latency_ms": 0,
|
||||
"chunk_count": 0,
|
||||
"parser_name": "local_markdown_parser",
|
||||
"index_name": "",
|
||||
"error_message": "embedding 维度不匹配,期望 1536",
|
||||
"created_at": "2026-05-18T07:12:16.668306+00:00",
|
||||
"updated_at": "2026-05-18T07:12:19.417142+00:00",
|
||||
"metadata": {
|
||||
"generate_summary": true,
|
||||
"structure_nodes": 0
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,9 +3,12 @@
|
||||
import uvicorn
|
||||
|
||||
from app.config.settings import settings
|
||||
# Keep module behavior explicit so the backend flow stays easy to audit.
|
||||
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Run the module entrypoint."""
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.api_host,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user