init
This commit is contained in:
121
vw-agentic-rag/service/ai_sdk_chat.py
Normal file
121
vw-agentic-rag/service/ai_sdk_chat.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
AI SDK compatible chat endpoint
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from .config import get_config
|
||||
from .graph.state import TurnState, Message
|
||||
from .schemas.messages import ChatRequest
|
||||
from .ai_sdk_adapter import stream_ai_sdk_compatible
|
||||
from .sse import create_error_event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def handle_ai_sdk_chat(request: ChatRequest, app_state) -> StreamingResponse:
|
||||
"""Handle chat request with AI SDK Data Stream Protocol"""
|
||||
|
||||
async def ai_sdk_stream() -> AsyncGenerator[str, None]:
|
||||
try:
|
||||
app_config = get_config()
|
||||
memory_manager = app_state.memory_manager
|
||||
graph = app_state.graph
|
||||
|
||||
# Prepare the new user message for LangGraph (session memory handled automatically)
|
||||
graph_config = {
|
||||
"configurable": {
|
||||
"thread_id": request.session_id
|
||||
}
|
||||
}
|
||||
|
||||
# Get the latest user message from AI SDK format
|
||||
new_user_message = None
|
||||
if request.messages:
|
||||
last_message = request.messages[-1]
|
||||
if last_message.get("role") == "user":
|
||||
new_user_message = HumanMessage(content=last_message.get("content", ""))
|
||||
|
||||
if not new_user_message:
|
||||
logger.error("No user message found in request")
|
||||
yield create_error_event("No user message provided")
|
||||
return
|
||||
|
||||
# Create event queue for internal streaming
|
||||
event_queue = asyncio.Queue()
|
||||
|
||||
async def stream_callback(event_str: str):
|
||||
await event_queue.put(event_str)
|
||||
|
||||
async def run_workflow():
|
||||
try:
|
||||
# Set stream callback in context for the workflow
|
||||
from .graph.graph import stream_callback_context
|
||||
stream_callback_context.set(stream_callback)
|
||||
|
||||
# Create TurnState with the new user message
|
||||
# AgenticWorkflow will handle LangGraph interaction and session history
|
||||
from .graph.state import TurnState, Message
|
||||
|
||||
turn_state = TurnState(
|
||||
messages=[Message(
|
||||
role="user",
|
||||
content=str(new_user_message.content),
|
||||
timestamp=None
|
||||
)],
|
||||
session_id=request.session_id,
|
||||
tool_results=[],
|
||||
final_answer=""
|
||||
)
|
||||
|
||||
# Use AgenticWorkflow.astream with stream_callback parameter
|
||||
async for final_state in graph.astream(turn_state, stream_callback=stream_callback):
|
||||
# The workflow handles all streaming internally via stream_callback
|
||||
pass # final_state contains the complete result
|
||||
await event_queue.put(None) # Signal completion
|
||||
except Exception as e:
|
||||
logger.error(f"Workflow execution error: {e}", exc_info=True)
|
||||
await event_queue.put(create_error_event(f"Processing error: {str(e)}"))
|
||||
await event_queue.put(None)
|
||||
|
||||
# Start workflow task
|
||||
workflow_task = asyncio.create_task(run_workflow())
|
||||
|
||||
# Convert internal events to AI SDK format
|
||||
async def internal_stream():
|
||||
try:
|
||||
while True:
|
||||
event = await event_queue.get()
|
||||
if event is None:
|
||||
break
|
||||
yield event
|
||||
finally:
|
||||
if not workflow_task.done():
|
||||
workflow_task.cancel()
|
||||
|
||||
# Stream converted events
|
||||
async for ai_sdk_event in stream_ai_sdk_compatible(internal_stream()):
|
||||
yield ai_sdk_event
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AI SDK chat error: {e}")
|
||||
# Send error in AI SDK format
|
||||
from .ai_sdk_adapter import create_error_part
|
||||
yield create_error_part(f"Server error: {str(e)}")
|
||||
|
||||
return StreamingResponse(
|
||||
ai_sdk_stream(),
|
||||
media_type="text/plain",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Headers": "*",
|
||||
"x-vercel-ai-data-stream": "v1", # AI SDK Data Stream Protocol header
|
||||
}
|
||||
)
|
||||
Reference in New Issue
Block a user