Fix SSE route dependency and align architecture docs
This commit is contained in:
@@ -1,16 +1,29 @@
|
||||
"""API路由模块"""
|
||||
"""Initialize the app.api.routes package."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from .compliance import router as compliance_router
|
||||
from .documents import router as documents_router
|
||||
from .knowledge import router as knowledge_router
|
||||
from .agent import router as agent_router
|
||||
from .status import router as status_router
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
# 主路由
|
||||
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
api_router = APIRouter()
|
||||
|
||||
# 注册子路由
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
api_router.include_router(documents_router)
|
||||
api_router.include_router(knowledge_router)
|
||||
api_router.include_router(agent_router)
|
||||
api_router.include_router(compliance_router)
|
||||
api_router.include_router(status_router)
|
||||
|
||||
__all__ = ["api_router", "documents_router", "knowledge_router", "agent_router"]
|
||||
__all__ = [
|
||||
"api_router",
|
||||
"documents_router",
|
||||
"knowledge_router",
|
||||
"agent_router",
|
||||
"compliance_router",
|
||||
"status_router",
|
||||
]
|
||||
|
||||
@@ -1,186 +1,83 @@
|
||||
"""Agent API接口 - 问答对话接口"""
|
||||
"""Define API routes for agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional, AsyncGenerator
|
||||
from loguru import logger
|
||||
import json
|
||||
import asyncio
|
||||
import json
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from app.services.agent.qa_agent import QAAgent, AgentConfig
|
||||
from app.services.agent.session_manager import SessionManager
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from dataclasses import asdict
|
||||
|
||||
from app.api.models import (
|
||||
AskRequest,
|
||||
AskResponse,
|
||||
ChatRequest,
|
||||
ChatResponse,
|
||||
FeedbackRequest,
|
||||
SessionInfo,
|
||||
)
|
||||
from app.config.settings import settings
|
||||
from app.shared.bootstrap import get_agent_conversation_service, get_conversation_store
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
|
||||
|
||||
router = APIRouter(prefix="/agent", tags=["agent"])
|
||||
|
||||
# 会话管理器(全局实例)
|
||||
session_manager = SessionManager()
|
||||
|
||||
|
||||
# ===== Pydantic Models =====
|
||||
|
||||
class AskRequest(BaseModel):
|
||||
"""单次问答请求"""
|
||||
query: str = Field(..., description="用户问题", min_length=1, max_length=2000)
|
||||
filters: Optional[str] = Field(None, description="检索过滤条件")
|
||||
provider: Optional[str] = Field(None, description="LLM提供商 (qwen/deepseek)")
|
||||
model: Optional[str] = Field(None, description="LLM模型名称")
|
||||
top_k: Optional[int] = Field(None, description="检索数量", ge=1, le=20)
|
||||
prompt_template: Optional[str] = Field(None, description="Prompt模板名称")
|
||||
|
||||
|
||||
class AskResponse(BaseModel):
|
||||
"""问答响应"""
|
||||
answer: str
|
||||
sources: List[Dict] = []
|
||||
model: str = ""
|
||||
latency_ms: int = 0
|
||||
retrieved_count: int = 0
|
||||
context_tokens: int = 0
|
||||
truncated: bool = False
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
"""多轮对话请求"""
|
||||
query: str = Field(..., description="用户问题", min_length=1, max_length=2000)
|
||||
session_id: Optional[str] = Field(None, description="会话ID(首次对话可不传)")
|
||||
filters: Optional[str] = Field(None, description="检索过滤条件")
|
||||
provider: Optional[str] = Field(None, description="LLM提供商")
|
||||
model: Optional[str] = Field(None, description="LLM模型名称")
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""多轮对话响应"""
|
||||
session_id: str
|
||||
answer: str
|
||||
sources: List[Dict] = []
|
||||
model: str = ""
|
||||
latency_ms: int = 0
|
||||
message_count: int = 0
|
||||
|
||||
|
||||
class SessionInfo(BaseModel):
|
||||
"""会话信息"""
|
||||
session_id: str
|
||||
message_count: int
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
|
||||
class FeedbackRequest(BaseModel):
|
||||
"""反馈请求"""
|
||||
session_id: str
|
||||
message_index: int
|
||||
rating: int = Field(..., ge=1, le=5, description="评分 1-5")
|
||||
comment: Optional[str] = Field(None, description="反馈内容")
|
||||
|
||||
|
||||
class TemplateListResponse(BaseModel):
|
||||
"""模板列表响应"""
|
||||
templates: Dict[str, str]
|
||||
|
||||
|
||||
# ===== API Endpoints =====
|
||||
|
||||
@router.post("/ask", response_model=AskResponse)
|
||||
async def ask_question(request: AskRequest):
|
||||
"""
|
||||
单次问答接口
|
||||
|
||||
不保存会话历史,适合单次查询场景。
|
||||
"""
|
||||
logger.info(f"收到问答请求: {request.query}")
|
||||
|
||||
"""Handle ask question."""
|
||||
try:
|
||||
# 构建Agent配置
|
||||
config = AgentConfig(
|
||||
llm_provider=request.provider or settings.llm_provider,
|
||||
llm_model=request.model or settings.llm_model,
|
||||
top_k=request.top_k or settings.rag_top_k
|
||||
)
|
||||
|
||||
# 创建Agent并执行问答
|
||||
agent = QAAgent(config)
|
||||
response = agent.ask(
|
||||
_, result = get_agent_conversation_service().ask(
|
||||
query=request.query,
|
||||
filters=request.filters,
|
||||
prompt_template=request.prompt_template
|
||||
provider=request.provider or settings.llm_provider,
|
||||
model=request.model or settings.llm_model,
|
||||
top_k=request.top_k or settings.rag_top_k,
|
||||
prompt_template=request.prompt_template,
|
||||
)
|
||||
agent.close()
|
||||
|
||||
return AskResponse(
|
||||
answer=response.answer,
|
||||
sources=response.sources,
|
||||
model=response.model,
|
||||
latency_ms=response.latency_ms,
|
||||
retrieved_count=response.retrieved_count,
|
||||
context_tokens=response.context_tokens,
|
||||
truncated=response.truncated,
|
||||
error=response.error
|
||||
answer=result.answer,
|
||||
sources=[asdict(source) for source in result.sources],
|
||||
model=result.model,
|
||||
latency_ms=result.latency_ms,
|
||||
retrieved_count=result.retrieved_count,
|
||||
context_tokens=result.context_tokens,
|
||||
truncated=result.truncated,
|
||||
error=result.error,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"问答失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/chat", response_model=ChatResponse)
|
||||
async def chat_with_session(request: ChatRequest):
|
||||
"""
|
||||
多轮对话接口
|
||||
|
||||
支持会话历史记录,适合连续对话场景。
|
||||
"""
|
||||
logger.info(f"收到对话请求: session={request.session_id}, query={request.query}")
|
||||
|
||||
"""Handle chat with session."""
|
||||
try:
|
||||
# 获取或创建会话
|
||||
if request.session_id:
|
||||
session = session_manager.get_session(request.session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="会话不存在或已过期")
|
||||
else:
|
||||
session = session_manager.create_session()
|
||||
|
||||
# 添加用户消息
|
||||
session.add_user_message(request.query)
|
||||
|
||||
# 执行问答
|
||||
config = AgentConfig(
|
||||
llm_provider=request.provider or settings.llm_provider,
|
||||
llm_model=request.model or settings.llm_model
|
||||
)
|
||||
|
||||
agent = QAAgent(config)
|
||||
response = agent.ask(
|
||||
session_id, result = get_agent_conversation_service().chat(
|
||||
query=request.query,
|
||||
filters=request.filters
|
||||
session_id=request.session_id,
|
||||
filters=request.filters,
|
||||
provider=request.provider or settings.llm_provider,
|
||||
model=request.model or settings.llm_model,
|
||||
top_k=request.top_k or settings.rag_top_k,
|
||||
)
|
||||
agent.close()
|
||||
|
||||
# 添加助手消息
|
||||
session.add_assistant_message(
|
||||
response.answer,
|
||||
response.sources
|
||||
)
|
||||
|
||||
session = get_conversation_store().get_session(session_id)
|
||||
return ChatResponse(
|
||||
session_id=session.session_id,
|
||||
answer=response.answer,
|
||||
sources=response.sources,
|
||||
model=response.model,
|
||||
latency_ms=response.latency_ms,
|
||||
message_count=session.message_count
|
||||
session_id=session_id,
|
||||
answer=result.answer,
|
||||
sources=[asdict(source) for source in result.sources],
|
||||
model=result.model,
|
||||
latency_ms=result.latency_ms,
|
||||
message_count=len(session.messages) if session else 0,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"对话失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc))
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.get("/chat/stream")
|
||||
@@ -189,260 +86,93 @@ async def chat_stream_get(
|
||||
session_id: Optional[str] = None,
|
||||
filters: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
model: Optional[str] = None
|
||||
model: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
流式对话接口(SSE)- GET版本
|
||||
|
||||
EventSource只能发送GET请求,因此提供此接口。
|
||||
query参数通过URL传递。
|
||||
|
||||
SSE事件格式:
|
||||
- event: session - 会话ID
|
||||
- event: status - 状态更新(检索中、生成中)
|
||||
- event: sources - 引用来源
|
||||
- event: content - 回答内容片段
|
||||
- event: done - 完成,包含统计信息
|
||||
- event: error - 错误信息
|
||||
"""
|
||||
logger.info(f"收到GET流式对话请求: session={session_id}, query={query}")
|
||||
|
||||
"""Handle chat stream get."""
|
||||
async def generate_sse() -> AsyncGenerator[str, None]:
|
||||
"""生成SSE事件流"""
|
||||
"""Handle generate sse."""
|
||||
try:
|
||||
# 获取或创建会话
|
||||
if session_id:
|
||||
session = session_manager.get_session(session_id)
|
||||
if not session:
|
||||
yield f"event: error\ndata: 会话不存在或已过期\n\n"
|
||||
return
|
||||
else:
|
||||
session = session_manager.create_session()
|
||||
|
||||
# 发送session_id
|
||||
yield f"event: session\ndata: {json.dumps({'session_id': session.session_id})}\n\n"
|
||||
|
||||
# 添加用户消息
|
||||
session.add_user_message(query)
|
||||
|
||||
# 创建Agent
|
||||
config = AgentConfig(
|
||||
llm_provider=provider or settings.llm_provider,
|
||||
llm_model=model or settings.llm_model
|
||||
)
|
||||
|
||||
agent = QAAgent(config)
|
||||
|
||||
# 执行流式问答
|
||||
full_answer = ""
|
||||
sources = []
|
||||
done_data = {}
|
||||
|
||||
for event_data in agent.ask_stream(
|
||||
session_id_, event_stream = get_agent_conversation_service().stream_chat(
|
||||
query=query,
|
||||
filters=filters
|
||||
):
|
||||
session_id=session_id,
|
||||
filters=filters,
|
||||
provider=provider or settings.llm_provider,
|
||||
model=model or settings.llm_model,
|
||||
top_k=settings.rag_top_k,
|
||||
)
|
||||
yield f"event: session\ndata: {json.dumps({'session_id': session_id_})}\n\n"
|
||||
for event_data in event_stream:
|
||||
event_type = event_data.get("event", "content")
|
||||
data = event_data.get("data", "")
|
||||
|
||||
# 收集完整回答和来源
|
||||
if event_type == "content":
|
||||
full_answer += str(data)
|
||||
elif event_type == "sources":
|
||||
sources = data
|
||||
elif event_type == "done":
|
||||
done_data = data
|
||||
|
||||
# 发送SSE事件
|
||||
if isinstance(data, (dict, list)):
|
||||
yield f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
||||
yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
yield f"event: {event_type}\ndata: {data}\n\n"
|
||||
|
||||
# 小延迟让其他任务有机会执行
|
||||
await asyncio.sleep(0)
|
||||
|
||||
agent.close()
|
||||
|
||||
# 保存到会话历史
|
||||
session.add_assistant_message(full_answer, sources)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式对话失败: {e}")
|
||||
yield f"event: error\ndata: {str(e)}\n\n"
|
||||
except Exception as exc:
|
||||
yield f"event: error\ndata: {str(exc)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate_sse(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no" # 禁用nginx缓冲
|
||||
}
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chat/stream")
|
||||
async def chat_stream(request: ChatRequest):
|
||||
"""
|
||||
流式对话接口(SSE)
|
||||
|
||||
返回Server-Sent Events格式的流式响应,用户可实时看到思考过程和回答生成。
|
||||
|
||||
SSE事件格式:
|
||||
- event: status - 状态更新(检索中、生成中)
|
||||
- event: sources - 引用来源
|
||||
- event: content - 回答内容片段
|
||||
- event: done - 完成,包含统计信息
|
||||
- event: error - 错误信息
|
||||
"""
|
||||
logger.info(f"收到流式对话请求: session={request.session_id}, query={request.query}")
|
||||
|
||||
async def generate_sse() -> AsyncGenerator[str, None]:
|
||||
"""生成SSE事件流"""
|
||||
try:
|
||||
# 获取或创建会话
|
||||
if request.session_id:
|
||||
session = session_manager.get_session(request.session_id)
|
||||
if not session:
|
||||
yield f"event: error\ndata: 会话不存在或已过期\n\n"
|
||||
return
|
||||
else:
|
||||
session = session_manager.create_session()
|
||||
|
||||
# 发送session_id
|
||||
yield f"event: session\ndata: {json.dumps({'session_id': session.session_id})}\n\n"
|
||||
|
||||
# 添加用户消息
|
||||
session.add_user_message(request.query)
|
||||
|
||||
# 创建Agent
|
||||
config = AgentConfig(
|
||||
llm_provider=request.provider or settings.llm_provider,
|
||||
llm_model=request.model or settings.llm_model
|
||||
)
|
||||
|
||||
agent = QAAgent(config)
|
||||
|
||||
# 执行流式问答
|
||||
full_answer = ""
|
||||
sources = []
|
||||
done_data = {}
|
||||
|
||||
for event_data in agent.ask_stream(
|
||||
query=request.query,
|
||||
filters=request.filters
|
||||
):
|
||||
event_type = event_data.get("event", "content")
|
||||
data = event_data.get("data", "")
|
||||
|
||||
# 收集完整回答和来源
|
||||
if event_type == "content":
|
||||
full_answer += str(data)
|
||||
elif event_type == "sources":
|
||||
sources = data
|
||||
elif event_type == "done":
|
||||
done_data = data
|
||||
|
||||
# 发送SSE事件
|
||||
if isinstance(data, (dict, list)):
|
||||
yield f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
||||
else:
|
||||
yield f"event: {event_type}\ndata: {data}\n\n"
|
||||
|
||||
# 小延迟让其他任务有机会执行
|
||||
await asyncio.sleep(0)
|
||||
|
||||
agent.close()
|
||||
|
||||
# 保存到会话历史
|
||||
session.add_assistant_message(full_answer, sources)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式对话失败: {e}")
|
||||
yield f"event: error\ndata: {str(e)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate_sse(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no" # 禁用nginx缓冲
|
||||
}
|
||||
"""Handle chat stream."""
|
||||
return await chat_stream_get(
|
||||
query=request.query,
|
||||
session_id=request.session_id,
|
||||
filters=request.filters,
|
||||
provider=request.provider,
|
||||
model=request.model,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/session/{session_id}", response_model=SessionInfo)
|
||||
async def get_session_info(session_id: str):
|
||||
"""获取会话信息"""
|
||||
session = session_manager.get_session(session_id)
|
||||
"""Return session info."""
|
||||
session = get_conversation_store().get_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="会话不存在或已过期")
|
||||
|
||||
return SessionInfo(
|
||||
session_id=session.session_id,
|
||||
message_count=session.message_count,
|
||||
message_count=len(session.messages),
|
||||
created_at=session.created_at,
|
||||
updated_at=session.updated_at
|
||||
updated_at=session.updated_at,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/session/{session_id}/history")
|
||||
async def get_session_history(session_id: str, max_turns: int = 5):
|
||||
"""获取会话历史"""
|
||||
session = session_manager.get_session(session_id)
|
||||
"""Return session history."""
|
||||
session = get_conversation_store().get_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="会话不存在或已过期")
|
||||
|
||||
history = session.get_history(max_turns)
|
||||
history = [{"role": msg.role, "content": msg.content} for msg in session.messages[-(max_turns * 2):]]
|
||||
return {"session_id": session_id, "history": history}
|
||||
|
||||
|
||||
@router.delete("/session/{session_id}")
|
||||
async def delete_session(session_id: str):
|
||||
"""删除会话"""
|
||||
success = session_manager.delete_session(session_id)
|
||||
if not success:
|
||||
"""Delete session."""
|
||||
if not get_conversation_store().delete_session(session_id):
|
||||
raise HTTPException(status_code=404, detail="会话不存在")
|
||||
|
||||
return {"message": "会话已删除", "session_id": session_id}
|
||||
|
||||
|
||||
@router.get("/sessions", response_model=List[SessionInfo])
|
||||
async def list_sessions():
|
||||
"""列出所有活跃会话"""
|
||||
sessions = session_manager.list_sessions()
|
||||
return [SessionInfo(**s) for s in sessions]
|
||||
"""List sessions."""
|
||||
return [SessionInfo(**item) for item in get_conversation_store().list_sessions()]
|
||||
|
||||
|
||||
@router.post("/feedback")
|
||||
async def submit_feedback(request: FeedbackRequest):
|
||||
"""提交问答反馈"""
|
||||
session = session_manager.get_session(request.session_id)
|
||||
"""Submit feedback."""
|
||||
session = get_conversation_store().get_session(request.session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="会话不存在")
|
||||
|
||||
# 记录反馈(实际应用中可存储到数据库)
|
||||
logger.info(f"收到反馈: session={request.session_id}, rating={request.rating}, comment={request.comment}")
|
||||
|
||||
return {"message": "反馈已记录", "rating": request.rating}
|
||||
|
||||
|
||||
@router.get("/templates", response_model=TemplateListResponse)
|
||||
async def list_prompt_templates():
|
||||
"""列出可用的Prompt模板"""
|
||||
from app.services.rag.prompt_templates import PromptTemplates
|
||||
|
||||
templates = PromptTemplates.list_templates()
|
||||
return TemplateListResponse(templates=templates)
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
async def list_available_models():
|
||||
"""列出可用的LLM模型"""
|
||||
from app.services.llm import LLMFactory
|
||||
|
||||
factory = LLMFactory()
|
||||
models = factory.list_available_providers()
|
||||
return {"models": models}
|
||||
return {"message": "反馈已提交", "session_id": request.session_id, "message_index": request.message_index}
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
import uuid
|
||||
import os
|
||||
import json
|
||||
"""Define API routes for compliance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import APIRouter, File, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.schemas.compliance import (
|
||||
AnalyzeResponse,
|
||||
ComplianceChatRequest,
|
||||
@@ -13,38 +19,42 @@ from app.services.mock_data import (
|
||||
get_mock_compliance_result,
|
||||
get_mock_compliance_chat_response,
|
||||
)
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
|
||||
|
||||
router = APIRouter(prefix="/compliance", tags=["合规分析"])
|
||||
|
||||
# 临时存储分析任务
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
tasks_store: dict[str, dict] = {}
|
||||
|
||||
# Store uploaded compliance files inside the local backend data directory.
|
||||
RAW_DATA_DIR = Path(__file__).resolve().parents[3] / "data" / "raw"
|
||||
|
||||
|
||||
@router.post("/analyze", response_model=AnalyzeResponse)
|
||||
async def analyze_document(file: UploadFile = File(...)):
|
||||
"""上传设计方案进行分析"""
|
||||
# 生成任务ID
|
||||
"""Handle analyze document."""
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
task_id = generate_task_id()
|
||||
|
||||
# 保存文件
|
||||
raw_dir = "/airegulation/demo-mao/backend/data/raw"
|
||||
os.makedirs(raw_dir, exist_ok=True)
|
||||
file_path = os.path.join(raw_dir, f"compliance_{task_id}_{file.filename}")
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
RAW_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
file_path = RAW_DATA_DIR / f"compliance_{task_id}_{file.filename}"
|
||||
|
||||
content = await file.read()
|
||||
with open(file_path, "wb") as f:
|
||||
with file_path.open("wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# 记录任务
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
tasks_store[task_id] = {
|
||||
"task_id": task_id,
|
||||
"file_path": file_path,
|
||||
"file_path": str(file_path),
|
||||
"status": "processing",
|
||||
"result": None,
|
||||
}
|
||||
|
||||
# 模拟异步处理完成(立即返回结果)
|
||||
# 实际应用中这应该是后台任务
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
tasks_store[task_id]["status"] = "completed"
|
||||
tasks_store[task_id]["result"] = get_mock_compliance_result(task_id)
|
||||
|
||||
@@ -53,9 +63,9 @@ async def analyze_document(file: UploadFile = File(...)):
|
||||
|
||||
@router.get("/result/{task_id}")
|
||||
async def get_result(task_id: str):
|
||||
"""获取分析结果"""
|
||||
"""Return result."""
|
||||
if task_id not in tasks_store:
|
||||
# 如果任务ID不存在,返回默认mock结果
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
return get_mock_compliance_result(task_id)
|
||||
|
||||
task = tasks_store[task_id]
|
||||
@@ -68,8 +78,8 @@ async def get_result(task_id: str):
|
||||
|
||||
@router.post("/chat/{segment_id}")
|
||||
async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
|
||||
"""针对段落进行合规对话"""
|
||||
# 根据segment_id获取对应的intent
|
||||
"""Handle compliance chat."""
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
intent_map = {
|
||||
1: "车身结构设计",
|
||||
2: "动力系统配置",
|
||||
@@ -77,11 +87,12 @@ async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
|
||||
}
|
||||
intent = intent_map.get(segment_id, "车身结构设计")
|
||||
|
||||
async def generate():
|
||||
# 获取预设响应
|
||||
async def generate() -> AsyncGenerator[str, None]:
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
"""Handle generate."""
|
||||
response = get_mock_compliance_chat_response(intent, request.query)
|
||||
|
||||
# 流式输出响应
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
sentences = response.split("\n\n")
|
||||
for sentence in sentences:
|
||||
if sentence.strip():
|
||||
@@ -89,8 +100,15 @@ async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
|
||||
for chunk in chunks:
|
||||
if chunk.strip():
|
||||
await asyncio.sleep(0.05)
|
||||
yield {"event": "message", "data": json.dumps({"type": "chunk", "text": chunk + "\n"})}
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'chunk', 'text': chunk + chr(10)}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
|
||||
yield {"event": "message", "data": json.dumps({"type": "done"})}
|
||||
yield f"event: message\ndata: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
return EventSourceResponse(generate())
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Define API routes for docs."""
|
||||
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException
|
||||
import os
|
||||
import uuid
|
||||
@@ -10,30 +12,32 @@ from app.schemas.doc import (
|
||||
EmbedResponse,
|
||||
)
|
||||
from app.services.mock_data import get_mock_documents, generate_doc_id
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
|
||||
|
||||
router = APIRouter(prefix="/docs", tags=["文档管理"])
|
||||
|
||||
# 临时存储文档信息(包含预设的mock文档)
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
documents_store: dict[str, dict] = {}
|
||||
|
||||
# 初始化时加载mock文档
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
for doc in get_mock_documents():
|
||||
documents_store[doc["id"]] = doc
|
||||
|
||||
|
||||
@router.post("/upload", response_model=DocumentUploadResponse)
|
||||
async def upload_document(file: UploadFile = File(...)):
|
||||
"""上传法规文档"""
|
||||
# 检查文件格式
|
||||
"""Handle upload document."""
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
allowed_ext = [".pdf", ".docx", ".doc", ".txt"]
|
||||
ext = os.path.splitext(file.filename)[1].lower()
|
||||
if ext not in allowed_ext:
|
||||
raise HTTPException(400, f"Unsupported file format: {ext}")
|
||||
|
||||
# 生成文档ID
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
doc_id = generate_doc_id()
|
||||
|
||||
# 保存文件
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
raw_dir = "/airegulation/demo-mao/backend/data/raw"
|
||||
os.makedirs(raw_dir, exist_ok=True)
|
||||
file_path = os.path.join(raw_dir, f"{doc_id}_{file.filename}")
|
||||
@@ -42,7 +46,7 @@ async def upload_document(file: UploadFile = File(...)):
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# 记录文档信息
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
documents_store[doc_id] = {
|
||||
"id": doc_id,
|
||||
"name": file.filename,
|
||||
@@ -62,7 +66,7 @@ async def upload_document(file: UploadFile = File(...)):
|
||||
|
||||
@router.get("/list", response_model=DocumentListResponse)
|
||||
async def list_documents():
|
||||
"""获取已索引文档列表"""
|
||||
"""List documents."""
|
||||
docs = [
|
||||
DocumentInfo(
|
||||
id=d["id"],
|
||||
@@ -78,14 +82,14 @@ async def list_documents():
|
||||
|
||||
@router.post("/parse/{doc_id}", response_model=ParseResponse)
|
||||
async def parse_document(doc_id: str):
|
||||
"""解析文档并分块"""
|
||||
"""Parse document."""
|
||||
if doc_id not in documents_store:
|
||||
raise HTTPException(404, "Document not found")
|
||||
|
||||
doc = documents_store[doc_id]
|
||||
# 模拟解析逻辑
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
doc["status"] = "parsed"
|
||||
# 根据文件大小计算chunks数量
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
file_size = doc.get("size", 100000)
|
||||
doc["chunks"] = max(20, file_size // 8000)
|
||||
|
||||
@@ -94,12 +98,12 @@ async def parse_document(doc_id: str):
|
||||
|
||||
@router.post("/embed/{doc_id}", response_model=EmbedResponse)
|
||||
async def embed_document(doc_id: str):
|
||||
"""嵌入并存入向量库"""
|
||||
"""Embed document."""
|
||||
if doc_id not in documents_store:
|
||||
raise HTTPException(404, "Document not found")
|
||||
|
||||
doc = documents_store[doc_id]
|
||||
# 模拟嵌入逻辑
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
doc["status"] = "indexed"
|
||||
|
||||
return EmbedResponse(doc_id=doc_id, vectors=doc["chunks"])
|
||||
@@ -107,7 +111,7 @@ async def embed_document(doc_id: str):
|
||||
|
||||
@router.delete("/delete/{doc_id}")
|
||||
async def delete_document(doc_id: str):
|
||||
"""删除文档"""
|
||||
"""Delete document."""
|
||||
if doc_id not in documents_store:
|
||||
raise HTTPException(404, "Document not found")
|
||||
|
||||
|
||||
@@ -1,290 +1,140 @@
|
||||
"""文档上传与处理接口"""
|
||||
"""Define API routes for documents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, UploadFile, File, Form, HTTPException
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from typing import Optional
|
||||
import os
|
||||
import uuid
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from io import BytesIO
|
||||
from urllib.parse import quote
|
||||
|
||||
from ..models import DocumentUploadResponse, ErrorResponse
|
||||
from app.services.document_processor import DocumentProcessor
|
||||
from app.services.storage.minio_client import MinIOClient
|
||||
from app.config.settings import settings
|
||||
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
|
||||
from app.api.models import DocumentUploadResponse
|
||||
from app.application.documents import DocumentProcessResult
|
||||
from app.shared.bootstrap import get_document_command_service, get_document_query_service
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
|
||||
|
||||
router = APIRouter(prefix="/documents", tags=["documents"])
|
||||
|
||||
# MinIO客户端(用于文档存储)
|
||||
minio_client: Optional[MinIOClient] = None
|
||||
|
||||
|
||||
def get_minio_client() -> MinIOClient:
|
||||
"""获取MinIO客户端实例"""
|
||||
global minio_client
|
||||
if minio_client is None:
|
||||
minio_client = MinIOClient()
|
||||
minio_client.connect()
|
||||
minio_client.ensure_bucket()
|
||||
return minio_client
|
||||
|
||||
|
||||
def _build_document_records(limit: Optional[int] = None):
|
||||
"""构建文档列表记录,支持按最近更新时间倒序截断。"""
|
||||
minio = get_minio_client()
|
||||
|
||||
document_records = []
|
||||
objects = minio.client.list_objects(minio.bucket, recursive=True)
|
||||
for obj in objects:
|
||||
parts = obj.object_name.split("/", 1)
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
|
||||
doc_id, filename = parts
|
||||
last_modified = getattr(obj, "last_modified", None)
|
||||
document_records.append({
|
||||
"doc_id": doc_id,
|
||||
"filename": filename,
|
||||
"size": getattr(obj, "size", 0) or 0,
|
||||
"object_name": obj.object_name,
|
||||
"download_url": f"/api/v1/documents/download/{doc_id}",
|
||||
"last_modified": last_modified.isoformat() if last_modified else None,
|
||||
"_sort_key": last_modified.timestamp() if last_modified else 0,
|
||||
})
|
||||
|
||||
document_records.sort(key=lambda item: item["_sort_key"], reverse=True)
|
||||
if limit is not None:
|
||||
document_records = document_records[:limit]
|
||||
|
||||
for item in document_records:
|
||||
item.pop("_sort_key", None)
|
||||
|
||||
return document_records
|
||||
def _document_response(result: DocumentProcessResult) -> DocumentUploadResponse:
|
||||
"""Handle document response for this module."""
|
||||
return DocumentUploadResponse(
|
||||
doc_id=result.doc_id,
|
||||
doc_name=result.doc_name,
|
||||
status=result.status,
|
||||
message=result.message,
|
||||
num_chunks=result.num_chunks,
|
||||
summary=result.summary,
|
||||
summary_latency_ms=result.summary_latency_ms,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/upload", response_model=DocumentUploadResponse)
|
||||
async def upload_document(
|
||||
file: UploadFile = File(..., description="上传的文档文件"),
|
||||
doc_name: Optional[str] = Form(None, description="文档名称"),
|
||||
regulation_type: Optional[str] = Form(None, description="法规类型"),
|
||||
version: Optional[str] = Form(None, description="文档版本"),
|
||||
generate_summary: bool = Form(False, description="是否生成摘要(默认不生成,可节省约60秒)")
|
||||
doc_name: str | None = Form(None, description="文档名称"),
|
||||
regulation_type: str | None = Form(None, description="法规类型"),
|
||||
version: str | None = Form(None, description="文档版本"),
|
||||
generate_summary: bool = Form(False, description="是否生成摘要"),
|
||||
):
|
||||
"""
|
||||
上传文档并处理
|
||||
|
||||
支持格式:PDF、DOCX、DOC
|
||||
处理流程:解析 → 分块 → 嵌入 → 入库(摘要可选)
|
||||
文件存储:MinIO对象存储
|
||||
|
||||
参数说明:
|
||||
- generate_summary: 是否生成LLM摘要,默认False。勾选后处理时间增加约60秒。
|
||||
"""
|
||||
# 验证文件类型
|
||||
ext = os.path.splitext(file.filename)[1].lower()
|
||||
if ext not in [".pdf", ".docx", ".doc"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的文件类型: {ext},仅支持PDF、DOCX、DOC"
|
||||
)
|
||||
|
||||
# 验证文件大小
|
||||
if file.size and file.size > settings.max_file_size_mb * 1024 * 1024:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"文件过大,最大支持{settings.max_file_size_mb}MB"
|
||||
)
|
||||
|
||||
# 生成文档ID
|
||||
doc_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# 文档名称
|
||||
final_doc_name = doc_name or file.filename
|
||||
|
||||
# MinIO对象名称
|
||||
object_name = f"{doc_id}/{file.filename}"
|
||||
|
||||
logger.info(f"接收到文件上传: {final_doc_name}, 类型: {ext}, doc_id={doc_id}")
|
||||
"""Handle upload document."""
|
||||
content = await file.read()
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="文件名不能为空")
|
||||
if not content:
|
||||
raise HTTPException(status_code=400, detail="上传文件为空")
|
||||
|
||||
try:
|
||||
# 读取文件内容
|
||||
content = await file.read()
|
||||
|
||||
# 保存临时文件用于处理
|
||||
temp_dir = tempfile.gettempdir()
|
||||
temp_path = os.path.join(temp_dir, f"{doc_id}_{file.filename}")
|
||||
|
||||
with open(temp_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
logger.info(f"临时文件已保存到: {temp_path}")
|
||||
|
||||
# 上传到MinIO
|
||||
minio = get_minio_client()
|
||||
upload_success = minio.upload_bytes(
|
||||
data=content,
|
||||
object_name=object_name,
|
||||
content_type=minio._get_content_type(file.filename),
|
||||
metadata={
|
||||
"doc_id": doc_id # 仅传递ASCII安全的metadata
|
||||
}
|
||||
)
|
||||
|
||||
if upload_success:
|
||||
logger.success(f"文件已上传到MinIO: {object_name}")
|
||||
else:
|
||||
logger.warning(f"MinIO上传失败,仅使用本地临时文件")
|
||||
|
||||
# 处理文档(传入相同的doc_id,保持一致性)
|
||||
processor = DocumentProcessor(generate_summary=generate_summary)
|
||||
result = processor.process(
|
||||
file_path=temp_path,
|
||||
doc_id=doc_id, # 使用相同的doc_id
|
||||
doc_name=final_doc_name,
|
||||
result = get_document_command_service().upload_and_process(
|
||||
file_name=file.filename,
|
||||
content=content,
|
||||
content_type=file.content_type or "application/octet-stream",
|
||||
doc_name=doc_name,
|
||||
regulation_type=regulation_type or "",
|
||||
version=version or ""
|
||||
)
|
||||
processor.close()
|
||||
|
||||
# 清理临时文件
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
if result.success:
|
||||
return DocumentUploadResponse(
|
||||
doc_id=result.doc_id,
|
||||
doc_name=result.doc_name,
|
||||
status="success",
|
||||
message=result.message,
|
||||
num_chunks=result.num_chunks,
|
||||
summary=result.summary,
|
||||
summary_latency_ms=result.summary_latency_ms
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=result.message
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"文档处理失败: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"文档处理失败: {str(e)}"
|
||||
version=version or "",
|
||||
generate_summary=generate_summary,
|
||||
)
|
||||
if result.status == "failed":
|
||||
raise HTTPException(status_code=500, detail=result.message)
|
||||
return _document_response(result)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("文档上传失败")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.get("/status/{doc_id}", response_model=DocumentUploadResponse)
|
||||
async def get_document_status(doc_id: str):
|
||||
"""
|
||||
查询文档处理状态
|
||||
|
||||
Args:
|
||||
doc_id: 文档ID
|
||||
"""
|
||||
# TODO: 实现状态查询(需要数据库支持)
|
||||
"""Return document status."""
|
||||
document = get_document_query_service().get(doc_id)
|
||||
if not document:
|
||||
raise HTTPException(status_code=404, detail="文档不存在")
|
||||
return DocumentUploadResponse(
|
||||
doc_id=doc_id,
|
||||
doc_name="",
|
||||
status="unknown",
|
||||
message="状态查询功能待实现"
|
||||
doc_id=document.doc_id,
|
||||
doc_name=document.doc_name,
|
||||
status=document.status.value,
|
||||
message=document.error_message or "查询成功",
|
||||
num_chunks=document.chunk_count,
|
||||
summary=document.summary,
|
||||
summary_latency_ms=document.summary_latency_ms,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/download/{doc_id}")
|
||||
async def download_document(doc_id: str):
|
||||
"""
|
||||
下载文档(从MinIO获取)
|
||||
|
||||
Args:
|
||||
doc_id: 文档ID
|
||||
|
||||
Returns:
|
||||
文件下载响应
|
||||
"""
|
||||
logger.info(f"请求下载文档: doc_id={doc_id}")
|
||||
|
||||
"""Handle download document."""
|
||||
try:
|
||||
minio = get_minio_client()
|
||||
|
||||
# 查找该doc_id下的文件(MinIO对象名称格式: {doc_id}/{filename})
|
||||
objects = minio.list_objects(prefix=f"{doc_id}/")
|
||||
|
||||
if not objects:
|
||||
logger.warning(f"MinIO中未找到文档: doc_id={doc_id}")
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"文档不存在: doc_id={doc_id}"
|
||||
)
|
||||
|
||||
# 获取第一个匹配的对象
|
||||
object_name = objects[0]
|
||||
logger.info(f"找到MinIO对象: {object_name}")
|
||||
|
||||
# 获取文件数据
|
||||
file_data = minio.get_object_data(object_name)
|
||||
if file_data is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"获取文档数据失败"
|
||||
)
|
||||
|
||||
# 解析原始文件名
|
||||
original_name = object_name.split("/", 1)[1] if "/" in object_name else object_name
|
||||
|
||||
# 获取Content-Type
|
||||
content_type = minio._get_content_type(original_name)
|
||||
|
||||
logger.success(f"文档下载成功: {original_name}, 大小={len(file_data)}")
|
||||
|
||||
# 返回文件流(URL编码文件名以支持中文)
|
||||
encoded_name = quote(original_name)
|
||||
document, file_data = get_document_query_service().download(doc_id)
|
||||
encoded_name = quote(document.file_name)
|
||||
return StreamingResponse(
|
||||
BytesIO(file_data),
|
||||
media_type=content_type,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_name}"
|
||||
}
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"文档下载失败: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"文档下载失败: {str(e)}"
|
||||
media_type=document.content_type or "application/octet-stream",
|
||||
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_name}"},
|
||||
)
|
||||
except FileNotFoundError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc))
|
||||
except Exception as exc:
|
||||
logger.exception("文档下载失败")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.get("/list")
|
||||
async def list_documents():
|
||||
"""
|
||||
列出所有已上传的文档(从MinIO获取)
|
||||
"""
|
||||
try:
|
||||
documents = _build_document_records()
|
||||
return {"documents": documents, "total": len(documents)}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"列出文档失败: {e}")
|
||||
return {"documents": [], "total": 0, "error": str(e)}
|
||||
"""List documents."""
|
||||
documents = get_document_query_service().list_documents()
|
||||
return {
|
||||
"documents": [
|
||||
{
|
||||
"doc_id": item.doc_id,
|
||||
"doc_name": item.doc_name,
|
||||
"status": item.status.value,
|
||||
"chunk_count": item.chunk_count,
|
||||
"updated_at": item.updated_at.isoformat(),
|
||||
}
|
||||
for item in documents
|
||||
],
|
||||
"total": len(documents),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/management-list")
|
||||
async def get_document_management_list():
|
||||
"""
|
||||
文档管理清单接口:仅返回最近的10条文档。
|
||||
"""
|
||||
try:
|
||||
documents = _build_document_records(limit=10)
|
||||
return {"documents": documents, "total": len(documents), "limit": 10}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取文档管理清单失败: {e}")
|
||||
return {"documents": [], "total": 0, "limit": 10, "error": str(e)}
|
||||
"""Return document management list."""
|
||||
documents = get_document_query_service().list_documents(limit=10)
|
||||
return {
|
||||
"documents": [
|
||||
{
|
||||
"doc_id": item.doc_id,
|
||||
"doc_name": item.doc_name,
|
||||
"status": item.status.value,
|
||||
"chunk_count": item.chunk_count,
|
||||
"updated_at": item.updated_at.isoformat(),
|
||||
}
|
||||
for item in documents
|
||||
],
|
||||
"total": len(documents),
|
||||
"limit": 10,
|
||||
}
|
||||
|
||||
@@ -1,80 +1,51 @@
|
||||
"""知识库检索接口"""
|
||||
"""Define API routes for knowledge."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from ..models import SearchRequest, SearchResponse, SearchResultItem, ErrorResponse
|
||||
from app.services.document_processor import DocumentProcessor
|
||||
from app.api.models import SearchResponse, SearchResultItem, SearchRequest
|
||||
from app.shared.bootstrap import get_retrieval_service
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
|
||||
|
||||
router = APIRouter(prefix="/knowledge", tags=["knowledge"])
|
||||
|
||||
|
||||
@router.post("/search", response_model=SearchResponse)
|
||||
async def search_knowledge(request: SearchRequest):
|
||||
"""
|
||||
检索法规知识库
|
||||
"""Search knowledge."""
|
||||
if not request.query or not request.query.strip():
|
||||
raise HTTPException(status_code=400, detail="查询文本不能为空")
|
||||
|
||||
使用混合检索:Dense向量 + Sparse向量 + RRF融合
|
||||
|
||||
Args:
|
||||
request: 检索请求参数
|
||||
"""
|
||||
if not request.query or len(request.query.strip()) == 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="查询文本不能为空"
|
||||
)
|
||||
|
||||
logger.info(f"收到检索请求: {request.query}")
|
||||
|
||||
try:
|
||||
# 执行检索
|
||||
processor = DocumentProcessor()
|
||||
results = processor.search(
|
||||
query=request.query,
|
||||
top_k=request.top_k,
|
||||
filters=request.filters
|
||||
)
|
||||
processor.close()
|
||||
|
||||
# 转换结果格式
|
||||
result_items = []
|
||||
for r in results:
|
||||
item = SearchResultItem(
|
||||
id=r.get("id", 0),
|
||||
content=r.get("content", ""),
|
||||
score=r.get("score", 0.0),
|
||||
metadata=r.get("metadata", {})
|
||||
results = get_retrieval_service().retrieve(
|
||||
query=request.query,
|
||||
top_k=request.top_k,
|
||||
filters=request.filters,
|
||||
)
|
||||
return SearchResponse(
|
||||
query=request.query,
|
||||
total=len(results),
|
||||
results=[
|
||||
SearchResultItem(
|
||||
id=index + 1,
|
||||
content=item.content,
|
||||
score=item.score,
|
||||
metadata={
|
||||
"doc_id": item.doc_id,
|
||||
"doc_name": item.doc_name,
|
||||
"chunk_id": item.chunk_id,
|
||||
"section_title": item.section_title,
|
||||
"page_number": item.page_number,
|
||||
**item.metadata,
|
||||
},
|
||||
)
|
||||
result_items.append(item)
|
||||
|
||||
return SearchResponse(
|
||||
query=request.query,
|
||||
total=len(result_items),
|
||||
results=result_items
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检索失败: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"检索失败: {str(e)}"
|
||||
)
|
||||
for index, item in enumerate(results)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/retrieval", response_model=SearchResponse)
|
||||
async def knowledge_retrieval(request: SearchRequest):
|
||||
"""
|
||||
知识检索接口(与架构文档对齐)
|
||||
|
||||
该接口实现完整的检索流程:
|
||||
1. 意图识别
|
||||
2. BM25关键词检索 + 向量语义检索(双路召回)
|
||||
3. Cross-Encoder精排
|
||||
4. 返回结果
|
||||
|
||||
Args:
|
||||
request: 检索请求
|
||||
"""
|
||||
# 当前版本使用混合检索,后续可添加精排步骤
|
||||
"""Handle knowledge retrieval."""
|
||||
return await search_knowledge(request)
|
||||
|
||||
@@ -1,29 +1,39 @@
|
||||
"""Define API routes for rag."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion
|
||||
from app.services.mock_data import (
|
||||
get_mock_quick_questions,
|
||||
get_mock_retrieval,
|
||||
get_mock_rag_answer,
|
||||
)
|
||||
import json
|
||||
import asyncio
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
|
||||
|
||||
router = APIRouter(prefix="/rag", tags=["RAG问答"])
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
async def rag_chat(request: RagChatRequest):
|
||||
"""SSE流式问答"""
|
||||
"""Handle rag chat."""
|
||||
|
||||
async def generate():
|
||||
# 发送检索开始事件
|
||||
yield {"event": "message", "data": json.dumps({"type": "retrieving"})}
|
||||
async def generate() -> AsyncGenerator[str, None]:
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
"""Handle generate."""
|
||||
yield f"event: message\ndata: {json.dumps({'type': 'retrieving'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 模拟检索延迟
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# 执行检索
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
docs = get_mock_retrieval(request.query, top_k=request.top_k)
|
||||
|
||||
retrieved_data = [
|
||||
@@ -36,39 +46,49 @@ async def rag_chat(request: RagChatRequest):
|
||||
}
|
||||
for d in docs
|
||||
]
|
||||
yield {"event": "message", "data": json.dumps({"type": "retrieved", "docs": retrieved_data})}
|
||||
yield f"event: message\ndata: {json.dumps({'type': 'retrieved', 'docs': retrieved_data}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 发送生成开始事件
|
||||
yield {"event": "message", "data": json.dumps({"type": "generating", "text": "正在生成答案..."})}
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
yield (
|
||||
f"event: message\ndata: "
|
||||
f"{json.dumps({'type': 'generating', 'text': '正在生成答案...'}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
|
||||
# 模拟生成延迟
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# 获取预设答案
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
answer = get_mock_rag_answer(request.query)
|
||||
|
||||
# 流式输出答案(按句子分割)
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
sentences = answer.split("\n\n")
|
||||
for sentence in sentences:
|
||||
if sentence.strip():
|
||||
# 进一步分割长句子
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
chunks = sentence.split("\n")
|
||||
for chunk in chunks:
|
||||
if chunk.strip():
|
||||
await asyncio.sleep(0.05) # 模拟生成延迟
|
||||
yield {"event": "message", "data": json.dumps({"type": "chunk", "text": chunk + "\n"})}
|
||||
await asyncio.sleep(0.05) # Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'chunk', 'text': chunk + chr(10)}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
|
||||
# 发送完成事件
|
||||
yield {"event": "message", "data": json.dumps({"type": "done"})}
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
yield f"event: message\ndata: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
return EventSourceResponse(generate())
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/quick-questions", response_model=QuickQuestionsResponse)
|
||||
async def get_quick_questions():
|
||||
"""获取预设快捷问题"""
|
||||
"""Return quick questions."""
|
||||
questions = [
|
||||
QuickQuestion(id=q["id"], question=q["question"], category=q["category"])
|
||||
for q in get_mock_quick_questions()
|
||||
]
|
||||
return QuickQuestionsResponse(questions=questions)
|
||||
return QuickQuestionsResponse(questions=questions)
|
||||
|
||||
@@ -1,28 +1,44 @@
|
||||
"""Define API routes for status."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from app.core.config import settings
|
||||
from app.services.mock_data import MOCK_SYSTEM_STATS, MOCK_SYSTEM_CONFIG
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.shared.bootstrap import get_document_query_service, get_vector_index
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
|
||||
|
||||
router = APIRouter(prefix="/status", tags=["系统状态"])
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_stats():
|
||||
"""获取系统统计"""
|
||||
# 返回预设统计数据
|
||||
return MOCK_SYSTEM_STATS
|
||||
"""Return stats."""
|
||||
documents = get_document_query_service().list_documents()
|
||||
indexed = sum(1 for item in documents if item.status.value == "indexed")
|
||||
failed = sum(1 for item in documents if item.status.value == "failed")
|
||||
return {
|
||||
"documents_total": len(documents),
|
||||
"documents_indexed": indexed,
|
||||
"documents_failed": failed,
|
||||
"chunks_total": sum(item.chunk_count for item in documents),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def get_config():
|
||||
"""获取当前配置"""
|
||||
return MOCK_SYSTEM_CONFIG
|
||||
"""Return config."""
|
||||
return {
|
||||
"embedding_model": settings.embedding_model,
|
||||
"embedding_dim": settings.embedding_dim,
|
||||
"embedding_base_url": settings.embedding_base_url,
|
||||
"milvus_collection": settings.milvus_collection,
|
||||
"llm_provider": settings.llm_provider,
|
||||
"llm_model": settings.llm_model,
|
||||
"document_metadata_path": settings.document_metadata_path,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/milvus/health")
|
||||
async def milvus_health():
|
||||
"""Milvus健康检查"""
|
||||
# 模拟连接状态(假数据模式下始终返回连接成功)
|
||||
return {
|
||||
"connected": True,
|
||||
"collections": ["vehicle_regulations"],
|
||||
}
|
||||
"""Handle milvus health."""
|
||||
return get_vector_index().health()
|
||||
|
||||
Reference in New Issue
Block a user