This commit is contained in:
2026-05-14 15:07:34 +08:00
parent c2a398930d
commit 10d04c4083
179 changed files with 24073 additions and 1243 deletions

View File

@@ -0,0 +1,17 @@
# src/api/routes/__init__.py
"""API路由模块"""
from fastapi import APIRouter
from .documents import router as documents_router
from .knowledge import router as knowledge_router
from .agent import router as agent_router
# 主路由
api_router = APIRouter()
# 注册子路由
api_router.include_router(documents_router)
api_router.include_router(knowledge_router)
api_router.include_router(agent_router)
__all__ = ["api_router", "documents_router", "knowledge_router", "agent_router"]

View File

@@ -0,0 +1,449 @@
# src/api/routes/agent.py
"""Agent API接口 - 问答对话接口"""
from fastapi import APIRouter, HTTPException, Depends
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, AsyncGenerator
from loguru import logger
import json
import asyncio
from app.services.agent.qa_agent import QAAgent, AgentConfig
from app.services.agent.session_manager import SessionManager
from app.config.settings import settings
router = APIRouter(prefix="/agent", tags=["agent"])
# 会话管理器(全局实例)
session_manager = SessionManager()
# ===== Pydantic Models =====
class AskRequest(BaseModel):
"""单次问答请求"""
query: str = Field(..., description="用户问题", min_length=1, max_length=2000)
filters: Optional[str] = Field(None, description="检索过滤条件")
provider: Optional[str] = Field(None, description="LLM提供商 (qwen/deepseek)")
model: Optional[str] = Field(None, description="LLM模型名称")
top_k: Optional[int] = Field(None, description="检索数量", ge=1, le=20)
prompt_template: Optional[str] = Field(None, description="Prompt模板名称")
class AskResponse(BaseModel):
"""问答响应"""
answer: str
sources: List[Dict] = []
model: str = ""
latency_ms: int = 0
retrieved_count: int = 0
context_tokens: int = 0
truncated: bool = False
error: Optional[str] = None
class ChatRequest(BaseModel):
"""多轮对话请求"""
query: str = Field(..., description="用户问题", min_length=1, max_length=2000)
session_id: Optional[str] = Field(None, description="会话ID首次对话可不传")
filters: Optional[str] = Field(None, description="检索过滤条件")
provider: Optional[str] = Field(None, description="LLM提供商")
model: Optional[str] = Field(None, description="LLM模型名称")
class ChatResponse(BaseModel):
"""多轮对话响应"""
session_id: str
answer: str
sources: List[Dict] = []
model: str = ""
latency_ms: int = 0
message_count: int = 0
class SessionInfo(BaseModel):
"""会话信息"""
session_id: str
message_count: int
created_at: int
updated_at: int
class FeedbackRequest(BaseModel):
"""反馈请求"""
session_id: str
message_index: int
rating: int = Field(..., ge=1, le=5, description="评分 1-5")
comment: Optional[str] = Field(None, description="反馈内容")
class TemplateListResponse(BaseModel):
"""模板列表响应"""
templates: Dict[str, str]
# ===== API Endpoints =====
@router.post("/ask", response_model=AskResponse)
async def ask_question(request: AskRequest):
"""
单次问答接口
不保存会话历史,适合单次查询场景。
"""
logger.info(f"收到问答请求: {request.query}")
try:
# 构建Agent配置
config = AgentConfig(
llm_provider=request.provider or settings.llm_provider,
llm_model=request.model or settings.llm_model,
top_k=request.top_k or settings.rag_top_k
)
# 创建Agent并执行问答
agent = QAAgent(config)
response = agent.ask(
query=request.query,
filters=request.filters,
prompt_template=request.prompt_template
)
agent.close()
return AskResponse(
answer=response.answer,
sources=response.sources,
model=response.model,
latency_ms=response.latency_ms,
retrieved_count=response.retrieved_count,
context_tokens=response.context_tokens,
truncated=response.truncated,
error=response.error
)
except Exception as e:
logger.error(f"问答失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/chat", response_model=ChatResponse)
async def chat_with_session(request: ChatRequest):
"""
多轮对话接口
支持会话历史记录,适合连续对话场景。
"""
logger.info(f"收到对话请求: session={request.session_id}, query={request.query}")
try:
# 获取或创建会话
if request.session_id:
session = session_manager.get_session(request.session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在或已过期")
else:
session = session_manager.create_session()
# 添加用户消息
session.add_user_message(request.query)
# 执行问答
config = AgentConfig(
llm_provider=request.provider or settings.llm_provider,
llm_model=request.model or settings.llm_model
)
agent = QAAgent(config)
response = agent.ask(
query=request.query,
filters=request.filters
)
agent.close()
# 添加助手消息
session.add_assistant_message(
response.answer,
response.sources
)
return ChatResponse(
session_id=session.session_id,
answer=response.answer,
sources=response.sources,
model=response.model,
latency_ms=response.latency_ms,
message_count=session.message_count
)
except HTTPException:
raise
except Exception as e:
logger.error(f"对话失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/chat/stream")
async def chat_stream_get(
query: str,
session_id: Optional[str] = None,
filters: Optional[str] = None,
provider: Optional[str] = None,
model: Optional[str] = None
):
"""
流式对话接口SSE- GET版本
EventSource只能发送GET请求因此提供此接口。
query参数通过URL传递。
SSE事件格式
- event: session - 会话ID
- event: status - 状态更新(检索中、生成中)
- event: sources - 引用来源
- event: content - 回答内容片段
- event: done - 完成,包含统计信息
- event: error - 错误信息
"""
logger.info(f"收到GET流式对话请求: session={session_id}, query={query}")
async def generate_sse() -> AsyncGenerator[str, None]:
"""生成SSE事件流"""
try:
# 获取或创建会话
if session_id:
session = session_manager.get_session(session_id)
if not session:
yield f"event: error\ndata: 会话不存在或已过期\n\n"
return
else:
session = session_manager.create_session()
# 发送session_id
yield f"event: session\ndata: {json.dumps({'session_id': session.session_id})}\n\n"
# 添加用户消息
session.add_user_message(query)
# 创建Agent
config = AgentConfig(
llm_provider=provider or settings.llm_provider,
llm_model=model or settings.llm_model
)
agent = QAAgent(config)
# 执行流式问答
full_answer = ""
sources = []
done_data = {}
for event_data in agent.ask_stream(
query=query,
filters=filters
):
event_type = event_data.get("event", "content")
data = event_data.get("data", "")
# 收集完整回答和来源
if event_type == "content":
full_answer += str(data)
elif event_type == "sources":
sources = data
elif event_type == "done":
done_data = data
# 发送SSE事件
if isinstance(data, (dict, list)):
yield f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
else:
yield f"event: {event_type}\ndata: {data}\n\n"
# 小延迟让其他任务有机会执行
await asyncio.sleep(0)
agent.close()
# 保存到会话历史
session.add_assistant_message(full_answer, sources)
except Exception as e:
logger.error(f"流式对话失败: {e}")
yield f"event: error\ndata: {str(e)}\n\n"
return StreamingResponse(
generate_sse(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # 禁用nginx缓冲
}
)
@router.post("/chat/stream")
async def chat_stream(request: ChatRequest):
"""
流式对话接口SSE
返回Server-Sent Events格式的流式响应用户可实时看到思考过程和回答生成。
SSE事件格式
- event: status - 状态更新(检索中、生成中)
- event: sources - 引用来源
- event: content - 回答内容片段
- event: done - 完成,包含统计信息
- event: error - 错误信息
"""
logger.info(f"收到流式对话请求: session={request.session_id}, query={request.query}")
async def generate_sse() -> AsyncGenerator[str, None]:
"""生成SSE事件流"""
try:
# 获取或创建会话
if request.session_id:
session = session_manager.get_session(request.session_id)
if not session:
yield f"event: error\ndata: 会话不存在或已过期\n\n"
return
else:
session = session_manager.create_session()
# 发送session_id
yield f"event: session\ndata: {json.dumps({'session_id': session.session_id})}\n\n"
# 添加用户消息
session.add_user_message(request.query)
# 创建Agent
config = AgentConfig(
llm_provider=request.provider or settings.llm_provider,
llm_model=request.model or settings.llm_model
)
agent = QAAgent(config)
# 执行流式问答
full_answer = ""
sources = []
done_data = {}
for event_data in agent.ask_stream(
query=request.query,
filters=request.filters
):
event_type = event_data.get("event", "content")
data = event_data.get("data", "")
# 收集完整回答和来源
if event_type == "content":
full_answer += str(data)
elif event_type == "sources":
sources = data
elif event_type == "done":
done_data = data
# 发送SSE事件
if isinstance(data, (dict, list)):
yield f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
else:
yield f"event: {event_type}\ndata: {data}\n\n"
# 小延迟让其他任务有机会执行
await asyncio.sleep(0)
agent.close()
# 保存到会话历史
session.add_assistant_message(full_answer, sources)
except Exception as e:
logger.error(f"流式对话失败: {e}")
yield f"event: error\ndata: {str(e)}\n\n"
return StreamingResponse(
generate_sse(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # 禁用nginx缓冲
}
)
@router.get("/session/{session_id}", response_model=SessionInfo)
async def get_session_info(session_id: str):
"""获取会话信息"""
session = session_manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在或已过期")
return SessionInfo(
session_id=session.session_id,
message_count=session.message_count,
created_at=session.created_at,
updated_at=session.updated_at
)
@router.get("/session/{session_id}/history")
async def get_session_history(session_id: str, max_turns: int = 5):
"""获取会话历史"""
session = session_manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在或已过期")
history = session.get_history(max_turns)
return {"session_id": session_id, "history": history}
@router.delete("/session/{session_id}")
async def delete_session(session_id: str):
"""删除会话"""
success = session_manager.delete_session(session_id)
if not success:
raise HTTPException(status_code=404, detail="会话不存在")
return {"message": "会话已删除", "session_id": session_id}
@router.get("/sessions", response_model=List[SessionInfo])
async def list_sessions():
"""列出所有活跃会话"""
sessions = session_manager.list_sessions()
return [SessionInfo(**s) for s in sessions]
@router.post("/feedback")
async def submit_feedback(request: FeedbackRequest):
"""提交问答反馈"""
session = session_manager.get_session(request.session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在")
# 记录反馈(实际应用中可存储到数据库)
logger.info(f"收到反馈: session={request.session_id}, rating={request.rating}, comment={request.comment}")
return {"message": "反馈已记录", "rating": request.rating}
@router.get("/templates", response_model=TemplateListResponse)
async def list_prompt_templates():
"""列出可用的Prompt模板"""
from app.services.rag.prompt_templates import PromptTemplates
templates = PromptTemplates.list_templates()
return TemplateListResponse(templates=templates)
@router.get("/models")
async def list_available_models():
"""列出可用的LLM模型"""
from app.services.llm import LLMFactory
factory = LLMFactory()
models = factory.list_available_providers()
return {"models": models}

View File

@@ -0,0 +1,96 @@
from fastapi import APIRouter, UploadFile, File, HTTPException
from sse_starlette.sse import EventSourceResponse
import uuid
import os
import json
import asyncio
from app.schemas.compliance import (
AnalyzeResponse,
ComplianceChatRequest,
)
from app.services.mock_data import (
generate_task_id,
get_mock_compliance_result,
get_mock_compliance_chat_response,
)
router = APIRouter(prefix="/compliance", tags=["合规分析"])
# 临时存储分析任务
tasks_store: dict[str, dict] = {}
@router.post("/analyze", response_model=AnalyzeResponse)
async def analyze_document(file: UploadFile = File(...)):
"""上传设计方案进行分析"""
# 生成任务ID
task_id = generate_task_id()
# 保存文件
raw_dir = "/airegulation/demo-mao/backend/data/raw"
os.makedirs(raw_dir, exist_ok=True)
file_path = os.path.join(raw_dir, f"compliance_{task_id}_{file.filename}")
content = await file.read()
with open(file_path, "wb") as f:
f.write(content)
# 记录任务
tasks_store[task_id] = {
"task_id": task_id,
"file_path": file_path,
"status": "processing",
"result": None,
}
# 模拟异步处理完成(立即返回结果)
# 实际应用中这应该是后台任务
tasks_store[task_id]["status"] = "completed"
tasks_store[task_id]["result"] = get_mock_compliance_result(task_id)
return AnalyzeResponse(task_id=task_id)
@router.get("/result/{task_id}")
async def get_result(task_id: str):
"""获取分析结果"""
if task_id not in tasks_store:
# 如果任务ID不存在返回默认mock结果
return get_mock_compliance_result(task_id)
task = tasks_store[task_id]
if task["status"] == "processing":
return {"status": "processing", "message": "分析进行中"}
return task["result"]
@router.post("/chat/{segment_id}")
async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
"""针对段落进行合规对话"""
# 根据segment_id获取对应的intent
intent_map = {
1: "车身结构设计",
2: "动力系统配置",
3: "安全配置设计",
}
intent = intent_map.get(segment_id, "车身结构设计")
async def generate():
# 获取预设响应
response = get_mock_compliance_chat_response(intent, request.query)
# 流式输出响应
sentences = response.split("\n\n")
for sentence in sentences:
if sentence.strip():
chunks = sentence.split("\n")
for chunk in chunks:
if chunk.strip():
await asyncio.sleep(0.05)
yield {"event": "message", "data": json.dumps({"type": "chunk", "text": chunk + "\n"})}
yield {"event": "message", "data": json.dumps({"type": "done"})}
return EventSourceResponse(generate())

View File

@@ -0,0 +1,115 @@
from fastapi import APIRouter, UploadFile, File, HTTPException
import os
import uuid
from datetime import datetime
from app.schemas.doc import (
DocumentUploadResponse,
DocumentListResponse,
DocumentInfo,
ParseResponse,
EmbedResponse,
)
from app.services.mock_data import get_mock_documents, generate_doc_id
router = APIRouter(prefix="/docs", tags=["文档管理"])
# 临时存储文档信息包含预设的mock文档
documents_store: dict[str, dict] = {}
# 初始化时加载mock文档
for doc in get_mock_documents():
documents_store[doc["id"]] = doc
@router.post("/upload", response_model=DocumentUploadResponse)
async def upload_document(file: UploadFile = File(...)):
"""上传法规文档"""
# 检查文件格式
allowed_ext = [".pdf", ".docx", ".doc", ".txt"]
ext = os.path.splitext(file.filename)[1].lower()
if ext not in allowed_ext:
raise HTTPException(400, f"Unsupported file format: {ext}")
# 生成文档ID
doc_id = generate_doc_id()
# 保存文件
raw_dir = "/airegulation/demo-mao/backend/data/raw"
os.makedirs(raw_dir, exist_ok=True)
file_path = os.path.join(raw_dir, f"{doc_id}_{file.filename}")
content = await file.read()
with open(file_path, "wb") as f:
f.write(content)
# 记录文档信息
documents_store[doc_id] = {
"id": doc_id,
"name": file.filename,
"path": file_path,
"size": len(content),
"status": "uploaded",
"chunks": 0,
"created_at": datetime.now(),
}
return DocumentUploadResponse(
doc_id=doc_id,
filename=file.filename,
size=len(content),
)
@router.get("/list", response_model=DocumentListResponse)
async def list_documents():
"""获取已索引文档列表"""
docs = [
DocumentInfo(
id=d["id"],
name=d["name"],
chunks=d["chunks"],
status=d["status"],
created_at=d.get("created_at"),
)
for d in documents_store.values()
]
return DocumentListResponse(docs=docs)
@router.post("/parse/{doc_id}", response_model=ParseResponse)
async def parse_document(doc_id: str):
"""解析文档并分块"""
if doc_id not in documents_store:
raise HTTPException(404, "Document not found")
doc = documents_store[doc_id]
# 模拟解析逻辑
doc["status"] = "parsed"
# 根据文件大小计算chunks数量
file_size = doc.get("size", 100000)
doc["chunks"] = max(20, file_size // 8000)
return ParseResponse(doc_id=doc_id, chunks=doc["chunks"])
@router.post("/embed/{doc_id}", response_model=EmbedResponse)
async def embed_document(doc_id: str):
"""嵌入并存入向量库"""
if doc_id not in documents_store:
raise HTTPException(404, "Document not found")
doc = documents_store[doc_id]
# 模拟嵌入逻辑
doc["status"] = "indexed"
return EmbedResponse(doc_id=doc_id, vectors=doc["chunks"])
@router.delete("/delete/{doc_id}")
async def delete_document(doc_id: str):
"""删除文档"""
if doc_id not in documents_store:
raise HTTPException(404, "Document not found")
del documents_store[doc_id]
return {"success": True}

View File

@@ -0,0 +1,291 @@
# src/api/routes/documents.py
"""文档上传与处理接口"""
from fastapi import APIRouter, UploadFile, File, Form, HTTPException
from fastapi.responses import FileResponse, StreamingResponse
from typing import Optional
import os
import uuid
import tempfile
from pathlib import Path
from loguru import logger
from io import BytesIO
from urllib.parse import quote
from ..models import DocumentUploadResponse, ErrorResponse
from app.services.document_processor import DocumentProcessor
from app.services.storage.minio_client import MinIOClient
from app.config.settings import settings
router = APIRouter(prefix="/documents", tags=["documents"])
# MinIO客户端用于文档存储
minio_client: Optional[MinIOClient] = None
def get_minio_client() -> MinIOClient:
"""获取MinIO客户端实例"""
global minio_client
if minio_client is None:
minio_client = MinIOClient()
minio_client.connect()
minio_client.ensure_bucket()
return minio_client
def _build_document_records(limit: Optional[int] = None):
"""构建文档列表记录,支持按最近更新时间倒序截断。"""
minio = get_minio_client()
document_records = []
objects = minio.client.list_objects(minio.bucket, recursive=True)
for obj in objects:
parts = obj.object_name.split("/", 1)
if len(parts) != 2:
continue
doc_id, filename = parts
last_modified = getattr(obj, "last_modified", None)
document_records.append({
"doc_id": doc_id,
"filename": filename,
"size": getattr(obj, "size", 0) or 0,
"object_name": obj.object_name,
"download_url": f"/api/v1/documents/download/{doc_id}",
"last_modified": last_modified.isoformat() if last_modified else None,
"_sort_key": last_modified.timestamp() if last_modified else 0,
})
document_records.sort(key=lambda item: item["_sort_key"], reverse=True)
if limit is not None:
document_records = document_records[:limit]
for item in document_records:
item.pop("_sort_key", None)
return document_records
@router.post("/upload", response_model=DocumentUploadResponse)
async def upload_document(
file: UploadFile = File(..., description="上传的文档文件"),
doc_name: Optional[str] = Form(None, description="文档名称"),
regulation_type: Optional[str] = Form(None, description="法规类型"),
version: Optional[str] = Form(None, description="文档版本"),
generate_summary: bool = Form(False, description="是否生成摘要默认不生成可节省约60秒")
):
"""
上传文档并处理
支持格式PDF、DOCX、DOC
处理流程:解析 → 分块 → 嵌入 → 入库(摘要可选)
文件存储MinIO对象存储
参数说明:
- generate_summary: 是否生成LLM摘要默认False。勾选后处理时间增加约60秒。
"""
# 验证文件类型
ext = os.path.splitext(file.filename)[1].lower()
if ext not in [".pdf", ".docx", ".doc"]:
raise HTTPException(
status_code=400,
detail=f"不支持的文件类型: {ext}仅支持PDF、DOCX、DOC"
)
# 验证文件大小
if file.size and file.size > settings.max_file_size_mb * 1024 * 1024:
raise HTTPException(
status_code=400,
detail=f"文件过大,最大支持{settings.max_file_size_mb}MB"
)
# 生成文档ID
doc_id = str(uuid.uuid4())[:8]
# 文档名称
final_doc_name = doc_name or file.filename
# MinIO对象名称
object_name = f"{doc_id}/{file.filename}"
logger.info(f"接收到文件上传: {final_doc_name}, 类型: {ext}, doc_id={doc_id}")
try:
# 读取文件内容
content = await file.read()
# 保存临时文件用于处理
temp_dir = tempfile.gettempdir()
temp_path = os.path.join(temp_dir, f"{doc_id}_{file.filename}")
with open(temp_path, "wb") as f:
f.write(content)
logger.info(f"临时文件已保存到: {temp_path}")
# 上传到MinIO
minio = get_minio_client()
upload_success = minio.upload_bytes(
data=content,
object_name=object_name,
content_type=minio._get_content_type(file.filename),
metadata={
"doc_id": doc_id # 仅传递ASCII安全的metadata
}
)
if upload_success:
logger.success(f"文件已上传到MinIO: {object_name}")
else:
logger.warning(f"MinIO上传失败仅使用本地临时文件")
# 处理文档传入相同的doc_id保持一致性
processor = DocumentProcessor(generate_summary=generate_summary)
result = processor.process(
file_path=temp_path,
doc_id=doc_id, # 使用相同的doc_id
doc_name=final_doc_name,
regulation_type=regulation_type or "",
version=version or ""
)
processor.close()
# 清理临时文件
try:
os.remove(temp_path)
except:
pass
if result.success:
return DocumentUploadResponse(
doc_id=result.doc_id,
doc_name=result.doc_name,
status="success",
message=result.message,
num_chunks=result.num_chunks,
summary=result.summary,
summary_latency_ms=result.summary_latency_ms
)
else:
raise HTTPException(
status_code=500,
detail=result.message
)
except Exception as e:
logger.error(f"文档处理失败: {e}")
raise HTTPException(
status_code=500,
detail=f"文档处理失败: {str(e)}"
)
@router.get("/status/{doc_id}", response_model=DocumentUploadResponse)
async def get_document_status(doc_id: str):
"""
查询文档处理状态
Args:
doc_id: 文档ID
"""
# TODO: 实现状态查询(需要数据库支持)
return DocumentUploadResponse(
doc_id=doc_id,
doc_name="",
status="unknown",
message="状态查询功能待实现"
)
@router.get("/download/{doc_id}")
async def download_document(doc_id: str):
"""
下载文档从MinIO获取
Args:
doc_id: 文档ID
Returns:
文件下载响应
"""
logger.info(f"请求下载文档: doc_id={doc_id}")
try:
minio = get_minio_client()
# 查找该doc_id下的文件MinIO对象名称格式: {doc_id}/{filename}
objects = minio.list_objects(prefix=f"{doc_id}/")
if not objects:
logger.warning(f"MinIO中未找到文档: doc_id={doc_id}")
raise HTTPException(
status_code=404,
detail=f"文档不存在: doc_id={doc_id}"
)
# 获取第一个匹配的对象
object_name = objects[0]
logger.info(f"找到MinIO对象: {object_name}")
# 获取文件数据
file_data = minio.get_object_data(object_name)
if file_data is None:
raise HTTPException(
status_code=500,
detail=f"获取文档数据失败"
)
# 解析原始文件名
original_name = object_name.split("/", 1)[1] if "/" in object_name else object_name
# 获取Content-Type
content_type = minio._get_content_type(original_name)
logger.success(f"文档下载成功: {original_name}, 大小={len(file_data)}")
# 返回文件流URL编码文件名以支持中文
encoded_name = quote(original_name)
return StreamingResponse(
BytesIO(file_data),
media_type=content_type,
headers={
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_name}"
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"文档下载失败: {e}")
raise HTTPException(
status_code=500,
detail=f"文档下载失败: {str(e)}"
)
@router.get("/list")
async def list_documents():
"""
列出所有已上传的文档从MinIO获取
"""
try:
documents = _build_document_records()
return {"documents": documents, "total": len(documents)}
except Exception as e:
logger.error(f"列出文档失败: {e}")
return {"documents": [], "total": 0, "error": str(e)}
@router.get("/management-list")
async def get_document_management_list():
"""
文档管理清单接口仅返回最近的10条文档。
"""
try:
documents = _build_document_records(limit=10)
return {"documents": documents, "total": len(documents), "limit": 10}
except Exception as e:
logger.error(f"获取文档管理清单失败: {e}")
return {"documents": [], "total": 0, "limit": 10, "error": str(e)}

View File

@@ -0,0 +1,81 @@
# src/api/routes/knowledge.py
"""知识库检索接口"""
from fastapi import APIRouter, HTTPException
from loguru import logger
from ..models import SearchRequest, SearchResponse, SearchResultItem, ErrorResponse
from app.services.document_processor import DocumentProcessor
router = APIRouter(prefix="/knowledge", tags=["knowledge"])
@router.post("/search", response_model=SearchResponse)
async def search_knowledge(request: SearchRequest):
"""
检索法规知识库
使用混合检索Dense向量 + Sparse向量 + RRF融合
Args:
request: 检索请求参数
"""
if not request.query or len(request.query.strip()) == 0:
raise HTTPException(
status_code=400,
detail="查询文本不能为空"
)
logger.info(f"收到检索请求: {request.query}")
try:
# 执行检索
processor = DocumentProcessor()
results = processor.search(
query=request.query,
top_k=request.top_k,
filters=request.filters
)
processor.close()
# 转换结果格式
result_items = []
for r in results:
item = SearchResultItem(
id=r.get("id", 0),
content=r.get("content", ""),
score=r.get("score", 0.0),
metadata=r.get("metadata", {})
)
result_items.append(item)
return SearchResponse(
query=request.query,
total=len(result_items),
results=result_items
)
except Exception as e:
logger.error(f"检索失败: {e}")
raise HTTPException(
status_code=500,
detail=f"检索失败: {str(e)}"
)
@router.post("/retrieval", response_model=SearchResponse)
async def knowledge_retrieval(request: SearchRequest):
"""
知识检索接口(与架构文档对齐)
该接口实现完整的检索流程:
1. 意图识别
2. BM25关键词检索 + 向量语义检索(双路召回)
3. Cross-Encoder精排
4. 返回结果
Args:
request: 检索请求
"""
# 当前版本使用混合检索,后续可添加精排步骤
return await search_knowledge(request)

View File

@@ -0,0 +1,74 @@
from fastapi import APIRouter
from sse_starlette.sse import EventSourceResponse
from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion
from app.services.mock_data import (
get_mock_quick_questions,
get_mock_retrieval,
get_mock_rag_answer,
)
import json
import asyncio
router = APIRouter(prefix="/rag", tags=["RAG问答"])
@router.post("/chat")
async def rag_chat(request: RagChatRequest):
"""SSE流式问答"""
async def generate():
# 发送检索开始事件
yield {"event": "message", "data": json.dumps({"type": "retrieving"})}
# 模拟检索延迟
await asyncio.sleep(0.3)
# 执行检索
docs = get_mock_retrieval(request.query, top_k=request.top_k)
retrieved_data = [
{
"id": d["id"],
"score": d["score"],
"preview": d["preview"],
"doc_name": d.get("doc_name", ""),
"clause": d.get("clause", ""),
}
for d in docs
]
yield {"event": "message", "data": json.dumps({"type": "retrieved", "docs": retrieved_data})}
# 发送生成开始事件
yield {"event": "message", "data": json.dumps({"type": "generating", "text": "正在生成答案..."})}
# 模拟生成延迟
await asyncio.sleep(0.2)
# 获取预设答案
answer = get_mock_rag_answer(request.query)
# 流式输出答案(按句子分割)
sentences = answer.split("\n\n")
for sentence in sentences:
if sentence.strip():
# 进一步分割长句子
chunks = sentence.split("\n")
for chunk in chunks:
if chunk.strip():
await asyncio.sleep(0.05) # 模拟生成延迟
yield {"event": "message", "data": json.dumps({"type": "chunk", "text": chunk + "\n"})}
# 发送完成事件
yield {"event": "message", "data": json.dumps({"type": "done"})}
return EventSourceResponse(generate())
@router.get("/quick-questions", response_model=QuickQuestionsResponse)
async def get_quick_questions():
"""获取预设快捷问题"""
questions = [
QuickQuestion(id=q["id"], question=q["question"], category=q["category"])
for q in get_mock_quick_questions()
]
return QuickQuestionsResponse(questions=questions)

View File

@@ -0,0 +1,28 @@
from fastapi import APIRouter
from app.core.config import settings
from app.services.mock_data import MOCK_SYSTEM_STATS, MOCK_SYSTEM_CONFIG
router = APIRouter(prefix="/status", tags=["系统状态"])
@router.get("/stats")
async def get_stats():
"""获取系统统计"""
# 返回预设统计数据
return MOCK_SYSTEM_STATS
@router.get("/config")
async def get_config():
"""获取当前配置"""
return MOCK_SYSTEM_CONFIG
@router.get("/milvus/health")
async def milvus_health():
"""Milvus健康检查"""
# 模拟连接状态(假数据模式下始终返回连接成功)
return {
"connected": True,
"collections": ["vehicle_regulations"],
}