Fix SSE route dependency and align architecture docs

This commit is contained in:
ash66
2026-05-18 16:32:42 +08:00
parent 86b9ac806a
commit 3f69cad404
149 changed files with 4786 additions and 5957 deletions

View File

@@ -1,186 +1,83 @@
"""Agent API接口 - 问答对话接口"""
"""Define API routes for agent."""
from __future__ import annotations
from fastapi import APIRouter, HTTPException, Depends
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, AsyncGenerator
from loguru import logger
import json
import asyncio
import json
from typing import AsyncGenerator, List, Optional
from app.services.agent.qa_agent import QAAgent, AgentConfig
from app.services.agent.session_manager import SessionManager
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from dataclasses import asdict
from app.api.models import (
AskRequest,
AskResponse,
ChatRequest,
ChatResponse,
FeedbackRequest,
SessionInfo,
)
from app.config.settings import settings
from app.shared.bootstrap import get_agent_conversation_service, get_conversation_store
# Keep route handlers close to their transport-layer wiring for easier auditing.
router = APIRouter(prefix="/agent", tags=["agent"])
# 会话管理器(全局实例)
session_manager = SessionManager()
# ===== Pydantic Models =====
class AskRequest(BaseModel):
"""单次问答请求"""
query: str = Field(..., description="用户问题", min_length=1, max_length=2000)
filters: Optional[str] = Field(None, description="检索过滤条件")
provider: Optional[str] = Field(None, description="LLM提供商 (qwen/deepseek)")
model: Optional[str] = Field(None, description="LLM模型名称")
top_k: Optional[int] = Field(None, description="检索数量", ge=1, le=20)
prompt_template: Optional[str] = Field(None, description="Prompt模板名称")
class AskResponse(BaseModel):
"""问答响应"""
answer: str
sources: List[Dict] = []
model: str = ""
latency_ms: int = 0
retrieved_count: int = 0
context_tokens: int = 0
truncated: bool = False
error: Optional[str] = None
class ChatRequest(BaseModel):
"""多轮对话请求"""
query: str = Field(..., description="用户问题", min_length=1, max_length=2000)
session_id: Optional[str] = Field(None, description="会话ID首次对话可不传")
filters: Optional[str] = Field(None, description="检索过滤条件")
provider: Optional[str] = Field(None, description="LLM提供商")
model: Optional[str] = Field(None, description="LLM模型名称")
class ChatResponse(BaseModel):
"""多轮对话响应"""
session_id: str
answer: str
sources: List[Dict] = []
model: str = ""
latency_ms: int = 0
message_count: int = 0
class SessionInfo(BaseModel):
"""会话信息"""
session_id: str
message_count: int
created_at: int
updated_at: int
class FeedbackRequest(BaseModel):
"""反馈请求"""
session_id: str
message_index: int
rating: int = Field(..., ge=1, le=5, description="评分 1-5")
comment: Optional[str] = Field(None, description="反馈内容")
class TemplateListResponse(BaseModel):
"""模板列表响应"""
templates: Dict[str, str]
# ===== API Endpoints =====
@router.post("/ask", response_model=AskResponse)
async def ask_question(request: AskRequest):
"""
单次问答接口
不保存会话历史,适合单次查询场景。
"""
logger.info(f"收到问答请求: {request.query}")
"""Handle ask question."""
try:
# 构建Agent配置
config = AgentConfig(
llm_provider=request.provider or settings.llm_provider,
llm_model=request.model or settings.llm_model,
top_k=request.top_k or settings.rag_top_k
)
# 创建Agent并执行问答
agent = QAAgent(config)
response = agent.ask(
_, result = get_agent_conversation_service().ask(
query=request.query,
filters=request.filters,
prompt_template=request.prompt_template
provider=request.provider or settings.llm_provider,
model=request.model or settings.llm_model,
top_k=request.top_k or settings.rag_top_k,
prompt_template=request.prompt_template,
)
agent.close()
return AskResponse(
answer=response.answer,
sources=response.sources,
model=response.model,
latency_ms=response.latency_ms,
retrieved_count=response.retrieved_count,
context_tokens=response.context_tokens,
truncated=response.truncated,
error=response.error
answer=result.answer,
sources=[asdict(source) for source in result.sources],
model=result.model,
latency_ms=result.latency_ms,
retrieved_count=result.retrieved_count,
context_tokens=result.context_tokens,
truncated=result.truncated,
error=result.error,
)
except Exception as e:
logger.error(f"问答失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/chat", response_model=ChatResponse)
async def chat_with_session(request: ChatRequest):
"""
多轮对话接口
支持会话历史记录,适合连续对话场景。
"""
logger.info(f"收到对话请求: session={request.session_id}, query={request.query}")
"""Handle chat with session."""
try:
# 获取或创建会话
if request.session_id:
session = session_manager.get_session(request.session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在或已过期")
else:
session = session_manager.create_session()
# 添加用户消息
session.add_user_message(request.query)
# 执行问答
config = AgentConfig(
llm_provider=request.provider or settings.llm_provider,
llm_model=request.model or settings.llm_model
)
agent = QAAgent(config)
response = agent.ask(
session_id, result = get_agent_conversation_service().chat(
query=request.query,
filters=request.filters
session_id=request.session_id,
filters=request.filters,
provider=request.provider or settings.llm_provider,
model=request.model or settings.llm_model,
top_k=request.top_k or settings.rag_top_k,
)
agent.close()
# 添加助手消息
session.add_assistant_message(
response.answer,
response.sources
)
session = get_conversation_store().get_session(session_id)
return ChatResponse(
session_id=session.session_id,
answer=response.answer,
sources=response.sources,
model=response.model,
latency_ms=response.latency_ms,
message_count=session.message_count
session_id=session_id,
answer=result.answer,
sources=[asdict(source) for source in result.sources],
model=result.model,
latency_ms=result.latency_ms,
message_count=len(session.messages) if session else 0,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"对话失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc))
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
@router.get("/chat/stream")
@@ -189,260 +86,93 @@ async def chat_stream_get(
session_id: Optional[str] = None,
filters: Optional[str] = None,
provider: Optional[str] = None,
model: Optional[str] = None
model: Optional[str] = None,
):
"""
流式对话接口SSE- GET版本
EventSource只能发送GET请求因此提供此接口。
query参数通过URL传递。
SSE事件格式
- event: session - 会话ID
- event: status - 状态更新(检索中、生成中)
- event: sources - 引用来源
- event: content - 回答内容片段
- event: done - 完成,包含统计信息
- event: error - 错误信息
"""
logger.info(f"收到GET流式对话请求: session={session_id}, query={query}")
"""Handle chat stream get."""
async def generate_sse() -> AsyncGenerator[str, None]:
"""生成SSE事件流"""
"""Handle generate sse."""
try:
# 获取或创建会话
if session_id:
session = session_manager.get_session(session_id)
if not session:
yield f"event: error\ndata: 会话不存在或已过期\n\n"
return
else:
session = session_manager.create_session()
# 发送session_id
yield f"event: session\ndata: {json.dumps({'session_id': session.session_id})}\n\n"
# 添加用户消息
session.add_user_message(query)
# 创建Agent
config = AgentConfig(
llm_provider=provider or settings.llm_provider,
llm_model=model or settings.llm_model
)
agent = QAAgent(config)
# 执行流式问答
full_answer = ""
sources = []
done_data = {}
for event_data in agent.ask_stream(
session_id_, event_stream = get_agent_conversation_service().stream_chat(
query=query,
filters=filters
):
session_id=session_id,
filters=filters,
provider=provider or settings.llm_provider,
model=model or settings.llm_model,
top_k=settings.rag_top_k,
)
yield f"event: session\ndata: {json.dumps({'session_id': session_id_})}\n\n"
for event_data in event_stream:
event_type = event_data.get("event", "content")
data = event_data.get("data", "")
# 收集完整回答和来源
if event_type == "content":
full_answer += str(data)
elif event_type == "sources":
sources = data
elif event_type == "done":
done_data = data
# 发送SSE事件
if isinstance(data, (dict, list)):
yield f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
else:
yield f"event: {event_type}\ndata: {data}\n\n"
# 小延迟让其他任务有机会执行
await asyncio.sleep(0)
agent.close()
# 保存到会话历史
session.add_assistant_message(full_answer, sources)
except Exception as e:
logger.error(f"流式对话失败: {e}")
yield f"event: error\ndata: {str(e)}\n\n"
except Exception as exc:
yield f"event: error\ndata: {str(exc)}\n\n"
return StreamingResponse(
generate_sse(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # 禁用nginx缓冲
}
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)
@router.post("/chat/stream")
async def chat_stream(request: ChatRequest):
"""
流式对话接口SSE
返回Server-Sent Events格式的流式响应用户可实时看到思考过程和回答生成。
SSE事件格式
- event: status - 状态更新(检索中、生成中)
- event: sources - 引用来源
- event: content - 回答内容片段
- event: done - 完成,包含统计信息
- event: error - 错误信息
"""
logger.info(f"收到流式对话请求: session={request.session_id}, query={request.query}")
async def generate_sse() -> AsyncGenerator[str, None]:
"""生成SSE事件流"""
try:
# 获取或创建会话
if request.session_id:
session = session_manager.get_session(request.session_id)
if not session:
yield f"event: error\ndata: 会话不存在或已过期\n\n"
return
else:
session = session_manager.create_session()
# 发送session_id
yield f"event: session\ndata: {json.dumps({'session_id': session.session_id})}\n\n"
# 添加用户消息
session.add_user_message(request.query)
# 创建Agent
config = AgentConfig(
llm_provider=request.provider or settings.llm_provider,
llm_model=request.model or settings.llm_model
)
agent = QAAgent(config)
# 执行流式问答
full_answer = ""
sources = []
done_data = {}
for event_data in agent.ask_stream(
query=request.query,
filters=request.filters
):
event_type = event_data.get("event", "content")
data = event_data.get("data", "")
# 收集完整回答和来源
if event_type == "content":
full_answer += str(data)
elif event_type == "sources":
sources = data
elif event_type == "done":
done_data = data
# 发送SSE事件
if isinstance(data, (dict, list)):
yield f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
else:
yield f"event: {event_type}\ndata: {data}\n\n"
# 小延迟让其他任务有机会执行
await asyncio.sleep(0)
agent.close()
# 保存到会话历史
session.add_assistant_message(full_answer, sources)
except Exception as e:
logger.error(f"流式对话失败: {e}")
yield f"event: error\ndata: {str(e)}\n\n"
return StreamingResponse(
generate_sse(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # 禁用nginx缓冲
}
"""Handle chat stream."""
return await chat_stream_get(
query=request.query,
session_id=request.session_id,
filters=request.filters,
provider=request.provider,
model=request.model,
)
@router.get("/session/{session_id}", response_model=SessionInfo)
async def get_session_info(session_id: str):
"""获取会话信息"""
session = session_manager.get_session(session_id)
"""Return session info."""
session = get_conversation_store().get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在或已过期")
return SessionInfo(
session_id=session.session_id,
message_count=session.message_count,
message_count=len(session.messages),
created_at=session.created_at,
updated_at=session.updated_at
updated_at=session.updated_at,
)
@router.get("/session/{session_id}/history")
async def get_session_history(session_id: str, max_turns: int = 5):
"""获取会话历史"""
session = session_manager.get_session(session_id)
"""Return session history."""
session = get_conversation_store().get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在或已过期")
history = session.get_history(max_turns)
history = [{"role": msg.role, "content": msg.content} for msg in session.messages[-(max_turns * 2):]]
return {"session_id": session_id, "history": history}
@router.delete("/session/{session_id}")
async def delete_session(session_id: str):
"""删除会话"""
success = session_manager.delete_session(session_id)
if not success:
"""Delete session."""
if not get_conversation_store().delete_session(session_id):
raise HTTPException(status_code=404, detail="会话不存在")
return {"message": "会话已删除", "session_id": session_id}
@router.get("/sessions", response_model=List[SessionInfo])
async def list_sessions():
"""列出所有活跃会话"""
sessions = session_manager.list_sessions()
return [SessionInfo(**s) for s in sessions]
"""List sessions."""
return [SessionInfo(**item) for item in get_conversation_store().list_sessions()]
@router.post("/feedback")
async def submit_feedback(request: FeedbackRequest):
"""提交问答反馈"""
session = session_manager.get_session(request.session_id)
"""Submit feedback."""
session = get_conversation_store().get_session(request.session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在")
# 记录反馈(实际应用中可存储到数据库)
logger.info(f"收到反馈: session={request.session_id}, rating={request.rating}, comment={request.comment}")
return {"message": "反馈已记录", "rating": request.rating}
@router.get("/templates", response_model=TemplateListResponse)
async def list_prompt_templates():
"""列出可用的Prompt模板"""
from app.services.rag.prompt_templates import PromptTemplates
templates = PromptTemplates.list_templates()
return TemplateListResponse(templates=templates)
@router.get("/models")
async def list_available_models():
"""列出可用的LLM模型"""
from app.services.llm import LLMFactory
factory = LLMFactory()
models = factory.list_available_providers()
return {"models": models}
return {"message": "反馈已提交", "session_id": request.session_id, "message_index": request.message_index}