187 lines
5.1 KiB
Python
187 lines
5.1 KiB
Python
"""
|
||
回调处理器 - 用于实时推送对话内容到前端
|
||
"""
|
||
from typing import Callable, Dict, Any, List, Optional
|
||
from datetime import datetime
|
||
import threading
|
||
import queue
|
||
|
||
|
||
class MessageCallbackHandler:
|
||
"""消息回调处理器,支持实时流式输出"""
|
||
|
||
def __init__(self):
|
||
"""初始化回调处理器"""
|
||
self.callbacks: List[Callable[[Dict[str, Any]], None]] = []
|
||
self.message_queue = queue.Queue()
|
||
self.is_running = False
|
||
self._lock = threading.Lock()
|
||
|
||
def register_callback(self, callback: Callable[[Dict[str, Any]], None]):
|
||
"""
|
||
注册回调函数
|
||
|
||
Args:
|
||
callback: 回调函数,接收字典格式的消息
|
||
"""
|
||
with self._lock:
|
||
self.callbacks.append(callback)
|
||
|
||
def unregister_callback(self, callback: Callable[[Dict[str, Any]], None]):
|
||
"""注销回调函数"""
|
||
with self._lock:
|
||
if callback in self.callbacks:
|
||
self.callbacks.remove(callback)
|
||
|
||
def on_message(
|
||
self,
|
||
agent_name: str,
|
||
message: str,
|
||
role: str = "assistant",
|
||
metadata: Optional[Dict] = None
|
||
):
|
||
"""
|
||
处理新消息
|
||
|
||
Args:
|
||
agent_name: Agent 名称
|
||
message: 消息内容
|
||
role: 角色
|
||
metadata: 元数据
|
||
"""
|
||
msg_data = {
|
||
"timestamp": datetime.now().isoformat(),
|
||
"agent_name": agent_name,
|
||
"role": role,
|
||
"message": message,
|
||
"metadata": metadata or {}
|
||
}
|
||
|
||
# 放入队列
|
||
self.message_queue.put(msg_data)
|
||
|
||
# 调用所有回调
|
||
with self._lock:
|
||
for callback in self.callbacks:
|
||
try:
|
||
callback(msg_data)
|
||
except Exception as e:
|
||
print(f"回调执行失败:{e}")
|
||
|
||
def on_thinking(self, agent_name: str, status: str = "thinking"):
|
||
"""
|
||
发送思考状态
|
||
|
||
Args:
|
||
agent_name: Agent 名称
|
||
status: 状态(thinking/generating/coding/testing)
|
||
"""
|
||
self.on_message(
|
||
agent_name=agent_name,
|
||
message=f"_{status}...",
|
||
role="system",
|
||
metadata={"status": status}
|
||
)
|
||
|
||
def on_file_created(self, agent_name: str, file_path: str, file_type: str):
|
||
"""
|
||
发送文件创建事件
|
||
|
||
Args:
|
||
agent_name: Agent 名称
|
||
file_path: 文件路径
|
||
file_type: 文件类型
|
||
"""
|
||
self.on_message(
|
||
agent_name=agent_name,
|
||
message=f"📄 创建了文件:{file_path}",
|
||
role="system",
|
||
metadata={
|
||
"event_type": "file_created",
|
||
"file_path": file_path,
|
||
"file_type": file_type
|
||
}
|
||
)
|
||
|
||
def on_test_result(self, agent_name: str, passed: bool, details: str):
|
||
"""
|
||
发送测试结果
|
||
|
||
Args:
|
||
agent_name: Agent 名称
|
||
passed: 是否通过
|
||
details: 详细信息
|
||
"""
|
||
icon = "✅" if passed else "❌"
|
||
self.on_message(
|
||
agent_name=agent_name,
|
||
message=f"{icon} 测试结果:{'通过' if passed else '失败'}\n{details}",
|
||
role="system",
|
||
metadata={
|
||
"event_type": "test_result",
|
||
"passed": passed,
|
||
"details": details
|
||
}
|
||
)
|
||
|
||
def on_human_approval_request(
|
||
self,
|
||
request_id: str,
|
||
description: str,
|
||
data: Dict[str, Any]
|
||
):
|
||
"""
|
||
发送人工确认请求
|
||
|
||
Args:
|
||
request_id: 请求 ID
|
||
description: 请求描述
|
||
data: 相关数据
|
||
"""
|
||
self.on_message(
|
||
agent_name="Orchestrator",
|
||
message=f"⚠️ 需要人工确认:{description}",
|
||
role="system",
|
||
metadata={
|
||
"event_type": "human_approval",
|
||
"request_id": request_id,
|
||
"description": description,
|
||
"data": data
|
||
}
|
||
)
|
||
|
||
def get_message(self, timeout: float = 1.0) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
从队列获取消息(非阻塞)
|
||
|
||
Args:
|
||
timeout: 超时时间
|
||
|
||
Returns:
|
||
消息字典或 None
|
||
"""
|
||
try:
|
||
return self.message_queue.get(timeout=timeout)
|
||
except queue.Empty:
|
||
return None
|
||
|
||
def clear_queue(self):
|
||
"""清空消息队列"""
|
||
while not self.message_queue.empty():
|
||
try:
|
||
self.message_queue.get_nowait()
|
||
except queue.Empty:
|
||
break
|
||
|
||
|
||
# 全局回调实例
|
||
_global_callback: Optional[MessageCallbackHandler] = None
|
||
|
||
|
||
def get_callback_handler() -> MessageCallbackHandler:
|
||
"""获取或创建全局回调处理器"""
|
||
global _global_callback
|
||
if _global_callback is None:
|
||
_global_callback = MessageCallbackHandler()
|
||
return _global_callback
|