81 lines
2.7 KiB
Python
81 lines
2.7 KiB
Python
"""
|
||
app/test_runner.py - 在临时目录中真实执行 pytest 单元测试
|
||
"""
|
||
import os
|
||
import re
|
||
import sys
|
||
import tempfile
|
||
import subprocess
|
||
|
||
|
||
def run_python_tests(python_code: str, test_code: str) -> dict:
|
||
"""
|
||
将业务代码写入 implementation.py,测试代码写入 test_implementation.py,
|
||
在隔离的临时目录中用 pytest 执行,返回结构化测试结果。
|
||
|
||
Returns:
|
||
dict: {
|
||
success: bool,
|
||
passed: int,
|
||
failed: int,
|
||
errors: int,
|
||
total: int,
|
||
output: str # pytest 完整输出
|
||
}
|
||
"""
|
||
with tempfile.TemporaryDirectory() as tmpdir:
|
||
impl_path = os.path.join(tmpdir, "implementation.py")
|
||
test_path = os.path.join(tmpdir, "test_implementation.py")
|
||
conftest_path = os.path.join(tmpdir, "conftest.py")
|
||
|
||
with open(impl_path, "w", encoding="utf-8") as f:
|
||
f.write(python_code)
|
||
|
||
with open(test_path, "w", encoding="utf-8") as f:
|
||
f.write(test_code)
|
||
|
||
# conftest.py 确保 tmpdir 在 sys.path 首位,解决模块导入问题
|
||
with open(conftest_path, "w", encoding="utf-8") as f:
|
||
f.write("import sys, os\nsys.path.insert(0, os.path.dirname(__file__))\n")
|
||
|
||
try:
|
||
proc = subprocess.run(
|
||
[sys.executable, "-m", "pytest", "test_implementation.py",
|
||
"-v", "--tb=short", "--no-header", "--color=no"],
|
||
cwd=tmpdir,
|
||
capture_output=True,
|
||
text=True,
|
||
timeout=60,
|
||
env={**os.environ, "PYTHONPATH": tmpdir},
|
||
)
|
||
output = proc.stdout
|
||
if proc.stderr.strip():
|
||
output += "\n--- stderr ---\n" + proc.stderr
|
||
except subprocess.TimeoutExpired:
|
||
return {
|
||
"success": False,
|
||
"passed": 0, "failed": 0, "errors": 1, "total": 0,
|
||
"output": "❌ 测试执行超时(超过 60 秒)",
|
||
}
|
||
except FileNotFoundError:
|
||
return {
|
||
"success": False,
|
||
"passed": 0, "failed": 0, "errors": 1, "total": 0,
|
||
"output": "❌ 未找到 pytest,请确保已安装:pip install pytest",
|
||
}
|
||
|
||
passed = int(m.group(1)) if (m := re.search(r"(\d+) passed", output)) else 0
|
||
failed = int(m.group(1)) if (m := re.search(r"(\d+) failed", output)) else 0
|
||
errors = int(m.group(1)) if (m := re.search(r"(\d+) error", output)) else 0
|
||
total = passed + failed + errors
|
||
success = (failed == 0 and errors == 0 and total > 0)
|
||
|
||
return {
|
||
"success": success,
|
||
"passed": passed,
|
||
"failed": failed,
|
||
"errors": errors,
|
||
"total": total,
|
||
"output": output,
|
||
}
|