515 lines
18 KiB
Python
515 lines
18 KiB
Python
"""
|
||
app/routers/session.py - 交互式会话路由
|
||
"""
|
||
|
||
import logging
|
||
import json
|
||
from fastapi import APIRouter, HTTPException
|
||
from fastapi.responses import StreamingResponse
|
||
from pydantic import BaseModel
|
||
from typing import Optional
|
||
|
||
from app.session import SessionStore, SessionStatus
|
||
from app.agents import ClarifyAgent, PMAgent, QAAgent, DevAgent, FixAgent
|
||
from app.test_runner import run_python_tests
|
||
from app.message import (
|
||
send_workflow_start,
|
||
send_requirement_result,
|
||
send_test_cases,
|
||
send_generate_code,
|
||
)
|
||
|
||
router = APIRouter(prefix="/session", tags=["session"])
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ---------- 请求 / 响应模型 ----------
|
||
|
||
class StartRequest(BaseModel):
|
||
requirement: str
|
||
|
||
class ClarifyRequest(BaseModel):
|
||
message: str
|
||
|
||
class RefineRequest(BaseModel):
|
||
feedback: str
|
||
|
||
|
||
class SessionResponse(BaseModel):
|
||
session_id: str
|
||
status: str
|
||
ready: bool = False
|
||
question: Optional[str] = None # 当 ready=False 时返回追问
|
||
data: Optional[dict] = None # 当前阶段产出
|
||
|
||
|
||
# ---------- 工具函数 ----------
|
||
|
||
def _get_session_or_404(session_id: str):
|
||
session = SessionStore.get(session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="会话不存在或已过期")
|
||
return session
|
||
|
||
def _require_status(session, *allowed: SessionStatus):
|
||
if session.status not in allowed:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"当前状态 [{session.status}] 不允许此操作,允许的状态: {[s.value for s in allowed]}"
|
||
)
|
||
|
||
|
||
# ---------- 接口 ----------
|
||
|
||
@router.post("/start", response_model=SessionResponse)
|
||
def start_session(body: StartRequest):
|
||
"""创建会话,AI 判断需求是否完整,不够则追问。"""
|
||
session = SessionStore.create(body.requirement)
|
||
agent = ClarifyAgent()
|
||
result = agent.start(body.requirement)
|
||
|
||
q = result.get("question", "")
|
||
if q:
|
||
session.clarify_history.append({"role": "assistant", "content": q})
|
||
session.clarified_requirement = result.get("clarified_requirement", body.requirement)
|
||
|
||
if result.get("ready"):
|
||
session.status = SessionStatus.PM_READY
|
||
|
||
session.touch()
|
||
return SessionResponse(
|
||
session_id=session.session_id,
|
||
status=session.status,
|
||
ready=result.get("ready", False),
|
||
question=result.get("question") or None,
|
||
)
|
||
|
||
|
||
@router.post("/{session_id}/clarify", response_model=SessionResponse)
|
||
def clarify(session_id: str, body: ClarifyRequest):
|
||
"""用户补充需求,AI 继续判断是否够了。"""
|
||
session = _get_session_or_404(session_id)
|
||
_require_status(session, SessionStatus.CLARIFYING)
|
||
|
||
session.clarify_history.append({"role": "user", "content": body.message})
|
||
|
||
agent = ClarifyAgent()
|
||
result = agent.continue_clarify(session.clarify_history, body.message)
|
||
|
||
q = result.get("question", "")
|
||
if q:
|
||
session.clarify_history.append({"role": "assistant", "content": q})
|
||
session.clarified_requirement = result.get("clarified_requirement", session.clarified_requirement)
|
||
|
||
if result.get("ready"):
|
||
session.status = SessionStatus.PM_READY
|
||
|
||
session.touch()
|
||
return SessionResponse(
|
||
session_id=session.session_id,
|
||
status=session.status,
|
||
ready=result.get("ready", False),
|
||
question=result.get("question") or None,
|
||
)
|
||
|
||
|
||
@router.get("/{session_id}/pm/stream")
|
||
def pm_stream(session_id: str):
|
||
"""流式返回 PM Agent 分析过程和结果(SSE)。"""
|
||
session = _get_session_or_404(session_id)
|
||
_require_status(session, SessionStatus.PM_READY)
|
||
|
||
send_workflow_start(session.clarified_requirement)
|
||
agent = PMAgent()
|
||
simple_requirement = session.clarified_requirement
|
||
|
||
def generate():
|
||
try:
|
||
result = None
|
||
for item in agent.stream_analyze(simple_requirement):
|
||
if isinstance(item, tuple):
|
||
_, result = item
|
||
else:
|
||
yield f"data: {json.dumps({'type': 'chunk', 'text': item}, ensure_ascii=False)}\n\n"
|
||
session.requirement_analysis = result
|
||
send_requirement_result(result)
|
||
session.status = SessionStatus.PM_DONE
|
||
session.touch()
|
||
yield f"data: {json.dumps({'type': 'done', 'status': session.status, 'data': {'requirement_analysis': result}}, ensure_ascii=False)}\n\n"
|
||
except Exception as e:
|
||
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
|
||
|
||
return StreamingResponse(
|
||
generate(),
|
||
media_type="text/event-stream",
|
||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||
)
|
||
|
||
|
||
@router.post("/{session_id}/pm/run", response_model=SessionResponse)
|
||
def pm_run(session_id: str):
|
||
"""触发 PM Agent 分析需求。"""
|
||
session = _get_session_or_404(session_id)
|
||
_require_status(session, SessionStatus.PM_READY)
|
||
|
||
send_workflow_start(session.clarified_requirement)
|
||
|
||
agent = PMAgent()
|
||
session.requirement_analysis = agent.analyze_requirement(session.clarified_requirement)
|
||
send_requirement_result(session.requirement_analysis)
|
||
|
||
session.status = SessionStatus.PM_DONE
|
||
session.touch()
|
||
return SessionResponse(
|
||
session_id=session.session_id,
|
||
status=session.status,
|
||
ready=True,
|
||
data={"requirement_analysis": session.requirement_analysis},
|
||
)
|
||
|
||
|
||
@router.get("/{session_id}/pm/refine/stream")
|
||
def pm_refine_stream(session_id: str, feedback: str):
|
||
"""流式修改 PM 产出(SSE),feedback 经由 query param 传入。"""
|
||
session = _get_session_or_404(session_id)
|
||
_require_status(session, SessionStatus.PM_DONE)
|
||
|
||
agent = PMAgent()
|
||
previous = session.requirement_analysis
|
||
|
||
def generate():
|
||
try:
|
||
result = None
|
||
for item in agent.stream_refine(previous, feedback):
|
||
if isinstance(item, tuple):
|
||
_, result = item
|
||
else:
|
||
yield f"data: {json.dumps({'type': 'chunk', 'text': item}, ensure_ascii=False)}\n\n"
|
||
session.requirement_analysis = result
|
||
send_requirement_result(result)
|
||
session.touch()
|
||
yield f"data: {json.dumps({'type': 'done', 'status': session.status, 'data': {'requirement_analysis': result}}, ensure_ascii=False)}\n\n"
|
||
except Exception as e:
|
||
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
|
||
|
||
return StreamingResponse(
|
||
generate(),
|
||
media_type="text/event-stream",
|
||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||
)
|
||
|
||
|
||
@router.post("/{session_id}/pm/refine", response_model=SessionResponse)
|
||
def pm_refine(session_id: str, body: RefineRequest):
|
||
"""根据反馈修改 PM 产出。"""
|
||
session = _get_session_or_404(session_id)
|
||
_require_status(session, SessionStatus.PM_DONE)
|
||
|
||
agent = PMAgent()
|
||
session.requirement_analysis = agent.refine(session.requirement_analysis, body.feedback)
|
||
send_requirement_result(session.requirement_analysis)
|
||
|
||
session.touch()
|
||
return SessionResponse(
|
||
session_id=session.session_id,
|
||
status=session.status,
|
||
ready=True,
|
||
data={"requirement_analysis": session.requirement_analysis},
|
||
)
|
||
|
||
|
||
@router.get("/{session_id}/qa/stream")
|
||
def qa_stream(session_id: str):
|
||
"""流式返回 QA Agent 测试用例生成过程(SSE)。"""
|
||
session = _get_session_or_404(session_id)
|
||
_require_status(session, SessionStatus.PM_DONE)
|
||
|
||
if not session.requirement_analysis:
|
||
raise HTTPException(status_code=400, detail="PM Agent 产出不存在")
|
||
|
||
agent = QAAgent()
|
||
req_analysis = session.requirement_analysis
|
||
|
||
def generate():
|
||
try:
|
||
result = None
|
||
for item in agent.stream_generate_test_cases(req_analysis):
|
||
if isinstance(item, tuple):
|
||
_, result = item
|
||
else:
|
||
yield f"data: {json.dumps({'type': 'chunk', 'text': item}, ensure_ascii=False)}\n\n"
|
||
session.test_cases = result
|
||
send_test_cases(result)
|
||
session.status = SessionStatus.QA_DONE
|
||
session.touch()
|
||
yield f"data: {json.dumps({'type': 'done', 'status': session.status, 'data': {'test_cases': result}}, ensure_ascii=False)}\n\n"
|
||
except Exception as e:
|
||
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
|
||
|
||
return StreamingResponse(
|
||
generate(),
|
||
media_type="text/event-stream",
|
||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||
)
|
||
|
||
|
||
@router.post("/{session_id}/qa/run", response_model=SessionResponse)
|
||
def qa_run(session_id: str):
|
||
"""触发 QA Agent 生成测试用例。"""
|
||
session = _get_session_or_404(session_id)
|
||
_require_status(session, SessionStatus.PM_DONE)
|
||
|
||
if not session.requirement_analysis:
|
||
raise HTTPException(status_code=400, detail="PM Agent 产出不存在")
|
||
|
||
agent = QAAgent()
|
||
session.test_cases = agent.generate_test_cases(session.requirement_analysis)
|
||
send_test_cases(session.test_cases)
|
||
|
||
session.status = SessionStatus.QA_DONE
|
||
session.touch()
|
||
return SessionResponse(
|
||
session_id=session.session_id,
|
||
status=session.status,
|
||
ready=True,
|
||
data={"test_cases": session.test_cases},
|
||
)
|
||
|
||
|
||
@router.get("/{session_id}/qa/refine/stream")
|
||
def qa_refine_stream(session_id: str, feedback: str):
|
||
"""流式修改 QA 产出(SSE),feedback 经由 query param 传入。"""
|
||
session = _get_session_or_404(session_id)
|
||
_require_status(session, SessionStatus.QA_DONE)
|
||
|
||
agent = QAAgent()
|
||
previous = session.test_cases
|
||
|
||
def generate():
|
||
try:
|
||
result = None
|
||
for item in agent.stream_refine(previous, feedback):
|
||
if isinstance(item, tuple):
|
||
_, result = item
|
||
else:
|
||
yield f"data: {json.dumps({'type': 'chunk', 'text': item}, ensure_ascii=False)}\n\n"
|
||
session.test_cases = result
|
||
send_test_cases(result)
|
||
session.touch()
|
||
yield f"data: {json.dumps({'type': 'done', 'status': session.status, 'data': {'test_cases': result}}, ensure_ascii=False)}\n\n"
|
||
except Exception as e:
|
||
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
|
||
|
||
return StreamingResponse(
|
||
generate(),
|
||
media_type="text/event-stream",
|
||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||
)
|
||
|
||
|
||
@router.post("/{session_id}/qa/refine", response_model=SessionResponse)
|
||
def qa_refine(session_id: str, body: RefineRequest):
|
||
"""根据反馈修改 QA 产出。"""
|
||
session = _get_session_or_404(session_id)
|
||
_require_status(session, SessionStatus.QA_DONE)
|
||
|
||
agent = QAAgent()
|
||
session.test_cases = agent.refine(session.test_cases, body.feedback)
|
||
send_test_cases(session.test_cases)
|
||
|
||
session.touch()
|
||
return SessionResponse(
|
||
session_id=session.session_id,
|
||
status=session.status,
|
||
ready=True,
|
||
data={"test_cases": session.test_cases},
|
||
)
|
||
|
||
|
||
@router.get("/{session_id}/dev/stream")
|
||
def dev_stream(session_id: str):
|
||
"""流式返回 Dev Agent 代码生成过程(SSE)。"""
|
||
session = _get_session_or_404(session_id)
|
||
_require_status(session, SessionStatus.QA_DONE)
|
||
|
||
if not session.requirement_analysis or not session.test_cases:
|
||
raise HTTPException(status_code=400, detail="PM / QA Agent 产出不完整")
|
||
|
||
agent = DevAgent()
|
||
req_analysis = session.requirement_analysis
|
||
test_cases = session.test_cases
|
||
|
||
def generate():
|
||
try:
|
||
result = None
|
||
for item in agent.stream_generate_code(req_analysis, test_cases):
|
||
if isinstance(item, tuple):
|
||
_, result = item
|
||
else:
|
||
yield f"data: {json.dumps({'type': 'chunk', 'text': item}, ensure_ascii=False)}\n\n"
|
||
session.code_generation = result
|
||
send_generate_code(result)
|
||
session.status = SessionStatus.DEV_DONE
|
||
session.touch()
|
||
yield f"data: {json.dumps({'type': 'done', 'status': session.status, 'data': {'code_generation': result}}, ensure_ascii=False)}\n\n"
|
||
except Exception as e:
|
||
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
|
||
|
||
return StreamingResponse(
|
||
generate(),
|
||
media_type="text/event-stream",
|
||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||
)
|
||
|
||
|
||
@router.post("/{session_id}/dev/run", response_model=SessionResponse)
|
||
def dev_run(session_id: str):
|
||
"""触发 Dev Agent 生成代码。"""
|
||
session = _get_session_or_404(session_id)
|
||
_require_status(session, SessionStatus.QA_DONE)
|
||
|
||
if not session.requirement_analysis or not session.test_cases:
|
||
raise HTTPException(status_code=400, detail="PM / QA Agent 产出不完整")
|
||
|
||
agent = DevAgent()
|
||
session.code_generation = agent.generate_code(session.requirement_analysis, session.test_cases)
|
||
send_generate_code(session.code_generation)
|
||
|
||
session.status = SessionStatus.DEV_DONE
|
||
session.touch()
|
||
return SessionResponse(
|
||
session_id=session.session_id,
|
||
status=session.status,
|
||
ready=True,
|
||
data={"code_generation": session.code_generation},
|
||
)
|
||
|
||
|
||
@router.post("/{session_id}/dev/refine", response_model=SessionResponse)
|
||
def dev_refine(session_id: str, body: RefineRequest):
|
||
"""根据反馈修改 Dev 产出。"""
|
||
session = _get_session_or_404(session_id)
|
||
_require_status(session, SessionStatus.DEV_DONE)
|
||
|
||
agent = DevAgent()
|
||
session.code_generation = agent.refine(
|
||
session.code_generation,
|
||
session.requirement_analysis,
|
||
session.test_cases,
|
||
body.feedback,
|
||
)
|
||
send_generate_code(session.code_generation)
|
||
|
||
session.touch()
|
||
return SessionResponse(
|
||
session_id=session.session_id,
|
||
status=session.status,
|
||
ready=True,
|
||
data={"code_generation": session.code_generation},
|
||
)
|
||
|
||
|
||
@router.post("/{session_id}/test/run", response_model=SessionResponse)
|
||
def test_run(session_id: str):
|
||
"""在临时目录中真实执行 pytest,返回测试结果。"""
|
||
session = _get_session_or_404(session_id)
|
||
|
||
if not session.code_generation:
|
||
raise HTTPException(status_code=400, detail="Dev Agent 产出不存在")
|
||
|
||
result = run_python_tests(
|
||
session.code_generation["java_code"],
|
||
session.code_generation["unit_tests"],
|
||
)
|
||
session.test_execution = result
|
||
session.status = SessionStatus.TEST_DONE
|
||
session.touch()
|
||
return SessionResponse(
|
||
session_id=session.session_id,
|
||
status=session.status,
|
||
ready=True,
|
||
data={"test_execution": result},
|
||
)
|
||
|
||
|
||
@router.post("/{session_id}/test/fix", response_model=SessionResponse)
|
||
def test_fix(session_id: str):
|
||
"""调用 FixAgent 根据测试失败信息自动修复代码。"""
|
||
session = _get_session_or_404(session_id)
|
||
_require_status(session, SessionStatus.TEST_DONE)
|
||
|
||
if not session.test_execution:
|
||
raise HTTPException(status_code=400, detail="尚未执行测试")
|
||
if session.test_execution.get("success"):
|
||
raise HTTPException(status_code=400, detail="测试已全部通过,无需修复")
|
||
|
||
agent = FixAgent()
|
||
session.code_generation = agent.fix(
|
||
session.code_generation,
|
||
session.test_execution["output"],
|
||
)
|
||
session.status = SessionStatus.DEV_DONE # 修复后重置为 dev_done,可再次测试
|
||
session.touch()
|
||
return SessionResponse(
|
||
session_id=session.session_id,
|
||
status=session.status,
|
||
ready=True,
|
||
data={"code_generation": session.code_generation},
|
||
)
|
||
|
||
|
||
@router.get("/{session_id}/test/fix/stream")
|
||
def test_fix_stream(session_id: str):
|
||
"""流式返回 FixAgent 代码修复过程(SSE)。"""
|
||
session = _get_session_or_404(session_id)
|
||
_require_status(session, SessionStatus.TEST_DONE)
|
||
|
||
if not session.test_execution:
|
||
raise HTTPException(status_code=400, detail="尚未执行测试")
|
||
if session.test_execution.get("success"):
|
||
raise HTTPException(status_code=400, detail="测试已全部通过,无需修复")
|
||
|
||
agent = FixAgent()
|
||
code_generation = session.code_generation
|
||
test_output = session.test_execution["output"]
|
||
|
||
def generate():
|
||
try:
|
||
result = None
|
||
for item in agent.stream_fix(code_generation, test_output):
|
||
if isinstance(item, tuple):
|
||
_, result = item
|
||
else:
|
||
yield f"data: {json.dumps({'type': 'chunk', 'text': item}, ensure_ascii=False)}\n\n"
|
||
session.code_generation = result
|
||
session.status = SessionStatus.DEV_DONE
|
||
session.touch()
|
||
yield f"data: {json.dumps({'type': 'done', 'status': session.status, 'data': {'code_generation': result}}, ensure_ascii=False)}\n\n"
|
||
except Exception as e:
|
||
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
|
||
|
||
return StreamingResponse(
|
||
generate(),
|
||
media_type="text/event-stream",
|
||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||
)
|
||
|
||
|
||
@router.get("/{session_id}", response_model=SessionResponse)
|
||
def get_session(session_id: str):
|
||
"""获取当前会话状态和所有产出。"""
|
||
session = _get_session_or_404(session_id)
|
||
return SessionResponse(
|
||
session_id=session.session_id,
|
||
status=session.status,
|
||
ready=session.status != SessionStatus.CLARIFYING,
|
||
data={
|
||
"raw_requirement": session.raw_requirement,
|
||
"clarify_history": session.clarify_history,
|
||
"requirement_analysis": session.requirement_analysis,
|
||
"test_cases": session.test_cases,
|
||
"code_generation": session.code_generation,
|
||
"test_execution": session.test_execution,
|
||
},
|
||
)
|