147 lines
5.0 KiB
Python
147 lines
5.0 KiB
Python
"""
|
|
AI SDK Data Stream Protocol adapter
|
|
Converts our internal SSE events to AI SDK compatible format
|
|
Following the official Data Stream Protocol: TYPE_ID:CONTENT_JSON\n
|
|
"""
|
|
import json
|
|
import uuid
|
|
from typing import Dict, Any, AsyncGenerator
|
|
|
|
|
|
def format_data_stream_part(type_id: str, content: Any) -> str:
|
|
"""Format data as AI SDK Data Stream Protocol part: TYPE_ID:JSON\n"""
|
|
content_json = json.dumps(content, ensure_ascii=False)
|
|
return f"{type_id}:{content_json}\n"
|
|
|
|
|
|
def create_text_part(text: str) -> str:
|
|
"""Create text part (type 0)"""
|
|
return format_data_stream_part("0", text)
|
|
|
|
|
|
def create_data_part(data: list) -> str:
|
|
"""Create data part (type 2) for additional data"""
|
|
return format_data_stream_part("2", data)
|
|
|
|
|
|
def create_error_part(error: str) -> str:
|
|
"""Create error part (type 3)"""
|
|
return format_data_stream_part("3", error)
|
|
|
|
|
|
def create_tool_call_part(tool_call_id: str, tool_name: str, args: dict) -> str:
|
|
"""Create tool call part (type 9)"""
|
|
return format_data_stream_part("9", {
|
|
"toolCallId": tool_call_id,
|
|
"toolName": tool_name,
|
|
"args": args
|
|
})
|
|
|
|
|
|
def create_tool_result_part(tool_call_id: str, result: Any) -> str:
|
|
"""Create tool result part (type a)"""
|
|
return format_data_stream_part("a", {
|
|
"toolCallId": tool_call_id,
|
|
"result": result
|
|
})
|
|
|
|
|
|
def create_finish_step_part(finish_reason: str = "stop", usage: Dict[str, int] | None = None, is_continued: bool = False) -> str:
|
|
"""Create finish step part (type e)"""
|
|
usage = usage or {"promptTokens": 0, "completionTokens": 0}
|
|
return format_data_stream_part("e", {
|
|
"finishReason": finish_reason,
|
|
"usage": usage,
|
|
"isContinued": is_continued
|
|
})
|
|
|
|
|
|
def create_finish_message_part(finish_reason: str = "stop", usage: Dict[str, int] | None = None) -> str:
|
|
"""Create finish message part (type d)"""
|
|
usage = usage or {"promptTokens": 0, "completionTokens": 0}
|
|
return format_data_stream_part("d", {
|
|
"finishReason": finish_reason,
|
|
"usage": usage
|
|
})
|
|
|
|
|
|
class AISDKEventAdapter:
|
|
"""Adapter to convert our internal events to AI SDK Data Stream Protocol format"""
|
|
|
|
def __init__(self):
|
|
self.tool_calls = {} # Track tool calls
|
|
self.current_message_id = str(uuid.uuid4())
|
|
|
|
def convert_event(self, event_line: str) -> str | None:
|
|
"""Convert our SSE event to AI SDK Data Stream Protocol format"""
|
|
if not event_line.strip():
|
|
return None
|
|
|
|
try:
|
|
# Handle multi-line SSE format
|
|
lines = event_line.strip().split('\n')
|
|
event_type = None
|
|
data = None
|
|
|
|
for line in lines:
|
|
if line.startswith("event: "):
|
|
event_type = line.replace("event: ", "")
|
|
elif line.startswith("data: "):
|
|
data_str = line[6:] # Remove "data: "
|
|
if data_str:
|
|
data = json.loads(data_str)
|
|
|
|
if event_type and data:
|
|
return self._convert_by_type(event_type, data)
|
|
|
|
except (json.JSONDecodeError, IndexError, KeyError) as e:
|
|
# Skip malformed events
|
|
return None
|
|
|
|
return None
|
|
|
|
def _convert_by_type(self, event_type: str, data: Dict[str, Any]) -> str | None:
|
|
"""Convert event by type to Data Stream Protocol format"""
|
|
|
|
if event_type == "tokens":
|
|
# Token streaming -> text part (type 0)
|
|
delta = data.get("delta", "")
|
|
if delta:
|
|
return create_text_part(delta)
|
|
|
|
elif event_type == "tool_start":
|
|
# Tool start -> tool call part (type 9)
|
|
tool_id = data.get("id", str(uuid.uuid4()))
|
|
tool_name = data.get("name", "unknown")
|
|
args = data.get("args", {})
|
|
self.tool_calls[tool_id] = {"name": tool_name, "args": args}
|
|
return create_tool_call_part(tool_id, tool_name, args)
|
|
|
|
elif event_type == "tool_result":
|
|
# Tool result -> tool result part (type a)
|
|
tool_id = data.get("id", "")
|
|
results = data.get("results", [])
|
|
return create_tool_result_part(tool_id, results)
|
|
|
|
elif event_type == "tool_error":
|
|
# Tool error -> error part (type 3)
|
|
error = data.get("error", "Tool execution failed")
|
|
return create_error_part(error)
|
|
|
|
elif event_type == "error":
|
|
# Error -> error part (type 3)
|
|
error = data.get("error", "Unknown error")
|
|
return create_error_part(error)
|
|
|
|
return None
|
|
|
|
|
|
async def stream_ai_sdk_compatible(internal_stream: AsyncGenerator[str, None]) -> AsyncGenerator[str, None]:
|
|
"""Convert our internal SSE stream to AI SDK Data Stream Protocol compatible format"""
|
|
adapter = AISDKEventAdapter()
|
|
|
|
async for event in internal_stream:
|
|
converted = adapter.convert_event(event)
|
|
if converted:
|
|
yield converted
|