450 lines
14 KiB
Python
450 lines
14 KiB
Python
|
|
# src/api/routes/agent.py
|
|||
|
|
"""Agent API接口 - 问答对话接口"""
|
|||
|
|
|
|||
|
|
from fastapi import APIRouter, HTTPException, Depends
|
|||
|
|
from fastapi.responses import StreamingResponse
|
|||
|
|
from pydantic import BaseModel, Field
|
|||
|
|
from typing import List, Dict, Optional, AsyncGenerator
|
|||
|
|
from loguru import logger
|
|||
|
|
import json
|
|||
|
|
import asyncio
|
|||
|
|
|
|||
|
|
from app.services.agent.qa_agent import QAAgent, AgentConfig
|
|||
|
|
from app.services.agent.session_manager import SessionManager
|
|||
|
|
from app.config.settings import settings
|
|||
|
|
|
|||
|
|
|
|||
|
|
router = APIRouter(prefix="/agent", tags=["agent"])
|
|||
|
|
|
|||
|
|
# 会话管理器(全局实例)
|
|||
|
|
session_manager = SessionManager()
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ===== Pydantic Models =====
|
|||
|
|
|
|||
|
|
class AskRequest(BaseModel):
|
|||
|
|
"""单次问答请求"""
|
|||
|
|
query: str = Field(..., description="用户问题", min_length=1, max_length=2000)
|
|||
|
|
filters: Optional[str] = Field(None, description="检索过滤条件")
|
|||
|
|
provider: Optional[str] = Field(None, description="LLM提供商 (qwen/deepseek)")
|
|||
|
|
model: Optional[str] = Field(None, description="LLM模型名称")
|
|||
|
|
top_k: Optional[int] = Field(None, description="检索数量", ge=1, le=20)
|
|||
|
|
prompt_template: Optional[str] = Field(None, description="Prompt模板名称")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class AskResponse(BaseModel):
|
|||
|
|
"""问答响应"""
|
|||
|
|
answer: str
|
|||
|
|
sources: List[Dict] = []
|
|||
|
|
model: str = ""
|
|||
|
|
latency_ms: int = 0
|
|||
|
|
retrieved_count: int = 0
|
|||
|
|
context_tokens: int = 0
|
|||
|
|
truncated: bool = False
|
|||
|
|
error: Optional[str] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ChatRequest(BaseModel):
|
|||
|
|
"""多轮对话请求"""
|
|||
|
|
query: str = Field(..., description="用户问题", min_length=1, max_length=2000)
|
|||
|
|
session_id: Optional[str] = Field(None, description="会话ID(首次对话可不传)")
|
|||
|
|
filters: Optional[str] = Field(None, description="检索过滤条件")
|
|||
|
|
provider: Optional[str] = Field(None, description="LLM提供商")
|
|||
|
|
model: Optional[str] = Field(None, description="LLM模型名称")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ChatResponse(BaseModel):
|
|||
|
|
"""多轮对话响应"""
|
|||
|
|
session_id: str
|
|||
|
|
answer: str
|
|||
|
|
sources: List[Dict] = []
|
|||
|
|
model: str = ""
|
|||
|
|
latency_ms: int = 0
|
|||
|
|
message_count: int = 0
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SessionInfo(BaseModel):
|
|||
|
|
"""会话信息"""
|
|||
|
|
session_id: str
|
|||
|
|
message_count: int
|
|||
|
|
created_at: int
|
|||
|
|
updated_at: int
|
|||
|
|
|
|||
|
|
|
|||
|
|
class FeedbackRequest(BaseModel):
|
|||
|
|
"""反馈请求"""
|
|||
|
|
session_id: str
|
|||
|
|
message_index: int
|
|||
|
|
rating: int = Field(..., ge=1, le=5, description="评分 1-5")
|
|||
|
|
comment: Optional[str] = Field(None, description="反馈内容")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TemplateListResponse(BaseModel):
|
|||
|
|
"""模板列表响应"""
|
|||
|
|
templates: Dict[str, str]
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ===== API Endpoints =====
|
|||
|
|
|
|||
|
|
@router.post("/ask", response_model=AskResponse)
|
|||
|
|
async def ask_question(request: AskRequest):
|
|||
|
|
"""
|
|||
|
|
单次问答接口
|
|||
|
|
|
|||
|
|
不保存会话历史,适合单次查询场景。
|
|||
|
|
"""
|
|||
|
|
logger.info(f"收到问答请求: {request.query}")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 构建Agent配置
|
|||
|
|
config = AgentConfig(
|
|||
|
|
llm_provider=request.provider or settings.llm_provider,
|
|||
|
|
llm_model=request.model or settings.llm_model,
|
|||
|
|
top_k=request.top_k or settings.rag_top_k
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 创建Agent并执行问答
|
|||
|
|
agent = QAAgent(config)
|
|||
|
|
response = agent.ask(
|
|||
|
|
query=request.query,
|
|||
|
|
filters=request.filters,
|
|||
|
|
prompt_template=request.prompt_template
|
|||
|
|
)
|
|||
|
|
agent.close()
|
|||
|
|
|
|||
|
|
return AskResponse(
|
|||
|
|
answer=response.answer,
|
|||
|
|
sources=response.sources,
|
|||
|
|
model=response.model,
|
|||
|
|
latency_ms=response.latency_ms,
|
|||
|
|
retrieved_count=response.retrieved_count,
|
|||
|
|
context_tokens=response.context_tokens,
|
|||
|
|
truncated=response.truncated,
|
|||
|
|
error=response.error
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"问答失败: {e}")
|
|||
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/chat", response_model=ChatResponse)
|
|||
|
|
async def chat_with_session(request: ChatRequest):
|
|||
|
|
"""
|
|||
|
|
多轮对话接口
|
|||
|
|
|
|||
|
|
支持会话历史记录,适合连续对话场景。
|
|||
|
|
"""
|
|||
|
|
logger.info(f"收到对话请求: session={request.session_id}, query={request.query}")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 获取或创建会话
|
|||
|
|
if request.session_id:
|
|||
|
|
session = session_manager.get_session(request.session_id)
|
|||
|
|
if not session:
|
|||
|
|
raise HTTPException(status_code=404, detail="会话不存在或已过期")
|
|||
|
|
else:
|
|||
|
|
session = session_manager.create_session()
|
|||
|
|
|
|||
|
|
# 添加用户消息
|
|||
|
|
session.add_user_message(request.query)
|
|||
|
|
|
|||
|
|
# 执行问答
|
|||
|
|
config = AgentConfig(
|
|||
|
|
llm_provider=request.provider or settings.llm_provider,
|
|||
|
|
llm_model=request.model or settings.llm_model
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
agent = QAAgent(config)
|
|||
|
|
response = agent.ask(
|
|||
|
|
query=request.query,
|
|||
|
|
filters=request.filters
|
|||
|
|
)
|
|||
|
|
agent.close()
|
|||
|
|
|
|||
|
|
# 添加助手消息
|
|||
|
|
session.add_assistant_message(
|
|||
|
|
response.answer,
|
|||
|
|
response.sources
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return ChatResponse(
|
|||
|
|
session_id=session.session_id,
|
|||
|
|
answer=response.answer,
|
|||
|
|
sources=response.sources,
|
|||
|
|
model=response.model,
|
|||
|
|
latency_ms=response.latency_ms,
|
|||
|
|
message_count=session.message_count
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except HTTPException:
|
|||
|
|
raise
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"对话失败: {e}")
|
|||
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/chat/stream")
|
|||
|
|
async def chat_stream_get(
|
|||
|
|
query: str,
|
|||
|
|
session_id: Optional[str] = None,
|
|||
|
|
filters: Optional[str] = None,
|
|||
|
|
provider: Optional[str] = None,
|
|||
|
|
model: Optional[str] = None
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
流式对话接口(SSE)- GET版本
|
|||
|
|
|
|||
|
|
EventSource只能发送GET请求,因此提供此接口。
|
|||
|
|
query参数通过URL传递。
|
|||
|
|
|
|||
|
|
SSE事件格式:
|
|||
|
|
- event: session - 会话ID
|
|||
|
|
- event: status - 状态更新(检索中、生成中)
|
|||
|
|
- event: sources - 引用来源
|
|||
|
|
- event: content - 回答内容片段
|
|||
|
|
- event: done - 完成,包含统计信息
|
|||
|
|
- event: error - 错误信息
|
|||
|
|
"""
|
|||
|
|
logger.info(f"收到GET流式对话请求: session={session_id}, query={query}")
|
|||
|
|
|
|||
|
|
async def generate_sse() -> AsyncGenerator[str, None]:
|
|||
|
|
"""生成SSE事件流"""
|
|||
|
|
try:
|
|||
|
|
# 获取或创建会话
|
|||
|
|
if session_id:
|
|||
|
|
session = session_manager.get_session(session_id)
|
|||
|
|
if not session:
|
|||
|
|
yield f"event: error\ndata: 会话不存在或已过期\n\n"
|
|||
|
|
return
|
|||
|
|
else:
|
|||
|
|
session = session_manager.create_session()
|
|||
|
|
|
|||
|
|
# 发送session_id
|
|||
|
|
yield f"event: session\ndata: {json.dumps({'session_id': session.session_id})}\n\n"
|
|||
|
|
|
|||
|
|
# 添加用户消息
|
|||
|
|
session.add_user_message(query)
|
|||
|
|
|
|||
|
|
# 创建Agent
|
|||
|
|
config = AgentConfig(
|
|||
|
|
llm_provider=provider or settings.llm_provider,
|
|||
|
|
llm_model=model or settings.llm_model
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
agent = QAAgent(config)
|
|||
|
|
|
|||
|
|
# 执行流式问答
|
|||
|
|
full_answer = ""
|
|||
|
|
sources = []
|
|||
|
|
done_data = {}
|
|||
|
|
|
|||
|
|
for event_data in agent.ask_stream(
|
|||
|
|
query=query,
|
|||
|
|
filters=filters
|
|||
|
|
):
|
|||
|
|
event_type = event_data.get("event", "content")
|
|||
|
|
data = event_data.get("data", "")
|
|||
|
|
|
|||
|
|
# 收集完整回答和来源
|
|||
|
|
if event_type == "content":
|
|||
|
|
full_answer += str(data)
|
|||
|
|
elif event_type == "sources":
|
|||
|
|
sources = data
|
|||
|
|
elif event_type == "done":
|
|||
|
|
done_data = data
|
|||
|
|
|
|||
|
|
# 发送SSE事件
|
|||
|
|
if isinstance(data, (dict, list)):
|
|||
|
|
yield f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
|||
|
|
else:
|
|||
|
|
yield f"event: {event_type}\ndata: {data}\n\n"
|
|||
|
|
|
|||
|
|
# 小延迟让其他任务有机会执行
|
|||
|
|
await asyncio.sleep(0)
|
|||
|
|
|
|||
|
|
agent.close()
|
|||
|
|
|
|||
|
|
# 保存到会话历史
|
|||
|
|
session.add_assistant_message(full_answer, sources)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"流式对话失败: {e}")
|
|||
|
|
yield f"event: error\ndata: {str(e)}\n\n"
|
|||
|
|
|
|||
|
|
return StreamingResponse(
|
|||
|
|
generate_sse(),
|
|||
|
|
media_type="text/event-stream",
|
|||
|
|
headers={
|
|||
|
|
"Cache-Control": "no-cache",
|
|||
|
|
"Connection": "keep-alive",
|
|||
|
|
"X-Accel-Buffering": "no" # 禁用nginx缓冲
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/chat/stream")
|
|||
|
|
async def chat_stream(request: ChatRequest):
|
|||
|
|
"""
|
|||
|
|
流式对话接口(SSE)
|
|||
|
|
|
|||
|
|
返回Server-Sent Events格式的流式响应,用户可实时看到思考过程和回答生成。
|
|||
|
|
|
|||
|
|
SSE事件格式:
|
|||
|
|
- event: status - 状态更新(检索中、生成中)
|
|||
|
|
- event: sources - 引用来源
|
|||
|
|
- event: content - 回答内容片段
|
|||
|
|
- event: done - 完成,包含统计信息
|
|||
|
|
- event: error - 错误信息
|
|||
|
|
"""
|
|||
|
|
logger.info(f"收到流式对话请求: session={request.session_id}, query={request.query}")
|
|||
|
|
|
|||
|
|
async def generate_sse() -> AsyncGenerator[str, None]:
|
|||
|
|
"""生成SSE事件流"""
|
|||
|
|
try:
|
|||
|
|
# 获取或创建会话
|
|||
|
|
if request.session_id:
|
|||
|
|
session = session_manager.get_session(request.session_id)
|
|||
|
|
if not session:
|
|||
|
|
yield f"event: error\ndata: 会话不存在或已过期\n\n"
|
|||
|
|
return
|
|||
|
|
else:
|
|||
|
|
session = session_manager.create_session()
|
|||
|
|
|
|||
|
|
# 发送session_id
|
|||
|
|
yield f"event: session\ndata: {json.dumps({'session_id': session.session_id})}\n\n"
|
|||
|
|
|
|||
|
|
# 添加用户消息
|
|||
|
|
session.add_user_message(request.query)
|
|||
|
|
|
|||
|
|
# 创建Agent
|
|||
|
|
config = AgentConfig(
|
|||
|
|
llm_provider=request.provider or settings.llm_provider,
|
|||
|
|
llm_model=request.model or settings.llm_model
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
agent = QAAgent(config)
|
|||
|
|
|
|||
|
|
# 执行流式问答
|
|||
|
|
full_answer = ""
|
|||
|
|
sources = []
|
|||
|
|
done_data = {}
|
|||
|
|
|
|||
|
|
for event_data in agent.ask_stream(
|
|||
|
|
query=request.query,
|
|||
|
|
filters=request.filters
|
|||
|
|
):
|
|||
|
|
event_type = event_data.get("event", "content")
|
|||
|
|
data = event_data.get("data", "")
|
|||
|
|
|
|||
|
|
# 收集完整回答和来源
|
|||
|
|
if event_type == "content":
|
|||
|
|
full_answer += str(data)
|
|||
|
|
elif event_type == "sources":
|
|||
|
|
sources = data
|
|||
|
|
elif event_type == "done":
|
|||
|
|
done_data = data
|
|||
|
|
|
|||
|
|
# 发送SSE事件
|
|||
|
|
if isinstance(data, (dict, list)):
|
|||
|
|
yield f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
|||
|
|
else:
|
|||
|
|
yield f"event: {event_type}\ndata: {data}\n\n"
|
|||
|
|
|
|||
|
|
# 小延迟让其他任务有机会执行
|
|||
|
|
await asyncio.sleep(0)
|
|||
|
|
|
|||
|
|
agent.close()
|
|||
|
|
|
|||
|
|
# 保存到会话历史
|
|||
|
|
session.add_assistant_message(full_answer, sources)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"流式对话失败: {e}")
|
|||
|
|
yield f"event: error\ndata: {str(e)}\n\n"
|
|||
|
|
|
|||
|
|
return StreamingResponse(
|
|||
|
|
generate_sse(),
|
|||
|
|
media_type="text/event-stream",
|
|||
|
|
headers={
|
|||
|
|
"Cache-Control": "no-cache",
|
|||
|
|
"Connection": "keep-alive",
|
|||
|
|
"X-Accel-Buffering": "no" # 禁用nginx缓冲
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/session/{session_id}", response_model=SessionInfo)
|
|||
|
|
async def get_session_info(session_id: str):
|
|||
|
|
"""获取会话信息"""
|
|||
|
|
session = session_manager.get_session(session_id)
|
|||
|
|
if not session:
|
|||
|
|
raise HTTPException(status_code=404, detail="会话不存在或已过期")
|
|||
|
|
|
|||
|
|
return SessionInfo(
|
|||
|
|
session_id=session.session_id,
|
|||
|
|
message_count=session.message_count,
|
|||
|
|
created_at=session.created_at,
|
|||
|
|
updated_at=session.updated_at
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/session/{session_id}/history")
|
|||
|
|
async def get_session_history(session_id: str, max_turns: int = 5):
|
|||
|
|
"""获取会话历史"""
|
|||
|
|
session = session_manager.get_session(session_id)
|
|||
|
|
if not session:
|
|||
|
|
raise HTTPException(status_code=404, detail="会话不存在或已过期")
|
|||
|
|
|
|||
|
|
history = session.get_history(max_turns)
|
|||
|
|
return {"session_id": session_id, "history": history}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.delete("/session/{session_id}")
|
|||
|
|
async def delete_session(session_id: str):
|
|||
|
|
"""删除会话"""
|
|||
|
|
success = session_manager.delete_session(session_id)
|
|||
|
|
if not success:
|
|||
|
|
raise HTTPException(status_code=404, detail="会话不存在")
|
|||
|
|
|
|||
|
|
return {"message": "会话已删除", "session_id": session_id}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/sessions", response_model=List[SessionInfo])
|
|||
|
|
async def list_sessions():
|
|||
|
|
"""列出所有活跃会话"""
|
|||
|
|
sessions = session_manager.list_sessions()
|
|||
|
|
return [SessionInfo(**s) for s in sessions]
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/feedback")
|
|||
|
|
async def submit_feedback(request: FeedbackRequest):
|
|||
|
|
"""提交问答反馈"""
|
|||
|
|
session = session_manager.get_session(request.session_id)
|
|||
|
|
if not session:
|
|||
|
|
raise HTTPException(status_code=404, detail="会话不存在")
|
|||
|
|
|
|||
|
|
# 记录反馈(实际应用中可存储到数据库)
|
|||
|
|
logger.info(f"收到反馈: session={request.session_id}, rating={request.rating}, comment={request.comment}")
|
|||
|
|
|
|||
|
|
return {"message": "反馈已记录", "rating": request.rating}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/templates", response_model=TemplateListResponse)
|
|||
|
|
async def list_prompt_templates():
|
|||
|
|
"""列出可用的Prompt模板"""
|
|||
|
|
from app.services.rag.prompt_templates import PromptTemplates
|
|||
|
|
|
|||
|
|
templates = PromptTemplates.list_templates()
|
|||
|
|
return TemplateListResponse(templates=templates)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/models")
|
|||
|
|
async def list_available_models():
|
|||
|
|
"""列出可用的LLM模型"""
|
|||
|
|
from app.services.llm import LLMFactory
|
|||
|
|
|
|||
|
|
factory = LLMFactory()
|
|||
|
|
models = factory.list_available_providers()
|
|||
|
|
return {"models": models}
|