450 lines
14 KiB
Python
450 lines
14 KiB
Python
# src/api/routes/agent.py
|
||
"""Agent API接口 - 问答对话接口"""
|
||
|
||
from fastapi import APIRouter, HTTPException, Depends
|
||
from fastapi.responses import StreamingResponse
|
||
from pydantic import BaseModel, Field
|
||
from typing import List, Dict, Optional, AsyncGenerator
|
||
from loguru import logger
|
||
import json
|
||
import asyncio
|
||
|
||
from app.services.agent.qa_agent import QAAgent, AgentConfig
|
||
from app.services.agent.session_manager import SessionManager
|
||
from app.config.settings import settings
|
||
|
||
|
||
router = APIRouter(prefix="/agent", tags=["agent"])
|
||
|
||
# 会话管理器(全局实例)
|
||
session_manager = SessionManager()
|
||
|
||
|
||
# ===== Pydantic Models =====
|
||
|
||
class AskRequest(BaseModel):
|
||
"""单次问答请求"""
|
||
query: str = Field(..., description="用户问题", min_length=1, max_length=2000)
|
||
filters: Optional[str] = Field(None, description="检索过滤条件")
|
||
provider: Optional[str] = Field(None, description="LLM提供商 (qwen/deepseek)")
|
||
model: Optional[str] = Field(None, description="LLM模型名称")
|
||
top_k: Optional[int] = Field(None, description="检索数量", ge=1, le=20)
|
||
prompt_template: Optional[str] = Field(None, description="Prompt模板名称")
|
||
|
||
|
||
class AskResponse(BaseModel):
|
||
"""问答响应"""
|
||
answer: str
|
||
sources: List[Dict] = []
|
||
model: str = ""
|
||
latency_ms: int = 0
|
||
retrieved_count: int = 0
|
||
context_tokens: int = 0
|
||
truncated: bool = False
|
||
error: Optional[str] = None
|
||
|
||
|
||
class ChatRequest(BaseModel):
|
||
"""多轮对话请求"""
|
||
query: str = Field(..., description="用户问题", min_length=1, max_length=2000)
|
||
session_id: Optional[str] = Field(None, description="会话ID(首次对话可不传)")
|
||
filters: Optional[str] = Field(None, description="检索过滤条件")
|
||
provider: Optional[str] = Field(None, description="LLM提供商")
|
||
model: Optional[str] = Field(None, description="LLM模型名称")
|
||
|
||
|
||
class ChatResponse(BaseModel):
|
||
"""多轮对话响应"""
|
||
session_id: str
|
||
answer: str
|
||
sources: List[Dict] = []
|
||
model: str = ""
|
||
latency_ms: int = 0
|
||
message_count: int = 0
|
||
|
||
|
||
class SessionInfo(BaseModel):
|
||
"""会话信息"""
|
||
session_id: str
|
||
message_count: int
|
||
created_at: int
|
||
updated_at: int
|
||
|
||
|
||
class FeedbackRequest(BaseModel):
|
||
"""反馈请求"""
|
||
session_id: str
|
||
message_index: int
|
||
rating: int = Field(..., ge=1, le=5, description="评分 1-5")
|
||
comment: Optional[str] = Field(None, description="反馈内容")
|
||
|
||
|
||
class TemplateListResponse(BaseModel):
|
||
"""模板列表响应"""
|
||
templates: Dict[str, str]
|
||
|
||
|
||
# ===== API Endpoints =====
|
||
|
||
@router.post("/ask", response_model=AskResponse)
|
||
async def ask_question(request: AskRequest):
|
||
"""
|
||
单次问答接口
|
||
|
||
不保存会话历史,适合单次查询场景。
|
||
"""
|
||
logger.info(f"收到问答请求: {request.query}")
|
||
|
||
try:
|
||
# 构建Agent配置
|
||
config = AgentConfig(
|
||
llm_provider=request.provider or settings.llm_provider,
|
||
llm_model=request.model or settings.llm_model,
|
||
top_k=request.top_k or settings.rag_top_k
|
||
)
|
||
|
||
# 创建Agent并执行问答
|
||
agent = QAAgent(config)
|
||
response = agent.ask(
|
||
query=request.query,
|
||
filters=request.filters,
|
||
prompt_template=request.prompt_template
|
||
)
|
||
agent.close()
|
||
|
||
return AskResponse(
|
||
answer=response.answer,
|
||
sources=response.sources,
|
||
model=response.model,
|
||
latency_ms=response.latency_ms,
|
||
retrieved_count=response.retrieved_count,
|
||
context_tokens=response.context_tokens,
|
||
truncated=response.truncated,
|
||
error=response.error
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"问答失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.post("/chat", response_model=ChatResponse)
|
||
async def chat_with_session(request: ChatRequest):
|
||
"""
|
||
多轮对话接口
|
||
|
||
支持会话历史记录,适合连续对话场景。
|
||
"""
|
||
logger.info(f"收到对话请求: session={request.session_id}, query={request.query}")
|
||
|
||
try:
|
||
# 获取或创建会话
|
||
if request.session_id:
|
||
session = session_manager.get_session(request.session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="会话不存在或已过期")
|
||
else:
|
||
session = session_manager.create_session()
|
||
|
||
# 添加用户消息
|
||
session.add_user_message(request.query)
|
||
|
||
# 执行问答
|
||
config = AgentConfig(
|
||
llm_provider=request.provider or settings.llm_provider,
|
||
llm_model=request.model or settings.llm_model
|
||
)
|
||
|
||
agent = QAAgent(config)
|
||
response = agent.ask(
|
||
query=request.query,
|
||
filters=request.filters
|
||
)
|
||
agent.close()
|
||
|
||
# 添加助手消息
|
||
session.add_assistant_message(
|
||
response.answer,
|
||
response.sources
|
||
)
|
||
|
||
return ChatResponse(
|
||
session_id=session.session_id,
|
||
answer=response.answer,
|
||
sources=response.sources,
|
||
model=response.model,
|
||
latency_ms=response.latency_ms,
|
||
message_count=session.message_count
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"对话失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("/chat/stream")
|
||
async def chat_stream_get(
|
||
query: str,
|
||
session_id: Optional[str] = None,
|
||
filters: Optional[str] = None,
|
||
provider: Optional[str] = None,
|
||
model: Optional[str] = None
|
||
):
|
||
"""
|
||
流式对话接口(SSE)- GET版本
|
||
|
||
EventSource只能发送GET请求,因此提供此接口。
|
||
query参数通过URL传递。
|
||
|
||
SSE事件格式:
|
||
- event: session - 会话ID
|
||
- event: status - 状态更新(检索中、生成中)
|
||
- event: sources - 引用来源
|
||
- event: content - 回答内容片段
|
||
- event: done - 完成,包含统计信息
|
||
- event: error - 错误信息
|
||
"""
|
||
logger.info(f"收到GET流式对话请求: session={session_id}, query={query}")
|
||
|
||
async def generate_sse() -> AsyncGenerator[str, None]:
|
||
"""生成SSE事件流"""
|
||
try:
|
||
# 获取或创建会话
|
||
if session_id:
|
||
session = session_manager.get_session(session_id)
|
||
if not session:
|
||
yield f"event: error\ndata: 会话不存在或已过期\n\n"
|
||
return
|
||
else:
|
||
session = session_manager.create_session()
|
||
|
||
# 发送session_id
|
||
yield f"event: session\ndata: {json.dumps({'session_id': session.session_id})}\n\n"
|
||
|
||
# 添加用户消息
|
||
session.add_user_message(query)
|
||
|
||
# 创建Agent
|
||
config = AgentConfig(
|
||
llm_provider=provider or settings.llm_provider,
|
||
llm_model=model or settings.llm_model
|
||
)
|
||
|
||
agent = QAAgent(config)
|
||
|
||
# 执行流式问答
|
||
full_answer = ""
|
||
sources = []
|
||
done_data = {}
|
||
|
||
for event_data in agent.ask_stream(
|
||
query=query,
|
||
filters=filters
|
||
):
|
||
event_type = event_data.get("event", "content")
|
||
data = event_data.get("data", "")
|
||
|
||
# 收集完整回答和来源
|
||
if event_type == "content":
|
||
full_answer += str(data)
|
||
elif event_type == "sources":
|
||
sources = data
|
||
elif event_type == "done":
|
||
done_data = data
|
||
|
||
# 发送SSE事件
|
||
if isinstance(data, (dict, list)):
|
||
yield f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
||
else:
|
||
yield f"event: {event_type}\ndata: {data}\n\n"
|
||
|
||
# 小延迟让其他任务有机会执行
|
||
await asyncio.sleep(0)
|
||
|
||
agent.close()
|
||
|
||
# 保存到会话历史
|
||
session.add_assistant_message(full_answer, sources)
|
||
|
||
except Exception as e:
|
||
logger.error(f"流式对话失败: {e}")
|
||
yield f"event: error\ndata: {str(e)}\n\n"
|
||
|
||
return StreamingResponse(
|
||
generate_sse(),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"X-Accel-Buffering": "no" # 禁用nginx缓冲
|
||
}
|
||
)
|
||
|
||
|
||
@router.post("/chat/stream")
|
||
async def chat_stream(request: ChatRequest):
|
||
"""
|
||
流式对话接口(SSE)
|
||
|
||
返回Server-Sent Events格式的流式响应,用户可实时看到思考过程和回答生成。
|
||
|
||
SSE事件格式:
|
||
- event: status - 状态更新(检索中、生成中)
|
||
- event: sources - 引用来源
|
||
- event: content - 回答内容片段
|
||
- event: done - 完成,包含统计信息
|
||
- event: error - 错误信息
|
||
"""
|
||
logger.info(f"收到流式对话请求: session={request.session_id}, query={request.query}")
|
||
|
||
async def generate_sse() -> AsyncGenerator[str, None]:
|
||
"""生成SSE事件流"""
|
||
try:
|
||
# 获取或创建会话
|
||
if request.session_id:
|
||
session = session_manager.get_session(request.session_id)
|
||
if not session:
|
||
yield f"event: error\ndata: 会话不存在或已过期\n\n"
|
||
return
|
||
else:
|
||
session = session_manager.create_session()
|
||
|
||
# 发送session_id
|
||
yield f"event: session\ndata: {json.dumps({'session_id': session.session_id})}\n\n"
|
||
|
||
# 添加用户消息
|
||
session.add_user_message(request.query)
|
||
|
||
# 创建Agent
|
||
config = AgentConfig(
|
||
llm_provider=request.provider or settings.llm_provider,
|
||
llm_model=request.model or settings.llm_model
|
||
)
|
||
|
||
agent = QAAgent(config)
|
||
|
||
# 执行流式问答
|
||
full_answer = ""
|
||
sources = []
|
||
done_data = {}
|
||
|
||
for event_data in agent.ask_stream(
|
||
query=request.query,
|
||
filters=request.filters
|
||
):
|
||
event_type = event_data.get("event", "content")
|
||
data = event_data.get("data", "")
|
||
|
||
# 收集完整回答和来源
|
||
if event_type == "content":
|
||
full_answer += str(data)
|
||
elif event_type == "sources":
|
||
sources = data
|
||
elif event_type == "done":
|
||
done_data = data
|
||
|
||
# 发送SSE事件
|
||
if isinstance(data, (dict, list)):
|
||
yield f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
||
else:
|
||
yield f"event: {event_type}\ndata: {data}\n\n"
|
||
|
||
# 小延迟让其他任务有机会执行
|
||
await asyncio.sleep(0)
|
||
|
||
agent.close()
|
||
|
||
# 保存到会话历史
|
||
session.add_assistant_message(full_answer, sources)
|
||
|
||
except Exception as e:
|
||
logger.error(f"流式对话失败: {e}")
|
||
yield f"event: error\ndata: {str(e)}\n\n"
|
||
|
||
return StreamingResponse(
|
||
generate_sse(),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"X-Accel-Buffering": "no" # 禁用nginx缓冲
|
||
}
|
||
)
|
||
|
||
|
||
@router.get("/session/{session_id}", response_model=SessionInfo)
|
||
async def get_session_info(session_id: str):
|
||
"""获取会话信息"""
|
||
session = session_manager.get_session(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="会话不存在或已过期")
|
||
|
||
return SessionInfo(
|
||
session_id=session.session_id,
|
||
message_count=session.message_count,
|
||
created_at=session.created_at,
|
||
updated_at=session.updated_at
|
||
)
|
||
|
||
|
||
@router.get("/session/{session_id}/history")
|
||
async def get_session_history(session_id: str, max_turns: int = 5):
|
||
"""获取会话历史"""
|
||
session = session_manager.get_session(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="会话不存在或已过期")
|
||
|
||
history = session.get_history(max_turns)
|
||
return {"session_id": session_id, "history": history}
|
||
|
||
|
||
@router.delete("/session/{session_id}")
|
||
async def delete_session(session_id: str):
|
||
"""删除会话"""
|
||
success = session_manager.delete_session(session_id)
|
||
if not success:
|
||
raise HTTPException(status_code=404, detail="会话不存在")
|
||
|
||
return {"message": "会话已删除", "session_id": session_id}
|
||
|
||
|
||
@router.get("/sessions", response_model=List[SessionInfo])
|
||
async def list_sessions():
|
||
"""列出所有活跃会话"""
|
||
sessions = session_manager.list_sessions()
|
||
return [SessionInfo(**s) for s in sessions]
|
||
|
||
|
||
@router.post("/feedback")
|
||
async def submit_feedback(request: FeedbackRequest):
|
||
"""提交问答反馈"""
|
||
session = session_manager.get_session(request.session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="会话不存在")
|
||
|
||
# 记录反馈(实际应用中可存储到数据库)
|
||
logger.info(f"收到反馈: session={request.session_id}, rating={request.rating}, comment={request.comment}")
|
||
|
||
return {"message": "反馈已记录", "rating": request.rating}
|
||
|
||
|
||
@router.get("/templates", response_model=TemplateListResponse)
|
||
async def list_prompt_templates():
|
||
"""列出可用的Prompt模板"""
|
||
from app.services.rag.prompt_templates import PromptTemplates
|
||
|
||
templates = PromptTemplates.list_templates()
|
||
return TemplateListResponse(templates=templates)
|
||
|
||
|
||
@router.get("/models")
|
||
async def list_available_models():
|
||
"""列出可用的LLM模型"""
|
||
from app.services.llm import LLMFactory
|
||
|
||
factory = LLMFactory()
|
||
models = factory.list_available_providers()
|
||
return {"models": models}
|