"""Define API routes for rag.""" from __future__ import annotations import asyncio import json from typing import AsyncGenerator from fastapi import APIRouter from fastapi.responses import StreamingResponse from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion from app.services.mock_data import ( get_mock_quick_questions, get_mock_retrieval, get_mock_rag_answer, ) # Keep route handlers close to their transport-layer wiring for easier auditing. router = APIRouter(prefix="/rag", tags=["RAG问答"]) @router.post("/chat") async def rag_chat(request: RagChatRequest): """Handle rag chat.""" async def generate() -> AsyncGenerator[str, None]: # Keep route handlers close to their transport-layer wiring for easier auditing. """Handle generate.""" yield f"event: message\ndata: {json.dumps({'type': 'retrieving'}, ensure_ascii=False)}\n\n" # Keep route handlers close to their transport-layer wiring for easier auditing. await asyncio.sleep(0.3) # Keep route handlers close to their transport-layer wiring for easier auditing. docs = get_mock_retrieval(request.query, top_k=request.top_k) retrieved_data = [ { "id": d["id"], "score": d["score"], "preview": d["preview"], "doc_name": d.get("doc_name", ""), "clause": d.get("clause", ""), } for d in docs ] yield f"event: message\ndata: {json.dumps({'type': 'retrieved', 'docs': retrieved_data}, ensure_ascii=False)}\n\n" # Keep route handlers close to their transport-layer wiring for easier auditing. yield ( f"event: message\ndata: " f"{json.dumps({'type': 'generating', 'text': '正在生成答案...'}, ensure_ascii=False)}\n\n" ) # Keep route handlers close to their transport-layer wiring for easier auditing. await asyncio.sleep(0.2) # Keep route handlers close to their transport-layer wiring for easier auditing. answer = get_mock_rag_answer(request.query) # Keep route handlers close to their transport-layer wiring for easier auditing. sentences = answer.split("\n\n") for sentence in sentences: if sentence.strip(): # Keep route handlers close to their transport-layer wiring for easier auditing. chunks = sentence.split("\n") for chunk in chunks: if chunk.strip(): await asyncio.sleep(0.05) # Keep route handlers close to their transport-layer wiring for easier auditing. yield ( "event: message\n" f"data: {json.dumps({'type': 'chunk', 'text': chunk + chr(10)}, ensure_ascii=False)}\n\n" ) # Keep route handlers close to their transport-layer wiring for easier auditing. yield f"event: message\ndata: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" return StreamingResponse( generate(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, ) @router.get("/quick-questions", response_model=QuickQuestionsResponse) async def get_quick_questions(): """Return quick questions.""" questions = [ QuickQuestion(id=q["id"], question=q["question"], category=q["category"]) for q in get_mock_quick_questions() ] return QuickQuestionsResponse(questions=questions)