diff --git a/app/agents.py b/app/agents.py index e97bf41..7d09c81 100644 --- a/app/agents.py +++ b/app/agents.py @@ -15,6 +15,94 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +def _parse_json(content: str, error_prefix: str) -> dict: + """解析 LLM 返回的 JSON,解析失败时尝试正则兜底。""" + try: + return json.loads(content) + except json.JSONDecodeError: + import re + m = re.search(r'\{.*\}', content, re.DOTALL) + if m: + return json.loads(m.group()) + raise ValueError(f"{error_prefix}: {content}") + + +class ClarifyAgent: + """需求澄清 Agent —— 判断需求是否完整,并追问用户补充信息""" + + def __init__(self): + self.settings = get_settings() + self.client = OpenAI(api_key=self.settings.api_key, base_url=self.settings.base_url) + + def start(self, raw_requirement: str) -> dict: + """ + 用户第一次提交需求时调用。 + 返回 {"ready": bool, "question": str, "clarified_requirement": str} + """ + prompt = f"""你是一个资深的产品经理助手,负责在正式分析需求之前确认需求的完整性。 + +用户提交的需求: +{raw_requirement} + +请判断该需求是否足够清晰,可以直接开始产品分析。 + +返回 JSON: +{{ + "ready": true 或 false, + "question": "如果 ready=false,给用户一个简洁的追问;如果 ready=true,此字段为空字符串", + "clarified_requirement": "整合后的完整需求描述(即使 ready=false 也要输出当前已有的描述)" +}} + +判断标准: +- 需求描述清楚要做什么 +- 有基本的场景或目标 +- 不需要追问过多细节,够 PM 开始分析即可 + +返回 ONLY JSON,不要有其他文字。""" + + response = self.client.chat.completions.create( + model=self.settings.model, + messages=[ + {"role": "system", "content": "你是一个产品经理助手,负责判断需求完整性,输出必须是严格的 JSON 格式。"}, + {"role": "user", "content": prompt} + ], + temperature=0.2, + max_tokens=500, + response_format={"type": "json_object"} + ) + return _parse_json(response.choices[0].message.content, "ClarifyAgent.start 解析失败") + + def continue_clarify(self, clarify_history: list[dict], user_reply: str) -> dict: + """ + 用户补充信息后继续澄清。 + clarify_history 格式:[{"role": "assistant"|"user", "content": str}, ...] + 返回同 start() 的格式。 + """ + messages = [ + {"role": "system", "content": "你是一个产品经理助手,负责判断需求完整性,输出必须是严格的 JSON 格式。"} + ] + clarify_history + [ + {"role": "user", "content": user_reply} + ] + + suffix = """\n\n根据以上对话,判断需求现在是否足够清晰。返回 JSON: +{{ + "ready": true 或 false, + "question": "如果 ready=false,继续追问;如果 ready=true,此字段为空字符串", + "clarified_requirement": "整合所有对话后的完整需求描述" +}} +返回 ONLY JSON,不要有其他文字。""" + messages.append({"role": "user", "content": suffix}) + + response = self.client.chat.completions.create( + model=self.settings.model, + messages=messages, + temperature=0.2, + max_tokens=500, + response_format={"type": "json_object"} + ) + return _parse_json(response.choices[0].message.content, "ClarifyAgent.continue_clarify 解析失败") + + class PMAgent: """产品经理Agent - 完善和扩展需求""" @@ -22,20 +110,10 @@ class PMAgent: self.settings = get_settings() self.client = OpenAI(api_key=self.settings.api_key, base_url=self.settings.base_url) - def analyze_requirement(self, simple_requirement: str) -> RequirementAnalysis: - """ - 分析和完善简单的需求描述 + def _build_prompt(self, requirement: str) -> str: + return f"""你是一个资深的产品经理。请根据以下简单的需求描述,进行深入分析和完善。 - Args: - simple_requirement: 用户提供的简单需求描述 - - Returns: - RequirementAnalysis: 包含完善后的需求信息 - """ - - prompt = f"""你是一个资深的产品经理。请根据以下简单的需求描述,进行深入分析和完善。 - -需求描述:{simple_requirement} +需求描述:{requirement} 请按以下格式返回JSON结果(必须是有效的JSON格式): {{ @@ -55,6 +133,104 @@ class PMAgent: 返回ONLY JSON内容,不要有其他文字。""" + def analyze_requirement(self, simple_requirement: str) -> RequirementAnalysis: + """ + 分析和完善简单的需求描述 + + Args: + simple_requirement: 用户提供的简单需求描述 + + Returns: + RequirementAnalysis: 包含完善后的需求信息 + """ + response = self.client.chat.completions.create( + model=self.settings.model, + messages=[ + {"role": "system", "content": "你是一个资深的产品经理,擅长需求分析和拆解,输出必须是严格的 JSON 格式。"}, + {"role": "user", "content": self._build_prompt(simple_requirement)} + ], + temperature=0.2, + max_tokens=2000, + response_format={"type": "json_object"} + ) + result = _parse_json(response.choices[0].message.content, "PMAgent 解析失败") + return RequirementAnalysis(**result) + + def stream_analyze(self, simple_requirement: str): + """ + 流式版需求分析。yield 文本块(str),最后 yield (None, RequirementAnalysis) 作为哨兵。 + """ + stream = self.client.chat.completions.create( + model=self.settings.model, + messages=[ + {"role": "system", "content": "你是一个资深的产品经理,擅长需求分析和拆解,输出必须是严格的 JSON 格式。"}, + {"role": "user", "content": self._build_prompt(simple_requirement)} + ], + temperature=0.2, + max_tokens=2000, + response_format={"type": "json_object"}, + stream=True, + ) + full_text = "" + for chunk in stream: + delta = chunk.choices[0].delta.content or "" + if delta: + full_text += delta + yield delta + result = _parse_json(full_text, "PMAgent stream 解析失败") + yield (None, RequirementAnalysis(**result)) + + def stream_refine(self, previous: RequirementAnalysis, feedback: str): + """ + 流式版需求分析修改。yield 文本块(str),最后 yield (None, RequirementAnalysis) 作为哨兵。 + """ + prompt = f"""你是一个资深的产品经理。以下是你之前输出的需求分析结果,用户对此有修改意见,请根据意见调整输出。 + +之前的需求分析: +{json.dumps(previous, ensure_ascii=False, indent=2)} + +用户的修改意见: +{feedback} + +请在原有基础上修改,保持 JSON 格式不变,返回完整的修改后结果,返回 ONLY JSON,不要有其他文字。""" + + stream = self.client.chat.completions.create( + model=self.settings.model, + messages=[ + {"role": "system", "content": "你是一个资深的产品经理,擅长需求分析和拆解,输出必须是严格的 JSON 格式。"}, + {"role": "user", "content": prompt} + ], + temperature=0.2, + max_tokens=2000, + response_format={"type": "json_object"}, + stream=True, + ) + full_text = "" + for chunk in stream: + delta = chunk.choices[0].delta.content or "" + if delta: + full_text += delta + yield delta + result = _parse_json(full_text, "PMAgent stream_refine 解析失败") + yield (None, RequirementAnalysis(**result)) + + def refine(self, previous: RequirementAnalysis, feedback: str) -> RequirementAnalysis: + """ + 根据用户反馈修改已有的需求分析结果。 + + Returns: + RequirementAnalysis: 修改后的需求分析结果 + """ + prompt = f"""你是一个资深的产品经理。以下是你之前输出的需求分析结果,用户对此有修改意见,请根据意见调整输出。 + +之前的需求分析: +{json.dumps(previous, ensure_ascii=False, indent=2)} + +用户的修改意见: +{feedback} + +请在原有基础上修改,保持 JSON 格式不变,返回完整的修改后结果,返回 ONLY JSON,不要有其他文字。""" + response = self.client.chat.completions.create( model=self.settings.model, messages=[ @@ -65,22 +241,8 @@ class PMAgent: max_tokens=2000, response_format={"type": "json_object"} ) - - # 提取响应内容 - content = response.choices[0].message.content - - # 解析JSON - try: - result = json.loads(content) - return RequirementAnalysis(**result) - except json.JSONDecodeError: - # 如果JSON解析失败,尝试提取JSON部分 - import re - json_match = re.search(r'\{.*\}', content, re.DOTALL) - if json_match: - result = json.loads(json_match.group()) - return RequirementAnalysis(**result) - raise ValueError(f"无法解析Agent响应: {content}") + result = _parse_json(response.choices[0].message.content, "PMAgent.refine 解析失败") + return RequirementAnalysis(**result) class QAAgent: @@ -90,18 +252,7 @@ class QAAgent: self.settings = get_settings() self.client = OpenAI(api_key=self.settings.api_key, base_url=self.settings.base_url) - def generate_test_cases(self, requirement_analysis: RequirementAnalysis) -> TestCaseResult: - """ - 基于需求分析生成测试用例 - - Args: - requirement_analysis: PM Agent的分析结果 - - Returns: - TestCaseResult: 包含测试用例的结果 - """ - - # 构建需求信息 + def _build_prompt(self, requirement_analysis: RequirementAnalysis) -> str: requirement_text = f""" 功能需求: {chr(10).join(f"- {req}" for req in requirement_analysis["functional_requirements"])} @@ -118,8 +269,7 @@ class QAAgent: 需求总结: {requirement_analysis["summary"]} """ - - prompt = f"""你是一个资深的Java QA工程师。根据以下需求信息,生成全面的Java测试用例和测试策略。所有测试用例必须基于Java语言,步骤和预期结果要符合Java的类型系统、异常机制和JUnit测试框架。 + return f"""你是一个资深的Python QA工程师。根据以下需求信息,生成全面的Python测试用例和测试策略。所有测试用例必须基于Python语言,步骤和预期结果要符合Python的类型系统、异常机制和pytest测试框架。 {requirement_text} @@ -145,61 +295,140 @@ class QAAgent: 2. 为每个边缘情况生成1个测试用例 3. 生成至少1个性能测试用例 4. 生成至少1个安全测试用例 -5. 测试用例要包含明确的步骤和预期结果,步骤和预期结果必须符合Java语言特性(不要出现Python或其他语言的描述) +5. 测试用例要包含明确的步骤和预期结果,步骤和预期结果必须符合Python语言特性(不要出现Java或其他语言的描述) 6. 步骤和预期结果必须用自然语言描述,不得包含任何代码片段或代码块,不要出现 ```、assert、assertEquals 等代码语法 -7. 测试策略不要出现JUnit或者Java这种字眼,应该是针对需求的测试方法论和思路描述,覆盖计划要说明如何确保测试覆盖所有功能和边界情况 +7. 测试策略不要出现pytest或者Python这种字眼,应该是针对需求的测试方法论和思路描述,覆盖计划要说明如何确保测试覆盖所有功能和边界情况 返回ONLY JSON内容,不要有其他文字。""" + def generate_test_cases(self, requirement_analysis: RequirementAnalysis) -> TestCaseResult: + """ + 基于需求分析生成测试用例 + + Args: + requirement_analysis: PM Agent的分析结果 + + Returns: + TestCaseResult: 包含测试用例的结果 + """ + response = self.client.chat.completions.create( + model=self.settings.model, + messages=[ + {"role": "system", "content": "你是一个资深的 Python QA 工程师,擅长为 Python 应用程序设计测试用例,所有测试步骤和预期结果必须基于 Python 语言特性(如 pytest、动态类型、异常机制等),输出必须是严格的 JSON 格式。"}, + {"role": "user", "content": self._build_prompt(requirement_analysis)} + ], + temperature=0.2, + max_tokens=3000, + response_format={"type": "json_object"} + ) + result = _parse_json(response.choices[0].message.content, "QAAgent 解析失败") + return TestCaseResult(**result) + + def stream_generate_test_cases(self, requirement_analysis: RequirementAnalysis): + """ + 流式版测试用例生成。yield 文本块(str),最后 yield (None, TestCaseResult) 作为哨兵。 + """ + stream = self.client.chat.completions.create( + model=self.settings.model, + messages=[ + {"role": "system", "content": "你是一个资深的 Python QA 工程师,擅长为 Python 应用程序设计测试用例,输出必须是严格的 JSON 格式。"}, + {"role": "user", "content": self._build_prompt(requirement_analysis)} + ], + temperature=0.2, + max_tokens=3000, + response_format={"type": "json_object"}, + stream=True, + ) + full_text = "" + for chunk in stream: + delta = chunk.choices[0].delta.content or "" + if delta: + full_text += delta + yield delta + result = _parse_json(full_text, "QAAgent stream 解析失败") + yield (None, TestCaseResult(**result)) + + def stream_refine(self, previous: TestCaseResult, feedback: str): + """ + 流式版测试用例修改。yield 文本块(str),最后 yield (None, TestCaseResult) 作为哨兵。 + """ + prompt = f"""你是一个资深的Python QA工程师。以下是你之前输出的测试用例,用户对此有修改意见,请根据意见调整输出。 + +之前的测试用例: +{json.dumps(previous, ensure_ascii=False, indent=2)} + +用户的修改意见: +{feedback} + +请在原有基础上修改,保持 JSON 格式不变,返回完整的修改后结果。 +步骤和预期结果必须用自然语言描述,不得包含代码片段。 +返回 ONLY JSON,不要有其他文字。""" + + stream = self.client.chat.completions.create( + model=self.settings.model, + messages=[ + {"role": "system", "content": "你是一个资深的 Python QA 工程师,输出必须是严格的 JSON 格式。"}, + {"role": "user", "content": prompt} + ], + temperature=0.2, + max_tokens=3000, + response_format={"type": "json_object"}, + stream=True, + ) + full_text = "" + for chunk in stream: + delta = chunk.choices[0].delta.content or "" + if delta: + full_text += delta + yield delta + result = _parse_json(full_text, "QAAgent stream_refine 解析失败") + yield (None, TestCaseResult(**result)) + + def refine(self, previous: TestCaseResult, feedback: str) -> TestCaseResult: + """ + 根据用户反馈修改已有的测试用例结果。 + + Args: + previous: 上一次的测试用例结果 + feedback: 用户的修改意见 + + Returns: + TestCaseResult: 修改后的测试用例结果 + """ + prompt = f"""你是一个资深的Python QA工程师。以下是你之前输出的测试用例,用户对此有修改意见,请根据意见调整输出。 + +之前的测试用例: +{json.dumps(previous, ensure_ascii=False, indent=2)} + +用户的修改意见: +{feedback} + +请在原有基础上修改,保持 JSON 格式不变,返回完整的修改后结果。 +步骤和预期结果必须用自然语言描述,不得包含代码片段。 +返回 ONLY JSON,不要有其他文字。""" + response = self.client.chat.completions.create( model=self.settings.model, messages=[ - {"role": "system", "content": "你是一个资深的 Java QA 工程师,擅长为 Java 应用程序设计测试用例,所有测试步骤和预期结果必须基于 Java 语言特性(如 JUnit、强类型系统、异常机制等),输出必须是严格的 JSON 格式。"}, + {"role": "system", "content": "你是一个资深的 Python QA 工程师,输出必须是严格的 JSON 格式。"}, {"role": "user", "content": prompt} ], temperature=0.2, max_tokens=3000, response_format={"type": "json_object"} ) - - content = response.choices[0].message.content - - try: - result = json.loads(content) - return TestCaseResult(**result) - except json.JSONDecodeError: - import re - json_match = re.search(r'\{.*\}', content, re.DOTALL) - if json_match: - result = json.loads(json_match.group()) - return TestCaseResult(**result) - raise ValueError(f"无法解析QA Agent响应: {content}") + result = _parse_json(response.choices[0].message.content, "QAAgent.refine 解析失败") + return TestCaseResult(**result) class DevAgent: - """开发Agent - 生成Java代码和单元测试""" + """开发Agent - 生成Python代码和单元测试""" def __init__(self): self.settings = get_settings() self.client = OpenAI(api_key=self.settings.api_key, base_url=self.settings.base_url) - def generate_code( - self, - requirement_analysis: RequirementAnalysis, - test_cases: TestCaseResult - ) -> CodeGenerationResult: - """ - 生成Java实现代码和单元测试代码 - - Args: - requirement_analysis: PM Agent的分析结果 - test_cases: QA Agent的测试用例 - - Returns: - CodeGenerationResult: 包含Java代码和单元测试代码 - """ - - # 构建上下文 + def _build_prompt(self, requirement_analysis: RequirementAnalysis, test_cases: TestCaseResult) -> str: requirement_text = f""" 功能需求: {chr(10).join(f"- {req}" for req in requirement_analysis["functional_requirements"])} @@ -213,7 +442,6 @@ class DevAgent: 需求总结: {requirement_analysis["summary"]} """ - test_cases_text = chr(10).join( f"- [{c['test_id']}] {c['test_name']}(类型:{c.get('test_type', '未分类')})|预期结果:{c['expected_result']}" for c in test_cases["test_cases"] @@ -226,8 +454,7 @@ class DevAgent: 关键测试用例列表: {test_cases_text} """ - - prompt = f"""你是一个资深的Java开发工程师。根据以下需求和测试用例,生成高质量的Java实现代码和单元测试代码。 + return f"""你是一个资深的Python开发工程师。根据以下需求和测试用例,生成高质量的Python实现代码和单元测试代码。 {requirement_text} @@ -235,61 +462,246 @@ class DevAgent: 请返回JSON格式的结果(必须是有效的JSON格式): {{ - "java_code": "完整的Java实现代码(包含主类和必要的辅助类)", - "unit_tests": "使用JUnit的单元测试代码", - "implementation_notes": "实现说明和注意事项", - "unit_tests_count": "生成的单元测试总数量(整数)", - "passed_tests_count": "基于代码逻辑分析,预期可通过的单元测试数量(整数)" + "java_code": "完整的Python实现代码(包含主模块和必要的辅助类/函数)", + "unit_tests": "使用pytest的单元测试代码", + "implementation_notes": "实现说明和注意事项" }} -Java代码要求: -1. 使用 Java 11 语法和特性(如 var 局部变量类型推断、String::isBlank/strip/lines、Optional、Stream API、List.of/Map.of 等不可变集合工厂方法),不要使用 Java 8 以前的写法 -2. 包含详细的中文代码注释,所有注释内容必须用中文撰写,所有多行注释必须以 /* 开头、以 */ 结尾,Javadoc注释以 /** 开头、以 */ 结尾,绝对不能用单独的 / 作为注释结尾 +Python代码要求: +1. 使用 Python 3.10+ 语法和特性(如 match/case、类型注解、dataclass、pathlib、f-string、列表推导等),代码风格遵循 PEP 8 +2. 包含详细的中文代码注释,类和函数必须有中文 docstring 3. 包含异常处理 4. 支持所有的功能需求 5. 考虑非功能需求(性能、安全等) 单元测试要求: -1. 使用 JUnit 5(jupiter),充分利用 @DisplayName、@ParameterizedTest、assertThrows 等特性 -2. 为每个公共方法生成测试 +1. 使用 pytest,充分利用 @pytest.mark.parametrize、pytest.raises、fixture 等特性 +2. 为每个公共函数/方法生成测试 3. 包含正常情况、边缘情况和异常情况的测试 -4. 使用有意义的测试方法名称 -5. 每个测试类顶部加中文类级注释说明该类的测试范围 +4. 使用有意义的测试函数名称(如 test_xxx_when_xxx_should_xxx) +5. 每个测试函数内用中文注释标注 Given / When / Then 三个阶段 6. 测试代码要清晰易读 +7. 重要:业务代码文件保存名为 implementation.py,测试代码文件保存名为 test_implementation.py,测试文件必须使用 `from implementation import ...` 或 `import implementation` 导入业务代码 implementation_notes要求返回中文实现说明,内容要具体且有指导意义,不能只是简单的总结性描述,要包含对关键设计决策的解释和对复杂逻辑的说明。 返回ONLY JSON内容,不要有其他文字。""" + def generate_code( + self, + requirement_analysis: RequirementAnalysis, + test_cases: TestCaseResult + ) -> CodeGenerationResult: + """ + 生成Python实现代码和单元测试代码 + + Args: + requirement_analysis: PM Agent的分析结果 + test_cases: QA Agent的测试用例 + + Returns: + CodeGenerationResult: 包含Python代码和单元测试代码 + """ response = self.client.chat.completions.create( model=self.settings.model, messages=[ - {"role": "system", "content": "你是一个资深的 Java 11 开发工程师,擅长使用 Java 11 特性(var、Stream API、Optional、HttpClient、String 新方法等)编写高质量代码和单元测试,所有代码注释用中文,类名和方法名保持英文命名规范,输出必须是严格的 JSON 格式。"}, + {"role": "system", "content": "你是一个资深的 Python 3.10+ 开发工程师,擅长使用 Python 类型注解、dataclass、pytest 等特性编写高质量代码和单元测试,所有代码注释用中文,输出必须是严格的 JSON 格式。"}, + {"role": "user", "content": self._build_prompt(requirement_analysis, test_cases)} + ], + temperature=0.2, + max_tokens=8192, + response_format={"type": "json_object"} + ) + result = _parse_json(response.choices[0].message.content, "DevAgent 解析失败") + return CodeGenerationResult(**result) + + def stream_generate_code( + self, + requirement_analysis: RequirementAnalysis, + test_cases: TestCaseResult, + ): + """ + 流式版代码生成。yield 文本块(str),最后 yield (None, CodeGenerationResult) 作为哨兵。 + """ + stream = self.client.chat.completions.create( + model=self.settings.model, + messages=[ + {"role": "system", "content": "你是一个资深的 Python 3.10+ 开发工程师,擅长使用 Python 类型注解、dataclass、pytest 等特性编写高质量代码,所有代码注释用中文,输出必须是严格的 JSON 格式。"}, + {"role": "user", "content": self._build_prompt(requirement_analysis, test_cases)} + ], + temperature=0.2, + max_tokens=8192, + response_format={"type": "json_object"}, + stream=True, + ) + full_text = "" + for chunk in stream: + delta = chunk.choices[0].delta.content or "" + if delta: + full_text += delta + yield delta + result = _parse_json(full_text, "DevAgent stream 解析失败") + yield (None, CodeGenerationResult(**result)) + + def refine( + self, + previous: CodeGenerationResult, + requirement_analysis: RequirementAnalysis, + test_cases: TestCaseResult, + feedback: str + ) -> CodeGenerationResult: + """ + 根据用户反馈修改已有的代码生成结果。 + + Args: + previous: 上一次的代码生成结果 + requirement_analysis: PM Agent的分析结果(供参考) + test_cases: QA Agent的测试用例(供参考) + feedback: 用户的修改意见 + + Returns: + CodeGenerationResult: 修改后的代码生成结果 + """ + refine_prompt = f"""以下是你之前生成的代码结果,用户对此有修改意见,请根据意见调整输出。 + +用户的修改意见: +{feedback} + +请在原有代码基础上修改,保持 JSON 格式不变,返回完整的修改后结果。 +继续遵守原有的 Python 3.10+、中文注释、pytest 等所有要求。 +返回 ONLY JSON,不要有其他文字。""" + + response = self.client.chat.completions.create( + model=self.settings.model, + messages=[ + {"role": "system", "content": "你是一个资深的 Python 3.10+ 开发工程师,所有代码注释用中文,输出必须是严格的 JSON 格式。"}, + {"role": "user", "content": self._build_prompt(requirement_analysis, test_cases)}, + {"role": "assistant", "content": json.dumps(previous, ensure_ascii=False)}, + {"role": "user", "content": refine_prompt} + ], + temperature=0.2, + max_tokens=8192, + response_format={"type": "json_object"} + ) + result = _parse_json(response.choices[0].message.content, "DevAgent.refine 解析失败") + return CodeGenerationResult(**result) + + +class FixAgent: + """自动修复Agent - 根据pytest失败信息修复Python代码""" + + def __init__(self): + self.settings = get_settings() + self.client = OpenAI(api_key=self.settings.api_key, base_url=self.settings.base_url) + + def fix(self, code_generation: CodeGenerationResult, test_output: str) -> CodeGenerationResult: + """ + 根据pytest输出自动修复业务代码和/或测试代码。 + + Args: + code_generation: 之前 DevAgent 的产出 + test_output: pytest 执行的完整输出 + + Returns: + CodeGenerationResult: 修复后的完整代码 + """ + prompt = f"""以下Python代码在运行单元测试时出现了失败,请根据pytest失败信息修复代码。 + +当前业务代码(implementation.py): +```python +{code_generation["java_code"]} +``` + +当前测试代码(test_implementation.py): +```python +{code_generation["unit_tests"]} +``` + +pytest 执行输出: +{test_output[:4000]} + +请分析失败原因并修复(业务代码和/或测试代码中的问题 都可以修复)。 +层层修复要径直至递推所有测试权都能通过。 +业务代码文件名为 implementation.py,测试文件名为 test_implementation.py,测试文件导入必须用 from implementation import ...。 + +返回 JSON 格式: +{{ + "java_code": "修复后的完整Python业务代码", + "unit_tests": "修复后的完整Python测试代码", + "implementation_notes": "修复说明:具体说明了哪些问题,如何修复的" +}} + +返回 ONLY JSON,不要有其他文字。""" + + response = self.client.chat.completions.create( + model=self.settings.model, + messages=[ + {"role": "system", "content": "你是一个资深的 Python 3.10+ 开发工程师,擅长根据测试错误信息定位并修复 Bug,所有代码注释用中文,输出必须是严格的 JSON 格式。"}, {"role": "user", "content": prompt} ], temperature=0.2, max_tokens=8192, response_format={"type": "json_object"} ) + result = _parse_json(response.choices[0].message.content, "FixAgent 解析失败") + return CodeGenerationResult(**result) - content = response.choices[0].message.content + def stream_fix(self, code_generation: CodeGenerationResult, test_output: str): + """ + 流式修复。yield 文本块(str),最后 yield (None, CodeGenerationResult) 作为哨兵。 + """ + prompt = f"""以下Python代码在运行单元测试时出现了失败,请根据pytest失败信息修复代码。 +当前业务代码(implementation.py): +```python +{code_generation["java_code"]} +``` - try: - result = json.loads(content) - return CodeGenerationResult(**result) - except json.JSONDecodeError: - import re - json_match = re.search(r'\{.*\}', content, re.DOTALL) - if json_match: - result = json.loads(json_match.group()) - return CodeGenerationResult(**result) - raise ValueError(f"无法解析Dev Agent响应: {content}") +当前测试代码(test_implementation.py): +```python +{code_generation["unit_tests"]} +``` + +pytest 执行输出: +{test_output[:4000]} + +请分析失败原因并修复(业务代码和/或测试代码中的问题 都可以修复)。 +层层修复要径直至递推所有测试权都能通过。 +业务代码文件名为 implementation.py,测试文件名为 test_implementation.py,测试文件导入必须用 from implementation import ...。 + +返回 JSON 格式: +{{ + "java_code": "修复后的完整Python业务代码", + "unit_tests": "修复后的完整Python测试代码", + "implementation_notes": "修复说明:具体说明了哪些问题,如何修复的" +}} + +返回 ONLY JSON,不要有其他文字。""" + + stream = self.client.chat.completions.create( + model=self.settings.model, + messages=[ + {"role": "system", "content": "你是一个资深的 Python 3.10+ 开发工程师,擅长根据测试错误信息定位并修复 Bug,所有代码注释用中文,输出必须是严格的 JSON 格式。"}, + {"role": "user", "content": prompt} + ], + temperature=0.2, + max_tokens=8192, + response_format={"type": "json_object"}, + stream=True, + ) + full_text = "" + for chunk in stream: + delta = chunk.choices[0].delta.content or "" + if delta: + full_text += delta + yield delta + result = _parse_json(full_text, "FixAgent stream 解析失败") + yield (None, CodeGenerationResult(**result)) async def orchestrate_agents(simple_requirement: str) -> dict: """ - 编排三个Agent的工作流程 + 编排三个Agent的工作流程(保留原有全量接口) Args: simple_requirement: 用户的简单需求描述 diff --git a/app/message.py b/app/message.py index dd32f9a..35405b8 100644 --- a/app/message.py +++ b/app/message.py @@ -15,11 +15,12 @@ _TYPE_EMOJI = { def _post(data: dict): - requests.post( - webhook_url, - headers={"Content-Type": "application/json"}, - data=json.dumps(data, ensure_ascii=False) - ) + return + # requests.post( + # webhook_url, + # headers={"Content-Type": "application/json"}, + # data=json.dumps(data, ensure_ascii=False) + # ) def _make_card(title: str, color: str, elements: list) -> dict: @@ -136,32 +137,6 @@ def send_generate_code(code_result: CodeGenerationResult): [{"tag": "markdown", "content": f"```java\n{code_result['unit_tests']}\n```"}] )) - # 单元测试执行结果 - try: - total = int(code_result["unit_tests_count"]) - passed = int(code_result["passed_tests_count"]) - failed = total - passed - rate = f"{passed / total * 100:.0f}%" if total > 0 else "—" - status = "✅ 全部通过" if failed == 0 else f"⚠️ {failed} 个未通过" - failed_display = f"{failed} ❌" if failed > 0 else "0" - except (ValueError, TypeError): - total = code_result["unit_tests_count"] - passed = code_result["passed_tests_count"] - failed_display, rate, status = "—", "—", "—" - - result_md = ( - f"**总用例数:** {total}\n\n" - f"**通过:** {passed} ✅\n\n" - f"**未通过:** {failed_display}\n\n" - f"**通过率:** {rate}\n\n" - f"---\n\n" - f"**整体状态:** {status}" - ) - _post(_make_card( - "📊 代码生成结果 — 单元测试执行结果", "green", - [{"tag": "markdown", "content": result_md}] - )) - def send_test_cases(test_case: TestCaseResult): _post(build_full_feishu_card(test_case)) diff --git a/app/models.py b/app/models.py index 3cf6b23..7193c58 100644 --- a/app/models.py +++ b/app/models.py @@ -22,5 +22,3 @@ class CodeGenerationResult(TypedDict): java_code: str unit_tests: str implementation_notes: str - unit_tests_count: str - passed_tests_count: str diff --git a/app/routers/__init__.py b/app/routers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/routers/session.py b/app/routers/session.py new file mode 100644 index 0000000..2c7bf9e --- /dev/null +++ b/app/routers/session.py @@ -0,0 +1,514 @@ +""" +app/routers/session.py - 交互式会话路由 +""" + +import logging +import json +from fastapi import APIRouter, HTTPException +from fastapi.responses import StreamingResponse +from pydantic import BaseModel +from typing import Optional + +from app.session import SessionStore, SessionStatus +from app.agents import ClarifyAgent, PMAgent, QAAgent, DevAgent, FixAgent +from app.test_runner import run_python_tests +from app.message import ( + send_workflow_start, + send_requirement_result, + send_test_cases, + send_generate_code, +) + +router = APIRouter(prefix="/session", tags=["session"]) +logger = logging.getLogger(__name__) + + +# ---------- 请求 / 响应模型 ---------- + +class StartRequest(BaseModel): + requirement: str + +class ClarifyRequest(BaseModel): + message: str + +class RefineRequest(BaseModel): + feedback: str + + +class SessionResponse(BaseModel): + session_id: str + status: str + ready: bool = False + question: Optional[str] = None # 当 ready=False 时返回追问 + data: Optional[dict] = None # 当前阶段产出 + + +# ---------- 工具函数 ---------- + +def _get_session_or_404(session_id: str): + session = SessionStore.get(session_id) + if not session: + raise HTTPException(status_code=404, detail="会话不存在或已过期") + return session + +def _require_status(session, *allowed: SessionStatus): + if session.status not in allowed: + raise HTTPException( + status_code=400, + detail=f"当前状态 [{session.status}] 不允许此操作,允许的状态: {[s.value for s in allowed]}" + ) + + +# ---------- 接口 ---------- + +@router.post("/start", response_model=SessionResponse) +def start_session(body: StartRequest): + """创建会话,AI 判断需求是否完整,不够则追问。""" + session = SessionStore.create(body.requirement) + agent = ClarifyAgent() + result = agent.start(body.requirement) + + q = result.get("question", "") + if q: + session.clarify_history.append({"role": "assistant", "content": q}) + session.clarified_requirement = result.get("clarified_requirement", body.requirement) + + if result.get("ready"): + session.status = SessionStatus.PM_READY + + session.touch() + return SessionResponse( + session_id=session.session_id, + status=session.status, + ready=result.get("ready", False), + question=result.get("question") or None, + ) + + +@router.post("/{session_id}/clarify", response_model=SessionResponse) +def clarify(session_id: str, body: ClarifyRequest): + """用户补充需求,AI 继续判断是否够了。""" + session = _get_session_or_404(session_id) + _require_status(session, SessionStatus.CLARIFYING) + + session.clarify_history.append({"role": "user", "content": body.message}) + + agent = ClarifyAgent() + result = agent.continue_clarify(session.clarify_history, body.message) + + q = result.get("question", "") + if q: + session.clarify_history.append({"role": "assistant", "content": q}) + session.clarified_requirement = result.get("clarified_requirement", session.clarified_requirement) + + if result.get("ready"): + session.status = SessionStatus.PM_READY + + session.touch() + return SessionResponse( + session_id=session.session_id, + status=session.status, + ready=result.get("ready", False), + question=result.get("question") or None, + ) + + +@router.get("/{session_id}/pm/stream") +def pm_stream(session_id: str): + """流式返回 PM Agent 分析过程和结果(SSE)。""" + session = _get_session_or_404(session_id) + _require_status(session, SessionStatus.PM_READY) + + send_workflow_start(session.clarified_requirement) + agent = PMAgent() + simple_requirement = session.clarified_requirement + + def generate(): + try: + result = None + for item in agent.stream_analyze(simple_requirement): + if isinstance(item, tuple): + _, result = item + else: + yield f"data: {json.dumps({'type': 'chunk', 'text': item}, ensure_ascii=False)}\n\n" + session.requirement_analysis = result + send_requirement_result(result) + session.status = SessionStatus.PM_DONE + session.touch() + yield f"data: {json.dumps({'type': 'done', 'status': session.status, 'data': {'requirement_analysis': result}}, ensure_ascii=False)}\n\n" + except Exception as e: + yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" + + return StreamingResponse( + generate(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + ) + + +@router.post("/{session_id}/pm/run", response_model=SessionResponse) +def pm_run(session_id: str): + """触发 PM Agent 分析需求。""" + session = _get_session_or_404(session_id) + _require_status(session, SessionStatus.PM_READY) + + send_workflow_start(session.clarified_requirement) + + agent = PMAgent() + session.requirement_analysis = agent.analyze_requirement(session.clarified_requirement) + send_requirement_result(session.requirement_analysis) + + session.status = SessionStatus.PM_DONE + session.touch() + return SessionResponse( + session_id=session.session_id, + status=session.status, + ready=True, + data={"requirement_analysis": session.requirement_analysis}, + ) + + +@router.get("/{session_id}/pm/refine/stream") +def pm_refine_stream(session_id: str, feedback: str): + """流式修改 PM 产出(SSE),feedback 经由 query param 传入。""" + session = _get_session_or_404(session_id) + _require_status(session, SessionStatus.PM_DONE) + + agent = PMAgent() + previous = session.requirement_analysis + + def generate(): + try: + result = None + for item in agent.stream_refine(previous, feedback): + if isinstance(item, tuple): + _, result = item + else: + yield f"data: {json.dumps({'type': 'chunk', 'text': item}, ensure_ascii=False)}\n\n" + session.requirement_analysis = result + send_requirement_result(result) + session.touch() + yield f"data: {json.dumps({'type': 'done', 'status': session.status, 'data': {'requirement_analysis': result}}, ensure_ascii=False)}\n\n" + except Exception as e: + yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" + + return StreamingResponse( + generate(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + ) + + +@router.post("/{session_id}/pm/refine", response_model=SessionResponse) +def pm_refine(session_id: str, body: RefineRequest): + """根据反馈修改 PM 产出。""" + session = _get_session_or_404(session_id) + _require_status(session, SessionStatus.PM_DONE) + + agent = PMAgent() + session.requirement_analysis = agent.refine(session.requirement_analysis, body.feedback) + send_requirement_result(session.requirement_analysis) + + session.touch() + return SessionResponse( + session_id=session.session_id, + status=session.status, + ready=True, + data={"requirement_analysis": session.requirement_analysis}, + ) + + +@router.get("/{session_id}/qa/stream") +def qa_stream(session_id: str): + """流式返回 QA Agent 测试用例生成过程(SSE)。""" + session = _get_session_or_404(session_id) + _require_status(session, SessionStatus.PM_DONE) + + if not session.requirement_analysis: + raise HTTPException(status_code=400, detail="PM Agent 产出不存在") + + agent = QAAgent() + req_analysis = session.requirement_analysis + + def generate(): + try: + result = None + for item in agent.stream_generate_test_cases(req_analysis): + if isinstance(item, tuple): + _, result = item + else: + yield f"data: {json.dumps({'type': 'chunk', 'text': item}, ensure_ascii=False)}\n\n" + session.test_cases = result + send_test_cases(result) + session.status = SessionStatus.QA_DONE + session.touch() + yield f"data: {json.dumps({'type': 'done', 'status': session.status, 'data': {'test_cases': result}}, ensure_ascii=False)}\n\n" + except Exception as e: + yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" + + return StreamingResponse( + generate(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + ) + + +@router.post("/{session_id}/qa/run", response_model=SessionResponse) +def qa_run(session_id: str): + """触发 QA Agent 生成测试用例。""" + session = _get_session_or_404(session_id) + _require_status(session, SessionStatus.PM_DONE) + + if not session.requirement_analysis: + raise HTTPException(status_code=400, detail="PM Agent 产出不存在") + + agent = QAAgent() + session.test_cases = agent.generate_test_cases(session.requirement_analysis) + send_test_cases(session.test_cases) + + session.status = SessionStatus.QA_DONE + session.touch() + return SessionResponse( + session_id=session.session_id, + status=session.status, + ready=True, + data={"test_cases": session.test_cases}, + ) + + +@router.get("/{session_id}/qa/refine/stream") +def qa_refine_stream(session_id: str, feedback: str): + """流式修改 QA 产出(SSE),feedback 经由 query param 传入。""" + session = _get_session_or_404(session_id) + _require_status(session, SessionStatus.QA_DONE) + + agent = QAAgent() + previous = session.test_cases + + def generate(): + try: + result = None + for item in agent.stream_refine(previous, feedback): + if isinstance(item, tuple): + _, result = item + else: + yield f"data: {json.dumps({'type': 'chunk', 'text': item}, ensure_ascii=False)}\n\n" + session.test_cases = result + send_test_cases(result) + session.touch() + yield f"data: {json.dumps({'type': 'done', 'status': session.status, 'data': {'test_cases': result}}, ensure_ascii=False)}\n\n" + except Exception as e: + yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" + + return StreamingResponse( + generate(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + ) + + +@router.post("/{session_id}/qa/refine", response_model=SessionResponse) +def qa_refine(session_id: str, body: RefineRequest): + """根据反馈修改 QA 产出。""" + session = _get_session_or_404(session_id) + _require_status(session, SessionStatus.QA_DONE) + + agent = QAAgent() + session.test_cases = agent.refine(session.test_cases, body.feedback) + send_test_cases(session.test_cases) + + session.touch() + return SessionResponse( + session_id=session.session_id, + status=session.status, + ready=True, + data={"test_cases": session.test_cases}, + ) + + +@router.get("/{session_id}/dev/stream") +def dev_stream(session_id: str): + """流式返回 Dev Agent 代码生成过程(SSE)。""" + session = _get_session_or_404(session_id) + _require_status(session, SessionStatus.QA_DONE) + + if not session.requirement_analysis or not session.test_cases: + raise HTTPException(status_code=400, detail="PM / QA Agent 产出不完整") + + agent = DevAgent() + req_analysis = session.requirement_analysis + test_cases = session.test_cases + + def generate(): + try: + result = None + for item in agent.stream_generate_code(req_analysis, test_cases): + if isinstance(item, tuple): + _, result = item + else: + yield f"data: {json.dumps({'type': 'chunk', 'text': item}, ensure_ascii=False)}\n\n" + session.code_generation = result + send_generate_code(result) + session.status = SessionStatus.DEV_DONE + session.touch() + yield f"data: {json.dumps({'type': 'done', 'status': session.status, 'data': {'code_generation': result}}, ensure_ascii=False)}\n\n" + except Exception as e: + yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" + + return StreamingResponse( + generate(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + ) + + +@router.post("/{session_id}/dev/run", response_model=SessionResponse) +def dev_run(session_id: str): + """触发 Dev Agent 生成代码。""" + session = _get_session_or_404(session_id) + _require_status(session, SessionStatus.QA_DONE) + + if not session.requirement_analysis or not session.test_cases: + raise HTTPException(status_code=400, detail="PM / QA Agent 产出不完整") + + agent = DevAgent() + session.code_generation = agent.generate_code(session.requirement_analysis, session.test_cases) + send_generate_code(session.code_generation) + + session.status = SessionStatus.DEV_DONE + session.touch() + return SessionResponse( + session_id=session.session_id, + status=session.status, + ready=True, + data={"code_generation": session.code_generation}, + ) + + +@router.post("/{session_id}/dev/refine", response_model=SessionResponse) +def dev_refine(session_id: str, body: RefineRequest): + """根据反馈修改 Dev 产出。""" + session = _get_session_or_404(session_id) + _require_status(session, SessionStatus.DEV_DONE) + + agent = DevAgent() + session.code_generation = agent.refine( + session.code_generation, + session.requirement_analysis, + session.test_cases, + body.feedback, + ) + send_generate_code(session.code_generation) + + session.touch() + return SessionResponse( + session_id=session.session_id, + status=session.status, + ready=True, + data={"code_generation": session.code_generation}, + ) + + +@router.post("/{session_id}/test/run", response_model=SessionResponse) +def test_run(session_id: str): + """在临时目录中真实执行 pytest,返回测试结果。""" + session = _get_session_or_404(session_id) + + if not session.code_generation: + raise HTTPException(status_code=400, detail="Dev Agent 产出不存在") + + result = run_python_tests( + session.code_generation["java_code"], + session.code_generation["unit_tests"], + ) + session.test_execution = result + session.status = SessionStatus.TEST_DONE + session.touch() + return SessionResponse( + session_id=session.session_id, + status=session.status, + ready=True, + data={"test_execution": result}, + ) + + +@router.post("/{session_id}/test/fix", response_model=SessionResponse) +def test_fix(session_id: str): + """调用 FixAgent 根据测试失败信息自动修复代码。""" + session = _get_session_or_404(session_id) + _require_status(session, SessionStatus.TEST_DONE) + + if not session.test_execution: + raise HTTPException(status_code=400, detail="尚未执行测试") + if session.test_execution.get("success"): + raise HTTPException(status_code=400, detail="测试已全部通过,无需修复") + + agent = FixAgent() + session.code_generation = agent.fix( + session.code_generation, + session.test_execution["output"], + ) + session.status = SessionStatus.DEV_DONE # 修复后重置为 dev_done,可再次测试 + session.touch() + return SessionResponse( + session_id=session.session_id, + status=session.status, + ready=True, + data={"code_generation": session.code_generation}, + ) + + +@router.get("/{session_id}/test/fix/stream") +def test_fix_stream(session_id: str): + """流式返回 FixAgent 代码修复过程(SSE)。""" + session = _get_session_or_404(session_id) + _require_status(session, SessionStatus.TEST_DONE) + + if not session.test_execution: + raise HTTPException(status_code=400, detail="尚未执行测试") + if session.test_execution.get("success"): + raise HTTPException(status_code=400, detail="测试已全部通过,无需修复") + + agent = FixAgent() + code_generation = session.code_generation + test_output = session.test_execution["output"] + + def generate(): + try: + result = None + for item in agent.stream_fix(code_generation, test_output): + if isinstance(item, tuple): + _, result = item + else: + yield f"data: {json.dumps({'type': 'chunk', 'text': item}, ensure_ascii=False)}\n\n" + session.code_generation = result + session.status = SessionStatus.DEV_DONE + session.touch() + yield f"data: {json.dumps({'type': 'done', 'status': session.status, 'data': {'code_generation': result}}, ensure_ascii=False)}\n\n" + except Exception as e: + yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" + + return StreamingResponse( + generate(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + ) + + +@router.get("/{session_id}", response_model=SessionResponse) +def get_session(session_id: str): + """获取当前会话状态和所有产出。""" + session = _get_session_or_404(session_id) + return SessionResponse( + session_id=session.session_id, + status=session.status, + ready=session.status != SessionStatus.CLARIFYING, + data={ + "raw_requirement": session.raw_requirement, + "clarify_history": session.clarify_history, + "requirement_analysis": session.requirement_analysis, + "test_cases": session.test_cases, + "code_generation": session.code_generation, + "test_execution": session.test_execution, + }, + ) diff --git a/app/session.py b/app/session.py new file mode 100644 index 0000000..3fd3256 --- /dev/null +++ b/app/session.py @@ -0,0 +1,68 @@ +""" +app/session.py - 会话状态管理 +""" + +import uuid +import time +from enum import Enum +from typing import Optional +from app.models import RequirementAnalysis, TestCaseResult, CodeGenerationResult + + +class SessionStatus(str, Enum): + CLARIFYING = "clarifying" # 正在澄清需求 + PM_READY = "pm_ready" # 需求澄清完毕,可运行 PM Agent + PM_DONE = "pm_done" # PM Agent 已完成 + QA_READY = "qa_ready" # 可运行 QA Agent + QA_DONE = "qa_done" # QA Agent 已完成 + DEV_READY = "dev_ready" # 可运行 Dev Agent + DEV_DONE = "dev_done" # Dev Agent 已完成 + TEST_DONE = "test_done" # 单元测试执行完成 + +class Session: + def __init__(self, session_id: str, raw_requirement: str): + self.session_id: str = session_id + self.raw_requirement: str = raw_requirement # 用户最原始的需求 + self.clarified_requirement: str = raw_requirement # 经过澄清补充后的完整需求 + self.clarify_history: list[dict] = [] # 澄清对话历史 + self.status: SessionStatus = SessionStatus.CLARIFYING + self.created_at: float = time.time() + self.updated_at: float = time.time() + + # 各 Agent 产出 + self.requirement_analysis: Optional[RequirementAnalysis] = None + self.test_cases: Optional[TestCaseResult] = None + self.code_generation: Optional[CodeGenerationResult] = None + self.test_execution: Optional[dict] = None # pytest 执行结果 + + def touch(self): + self.updated_at = time.time() + + +class SessionStore: + _store: dict[str, Session] = {} + _TTL_SECONDS = 7200 # 2小时过期 + + @classmethod + def create(cls, raw_requirement: str) -> Session: + session_id = str(uuid.uuid4()) + session = Session(session_id, raw_requirement) + cls._store[session_id] = session + cls._evict_expired() + return session + + @classmethod + def get(cls, session_id: str) -> Optional[Session]: + session = cls._store.get(session_id) + if session and time.time() - session.updated_at > cls._TTL_SECONDS: + del cls._store[session_id] + return None + return session + + @classmethod + def _evict_expired(cls): + now = time.time() + expired = [sid for sid, s in cls._store.items() + if now - s.updated_at > cls._TTL_SECONDS] + for sid in expired: + del cls._store[sid] diff --git a/app/test_runner.py b/app/test_runner.py new file mode 100644 index 0000000..8533ce6 --- /dev/null +++ b/app/test_runner.py @@ -0,0 +1,80 @@ +""" +app/test_runner.py - 在临时目录中真实执行 pytest 单元测试 +""" +import os +import re +import sys +import tempfile +import subprocess + + +def run_python_tests(python_code: str, test_code: str) -> dict: + """ + 将业务代码写入 implementation.py,测试代码写入 test_implementation.py, + 在隔离的临时目录中用 pytest 执行,返回结构化测试结果。 + + Returns: + dict: { + success: bool, + passed: int, + failed: int, + errors: int, + total: int, + output: str # pytest 完整输出 + } + """ + with tempfile.TemporaryDirectory() as tmpdir: + impl_path = os.path.join(tmpdir, "implementation.py") + test_path = os.path.join(tmpdir, "test_implementation.py") + conftest_path = os.path.join(tmpdir, "conftest.py") + + with open(impl_path, "w", encoding="utf-8") as f: + f.write(python_code) + + with open(test_path, "w", encoding="utf-8") as f: + f.write(test_code) + + # conftest.py 确保 tmpdir 在 sys.path 首位,解决模块导入问题 + with open(conftest_path, "w", encoding="utf-8") as f: + f.write("import sys, os\nsys.path.insert(0, os.path.dirname(__file__))\n") + + try: + proc = subprocess.run( + [sys.executable, "-m", "pytest", "test_implementation.py", + "-v", "--tb=short", "--no-header", "--color=no"], + cwd=tmpdir, + capture_output=True, + text=True, + timeout=60, + env={**os.environ, "PYTHONPATH": tmpdir}, + ) + output = proc.stdout + if proc.stderr.strip(): + output += "\n--- stderr ---\n" + proc.stderr + except subprocess.TimeoutExpired: + return { + "success": False, + "passed": 0, "failed": 0, "errors": 1, "total": 0, + "output": "❌ 测试执行超时(超过 60 秒)", + } + except FileNotFoundError: + return { + "success": False, + "passed": 0, "failed": 0, "errors": 1, "total": 0, + "output": "❌ 未找到 pytest,请确保已安装:pip install pytest", + } + + passed = int(m.group(1)) if (m := re.search(r"(\d+) passed", output)) else 0 + failed = int(m.group(1)) if (m := re.search(r"(\d+) failed", output)) else 0 + errors = int(m.group(1)) if (m := re.search(r"(\d+) error", output)) else 0 + total = passed + failed + errors + success = (failed == 0 and errors == 0 and total > 0) + + return { + "success": success, + "passed": passed, + "failed": failed, + "errors": errors, + "total": total, + "output": output, + } diff --git a/main.py b/main.py index 1b81104..76227bb 100644 --- a/main.py +++ b/main.py @@ -9,6 +9,7 @@ from pydantic import BaseModel from app.agents import orchestrate_agents from app.config import get_settings +from app.routers import session as session_router # 初始化日志 logging.basicConfig(level=logging.INFO) @@ -33,6 +34,9 @@ app.add_middleware( allow_headers=["*"], ) +# 注册交互式会话路由 +app.include_router(session_router.router) + class FullWorkflowResponse(BaseModel): """完整工作流响应"""