fix
This commit is contained in:
90
main.py
90
main.py
@@ -6,6 +6,8 @@ SDLC Agent Demo - FastAPI 主服务(异步版本)
|
||||
import json
|
||||
import uuid
|
||||
import asyncio
|
||||
import threading
|
||||
import queue
|
||||
from typing import Dict, Optional, AsyncGenerator
|
||||
from datetime import datetime
|
||||
from fastapi import FastAPI, HTTPException
|
||||
@@ -50,7 +52,7 @@ class TaskManager:
|
||||
|
||||
def __init__(self):
|
||||
self.tasks: Dict[str, Dict] = {}
|
||||
self.task_queues: Dict[str, asyncio.Queue] = {}
|
||||
self.task_events: Dict[str, list] = {} # 存储任务的所有事件
|
||||
|
||||
def create_task(self, requirement: str) -> str:
|
||||
"""创建新任务"""
|
||||
@@ -62,8 +64,8 @@ class TaskManager:
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": datetime.now().isoformat()
|
||||
}
|
||||
# 创建异步队列用于 SSE 推送
|
||||
self.task_queues[task_id] = asyncio.Queue()
|
||||
# 创建事件列表用于存储所有事件
|
||||
self.task_events[task_id] = []
|
||||
return task_id
|
||||
|
||||
def update_task_status(self, task_id: str, status: str):
|
||||
@@ -72,22 +74,27 @@ class TaskManager:
|
||||
self.tasks[task_id]["status"] = status
|
||||
self.tasks[task_id]["updated_at"] = datetime.now().isoformat()
|
||||
|
||||
async def send_event(self, task_id: str, event: dict):
|
||||
"""发送事件到队列"""
|
||||
if task_id in self.task_queues:
|
||||
await self.task_queues[task_id].put(event)
|
||||
def add_event(self, task_id: str, event: dict):
|
||||
"""添加事件到任务"""
|
||||
if task_id in self.task_events:
|
||||
self.task_events[task_id].append(event)
|
||||
|
||||
async def get_event(self, task_id: str, timeout: float = 60.0) -> Optional[dict]:
|
||||
"""从队列获取事件"""
|
||||
if task_id in self.task_queues:
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self.task_queues[task_id].get(),
|
||||
timeout=timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
return None
|
||||
return None
|
||||
def get_events(self, task_id: str, last_index: int = 0) -> Dict:
|
||||
"""获取任务的新事件"""
|
||||
if task_id not in self.task_events:
|
||||
return {"events": [], "has_more": False}
|
||||
|
||||
events = self.task_events[task_id]
|
||||
new_events = events[last_index:]
|
||||
|
||||
task = self.tasks.get(task_id, {})
|
||||
has_more = task.get("status") not in ["completed", "failed", "pending"]
|
||||
|
||||
return {
|
||||
"events": new_events,
|
||||
"has_more": has_more,
|
||||
"status": task.get("status", "unknown")
|
||||
}
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[Dict]:
|
||||
"""获取任务信息"""
|
||||
@@ -102,7 +109,7 @@ task_manager = TaskManager()
|
||||
@app.post("/api/v1/sdlc/start", response_model=Dict[str, str])
|
||||
async def start_sdlc_process(request: StartRequest):
|
||||
"""
|
||||
启动 SDLC 流程(异步执行)
|
||||
启动 SDLC 流程(使用线程池异步执行)
|
||||
"""
|
||||
# 验证配置
|
||||
try:
|
||||
@@ -113,8 +120,9 @@ async def start_sdlc_process(request: StartRequest):
|
||||
# 创建任务
|
||||
task_id = task_manager.create_task(request.requirement)
|
||||
|
||||
# 异步执行 SDLC 流程
|
||||
asyncio.create_task(execute_sdlc_flow(task_id, request.requirement))
|
||||
# 在线程池中异步执行 SDLC 流程
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_in_executor(executor, execute_sdlc_flow, task_id, request.requirement)
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
@@ -122,9 +130,16 @@ async def start_sdlc_process(request: StartRequest):
|
||||
}
|
||||
|
||||
|
||||
async def execute_sdlc_flow(task_id: str, requirement: str):
|
||||
import threading
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
# 线程池
|
||||
executor = ThreadPoolExecutor(max_workers=5)
|
||||
|
||||
def execute_sdlc_flow(task_id: str, requirement: str):
|
||||
"""
|
||||
异步执行 SDLC 流程(使用 asyncio.to_thread 运行同步生成器)
|
||||
在线程池中执行 SDLC 流程,将事件保存到任务列表
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
@@ -133,8 +148,8 @@ async def execute_sdlc_flow(task_id: str, requirement: str):
|
||||
task_manager.update_task_status(task_id, "processing")
|
||||
|
||||
try:
|
||||
# 先发送一个任务启动事件
|
||||
await task_manager.send_event(task_id, {
|
||||
# 添加启动事件
|
||||
task_manager.add_event(task_id, {
|
||||
"event": "task_started",
|
||||
"data": {
|
||||
"status": "starting",
|
||||
@@ -143,13 +158,13 @@ async def execute_sdlc_flow(task_id: str, requirement: str):
|
||||
}
|
||||
})
|
||||
|
||||
# 在线程池中执行同步生成器
|
||||
# 直接执行 CrewAI (同步阻塞)
|
||||
crew = SDLCCrew()
|
||||
for event in crew.execute(requirement):
|
||||
# 发送事件到队列
|
||||
await task_manager.send_event(task_id, event)
|
||||
# 添加事件到任务
|
||||
task_manager.add_event(task_id, event)
|
||||
|
||||
# 如果是最终结果或错误,更新状态
|
||||
# 更新状态
|
||||
event_type = event.get("event", "")
|
||||
if event_type == "final_result":
|
||||
task_manager.update_task_status(task_id, "completed")
|
||||
@@ -158,7 +173,7 @@ async def execute_sdlc_flow(task_id: str, requirement: str):
|
||||
|
||||
except Exception as e:
|
||||
task_manager.update_task_status(task_id, "failed")
|
||||
await task_manager.send_event(task_id, {
|
||||
task_manager.add_event(task_id, {
|
||||
"event": "error",
|
||||
"data": {
|
||||
"error": str(e),
|
||||
@@ -167,6 +182,19 @@ async def execute_sdlc_flow(task_id: str, requirement: str):
|
||||
})
|
||||
|
||||
|
||||
@app.get("/api/v1/sdlc/poll/{task_id}")
|
||||
async def poll_task_events(task_id: str, last_index: int = 0):
|
||||
"""
|
||||
轮询任务事件(替代 SSE)
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
last_index: 最后已知的 event 索引
|
||||
"""
|
||||
result = task_manager.get_events(task_id, last_index)
|
||||
return result
|
||||
|
||||
|
||||
@app.get("/api/v1/sdlc/stream/{task_id}")
|
||||
async def stream_task_progress(task_id: str):
|
||||
"""
|
||||
@@ -316,7 +344,7 @@ if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
port=8080,
|
||||
reload=False,
|
||||
log_level="info"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user