Files
AIRegulation-DocAnalysis/backend/app/api/routes/rag.py

95 lines
3.6 KiB
Python

"""Define API routes for rag."""
from __future__ import annotations
import asyncio
import json
from typing import AsyncGenerator
from fastapi import APIRouter
from fastapi.responses import StreamingResponse
from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion
from app.services.mock_data import (
get_mock_quick_questions,
get_mock_retrieval,
get_mock_rag_answer,
)
# Keep route handlers close to their transport-layer wiring for easier auditing.
router = APIRouter(prefix="/rag", tags=["RAG问答"])
@router.post("/chat")
async def rag_chat(request: RagChatRequest):
"""Handle rag chat."""
async def generate() -> AsyncGenerator[str, None]:
# Keep route handlers close to their transport-layer wiring for easier auditing.
"""Handle generate."""
yield f"event: message\ndata: {json.dumps({'type': 'retrieving'}, ensure_ascii=False)}\n\n"
# Keep route handlers close to their transport-layer wiring for easier auditing.
await asyncio.sleep(0.3)
# Keep route handlers close to their transport-layer wiring for easier auditing.
docs = get_mock_retrieval(request.query, top_k=request.top_k)
retrieved_data = [
{
"id": d["id"],
"score": d["score"],
"preview": d["preview"],
"doc_name": d.get("doc_name", ""),
"clause": d.get("clause", ""),
}
for d in docs
]
yield f"event: message\ndata: {json.dumps({'type': 'retrieved', 'docs': retrieved_data}, ensure_ascii=False)}\n\n"
# Keep route handlers close to their transport-layer wiring for easier auditing.
yield (
f"event: message\ndata: "
f"{json.dumps({'type': 'generating', 'text': '正在生成答案...'}, ensure_ascii=False)}\n\n"
)
# Keep route handlers close to their transport-layer wiring for easier auditing.
await asyncio.sleep(0.2)
# Keep route handlers close to their transport-layer wiring for easier auditing.
answer = get_mock_rag_answer(request.query)
# Keep route handlers close to their transport-layer wiring for easier auditing.
sentences = answer.split("\n\n")
for sentence in sentences:
if sentence.strip():
# Keep route handlers close to their transport-layer wiring for easier auditing.
chunks = sentence.split("\n")
for chunk in chunks:
if chunk.strip():
await asyncio.sleep(0.05) # Keep route handlers close to their transport-layer wiring for easier auditing.
yield (
"event: message\n"
f"data: {json.dumps({'type': 'chunk', 'text': chunk + chr(10)}, ensure_ascii=False)}\n\n"
)
# Keep route handlers close to their transport-layer wiring for easier auditing.
yield f"event: message\ndata: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)
@router.get("/quick-questions", response_model=QuickQuestionsResponse)
async def get_quick_questions():
"""Return quick questions."""
questions = [
QuickQuestion(id=q["id"], question=q["question"], category=q["category"])
for q in get_mock_quick_questions()
]
return QuickQuestionsResponse(questions=questions)