""" AI SDK compatible chat endpoint """ import asyncio import logging from typing import AsyncGenerator from fastapi import Request from fastapi.responses import StreamingResponse from langchain_core.messages import HumanMessage from .config import get_config from .graph.state import TurnState, Message from .schemas.messages import ChatRequest from .ai_sdk_adapter import stream_ai_sdk_compatible from .sse import create_error_event logger = logging.getLogger(__name__) async def handle_ai_sdk_chat(request: ChatRequest, app_state) -> StreamingResponse: """Handle chat request with AI SDK Data Stream Protocol""" async def ai_sdk_stream() -> AsyncGenerator[str, None]: try: app_config = get_config() memory_manager = app_state.memory_manager graph = app_state.graph # Prepare the new user message for LangGraph (session memory handled automatically) graph_config = { "configurable": { "thread_id": request.session_id } } # Get the latest user message from AI SDK format new_user_message = None if request.messages: last_message = request.messages[-1] if last_message.get("role") == "user": new_user_message = HumanMessage(content=last_message.get("content", "")) if not new_user_message: logger.error("No user message found in request") yield create_error_event("No user message provided") return # Create event queue for internal streaming event_queue = asyncio.Queue() async def stream_callback(event_str: str): await event_queue.put(event_str) async def run_workflow(): try: # Set stream callback in context for the workflow from .graph.graph import stream_callback_context stream_callback_context.set(stream_callback) # Create TurnState with the new user message # AgenticWorkflow will handle LangGraph interaction and session history from .graph.state import TurnState, Message turn_state = TurnState( messages=[Message( role="user", content=str(new_user_message.content), timestamp=None )], session_id=request.session_id, tool_results=[], final_answer="" ) # Use AgenticWorkflow.astream with stream_callback parameter async for final_state in graph.astream(turn_state, stream_callback=stream_callback): # The workflow handles all streaming internally via stream_callback pass # final_state contains the complete result await event_queue.put(None) # Signal completion except Exception as e: logger.error(f"Workflow execution error: {e}", exc_info=True) await event_queue.put(create_error_event(f"Processing error: {str(e)}")) await event_queue.put(None) # Start workflow task workflow_task = asyncio.create_task(run_workflow()) # Convert internal events to AI SDK format async def internal_stream(): try: while True: event = await event_queue.get() if event is None: break yield event finally: if not workflow_task.done(): workflow_task.cancel() # Stream converted events async for ai_sdk_event in stream_ai_sdk_compatible(internal_stream()): yield ai_sdk_event except Exception as e: logger.error(f"AI SDK chat error: {e}") # Send error in AI SDK format from .ai_sdk_adapter import create_error_part yield create_error_part(f"Server error: {str(e)}") return StreamingResponse( ai_sdk_stream(), media_type="text/plain", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*", "x-vercel-ai-data-stream": "v1", # AI SDK Data Stream Protocol header } )