Files
code_scan/scanner/ai_reviewer.py
Dang Zerong cb90b66f09 代码测试
2026-03-13 11:26:01 +08:00

593 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
AI 代码审查器
使用大模型进行智能代码审查
"""
import os
import re
import json
import logging
from typing import Dict, Any, List, Optional
from scanner.base import BaseScanner
logger = logging.getLogger(__name__)
class AIReviewer(BaseScanner):
"""AI 代码审查器"""
def __init__(self, config: Dict[str, Any]):
"""
初始化 AI 审查器
Args:
config: AI 配置
"""
# 先初始化基类
super().__init__(config.get('scanner', {}))
self.config = config
self.enabled = config.get('enabled', True)
self.provider = config.get('provider', 'ollama')
self.model = config.get('model', 'llama3')
self.api_url = config.get('api_url', 'http://localhost:11434')
self.api_key = config.get('api_key', '')
self.max_lines = config.get('max_lines', 200)
if not self.enabled:
logger.info('AI 审查器已禁用')
return
logger.info(f'AI 审查器初始化: {self.provider}/{self.model}')
def scan(self, repo_url: str, commit_id: Optional[str], branch: str, changed_files: Optional[List[str]] = None) -> Dict[str, Any]:
"""
执行代码扫描(实现抽象方法)
Args:
repo_url: 仓库 URL
commit_id: 提交 ID
branch: 分支名
changed_files: 可选的变更文件列表(来自 PR
Returns:
审查结果
"""
# 调用实际的审查逻辑
return self._do_review(repo_url=repo_url, commit_id=commit_id, branch=branch, changed_files=changed_files)
def _do_review(self, clone_dir: str = None, repo_url: str = None,
commit_id: str = None, branch: str = None,
language: str = 'python',
changed_files: Optional[List[str]] = None) -> Dict[str, Any]:
"""
执行 AI 代码审查
Args:
clone_dir: 仓库目录(如果已克隆则直接传入)
repo_url: 仓库 URL如果未克隆则需要传入
commit_id: 提交 ID
branch: 分支名
language: 编程语言
changed_files: 可选的变更文件列表(来自 PR
Returns:
审查结果(与 python_scanner.py 兼容的格式)
"""
result = {
'tool': 'AI Code Reviewer',
'language': language,
'status': 'success',
'issues': [],
'summary': {
'total': 0,
'error': 0,
'warning': 0,
'info': 0
},
'files_scanned': 0
}
if not self.enabled:
result['status'] = 'disabled'
result['summary'] = 'AI 审查已禁用'
return result
try:
# 如果没有传入 clone_dir需要克隆
if not clone_dir and repo_url:
clone_dir = self.clone_repo(repo_url, commit_id, branch)
if not clone_dir or not os.path.exists(clone_dir):
result['status'] = 'error'
result['error'] = '无法获取代码目录'
return result
# 获取要审查的代码文件
files = self._get_code_files(clone_dir, language, changed_files)
if not files:
result['summary'] = '未找到可审查的代码文件'
return result
# 对每个文件进行 AI 审查
all_issues = []
for file_path in files[:5]: # 限制最多审查 5 个文件
review = self._review_file(file_path, language, clone_dir)
if review and review.get('issues'):
all_issues.extend(review['issues'])
result['issues'] = all_issues[:self.max_issues] if self.detailed else all_issues
result['summary'] = self._calculate_summary(all_issues)
result['files_scanned'] = len(files[:5])
result['clone_dir'] = clone_dir
# 生成质量评分
result['quality_score'] = self._calculate_quality_score(all_issues, files[:5])
return result
except Exception as e:
logger.error(f'AI 审查失败: {str(e)}')
result['status'] = 'error'
result['error'] = str(e)
return result
def _calculate_summary(self, issues: List[Dict]) -> Dict[str, int]:
"""计算问题摘要"""
summary = {
'total': len(issues),
'error': 0,
'warning': 0,
'info': 0
}
for issue in issues:
severity = issue.get('severity', '').lower()
if severity in ['error', 'critical', 'fatal']:
summary['error'] += 1
elif severity in ['warning', 'moderate']:
summary['warning'] += 1
else:
summary['info'] += 1
return summary
def _calculate_quality_score(self, issues: List[Dict], files: List[str]) -> Dict[str, Any]:
"""
计算代码质量评分
返回:总分(0-100)及各维度评分
"""
if not files:
return {'total': 100, 'maintainability': 100, 'security': 100, 'readability': 100, 'best_practices': 100}
# 统计问题
error_count = sum(1 for i in issues if i.get('severity', '').lower() in ['error', 'critical'])
warning_count = sum(1 for i in issues if i.get('severity', '').lower() == 'warning')
info_count = sum(1 for i in issues if i.get('severity', '').lower() == 'info')
# 分类统计
security_keywords = ['sql injection', 'xss', 'csrf', 'password', 'secret', 'token', '权限', '注入', '认证']
security_issues = sum(1 for i in issues if any(k in (i.get('message', '') + i.get('symbol', '')).lower() for k in security_keywords))
# 计算各维度分数
# 可维护性:基于错误和警告数量
issue_weight = error_count * 5 + warning_count * 2 + info_count * 0.5
maintainability = max(0, 100 - issue_weight)
# 安全性:基于安全问题
security_score = max(0, 100 - security_issues * 15)
# 可读性:基于 info 级别问题(风格类)
readability = max(0, 100 - info_count * 3)
# 最佳实践:基于 warning 级别
best_practices = max(0, 100 - warning_count * 5)
# 总分:加权平均
total = int((maintainability * 0.3 + security_score * 0.35 + readability * 0.15 + best_practices * 0.2))
return {
'total': total,
'maintainability': maintainability,
'security': security_score,
'readability': readability,
'best_practices': best_practices,
'details': {
'error_count': error_count,
'warning_count': warning_count,
'info_count': info_count,
'security_issues': security_issues
}
}
def generate_fix_suggestion(self, file_path: str, line: int, message: str, code: str) -> Optional[str]:
"""
对指定问题生成修复建议代码
"""
prompt = f"""你是一位代码修复专家。请根据以下问题,生成修复后的代码。
问题描述:{message}
问题所在行号:{line}
原始代码:
```
{code}
```
请以 JSON 格式输出修复建议:
```json
{{
"fixed_code": "修复后的完整代码或关键片段",
"explanation": "修复说明50字以内",
"confidence": "high/medium/low 修复把握度"
}}
```
如果无法修复,请返回:{{"fixed_code": "", "explanation": "无法自动修复", "confidence": "low"}}"""
try:
response = self._call_ai(prompt)
if response and response.get('fixed_code'):
return response
except Exception as e:
logger.warning(f'生成修复建议失败: {e}')
return None
def _get_code_files(self, clone_dir: str, language: str, changed_files: Optional[List[str]] = None) -> List[str]:
"""获取代码文件列表"""
import glob
extensions = {
'python': ['.py'],
'javascript': ['.js', '.jsx'],
'typescript': ['.ts', '.tsx']
}
exts = extensions.get(language, ['.py'])
# 如果提供了变更文件列表,只返回这些文件
if changed_files:
files = []
for changed_file in changed_files:
if any(changed_file.endswith(ext) for ext in exts):
full_path = os.path.join(clone_dir, changed_file)
if os.path.exists(full_path):
files.append(full_path)
return files[:10]
# 否则扫描整个仓库
files = []
for ext in exts:
pattern = os.path.join(clone_dir, '**', f'*{ext}')
files.extend(glob.glob(pattern, recursive=True))
# 过滤掉测试文件和虚拟环境
files = [f for f in files if not any(x in f for x in [
'test_', '_test.', 'venv', 'node_modules', '__pycache__'
])]
return files[:10] # 最多 10 个文件
def _review_file(self, file_path: str, language: str, clone_dir: str = None) -> Optional[Dict[str, Any]]:
"""审查单个文件"""
issues = []
try:
with open(file_path, 'r', encoding='utf-8') as f:
code = f.read()
# 限制代码行数
lines = code.split('\n')
if len(lines) > self.max_lines:
code = '\n'.join(lines[:self.max_lines])
truncated = True
else:
truncated = False
# 给代码加行号再发给模型,便于模型返回准确行号
code_with_lines = self._code_with_line_numbers(code)
prompt = self._build_prompt(code_with_lines, language)
# 调用 AI
response = self._call_ai(prompt)
# 获取相对路径
rel_path = os.path.relpath(file_path, clone_dir) if (clone_dir and file_path) else file_path
if not response:
return {
'file': rel_path,
'path': file_path,
'truncated': truncated,
'issues': []
}
# 解析 AI 响应,转换为标准 issues 格式,并校正行号
ai_issues = response.get('issues', [])
for issue in ai_issues:
self._correct_issue_line(issue, code)
issues.append({
'tool': 'ai_reviewer',
'type': issue.get('type', 'info'),
'severity': issue.get('severity', 'Info'),
'message': issue.get('message', ''),
'file': rel_path,
'line': issue.get('line', 0),
'column': issue.get('column', 0),
'symbol': issue.get('symbol', ''),
'code_context': issue.get('code_context', ''),
'defect_reason': issue.get('defect_reason', '')
})
return {
'file': rel_path,
'path': file_path,
'truncated': truncated,
'issues': issues
}
except Exception as e:
logger.warning(f'审查文件失败 {file_path}: {str(e)}')
return None
def _build_prompt(self, code: str, language: str) -> str:
"""构建审查 prompt"""
if language == 'python':
lang_name = 'Python'
elif language in ['javascript', 'typescript']:
lang_name = 'JavaScript/TypeScript'
else:
lang_name = language
prompt = f"""你是一位资深的 {lang_name} 代码审查专家。请审查以下代码,找出潜在的问题和缺陷。
请以 JSON 格式输出审查结果,必须包含以下字段:
```json
{{
"issues": [
{{
"line": 行号,
"column": 列号,
"message": "问题描述",
"type": "error/warning/info 之一",
"severity": "Error/Warning/Info 之一",
"symbol": "错误标识符如 unused-variable, syntax-error 等",
"code_context": "问题代码的上下文(包含问题的那行或几行代码)",
"defect_reason": "缺陷原因分析30字以内简洁描述"
}}
]
}}
```
注意:
1. line 和 column 是问题所在的行号和列号(从 1 开始)
2. type: error=错误, warning=警告, info=信息
3. severity: Error=严重, Warning=一般, Info=提示
4. code_context: 包含问题代码的那一行或相邻的几行
5. defect_reason: 精简描述30字以内说明问题原因和风险
如果代码没有问题,返回空数组: {{"issues": []}}
重要:以下代码每行前已标注行号(格式为 "行号|"),请根据问题实际出现的代码行,严格使用该行前的行号填写 issues 中的 line 字段,不要猜测或使用错误行号。
以下是待审查的代码(行号已标注):
```{language}
{code}
```"""
return prompt
def _code_with_line_numbers(self, code: str) -> str:
"""给代码每行前加上行号,便于模型返回准确行号"""
lines = code.split('\n')
width = len(str(len(lines)))
return '\n'.join(f'{i:>{width}}| {line}' for i, line in enumerate(lines, 1))
def _correct_issue_line(self, issue: Dict[str, Any], code: str) -> None:
"""
根据 message/symbol 在源码中搜索,尽量把 issue 的 line 校正到真实出现位置。
AI 返回的行号常不准确,通过匹配问题相关的标识符(如 'unused_module')修正行号。
"""
line = issue.get('line')
if not line or not code:
return
lines = code.split('\n')
if line < 1 or line > len(lines):
return
# 从 message 中提取被引用的标识符(如 'unused_module' -> unused_module
message = (issue.get('message') or '')
symbol = (issue.get('symbol') or '').strip()
candidates = []
if symbol:
candidates.append(symbol)
for m in re.finditer(r"['\"]([a-zA-Z_][a-zA-Z0-9_]*)['\"]", message or ''):
candidates.append(m.group(1))
# 若 message 里没有引号标识符,取首段英文/数字/下划线作为关键词
if not candidates:
first_word = re.search(r'\b([a-zA-Z_][a-zA-Z0-9_]*)\b', message)
if first_word:
candidates.append(first_word.group(1))
for token in candidates:
if not token:
continue
for i, code_line in enumerate(lines):
if token in code_line:
issue['line'] = i + 1
return
def _call_ai(self, prompt: str) -> Optional[Dict[str, Any]]:
"""调用 AI 服务"""
try:
if self.provider == 'ollama':
return self._call_ollama(prompt)
elif self.provider == 'api':
return self._call_api(prompt)
else:
logger.warning(f'未知的 AI provider: {self.provider}')
return None
except Exception as e:
print("异常追踪信息:", e.__traceback__)
logger.error(f'AI 调用失败: {str(e)}')
return None
def _extract_json_obj(self, content: Any) -> Optional[Dict[str, Any]]:
"""
从模型输出中尽可能提取 JSON 对象(dict)。
兼容场景:
- content 已经是 dict
- content 是 JSON 字符串
- content 被 ```json ... ``` 或 ``` ... ``` 包裹
- content 前后夹杂说明文字,只要包含一个最外层 { ... } 就尝试解析
"""
if content is None:
logger.debug("_extract_json_obj: content is None")
return None
# 如果已经是 dict直接返回
if isinstance(content, dict):
logger.debug("_extract_json_obj: content is already dict")
return content
if not isinstance(content, str):
content = str(content)
text = content.strip()
logger.debug(f"_extract_json_obj: 原始内容长度 = {len(text)}")
logger.debug(f"_extract_json_obj: 原始内容前100字符: {text[:100]}")
# 去掉代码块包裹(兼容 ```json / ``` json / ```JSON 等)
lowered = text.lower()
fence_start = lowered.find('```')
if fence_start != -1:
logger.debug(f"_extract_json_obj: 发现代码块 fence_start={fence_start}")
# 找到第一段 fence
after = text[fence_start + 3:]
after_l = after.lower()
# 如果 fence 后紧跟语言标识json 或其他),跳过这一行直到换行
newline_idx = after.find('\n')
if newline_idx != -1:
lang_header = after_l[:newline_idx].strip()
logger.debug(f"_extract_json_obj: 语言标识: {lang_header}")
body = after[newline_idx + 1:]
# 截取到下一个 fence 结束
end_idx = body.lower().find('```')
if end_idx != -1:
candidate = body[:end_idx].strip()
else:
# 没有结束 fence直接用 body 作为候选(可能是截断的 JSON
candidate = body.strip()
# 只有在确实像 json 的情况下才替换,避免误伤普通文本
if '{' in candidate and '}' in candidate:
text = candidate
logger.debug(f"_extract_json_obj: 提取代码块内容成功,长度={len(text)}")
else:
# 没有换行就按旧逻辑尽量截取
pass
# 直接解析
try:
obj = json.loads(text)
logger.debug("_extract_json_obj: 直接解析成功")
return obj if isinstance(obj, dict) else None
except Exception as e:
logger.debug(f"_extract_json_obj: 直接解析失败: {e}")
# 兜底:截取最外层 { ... } 再解析
start = text.find('{')
end = text.rfind('}')
logger.debug(f"_extract_json_obj: 查找大括号 start={start}, end={end}")
if start != -1 and end != -1 and end > start:
candidate = text[start:end + 1].strip()
logger.debug(f"_extract_json_obj: 候选内容长度={len(candidate)}, 前50字符: {candidate[:50]}")
try:
obj = json.loads(candidate)
logger.debug("_extract_json_obj: 兜底解析成功")
return obj if isinstance(obj, dict) else None
except Exception as e:
logger.debug(f"_extract_json_obj: 兜底解析失败: {e}")
return None
logger.debug("_extract_json_obj: 未能提取到有效的 JSON 对象")
return None
def _call_ollama(self, prompt: str) -> Optional[Dict[str, Any]]:
"""调用 Ollama 本地模型"""
import requests
url = f"{self.api_url}/api/generate"
payload = {
"model": self.model,
"prompt": prompt,
"stream": False,
"format": "json"
}
logger.info(f"调用 Ollama: {url}, model={self.model}")
response = requests.post(url, json=payload, timeout=120)
if response.status_code == 200:
result = response.json()
content = result.get('response', '')
logger.info(f"Ollama 返回内容长度: {len(content) if content else 0}")
logger.debug(f"Ollama 返回内容预览: {content[:200] if content else 'empty'}")
parsed = self._extract_json_obj(content)
return parsed
logger.warning(f'Ollama 返回错误: {response.status_code}')
return None
def _call_api(self, prompt: str) -> Optional[Dict[str, Any]]:
"""调用在线 API"""
import requests
headers = {
'Content-Type': 'application/json'
}
if self.api_key:
headers['Authorization'] = f'Bearer {self.api_key}'
# 根据 API URL 自动判断 provider
if 'siliconflow' in self.api_url:
url = f"{self.api_url}/chat/completions"
payload = {
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 1024*5,
"temperature": 0.7
}
elif 'deepseek' in self.api_url:
url = f"{self.api_url}/chat/completions"
payload = {
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 1024*5,
"temperature": 0.7
}
else:
url = f"{self.api_url}/chat/completions"
payload = {
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 1024*5,
"temperature": 0.7
}
response = requests.post(url, json=payload, headers=headers, timeout=120)
if response.status_code == 200:
result = response.json()
content = result['choices'][0]['message']['content']
parsed = self._extract_json_obj(content)
return parsed
logger.warning(f'API 返回错误: {response.status_code}')
return None