Files
autogen/frontend/app.py
2026-03-12 17:58:15 +08:00

438 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
Streamlit 实时 Agent 协作平台
功能:
1. 实时展示每个 Agent 的状态和动作
2. 自动保存生成的文件到 workspace/
3. Agent 自主决定对话对象(动态发言顺序)
4. 实时更新对话流和 Agent 状态
"""
import streamlit as st
import os
from pathlib import Path
from datetime import datetime
import time
import re
try:
from autogen import AssistantAgent, UserProxyAgent, GroupChat, GroupChatManager
AUTOGEN_AVAILABLE = True
except ImportError:
AUTOGEN_AVAILABLE = False
# 添加项目根目录到路径
import sys
sys.path.insert(0, str(Path(__file__).parent.parent))
from config.llm_config import get_llm_config, PM_PROMPT, QA_PROMPT, DEV_PROMPT, ORCH_PROMPT
from utils.callback_handler import get_callback_handler
# 页面配置
st.set_page_config(page_title="多 Agent 协作平台", page_icon="🤖", layout="wide")
# Agent 配置
AGENTS = {
"PM_Agent": {"name": "产品经理", "avatar": "📋", "color": "blue", "desc": "需求分析与 SRS 生成"},
"QA_Agent": {"name": "测试工程师", "avatar": "", "color": "green", "desc": "测试用例设计"},
"Dev_Agent": {"name": "开发工程师", "avatar": "💻", "color": "orange", "desc": "代码实现"},
"Orchestrator": {"name": "协调器", "avatar": "🎯", "color": "purple", "desc": "流程协调与验证"},
"User_Proxy": {"name": "用户代理", "avatar": "👤", "color": "gray", "desc": "测试执行"}
}
def init_state():
"""初始化 session state"""
if "messages" not in st.session_state:
st.session_state.messages = []
if "running" not in st.session_state:
st.session_state.running = False
if "current_agent" not in st.session_state:
st.session_state.current_agent = None
if "agent_counts" not in st.session_state:
st.session_state.agent_counts = {k: 0 for k in AGENTS}
if "agent_status" not in st.session_state:
st.session_state.agent_status = {k: "⚪ 等待中" for k in AGENTS}
if "current_task" not in st.session_state:
st.session_state.current_task = {k: "" for k in AGENTS}
if "saved_files" not in st.session_state:
st.session_state.saved_files = []
if "conversation_history" not in st.session_state:
st.session_state.conversation_history = []
def add_message(agent, content, task=""):
"""添加消息"""
msg = {
"agent": agent,
"content": content,
"task": task,
"time": datetime.now().strftime("%H:%M:%S")
}
st.session_state.messages.append(msg)
st.session_state.agent_counts[agent] = st.session_state.agent_counts.get(agent, 0) + 1
st.session_state.current_agent = agent
# 更新 Agent 状态
for ag in AGENTS:
if ag == agent:
st.session_state.agent_status[ag] = "🟢 发言中"
st.session_state.current_task[ag] = task
else:
st.session_state.agent_status[ag] = "⚪ 等待中"
# 添加到对话历史(用于展示完整的对话流)
st.session_state.conversation_history.append(msg)
def show_agent_status():
"""显示 Agent 状态"""
st.subheader("🎯 Agent 实时状态")
cols = st.columns(len(AGENTS))
for i, (agent_key, info) in enumerate(AGENTS.items()):
with cols[i]:
is_active = st.session_state.current_agent == agent_key
count = st.session_state.agent_counts.get(agent_key, 0)
status = st.session_state.agent_status.get(agent_key, "⚪ 等待中")
current_task = st.session_state.current_task.get(agent_key, "")
border_color = info["color"]
bg_color = "#e8f5e9" if is_active else "white"
st.markdown(f"""
<div style='
padding: 15px;
border-radius: 10px;
border: 3px {"solid" if is_active else "dashed"} {border_color};
background: {bg_color};
text-align: center;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
'>
<div style='font-size: 2.5rem;'>{info["avatar"]}</div>
<div style='font-weight: bold; margin: 5px 0; font-size: 1.1rem;'>{info["name"]}</div>
<div style='font-size: 0.75rem; color: #666; margin-bottom: 5px;'>{info["desc"]}</div>
<div style='font-size: 0.85rem; color: #2e7d32; margin: 5px 0;'>
{status}
</div>
<div style='font-size: 0.75rem; color: #999; margin-top: 5px;'>
💬 {count} 条消息
</div>
{f"<div style='font-size: 0.7rem; color: #1976d2; margin-top: 5px; font-style: italic;'>📋 {current_task}</div>" if current_task else ""}
</div>
""", unsafe_allow_html=True)
def show_chat():
"""显示对话流"""
st.subheader("💬 Agent 对话流")
if not st.session_state.conversation_history:
st.info("👈 暂无对话,请在下方输入需求并启动")
return
# 使用容器展示对话,支持滚动
chat_container = st.container()
with chat_container:
for i, msg in enumerate(st.session_state.conversation_history):
agent = msg["agent"]
info = AGENTS.get(agent, {"name": "未知", "avatar": "🤖", "color": "gray"})
# 创建对话气泡
with st.chat_message(agent.lower(), avatar=info["avatar"]):
# 显示 Agent 名称、时间和任务
task_info = f"- {msg['task']}" if msg['task'] else ""
st.markdown(f"**{info['name']}** *{msg['time']}* {task_info}")
# 处理长内容,提供折叠选项
content = msg["content"]
if len(content) > 1000:
# 短预览 + 展开详情
preview = content[:800] + "..."
st.markdown(preview)
with st.expander("查看完整内容"):
st.markdown(content)
else:
st.markdown(content)
# 如果是代码内容,提供语法高亮
if "```" in content:
code_blocks = re.findall(r'```(\w+)?\n(.*?)```', content, re.DOTALL)
for lang, code in code_blocks:
language = lang if lang else "python"
st.code(code, language=language)
def extract_code(content):
"""从 Markdown 代码块中提取纯代码"""
# 检查是否有 ```python 标记
if "```python" in content:
# 提取 ```python 和 ``` 之间的内容
parts = content.split("```python")
if len(parts) > 1:
code = parts[1].split("```")[0].strip()
return code
elif "```" in content:
# 通用的 ``` 标记
parts = content.split("```")
if len(parts) > 1:
code = parts[1].strip()
return code
# 没有标记,返回原内容
return content
def is_code_complete(code):
"""检查代码是否完整"""
if not code or not code.strip():
return False
# 检查是否有未闭合的括号
if code.count('(') != code.count(')'):
print(f"⚠️ 括号不匹配:( {code.count('(')} vs ) {code.count(')')}")
return False
if code.count('[') != code.count(']'):
print(f"⚠️ 方括号不匹配:[ {code.count('[')} vs ] {code.count(']')}")
return False
if code.count('{') != code.count('}'):
print(f"⚠️ 花括号不匹配:{{ {code.count('{')} vs }} {code.count('}')}")
return False
# 检查是否以完整的方式结束(不是突然截断)
lines = code.strip().split('\n')
if lines:
last_line = lines[-1].rstrip()
# 如果最后一行以这些符号结尾,说明可能被截断了
if last_line.endswith(':') or last_line.endswith('\\') or last_line.endswith('(') or last_line.endswith('['):
print(f"⚠️ 代码可能截断:最后一行是 '{last_line}'")
return False
# 尝试简单的语法检查
try:
compile(code, '<string>', 'exec')
except SyntaxError as e:
print(f"⚠️ 语法错误:{e}")
return False
return True
def save_files():
"""保存生成的文件到 workspace/"""
workspace = Path("workspace")
workspace.mkdir(exist_ok=True)
files = []
# 1. 保存 PM Agent 的 SRS 文档
for msg in st.session_state.messages:
if msg["agent"] == "PM_Agent" and ("需求" in msg["content"] or "SRS" in msg["content"]):
file = workspace / "SRS.md"
with open(file, "w", encoding="utf-8") as f:
f.write(f"# 软件需求规格说明书\n\n生成时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
f.write(msg["content"])
files.append(str(file))
break # 只保存第一个
# 2. 保存 QA Agent 的测试代码
for msg in st.session_state.messages:
if msg["agent"] == "QA_Agent" and ("test" in msg["content"].lower() or "测试" in msg["content"] or "def test_" in msg["content"]):
file = workspace / "test_sample.py"
code = extract_code(msg["content"])
with open(file, "w", encoding="utf-8") as f:
f.write(f"# 测试用例\n# 生成时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
f.write(code)
files.append(str(file))
break # 只保存第一个
# 3. 保存 Dev Agent 的代码 - 收集所有消息合并处理
dev_messages = [msg for msg in st.session_state.messages if msg["agent"] == "Dev_Agent"]
if dev_messages:
# 合并所有 Dev Agent 的消息内容
all_dev_content = "\n\n".join([msg["content"] for msg in dev_messages])
# 尝试提取带文件名的代码块
import re
code_blocks = re.findall(r'```python\s*\n#(?:\s*File:|\s*filename:)?\s*([^\n]+)\n(.*?)```', all_dev_content, re.DOTALL)
if code_blocks:
# 保存多个文件
for filename, code_content in code_blocks:
filename = filename.strip()
file_path = workspace / filename
code = extract_code(f"```python\n{code_content}```")
# 检查完整性
print(f"\n🔍 检查文件 {filename}...")
if not is_code_complete(code):
print(f"⚠️ 警告:{filename} 代码可能不完整,但仍会保存")
else:
print(f"{filename} 代码完整")
with open(file_path, "w", encoding="utf-8") as f:
f.write(f"# 生成时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n{code}")
files.append(str(file_path))
else:
# 没有文件名标记,合并所有代码保存为 src_sample.py
all_code = "\n\n".join([extract_code(msg["content"]) for msg in dev_messages])
file = workspace / "src_sample.py"
# 检查完整性
print(f"\n🔍 检查文件 src_sample.py...")
if not is_code_complete(all_code):
print(f"⚠️ 警告src_sample.py 代码可能不完整,但仍会保存")
else:
print(f"✅ src_sample.py 代码完整")
with open(file, "w", encoding="utf-8") as f:
f.write(f"# 源代码\n# 生成时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
f.write(all_code)
files.append(str(file))
return files
def main():
st.title("🤖 多 Agent 协作平台")
st.markdown("**实时展示 Agent 状态 · 自动生成文件**")
init_state()
# 侧边栏
with st.sidebar:
st.title("⚙️ 配置")
api_key = st.text_input("API Key", type="password", value=os.getenv("DASHSCOPE_API_KEY", ""))
model = st.selectbox("模型", ["qwen3.5-flash", "qwen-max", "qwen-plus"], index=0)
max_round = st.slider("最大轮数", 5, 300, 15)
st.divider()
if st.button("▶️ 启动工作流", type="primary", use_container_width=True):
if not api_key:
st.error("请先设置 API Key")
elif not AUTOGEN_AVAILABLE:
st.error("请先安装 AutoGen")
else:
run_workflow(api_key, model, max_round)
st.divider()
if st.button("🗑️ 清空对话", use_container_width=True):
st.session_state.messages = []
st.session_state.current_agent = None
st.session_state.agent_counts = {k: 0 for k in AGENTS}
st.rerun()
# 显示生成的文件
st.divider()
st.subheader("📁 生成的文件")
workspace = Path("workspace")
if workspace.exists():
for file in workspace.glob("*"):
if file.is_file():
st.caption(f"📄 {file.name}")
# 主界面
show_agent_status()
st.divider()
show_chat()
# 输入框
st.divider()
if user_input := st.chat_input("输入需求..."):
add_message("User_Proxy", user_input, "提出需求")
st.rerun()
def run_workflow(api_key, model, max_round):
"""运行工作流"""
# 保存现有消息(包括用户需求)
existing_messages = list(st.session_state.messages)
# 只清空计数,不清空消息
st.session_state.running = True
st.session_state.agent_counts = {k: 0 for k in AGENTS}
progress = st.empty()
progress.info("🚀 启动工作流...")
try:
# 获取需求 - 在清空前已经保存了
user_msgs = [m for m in existing_messages if m["agent"] == "User_Proxy"]
requirement = user_msgs[-1]["content"] if user_msgs else "开发一个电池健康预测 API"
# 显示需求
st.info(f"📋 用户需求:{requirement}")
# 创建 Agent - 配置大 token 数,确保生成完整内容
llm_config = get_llm_config(model=model, api_key=api_key, temperature=0.7, max_tokens=8192)
# 优化提示词,要求生成完整代码
pm_prompt = PM_PROMPT + "\n\n【重要】请生成完整的、详细的需求文档,不要省略任何内容。"
qa_prompt = QA_PROMPT + "\n\n【重要】请生成完整的测试代码,包含所有必要的导入、测试函数和断言,不要省略。"
dev_prompt = DEV_PROMPT + "\n\n【重要】请生成完整的、可运行的代码,包含所有必要的导入、类和函数实现,不要省略任何代码。"
orch_prompt = ORCH_PROMPT + "\n\n【重要】请确保所有输出完整详细。"
pm = AssistantAgent("PM_Agent", system_message=pm_prompt, llm_config=llm_config, human_input_mode="NEVER")
qa = AssistantAgent("QA_Agent", system_message=qa_prompt, llm_config=llm_config, human_input_mode="NEVER")
dev = AssistantAgent("Dev_Agent", system_message=dev_prompt, llm_config=llm_config, human_input_mode="NEVER")
orch = AssistantAgent("Orchestrator", system_message=orch_prompt, llm_config=llm_config, human_input_mode="NEVER")
user = UserProxyAgent("User_Proxy", human_input_mode="NEVER", max_consecutive_auto_reply=0,
code_execution_config={"work_dir": "workspace", "use_docker": False})
# 创建 GroupChat
groupchat = GroupChat(
agents=[pm, qa, dev, orch, user],
messages=[],
max_round=max_round,
speaker_selection_method="round_robin"
)
manager = GroupChatManager(groupchat=groupchat, llm_config=llm_config)
# 初始消息
initial_msg = f"""请启动完整的 SDLC 流程:
【用户需求】{requirement}
【流程】
1. PM_Agent → SRS 文档
2. QA_Agent → 测试用例
3. Dev_Agent → 编写代码
4. User_Proxy → 执行测试
5. Orchestrator → 汇总
开始协作!"""
# 执行对话
with st.spinner("💬 Agent 们正在协作中..."):
chat_result = user.initiate_chat(manager, message=initial_msg, max_turns=max_round)
# 记录所有对话
task_map = {
"PM_Agent": "需求分析",
"QA_Agent": "测试设计",
"Dev_Agent": "代码实现",
"Orchestrator": "流程协调",
"User_Proxy": "测试执行"
}
for msg in groupchat.messages:
agent = msg.get("name", "Unknown")
content = msg.get("content", "")
task = task_map.get(agent, "工作中")
add_message(agent, content, task)
# 保存文件
progress.info("💾 正在保存文件...")
files = save_files()
if files:
progress.success(f"✅ 完成!已保存 {len(files)} 个文件到 workspace/")
else:
progress.success("✅ 工作流完成!")
st.session_state.running = False
st.rerun()
except Exception as e:
st.session_state.running = False
progress.error(f"❌ 错误:{e}")
import traceback
st.error(traceback.format_exc())
if __name__ == "__main__":
main()