284 lines
9.1 KiB
Python
284 lines
9.1 KiB
Python
|
|
"""
|
|||
|
|
SSE 流管理器
|
|||
|
|
负责管理任务执行过程中的消息队列和 SSE 连接
|
|||
|
|
确保多用户并发时不同 task_id 的流互不干扰
|
|||
|
|
|
|||
|
|
关键技术点:
|
|||
|
|
1. 使用 asyncio.Queue 实现异步非阻塞消息队列
|
|||
|
|
2. 通过 asyncio.Lock 保证并发安全
|
|||
|
|
3. 每个 task_id 独立队列,实现任务隔离
|
|||
|
|
4. 支持从同步线程(CrewAI)安全地发布事件到异步队列
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import asyncio
|
|||
|
|
from datetime import datetime
|
|||
|
|
from typing import Dict, AsyncGenerator, Optional, Any
|
|||
|
|
from collections import deque
|
|||
|
|
import uuid
|
|||
|
|
import threading
|
|||
|
|
from concurrent.futures import ThreadPoolExecutor
|
|||
|
|
|
|||
|
|
|
|||
|
|
class StreamEvent:
|
|||
|
|
"""
|
|||
|
|
SSE 事件数据结构
|
|||
|
|
|
|||
|
|
统一的 JSON 格式设计,便于前端解析:
|
|||
|
|
{
|
|||
|
|
"task_id": "550e8400-e29b...",
|
|||
|
|
"sequence": 1,
|
|||
|
|
"agent_name": "ProductManager",
|
|||
|
|
"event_type": "thought", // 或 action, output, complete
|
|||
|
|
"content": "正在分析用户需求...",
|
|||
|
|
"timestamp": "2023-10-27T10:00:00Z"
|
|||
|
|
}
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# 全局序列号计数器(每个 task_id 独立)
|
|||
|
|
_sequence_counters: Dict[str, int] = {}
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
event_type: str,
|
|||
|
|
agent: str,
|
|||
|
|
content: str,
|
|||
|
|
task_id: str,
|
|||
|
|
timestamp: Optional[str] = None,
|
|||
|
|
metadata: Optional[Dict[str, Any]] = None
|
|||
|
|
):
|
|||
|
|
self.event_type = event_type # start, thought, action, output, end, error
|
|||
|
|
self.agent = agent
|
|||
|
|
self.content = content
|
|||
|
|
self.task_id = task_id
|
|||
|
|
self.timestamp = timestamp or datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ')
|
|||
|
|
self.metadata = metadata or {}
|
|||
|
|
|
|||
|
|
# 为每个 task_id 维护独立的序列号
|
|||
|
|
if task_id not in StreamEvent._sequence_counters:
|
|||
|
|
StreamEvent._sequence_counters[task_id] = 0
|
|||
|
|
StreamEvent._sequence_counters[task_id] += 1
|
|||
|
|
self.sequence = StreamEvent._sequence_counters[task_id]
|
|||
|
|
|
|||
|
|
def to_dict(self) -> Dict[str, Any]:
|
|||
|
|
"""转换为字典格式用于 JSON 序列化"""
|
|||
|
|
return {
|
|||
|
|
"task_id": self.task_id,
|
|||
|
|
"sequence": self.sequence,
|
|||
|
|
"agent_name": self.agent,
|
|||
|
|
"event_type": self.event_type,
|
|||
|
|
"content": self.content,
|
|||
|
|
"timestamp": self.timestamp,
|
|||
|
|
**(self.metadata if self.metadata else {})
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def to_sse_format(self) -> str:
|
|||
|
|
"""转换为 SSE 数据格式"""
|
|||
|
|
import json
|
|||
|
|
return f"data: {json.dumps(self.to_dict(), ensure_ascii=False)}\n\n"
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def reset_sequence(cls, task_id: str):
|
|||
|
|
"""重置指定 task_id 的序列号(任务重新开始时调用)"""
|
|||
|
|
cls._sequence_counters[task_id] = 0
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TaskStreamQueue:
|
|||
|
|
"""
|
|||
|
|
单个任务的流式消息队列(线程安全)
|
|||
|
|
|
|||
|
|
并发处理逻辑:
|
|||
|
|
- CrewAI 默认是同步运行的,而 FastAPI 和 SSE 需要异步
|
|||
|
|
- 使用 asyncio.Queue 的 run_coroutine_threadsafe 方法从同步线程安全地发布事件
|
|||
|
|
- 确保 stream_manager 能安全地在线程间传递消息
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, task_id: str, max_size: int = 1000):
|
|||
|
|
self.task_id = task_id
|
|||
|
|
self.queue: asyncio.Queue[StreamEvent] = asyncio.Queue(maxsize=max_size)
|
|||
|
|
self.is_closed = False
|
|||
|
|
self.subscribers: int = 0
|
|||
|
|
self._loop = asyncio.get_event_loop()
|
|||
|
|
self._lock = threading.Lock() # 用于保护同步操作
|
|||
|
|
|
|||
|
|
async def put(self, event: StreamEvent) -> bool:
|
|||
|
|
"""向队列添加事件(异步调用)"""
|
|||
|
|
if self.is_closed:
|
|||
|
|
return False
|
|||
|
|
try:
|
|||
|
|
await asyncio.wait_for(self.queue.put(event), timeout=5.0)
|
|||
|
|
return True
|
|||
|
|
except asyncio.TimeoutError:
|
|||
|
|
return False
|
|||
|
|
except Exception:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
def put_nowait(self, event: StreamEvent) -> bool:
|
|||
|
|
"""
|
|||
|
|
从同步线程(如 CrewAI 事件处理器)安全地发布事件
|
|||
|
|
|
|||
|
|
使用 run_coroutine_threadsafe 将协程提交到事件循环执行
|
|||
|
|
这是实现 CrewAI(同步)与 SSE(异步)集成的关键
|
|||
|
|
"""
|
|||
|
|
if self.is_closed:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 将协程提交到事件循环线程安全地执行
|
|||
|
|
future = asyncio.run_coroutine_threadsafe(
|
|||
|
|
self.queue.put(event),
|
|||
|
|
self._loop
|
|||
|
|
)
|
|||
|
|
# 等待完成(带超时)
|
|||
|
|
future.result(timeout=5.0)
|
|||
|
|
return True
|
|||
|
|
except Exception as e:
|
|||
|
|
# print(f"发布事件失败:{e}")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
async def get(self) -> Optional[StreamEvent]:
|
|||
|
|
"""从队列获取事件"""
|
|||
|
|
if self.is_closed and self.queue.empty():
|
|||
|
|
return None
|
|||
|
|
try:
|
|||
|
|
return await self.queue.get()
|
|||
|
|
except Exception:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def close(self):
|
|||
|
|
"""关闭队列"""
|
|||
|
|
with self._lock:
|
|||
|
|
self.is_closed = True
|
|||
|
|
|
|||
|
|
async def stream_events(self) -> AsyncGenerator[StreamEvent, None]:
|
|||
|
|
"""异步生成器,持续产出事件直到队列关闭"""
|
|||
|
|
while not (self.is_closed and self.queue.empty()):
|
|||
|
|
try:
|
|||
|
|
event = await asyncio.wait_for(self.queue.get(), timeout=30.0)
|
|||
|
|
yield event
|
|||
|
|
except asyncio.TimeoutError:
|
|||
|
|
if self.is_closed and self.queue.empty():
|
|||
|
|
break
|
|||
|
|
continue
|
|||
|
|
except Exception:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
|
|||
|
|
class StreamManager:
|
|||
|
|
"""全局流管理器 - 管理所有任务的 SSE 流"""
|
|||
|
|
|
|||
|
|
_instance: Optional['StreamManager'] = None
|
|||
|
|
|
|||
|
|
def __new__(cls):
|
|||
|
|
if cls._instance is None:
|
|||
|
|
cls._instance = super().__new__(cls)
|
|||
|
|
cls._instance._initialized = False
|
|||
|
|
return cls._instance
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
if self._initialized:
|
|||
|
|
return
|
|||
|
|
self._initialized = True
|
|||
|
|
# task_id -> TaskStreamQueue 映射
|
|||
|
|
self.streams: Dict[str, TaskStreamQueue] = {}
|
|||
|
|
self._lock = asyncio.Lock()
|
|||
|
|
|
|||
|
|
async def create_stream(self, task_id: str) -> TaskStreamQueue:
|
|||
|
|
"""为指定 task_id 创建新的流队列"""
|
|||
|
|
async with self._lock:
|
|||
|
|
if task_id in self.streams:
|
|||
|
|
# 如果已存在,先关闭旧的
|
|||
|
|
old_stream = self.streams[task_id]
|
|||
|
|
old_stream.close()
|
|||
|
|
|
|||
|
|
# 重置序列号计数器
|
|||
|
|
StreamEvent.reset_sequence(task_id)
|
|||
|
|
|
|||
|
|
queue = TaskStreamQueue(task_id)
|
|||
|
|
self.streams[task_id] = queue
|
|||
|
|
return queue
|
|||
|
|
|
|||
|
|
async def get_stream(self, task_id: str) -> Optional[TaskStreamQueue]:
|
|||
|
|
"""获取指定 task_id 的流队列"""
|
|||
|
|
async with self._lock:
|
|||
|
|
return self.streams.get(task_id)
|
|||
|
|
|
|||
|
|
async def publish_event(
|
|||
|
|
self,
|
|||
|
|
task_id: str,
|
|||
|
|
event_type: str,
|
|||
|
|
agent: str,
|
|||
|
|
content: str,
|
|||
|
|
metadata: Optional[Dict[str, Any]] = None
|
|||
|
|
) -> bool:
|
|||
|
|
"""发布事件到指定任务的流队列"""
|
|||
|
|
async with self._lock:
|
|||
|
|
stream = self.streams.get(task_id)
|
|||
|
|
if stream is None:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
event = StreamEvent(
|
|||
|
|
event_type=event_type,
|
|||
|
|
agent=agent,
|
|||
|
|
content=content,
|
|||
|
|
task_id=task_id,
|
|||
|
|
metadata=metadata
|
|||
|
|
)
|
|||
|
|
return await stream.put(event)
|
|||
|
|
|
|||
|
|
async def close_stream(self, task_id: str):
|
|||
|
|
"""关闭指定任务的流队列"""
|
|||
|
|
async with self._lock:
|
|||
|
|
if task_id in self.streams:
|
|||
|
|
self.streams[task_id].close()
|
|||
|
|
# 可以选择删除或保留(如果需要历史记录)
|
|||
|
|
# del self.streams[task_id]
|
|||
|
|
|
|||
|
|
async def cleanup_old_streams(self, max_age_seconds: int = 3600):
|
|||
|
|
"""清理超过指定时间的旧流(定期调用)"""
|
|||
|
|
now = datetime.now()
|
|||
|
|
to_remove = []
|
|||
|
|
|
|||
|
|
async with self._lock:
|
|||
|
|
for task_id, stream in self.streams.items():
|
|||
|
|
if stream.is_closed:
|
|||
|
|
to_remove.append(task_id)
|
|||
|
|
|
|||
|
|
async with self._lock:
|
|||
|
|
for task_id in to_remove:
|
|||
|
|
del self.streams[task_id]
|
|||
|
|
|
|||
|
|
def list_active_streams(self) -> list:
|
|||
|
|
"""列出所有活跃的流"""
|
|||
|
|
return [
|
|||
|
|
{"task_id": tid, "closed": s.is_closed}
|
|||
|
|
for tid, s in self.streams.items()
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全局单例
|
|||
|
|
stream_manager = StreamManager()
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def create_sse_generator(task_id: str) -> AsyncGenerator[str, None]:
|
|||
|
|
"""
|
|||
|
|
创建 SSE 异步生成器
|
|||
|
|
供 FastAPI StreamingResponse 使用
|
|||
|
|
"""
|
|||
|
|
stream = await stream_manager.get_stream(task_id)
|
|||
|
|
if stream is None:
|
|||
|
|
# 创建一个新的流(如果不存在)
|
|||
|
|
stream = await stream_manager.create_stream(task_id)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
async for event in stream.stream_events():
|
|||
|
|
yield event.to_sse_format()
|
|||
|
|
finally:
|
|||
|
|
# 发送结束标记
|
|||
|
|
import json
|
|||
|
|
end_event = {
|
|||
|
|
"type": "connection_end",
|
|||
|
|
"task_id": task_id,
|
|||
|
|
"timestamp": datetime.now().isoformat()
|
|||
|
|
}
|
|||
|
|
yield f"data: {json.dumps(end_event)}\n\n"
|