""" app/routers/session.py - 交互式会话路由 """ import logging import json from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel from typing import Optional from app.session import SessionStore, SessionStatus from app.agents import ClarifyAgent, PMAgent, QAAgent, DevAgent, FixAgent from app.test_runner import run_python_tests from app.message import ( send_workflow_start, send_requirement_result, send_test_cases, send_generate_code, ) router = APIRouter(prefix="/session", tags=["session"]) logger = logging.getLogger(__name__) # ---------- 请求 / 响应模型 ---------- class StartRequest(BaseModel): requirement: str class ClarifyRequest(BaseModel): message: str class RefineRequest(BaseModel): feedback: str class SessionResponse(BaseModel): session_id: str status: str ready: bool = False question: Optional[str] = None # 当 ready=False 时返回追问 data: Optional[dict] = None # 当前阶段产出 # ---------- 工具函数 ---------- def _get_session_or_404(session_id: str): session = SessionStore.get(session_id) if not session: raise HTTPException(status_code=404, detail="会话不存在或已过期") return session def _require_status(session, *allowed: SessionStatus): if session.status not in allowed: raise HTTPException( status_code=400, detail=f"当前状态 [{session.status}] 不允许此操作,允许的状态: {[s.value for s in allowed]}" ) # ---------- 接口 ---------- @router.post("/start", response_model=SessionResponse) def start_session(body: StartRequest): """创建会话,AI 判断需求是否完整,不够则追问。""" session = SessionStore.create(body.requirement) agent = ClarifyAgent() result = agent.start(body.requirement) q = result.get("question", "") if q: session.clarify_history.append({"role": "assistant", "content": q}) session.clarified_requirement = result.get("clarified_requirement", body.requirement) if result.get("ready"): session.status = SessionStatus.PM_READY session.touch() return SessionResponse( session_id=session.session_id, status=session.status, ready=result.get("ready", False), question=result.get("question") or None, ) @router.post("/{session_id}/clarify", response_model=SessionResponse) def clarify(session_id: str, body: ClarifyRequest): """用户补充需求,AI 继续判断是否够了。""" session = _get_session_or_404(session_id) _require_status(session, SessionStatus.CLARIFYING) session.clarify_history.append({"role": "user", "content": body.message}) agent = ClarifyAgent() result = agent.continue_clarify(session.clarify_history, body.message) q = result.get("question", "") if q: session.clarify_history.append({"role": "assistant", "content": q}) session.clarified_requirement = result.get("clarified_requirement", session.clarified_requirement) if result.get("ready"): session.status = SessionStatus.PM_READY session.touch() return SessionResponse( session_id=session.session_id, status=session.status, ready=result.get("ready", False), question=result.get("question") or None, ) @router.get("/{session_id}/pm/stream") def pm_stream(session_id: str): """流式返回 PM Agent 分析过程和结果(SSE)。""" session = _get_session_or_404(session_id) _require_status(session, SessionStatus.PM_READY) send_workflow_start(session.clarified_requirement) agent = PMAgent() simple_requirement = session.clarified_requirement def generate(): try: result = None for item in agent.stream_analyze(simple_requirement): if isinstance(item, tuple): _, result = item else: yield f"data: {json.dumps({'type': 'chunk', 'text': item}, ensure_ascii=False)}\n\n" session.requirement_analysis = result send_requirement_result(result) session.status = SessionStatus.PM_DONE session.touch() yield f"data: {json.dumps({'type': 'done', 'status': session.status, 'data': {'requirement_analysis': result}}, ensure_ascii=False)}\n\n" except Exception as e: yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" return StreamingResponse( generate(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) @router.post("/{session_id}/pm/run", response_model=SessionResponse) def pm_run(session_id: str): """触发 PM Agent 分析需求。""" session = _get_session_or_404(session_id) _require_status(session, SessionStatus.PM_READY) send_workflow_start(session.clarified_requirement) agent = PMAgent() session.requirement_analysis = agent.analyze_requirement(session.clarified_requirement) send_requirement_result(session.requirement_analysis) session.status = SessionStatus.PM_DONE session.touch() return SessionResponse( session_id=session.session_id, status=session.status, ready=True, data={"requirement_analysis": session.requirement_analysis}, ) @router.get("/{session_id}/pm/refine/stream") def pm_refine_stream(session_id: str, feedback: str): """流式修改 PM 产出(SSE),feedback 经由 query param 传入。""" session = _get_session_or_404(session_id) _require_status(session, SessionStatus.PM_DONE) agent = PMAgent() previous = session.requirement_analysis def generate(): try: result = None for item in agent.stream_refine(previous, feedback): if isinstance(item, tuple): _, result = item else: yield f"data: {json.dumps({'type': 'chunk', 'text': item}, ensure_ascii=False)}\n\n" session.requirement_analysis = result send_requirement_result(result) session.touch() yield f"data: {json.dumps({'type': 'done', 'status': session.status, 'data': {'requirement_analysis': result}}, ensure_ascii=False)}\n\n" except Exception as e: yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" return StreamingResponse( generate(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) @router.post("/{session_id}/pm/refine", response_model=SessionResponse) def pm_refine(session_id: str, body: RefineRequest): """根据反馈修改 PM 产出。""" session = _get_session_or_404(session_id) _require_status(session, SessionStatus.PM_DONE) agent = PMAgent() session.requirement_analysis = agent.refine(session.requirement_analysis, body.feedback) send_requirement_result(session.requirement_analysis) session.touch() return SessionResponse( session_id=session.session_id, status=session.status, ready=True, data={"requirement_analysis": session.requirement_analysis}, ) @router.get("/{session_id}/qa/stream") def qa_stream(session_id: str): """流式返回 QA Agent 测试用例生成过程(SSE)。""" session = _get_session_or_404(session_id) _require_status(session, SessionStatus.PM_DONE) if not session.requirement_analysis: raise HTTPException(status_code=400, detail="PM Agent 产出不存在") agent = QAAgent() req_analysis = session.requirement_analysis def generate(): try: result = None for item in agent.stream_generate_test_cases(req_analysis): if isinstance(item, tuple): _, result = item else: yield f"data: {json.dumps({'type': 'chunk', 'text': item}, ensure_ascii=False)}\n\n" session.test_cases = result send_test_cases(result) session.status = SessionStatus.QA_DONE session.touch() yield f"data: {json.dumps({'type': 'done', 'status': session.status, 'data': {'test_cases': result}}, ensure_ascii=False)}\n\n" except Exception as e: yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" return StreamingResponse( generate(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) @router.post("/{session_id}/qa/run", response_model=SessionResponse) def qa_run(session_id: str): """触发 QA Agent 生成测试用例。""" session = _get_session_or_404(session_id) _require_status(session, SessionStatus.PM_DONE) if not session.requirement_analysis: raise HTTPException(status_code=400, detail="PM Agent 产出不存在") agent = QAAgent() session.test_cases = agent.generate_test_cases(session.requirement_analysis) send_test_cases(session.test_cases) session.status = SessionStatus.QA_DONE session.touch() return SessionResponse( session_id=session.session_id, status=session.status, ready=True, data={"test_cases": session.test_cases}, ) @router.get("/{session_id}/qa/refine/stream") def qa_refine_stream(session_id: str, feedback: str): """流式修改 QA 产出(SSE),feedback 经由 query param 传入。""" session = _get_session_or_404(session_id) _require_status(session, SessionStatus.QA_DONE) agent = QAAgent() previous = session.test_cases def generate(): try: result = None for item in agent.stream_refine(previous, feedback): if isinstance(item, tuple): _, result = item else: yield f"data: {json.dumps({'type': 'chunk', 'text': item}, ensure_ascii=False)}\n\n" session.test_cases = result send_test_cases(result) session.touch() yield f"data: {json.dumps({'type': 'done', 'status': session.status, 'data': {'test_cases': result}}, ensure_ascii=False)}\n\n" except Exception as e: yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" return StreamingResponse( generate(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) @router.post("/{session_id}/qa/refine", response_model=SessionResponse) def qa_refine(session_id: str, body: RefineRequest): """根据反馈修改 QA 产出。""" session = _get_session_or_404(session_id) _require_status(session, SessionStatus.QA_DONE) agent = QAAgent() session.test_cases = agent.refine(session.test_cases, body.feedback) send_test_cases(session.test_cases) session.touch() return SessionResponse( session_id=session.session_id, status=session.status, ready=True, data={"test_cases": session.test_cases}, ) @router.get("/{session_id}/dev/stream") def dev_stream(session_id: str): """流式返回 Dev Agent 代码生成过程(SSE)。""" session = _get_session_or_404(session_id) _require_status(session, SessionStatus.QA_DONE) if not session.requirement_analysis or not session.test_cases: raise HTTPException(status_code=400, detail="PM / QA Agent 产出不完整") agent = DevAgent() req_analysis = session.requirement_analysis test_cases = session.test_cases def generate(): try: result = None for item in agent.stream_generate_code(req_analysis, test_cases): if isinstance(item, tuple): _, result = item else: yield f"data: {json.dumps({'type': 'chunk', 'text': item}, ensure_ascii=False)}\n\n" session.code_generation = result send_generate_code(result) session.status = SessionStatus.DEV_DONE session.touch() yield f"data: {json.dumps({'type': 'done', 'status': session.status, 'data': {'code_generation': result}}, ensure_ascii=False)}\n\n" except Exception as e: yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" return StreamingResponse( generate(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) @router.post("/{session_id}/dev/run", response_model=SessionResponse) def dev_run(session_id: str): """触发 Dev Agent 生成代码。""" session = _get_session_or_404(session_id) _require_status(session, SessionStatus.QA_DONE) if not session.requirement_analysis or not session.test_cases: raise HTTPException(status_code=400, detail="PM / QA Agent 产出不完整") agent = DevAgent() session.code_generation = agent.generate_code(session.requirement_analysis, session.test_cases) send_generate_code(session.code_generation) session.status = SessionStatus.DEV_DONE session.touch() return SessionResponse( session_id=session.session_id, status=session.status, ready=True, data={"code_generation": session.code_generation}, ) @router.post("/{session_id}/dev/refine", response_model=SessionResponse) def dev_refine(session_id: str, body: RefineRequest): """根据反馈修改 Dev 产出。""" session = _get_session_or_404(session_id) _require_status(session, SessionStatus.DEV_DONE) agent = DevAgent() session.code_generation = agent.refine( session.code_generation, session.requirement_analysis, session.test_cases, body.feedback, ) send_generate_code(session.code_generation) session.touch() return SessionResponse( session_id=session.session_id, status=session.status, ready=True, data={"code_generation": session.code_generation}, ) @router.post("/{session_id}/test/run", response_model=SessionResponse) def test_run(session_id: str): """在临时目录中真实执行 pytest,返回测试结果。""" session = _get_session_or_404(session_id) if not session.code_generation: raise HTTPException(status_code=400, detail="Dev Agent 产出不存在") result = run_python_tests( session.code_generation["java_code"], session.code_generation["unit_tests"], ) session.test_execution = result session.status = SessionStatus.TEST_DONE session.touch() return SessionResponse( session_id=session.session_id, status=session.status, ready=True, data={"test_execution": result}, ) @router.post("/{session_id}/test/fix", response_model=SessionResponse) def test_fix(session_id: str): """调用 FixAgent 根据测试失败信息自动修复代码。""" session = _get_session_or_404(session_id) _require_status(session, SessionStatus.TEST_DONE) if not session.test_execution: raise HTTPException(status_code=400, detail="尚未执行测试") if session.test_execution.get("success"): raise HTTPException(status_code=400, detail="测试已全部通过,无需修复") agent = FixAgent() session.code_generation = agent.fix( session.code_generation, session.test_execution["output"], ) session.status = SessionStatus.DEV_DONE # 修复后重置为 dev_done,可再次测试 session.touch() return SessionResponse( session_id=session.session_id, status=session.status, ready=True, data={"code_generation": session.code_generation}, ) @router.get("/{session_id}/test/fix/stream") def test_fix_stream(session_id: str): """流式返回 FixAgent 代码修复过程(SSE)。""" session = _get_session_or_404(session_id) _require_status(session, SessionStatus.TEST_DONE) if not session.test_execution: raise HTTPException(status_code=400, detail="尚未执行测试") if session.test_execution.get("success"): raise HTTPException(status_code=400, detail="测试已全部通过,无需修复") agent = FixAgent() code_generation = session.code_generation test_output = session.test_execution["output"] def generate(): try: result = None for item in agent.stream_fix(code_generation, test_output): if isinstance(item, tuple): _, result = item else: yield f"data: {json.dumps({'type': 'chunk', 'text': item}, ensure_ascii=False)}\n\n" session.code_generation = result session.status = SessionStatus.DEV_DONE session.touch() yield f"data: {json.dumps({'type': 'done', 'status': session.status, 'data': {'code_generation': result}}, ensure_ascii=False)}\n\n" except Exception as e: yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" return StreamingResponse( generate(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) @router.get("/{session_id}", response_model=SessionResponse) def get_session(session_id: str): """获取当前会话状态和所有产出。""" session = _get_session_or_404(session_id) return SessionResponse( session_id=session.session_id, status=session.status, ready=session.status != SessionStatus.CLARIFYING, data={ "raw_requirement": session.raw_requirement, "clarify_history": session.clarify_history, "requirement_analysis": session.requirement_analysis, "test_cases": session.test_cases, "code_generation": session.code_generation, "test_execution": session.test_execution, }, )