Move session-related responsibilities into a new application-layer AgentSessionService (and AgentSessionFeedbackResult dataclass), provide a bootstrap factory (get_agent_session_service), and update agent API routes to call the service instead of accessing ConversationStore directly. Routes now translate ValueError into 404 responses and use service methods for get/list/history/delete/feedback. Also update package exports and docs/READMEs to declare the backend architecture authority, enforce api -> application -> domain ports -> infrastructure boundaries, and call out legacy services/workflows as migration-only. These changes centralize session logic in the application layer and tighten architecture guidance for future backend work.
185 lines
6.4 KiB
Python
185 lines
6.4 KiB
Python
"""Define API routes for agent."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import AsyncGenerator, List, Optional
|
|
|
|
from fastapi import APIRouter, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
from dataclasses import asdict
|
|
|
|
from app.api.models import (
|
|
AskRequest,
|
|
AskResponse,
|
|
ChatRequest,
|
|
ChatResponse,
|
|
FeedbackRequest,
|
|
SessionInfo,
|
|
)
|
|
from app.config.settings import settings
|
|
from app.shared.async_utils import iter_in_thread
|
|
from app.shared.bootstrap import get_agent_conversation_service, get_conversation_store
|
|
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
|
|
|
|
|
router = APIRouter(prefix="/agent", tags=["agent"])
|
|
|
|
|
|
@router.post("/ask", response_model=AskResponse)
|
|
async def ask_question(request: AskRequest):
|
|
"""Handle ask question."""
|
|
try:
|
|
_, result = get_agent_conversation_service().ask(
|
|
query=request.query,
|
|
filters=request.filters,
|
|
provider=request.provider or settings.llm_provider,
|
|
model=request.model or settings.llm_model,
|
|
top_k=request.top_k or settings.rag_top_k,
|
|
prompt_template=request.prompt_template,
|
|
)
|
|
return AskResponse(
|
|
answer=result.answer,
|
|
sources=[asdict(source) for source in result.sources],
|
|
model=result.model,
|
|
latency_ms=result.latency_ms,
|
|
retrieved_count=result.retrieved_count,
|
|
context_tokens=result.context_tokens,
|
|
truncated=result.truncated,
|
|
error=result.error,
|
|
)
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=500, detail=str(exc))
|
|
|
|
|
|
@router.post("/chat", response_model=ChatResponse)
|
|
async def chat_with_session(request: ChatRequest):
|
|
"""Handle chat with session."""
|
|
try:
|
|
session_id, result = get_agent_conversation_service().chat(
|
|
query=request.query,
|
|
session_id=request.session_id,
|
|
filters=request.filters,
|
|
provider=request.provider or settings.llm_provider,
|
|
model=request.model or settings.llm_model,
|
|
top_k=request.top_k or settings.rag_top_k,
|
|
)
|
|
session = get_agent_session_service().get_session(session_id)
|
|
return ChatResponse(
|
|
session_id=session_id,
|
|
answer=result.answer,
|
|
sources=[asdict(source) for source in result.sources],
|
|
model=result.model,
|
|
latency_ms=result.latency_ms,
|
|
message_count=len(session.messages) if session else 0,
|
|
)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=404, detail=str(exc))
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=500, detail=str(exc))
|
|
|
|
|
|
@router.get("/chat/stream")
|
|
async def chat_stream_get(
|
|
query: str,
|
|
session_id: Optional[str] = None,
|
|
filters: Optional[str] = None,
|
|
provider: Optional[str] = None,
|
|
model: Optional[str] = None,
|
|
):
|
|
"""Handle chat stream get."""
|
|
async def generate_sse() -> AsyncGenerator[str, None]:
|
|
"""Handle generate sse."""
|
|
try:
|
|
session_id_, event_stream = get_agent_conversation_service().stream_chat(
|
|
query=query,
|
|
session_id=session_id,
|
|
filters=filters,
|
|
provider=provider or settings.llm_provider,
|
|
model=model or settings.llm_model,
|
|
top_k=settings.rag_top_k,
|
|
)
|
|
yield f"event: session\ndata: {json.dumps({'session_id': session_id_})}\n\n"
|
|
async for event_data in iter_in_thread(event_stream):
|
|
event_type = event_data.get("event", "content")
|
|
data = event_data.get("data", "")
|
|
if isinstance(data, (dict, list)):
|
|
yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
|
else:
|
|
yield f"event: {event_type}\ndata: {data}\n\n"
|
|
except Exception as exc:
|
|
yield f"event: error\ndata: {str(exc)}\n\n"
|
|
|
|
return StreamingResponse(
|
|
generate_sse(),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
|
)
|
|
|
|
|
|
@router.post("/chat/stream")
|
|
async def chat_stream(request: ChatRequest):
|
|
"""Handle chat stream."""
|
|
return await chat_stream_get(
|
|
query=request.query,
|
|
session_id=request.session_id,
|
|
filters=request.filters,
|
|
provider=request.provider,
|
|
model=request.model,
|
|
)
|
|
|
|
|
|
@router.get("/session/{session_id}", response_model=SessionInfo)
|
|
async def get_session_info(session_id: str):
|
|
"""Return session info."""
|
|
try:
|
|
session = get_agent_session_service().get_session(session_id)
|
|
return SessionInfo(
|
|
session_id=session.session_id,
|
|
message_count=len(session.messages),
|
|
created_at=session.created_at,
|
|
updated_at=session.updated_at,
|
|
)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=404, detail=str(exc))
|
|
|
|
|
|
@router.get("/session/{session_id}/history")
|
|
async def get_session_history(session_id: str, max_turns: int = 5):
|
|
"""Return session history."""
|
|
try:
|
|
history = get_agent_session_service().get_history(session_id=session_id, max_turns=max_turns)
|
|
return {"session_id": session_id, "history": history}
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=404, detail=str(exc))
|
|
|
|
|
|
@router.delete("/session/{session_id}")
|
|
async def delete_session(session_id: str):
|
|
"""Delete session."""
|
|
try:
|
|
get_agent_session_service().delete_session(session_id)
|
|
return {"message": "会话已删除", "session_id": session_id}
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=404, detail=str(exc))
|
|
|
|
|
|
@router.get("/sessions", response_model=List[SessionInfo])
|
|
async def list_sessions():
|
|
"""List sessions."""
|
|
return [SessionInfo(**item) for item in get_agent_session_service().list_sessions()]
|
|
|
|
|
|
@router.post("/feedback")
|
|
async def submit_feedback(request: FeedbackRequest):
|
|
"""Submit feedback."""
|
|
try:
|
|
result = get_agent_session_service().submit_feedback(
|
|
session_id=request.session_id,
|
|
message_index=request.message_index,
|
|
)
|
|
return {"message": "反馈已提交", "session_id": result.session_id, "message_index": result.message_index}
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=404, detail=str(exc))
|