Fix SSE route dependency and align architecture docs
This commit is contained in:
@@ -1,186 +1,83 @@
|
||||
"""Agent API接口 - 问答对话接口"""
|
||||
"""Define API routes for agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional, AsyncGenerator
|
||||
from loguru import logger
|
||||
import json
|
||||
import asyncio
|
||||
import json
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from app.services.agent.qa_agent import QAAgent, AgentConfig
|
||||
from app.services.agent.session_manager import SessionManager
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from dataclasses import asdict
|
||||
|
||||
from app.api.models import (
|
||||
AskRequest,
|
||||
AskResponse,
|
||||
ChatRequest,
|
||||
ChatResponse,
|
||||
FeedbackRequest,
|
||||
SessionInfo,
|
||||
)
|
||||
from app.config.settings import settings
|
||||
from app.shared.bootstrap import get_agent_conversation_service, get_conversation_store
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
|
||||
|
||||
router = APIRouter(prefix="/agent", tags=["agent"])
|
||||
|
||||
# 会话管理器(全局实例)
|
||||
session_manager = SessionManager()
|
||||
|
||||
|
||||
# ===== Pydantic Models =====
|
||||
|
||||
class AskRequest(BaseModel):
|
||||
"""单次问答请求"""
|
||||
query: str = Field(..., description="用户问题", min_length=1, max_length=2000)
|
||||
filters: Optional[str] = Field(None, description="检索过滤条件")
|
||||
provider: Optional[str] = Field(None, description="LLM提供商 (qwen/deepseek)")
|
||||
model: Optional[str] = Field(None, description="LLM模型名称")
|
||||
top_k: Optional[int] = Field(None, description="检索数量", ge=1, le=20)
|
||||
prompt_template: Optional[str] = Field(None, description="Prompt模板名称")
|
||||
|
||||
|
||||
class AskResponse(BaseModel):
|
||||
"""问答响应"""
|
||||
answer: str
|
||||
sources: List[Dict] = []
|
||||
model: str = ""
|
||||
latency_ms: int = 0
|
||||
retrieved_count: int = 0
|
||||
context_tokens: int = 0
|
||||
truncated: bool = False
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
"""多轮对话请求"""
|
||||
query: str = Field(..., description="用户问题", min_length=1, max_length=2000)
|
||||
session_id: Optional[str] = Field(None, description="会话ID(首次对话可不传)")
|
||||
filters: Optional[str] = Field(None, description="检索过滤条件")
|
||||
provider: Optional[str] = Field(None, description="LLM提供商")
|
||||
model: Optional[str] = Field(None, description="LLM模型名称")
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""多轮对话响应"""
|
||||
session_id: str
|
||||
answer: str
|
||||
sources: List[Dict] = []
|
||||
model: str = ""
|
||||
latency_ms: int = 0
|
||||
message_count: int = 0
|
||||
|
||||
|
||||
class SessionInfo(BaseModel):
|
||||
"""会话信息"""
|
||||
session_id: str
|
||||
message_count: int
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
|
||||
class FeedbackRequest(BaseModel):
|
||||
"""反馈请求"""
|
||||
session_id: str
|
||||
message_index: int
|
||||
rating: int = Field(..., ge=1, le=5, description="评分 1-5")
|
||||
comment: Optional[str] = Field(None, description="反馈内容")
|
||||
|
||||
|
||||
class TemplateListResponse(BaseModel):
|
||||
"""模板列表响应"""
|
||||
templates: Dict[str, str]
|
||||
|
||||
|
||||
# ===== API Endpoints =====
|
||||
|
||||
@router.post("/ask", response_model=AskResponse)
|
||||
async def ask_question(request: AskRequest):
|
||||
"""
|
||||
单次问答接口
|
||||
|
||||
不保存会话历史,适合单次查询场景。
|
||||
"""
|
||||
logger.info(f"收到问答请求: {request.query}")
|
||||
|
||||
"""Handle ask question."""
|
||||
try:
|
||||
# 构建Agent配置
|
||||
config = AgentConfig(
|
||||
llm_provider=request.provider or settings.llm_provider,
|
||||
llm_model=request.model or settings.llm_model,
|
||||
top_k=request.top_k or settings.rag_top_k
|
||||
)
|
||||
|
||||
# 创建Agent并执行问答
|
||||
agent = QAAgent(config)
|
||||
response = agent.ask(
|
||||
_, result = get_agent_conversation_service().ask(
|
||||
query=request.query,
|
||||
filters=request.filters,
|
||||
prompt_template=request.prompt_template
|
||||
provider=request.provider or settings.llm_provider,
|
||||
model=request.model or settings.llm_model,
|
||||
top_k=request.top_k or settings.rag_top_k,
|
||||
prompt_template=request.prompt_template,
|
||||
)
|
||||
agent.close()
|
||||
|
||||
return AskResponse(
|
||||
answer=response.answer,
|
||||
sources=response.sources,
|
||||
model=response.model,
|
||||
latency_ms=response.latency_ms,
|
||||
retrieved_count=response.retrieved_count,
|
||||
context_tokens=response.context_tokens,
|
||||
truncated=response.truncated,
|
||||
error=response.error
|
||||
answer=result.answer,
|
||||
sources=[asdict(source) for source in result.sources],
|
||||
model=result.model,
|
||||
latency_ms=result.latency_ms,
|
||||
retrieved_count=result.retrieved_count,
|
||||
context_tokens=result.context_tokens,
|
||||
truncated=result.truncated,
|
||||
error=result.error,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"问答失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/chat", response_model=ChatResponse)
|
||||
async def chat_with_session(request: ChatRequest):
|
||||
"""
|
||||
多轮对话接口
|
||||
|
||||
支持会话历史记录,适合连续对话场景。
|
||||
"""
|
||||
logger.info(f"收到对话请求: session={request.session_id}, query={request.query}")
|
||||
|
||||
"""Handle chat with session."""
|
||||
try:
|
||||
# 获取或创建会话
|
||||
if request.session_id:
|
||||
session = session_manager.get_session(request.session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="会话不存在或已过期")
|
||||
else:
|
||||
session = session_manager.create_session()
|
||||
|
||||
# 添加用户消息
|
||||
session.add_user_message(request.query)
|
||||
|
||||
# 执行问答
|
||||
config = AgentConfig(
|
||||
llm_provider=request.provider or settings.llm_provider,
|
||||
llm_model=request.model or settings.llm_model
|
||||
)
|
||||
|
||||
agent = QAAgent(config)
|
||||
response = agent.ask(
|
||||
session_id, result = get_agent_conversation_service().chat(
|
||||
query=request.query,
|
||||
filters=request.filters
|
||||
session_id=request.session_id,
|
||||
filters=request.filters,
|
||||
provider=request.provider or settings.llm_provider,
|
||||
model=request.model or settings.llm_model,
|
||||
top_k=request.top_k or settings.rag_top_k,
|
||||
)
|
||||
agent.close()
|
||||
|
||||
# 添加助手消息
|
||||
session.add_assistant_message(
|
||||
response.answer,
|
||||
response.sources
|
||||
)
|
||||
|
||||
session = get_conversation_store().get_session(session_id)
|
||||
return ChatResponse(
|
||||
session_id=session.session_id,
|
||||
answer=response.answer,
|
||||
sources=response.sources,
|
||||
model=response.model,
|
||||
latency_ms=response.latency_ms,
|
||||
message_count=session.message_count
|
||||
session_id=session_id,
|
||||
answer=result.answer,
|
||||
sources=[asdict(source) for source in result.sources],
|
||||
model=result.model,
|
||||
latency_ms=result.latency_ms,
|
||||
message_count=len(session.messages) if session else 0,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"对话失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc))
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.get("/chat/stream")
|
||||
@@ -189,260 +86,93 @@ async def chat_stream_get(
|
||||
session_id: Optional[str] = None,
|
||||
filters: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
model: Optional[str] = None
|
||||
model: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
流式对话接口(SSE)- GET版本
|
||||
|
||||
EventSource只能发送GET请求,因此提供此接口。
|
||||
query参数通过URL传递。
|
||||
|
||||
SSE事件格式:
|
||||
- event: session - 会话ID
|
||||
- event: status - 状态更新(检索中、生成中)
|
||||
- event: sources - 引用来源
|
||||
- event: content - 回答内容片段
|
||||
- event: done - 完成,包含统计信息
|
||||
- event: error - 错误信息
|
||||
"""
|
||||
logger.info(f"收到GET流式对话请求: session={session_id}, query={query}")
|
||||
|
||||
"""Handle chat stream get."""
|
||||
async def generate_sse() -> AsyncGenerator[str, None]:
|
||||
"""生成SSE事件流"""
|
||||
"""Handle generate sse."""
|
||||
try:
|
||||
# 获取或创建会话
|
||||
if session_id:
|
||||
session = session_manager.get_session(session_id)
|
||||
if not session:
|
||||
yield f"event: error\ndata: 会话不存在或已过期\n\n"
|
||||
return
|
||||
else:
|
||||
session = session_manager.create_session()
|
||||
|
||||
# 发送session_id
|
||||
yield f"event: session\ndata: {json.dumps({'session_id': session.session_id})}\n\n"
|
||||
|
||||
# 添加用户消息
|
||||
session.add_user_message(query)
|
||||
|
||||
# 创建Agent
|
||||
config = AgentConfig(
|
||||
llm_provider=provider or settings.llm_provider,
|
||||
llm_model=model or settings.llm_model
|
||||
)
|
||||
|
||||
agent = QAAgent(config)
|
||||
|
||||
# 执行流式问答
|
||||
full_answer = ""
|
||||
sources = []
|
||||
done_data = {}
|
||||
|
||||
for event_data in agent.ask_stream(
|
||||
session_id_, event_stream = get_agent_conversation_service().stream_chat(
|
||||
query=query,
|
||||
filters=filters
|
||||
):
|
||||
session_id=session_id,
|
||||
filters=filters,
|
||||
provider=provider or settings.llm_provider,
|
||||
model=model or settings.llm_model,
|
||||
top_k=settings.rag_top_k,
|
||||
)
|
||||
yield f"event: session\ndata: {json.dumps({'session_id': session_id_})}\n\n"
|
||||
for event_data in event_stream:
|
||||
event_type = event_data.get("event", "content")
|
||||
data = event_data.get("data", "")
|
||||
|
||||
# 收集完整回答和来源
|
||||
if event_type == "content":
|
||||
full_answer += str(data)
|
||||
elif event_type == "sources":
|
||||
sources = data
|
||||
elif event_type == "done":
|
||||
done_data = data
|
||||
|
||||
# 发送SSE事件
|
||||
if isinstance(data, (dict, list)):
|
||||
yield f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
||||
yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
yield f"event: {event_type}\ndata: {data}\n\n"
|
||||
|
||||
# 小延迟让其他任务有机会执行
|
||||
await asyncio.sleep(0)
|
||||
|
||||
agent.close()
|
||||
|
||||
# 保存到会话历史
|
||||
session.add_assistant_message(full_answer, sources)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式对话失败: {e}")
|
||||
yield f"event: error\ndata: {str(e)}\n\n"
|
||||
except Exception as exc:
|
||||
yield f"event: error\ndata: {str(exc)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate_sse(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no" # 禁用nginx缓冲
|
||||
}
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chat/stream")
|
||||
async def chat_stream(request: ChatRequest):
|
||||
"""
|
||||
流式对话接口(SSE)
|
||||
|
||||
返回Server-Sent Events格式的流式响应,用户可实时看到思考过程和回答生成。
|
||||
|
||||
SSE事件格式:
|
||||
- event: status - 状态更新(检索中、生成中)
|
||||
- event: sources - 引用来源
|
||||
- event: content - 回答内容片段
|
||||
- event: done - 完成,包含统计信息
|
||||
- event: error - 错误信息
|
||||
"""
|
||||
logger.info(f"收到流式对话请求: session={request.session_id}, query={request.query}")
|
||||
|
||||
async def generate_sse() -> AsyncGenerator[str, None]:
|
||||
"""生成SSE事件流"""
|
||||
try:
|
||||
# 获取或创建会话
|
||||
if request.session_id:
|
||||
session = session_manager.get_session(request.session_id)
|
||||
if not session:
|
||||
yield f"event: error\ndata: 会话不存在或已过期\n\n"
|
||||
return
|
||||
else:
|
||||
session = session_manager.create_session()
|
||||
|
||||
# 发送session_id
|
||||
yield f"event: session\ndata: {json.dumps({'session_id': session.session_id})}\n\n"
|
||||
|
||||
# 添加用户消息
|
||||
session.add_user_message(request.query)
|
||||
|
||||
# 创建Agent
|
||||
config = AgentConfig(
|
||||
llm_provider=request.provider or settings.llm_provider,
|
||||
llm_model=request.model or settings.llm_model
|
||||
)
|
||||
|
||||
agent = QAAgent(config)
|
||||
|
||||
# 执行流式问答
|
||||
full_answer = ""
|
||||
sources = []
|
||||
done_data = {}
|
||||
|
||||
for event_data in agent.ask_stream(
|
||||
query=request.query,
|
||||
filters=request.filters
|
||||
):
|
||||
event_type = event_data.get("event", "content")
|
||||
data = event_data.get("data", "")
|
||||
|
||||
# 收集完整回答和来源
|
||||
if event_type == "content":
|
||||
full_answer += str(data)
|
||||
elif event_type == "sources":
|
||||
sources = data
|
||||
elif event_type == "done":
|
||||
done_data = data
|
||||
|
||||
# 发送SSE事件
|
||||
if isinstance(data, (dict, list)):
|
||||
yield f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
||||
else:
|
||||
yield f"event: {event_type}\ndata: {data}\n\n"
|
||||
|
||||
# 小延迟让其他任务有机会执行
|
||||
await asyncio.sleep(0)
|
||||
|
||||
agent.close()
|
||||
|
||||
# 保存到会话历史
|
||||
session.add_assistant_message(full_answer, sources)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式对话失败: {e}")
|
||||
yield f"event: error\ndata: {str(e)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate_sse(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no" # 禁用nginx缓冲
|
||||
}
|
||||
"""Handle chat stream."""
|
||||
return await chat_stream_get(
|
||||
query=request.query,
|
||||
session_id=request.session_id,
|
||||
filters=request.filters,
|
||||
provider=request.provider,
|
||||
model=request.model,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/session/{session_id}", response_model=SessionInfo)
|
||||
async def get_session_info(session_id: str):
|
||||
"""获取会话信息"""
|
||||
session = session_manager.get_session(session_id)
|
||||
"""Return session info."""
|
||||
session = get_conversation_store().get_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="会话不存在或已过期")
|
||||
|
||||
return SessionInfo(
|
||||
session_id=session.session_id,
|
||||
message_count=session.message_count,
|
||||
message_count=len(session.messages),
|
||||
created_at=session.created_at,
|
||||
updated_at=session.updated_at
|
||||
updated_at=session.updated_at,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/session/{session_id}/history")
|
||||
async def get_session_history(session_id: str, max_turns: int = 5):
|
||||
"""获取会话历史"""
|
||||
session = session_manager.get_session(session_id)
|
||||
"""Return session history."""
|
||||
session = get_conversation_store().get_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="会话不存在或已过期")
|
||||
|
||||
history = session.get_history(max_turns)
|
||||
history = [{"role": msg.role, "content": msg.content} for msg in session.messages[-(max_turns * 2):]]
|
||||
return {"session_id": session_id, "history": history}
|
||||
|
||||
|
||||
@router.delete("/session/{session_id}")
|
||||
async def delete_session(session_id: str):
|
||||
"""删除会话"""
|
||||
success = session_manager.delete_session(session_id)
|
||||
if not success:
|
||||
"""Delete session."""
|
||||
if not get_conversation_store().delete_session(session_id):
|
||||
raise HTTPException(status_code=404, detail="会话不存在")
|
||||
|
||||
return {"message": "会话已删除", "session_id": session_id}
|
||||
|
||||
|
||||
@router.get("/sessions", response_model=List[SessionInfo])
|
||||
async def list_sessions():
|
||||
"""列出所有活跃会话"""
|
||||
sessions = session_manager.list_sessions()
|
||||
return [SessionInfo(**s) for s in sessions]
|
||||
"""List sessions."""
|
||||
return [SessionInfo(**item) for item in get_conversation_store().list_sessions()]
|
||||
|
||||
|
||||
@router.post("/feedback")
|
||||
async def submit_feedback(request: FeedbackRequest):
|
||||
"""提交问答反馈"""
|
||||
session = session_manager.get_session(request.session_id)
|
||||
"""Submit feedback."""
|
||||
session = get_conversation_store().get_session(request.session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="会话不存在")
|
||||
|
||||
# 记录反馈(实际应用中可存储到数据库)
|
||||
logger.info(f"收到反馈: session={request.session_id}, rating={request.rating}, comment={request.comment}")
|
||||
|
||||
return {"message": "反馈已记录", "rating": request.rating}
|
||||
|
||||
|
||||
@router.get("/templates", response_model=TemplateListResponse)
|
||||
async def list_prompt_templates():
|
||||
"""列出可用的Prompt模板"""
|
||||
from app.services.rag.prompt_templates import PromptTemplates
|
||||
|
||||
templates = PromptTemplates.list_templates()
|
||||
return TemplateListResponse(templates=templates)
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
async def list_available_models():
|
||||
"""列出可用的LLM模型"""
|
||||
from app.services.llm import LLMFactory
|
||||
|
||||
factory = LLMFactory()
|
||||
models = factory.list_available_providers()
|
||||
return {"models": models}
|
||||
return {"message": "反馈已提交", "session_id": request.session_id, "message_index": request.message_index}
|
||||
|
||||
Reference in New Issue
Block a user