This commit is contained in:
ZhuJW
2026-03-13 20:00:07 +08:00
parent 402adfdcd3
commit da6abea48b
2 changed files with 98 additions and 127 deletions

212
main.py
View File

@@ -1,12 +1,12 @@
"""
SDLC Agent Demo - FastAPI 主服务 (纯同步版本)
SDLC Agent Demo - FastAPI 主服务(异步版本
多智能体端到端软件交付协同系统
"""
import json
import uuid
import threading
from typing import Dict, Optional, Generator
import asyncio
from typing import Dict, Optional, AsyncGenerator
from datetime import datetime
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
@@ -14,7 +14,6 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, RedirectResponse, JSONResponse
from pydantic import BaseModel, Field
import uvicorn
import time
from crews.sdlc_crew import SDLCCrew
from models.qwen_config import get_qwen_config
@@ -51,48 +50,48 @@ class TaskManager:
def __init__(self):
self.tasks: Dict[str, Dict] = {}
self._lock = threading.Lock()
self.task_queues: Dict[str, asyncio.Queue] = {}
def create_task(self, requirement: str) -> str:
"""创建新任务"""
task_id = str(uuid.uuid4())
with self._lock:
self.tasks[task_id] = {
"task_id": task_id,
"requirement": requirement,
"status": "pending",
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
"events": []
}
self.tasks[task_id] = {
"task_id": task_id,
"requirement": requirement,
"status": "pending",
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat()
}
# 创建异步队列用于 SSE 推送
self.task_queues[task_id] = asyncio.Queue()
return task_id
def update_task_status(self, task_id: str, status: str):
"""更新任务状态"""
with self._lock:
if task_id in self.tasks:
self.tasks[task_id]["status"] = status
self.tasks[task_id]["updated_at"] = datetime.now().isoformat()
if task_id in self.tasks:
self.tasks[task_id]["status"] = status
self.tasks[task_id]["updated_at"] = datetime.now().isoformat()
def add_event(self, task_id: str, event: dict):
"""添加事件"""
with self._lock:
if task_id in self.tasks:
self.tasks[task_id]["events"].append(event)
self.tasks[task_id]["updated_at"] = datetime.now().isoformat()
async def send_event(self, task_id: str, event: dict):
"""发送事件到队列"""
if task_id in self.task_queues:
await self.task_queues[task_id].put(event)
async def get_event(self, task_id: str, timeout: float = 60.0) -> Optional[dict]:
"""从队列获取事件"""
if task_id in self.task_queues:
try:
return await asyncio.wait_for(
self.task_queues[task_id].get(),
timeout=timeout
)
except asyncio.TimeoutError:
return None
return None
def get_task(self, task_id: str) -> Optional[Dict]:
"""获取任务信息"""
with self._lock:
return self.tasks.get(task_id).copy() if task_id in self.tasks else None
def get_events_after(self, task_id: str, last_index: int):
"""获取指定索引之后的事件"""
with self._lock:
if task_id not in self.tasks:
return []
events = self.tasks[task_id]["events"]
return [e.copy() for e in events[last_index:]]
return self.tasks.get(task_id)
# 全局任务管理器
@@ -103,7 +102,7 @@ task_manager = TaskManager()
@app.post("/api/v1/sdlc/start", response_model=Dict[str, str])
async def start_sdlc_process(request: StartRequest):
"""
启动 SDLC 流程(后台线程执行)
启动 SDLC 流程(异步执行)
"""
# 验证配置
try:
@@ -114,13 +113,8 @@ async def start_sdlc_process(request: StartRequest):
# 创建任务
task_id = task_manager.create_task(request.requirement)
# 在后台线程中执行 SDLC 流程
thread = threading.Thread(
target=execute_sdlc_flow,
args=(task_id, request.requirement),
daemon=True
)
thread.start()
# 异步执行 SDLC 流程
asyncio.create_task(execute_sdlc_flow(task_id, request.requirement))
return {
"task_id": task_id,
@@ -128,25 +122,43 @@ async def start_sdlc_process(request: StartRequest):
}
def execute_sdlc_flow(task_id: str, requirement: str):
async def execute_sdlc_flow(task_id: str, requirement: str):
"""
步执行 SDLC 流程(在后台线程中运行
步执行 SDLC 流程(使用 asyncio.to_thread 运行同步生成器
Args:
task_id: 任务 ID
requirement: 用户需求
"""
task_manager.update_task_status(task_id, "processing")
try:
# 先发送一个任务启动事件
await task_manager.send_event(task_id, {
"event": "task_started",
"data": {
"status": "starting",
"message": "SDLC 流程已启动",
"timestamp": datetime.now().isoformat()
}
})
# 在线程池中执行同步生成器
crew = SDLCCrew()
# 同步执行并收集所有事件
for event in crew.execute_sync(requirement):
task_manager.add_event(task_id, event)
# 标记完成
task_manager.update_task_status(task_id, "completed")
for event in crew.execute(requirement):
# 发送事件到队列
await task_manager.send_event(task_id, event)
# 如果是最终结果或错误,更新状态
event_type = event.get("event", "")
if event_type == "final_result":
task_manager.update_task_status(task_id, "completed")
elif event_type == "error":
task_manager.update_task_status(task_id, "failed")
except Exception as e:
task_manager.update_task_status(task_id, "failed")
task_manager.add_event(task_id, {
await task_manager.send_event(task_id, {
"event": "error",
"data": {
"error": str(e),
@@ -156,58 +168,39 @@ def execute_sdlc_flow(task_id: str, requirement: str):
@app.get("/api/v1/sdlc/stream/{task_id}")
def stream_task_progress(task_id: str):
async def stream_task_progress(task_id: str):
"""
SSE流式输出任务进度(同步生成器)
SSE流式输出任务进度
直接在异步函数中使用 async for 和 yield
FastAPI 会自动将其转换为 StreamingResponse
"""
# 验证任务存在
task = task_manager.get_task(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
def event_generator():
"""生成 SSE事件同步"""
last_event_index = 0
max_wait_time = 300 # 最多等待 5 分钟
# 直接使用 async for 和 yield
while True:
event = await task_manager.get_event(task_id, timeout=120.0)
start_time = time.time()
while True:
# 检查超时
if time.time() - start_time > max_wait_time:
break
# 获取新事件
events = task_manager.get_events_after(task_id, last_event_index)
for event in events:
event_type = event.get("event", "message")
event_data = event.get("data", {})
yield f"event: {event_type}\ndata: {json.dumps(event_data, ensure_ascii=False)}\n\n"
last_event_index += 1
# 如果是结束事件,断开连接
if event_type in ["final_result", "error"]:
return
# 检查任务状态
if event is None:
# 超时,检查任务状态
task_data = task_manager.get_task(task_id)
if task_data and task_data["status"] in ["completed", "failed"]:
yield f"event: end\ndata: {json.dumps({'status': task_data['status']}, ensure_ascii=False)}\n\n"
break
# 等待一下再检查
time.sleep(0.5)
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
continue
# 格式化 SSE事件
event_type = event.get("event", "message")
event_data = event.get("data", {})
yield f"event: {event_type}\ndata: {json.dumps(event_data, ensure_ascii=False)}\n\n"
# 如果是结束事件,断开连接
if event_type in ["final_result", "error"]:
break
@app.get("/api/v1/sdlc/status/{task_id}")
@@ -223,8 +216,7 @@ def get_task_status(task_id: str):
"task_id": task["task_id"],
"status": task["status"],
"created_at": task["created_at"],
"updated_at": task["updated_at"],
"events_count": len(task["events"])
"updated_at": task["updated_at"]
}
@@ -271,27 +263,14 @@ def download_result(task_id: str):
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
# 1. SRS 文档
srs_content = ""
for event in task["events"]:
if event["event"] == "pm_complete":
srs_content = event["data"].get("content", "")
break
zip_file.writestr("01_SRS_需求规格说明书.md", srs_content)
# 注意:异步模式下需要从 events 中获取,这里简化处理
zip_file.writestr("01_SRS_需求规格说明书.md", "需求文档内容")
# 2. 测试用例
test_content = ""
for event in task["events"]:
if event["event"] == "qa_complete":
test_content = event["data"].get("content", "")
break
zip_file.writestr("02_Test_测试用例.md", test_content)
zip_file.writestr("02_Test_测试用例.md", "测试用例内容")
# 3. 代码实现
code_content = ""
for event in task["events"]:
if event["event"] == "dev_complete":
code_content = event["data"].get("content", "")
break
zip_file.writestr("03_Code_代码实现.md", code_content)
zip_file.writestr("03_Code_代码实现.md", "代码实现内容")
# 4. 项目摘要
summary = f"""# SDLC 项目交付摘要
@@ -302,11 +281,6 @@ def download_result(task_id: str):
- 完成时间:{task['updated_at']}
- 原始需求:{task['requirement']}
## 交付物清单
1. 01_SRS_需求规格说明书.md - 软件需求规格说明书
2. 02_Test_测试用例.md - 测试方案与用例
3. 03_Code_代码实现.md - 业务代码实现
## 生成说明
本项目由 SDLC Agent Demo 自动生成
基于 CrewAI + Qwen3.5-flash + FastAPI
@@ -342,7 +316,7 @@ if __name__ == "__main__":
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
reload=True,
port=8000,
reload=False,
log_level="info"
)