This commit is contained in:
2026-05-14 15:07:34 +08:00
parent c2a398930d
commit 10d04c4083
179 changed files with 24073 additions and 1243 deletions

View File

@@ -0,0 +1,449 @@
# src/api/routes/agent.py
"""Agent API接口 - 问答对话接口"""
from fastapi import APIRouter, HTTPException, Depends
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, AsyncGenerator
from loguru import logger
import json
import asyncio
from app.services.agent.qa_agent import QAAgent, AgentConfig
from app.services.agent.session_manager import SessionManager
from app.config.settings import settings
router = APIRouter(prefix="/agent", tags=["agent"])
# 会话管理器(全局实例)
session_manager = SessionManager()
# ===== Pydantic Models =====
class AskRequest(BaseModel):
"""单次问答请求"""
query: str = Field(..., description="用户问题", min_length=1, max_length=2000)
filters: Optional[str] = Field(None, description="检索过滤条件")
provider: Optional[str] = Field(None, description="LLM提供商 (qwen/deepseek)")
model: Optional[str] = Field(None, description="LLM模型名称")
top_k: Optional[int] = Field(None, description="检索数量", ge=1, le=20)
prompt_template: Optional[str] = Field(None, description="Prompt模板名称")
class AskResponse(BaseModel):
"""问答响应"""
answer: str
sources: List[Dict] = []
model: str = ""
latency_ms: int = 0
retrieved_count: int = 0
context_tokens: int = 0
truncated: bool = False
error: Optional[str] = None
class ChatRequest(BaseModel):
"""多轮对话请求"""
query: str = Field(..., description="用户问题", min_length=1, max_length=2000)
session_id: Optional[str] = Field(None, description="会话ID首次对话可不传")
filters: Optional[str] = Field(None, description="检索过滤条件")
provider: Optional[str] = Field(None, description="LLM提供商")
model: Optional[str] = Field(None, description="LLM模型名称")
class ChatResponse(BaseModel):
"""多轮对话响应"""
session_id: str
answer: str
sources: List[Dict] = []
model: str = ""
latency_ms: int = 0
message_count: int = 0
class SessionInfo(BaseModel):
"""会话信息"""
session_id: str
message_count: int
created_at: int
updated_at: int
class FeedbackRequest(BaseModel):
"""反馈请求"""
session_id: str
message_index: int
rating: int = Field(..., ge=1, le=5, description="评分 1-5")
comment: Optional[str] = Field(None, description="反馈内容")
class TemplateListResponse(BaseModel):
"""模板列表响应"""
templates: Dict[str, str]
# ===== API Endpoints =====
@router.post("/ask", response_model=AskResponse)
async def ask_question(request: AskRequest):
"""
单次问答接口
不保存会话历史,适合单次查询场景。
"""
logger.info(f"收到问答请求: {request.query}")
try:
# 构建Agent配置
config = AgentConfig(
llm_provider=request.provider or settings.llm_provider,
llm_model=request.model or settings.llm_model,
top_k=request.top_k or settings.rag_top_k
)
# 创建Agent并执行问答
agent = QAAgent(config)
response = agent.ask(
query=request.query,
filters=request.filters,
prompt_template=request.prompt_template
)
agent.close()
return AskResponse(
answer=response.answer,
sources=response.sources,
model=response.model,
latency_ms=response.latency_ms,
retrieved_count=response.retrieved_count,
context_tokens=response.context_tokens,
truncated=response.truncated,
error=response.error
)
except Exception as e:
logger.error(f"问答失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/chat", response_model=ChatResponse)
async def chat_with_session(request: ChatRequest):
"""
多轮对话接口
支持会话历史记录,适合连续对话场景。
"""
logger.info(f"收到对话请求: session={request.session_id}, query={request.query}")
try:
# 获取或创建会话
if request.session_id:
session = session_manager.get_session(request.session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在或已过期")
else:
session = session_manager.create_session()
# 添加用户消息
session.add_user_message(request.query)
# 执行问答
config = AgentConfig(
llm_provider=request.provider or settings.llm_provider,
llm_model=request.model or settings.llm_model
)
agent = QAAgent(config)
response = agent.ask(
query=request.query,
filters=request.filters
)
agent.close()
# 添加助手消息
session.add_assistant_message(
response.answer,
response.sources
)
return ChatResponse(
session_id=session.session_id,
answer=response.answer,
sources=response.sources,
model=response.model,
latency_ms=response.latency_ms,
message_count=session.message_count
)
except HTTPException:
raise
except Exception as e:
logger.error(f"对话失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/chat/stream")
async def chat_stream_get(
query: str,
session_id: Optional[str] = None,
filters: Optional[str] = None,
provider: Optional[str] = None,
model: Optional[str] = None
):
"""
流式对话接口SSE- GET版本
EventSource只能发送GET请求因此提供此接口。
query参数通过URL传递。
SSE事件格式
- event: session - 会话ID
- event: status - 状态更新(检索中、生成中)
- event: sources - 引用来源
- event: content - 回答内容片段
- event: done - 完成,包含统计信息
- event: error - 错误信息
"""
logger.info(f"收到GET流式对话请求: session={session_id}, query={query}")
async def generate_sse() -> AsyncGenerator[str, None]:
"""生成SSE事件流"""
try:
# 获取或创建会话
if session_id:
session = session_manager.get_session(session_id)
if not session:
yield f"event: error\ndata: 会话不存在或已过期\n\n"
return
else:
session = session_manager.create_session()
# 发送session_id
yield f"event: session\ndata: {json.dumps({'session_id': session.session_id})}\n\n"
# 添加用户消息
session.add_user_message(query)
# 创建Agent
config = AgentConfig(
llm_provider=provider or settings.llm_provider,
llm_model=model or settings.llm_model
)
agent = QAAgent(config)
# 执行流式问答
full_answer = ""
sources = []
done_data = {}
for event_data in agent.ask_stream(
query=query,
filters=filters
):
event_type = event_data.get("event", "content")
data = event_data.get("data", "")
# 收集完整回答和来源
if event_type == "content":
full_answer += str(data)
elif event_type == "sources":
sources = data
elif event_type == "done":
done_data = data
# 发送SSE事件
if isinstance(data, (dict, list)):
yield f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
else:
yield f"event: {event_type}\ndata: {data}\n\n"
# 小延迟让其他任务有机会执行
await asyncio.sleep(0)
agent.close()
# 保存到会话历史
session.add_assistant_message(full_answer, sources)
except Exception as e:
logger.error(f"流式对话失败: {e}")
yield f"event: error\ndata: {str(e)}\n\n"
return StreamingResponse(
generate_sse(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # 禁用nginx缓冲
}
)
@router.post("/chat/stream")
async def chat_stream(request: ChatRequest):
"""
流式对话接口SSE
返回Server-Sent Events格式的流式响应用户可实时看到思考过程和回答生成。
SSE事件格式
- event: status - 状态更新(检索中、生成中)
- event: sources - 引用来源
- event: content - 回答内容片段
- event: done - 完成,包含统计信息
- event: error - 错误信息
"""
logger.info(f"收到流式对话请求: session={request.session_id}, query={request.query}")
async def generate_sse() -> AsyncGenerator[str, None]:
"""生成SSE事件流"""
try:
# 获取或创建会话
if request.session_id:
session = session_manager.get_session(request.session_id)
if not session:
yield f"event: error\ndata: 会话不存在或已过期\n\n"
return
else:
session = session_manager.create_session()
# 发送session_id
yield f"event: session\ndata: {json.dumps({'session_id': session.session_id})}\n\n"
# 添加用户消息
session.add_user_message(request.query)
# 创建Agent
config = AgentConfig(
llm_provider=request.provider or settings.llm_provider,
llm_model=request.model or settings.llm_model
)
agent = QAAgent(config)
# 执行流式问答
full_answer = ""
sources = []
done_data = {}
for event_data in agent.ask_stream(
query=request.query,
filters=request.filters
):
event_type = event_data.get("event", "content")
data = event_data.get("data", "")
# 收集完整回答和来源
if event_type == "content":
full_answer += str(data)
elif event_type == "sources":
sources = data
elif event_type == "done":
done_data = data
# 发送SSE事件
if isinstance(data, (dict, list)):
yield f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
else:
yield f"event: {event_type}\ndata: {data}\n\n"
# 小延迟让其他任务有机会执行
await asyncio.sleep(0)
agent.close()
# 保存到会话历史
session.add_assistant_message(full_answer, sources)
except Exception as e:
logger.error(f"流式对话失败: {e}")
yield f"event: error\ndata: {str(e)}\n\n"
return StreamingResponse(
generate_sse(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # 禁用nginx缓冲
}
)
@router.get("/session/{session_id}", response_model=SessionInfo)
async def get_session_info(session_id: str):
"""获取会话信息"""
session = session_manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在或已过期")
return SessionInfo(
session_id=session.session_id,
message_count=session.message_count,
created_at=session.created_at,
updated_at=session.updated_at
)
@router.get("/session/{session_id}/history")
async def get_session_history(session_id: str, max_turns: int = 5):
"""获取会话历史"""
session = session_manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在或已过期")
history = session.get_history(max_turns)
return {"session_id": session_id, "history": history}
@router.delete("/session/{session_id}")
async def delete_session(session_id: str):
"""删除会话"""
success = session_manager.delete_session(session_id)
if not success:
raise HTTPException(status_code=404, detail="会话不存在")
return {"message": "会话已删除", "session_id": session_id}
@router.get("/sessions", response_model=List[SessionInfo])
async def list_sessions():
"""列出所有活跃会话"""
sessions = session_manager.list_sessions()
return [SessionInfo(**s) for s in sessions]
@router.post("/feedback")
async def submit_feedback(request: FeedbackRequest):
"""提交问答反馈"""
session = session_manager.get_session(request.session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在")
# 记录反馈(实际应用中可存储到数据库)
logger.info(f"收到反馈: session={request.session_id}, rating={request.rating}, comment={request.comment}")
return {"message": "反馈已记录", "rating": request.rating}
@router.get("/templates", response_model=TemplateListResponse)
async def list_prompt_templates():
"""列出可用的Prompt模板"""
from app.services.rag.prompt_templates import PromptTemplates
templates = PromptTemplates.list_templates()
return TemplateListResponse(templates=templates)
@router.get("/models")
async def list_available_models():
"""列出可用的LLM模型"""
from app.services.llm import LLMFactory
factory = LLMFactory()
models = factory.list_available_providers()
return {"models": models}