Files
AIRegulation-DocAnalysis/backend/app/api/routes/agent.py
2026-05-21 23:20:39 +08:00

178 lines
6.3 KiB
Python

"""Define API routes for agent."""
from __future__ import annotations
import json
from typing import AsyncGenerator, List, Optional
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from dataclasses import asdict
from app.api.models import (
AskRequest,
AskResponse,
ChatRequest,
ChatResponse,
FeedbackRequest,
SessionInfo,
)
from app.config.settings import settings
from app.shared.async_utils import iter_in_thread
from app.shared.bootstrap import get_agent_conversation_service, get_conversation_store
# Keep route handlers close to their transport-layer wiring for easier auditing.
router = APIRouter(prefix="/agent", tags=["agent"])
@router.post("/ask", response_model=AskResponse)
async def ask_question(request: AskRequest):
"""Handle ask question."""
try:
_, result = get_agent_conversation_service().ask(
query=request.query,
filters=request.filters,
provider=request.provider or settings.llm_provider,
model=request.model or settings.llm_model,
top_k=request.top_k or settings.rag_top_k,
prompt_template=request.prompt_template,
)
return AskResponse(
answer=result.answer,
sources=[asdict(source) for source in result.sources],
model=result.model,
latency_ms=result.latency_ms,
retrieved_count=result.retrieved_count,
context_tokens=result.context_tokens,
truncated=result.truncated,
error=result.error,
)
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/chat", response_model=ChatResponse)
async def chat_with_session(request: ChatRequest):
"""Handle chat with session."""
try:
session_id, result = get_agent_conversation_service().chat(
query=request.query,
session_id=request.session_id,
filters=request.filters,
provider=request.provider or settings.llm_provider,
model=request.model or settings.llm_model,
top_k=request.top_k or settings.rag_top_k,
)
session = get_conversation_store().get_session(session_id)
return ChatResponse(
session_id=session_id,
answer=result.answer,
sources=[asdict(source) for source in result.sources],
model=result.model,
latency_ms=result.latency_ms,
message_count=len(session.messages) if session else 0,
)
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc))
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
@router.get("/chat/stream")
async def chat_stream_get(
query: str,
session_id: Optional[str] = None,
filters: Optional[str] = None,
provider: Optional[str] = None,
model: Optional[str] = None,
):
"""Handle chat stream get."""
async def generate_sse() -> AsyncGenerator[str, None]:
"""Handle generate sse."""
try:
session_id_, event_stream = get_agent_conversation_service().stream_chat(
query=query,
session_id=session_id,
filters=filters,
provider=provider or settings.llm_provider,
model=model or settings.llm_model,
top_k=settings.rag_top_k,
)
yield f"event: session\ndata: {json.dumps({'session_id': session_id_})}\n\n"
async for event_data in iter_in_thread(event_stream):
event_type = event_data.get("event", "content")
data = event_data.get("data", "")
if isinstance(data, (dict, list)):
yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
else:
yield f"event: {event_type}\ndata: {data}\n\n"
except Exception as exc:
yield f"event: error\ndata: {str(exc)}\n\n"
return StreamingResponse(
generate_sse(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)
@router.post("/chat/stream")
async def chat_stream(request: ChatRequest):
"""Handle chat stream."""
return await chat_stream_get(
query=request.query,
session_id=request.session_id,
filters=request.filters,
provider=request.provider,
model=request.model,
)
@router.get("/session/{session_id}", response_model=SessionInfo)
async def get_session_info(session_id: str):
"""Return session info."""
session = get_conversation_store().get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在或已过期")
return SessionInfo(
session_id=session.session_id,
message_count=len(session.messages),
created_at=session.created_at,
updated_at=session.updated_at,
)
@router.get("/session/{session_id}/history")
async def get_session_history(session_id: str, max_turns: int = 5):
"""Return session history."""
session = get_conversation_store().get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在或已过期")
history = [{"role": msg.role, "content": msg.content} for msg in session.messages[-(max_turns * 2):]]
return {"session_id": session_id, "history": history}
@router.delete("/session/{session_id}")
async def delete_session(session_id: str):
"""Delete session."""
if not get_conversation_store().delete_session(session_id):
raise HTTPException(status_code=404, detail="会话不存在")
return {"message": "会话已删除", "session_id": session_id}
@router.get("/sessions", response_model=List[SessionInfo])
async def list_sessions():
"""List sessions."""
return [SessionInfo(**item) for item in get_conversation_store().list_sessions()]
@router.post("/feedback")
async def submit_feedback(request: FeedbackRequest):
"""Submit feedback."""
session = get_conversation_store().get_session(request.session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在")
return {"message": "反馈已提交", "session_id": request.session_id, "message_index": request.message_index}