""" SSE 流管理器 负责管理任务执行过程中的消息队列和 SSE 连接 确保多用户并发时不同 task_id 的流互不干扰 关键技术点: 1. 使用 asyncio.Queue 实现异步非阻塞消息队列 2. 通过 asyncio.Lock 保证并发安全 3. 每个 task_id 独立队列,实现任务隔离 4. 支持从同步线程(CrewAI)安全地发布事件到异步队列 """ import asyncio from datetime import datetime from typing import Dict, AsyncGenerator, Optional, Any from collections import deque import uuid import threading from concurrent.futures import ThreadPoolExecutor class StreamEvent: """ SSE 事件数据结构 统一的 JSON 格式设计,便于前端解析: { "task_id": "550e8400-e29b...", "sequence": 1, "agent_name": "ProductManager", "event_type": "thought", // 或 action, output, complete "content": "正在分析用户需求...", "timestamp": "2023-10-27T10:00:00Z" } """ # 全局序列号计数器(每个 task_id 独立) _sequence_counters: Dict[str, int] = {} def __init__( self, event_type: str, agent: str, content: str, task_id: str, timestamp: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None ): self.event_type = event_type # start, thought, action, output, end, error self.agent = agent self.content = content self.task_id = task_id self.timestamp = timestamp or datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ') self.metadata = metadata or {} # 为每个 task_id 维护独立的序列号 if task_id not in StreamEvent._sequence_counters: StreamEvent._sequence_counters[task_id] = 0 StreamEvent._sequence_counters[task_id] += 1 self.sequence = StreamEvent._sequence_counters[task_id] def to_dict(self) -> Dict[str, Any]: """转换为字典格式用于 JSON 序列化""" return { "task_id": self.task_id, "sequence": self.sequence, "agent_name": self.agent, "event_type": self.event_type, "content": self.content, "timestamp": self.timestamp, **(self.metadata if self.metadata else {}) } def to_sse_format(self) -> str: """转换为 SSE 数据格式""" import json return f"data: {json.dumps(self.to_dict(), ensure_ascii=False)}\n\n" @classmethod def reset_sequence(cls, task_id: str): """重置指定 task_id 的序列号(任务重新开始时调用)""" cls._sequence_counters[task_id] = 0 class TaskStreamQueue: """ 单个任务的流式消息队列(线程安全) 并发处理逻辑: - CrewAI 默认是同步运行的,而 FastAPI 和 SSE 需要异步 - 使用 asyncio.Queue 的 run_coroutine_threadsafe 方法从同步线程安全地发布事件 - 确保 stream_manager 能安全地在线程间传递消息 """ def __init__(self, task_id: str, max_size: int = 1000): self.task_id = task_id self.queue: asyncio.Queue[StreamEvent] = asyncio.Queue(maxsize=max_size) self.is_closed = False self.subscribers: int = 0 self._loop = asyncio.get_event_loop() self._lock = threading.Lock() # 用于保护同步操作 async def put(self, event: StreamEvent) -> bool: """向队列添加事件(异步调用)""" if self.is_closed: return False try: await asyncio.wait_for(self.queue.put(event), timeout=5.0) return True except asyncio.TimeoutError: return False except Exception: return False def put_nowait(self, event: StreamEvent) -> bool: """ 从同步线程(如 CrewAI 事件处理器)安全地发布事件 使用 run_coroutine_threadsafe 将协程提交到事件循环执行 这是实现 CrewAI(同步)与 SSE(异步)集成的关键 """ if self.is_closed: return False try: # 将协程提交到事件循环线程安全地执行 future = asyncio.run_coroutine_threadsafe( self.queue.put(event), self._loop ) # 等待完成(带超时) future.result(timeout=5.0) return True except Exception as e: # print(f"发布事件失败:{e}") return False async def get(self) -> Optional[StreamEvent]: """从队列获取事件""" if self.is_closed and self.queue.empty(): return None try: return await self.queue.get() except Exception: return None def close(self): """关闭队列""" with self._lock: self.is_closed = True async def stream_events(self) -> AsyncGenerator[StreamEvent, None]: """异步生成器,持续产出事件直到队列关闭""" while not (self.is_closed and self.queue.empty()): try: event = await asyncio.wait_for(self.queue.get(), timeout=30.0) yield event except asyncio.TimeoutError: if self.is_closed and self.queue.empty(): break continue except Exception: break class StreamManager: """全局流管理器 - 管理所有任务的 SSE 流""" _instance: Optional['StreamManager'] = None def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self): if self._initialized: return self._initialized = True # task_id -> TaskStreamQueue 映射 self.streams: Dict[str, TaskStreamQueue] = {} self._lock = asyncio.Lock() async def create_stream(self, task_id: str) -> TaskStreamQueue: """为指定 task_id 创建新的流队列""" async with self._lock: if task_id in self.streams: # 如果已存在,先关闭旧的 old_stream = self.streams[task_id] old_stream.close() # 重置序列号计数器 StreamEvent.reset_sequence(task_id) queue = TaskStreamQueue(task_id) self.streams[task_id] = queue return queue async def get_stream(self, task_id: str) -> Optional[TaskStreamQueue]: """获取指定 task_id 的流队列""" async with self._lock: return self.streams.get(task_id) async def publish_event( self, task_id: str, event_type: str, agent: str, content: str, metadata: Optional[Dict[str, Any]] = None ) -> bool: """发布事件到指定任务的流队列""" async with self._lock: stream = self.streams.get(task_id) if stream is None: return False event = StreamEvent( event_type=event_type, agent=agent, content=content, task_id=task_id, metadata=metadata ) return await stream.put(event) async def close_stream(self, task_id: str): """关闭指定任务的流队列""" async with self._lock: if task_id in self.streams: self.streams[task_id].close() # 可以选择删除或保留(如果需要历史记录) # del self.streams[task_id] async def cleanup_old_streams(self, max_age_seconds: int = 3600): """清理超过指定时间的旧流(定期调用)""" now = datetime.now() to_remove = [] async with self._lock: for task_id, stream in self.streams.items(): if stream.is_closed: to_remove.append(task_id) async with self._lock: for task_id in to_remove: del self.streams[task_id] def list_active_streams(self) -> list: """列出所有活跃的流""" return [ {"task_id": tid, "closed": s.is_closed} for tid, s in self.streams.items() ] # 全局单例 stream_manager = StreamManager() async def create_sse_generator(task_id: str) -> AsyncGenerator[str, None]: """ 创建 SSE 异步生成器 供 FastAPI StreamingResponse 使用 """ stream = await stream_manager.get_stream(task_id) if stream is None: # 创建一个新的流(如果不存在) stream = await stream_manager.create_stream(task_id) try: async for event in stream.stream_events(): yield event.to_sse_format() finally: # 发送结束标记 import json end_event = { "type": "connection_end", "task_id": task_id, "timestamp": datetime.now().isoformat() } yield f"data: {json.dumps(end_event)}\n\n"