init
This commit is contained in:
146
vw-agentic-rag/service/ai_sdk_adapter.py
Normal file
146
vw-agentic-rag/service/ai_sdk_adapter.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
AI SDK Data Stream Protocol adapter
|
||||
Converts our internal SSE events to AI SDK compatible format
|
||||
Following the official Data Stream Protocol: TYPE_ID:CONTENT_JSON\n
|
||||
"""
|
||||
import json
|
||||
import uuid
|
||||
from typing import Dict, Any, AsyncGenerator
|
||||
|
||||
|
||||
def format_data_stream_part(type_id: str, content: Any) -> str:
|
||||
"""Format data as AI SDK Data Stream Protocol part: TYPE_ID:JSON\n"""
|
||||
content_json = json.dumps(content, ensure_ascii=False)
|
||||
return f"{type_id}:{content_json}\n"
|
||||
|
||||
|
||||
def create_text_part(text: str) -> str:
|
||||
"""Create text part (type 0)"""
|
||||
return format_data_stream_part("0", text)
|
||||
|
||||
|
||||
def create_data_part(data: list) -> str:
|
||||
"""Create data part (type 2) for additional data"""
|
||||
return format_data_stream_part("2", data)
|
||||
|
||||
|
||||
def create_error_part(error: str) -> str:
|
||||
"""Create error part (type 3)"""
|
||||
return format_data_stream_part("3", error)
|
||||
|
||||
|
||||
def create_tool_call_part(tool_call_id: str, tool_name: str, args: dict) -> str:
|
||||
"""Create tool call part (type 9)"""
|
||||
return format_data_stream_part("9", {
|
||||
"toolCallId": tool_call_id,
|
||||
"toolName": tool_name,
|
||||
"args": args
|
||||
})
|
||||
|
||||
|
||||
def create_tool_result_part(tool_call_id: str, result: Any) -> str:
|
||||
"""Create tool result part (type a)"""
|
||||
return format_data_stream_part("a", {
|
||||
"toolCallId": tool_call_id,
|
||||
"result": result
|
||||
})
|
||||
|
||||
|
||||
def create_finish_step_part(finish_reason: str = "stop", usage: Dict[str, int] | None = None, is_continued: bool = False) -> str:
|
||||
"""Create finish step part (type e)"""
|
||||
usage = usage or {"promptTokens": 0, "completionTokens": 0}
|
||||
return format_data_stream_part("e", {
|
||||
"finishReason": finish_reason,
|
||||
"usage": usage,
|
||||
"isContinued": is_continued
|
||||
})
|
||||
|
||||
|
||||
def create_finish_message_part(finish_reason: str = "stop", usage: Dict[str, int] | None = None) -> str:
|
||||
"""Create finish message part (type d)"""
|
||||
usage = usage or {"promptTokens": 0, "completionTokens": 0}
|
||||
return format_data_stream_part("d", {
|
||||
"finishReason": finish_reason,
|
||||
"usage": usage
|
||||
})
|
||||
|
||||
|
||||
class AISDKEventAdapter:
|
||||
"""Adapter to convert our internal events to AI SDK Data Stream Protocol format"""
|
||||
|
||||
def __init__(self):
|
||||
self.tool_calls = {} # Track tool calls
|
||||
self.current_message_id = str(uuid.uuid4())
|
||||
|
||||
def convert_event(self, event_line: str) -> str | None:
|
||||
"""Convert our SSE event to AI SDK Data Stream Protocol format"""
|
||||
if not event_line.strip():
|
||||
return None
|
||||
|
||||
try:
|
||||
# Handle multi-line SSE format
|
||||
lines = event_line.strip().split('\n')
|
||||
event_type = None
|
||||
data = None
|
||||
|
||||
for line in lines:
|
||||
if line.startswith("event: "):
|
||||
event_type = line.replace("event: ", "")
|
||||
elif line.startswith("data: "):
|
||||
data_str = line[6:] # Remove "data: "
|
||||
if data_str:
|
||||
data = json.loads(data_str)
|
||||
|
||||
if event_type and data:
|
||||
return self._convert_by_type(event_type, data)
|
||||
|
||||
except (json.JSONDecodeError, IndexError, KeyError) as e:
|
||||
# Skip malformed events
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _convert_by_type(self, event_type: str, data: Dict[str, Any]) -> str | None:
|
||||
"""Convert event by type to Data Stream Protocol format"""
|
||||
|
||||
if event_type == "tokens":
|
||||
# Token streaming -> text part (type 0)
|
||||
delta = data.get("delta", "")
|
||||
if delta:
|
||||
return create_text_part(delta)
|
||||
|
||||
elif event_type == "tool_start":
|
||||
# Tool start -> tool call part (type 9)
|
||||
tool_id = data.get("id", str(uuid.uuid4()))
|
||||
tool_name = data.get("name", "unknown")
|
||||
args = data.get("args", {})
|
||||
self.tool_calls[tool_id] = {"name": tool_name, "args": args}
|
||||
return create_tool_call_part(tool_id, tool_name, args)
|
||||
|
||||
elif event_type == "tool_result":
|
||||
# Tool result -> tool result part (type a)
|
||||
tool_id = data.get("id", "")
|
||||
results = data.get("results", [])
|
||||
return create_tool_result_part(tool_id, results)
|
||||
|
||||
elif event_type == "tool_error":
|
||||
# Tool error -> error part (type 3)
|
||||
error = data.get("error", "Tool execution failed")
|
||||
return create_error_part(error)
|
||||
|
||||
elif event_type == "error":
|
||||
# Error -> error part (type 3)
|
||||
error = data.get("error", "Unknown error")
|
||||
return create_error_part(error)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def stream_ai_sdk_compatible(internal_stream: AsyncGenerator[str, None]) -> AsyncGenerator[str, None]:
|
||||
"""Convert our internal SSE stream to AI SDK Data Stream Protocol compatible format"""
|
||||
adapter = AISDKEventAdapter()
|
||||
|
||||
async for event in internal_stream:
|
||||
converted = adapter.convert_event(event)
|
||||
if converted:
|
||||
yield converted
|
||||
Reference in New Issue
Block a user