Fix SSE route dependency and align architecture docs

This commit is contained in:
ash66
2026-05-18 16:32:42 +08:00
parent 86b9ac806a
commit 3f69cad404
149 changed files with 4786 additions and 5957 deletions

View File

@@ -1,16 +1,29 @@
"""API路由模块"""
"""Initialize the app.api.routes package."""
from fastapi import APIRouter
from .compliance import router as compliance_router
from .documents import router as documents_router
from .knowledge import router as knowledge_router
from .agent import router as agent_router
from .status import router as status_router
# Keep package boundaries explicit so backend imports stay predictable.
# 主路由
# Keep package boundaries explicit so backend imports stay predictable.
api_router = APIRouter()
# 注册子路由
# Keep package boundaries explicit so backend imports stay predictable.
api_router.include_router(documents_router)
api_router.include_router(knowledge_router)
api_router.include_router(agent_router)
api_router.include_router(compliance_router)
api_router.include_router(status_router)
__all__ = ["api_router", "documents_router", "knowledge_router", "agent_router"]
__all__ = [
"api_router",
"documents_router",
"knowledge_router",
"agent_router",
"compliance_router",
"status_router",
]

View File

@@ -1,186 +1,83 @@
"""Agent API接口 - 问答对话接口"""
"""Define API routes for agent."""
from __future__ import annotations
from fastapi import APIRouter, HTTPException, Depends
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, AsyncGenerator
from loguru import logger
import json
import asyncio
import json
from typing import AsyncGenerator, List, Optional
from app.services.agent.qa_agent import QAAgent, AgentConfig
from app.services.agent.session_manager import SessionManager
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from dataclasses import asdict
from app.api.models import (
AskRequest,
AskResponse,
ChatRequest,
ChatResponse,
FeedbackRequest,
SessionInfo,
)
from app.config.settings import settings
from app.shared.bootstrap import get_agent_conversation_service, get_conversation_store
# Keep route handlers close to their transport-layer wiring for easier auditing.
router = APIRouter(prefix="/agent", tags=["agent"])
# 会话管理器(全局实例)
session_manager = SessionManager()
# ===== Pydantic Models =====
class AskRequest(BaseModel):
"""单次问答请求"""
query: str = Field(..., description="用户问题", min_length=1, max_length=2000)
filters: Optional[str] = Field(None, description="检索过滤条件")
provider: Optional[str] = Field(None, description="LLM提供商 (qwen/deepseek)")
model: Optional[str] = Field(None, description="LLM模型名称")
top_k: Optional[int] = Field(None, description="检索数量", ge=1, le=20)
prompt_template: Optional[str] = Field(None, description="Prompt模板名称")
class AskResponse(BaseModel):
"""问答响应"""
answer: str
sources: List[Dict] = []
model: str = ""
latency_ms: int = 0
retrieved_count: int = 0
context_tokens: int = 0
truncated: bool = False
error: Optional[str] = None
class ChatRequest(BaseModel):
"""多轮对话请求"""
query: str = Field(..., description="用户问题", min_length=1, max_length=2000)
session_id: Optional[str] = Field(None, description="会话ID首次对话可不传")
filters: Optional[str] = Field(None, description="检索过滤条件")
provider: Optional[str] = Field(None, description="LLM提供商")
model: Optional[str] = Field(None, description="LLM模型名称")
class ChatResponse(BaseModel):
"""多轮对话响应"""
session_id: str
answer: str
sources: List[Dict] = []
model: str = ""
latency_ms: int = 0
message_count: int = 0
class SessionInfo(BaseModel):
"""会话信息"""
session_id: str
message_count: int
created_at: int
updated_at: int
class FeedbackRequest(BaseModel):
"""反馈请求"""
session_id: str
message_index: int
rating: int = Field(..., ge=1, le=5, description="评分 1-5")
comment: Optional[str] = Field(None, description="反馈内容")
class TemplateListResponse(BaseModel):
"""模板列表响应"""
templates: Dict[str, str]
# ===== API Endpoints =====
@router.post("/ask", response_model=AskResponse)
async def ask_question(request: AskRequest):
"""
单次问答接口
不保存会话历史,适合单次查询场景。
"""
logger.info(f"收到问答请求: {request.query}")
"""Handle ask question."""
try:
# 构建Agent配置
config = AgentConfig(
llm_provider=request.provider or settings.llm_provider,
llm_model=request.model or settings.llm_model,
top_k=request.top_k or settings.rag_top_k
)
# 创建Agent并执行问答
agent = QAAgent(config)
response = agent.ask(
_, result = get_agent_conversation_service().ask(
query=request.query,
filters=request.filters,
prompt_template=request.prompt_template
provider=request.provider or settings.llm_provider,
model=request.model or settings.llm_model,
top_k=request.top_k or settings.rag_top_k,
prompt_template=request.prompt_template,
)
agent.close()
return AskResponse(
answer=response.answer,
sources=response.sources,
model=response.model,
latency_ms=response.latency_ms,
retrieved_count=response.retrieved_count,
context_tokens=response.context_tokens,
truncated=response.truncated,
error=response.error
answer=result.answer,
sources=[asdict(source) for source in result.sources],
model=result.model,
latency_ms=result.latency_ms,
retrieved_count=result.retrieved_count,
context_tokens=result.context_tokens,
truncated=result.truncated,
error=result.error,
)
except Exception as e:
logger.error(f"问答失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/chat", response_model=ChatResponse)
async def chat_with_session(request: ChatRequest):
"""
多轮对话接口
支持会话历史记录,适合连续对话场景。
"""
logger.info(f"收到对话请求: session={request.session_id}, query={request.query}")
"""Handle chat with session."""
try:
# 获取或创建会话
if request.session_id:
session = session_manager.get_session(request.session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在或已过期")
else:
session = session_manager.create_session()
# 添加用户消息
session.add_user_message(request.query)
# 执行问答
config = AgentConfig(
llm_provider=request.provider or settings.llm_provider,
llm_model=request.model or settings.llm_model
)
agent = QAAgent(config)
response = agent.ask(
session_id, result = get_agent_conversation_service().chat(
query=request.query,
filters=request.filters
session_id=request.session_id,
filters=request.filters,
provider=request.provider or settings.llm_provider,
model=request.model or settings.llm_model,
top_k=request.top_k or settings.rag_top_k,
)
agent.close()
# 添加助手消息
session.add_assistant_message(
response.answer,
response.sources
)
session = get_conversation_store().get_session(session_id)
return ChatResponse(
session_id=session.session_id,
answer=response.answer,
sources=response.sources,
model=response.model,
latency_ms=response.latency_ms,
message_count=session.message_count
session_id=session_id,
answer=result.answer,
sources=[asdict(source) for source in result.sources],
model=result.model,
latency_ms=result.latency_ms,
message_count=len(session.messages) if session else 0,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"对话失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc))
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
@router.get("/chat/stream")
@@ -189,260 +86,93 @@ async def chat_stream_get(
session_id: Optional[str] = None,
filters: Optional[str] = None,
provider: Optional[str] = None,
model: Optional[str] = None
model: Optional[str] = None,
):
"""
流式对话接口SSE- GET版本
EventSource只能发送GET请求因此提供此接口。
query参数通过URL传递。
SSE事件格式
- event: session - 会话ID
- event: status - 状态更新(检索中、生成中)
- event: sources - 引用来源
- event: content - 回答内容片段
- event: done - 完成,包含统计信息
- event: error - 错误信息
"""
logger.info(f"收到GET流式对话请求: session={session_id}, query={query}")
"""Handle chat stream get."""
async def generate_sse() -> AsyncGenerator[str, None]:
"""生成SSE事件流"""
"""Handle generate sse."""
try:
# 获取或创建会话
if session_id:
session = session_manager.get_session(session_id)
if not session:
yield f"event: error\ndata: 会话不存在或已过期\n\n"
return
else:
session = session_manager.create_session()
# 发送session_id
yield f"event: session\ndata: {json.dumps({'session_id': session.session_id})}\n\n"
# 添加用户消息
session.add_user_message(query)
# 创建Agent
config = AgentConfig(
llm_provider=provider or settings.llm_provider,
llm_model=model or settings.llm_model
)
agent = QAAgent(config)
# 执行流式问答
full_answer = ""
sources = []
done_data = {}
for event_data in agent.ask_stream(
session_id_, event_stream = get_agent_conversation_service().stream_chat(
query=query,
filters=filters
):
session_id=session_id,
filters=filters,
provider=provider or settings.llm_provider,
model=model or settings.llm_model,
top_k=settings.rag_top_k,
)
yield f"event: session\ndata: {json.dumps({'session_id': session_id_})}\n\n"
for event_data in event_stream:
event_type = event_data.get("event", "content")
data = event_data.get("data", "")
# 收集完整回答和来源
if event_type == "content":
full_answer += str(data)
elif event_type == "sources":
sources = data
elif event_type == "done":
done_data = data
# 发送SSE事件
if isinstance(data, (dict, list)):
yield f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
else:
yield f"event: {event_type}\ndata: {data}\n\n"
# 小延迟让其他任务有机会执行
await asyncio.sleep(0)
agent.close()
# 保存到会话历史
session.add_assistant_message(full_answer, sources)
except Exception as e:
logger.error(f"流式对话失败: {e}")
yield f"event: error\ndata: {str(e)}\n\n"
except Exception as exc:
yield f"event: error\ndata: {str(exc)}\n\n"
return StreamingResponse(
generate_sse(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # 禁用nginx缓冲
}
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)
@router.post("/chat/stream")
async def chat_stream(request: ChatRequest):
"""
流式对话接口SSE
返回Server-Sent Events格式的流式响应用户可实时看到思考过程和回答生成。
SSE事件格式
- event: status - 状态更新(检索中、生成中)
- event: sources - 引用来源
- event: content - 回答内容片段
- event: done - 完成,包含统计信息
- event: error - 错误信息
"""
logger.info(f"收到流式对话请求: session={request.session_id}, query={request.query}")
async def generate_sse() -> AsyncGenerator[str, None]:
"""生成SSE事件流"""
try:
# 获取或创建会话
if request.session_id:
session = session_manager.get_session(request.session_id)
if not session:
yield f"event: error\ndata: 会话不存在或已过期\n\n"
return
else:
session = session_manager.create_session()
# 发送session_id
yield f"event: session\ndata: {json.dumps({'session_id': session.session_id})}\n\n"
# 添加用户消息
session.add_user_message(request.query)
# 创建Agent
config = AgentConfig(
llm_provider=request.provider or settings.llm_provider,
llm_model=request.model or settings.llm_model
)
agent = QAAgent(config)
# 执行流式问答
full_answer = ""
sources = []
done_data = {}
for event_data in agent.ask_stream(
query=request.query,
filters=request.filters
):
event_type = event_data.get("event", "content")
data = event_data.get("data", "")
# 收集完整回答和来源
if event_type == "content":
full_answer += str(data)
elif event_type == "sources":
sources = data
elif event_type == "done":
done_data = data
# 发送SSE事件
if isinstance(data, (dict, list)):
yield f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
else:
yield f"event: {event_type}\ndata: {data}\n\n"
# 小延迟让其他任务有机会执行
await asyncio.sleep(0)
agent.close()
# 保存到会话历史
session.add_assistant_message(full_answer, sources)
except Exception as e:
logger.error(f"流式对话失败: {e}")
yield f"event: error\ndata: {str(e)}\n\n"
return StreamingResponse(
generate_sse(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # 禁用nginx缓冲
}
"""Handle chat stream."""
return await chat_stream_get(
query=request.query,
session_id=request.session_id,
filters=request.filters,
provider=request.provider,
model=request.model,
)
@router.get("/session/{session_id}", response_model=SessionInfo)
async def get_session_info(session_id: str):
"""获取会话信息"""
session = session_manager.get_session(session_id)
"""Return session info."""
session = get_conversation_store().get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在或已过期")
return SessionInfo(
session_id=session.session_id,
message_count=session.message_count,
message_count=len(session.messages),
created_at=session.created_at,
updated_at=session.updated_at
updated_at=session.updated_at,
)
@router.get("/session/{session_id}/history")
async def get_session_history(session_id: str, max_turns: int = 5):
"""获取会话历史"""
session = session_manager.get_session(session_id)
"""Return session history."""
session = get_conversation_store().get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在或已过期")
history = session.get_history(max_turns)
history = [{"role": msg.role, "content": msg.content} for msg in session.messages[-(max_turns * 2):]]
return {"session_id": session_id, "history": history}
@router.delete("/session/{session_id}")
async def delete_session(session_id: str):
"""删除会话"""
success = session_manager.delete_session(session_id)
if not success:
"""Delete session."""
if not get_conversation_store().delete_session(session_id):
raise HTTPException(status_code=404, detail="会话不存在")
return {"message": "会话已删除", "session_id": session_id}
@router.get("/sessions", response_model=List[SessionInfo])
async def list_sessions():
"""列出所有活跃会话"""
sessions = session_manager.list_sessions()
return [SessionInfo(**s) for s in sessions]
"""List sessions."""
return [SessionInfo(**item) for item in get_conversation_store().list_sessions()]
@router.post("/feedback")
async def submit_feedback(request: FeedbackRequest):
"""提交问答反馈"""
session = session_manager.get_session(request.session_id)
"""Submit feedback."""
session = get_conversation_store().get_session(request.session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在")
# 记录反馈(实际应用中可存储到数据库)
logger.info(f"收到反馈: session={request.session_id}, rating={request.rating}, comment={request.comment}")
return {"message": "反馈已记录", "rating": request.rating}
@router.get("/templates", response_model=TemplateListResponse)
async def list_prompt_templates():
"""列出可用的Prompt模板"""
from app.services.rag.prompt_templates import PromptTemplates
templates = PromptTemplates.list_templates()
return TemplateListResponse(templates=templates)
@router.get("/models")
async def list_available_models():
"""列出可用的LLM模型"""
from app.services.llm import LLMFactory
factory = LLMFactory()
models = factory.list_available_providers()
return {"models": models}
return {"message": "反馈已提交", "session_id": request.session_id, "message_index": request.message_index}

View File

@@ -1,9 +1,15 @@
from fastapi import APIRouter, UploadFile, File, HTTPException
from sse_starlette.sse import EventSourceResponse
import uuid
import os
import json
"""Define API routes for compliance."""
from __future__ import annotations
import asyncio
import json
from pathlib import Path
from typing import AsyncGenerator
from fastapi import APIRouter, File, UploadFile
from fastapi.responses import StreamingResponse
from app.schemas.compliance import (
AnalyzeResponse,
ComplianceChatRequest,
@@ -13,38 +19,42 @@ from app.services.mock_data import (
get_mock_compliance_result,
get_mock_compliance_chat_response,
)
# Keep route handlers close to their transport-layer wiring for easier auditing.
router = APIRouter(prefix="/compliance", tags=["合规分析"])
# 临时存储分析任务
# Keep route handlers close to their transport-layer wiring for easier auditing.
tasks_store: dict[str, dict] = {}
# Store uploaded compliance files inside the local backend data directory.
RAW_DATA_DIR = Path(__file__).resolve().parents[3] / "data" / "raw"
@router.post("/analyze", response_model=AnalyzeResponse)
async def analyze_document(file: UploadFile = File(...)):
"""上传设计方案进行分析"""
# 生成任务ID
"""Handle analyze document."""
# Keep route handlers close to their transport-layer wiring for easier auditing.
task_id = generate_task_id()
# 保存文件
raw_dir = "/airegulation/demo-mao/backend/data/raw"
os.makedirs(raw_dir, exist_ok=True)
file_path = os.path.join(raw_dir, f"compliance_{task_id}_{file.filename}")
# Keep route handlers close to their transport-layer wiring for easier auditing.
RAW_DATA_DIR.mkdir(parents=True, exist_ok=True)
file_path = RAW_DATA_DIR / f"compliance_{task_id}_{file.filename}"
content = await file.read()
with open(file_path, "wb") as f:
with file_path.open("wb") as f:
f.write(content)
# 记录任务
# Keep route handlers close to their transport-layer wiring for easier auditing.
tasks_store[task_id] = {
"task_id": task_id,
"file_path": file_path,
"file_path": str(file_path),
"status": "processing",
"result": None,
}
# 模拟异步处理完成(立即返回结果)
# 实际应用中这应该是后台任务
# Keep route handlers close to their transport-layer wiring for easier auditing.
# Keep route handlers close to their transport-layer wiring for easier auditing.
tasks_store[task_id]["status"] = "completed"
tasks_store[task_id]["result"] = get_mock_compliance_result(task_id)
@@ -53,9 +63,9 @@ async def analyze_document(file: UploadFile = File(...)):
@router.get("/result/{task_id}")
async def get_result(task_id: str):
"""获取分析结果"""
"""Return result."""
if task_id not in tasks_store:
# 如果任务ID不存在返回默认mock结果
# Keep route handlers close to their transport-layer wiring for easier auditing.
return get_mock_compliance_result(task_id)
task = tasks_store[task_id]
@@ -68,8 +78,8 @@ async def get_result(task_id: str):
@router.post("/chat/{segment_id}")
async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
"""针对段落进行合规对话"""
# 根据segment_id获取对应的intent
"""Handle compliance chat."""
# Keep route handlers close to their transport-layer wiring for easier auditing.
intent_map = {
1: "车身结构设计",
2: "动力系统配置",
@@ -77,11 +87,12 @@ async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
}
intent = intent_map.get(segment_id, "车身结构设计")
async def generate():
# 获取预设响应
async def generate() -> AsyncGenerator[str, None]:
# Keep route handlers close to their transport-layer wiring for easier auditing.
"""Handle generate."""
response = get_mock_compliance_chat_response(intent, request.query)
# 流式输出响应
# Keep route handlers close to their transport-layer wiring for easier auditing.
sentences = response.split("\n\n")
for sentence in sentences:
if sentence.strip():
@@ -89,8 +100,15 @@ async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
for chunk in chunks:
if chunk.strip():
await asyncio.sleep(0.05)
yield {"event": "message", "data": json.dumps({"type": "chunk", "text": chunk + "\n"})}
yield (
"event: message\n"
f"data: {json.dumps({'type': 'chunk', 'text': chunk + chr(10)}, ensure_ascii=False)}\n\n"
)
yield {"event": "message", "data": json.dumps({"type": "done"})}
yield f"event: message\ndata: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
return EventSourceResponse(generate())
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)

View File

@@ -1,3 +1,5 @@
"""Define API routes for docs."""
from fastapi import APIRouter, UploadFile, File, HTTPException
import os
import uuid
@@ -10,30 +12,32 @@ from app.schemas.doc import (
EmbedResponse,
)
from app.services.mock_data import get_mock_documents, generate_doc_id
# Keep route handlers close to their transport-layer wiring for easier auditing.
router = APIRouter(prefix="/docs", tags=["文档管理"])
# 临时存储文档信息包含预设的mock文档
# Keep route handlers close to their transport-layer wiring for easier auditing.
documents_store: dict[str, dict] = {}
# 初始化时加载mock文档
# Keep route handlers close to their transport-layer wiring for easier auditing.
for doc in get_mock_documents():
documents_store[doc["id"]] = doc
@router.post("/upload", response_model=DocumentUploadResponse)
async def upload_document(file: UploadFile = File(...)):
"""上传法规文档"""
# 检查文件格式
"""Handle upload document."""
# Keep route handlers close to their transport-layer wiring for easier auditing.
allowed_ext = [".pdf", ".docx", ".doc", ".txt"]
ext = os.path.splitext(file.filename)[1].lower()
if ext not in allowed_ext:
raise HTTPException(400, f"Unsupported file format: {ext}")
# 生成文档ID
# Keep route handlers close to their transport-layer wiring for easier auditing.
doc_id = generate_doc_id()
# 保存文件
# Keep route handlers close to their transport-layer wiring for easier auditing.
raw_dir = "/airegulation/demo-mao/backend/data/raw"
os.makedirs(raw_dir, exist_ok=True)
file_path = os.path.join(raw_dir, f"{doc_id}_{file.filename}")
@@ -42,7 +46,7 @@ async def upload_document(file: UploadFile = File(...)):
with open(file_path, "wb") as f:
f.write(content)
# 记录文档信息
# Keep route handlers close to their transport-layer wiring for easier auditing.
documents_store[doc_id] = {
"id": doc_id,
"name": file.filename,
@@ -62,7 +66,7 @@ async def upload_document(file: UploadFile = File(...)):
@router.get("/list", response_model=DocumentListResponse)
async def list_documents():
"""获取已索引文档列表"""
"""List documents."""
docs = [
DocumentInfo(
id=d["id"],
@@ -78,14 +82,14 @@ async def list_documents():
@router.post("/parse/{doc_id}", response_model=ParseResponse)
async def parse_document(doc_id: str):
"""解析文档并分块"""
"""Parse document."""
if doc_id not in documents_store:
raise HTTPException(404, "Document not found")
doc = documents_store[doc_id]
# 模拟解析逻辑
# Keep route handlers close to their transport-layer wiring for easier auditing.
doc["status"] = "parsed"
# 根据文件大小计算chunks数量
# Keep route handlers close to their transport-layer wiring for easier auditing.
file_size = doc.get("size", 100000)
doc["chunks"] = max(20, file_size // 8000)
@@ -94,12 +98,12 @@ async def parse_document(doc_id: str):
@router.post("/embed/{doc_id}", response_model=EmbedResponse)
async def embed_document(doc_id: str):
"""嵌入并存入向量库"""
"""Embed document."""
if doc_id not in documents_store:
raise HTTPException(404, "Document not found")
doc = documents_store[doc_id]
# 模拟嵌入逻辑
# Keep route handlers close to their transport-layer wiring for easier auditing.
doc["status"] = "indexed"
return EmbedResponse(doc_id=doc_id, vectors=doc["chunks"])
@@ -107,7 +111,7 @@ async def embed_document(doc_id: str):
@router.delete("/delete/{doc_id}")
async def delete_document(doc_id: str):
"""删除文档"""
"""Delete document."""
if doc_id not in documents_store:
raise HTTPException(404, "Document not found")

View File

@@ -1,290 +1,140 @@
"""文档上传与处理接口"""
"""Define API routes for documents."""
from __future__ import annotations
from fastapi import APIRouter, UploadFile, File, Form, HTTPException
from fastapi.responses import FileResponse, StreamingResponse
from typing import Optional
import os
import uuid
import tempfile
from pathlib import Path
from loguru import logger
from io import BytesIO
from urllib.parse import quote
from ..models import DocumentUploadResponse, ErrorResponse
from app.services.document_processor import DocumentProcessor
from app.services.storage.minio_client import MinIOClient
from app.config.settings import settings
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from fastapi.responses import StreamingResponse
from loguru import logger
from app.api.models import DocumentUploadResponse
from app.application.documents import DocumentProcessResult
from app.shared.bootstrap import get_document_command_service, get_document_query_service
# Keep route handlers close to their transport-layer wiring for easier auditing.
router = APIRouter(prefix="/documents", tags=["documents"])
# MinIO客户端用于文档存储
minio_client: Optional[MinIOClient] = None
def get_minio_client() -> MinIOClient:
"""获取MinIO客户端实例"""
global minio_client
if minio_client is None:
minio_client = MinIOClient()
minio_client.connect()
minio_client.ensure_bucket()
return minio_client
def _build_document_records(limit: Optional[int] = None):
"""构建文档列表记录,支持按最近更新时间倒序截断。"""
minio = get_minio_client()
document_records = []
objects = minio.client.list_objects(minio.bucket, recursive=True)
for obj in objects:
parts = obj.object_name.split("/", 1)
if len(parts) != 2:
continue
doc_id, filename = parts
last_modified = getattr(obj, "last_modified", None)
document_records.append({
"doc_id": doc_id,
"filename": filename,
"size": getattr(obj, "size", 0) or 0,
"object_name": obj.object_name,
"download_url": f"/api/v1/documents/download/{doc_id}",
"last_modified": last_modified.isoformat() if last_modified else None,
"_sort_key": last_modified.timestamp() if last_modified else 0,
})
document_records.sort(key=lambda item: item["_sort_key"], reverse=True)
if limit is not None:
document_records = document_records[:limit]
for item in document_records:
item.pop("_sort_key", None)
return document_records
def _document_response(result: DocumentProcessResult) -> DocumentUploadResponse:
"""Handle document response for this module."""
return DocumentUploadResponse(
doc_id=result.doc_id,
doc_name=result.doc_name,
status=result.status,
message=result.message,
num_chunks=result.num_chunks,
summary=result.summary,
summary_latency_ms=result.summary_latency_ms,
)
@router.post("/upload", response_model=DocumentUploadResponse)
async def upload_document(
file: UploadFile = File(..., description="上传的文档文件"),
doc_name: Optional[str] = Form(None, description="文档名称"),
regulation_type: Optional[str] = Form(None, description="法规类型"),
version: Optional[str] = Form(None, description="文档版本"),
generate_summary: bool = Form(False, description="是否生成摘要默认不生成可节省约60秒")
doc_name: str | None = Form(None, description="文档名称"),
regulation_type: str | None = Form(None, description="法规类型"),
version: str | None = Form(None, description="文档版本"),
generate_summary: bool = Form(False, description="是否生成摘要"),
):
"""
上传文档并处理
支持格式PDF、DOCX、DOC
处理流程:解析 → 分块 → 嵌入 → 入库(摘要可选)
文件存储MinIO对象存储
参数说明:
- generate_summary: 是否生成LLM摘要默认False。勾选后处理时间增加约60秒。
"""
# 验证文件类型
ext = os.path.splitext(file.filename)[1].lower()
if ext not in [".pdf", ".docx", ".doc"]:
raise HTTPException(
status_code=400,
detail=f"不支持的文件类型: {ext}仅支持PDF、DOCX、DOC"
)
# 验证文件大小
if file.size and file.size > settings.max_file_size_mb * 1024 * 1024:
raise HTTPException(
status_code=400,
detail=f"文件过大,最大支持{settings.max_file_size_mb}MB"
)
# 生成文档ID
doc_id = str(uuid.uuid4())[:8]
# 文档名称
final_doc_name = doc_name or file.filename
# MinIO对象名称
object_name = f"{doc_id}/{file.filename}"
logger.info(f"接收到文件上传: {final_doc_name}, 类型: {ext}, doc_id={doc_id}")
"""Handle upload document."""
content = await file.read()
if not file.filename:
raise HTTPException(status_code=400, detail="文件名不能为空")
if not content:
raise HTTPException(status_code=400, detail="上传文件为空")
try:
# 读取文件内容
content = await file.read()
# 保存临时文件用于处理
temp_dir = tempfile.gettempdir()
temp_path = os.path.join(temp_dir, f"{doc_id}_{file.filename}")
with open(temp_path, "wb") as f:
f.write(content)
logger.info(f"临时文件已保存到: {temp_path}")
# 上传到MinIO
minio = get_minio_client()
upload_success = minio.upload_bytes(
data=content,
object_name=object_name,
content_type=minio._get_content_type(file.filename),
metadata={
"doc_id": doc_id # 仅传递ASCII安全的metadata
}
)
if upload_success:
logger.success(f"文件已上传到MinIO: {object_name}")
else:
logger.warning(f"MinIO上传失败仅使用本地临时文件")
# 处理文档传入相同的doc_id保持一致性
processor = DocumentProcessor(generate_summary=generate_summary)
result = processor.process(
file_path=temp_path,
doc_id=doc_id, # 使用相同的doc_id
doc_name=final_doc_name,
result = get_document_command_service().upload_and_process(
file_name=file.filename,
content=content,
content_type=file.content_type or "application/octet-stream",
doc_name=doc_name,
regulation_type=regulation_type or "",
version=version or ""
)
processor.close()
# 清理临时文件
try:
os.remove(temp_path)
except:
pass
if result.success:
return DocumentUploadResponse(
doc_id=result.doc_id,
doc_name=result.doc_name,
status="success",
message=result.message,
num_chunks=result.num_chunks,
summary=result.summary,
summary_latency_ms=result.summary_latency_ms
)
else:
raise HTTPException(
status_code=500,
detail=result.message
)
except Exception as e:
logger.error(f"文档处理失败: {e}")
raise HTTPException(
status_code=500,
detail=f"文档处理失败: {str(e)}"
version=version or "",
generate_summary=generate_summary,
)
if result.status == "failed":
raise HTTPException(status_code=500, detail=result.message)
return _document_response(result)
except HTTPException:
raise
except Exception as exc:
logger.exception("文档上传失败")
raise HTTPException(status_code=500, detail=str(exc))
@router.get("/status/{doc_id}", response_model=DocumentUploadResponse)
async def get_document_status(doc_id: str):
"""
查询文档处理状态
Args:
doc_id: 文档ID
"""
# TODO: 实现状态查询(需要数据库支持)
"""Return document status."""
document = get_document_query_service().get(doc_id)
if not document:
raise HTTPException(status_code=404, detail="文档不存在")
return DocumentUploadResponse(
doc_id=doc_id,
doc_name="",
status="unknown",
message="状态查询功能待实现"
doc_id=document.doc_id,
doc_name=document.doc_name,
status=document.status.value,
message=document.error_message or "查询成功",
num_chunks=document.chunk_count,
summary=document.summary,
summary_latency_ms=document.summary_latency_ms,
)
@router.get("/download/{doc_id}")
async def download_document(doc_id: str):
"""
下载文档从MinIO获取
Args:
doc_id: 文档ID
Returns:
文件下载响应
"""
logger.info(f"请求下载文档: doc_id={doc_id}")
"""Handle download document."""
try:
minio = get_minio_client()
# 查找该doc_id下的文件MinIO对象名称格式: {doc_id}/{filename}
objects = minio.list_objects(prefix=f"{doc_id}/")
if not objects:
logger.warning(f"MinIO中未找到文档: doc_id={doc_id}")
raise HTTPException(
status_code=404,
detail=f"文档不存在: doc_id={doc_id}"
)
# 获取第一个匹配的对象
object_name = objects[0]
logger.info(f"找到MinIO对象: {object_name}")
# 获取文件数据
file_data = minio.get_object_data(object_name)
if file_data is None:
raise HTTPException(
status_code=500,
detail=f"获取文档数据失败"
)
# 解析原始文件名
original_name = object_name.split("/", 1)[1] if "/" in object_name else object_name
# 获取Content-Type
content_type = minio._get_content_type(original_name)
logger.success(f"文档下载成功: {original_name}, 大小={len(file_data)}")
# 返回文件流URL编码文件名以支持中文
encoded_name = quote(original_name)
document, file_data = get_document_query_service().download(doc_id)
encoded_name = quote(document.file_name)
return StreamingResponse(
BytesIO(file_data),
media_type=content_type,
headers={
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_name}"
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"文档下载失败: {e}")
raise HTTPException(
status_code=500,
detail=f"文档下载失败: {str(e)}"
media_type=document.content_type or "application/octet-stream",
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_name}"},
)
except FileNotFoundError as exc:
raise HTTPException(status_code=404, detail=str(exc))
except Exception as exc:
logger.exception("文档下载失败")
raise HTTPException(status_code=500, detail=str(exc))
@router.get("/list")
async def list_documents():
"""
列出所有已上传的文档从MinIO获取
"""
try:
documents = _build_document_records()
return {"documents": documents, "total": len(documents)}
except Exception as e:
logger.error(f"列出文档失败: {e}")
return {"documents": [], "total": 0, "error": str(e)}
"""List documents."""
documents = get_document_query_service().list_documents()
return {
"documents": [
{
"doc_id": item.doc_id,
"doc_name": item.doc_name,
"status": item.status.value,
"chunk_count": item.chunk_count,
"updated_at": item.updated_at.isoformat(),
}
for item in documents
],
"total": len(documents),
}
@router.get("/management-list")
async def get_document_management_list():
"""
文档管理清单接口仅返回最近的10条文档。
"""
try:
documents = _build_document_records(limit=10)
return {"documents": documents, "total": len(documents), "limit": 10}
except Exception as e:
logger.error(f"获取文档管理清单失败: {e}")
return {"documents": [], "total": 0, "limit": 10, "error": str(e)}
"""Return document management list."""
documents = get_document_query_service().list_documents(limit=10)
return {
"documents": [
{
"doc_id": item.doc_id,
"doc_name": item.doc_name,
"status": item.status.value,
"chunk_count": item.chunk_count,
"updated_at": item.updated_at.isoformat(),
}
for item in documents
],
"total": len(documents),
"limit": 10,
}

View File

@@ -1,80 +1,51 @@
"""知识库检索接口"""
"""Define API routes for knowledge."""
from __future__ import annotations
from fastapi import APIRouter, HTTPException
from loguru import logger
from ..models import SearchRequest, SearchResponse, SearchResultItem, ErrorResponse
from app.services.document_processor import DocumentProcessor
from app.api.models import SearchResponse, SearchResultItem, SearchRequest
from app.shared.bootstrap import get_retrieval_service
# Keep route handlers close to their transport-layer wiring for easier auditing.
router = APIRouter(prefix="/knowledge", tags=["knowledge"])
@router.post("/search", response_model=SearchResponse)
async def search_knowledge(request: SearchRequest):
"""
检索法规知识库
"""Search knowledge."""
if not request.query or not request.query.strip():
raise HTTPException(status_code=400, detail="查询文本不能为空")
使用混合检索Dense向量 + Sparse向量 + RRF融合
Args:
request: 检索请求参数
"""
if not request.query or len(request.query.strip()) == 0:
raise HTTPException(
status_code=400,
detail="查询文本不能为空"
)
logger.info(f"收到检索请求: {request.query}")
try:
# 执行检索
processor = DocumentProcessor()
results = processor.search(
query=request.query,
top_k=request.top_k,
filters=request.filters
)
processor.close()
# 转换结果格式
result_items = []
for r in results:
item = SearchResultItem(
id=r.get("id", 0),
content=r.get("content", ""),
score=r.get("score", 0.0),
metadata=r.get("metadata", {})
results = get_retrieval_service().retrieve(
query=request.query,
top_k=request.top_k,
filters=request.filters,
)
return SearchResponse(
query=request.query,
total=len(results),
results=[
SearchResultItem(
id=index + 1,
content=item.content,
score=item.score,
metadata={
"doc_id": item.doc_id,
"doc_name": item.doc_name,
"chunk_id": item.chunk_id,
"section_title": item.section_title,
"page_number": item.page_number,
**item.metadata,
},
)
result_items.append(item)
return SearchResponse(
query=request.query,
total=len(result_items),
results=result_items
)
except Exception as e:
logger.error(f"检索失败: {e}")
raise HTTPException(
status_code=500,
detail=f"检索失败: {str(e)}"
)
for index, item in enumerate(results)
],
)
@router.post("/retrieval", response_model=SearchResponse)
async def knowledge_retrieval(request: SearchRequest):
"""
知识检索接口(与架构文档对齐)
该接口实现完整的检索流程:
1. 意图识别
2. BM25关键词检索 + 向量语义检索(双路召回)
3. Cross-Encoder精排
4. 返回结果
Args:
request: 检索请求
"""
# 当前版本使用混合检索,后续可添加精排步骤
"""Handle knowledge retrieval."""
return await search_knowledge(request)

View File

@@ -1,29 +1,39 @@
"""Define API routes for rag."""
from __future__ import annotations
import asyncio
import json
from typing import AsyncGenerator
from fastapi import APIRouter
from sse_starlette.sse import EventSourceResponse
from fastapi.responses import StreamingResponse
from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion
from app.services.mock_data import (
get_mock_quick_questions,
get_mock_retrieval,
get_mock_rag_answer,
)
import json
import asyncio
# Keep route handlers close to their transport-layer wiring for easier auditing.
router = APIRouter(prefix="/rag", tags=["RAG问答"])
@router.post("/chat")
async def rag_chat(request: RagChatRequest):
"""SSE流式问答"""
"""Handle rag chat."""
async def generate():
# 发送检索开始事件
yield {"event": "message", "data": json.dumps({"type": "retrieving"})}
async def generate() -> AsyncGenerator[str, None]:
# Keep route handlers close to their transport-layer wiring for easier auditing.
"""Handle generate."""
yield f"event: message\ndata: {json.dumps({'type': 'retrieving'}, ensure_ascii=False)}\n\n"
# 模拟检索延迟
# Keep route handlers close to their transport-layer wiring for easier auditing.
await asyncio.sleep(0.3)
# 执行检索
# Keep route handlers close to their transport-layer wiring for easier auditing.
docs = get_mock_retrieval(request.query, top_k=request.top_k)
retrieved_data = [
@@ -36,39 +46,49 @@ async def rag_chat(request: RagChatRequest):
}
for d in docs
]
yield {"event": "message", "data": json.dumps({"type": "retrieved", "docs": retrieved_data})}
yield f"event: message\ndata: {json.dumps({'type': 'retrieved', 'docs': retrieved_data}, ensure_ascii=False)}\n\n"
# 发送生成开始事件
yield {"event": "message", "data": json.dumps({"type": "generating", "text": "正在生成答案..."})}
# Keep route handlers close to their transport-layer wiring for easier auditing.
yield (
f"event: message\ndata: "
f"{json.dumps({'type': 'generating', 'text': '正在生成答案...'}, ensure_ascii=False)}\n\n"
)
# 模拟生成延迟
# Keep route handlers close to their transport-layer wiring for easier auditing.
await asyncio.sleep(0.2)
# 获取预设答案
# Keep route handlers close to their transport-layer wiring for easier auditing.
answer = get_mock_rag_answer(request.query)
# 流式输出答案(按句子分割)
# Keep route handlers close to their transport-layer wiring for easier auditing.
sentences = answer.split("\n\n")
for sentence in sentences:
if sentence.strip():
# 进一步分割长句子
# Keep route handlers close to their transport-layer wiring for easier auditing.
chunks = sentence.split("\n")
for chunk in chunks:
if chunk.strip():
await asyncio.sleep(0.05) # 模拟生成延迟
yield {"event": "message", "data": json.dumps({"type": "chunk", "text": chunk + "\n"})}
await asyncio.sleep(0.05) # Keep route handlers close to their transport-layer wiring for easier auditing.
yield (
"event: message\n"
f"data: {json.dumps({'type': 'chunk', 'text': chunk + chr(10)}, ensure_ascii=False)}\n\n"
)
# 发送完成事件
yield {"event": "message", "data": json.dumps({"type": "done"})}
# Keep route handlers close to their transport-layer wiring for easier auditing.
yield f"event: message\ndata: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
return EventSourceResponse(generate())
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)
@router.get("/quick-questions", response_model=QuickQuestionsResponse)
async def get_quick_questions():
"""获取预设快捷问题"""
"""Return quick questions."""
questions = [
QuickQuestion(id=q["id"], question=q["question"], category=q["category"])
for q in get_mock_quick_questions()
]
return QuickQuestionsResponse(questions=questions)
return QuickQuestionsResponse(questions=questions)

View File

@@ -1,28 +1,44 @@
"""Define API routes for status."""
from fastapi import APIRouter
from app.core.config import settings
from app.services.mock_data import MOCK_SYSTEM_STATS, MOCK_SYSTEM_CONFIG
from app.config.settings import settings
from app.shared.bootstrap import get_document_query_service, get_vector_index
# Keep route handlers close to their transport-layer wiring for easier auditing.
router = APIRouter(prefix="/status", tags=["系统状态"])
@router.get("/stats")
async def get_stats():
"""获取系统统计"""
# 返回预设统计数据
return MOCK_SYSTEM_STATS
"""Return stats."""
documents = get_document_query_service().list_documents()
indexed = sum(1 for item in documents if item.status.value == "indexed")
failed = sum(1 for item in documents if item.status.value == "failed")
return {
"documents_total": len(documents),
"documents_indexed": indexed,
"documents_failed": failed,
"chunks_total": sum(item.chunk_count for item in documents),
}
@router.get("/config")
async def get_config():
"""获取当前配置"""
return MOCK_SYSTEM_CONFIG
"""Return config."""
return {
"embedding_model": settings.embedding_model,
"embedding_dim": settings.embedding_dim,
"embedding_base_url": settings.embedding_base_url,
"milvus_collection": settings.milvus_collection,
"llm_provider": settings.llm_provider,
"llm_model": settings.llm_model,
"document_metadata_path": settings.document_metadata_path,
}
@router.get("/milvus/health")
async def milvus_health():
"""Milvus健康检查"""
# 模拟连接状态(假数据模式下始终返回连接成功)
return {
"connected": True,
"collections": ["vehicle_regulations"],
}
"""Handle milvus health."""
return get_vector_index().health()