Files
AIRegulation-DocAnalysis/backend/app/api/routes/agent.py

185 lines
6.4 KiB
Python
Raw Normal View History

"""Define API routes for agent."""
2026-05-14 15:07:34 +08:00
from __future__ import annotations
2026-05-14 15:07:34 +08:00
import json
from typing import AsyncGenerator, List, Optional
2026-05-14 15:07:34 +08:00
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
2026-05-14 15:07:34 +08:00
from dataclasses import asdict
2026-05-14 15:07:34 +08:00
from app.api.models import (
AskRequest,
AskResponse,
ChatRequest,
ChatResponse,
FeedbackRequest,
SessionInfo,
)
from app.config.settings import settings
2026-05-21 23:20:39 +08:00
from app.shared.async_utils import iter_in_thread
from app.shared.bootstrap import get_agent_conversation_service, get_agent_session_service
# Keep route handlers close to their transport-layer wiring for easier auditing.
2026-05-14 15:07:34 +08:00
router = APIRouter(prefix="/agent", tags=["agent"])
2026-05-14 15:07:34 +08:00
@router.post("/ask", response_model=AskResponse)
async def ask_question(request: AskRequest):
"""Handle ask question."""
2026-05-14 15:07:34 +08:00
try:
_, result = get_agent_conversation_service().ask(
2026-05-14 15:07:34 +08:00
query=request.query,
filters=request.filters,
provider=request.provider or settings.llm_provider,
model=request.model or settings.llm_model,
top_k=request.top_k or settings.rag_top_k,
prompt_template=request.prompt_template,
2026-05-14 15:07:34 +08:00
)
return AskResponse(
answer=result.answer,
sources=[asdict(source) for source in result.sources],
model=result.model,
latency_ms=result.latency_ms,
retrieved_count=result.retrieved_count,
context_tokens=result.context_tokens,
truncated=result.truncated,
error=result.error,
2026-05-14 15:07:34 +08:00
)
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
2026-05-14 15:07:34 +08:00
@router.post("/chat", response_model=ChatResponse)
async def chat_with_session(request: ChatRequest):
"""Handle chat with session."""
2026-05-14 15:07:34 +08:00
try:
session_id, result = get_agent_conversation_service().chat(
2026-05-14 15:07:34 +08:00
query=request.query,
session_id=request.session_id,
filters=request.filters,
provider=request.provider or settings.llm_provider,
model=request.model or settings.llm_model,
top_k=request.top_k or settings.rag_top_k,
2026-05-14 15:07:34 +08:00
)
session = get_agent_session_service().get_session(session_id)
2026-05-14 15:07:34 +08:00
return ChatResponse(
session_id=session_id,
answer=result.answer,
sources=[asdict(source) for source in result.sources],
model=result.model,
latency_ms=result.latency_ms,
message_count=len(session.messages) if session else 0,
2026-05-14 15:07:34 +08:00
)
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc))
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
2026-05-14 15:07:34 +08:00
@router.get("/chat/stream")
async def chat_stream_get(
query: str,
session_id: Optional[str] = None,
filters: Optional[str] = None,
provider: Optional[str] = None,
model: Optional[str] = None,
2026-05-14 15:07:34 +08:00
):
"""Handle chat stream get."""
2026-05-14 15:07:34 +08:00
async def generate_sse() -> AsyncGenerator[str, None]:
"""Handle generate sse."""
2026-05-14 15:07:34 +08:00
try:
session_id_, event_stream = get_agent_conversation_service().stream_chat(
2026-05-14 15:07:34 +08:00
query=query,
session_id=session_id,
filters=filters,
provider=provider or settings.llm_provider,
model=model or settings.llm_model,
top_k=settings.rag_top_k,
)
yield f"event: session\ndata: {json.dumps({'session_id': session_id_})}\n\n"
2026-05-21 23:20:39 +08:00
async for event_data in iter_in_thread(event_stream):
2026-05-14 15:07:34 +08:00
event_type = event_data.get("event", "content")
data = event_data.get("data", "")
if isinstance(data, (dict, list)):
yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
2026-05-14 15:07:34 +08:00
else:
yield f"event: {event_type}\ndata: {data}\n\n"
except Exception as exc:
yield f"event: error\ndata: {str(exc)}\n\n"
2026-05-14 15:07:34 +08:00
return StreamingResponse(
generate_sse(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
2026-05-14 15:07:34 +08:00
)
@router.post("/chat/stream")
async def chat_stream(request: ChatRequest):
"""Handle chat stream."""
return await chat_stream_get(
query=request.query,
session_id=request.session_id,
filters=request.filters,
provider=request.provider,
model=request.model,
2026-05-14 15:07:34 +08:00
)
@router.get("/session/{session_id}", response_model=SessionInfo)
async def get_session_info(session_id: str):
"""Return session info."""
try:
session = get_agent_session_service().get_session(session_id)
return SessionInfo(
session_id=session.session_id,
message_count=len(session.messages),
created_at=session.created_at,
updated_at=session.updated_at,
)
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc))
2026-05-14 15:07:34 +08:00
@router.get("/session/{session_id}/history")
async def get_session_history(session_id: str, max_turns: int = 5):
"""Return session history."""
try:
history = get_agent_session_service().get_history(session_id=session_id, max_turns=max_turns)
return {"session_id": session_id, "history": history}
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc))
2026-05-14 15:07:34 +08:00
@router.delete("/session/{session_id}")
async def delete_session(session_id: str):
"""Delete session."""
try:
get_agent_session_service().delete_session(session_id)
return {"message": "会话已删除", "session_id": session_id}
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc))
2026-05-14 15:07:34 +08:00
@router.get("/sessions", response_model=List[SessionInfo])
async def list_sessions():
"""List sessions."""
return [SessionInfo(**item) for item in get_agent_session_service().list_sessions()]
2026-05-14 15:07:34 +08:00
@router.post("/feedback")
async def submit_feedback(request: FeedbackRequest):
"""Submit feedback."""
try:
result = get_agent_session_service().submit_feedback(
session_id=request.session_id,
message_index=request.message_index,
)
return {"message": "反馈已提交", "session_id": result.session_id, "message_index": result.message_index}
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc))