Files
crewai/stream_manager.py

284 lines
9.1 KiB
Python
Raw Normal View History

2026-03-13 14:20:58 +08:00
"""
SSE 流管理器
负责管理任务执行过程中的消息队列和 SSE 连接
确保多用户并发时不同 task_id 的流互不干扰
关键技术点
1. 使用 asyncio.Queue 实现异步非阻塞消息队列
2. 通过 asyncio.Lock 保证并发安全
3. 每个 task_id 独立队列实现任务隔离
4. 支持从同步线程CrewAI安全地发布事件到异步队列
"""
import asyncio
from datetime import datetime
from typing import Dict, AsyncGenerator, Optional, Any
from collections import deque
import uuid
import threading
from concurrent.futures import ThreadPoolExecutor
class StreamEvent:
"""
SSE 事件数据结构
统一的 JSON 格式设计便于前端解析
{
"task_id": "550e8400-e29b...",
"sequence": 1,
"agent_name": "ProductManager",
"event_type": "thought", // action, output, complete
"content": "正在分析用户需求...",
"timestamp": "2023-10-27T10:00:00Z"
}
"""
# 全局序列号计数器(每个 task_id 独立)
_sequence_counters: Dict[str, int] = {}
def __init__(
self,
event_type: str,
agent: str,
content: str,
task_id: str,
timestamp: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
):
self.event_type = event_type # start, thought, action, output, end, error
self.agent = agent
self.content = content
self.task_id = task_id
self.timestamp = timestamp or datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ')
self.metadata = metadata or {}
# 为每个 task_id 维护独立的序列号
if task_id not in StreamEvent._sequence_counters:
StreamEvent._sequence_counters[task_id] = 0
StreamEvent._sequence_counters[task_id] += 1
self.sequence = StreamEvent._sequence_counters[task_id]
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式用于 JSON 序列化"""
return {
"task_id": self.task_id,
"sequence": self.sequence,
"agent_name": self.agent,
"event_type": self.event_type,
"content": self.content,
"timestamp": self.timestamp,
**(self.metadata if self.metadata else {})
}
def to_sse_format(self) -> str:
"""转换为 SSE 数据格式"""
import json
return f"data: {json.dumps(self.to_dict(), ensure_ascii=False)}\n\n"
@classmethod
def reset_sequence(cls, task_id: str):
"""重置指定 task_id 的序列号(任务重新开始时调用)"""
cls._sequence_counters[task_id] = 0
class TaskStreamQueue:
"""
单个任务的流式消息队列线程安全
并发处理逻辑
- CrewAI 默认是同步运行的 FastAPI SSE 需要异步
- 使用 asyncio.Queue run_coroutine_threadsafe 方法从同步线程安全地发布事件
- 确保 stream_manager 能安全地在线程间传递消息
"""
def __init__(self, task_id: str, max_size: int = 1000):
self.task_id = task_id
self.queue: asyncio.Queue[StreamEvent] = asyncio.Queue(maxsize=max_size)
self.is_closed = False
self.subscribers: int = 0
self._loop = asyncio.get_event_loop()
self._lock = threading.Lock() # 用于保护同步操作
async def put(self, event: StreamEvent) -> bool:
"""向队列添加事件(异步调用)"""
if self.is_closed:
return False
try:
await asyncio.wait_for(self.queue.put(event), timeout=5.0)
return True
except asyncio.TimeoutError:
return False
except Exception:
return False
def put_nowait(self, event: StreamEvent) -> bool:
"""
从同步线程 CrewAI 事件处理器安全地发布事件
使用 run_coroutine_threadsafe 将协程提交到事件循环执行
这是实现 CrewAI同步 SSE异步集成的关键
"""
if self.is_closed:
return False
try:
# 将协程提交到事件循环线程安全地执行
future = asyncio.run_coroutine_threadsafe(
self.queue.put(event),
self._loop
)
# 等待完成(带超时)
future.result(timeout=5.0)
return True
except Exception as e:
# print(f"发布事件失败:{e}")
return False
async def get(self) -> Optional[StreamEvent]:
"""从队列获取事件"""
if self.is_closed and self.queue.empty():
return None
try:
return await self.queue.get()
except Exception:
return None
def close(self):
"""关闭队列"""
with self._lock:
self.is_closed = True
async def stream_events(self) -> AsyncGenerator[StreamEvent, None]:
"""异步生成器,持续产出事件直到队列关闭"""
while not (self.is_closed and self.queue.empty()):
try:
event = await asyncio.wait_for(self.queue.get(), timeout=30.0)
yield event
except asyncio.TimeoutError:
if self.is_closed and self.queue.empty():
break
continue
except Exception:
break
class StreamManager:
"""全局流管理器 - 管理所有任务的 SSE 流"""
_instance: Optional['StreamManager'] = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._initialized = True
# task_id -> TaskStreamQueue 映射
self.streams: Dict[str, TaskStreamQueue] = {}
self._lock = asyncio.Lock()
async def create_stream(self, task_id: str) -> TaskStreamQueue:
"""为指定 task_id 创建新的流队列"""
async with self._lock:
if task_id in self.streams:
# 如果已存在,先关闭旧的
old_stream = self.streams[task_id]
old_stream.close()
# 重置序列号计数器
StreamEvent.reset_sequence(task_id)
queue = TaskStreamQueue(task_id)
self.streams[task_id] = queue
return queue
async def get_stream(self, task_id: str) -> Optional[TaskStreamQueue]:
"""获取指定 task_id 的流队列"""
async with self._lock:
return self.streams.get(task_id)
async def publish_event(
self,
task_id: str,
event_type: str,
agent: str,
content: str,
metadata: Optional[Dict[str, Any]] = None
) -> bool:
"""发布事件到指定任务的流队列"""
async with self._lock:
stream = self.streams.get(task_id)
if stream is None:
return False
event = StreamEvent(
event_type=event_type,
agent=agent,
content=content,
task_id=task_id,
metadata=metadata
)
return await stream.put(event)
async def close_stream(self, task_id: str):
"""关闭指定任务的流队列"""
async with self._lock:
if task_id in self.streams:
self.streams[task_id].close()
# 可以选择删除或保留(如果需要历史记录)
# del self.streams[task_id]
async def cleanup_old_streams(self, max_age_seconds: int = 3600):
"""清理超过指定时间的旧流(定期调用)"""
now = datetime.now()
to_remove = []
async with self._lock:
for task_id, stream in self.streams.items():
if stream.is_closed:
to_remove.append(task_id)
async with self._lock:
for task_id in to_remove:
del self.streams[task_id]
def list_active_streams(self) -> list:
"""列出所有活跃的流"""
return [
{"task_id": tid, "closed": s.is_closed}
for tid, s in self.streams.items()
]
# 全局单例
stream_manager = StreamManager()
async def create_sse_generator(task_id: str) -> AsyncGenerator[str, None]:
"""
创建 SSE 异步生成器
FastAPI StreamingResponse 使用
"""
stream = await stream_manager.get_stream(task_id)
if stream is None:
# 创建一个新的流(如果不存在)
stream = await stream_manager.create_stream(task_id)
try:
async for event in stream.stream_events():
yield event.to_sse_format()
finally:
# 发送结束标记
import json
end_event = {
"type": "connection_end",
"task_id": task_id,
"timestamp": datetime.now().isoformat()
}
yield f"data: {json.dumps(end_event)}\n\n"