74 lines
2.4 KiB
Python
74 lines
2.4 KiB
Python
|
|
from fastapi import APIRouter
|
||
|
|
from sse_starlette.sse import EventSourceResponse
|
||
|
|
from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion
|
||
|
|
from app.services.mock_data import (
|
||
|
|
get_mock_quick_questions,
|
||
|
|
get_mock_retrieval,
|
||
|
|
get_mock_rag_answer,
|
||
|
|
)
|
||
|
|
import json
|
||
|
|
import asyncio
|
||
|
|
|
||
|
|
router = APIRouter(prefix="/rag", tags=["RAG问答"])
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/chat")
|
||
|
|
async def rag_chat(request: RagChatRequest):
|
||
|
|
"""SSE流式问答"""
|
||
|
|
|
||
|
|
async def generate():
|
||
|
|
# 发送检索开始事件
|
||
|
|
yield {"event": "message", "data": json.dumps({"type": "retrieving"})}
|
||
|
|
|
||
|
|
# 模拟检索延迟
|
||
|
|
await asyncio.sleep(0.3)
|
||
|
|
|
||
|
|
# 执行检索
|
||
|
|
docs = get_mock_retrieval(request.query, top_k=request.top_k)
|
||
|
|
|
||
|
|
retrieved_data = [
|
||
|
|
{
|
||
|
|
"id": d["id"],
|
||
|
|
"score": d["score"],
|
||
|
|
"preview": d["preview"],
|
||
|
|
"doc_name": d.get("doc_name", ""),
|
||
|
|
"clause": d.get("clause", ""),
|
||
|
|
}
|
||
|
|
for d in docs
|
||
|
|
]
|
||
|
|
yield {"event": "message", "data": json.dumps({"type": "retrieved", "docs": retrieved_data})}
|
||
|
|
|
||
|
|
# 发送生成开始事件
|
||
|
|
yield {"event": "message", "data": json.dumps({"type": "generating", "text": "正在生成答案..."})}
|
||
|
|
|
||
|
|
# 模拟生成延迟
|
||
|
|
await asyncio.sleep(0.2)
|
||
|
|
|
||
|
|
# 获取预设答案
|
||
|
|
answer = get_mock_rag_answer(request.query)
|
||
|
|
|
||
|
|
# 流式输出答案(按句子分割)
|
||
|
|
sentences = answer.split("\n\n")
|
||
|
|
for sentence in sentences:
|
||
|
|
if sentence.strip():
|
||
|
|
# 进一步分割长句子
|
||
|
|
chunks = sentence.split("\n")
|
||
|
|
for chunk in chunks:
|
||
|
|
if chunk.strip():
|
||
|
|
await asyncio.sleep(0.05) # 模拟生成延迟
|
||
|
|
yield {"event": "message", "data": json.dumps({"type": "chunk", "text": chunk + "\n"})}
|
||
|
|
|
||
|
|
# 发送完成事件
|
||
|
|
yield {"event": "message", "data": json.dumps({"type": "done"})}
|
||
|
|
|
||
|
|
return EventSourceResponse(generate())
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/quick-questions", response_model=QuickQuestionsResponse)
|
||
|
|
async def get_quick_questions():
|
||
|
|
"""获取预设快捷问题"""
|
||
|
|
questions = [
|
||
|
|
QuickQuestion(id=q["id"], question=q["question"], category=q["category"])
|
||
|
|
for q in get_mock_quick_questions()
|
||
|
|
]
|
||
|
|
return QuickQuestionsResponse(questions=questions)
|