Files
catonline_ai/vw-agentic-rag/service/ai_sdk_adapter.py
2025-09-26 17:15:54 +08:00

147 lines
5.0 KiB
Python

"""
AI SDK Data Stream Protocol adapter
Converts our internal SSE events to AI SDK compatible format
Following the official Data Stream Protocol: TYPE_ID:CONTENT_JSON\n
"""
import json
import uuid
from typing import Dict, Any, AsyncGenerator
def format_data_stream_part(type_id: str, content: Any) -> str:
"""Format data as AI SDK Data Stream Protocol part: TYPE_ID:JSON\n"""
content_json = json.dumps(content, ensure_ascii=False)
return f"{type_id}:{content_json}\n"
def create_text_part(text: str) -> str:
"""Create text part (type 0)"""
return format_data_stream_part("0", text)
def create_data_part(data: list) -> str:
"""Create data part (type 2) for additional data"""
return format_data_stream_part("2", data)
def create_error_part(error: str) -> str:
"""Create error part (type 3)"""
return format_data_stream_part("3", error)
def create_tool_call_part(tool_call_id: str, tool_name: str, args: dict) -> str:
"""Create tool call part (type 9)"""
return format_data_stream_part("9", {
"toolCallId": tool_call_id,
"toolName": tool_name,
"args": args
})
def create_tool_result_part(tool_call_id: str, result: Any) -> str:
"""Create tool result part (type a)"""
return format_data_stream_part("a", {
"toolCallId": tool_call_id,
"result": result
})
def create_finish_step_part(finish_reason: str = "stop", usage: Dict[str, int] | None = None, is_continued: bool = False) -> str:
"""Create finish step part (type e)"""
usage = usage or {"promptTokens": 0, "completionTokens": 0}
return format_data_stream_part("e", {
"finishReason": finish_reason,
"usage": usage,
"isContinued": is_continued
})
def create_finish_message_part(finish_reason: str = "stop", usage: Dict[str, int] | None = None) -> str:
"""Create finish message part (type d)"""
usage = usage or {"promptTokens": 0, "completionTokens": 0}
return format_data_stream_part("d", {
"finishReason": finish_reason,
"usage": usage
})
class AISDKEventAdapter:
"""Adapter to convert our internal events to AI SDK Data Stream Protocol format"""
def __init__(self):
self.tool_calls = {} # Track tool calls
self.current_message_id = str(uuid.uuid4())
def convert_event(self, event_line: str) -> str | None:
"""Convert our SSE event to AI SDK Data Stream Protocol format"""
if not event_line.strip():
return None
try:
# Handle multi-line SSE format
lines = event_line.strip().split('\n')
event_type = None
data = None
for line in lines:
if line.startswith("event: "):
event_type = line.replace("event: ", "")
elif line.startswith("data: "):
data_str = line[6:] # Remove "data: "
if data_str:
data = json.loads(data_str)
if event_type and data:
return self._convert_by_type(event_type, data)
except (json.JSONDecodeError, IndexError, KeyError) as e:
# Skip malformed events
return None
return None
def _convert_by_type(self, event_type: str, data: Dict[str, Any]) -> str | None:
"""Convert event by type to Data Stream Protocol format"""
if event_type == "tokens":
# Token streaming -> text part (type 0)
delta = data.get("delta", "")
if delta:
return create_text_part(delta)
elif event_type == "tool_start":
# Tool start -> tool call part (type 9)
tool_id = data.get("id", str(uuid.uuid4()))
tool_name = data.get("name", "unknown")
args = data.get("args", {})
self.tool_calls[tool_id] = {"name": tool_name, "args": args}
return create_tool_call_part(tool_id, tool_name, args)
elif event_type == "tool_result":
# Tool result -> tool result part (type a)
tool_id = data.get("id", "")
results = data.get("results", [])
return create_tool_result_part(tool_id, results)
elif event_type == "tool_error":
# Tool error -> error part (type 3)
error = data.get("error", "Tool execution failed")
return create_error_part(error)
elif event_type == "error":
# Error -> error part (type 3)
error = data.get("error", "Unknown error")
return create_error_part(error)
return None
async def stream_ai_sdk_compatible(internal_stream: AsyncGenerator[str, None]) -> AsyncGenerator[str, None]:
"""Convert our internal SSE stream to AI SDK Data Stream Protocol compatible format"""
adapter = AISDKEventAdapter()
async for event in internal_stream:
converted = adapter.convert_event(event)
if converted:
yield converted