"""Define API routes for rag.""" from __future__ import annotations import json from typing import AsyncGenerator from fastapi import APIRouter, Depends from fastapi.responses import StreamingResponse from app.api.dependencies.auth import get_current_user from app.config.settings import settings from app.domain.auth.models import UserClaims from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion from app.shared.async_utils import iter_in_thread from app.shared.bootstrap import get_agent_conversation_service router = APIRouter(prefix="/rag", tags=["RAG问答"]) _DEFAULT_QUICK_QUESTIONS = [ {"id": "1", "question": "请总结最新入库法规对电池安全的核心要求", "category": "法规解读"}, {"id": "2", "question": "我上传的制度文档与新能源法规有哪些潜在冲突?", "category": "差距分析"}, {"id": "3", "question": "请给出法规依据,并按条款列出整改建议", "category": "整改建议"}, {"id": "4", "question": "请解释 UN-ECE 与 GB 标准在网络安全方面的差异", "category": "标准对比"}, {"id": "5", "question": "IATF 16949 对供应商质量管理有哪些强制要求?", "category": "法规解读"}, {"id": "6", "question": "ISO 45001 与 AQ 标准在职业健康安全方面的主要差异是什么?", "category": "标准对比"}, ] @router.post("/chat") async def rag_chat( request: RagChatRequest, current_user: UserClaims = Depends(get_current_user), ): """Stream RAG Q&A using the real agent service.""" session_id, event_stream = get_agent_conversation_service().stream_chat( query=request.query, session_id=request.session_id, filters=request.filters, top_k=request.top_k or settings.rag_top_k, ) async def generate() -> AsyncGenerator[str, None]: """Translate agent SSE events to rag format.""" yield ( "event: message\n" f"data: {json.dumps({'type': 'session', 'session_id': session_id}, ensure_ascii=False)}\n\n" ) async for event in iter_in_thread(event_stream): event_type = event.get("event", "") data = event.get("data", "") if event_type == "sources": docs = [ { "id": str(s.get("chunk_id") or s.get("doc_id") or idx + 1), "score": s.get("score", 0), "preview": s.get("text", s.get("content", ""))[:200], "doc_name": s.get("doc_title", s.get("doc_name", "")), "clause": s.get("section_title", "法规片段"), "doc_id": s.get("doc_id"), "download_url": ( f"/api/v1/documents/download/{s['doc_id']}" if s.get("doc_id") else None ), } for idx, s in enumerate(data if isinstance(data, list) else []) ] yield ( "event: message\n" f"data: {json.dumps({'type': 'retrieved', 'docs': docs}, ensure_ascii=False)}\n\n" ) elif event_type == "content": if data: yield ( "event: message\n" f"data: {json.dumps({'type': 'chunk', 'text': data}, ensure_ascii=False)}\n\n" ) elif event_type == "done": yield ( "event: message\n" f"data: {json.dumps({'type': 'done', 'session_id': session_id}, ensure_ascii=False)}\n\n" ) elif event_type == "status": yield ( "event: message\n" f"data: {json.dumps({'type': 'status', 'text': data}, ensure_ascii=False)}\n\n" ) elif event_type == "error": yield ( "event: message\n" f"data: {json.dumps({'type': 'error', 'text': str(data)}, ensure_ascii=False)}\n\n" ) return StreamingResponse( generate(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, ) @router.get("/quick-questions", response_model=QuickQuestionsResponse) async def get_quick_questions(): """Return configurable quick questions from settings or defaults.""" raw = getattr(settings, "rag_quick_questions", None) if raw and isinstance(raw, list): questions = [ QuickQuestion(id=str(i + 1), question=q if isinstance(q, str) else q.get("question", ""), category=q.get("category", "法规问答") if isinstance(q, dict) else "法规问答") for i, q in enumerate(raw) ] else: questions = [QuickQuestion(**q) for q in _DEFAULT_QUICK_QUESTIONS] return QuickQuestionsResponse(questions=questions)