2026-05-18 16:32:42 +08:00
|
|
|
"""Define API routes for rag."""
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import json
|
|
|
|
|
from typing import AsyncGenerator
|
|
|
|
|
|
2026-05-14 15:07:34 +08:00
|
|
|
from fastapi import APIRouter
|
2026-05-18 16:32:42 +08:00
|
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
|
|
2026-05-20 23:34:08 +08:00
|
|
|
from app.config.settings import settings
|
2026-05-14 15:07:34 +08:00
|
|
|
from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion
|
2026-05-21 23:20:39 +08:00
|
|
|
from app.shared.async_utils import iter_in_thread
|
2026-05-20 23:34:08 +08:00
|
|
|
from app.shared.bootstrap import get_agent_conversation_service
|
2026-05-18 16:32:42 +08:00
|
|
|
|
2026-05-14 15:07:34 +08:00
|
|
|
|
|
|
|
|
router = APIRouter(prefix="/rag", tags=["RAG问答"])
|
|
|
|
|
|
2026-05-20 23:34:08 +08:00
|
|
|
_DEFAULT_QUICK_QUESTIONS = [
|
|
|
|
|
{"id": "1", "question": "请总结最新入库法规对电池安全的核心要求", "category": "法规解读"},
|
|
|
|
|
{"id": "2", "question": "我上传的制度文档与新能源法规有哪些潜在冲突?", "category": "差距分析"},
|
|
|
|
|
{"id": "3", "question": "请给出法规依据,并按条款列出整改建议", "category": "整改建议"},
|
|
|
|
|
{"id": "4", "question": "请解释 UN-ECE 与 GB 标准在网络安全方面的差异", "category": "标准对比"},
|
|
|
|
|
{"id": "5", "question": "IATF 16949 对供应商质量管理有哪些强制要求?", "category": "法规解读"},
|
|
|
|
|
{"id": "6", "question": "ISO 45001 与 AQ 标准在职业健康安全方面的主要差异是什么?", "category": "标准对比"},
|
|
|
|
|
]
|
|
|
|
|
|
2026-05-14 15:07:34 +08:00
|
|
|
|
|
|
|
|
@router.post("/chat")
|
|
|
|
|
async def rag_chat(request: RagChatRequest):
|
2026-05-20 23:34:08 +08:00
|
|
|
"""Stream RAG Q&A using the real agent service."""
|
2026-05-21 23:20:39 +08:00
|
|
|
session_id, event_stream = get_agent_conversation_service().stream_chat(
|
2026-05-20 23:34:08 +08:00
|
|
|
query=request.query,
|
|
|
|
|
session_id=request.session_id,
|
|
|
|
|
filters=request.filters,
|
|
|
|
|
top_k=request.top_k or settings.rag_top_k,
|
|
|
|
|
)
|
2026-05-14 15:07:34 +08:00
|
|
|
|
2026-05-18 16:32:42 +08:00
|
|
|
async def generate() -> AsyncGenerator[str, None]:
|
2026-05-20 23:34:08 +08:00
|
|
|
"""Translate agent SSE events to rag format."""
|
2026-05-21 23:20:39 +08:00
|
|
|
yield (
|
|
|
|
|
"event: message\n"
|
|
|
|
|
f"data: {json.dumps({'type': 'session', 'session_id': session_id}, ensure_ascii=False)}\n\n"
|
|
|
|
|
)
|
|
|
|
|
async for event in iter_in_thread(event_stream):
|
2026-05-20 23:34:08 +08:00
|
|
|
event_type = event.get("event", "")
|
|
|
|
|
data = event.get("data", "")
|
|
|
|
|
if event_type == "sources":
|
|
|
|
|
docs = [
|
|
|
|
|
{
|
|
|
|
|
"id": str(s.get("chunk_id") or s.get("doc_id") or idx + 1),
|
|
|
|
|
"score": s.get("score", 0),
|
2026-05-26 20:21:31 +08:00
|
|
|
"preview": s.get("text", s.get("content", ""))[:200],
|
|
|
|
|
"doc_name": s.get("doc_title", s.get("doc_name", "")),
|
2026-05-20 23:34:08 +08:00
|
|
|
"clause": s.get("section_title", "法规片段"),
|
|
|
|
|
"doc_id": s.get("doc_id"),
|
|
|
|
|
"download_url": (
|
|
|
|
|
f"/api/v1/documents/download/{s['doc_id']}" if s.get("doc_id") else None
|
|
|
|
|
),
|
|
|
|
|
}
|
|
|
|
|
for idx, s in enumerate(data if isinstance(data, list) else [])
|
|
|
|
|
]
|
|
|
|
|
yield (
|
|
|
|
|
"event: message\n"
|
|
|
|
|
f"data: {json.dumps({'type': 'retrieved', 'docs': docs}, ensure_ascii=False)}\n\n"
|
|
|
|
|
)
|
|
|
|
|
elif event_type == "content":
|
|
|
|
|
if data:
|
|
|
|
|
yield (
|
|
|
|
|
"event: message\n"
|
|
|
|
|
f"data: {json.dumps({'type': 'chunk', 'text': data}, ensure_ascii=False)}\n\n"
|
|
|
|
|
)
|
|
|
|
|
elif event_type == "done":
|
|
|
|
|
yield (
|
|
|
|
|
"event: message\n"
|
2026-05-21 23:20:39 +08:00
|
|
|
f"data: {json.dumps({'type': 'done', 'session_id': session_id}, ensure_ascii=False)}\n\n"
|
2026-05-20 23:34:08 +08:00
|
|
|
)
|
|
|
|
|
elif event_type == "status":
|
|
|
|
|
yield (
|
|
|
|
|
"event: message\n"
|
|
|
|
|
f"data: {json.dumps({'type': 'status', 'text': data}, ensure_ascii=False)}\n\n"
|
|
|
|
|
)
|
2026-05-21 23:20:39 +08:00
|
|
|
elif event_type == "error":
|
|
|
|
|
yield (
|
|
|
|
|
"event: message\n"
|
|
|
|
|
f"data: {json.dumps({'type': 'error', 'text': str(data)}, ensure_ascii=False)}\n\n"
|
|
|
|
|
)
|
2026-05-14 15:07:34 +08:00
|
|
|
|
2026-05-18 16:32:42 +08:00
|
|
|
return StreamingResponse(
|
|
|
|
|
generate(),
|
|
|
|
|
media_type="text/event-stream",
|
|
|
|
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
|
|
|
|
)
|
2026-05-14 15:07:34 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/quick-questions", response_model=QuickQuestionsResponse)
|
|
|
|
|
async def get_quick_questions():
|
2026-05-20 23:34:08 +08:00
|
|
|
"""Return configurable quick questions from settings or defaults."""
|
|
|
|
|
raw = getattr(settings, "rag_quick_questions", None)
|
|
|
|
|
if raw and isinstance(raw, list):
|
|
|
|
|
questions = [
|
|
|
|
|
QuickQuestion(id=str(i + 1), question=q if isinstance(q, str) else q.get("question", ""), category=q.get("category", "法规问答") if isinstance(q, dict) else "法规问答")
|
|
|
|
|
for i, q in enumerate(raw)
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
questions = [QuickQuestion(**q) for q in _DEFAULT_QUICK_QUESTIONS]
|
2026-05-18 16:32:42 +08:00
|
|
|
return QuickQuestionsResponse(questions=questions)
|