122 lines
4.8 KiB
Python
122 lines
4.8 KiB
Python
"""
|
|
AI SDK compatible chat endpoint
|
|
"""
|
|
import asyncio
|
|
import logging
|
|
from typing import AsyncGenerator
|
|
|
|
from fastapi import Request
|
|
from fastapi.responses import StreamingResponse
|
|
from langchain_core.messages import HumanMessage
|
|
|
|
from .config import get_config
|
|
from .graph.state import TurnState, Message
|
|
from .schemas.messages import ChatRequest
|
|
from .ai_sdk_adapter import stream_ai_sdk_compatible
|
|
from .sse import create_error_event
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def handle_ai_sdk_chat(request: ChatRequest, app_state) -> StreamingResponse:
|
|
"""Handle chat request with AI SDK Data Stream Protocol"""
|
|
|
|
async def ai_sdk_stream() -> AsyncGenerator[str, None]:
|
|
try:
|
|
app_config = get_config()
|
|
memory_manager = app_state.memory_manager
|
|
graph = app_state.graph
|
|
|
|
# Prepare the new user message for LangGraph (session memory handled automatically)
|
|
graph_config = {
|
|
"configurable": {
|
|
"thread_id": request.session_id
|
|
}
|
|
}
|
|
|
|
# Get the latest user message from AI SDK format
|
|
new_user_message = None
|
|
if request.messages:
|
|
last_message = request.messages[-1]
|
|
if last_message.get("role") == "user":
|
|
new_user_message = HumanMessage(content=last_message.get("content", ""))
|
|
|
|
if not new_user_message:
|
|
logger.error("No user message found in request")
|
|
yield create_error_event("No user message provided")
|
|
return
|
|
|
|
# Create event queue for internal streaming
|
|
event_queue = asyncio.Queue()
|
|
|
|
async def stream_callback(event_str: str):
|
|
await event_queue.put(event_str)
|
|
|
|
async def run_workflow():
|
|
try:
|
|
# Set stream callback in context for the workflow
|
|
from .graph.graph import stream_callback_context
|
|
stream_callback_context.set(stream_callback)
|
|
|
|
# Create TurnState with the new user message
|
|
# AgenticWorkflow will handle LangGraph interaction and session history
|
|
from .graph.state import TurnState, Message
|
|
|
|
turn_state = TurnState(
|
|
messages=[Message(
|
|
role="user",
|
|
content=str(new_user_message.content),
|
|
timestamp=None
|
|
)],
|
|
session_id=request.session_id,
|
|
tool_results=[],
|
|
final_answer=""
|
|
)
|
|
|
|
# Use AgenticWorkflow.astream with stream_callback parameter
|
|
async for final_state in graph.astream(turn_state, stream_callback=stream_callback):
|
|
# The workflow handles all streaming internally via stream_callback
|
|
pass # final_state contains the complete result
|
|
await event_queue.put(None) # Signal completion
|
|
except Exception as e:
|
|
logger.error(f"Workflow execution error: {e}", exc_info=True)
|
|
await event_queue.put(create_error_event(f"Processing error: {str(e)}"))
|
|
await event_queue.put(None)
|
|
|
|
# Start workflow task
|
|
workflow_task = asyncio.create_task(run_workflow())
|
|
|
|
# Convert internal events to AI SDK format
|
|
async def internal_stream():
|
|
try:
|
|
while True:
|
|
event = await event_queue.get()
|
|
if event is None:
|
|
break
|
|
yield event
|
|
finally:
|
|
if not workflow_task.done():
|
|
workflow_task.cancel()
|
|
|
|
# Stream converted events
|
|
async for ai_sdk_event in stream_ai_sdk_compatible(internal_stream()):
|
|
yield ai_sdk_event
|
|
|
|
except Exception as e:
|
|
logger.error(f"AI SDK chat error: {e}")
|
|
# Send error in AI SDK format
|
|
from .ai_sdk_adapter import create_error_part
|
|
yield create_error_part(f"Server error: {str(e)}")
|
|
|
|
return StreamingResponse(
|
|
ai_sdk_stream(),
|
|
media_type="text/plain",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"Access-Control-Allow-Origin": "*",
|
|
"Access-Control-Allow-Headers": "*",
|
|
"x-vercel-ai-data-stream": "v1", # AI SDK Data Stream Protocol header
|
|
}
|
|
)
|