Fix 法规对话
This commit is contained in:
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
@@ -11,6 +10,7 @@ from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion
|
||||
from app.shared.async_utils import iter_in_thread
|
||||
from app.shared.bootstrap import get_agent_conversation_service
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ _DEFAULT_QUICK_QUESTIONS = [
|
||||
@router.post("/chat")
|
||||
async def rag_chat(request: RagChatRequest):
|
||||
"""Stream RAG Q&A using the real agent service."""
|
||||
_, event_stream = get_agent_conversation_service().stream_chat(
|
||||
session_id, event_stream = get_agent_conversation_service().stream_chat(
|
||||
query=request.query,
|
||||
session_id=request.session_id,
|
||||
filters=request.filters,
|
||||
@@ -38,7 +38,11 @@ async def rag_chat(request: RagChatRequest):
|
||||
|
||||
async def generate() -> AsyncGenerator[str, None]:
|
||||
"""Translate agent SSE events to rag format."""
|
||||
for event in event_stream:
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'session', 'session_id': session_id}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
async for event in iter_in_thread(event_stream):
|
||||
event_type = event.get("event", "")
|
||||
data = event.get("data", "")
|
||||
if event_type == "sources":
|
||||
@@ -69,14 +73,18 @@ async def rag_chat(request: RagChatRequest):
|
||||
elif event_type == "done":
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
f"data: {json.dumps({'type': 'done', 'session_id': session_id}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
elif event_type == "status":
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'status', 'text': data}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
elif event_type == "error":
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'error', 'text': str(data)}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
|
||||
Reference in New Issue
Block a user