fix
This commit is contained in:
@@ -105,6 +105,24 @@ def show_chat():
|
||||
st.markdown(f"**{info['name']}** *{msg['time']}* - {msg['task']}")
|
||||
st.markdown(msg["content"][:800] + ("..." if len(msg["content"]) > 800 else ""))
|
||||
|
||||
def extract_code(content):
|
||||
"""从 Markdown 代码块中提取纯代码"""
|
||||
# 检查是否有 ```python 标记
|
||||
if "```python" in content:
|
||||
# 提取 ```python 和 ``` 之间的内容
|
||||
parts = content.split("```python")
|
||||
if len(parts) > 1:
|
||||
code = parts[1].split("```")[0].strip()
|
||||
return code
|
||||
elif "```" in content:
|
||||
# 通用的 ``` 标记
|
||||
parts = content.split("```")
|
||||
if len(parts) > 1:
|
||||
code = parts[1].strip()
|
||||
return code
|
||||
# 没有标记,返回原内容
|
||||
return content
|
||||
|
||||
def save_files():
|
||||
"""保存生成的文件到 workspace/"""
|
||||
workspace = Path("workspace")
|
||||
@@ -121,21 +139,28 @@ def save_files():
|
||||
if agent == "PM_Agent" and ("需求" in content or "SRS" in content):
|
||||
file = workspace / "SRS.md"
|
||||
with open(file, "w", encoding="utf-8") as f:
|
||||
f.write(f"# 软件需求规格说明书\n\n生成时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n{content}")
|
||||
f.write(f"# 软件需求规格说明书\n\n生成时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
||||
f.write(content)
|
||||
files.append(str(file))
|
||||
|
||||
# QA Agent 生成测试
|
||||
if agent == "QA_Agent" and ("test" in content.lower() or "测试" in content):
|
||||
if agent == "QA_Agent" and ("test" in content.lower() or "测试" in content or "def test_" in content):
|
||||
file = workspace / "test_sample.py"
|
||||
# 提取纯代码
|
||||
code = extract_code(content)
|
||||
with open(file, "w", encoding="utf-8") as f:
|
||||
f.write(f"# 测试用例\n# 生成时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n{content}")
|
||||
f.write(f"# 测试用例\n# 生成时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
||||
f.write(code)
|
||||
files.append(str(file))
|
||||
|
||||
# Dev Agent 生成代码
|
||||
if agent == "Dev_Agent" and ("def " in content or "class " in content):
|
||||
file = workspace / "src_sample.py"
|
||||
# 提取纯代码
|
||||
code = extract_code(content)
|
||||
with open(file, "w", encoding="utf-8") as f:
|
||||
f.write(f"# 源代码\n# 生成时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n{content}")
|
||||
f.write(f"# 源代码\n# 生成时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
||||
f.write(code)
|
||||
files.append(str(file))
|
||||
|
||||
return files
|
||||
@@ -194,18 +219,24 @@ def main():
|
||||
|
||||
def run_workflow(api_key, model, max_round):
|
||||
"""运行工作流"""
|
||||
# 保存现有消息(包括用户需求)
|
||||
existing_messages = list(st.session_state.messages)
|
||||
|
||||
# 只清空计数,不清空消息
|
||||
st.session_state.running = True
|
||||
st.session_state.messages = []
|
||||
st.session_state.agent_counts = {k: 0 for k in AGENTS}
|
||||
|
||||
progress = st.empty()
|
||||
progress.info("🚀 启动工作流...")
|
||||
|
||||
try:
|
||||
# 获取需求
|
||||
user_msgs = [m for m in st.session_state.messages if m["agent"] == "User_Proxy"]
|
||||
# 获取需求 - 在清空前已经保存了
|
||||
user_msgs = [m for m in existing_messages if m["agent"] == "User_Proxy"]
|
||||
requirement = user_msgs[-1]["content"] if user_msgs else "开发一个电池健康预测 API"
|
||||
|
||||
# 显示需求
|
||||
st.info(f"📋 用户需求:{requirement}")
|
||||
|
||||
# 创建 Agent
|
||||
llm_config = get_llm_config(model=model, api_key=api_key)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user