Files
catonline_ai/vw-agentic-rag/service/ai_sdk_chat.py
2025-09-26 17:15:54 +08:00

122 lines
4.8 KiB
Python

"""
AI SDK compatible chat endpoint
"""
import asyncio
import logging
from typing import AsyncGenerator
from fastapi import Request
from fastapi.responses import StreamingResponse
from langchain_core.messages import HumanMessage
from .config import get_config
from .graph.state import TurnState, Message
from .schemas.messages import ChatRequest
from .ai_sdk_adapter import stream_ai_sdk_compatible
from .sse import create_error_event
logger = logging.getLogger(__name__)
async def handle_ai_sdk_chat(request: ChatRequest, app_state) -> StreamingResponse:
"""Handle chat request with AI SDK Data Stream Protocol"""
async def ai_sdk_stream() -> AsyncGenerator[str, None]:
try:
app_config = get_config()
memory_manager = app_state.memory_manager
graph = app_state.graph
# Prepare the new user message for LangGraph (session memory handled automatically)
graph_config = {
"configurable": {
"thread_id": request.session_id
}
}
# Get the latest user message from AI SDK format
new_user_message = None
if request.messages:
last_message = request.messages[-1]
if last_message.get("role") == "user":
new_user_message = HumanMessage(content=last_message.get("content", ""))
if not new_user_message:
logger.error("No user message found in request")
yield create_error_event("No user message provided")
return
# Create event queue for internal streaming
event_queue = asyncio.Queue()
async def stream_callback(event_str: str):
await event_queue.put(event_str)
async def run_workflow():
try:
# Set stream callback in context for the workflow
from .graph.graph import stream_callback_context
stream_callback_context.set(stream_callback)
# Create TurnState with the new user message
# AgenticWorkflow will handle LangGraph interaction and session history
from .graph.state import TurnState, Message
turn_state = TurnState(
messages=[Message(
role="user",
content=str(new_user_message.content),
timestamp=None
)],
session_id=request.session_id,
tool_results=[],
final_answer=""
)
# Use AgenticWorkflow.astream with stream_callback parameter
async for final_state in graph.astream(turn_state, stream_callback=stream_callback):
# The workflow handles all streaming internally via stream_callback
pass # final_state contains the complete result
await event_queue.put(None) # Signal completion
except Exception as e:
logger.error(f"Workflow execution error: {e}", exc_info=True)
await event_queue.put(create_error_event(f"Processing error: {str(e)}"))
await event_queue.put(None)
# Start workflow task
workflow_task = asyncio.create_task(run_workflow())
# Convert internal events to AI SDK format
async def internal_stream():
try:
while True:
event = await event_queue.get()
if event is None:
break
yield event
finally:
if not workflow_task.done():
workflow_task.cancel()
# Stream converted events
async for ai_sdk_event in stream_ai_sdk_compatible(internal_stream()):
yield ai_sdk_event
except Exception as e:
logger.error(f"AI SDK chat error: {e}")
# Send error in AI SDK format
from .ai_sdk_adapter import create_error_part
yield create_error_part(f"Server error: {str(e)}")
return StreamingResponse(
ai_sdk_stream(),
media_type="text/plain",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "*",
"x-vercel-ai-data-stream": "v1", # AI SDK Data Stream Protocol header
}
)