"""Define API routes for rag.""" from __future__ import annotations import asyncio import json from typing import AsyncGenerator from fastapi import APIRouter from fastapi.responses import StreamingResponse from app.config.settings import settings from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion from app.shared.bootstrap import get_agent_conversation_service router = APIRouter(prefix="/rag", tags=["RAG问答"]) _DEFAULT_QUICK_QUESTIONS = [ {"id": "1", "question": "请总结最新入库法规对电池安全的核心要求", "category": "法规解读"}, {"id": "2", "question": "我上传的制度文档与新能源法规有哪些潜在冲突?", "category": "差距分析"}, {"id": "3", "question": "请给出法规依据,并按条款列出整改建议", "category": "整改建议"}, {"id": "4", "question": "请解释 UN-ECE 与 GB 标准在网络安全方面的差异", "category": "标准对比"}, {"id": "5", "question": "IATF 16949 对供应商质量管理有哪些强制要求?", "category": "法规解读"}, {"id": "6", "question": "ISO 45001 与 AQ 标准在职业健康安全方面的主要差异是什么?", "category": "标准对比"}, ] @router.post("/chat") async def rag_chat(request: RagChatRequest): """Stream RAG Q&A using the real agent service.""" _, event_stream = get_agent_conversation_service().stream_chat( query=request.query, session_id=request.session_id, filters=request.filters, top_k=request.top_k or settings.rag_top_k, ) async def generate() -> AsyncGenerator[str, None]: """Translate agent SSE events to rag format.""" for event in event_stream: event_type = event.get("event", "") data = event.get("data", "") if event_type == "sources": docs = [ { "id": str(s.get("chunk_id") or s.get("doc_id") or idx + 1), "score": s.get("score", 0), "preview": s.get("content", "")[:200], "doc_name": s.get("doc_name", ""), "clause": s.get("section_title", "法规片段"), "doc_id": s.get("doc_id"), "download_url": ( f"/api/v1/documents/download/{s['doc_id']}" if s.get("doc_id") else None ), } for idx, s in enumerate(data if isinstance(data, list) else []) ] yield ( "event: message\n" f"data: {json.dumps({'type': 'retrieved', 'docs': docs}, ensure_ascii=False)}\n\n" ) elif event_type == "content": if data: yield ( "event: message\n" f"data: {json.dumps({'type': 'chunk', 'text': data}, ensure_ascii=False)}\n\n" ) elif event_type == "done": yield ( "event: message\n" f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" ) elif event_type == "status": yield ( "event: message\n" f"data: {json.dumps({'type': 'status', 'text': data}, ensure_ascii=False)}\n\n" ) await asyncio.sleep(0) return StreamingResponse( generate(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, ) @router.get("/quick-questions", response_model=QuickQuestionsResponse) async def get_quick_questions(): """Return configurable quick questions from settings or defaults.""" raw = getattr(settings, "rag_quick_questions", None) if raw and isinstance(raw, list): questions = [ QuickQuestion(id=str(i + 1), question=q if isinstance(q, str) else q.get("question", ""), category=q.get("category", "法规问答") if isinstance(q, dict) else "法规问答") for i, q in enumerate(raw) ] else: questions = [QuickQuestion(**q) for q in _DEFAULT_QUICK_QUESTIONS] return QuickQuestionsResponse(questions=questions)