第一次提交
This commit is contained in:
283
stream_manager.py
Normal file
283
stream_manager.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
SSE 流管理器
|
||||
负责管理任务执行过程中的消息队列和 SSE 连接
|
||||
确保多用户并发时不同 task_id 的流互不干扰
|
||||
|
||||
关键技术点:
|
||||
1. 使用 asyncio.Queue 实现异步非阻塞消息队列
|
||||
2. 通过 asyncio.Lock 保证并发安全
|
||||
3. 每个 task_id 独立队列,实现任务隔离
|
||||
4. 支持从同步线程(CrewAI)安全地发布事件到异步队列
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Dict, AsyncGenerator, Optional, Any
|
||||
from collections import deque
|
||||
import uuid
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
class StreamEvent:
|
||||
"""
|
||||
SSE 事件数据结构
|
||||
|
||||
统一的 JSON 格式设计,便于前端解析:
|
||||
{
|
||||
"task_id": "550e8400-e29b...",
|
||||
"sequence": 1,
|
||||
"agent_name": "ProductManager",
|
||||
"event_type": "thought", // 或 action, output, complete
|
||||
"content": "正在分析用户需求...",
|
||||
"timestamp": "2023-10-27T10:00:00Z"
|
||||
}
|
||||
"""
|
||||
|
||||
# 全局序列号计数器(每个 task_id 独立)
|
||||
_sequence_counters: Dict[str, int] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_type: str,
|
||||
agent: str,
|
||||
content: str,
|
||||
task_id: str,
|
||||
timestamp: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
self.event_type = event_type # start, thought, action, output, end, error
|
||||
self.agent = agent
|
||||
self.content = content
|
||||
self.task_id = task_id
|
||||
self.timestamp = timestamp or datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
self.metadata = metadata or {}
|
||||
|
||||
# 为每个 task_id 维护独立的序列号
|
||||
if task_id not in StreamEvent._sequence_counters:
|
||||
StreamEvent._sequence_counters[task_id] = 0
|
||||
StreamEvent._sequence_counters[task_id] += 1
|
||||
self.sequence = StreamEvent._sequence_counters[task_id]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式用于 JSON 序列化"""
|
||||
return {
|
||||
"task_id": self.task_id,
|
||||
"sequence": self.sequence,
|
||||
"agent_name": self.agent,
|
||||
"event_type": self.event_type,
|
||||
"content": self.content,
|
||||
"timestamp": self.timestamp,
|
||||
**(self.metadata if self.metadata else {})
|
||||
}
|
||||
|
||||
def to_sse_format(self) -> str:
|
||||
"""转换为 SSE 数据格式"""
|
||||
import json
|
||||
return f"data: {json.dumps(self.to_dict(), ensure_ascii=False)}\n\n"
|
||||
|
||||
@classmethod
|
||||
def reset_sequence(cls, task_id: str):
|
||||
"""重置指定 task_id 的序列号(任务重新开始时调用)"""
|
||||
cls._sequence_counters[task_id] = 0
|
||||
|
||||
|
||||
class TaskStreamQueue:
|
||||
"""
|
||||
单个任务的流式消息队列(线程安全)
|
||||
|
||||
并发处理逻辑:
|
||||
- CrewAI 默认是同步运行的,而 FastAPI 和 SSE 需要异步
|
||||
- 使用 asyncio.Queue 的 run_coroutine_threadsafe 方法从同步线程安全地发布事件
|
||||
- 确保 stream_manager 能安全地在线程间传递消息
|
||||
"""
|
||||
|
||||
def __init__(self, task_id: str, max_size: int = 1000):
|
||||
self.task_id = task_id
|
||||
self.queue: asyncio.Queue[StreamEvent] = asyncio.Queue(maxsize=max_size)
|
||||
self.is_closed = False
|
||||
self.subscribers: int = 0
|
||||
self._loop = asyncio.get_event_loop()
|
||||
self._lock = threading.Lock() # 用于保护同步操作
|
||||
|
||||
async def put(self, event: StreamEvent) -> bool:
|
||||
"""向队列添加事件(异步调用)"""
|
||||
if self.is_closed:
|
||||
return False
|
||||
try:
|
||||
await asyncio.wait_for(self.queue.put(event), timeout=5.0)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def put_nowait(self, event: StreamEvent) -> bool:
|
||||
"""
|
||||
从同步线程(如 CrewAI 事件处理器)安全地发布事件
|
||||
|
||||
使用 run_coroutine_threadsafe 将协程提交到事件循环执行
|
||||
这是实现 CrewAI(同步)与 SSE(异步)集成的关键
|
||||
"""
|
||||
if self.is_closed:
|
||||
return False
|
||||
|
||||
try:
|
||||
# 将协程提交到事件循环线程安全地执行
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self.queue.put(event),
|
||||
self._loop
|
||||
)
|
||||
# 等待完成(带超时)
|
||||
future.result(timeout=5.0)
|
||||
return True
|
||||
except Exception as e:
|
||||
# print(f"发布事件失败:{e}")
|
||||
return False
|
||||
|
||||
async def get(self) -> Optional[StreamEvent]:
|
||||
"""从队列获取事件"""
|
||||
if self.is_closed and self.queue.empty():
|
||||
return None
|
||||
try:
|
||||
return await self.queue.get()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def close(self):
|
||||
"""关闭队列"""
|
||||
with self._lock:
|
||||
self.is_closed = True
|
||||
|
||||
async def stream_events(self) -> AsyncGenerator[StreamEvent, None]:
|
||||
"""异步生成器,持续产出事件直到队列关闭"""
|
||||
while not (self.is_closed and self.queue.empty()):
|
||||
try:
|
||||
event = await asyncio.wait_for(self.queue.get(), timeout=30.0)
|
||||
yield event
|
||||
except asyncio.TimeoutError:
|
||||
if self.is_closed and self.queue.empty():
|
||||
break
|
||||
continue
|
||||
except Exception:
|
||||
break
|
||||
|
||||
|
||||
class StreamManager:
|
||||
"""全局流管理器 - 管理所有任务的 SSE 流"""
|
||||
|
||||
_instance: Optional['StreamManager'] = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
self._initialized = True
|
||||
# task_id -> TaskStreamQueue 映射
|
||||
self.streams: Dict[str, TaskStreamQueue] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def create_stream(self, task_id: str) -> TaskStreamQueue:
|
||||
"""为指定 task_id 创建新的流队列"""
|
||||
async with self._lock:
|
||||
if task_id in self.streams:
|
||||
# 如果已存在,先关闭旧的
|
||||
old_stream = self.streams[task_id]
|
||||
old_stream.close()
|
||||
|
||||
# 重置序列号计数器
|
||||
StreamEvent.reset_sequence(task_id)
|
||||
|
||||
queue = TaskStreamQueue(task_id)
|
||||
self.streams[task_id] = queue
|
||||
return queue
|
||||
|
||||
async def get_stream(self, task_id: str) -> Optional[TaskStreamQueue]:
|
||||
"""获取指定 task_id 的流队列"""
|
||||
async with self._lock:
|
||||
return self.streams.get(task_id)
|
||||
|
||||
async def publish_event(
|
||||
self,
|
||||
task_id: str,
|
||||
event_type: str,
|
||||
agent: str,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> bool:
|
||||
"""发布事件到指定任务的流队列"""
|
||||
async with self._lock:
|
||||
stream = self.streams.get(task_id)
|
||||
if stream is None:
|
||||
return False
|
||||
|
||||
event = StreamEvent(
|
||||
event_type=event_type,
|
||||
agent=agent,
|
||||
content=content,
|
||||
task_id=task_id,
|
||||
metadata=metadata
|
||||
)
|
||||
return await stream.put(event)
|
||||
|
||||
async def close_stream(self, task_id: str):
|
||||
"""关闭指定任务的流队列"""
|
||||
async with self._lock:
|
||||
if task_id in self.streams:
|
||||
self.streams[task_id].close()
|
||||
# 可以选择删除或保留(如果需要历史记录)
|
||||
# del self.streams[task_id]
|
||||
|
||||
async def cleanup_old_streams(self, max_age_seconds: int = 3600):
|
||||
"""清理超过指定时间的旧流(定期调用)"""
|
||||
now = datetime.now()
|
||||
to_remove = []
|
||||
|
||||
async with self._lock:
|
||||
for task_id, stream in self.streams.items():
|
||||
if stream.is_closed:
|
||||
to_remove.append(task_id)
|
||||
|
||||
async with self._lock:
|
||||
for task_id in to_remove:
|
||||
del self.streams[task_id]
|
||||
|
||||
def list_active_streams(self) -> list:
|
||||
"""列出所有活跃的流"""
|
||||
return [
|
||||
{"task_id": tid, "closed": s.is_closed}
|
||||
for tid, s in self.streams.items()
|
||||
]
|
||||
|
||||
|
||||
# 全局单例
|
||||
stream_manager = StreamManager()
|
||||
|
||||
|
||||
async def create_sse_generator(task_id: str) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
创建 SSE 异步生成器
|
||||
供 FastAPI StreamingResponse 使用
|
||||
"""
|
||||
stream = await stream_manager.get_stream(task_id)
|
||||
if stream is None:
|
||||
# 创建一个新的流(如果不存在)
|
||||
stream = await stream_manager.create_stream(task_id)
|
||||
|
||||
try:
|
||||
async for event in stream.stream_events():
|
||||
yield event.to_sse_format()
|
||||
finally:
|
||||
# 发送结束标记
|
||||
import json
|
||||
end_event = {
|
||||
"type": "connection_end",
|
||||
"task_id": task_id,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
yield f"data: {json.dumps(end_event)}\n\n"
|
||||
Reference in New Issue
Block a user