use multiple apis
This commit is contained in:
514
app/routers/session.py
Normal file
514
app/routers/session.py
Normal file
@@ -0,0 +1,514 @@
|
||||
"""
|
||||
app/routers/session.py - 交互式会话路由
|
||||
"""
|
||||
|
||||
import logging
|
||||
import json
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
from app.session import SessionStore, SessionStatus
|
||||
from app.agents import ClarifyAgent, PMAgent, QAAgent, DevAgent, FixAgent
|
||||
from app.test_runner import run_python_tests
|
||||
from app.message import (
|
||||
send_workflow_start,
|
||||
send_requirement_result,
|
||||
send_test_cases,
|
||||
send_generate_code,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/session", tags=["session"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------- 请求 / 响应模型 ----------
|
||||
|
||||
class StartRequest(BaseModel):
|
||||
requirement: str
|
||||
|
||||
class ClarifyRequest(BaseModel):
|
||||
message: str
|
||||
|
||||
class RefineRequest(BaseModel):
|
||||
feedback: str
|
||||
|
||||
|
||||
class SessionResponse(BaseModel):
|
||||
session_id: str
|
||||
status: str
|
||||
ready: bool = False
|
||||
question: Optional[str] = None # 当 ready=False 时返回追问
|
||||
data: Optional[dict] = None # 当前阶段产出
|
||||
|
||||
|
||||
# ---------- 工具函数 ----------
|
||||
|
||||
def _get_session_or_404(session_id: str):
|
||||
session = SessionStore.get(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="会话不存在或已过期")
|
||||
return session
|
||||
|
||||
def _require_status(session, *allowed: SessionStatus):
|
||||
if session.status not in allowed:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"当前状态 [{session.status}] 不允许此操作,允许的状态: {[s.value for s in allowed]}"
|
||||
)
|
||||
|
||||
|
||||
# ---------- 接口 ----------
|
||||
|
||||
@router.post("/start", response_model=SessionResponse)
|
||||
def start_session(body: StartRequest):
|
||||
"""创建会话,AI 判断需求是否完整,不够则追问。"""
|
||||
session = SessionStore.create(body.requirement)
|
||||
agent = ClarifyAgent()
|
||||
result = agent.start(body.requirement)
|
||||
|
||||
q = result.get("question", "")
|
||||
if q:
|
||||
session.clarify_history.append({"role": "assistant", "content": q})
|
||||
session.clarified_requirement = result.get("clarified_requirement", body.requirement)
|
||||
|
||||
if result.get("ready"):
|
||||
session.status = SessionStatus.PM_READY
|
||||
|
||||
session.touch()
|
||||
return SessionResponse(
|
||||
session_id=session.session_id,
|
||||
status=session.status,
|
||||
ready=result.get("ready", False),
|
||||
question=result.get("question") or None,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{session_id}/clarify", response_model=SessionResponse)
|
||||
def clarify(session_id: str, body: ClarifyRequest):
|
||||
"""用户补充需求,AI 继续判断是否够了。"""
|
||||
session = _get_session_or_404(session_id)
|
||||
_require_status(session, SessionStatus.CLARIFYING)
|
||||
|
||||
session.clarify_history.append({"role": "user", "content": body.message})
|
||||
|
||||
agent = ClarifyAgent()
|
||||
result = agent.continue_clarify(session.clarify_history, body.message)
|
||||
|
||||
q = result.get("question", "")
|
||||
if q:
|
||||
session.clarify_history.append({"role": "assistant", "content": q})
|
||||
session.clarified_requirement = result.get("clarified_requirement", session.clarified_requirement)
|
||||
|
||||
if result.get("ready"):
|
||||
session.status = SessionStatus.PM_READY
|
||||
|
||||
session.touch()
|
||||
return SessionResponse(
|
||||
session_id=session.session_id,
|
||||
status=session.status,
|
||||
ready=result.get("ready", False),
|
||||
question=result.get("question") or None,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{session_id}/pm/stream")
|
||||
def pm_stream(session_id: str):
|
||||
"""流式返回 PM Agent 分析过程和结果(SSE)。"""
|
||||
session = _get_session_or_404(session_id)
|
||||
_require_status(session, SessionStatus.PM_READY)
|
||||
|
||||
send_workflow_start(session.clarified_requirement)
|
||||
agent = PMAgent()
|
||||
simple_requirement = session.clarified_requirement
|
||||
|
||||
def generate():
|
||||
try:
|
||||
result = None
|
||||
for item in agent.stream_analyze(simple_requirement):
|
||||
if isinstance(item, tuple):
|
||||
_, result = item
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'chunk', 'text': item}, ensure_ascii=False)}\n\n"
|
||||
session.requirement_analysis = result
|
||||
send_requirement_result(result)
|
||||
session.status = SessionStatus.PM_DONE
|
||||
session.touch()
|
||||
yield f"data: {json.dumps({'type': 'done', 'status': session.status, 'data': {'requirement_analysis': result}}, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{session_id}/pm/run", response_model=SessionResponse)
|
||||
def pm_run(session_id: str):
|
||||
"""触发 PM Agent 分析需求。"""
|
||||
session = _get_session_or_404(session_id)
|
||||
_require_status(session, SessionStatus.PM_READY)
|
||||
|
||||
send_workflow_start(session.clarified_requirement)
|
||||
|
||||
agent = PMAgent()
|
||||
session.requirement_analysis = agent.analyze_requirement(session.clarified_requirement)
|
||||
send_requirement_result(session.requirement_analysis)
|
||||
|
||||
session.status = SessionStatus.PM_DONE
|
||||
session.touch()
|
||||
return SessionResponse(
|
||||
session_id=session.session_id,
|
||||
status=session.status,
|
||||
ready=True,
|
||||
data={"requirement_analysis": session.requirement_analysis},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{session_id}/pm/refine/stream")
|
||||
def pm_refine_stream(session_id: str, feedback: str):
|
||||
"""流式修改 PM 产出(SSE),feedback 经由 query param 传入。"""
|
||||
session = _get_session_or_404(session_id)
|
||||
_require_status(session, SessionStatus.PM_DONE)
|
||||
|
||||
agent = PMAgent()
|
||||
previous = session.requirement_analysis
|
||||
|
||||
def generate():
|
||||
try:
|
||||
result = None
|
||||
for item in agent.stream_refine(previous, feedback):
|
||||
if isinstance(item, tuple):
|
||||
_, result = item
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'chunk', 'text': item}, ensure_ascii=False)}\n\n"
|
||||
session.requirement_analysis = result
|
||||
send_requirement_result(result)
|
||||
session.touch()
|
||||
yield f"data: {json.dumps({'type': 'done', 'status': session.status, 'data': {'requirement_analysis': result}}, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{session_id}/pm/refine", response_model=SessionResponse)
|
||||
def pm_refine(session_id: str, body: RefineRequest):
|
||||
"""根据反馈修改 PM 产出。"""
|
||||
session = _get_session_or_404(session_id)
|
||||
_require_status(session, SessionStatus.PM_DONE)
|
||||
|
||||
agent = PMAgent()
|
||||
session.requirement_analysis = agent.refine(session.requirement_analysis, body.feedback)
|
||||
send_requirement_result(session.requirement_analysis)
|
||||
|
||||
session.touch()
|
||||
return SessionResponse(
|
||||
session_id=session.session_id,
|
||||
status=session.status,
|
||||
ready=True,
|
||||
data={"requirement_analysis": session.requirement_analysis},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{session_id}/qa/stream")
|
||||
def qa_stream(session_id: str):
|
||||
"""流式返回 QA Agent 测试用例生成过程(SSE)。"""
|
||||
session = _get_session_or_404(session_id)
|
||||
_require_status(session, SessionStatus.PM_DONE)
|
||||
|
||||
if not session.requirement_analysis:
|
||||
raise HTTPException(status_code=400, detail="PM Agent 产出不存在")
|
||||
|
||||
agent = QAAgent()
|
||||
req_analysis = session.requirement_analysis
|
||||
|
||||
def generate():
|
||||
try:
|
||||
result = None
|
||||
for item in agent.stream_generate_test_cases(req_analysis):
|
||||
if isinstance(item, tuple):
|
||||
_, result = item
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'chunk', 'text': item}, ensure_ascii=False)}\n\n"
|
||||
session.test_cases = result
|
||||
send_test_cases(result)
|
||||
session.status = SessionStatus.QA_DONE
|
||||
session.touch()
|
||||
yield f"data: {json.dumps({'type': 'done', 'status': session.status, 'data': {'test_cases': result}}, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{session_id}/qa/run", response_model=SessionResponse)
|
||||
def qa_run(session_id: str):
|
||||
"""触发 QA Agent 生成测试用例。"""
|
||||
session = _get_session_or_404(session_id)
|
||||
_require_status(session, SessionStatus.PM_DONE)
|
||||
|
||||
if not session.requirement_analysis:
|
||||
raise HTTPException(status_code=400, detail="PM Agent 产出不存在")
|
||||
|
||||
agent = QAAgent()
|
||||
session.test_cases = agent.generate_test_cases(session.requirement_analysis)
|
||||
send_test_cases(session.test_cases)
|
||||
|
||||
session.status = SessionStatus.QA_DONE
|
||||
session.touch()
|
||||
return SessionResponse(
|
||||
session_id=session.session_id,
|
||||
status=session.status,
|
||||
ready=True,
|
||||
data={"test_cases": session.test_cases},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{session_id}/qa/refine/stream")
|
||||
def qa_refine_stream(session_id: str, feedback: str):
|
||||
"""流式修改 QA 产出(SSE),feedback 经由 query param 传入。"""
|
||||
session = _get_session_or_404(session_id)
|
||||
_require_status(session, SessionStatus.QA_DONE)
|
||||
|
||||
agent = QAAgent()
|
||||
previous = session.test_cases
|
||||
|
||||
def generate():
|
||||
try:
|
||||
result = None
|
||||
for item in agent.stream_refine(previous, feedback):
|
||||
if isinstance(item, tuple):
|
||||
_, result = item
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'chunk', 'text': item}, ensure_ascii=False)}\n\n"
|
||||
session.test_cases = result
|
||||
send_test_cases(result)
|
||||
session.touch()
|
||||
yield f"data: {json.dumps({'type': 'done', 'status': session.status, 'data': {'test_cases': result}}, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{session_id}/qa/refine", response_model=SessionResponse)
|
||||
def qa_refine(session_id: str, body: RefineRequest):
|
||||
"""根据反馈修改 QA 产出。"""
|
||||
session = _get_session_or_404(session_id)
|
||||
_require_status(session, SessionStatus.QA_DONE)
|
||||
|
||||
agent = QAAgent()
|
||||
session.test_cases = agent.refine(session.test_cases, body.feedback)
|
||||
send_test_cases(session.test_cases)
|
||||
|
||||
session.touch()
|
||||
return SessionResponse(
|
||||
session_id=session.session_id,
|
||||
status=session.status,
|
||||
ready=True,
|
||||
data={"test_cases": session.test_cases},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{session_id}/dev/stream")
|
||||
def dev_stream(session_id: str):
|
||||
"""流式返回 Dev Agent 代码生成过程(SSE)。"""
|
||||
session = _get_session_or_404(session_id)
|
||||
_require_status(session, SessionStatus.QA_DONE)
|
||||
|
||||
if not session.requirement_analysis or not session.test_cases:
|
||||
raise HTTPException(status_code=400, detail="PM / QA Agent 产出不完整")
|
||||
|
||||
agent = DevAgent()
|
||||
req_analysis = session.requirement_analysis
|
||||
test_cases = session.test_cases
|
||||
|
||||
def generate():
|
||||
try:
|
||||
result = None
|
||||
for item in agent.stream_generate_code(req_analysis, test_cases):
|
||||
if isinstance(item, tuple):
|
||||
_, result = item
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'chunk', 'text': item}, ensure_ascii=False)}\n\n"
|
||||
session.code_generation = result
|
||||
send_generate_code(result)
|
||||
session.status = SessionStatus.DEV_DONE
|
||||
session.touch()
|
||||
yield f"data: {json.dumps({'type': 'done', 'status': session.status, 'data': {'code_generation': result}}, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{session_id}/dev/run", response_model=SessionResponse)
|
||||
def dev_run(session_id: str):
|
||||
"""触发 Dev Agent 生成代码。"""
|
||||
session = _get_session_or_404(session_id)
|
||||
_require_status(session, SessionStatus.QA_DONE)
|
||||
|
||||
if not session.requirement_analysis or not session.test_cases:
|
||||
raise HTTPException(status_code=400, detail="PM / QA Agent 产出不完整")
|
||||
|
||||
agent = DevAgent()
|
||||
session.code_generation = agent.generate_code(session.requirement_analysis, session.test_cases)
|
||||
send_generate_code(session.code_generation)
|
||||
|
||||
session.status = SessionStatus.DEV_DONE
|
||||
session.touch()
|
||||
return SessionResponse(
|
||||
session_id=session.session_id,
|
||||
status=session.status,
|
||||
ready=True,
|
||||
data={"code_generation": session.code_generation},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{session_id}/dev/refine", response_model=SessionResponse)
|
||||
def dev_refine(session_id: str, body: RefineRequest):
|
||||
"""根据反馈修改 Dev 产出。"""
|
||||
session = _get_session_or_404(session_id)
|
||||
_require_status(session, SessionStatus.DEV_DONE)
|
||||
|
||||
agent = DevAgent()
|
||||
session.code_generation = agent.refine(
|
||||
session.code_generation,
|
||||
session.requirement_analysis,
|
||||
session.test_cases,
|
||||
body.feedback,
|
||||
)
|
||||
send_generate_code(session.code_generation)
|
||||
|
||||
session.touch()
|
||||
return SessionResponse(
|
||||
session_id=session.session_id,
|
||||
status=session.status,
|
||||
ready=True,
|
||||
data={"code_generation": session.code_generation},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{session_id}/test/run", response_model=SessionResponse)
|
||||
def test_run(session_id: str):
|
||||
"""在临时目录中真实执行 pytest,返回测试结果。"""
|
||||
session = _get_session_or_404(session_id)
|
||||
|
||||
if not session.code_generation:
|
||||
raise HTTPException(status_code=400, detail="Dev Agent 产出不存在")
|
||||
|
||||
result = run_python_tests(
|
||||
session.code_generation["java_code"],
|
||||
session.code_generation["unit_tests"],
|
||||
)
|
||||
session.test_execution = result
|
||||
session.status = SessionStatus.TEST_DONE
|
||||
session.touch()
|
||||
return SessionResponse(
|
||||
session_id=session.session_id,
|
||||
status=session.status,
|
||||
ready=True,
|
||||
data={"test_execution": result},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{session_id}/test/fix", response_model=SessionResponse)
|
||||
def test_fix(session_id: str):
|
||||
"""调用 FixAgent 根据测试失败信息自动修复代码。"""
|
||||
session = _get_session_or_404(session_id)
|
||||
_require_status(session, SessionStatus.TEST_DONE)
|
||||
|
||||
if not session.test_execution:
|
||||
raise HTTPException(status_code=400, detail="尚未执行测试")
|
||||
if session.test_execution.get("success"):
|
||||
raise HTTPException(status_code=400, detail="测试已全部通过,无需修复")
|
||||
|
||||
agent = FixAgent()
|
||||
session.code_generation = agent.fix(
|
||||
session.code_generation,
|
||||
session.test_execution["output"],
|
||||
)
|
||||
session.status = SessionStatus.DEV_DONE # 修复后重置为 dev_done,可再次测试
|
||||
session.touch()
|
||||
return SessionResponse(
|
||||
session_id=session.session_id,
|
||||
status=session.status,
|
||||
ready=True,
|
||||
data={"code_generation": session.code_generation},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{session_id}/test/fix/stream")
|
||||
def test_fix_stream(session_id: str):
|
||||
"""流式返回 FixAgent 代码修复过程(SSE)。"""
|
||||
session = _get_session_or_404(session_id)
|
||||
_require_status(session, SessionStatus.TEST_DONE)
|
||||
|
||||
if not session.test_execution:
|
||||
raise HTTPException(status_code=400, detail="尚未执行测试")
|
||||
if session.test_execution.get("success"):
|
||||
raise HTTPException(status_code=400, detail="测试已全部通过,无需修复")
|
||||
|
||||
agent = FixAgent()
|
||||
code_generation = session.code_generation
|
||||
test_output = session.test_execution["output"]
|
||||
|
||||
def generate():
|
||||
try:
|
||||
result = None
|
||||
for item in agent.stream_fix(code_generation, test_output):
|
||||
if isinstance(item, tuple):
|
||||
_, result = item
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'chunk', 'text': item}, ensure_ascii=False)}\n\n"
|
||||
session.code_generation = result
|
||||
session.status = SessionStatus.DEV_DONE
|
||||
session.touch()
|
||||
yield f"data: {json.dumps({'type': 'done', 'status': session.status, 'data': {'code_generation': result}}, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{session_id}", response_model=SessionResponse)
|
||||
def get_session(session_id: str):
|
||||
"""获取当前会话状态和所有产出。"""
|
||||
session = _get_session_or_404(session_id)
|
||||
return SessionResponse(
|
||||
session_id=session.session_id,
|
||||
status=session.status,
|
||||
ready=session.status != SessionStatus.CLARIFYING,
|
||||
data={
|
||||
"raw_requirement": session.raw_requirement,
|
||||
"clarify_history": session.clarify_history,
|
||||
"requirement_analysis": session.requirement_analysis,
|
||||
"test_cases": session.test_cases,
|
||||
"code_generation": session.code_generation,
|
||||
"test_execution": session.test_execution,
|
||||
},
|
||||
)
|
||||
Reference in New Issue
Block a user