179 lines
6.3 KiB
Python
179 lines
6.3 KiB
Python
"""Define API routes for agent."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
from typing import AsyncGenerator, List, Optional
|
|
|
|
from fastapi import APIRouter, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
from dataclasses import asdict
|
|
|
|
from app.api.models import (
|
|
AskRequest,
|
|
AskResponse,
|
|
ChatRequest,
|
|
ChatResponse,
|
|
FeedbackRequest,
|
|
SessionInfo,
|
|
)
|
|
from app.config.settings import settings
|
|
from app.shared.bootstrap import get_agent_conversation_service, get_conversation_store
|
|
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
|
|
|
|
|
router = APIRouter(prefix="/agent", tags=["agent"])
|
|
|
|
|
|
@router.post("/ask", response_model=AskResponse)
|
|
async def ask_question(request: AskRequest):
|
|
"""Handle ask question."""
|
|
try:
|
|
_, result = get_agent_conversation_service().ask(
|
|
query=request.query,
|
|
filters=request.filters,
|
|
provider=request.provider or settings.llm_provider,
|
|
model=request.model or settings.llm_model,
|
|
top_k=request.top_k or settings.rag_top_k,
|
|
prompt_template=request.prompt_template,
|
|
)
|
|
return AskResponse(
|
|
answer=result.answer,
|
|
sources=[asdict(source) for source in result.sources],
|
|
model=result.model,
|
|
latency_ms=result.latency_ms,
|
|
retrieved_count=result.retrieved_count,
|
|
context_tokens=result.context_tokens,
|
|
truncated=result.truncated,
|
|
error=result.error,
|
|
)
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=500, detail=str(exc))
|
|
|
|
|
|
@router.post("/chat", response_model=ChatResponse)
|
|
async def chat_with_session(request: ChatRequest):
|
|
"""Handle chat with session."""
|
|
try:
|
|
session_id, result = get_agent_conversation_service().chat(
|
|
query=request.query,
|
|
session_id=request.session_id,
|
|
filters=request.filters,
|
|
provider=request.provider or settings.llm_provider,
|
|
model=request.model or settings.llm_model,
|
|
top_k=request.top_k or settings.rag_top_k,
|
|
)
|
|
session = get_conversation_store().get_session(session_id)
|
|
return ChatResponse(
|
|
session_id=session_id,
|
|
answer=result.answer,
|
|
sources=[asdict(source) for source in result.sources],
|
|
model=result.model,
|
|
latency_ms=result.latency_ms,
|
|
message_count=len(session.messages) if session else 0,
|
|
)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=404, detail=str(exc))
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=500, detail=str(exc))
|
|
|
|
|
|
@router.get("/chat/stream")
|
|
async def chat_stream_get(
|
|
query: str,
|
|
session_id: Optional[str] = None,
|
|
filters: Optional[str] = None,
|
|
provider: Optional[str] = None,
|
|
model: Optional[str] = None,
|
|
):
|
|
"""Handle chat stream get."""
|
|
async def generate_sse() -> AsyncGenerator[str, None]:
|
|
"""Handle generate sse."""
|
|
try:
|
|
session_id_, event_stream = get_agent_conversation_service().stream_chat(
|
|
query=query,
|
|
session_id=session_id,
|
|
filters=filters,
|
|
provider=provider or settings.llm_provider,
|
|
model=model or settings.llm_model,
|
|
top_k=settings.rag_top_k,
|
|
)
|
|
yield f"event: session\ndata: {json.dumps({'session_id': session_id_})}\n\n"
|
|
for event_data in event_stream:
|
|
event_type = event_data.get("event", "content")
|
|
data = event_data.get("data", "")
|
|
if isinstance(data, (dict, list)):
|
|
yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
|
else:
|
|
yield f"event: {event_type}\ndata: {data}\n\n"
|
|
await asyncio.sleep(0)
|
|
except Exception as exc:
|
|
yield f"event: error\ndata: {str(exc)}\n\n"
|
|
|
|
return StreamingResponse(
|
|
generate_sse(),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
|
)
|
|
|
|
|
|
@router.post("/chat/stream")
|
|
async def chat_stream(request: ChatRequest):
|
|
"""Handle chat stream."""
|
|
return await chat_stream_get(
|
|
query=request.query,
|
|
session_id=request.session_id,
|
|
filters=request.filters,
|
|
provider=request.provider,
|
|
model=request.model,
|
|
)
|
|
|
|
|
|
@router.get("/session/{session_id}", response_model=SessionInfo)
|
|
async def get_session_info(session_id: str):
|
|
"""Return session info."""
|
|
session = get_conversation_store().get_session(session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="会话不存在或已过期")
|
|
return SessionInfo(
|
|
session_id=session.session_id,
|
|
message_count=len(session.messages),
|
|
created_at=session.created_at,
|
|
updated_at=session.updated_at,
|
|
)
|
|
|
|
|
|
@router.get("/session/{session_id}/history")
|
|
async def get_session_history(session_id: str, max_turns: int = 5):
|
|
"""Return session history."""
|
|
session = get_conversation_store().get_session(session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="会话不存在或已过期")
|
|
history = [{"role": msg.role, "content": msg.content} for msg in session.messages[-(max_turns * 2):]]
|
|
return {"session_id": session_id, "history": history}
|
|
|
|
|
|
@router.delete("/session/{session_id}")
|
|
async def delete_session(session_id: str):
|
|
"""Delete session."""
|
|
if not get_conversation_store().delete_session(session_id):
|
|
raise HTTPException(status_code=404, detail="会话不存在")
|
|
return {"message": "会话已删除", "session_id": session_id}
|
|
|
|
|
|
@router.get("/sessions", response_model=List[SessionInfo])
|
|
async def list_sessions():
|
|
"""List sessions."""
|
|
return [SessionInfo(**item) for item in get_conversation_store().list_sessions()]
|
|
|
|
|
|
@router.post("/feedback")
|
|
async def submit_feedback(request: FeedbackRequest):
|
|
"""Submit feedback."""
|
|
session = get_conversation_store().get_session(request.session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="会话不存在")
|
|
return {"message": "反馈已提交", "session_id": request.session_id, "message_index": request.message_index}
|