#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ AI 代码审查器 使用大模型进行智能代码审查 """ import os import re import json import logging from typing import Dict, Any, List, Optional from scanner.base import BaseScanner logger = logging.getLogger(__name__) class AIReviewer(BaseScanner): """AI 代码审查器""" def __init__(self, config: Dict[str, Any]): """ 初始化 AI 审查器 Args: config: AI 配置 """ # 先初始化基类 super().__init__(config.get('scanner', {})) self.config = config self.enabled = config.get('enabled', True) self.provider = config.get('provider', 'api') self.model = config.get('model', 'llama3') self.api_url = config.get('api_url', 'http://localhost:11434') self.api_key = config.get('api_key', '') self.max_lines = config.get('max_lines', 200) if not self.enabled: logger.info('AI 审查器已禁用') return logger.info(f'AI 审查器初始化: {self.provider}/{self.model}') def scan(self, repo_url: str, commit_id: Optional[str], branch: str, changed_files: Optional[List[str]] = None) -> Dict[str, Any]: """ 执行代码扫描(实现抽象方法) Args: repo_url: 仓库 URL commit_id: 提交 ID branch: 分支名 changed_files: 可选的变更文件列表(来自 PR) Returns: 审查结果 """ # 调用实际的审查逻辑 return self._do_review(repo_url=repo_url, commit_id=commit_id, branch=branch, changed_files=changed_files) def _do_review(self, clone_dir: str = None, repo_url: str = None, commit_id: str = None, branch: str = None, language: str = 'python', changed_files: Optional[List[str]] = None) -> Dict[str, Any]: """ 执行 AI 代码审查 Args: clone_dir: 仓库目录(如果已克隆则直接传入) repo_url: 仓库 URL(如果未克隆则需要传入) commit_id: 提交 ID branch: 分支名 language: 编程语言 changed_files: 可选的变更文件列表(来自 PR) Returns: 审查结果(与 python_scanner.py 兼容的格式) """ result = { 'tool': 'AI Code Reviewer', 'language': language, 'status': 'success', 'issues': [], 'summary': { 'total': 0, 'error': 0, 'warning': 0, 'info': 0 }, 'files_scanned': 0 } if not self.enabled: result['status'] = 'disabled' result['summary'] = 'AI 审查已禁用' return result try: # 如果没有传入 clone_dir,需要克隆 if not clone_dir and repo_url: clone_dir = self.clone_repo(repo_url, commit_id, branch) if not clone_dir or not os.path.exists(clone_dir): result['status'] = 'error' result['error'] = '无法获取代码目录' return result # 获取要审查的代码文件 files = self._get_code_files(clone_dir, language, changed_files) if not files: result['summary'] = '未找到可审查的代码文件' return result # 对每个文件进行 AI 审查 all_issues = [] for file_path in files[:5]: # 限制最多审查 5 个文件 review = self._review_file(file_path, language, clone_dir) if review and review.get('issues'): all_issues.extend(review['issues']) result['issues'] = all_issues[:self.max_issues] if self.detailed else all_issues result['summary'] = self._calculate_summary(all_issues) result['files_scanned'] = len(files[:5]) result['clone_dir'] = clone_dir # 生成质量评分 result['quality_score'] = self._calculate_quality_score(all_issues, files[:5]) return result except Exception as e: logger.error(f'AI 审查失败: {str(e)}') result['status'] = 'error' result['error'] = str(e) return result def _calculate_summary(self, issues: List[Dict]) -> Dict[str, int]: """计算问题摘要""" summary = { 'total': len(issues), 'error': 0, 'warning': 0, 'info': 0 } for issue in issues: severity = issue.get('severity', '').lower() if severity in ['error', 'critical', 'fatal']: summary['error'] += 1 elif severity in ['warning', 'moderate']: summary['warning'] += 1 else: summary['info'] += 1 return summary def _calculate_quality_score(self, issues: List[Dict], files: List[str]) -> Dict[str, Any]: """ 计算代码质量评分 返回:总分(0-100)及各维度评分 """ if not files: return {'total': 100, 'maintainability': 100, 'security': 100, 'readability': 100, 'best_practices': 100} # 统计问题 error_count = sum(1 for i in issues if i.get('severity', '').lower() in ['error', 'critical']) warning_count = sum(1 for i in issues if i.get('severity', '').lower() == 'warning') info_count = sum(1 for i in issues if i.get('severity', '').lower() == 'info') # 分类统计 security_keywords = ['sql injection', 'xss', 'csrf', 'password', 'secret', 'token', '权限', '注入', '认证'] security_issues = sum(1 for i in issues if any(k in (i.get('message', '') + i.get('symbol', '')).lower() for k in security_keywords)) # 计算各维度分数 # 可维护性:基于错误和警告数量 issue_weight = error_count * 5 + warning_count * 2 + info_count * 0.5 maintainability = max(0, 100 - issue_weight) # 安全性:基于安全问题 security_score = max(0, 100 - security_issues * 15) # 可读性:基于 info 级别问题(风格类) readability = max(0, 100 - info_count * 3) # 最佳实践:基于 warning 级别 best_practices = max(0, 100 - warning_count * 5) # 总分:加权平均 total = int((maintainability * 0.3 + security_score * 0.35 + readability * 0.15 + best_practices * 0.2)) return { 'total': total, 'maintainability': maintainability, 'security': security_score, 'readability': readability, 'best_practices': best_practices, 'details': { 'error_count': error_count, 'warning_count': warning_count, 'info_count': info_count, 'security_issues': security_issues } } def generate_fix_suggestion(self, file_path: str, line: int, message: str, code: str) -> Optional[str]: """ 对指定问题生成修复建议代码 """ prompt = f"""你是一位代码修复专家。请根据以下问题,生成修复后的代码。 问题描述:{message} 问题所在行号:{line} 原始代码: ``` {code} ``` 请以 JSON 格式输出修复建议: ```json {{ "fixed_code": "修复后的完整代码或关键片段", "explanation": "修复说明(50字以内)", "confidence": "high/medium/low 修复把握度" }} ``` 如果无法修复,请返回:{{"fixed_code": "", "explanation": "无法自动修复", "confidence": "low"}}""" try: response = self._call_ai(prompt) if response and response.get('fixed_code'): return response except Exception as e: logger.warning(f'生成修复建议失败: {e}') return None def _get_code_files(self, clone_dir: str, language: str, changed_files: Optional[List[str]] = None) -> List[str]: """获取代码文件列表""" import glob extensions = { 'python': ['.py'], 'javascript': ['.js', '.jsx'], 'typescript': ['.ts', '.tsx'] } exts = extensions.get(language, ['.py']) # 如果提供了变更文件列表,只返回这些文件 if changed_files: files = [] for changed_file in changed_files: if any(changed_file.endswith(ext) for ext in exts): full_path = os.path.join(clone_dir, changed_file) if os.path.exists(full_path): files.append(full_path) return files[:10] # 否则扫描整个仓库 files = [] for ext in exts: pattern = os.path.join(clone_dir, '**', f'*{ext}') files.extend(glob.glob(pattern, recursive=True)) # 过滤掉测试文件和虚拟环境 files = [f for f in files if not any(x in f for x in [ 'test_', '_test.', 'venv', 'node_modules', '__pycache__' ])] return files[:10] # 最多 10 个文件 def _review_file(self, file_path: str, language: str, clone_dir: str = None) -> Optional[Dict[str, Any]]: """审查单个文件""" issues = [] try: with open(file_path, 'r', encoding='utf-8') as f: code = f.read() # 限制代码行数 lines = code.split('\n') if len(lines) > self.max_lines: code = '\n'.join(lines[:self.max_lines]) truncated = True else: truncated = False # 给代码加行号再发给模型,便于模型返回准确行号 code_with_lines = self._code_with_line_numbers(code) prompt = self._build_prompt(code_with_lines, language) # 调用 AI response = self._call_ai(prompt) # 获取相对路径 rel_path = os.path.relpath(file_path, clone_dir) if (clone_dir and file_path) else file_path if not response: return { 'file': rel_path, 'path': file_path, 'truncated': truncated, 'issues': [] } # 解析 AI 响应,转换为标准 issues 格式,并校正行号 ai_issues = response.get('issues', []) for issue in ai_issues: self._correct_issue_line(issue, code) issues.append({ 'tool': 'ai_reviewer', 'type': issue.get('type', 'info'), 'severity': issue.get('severity', 'Info'), 'message': issue.get('message', ''), 'file': rel_path, 'line': issue.get('line', 0), 'column': issue.get('column', 0), 'symbol': issue.get('symbol', ''), 'code_context': issue.get('code_context', ''), 'defect_reason': issue.get('defect_reason', '') }) return { 'file': rel_path, 'path': file_path, 'truncated': truncated, 'issues': issues } except Exception as e: logger.warning(f'审查文件失败 {file_path}: {str(e)}') return None def _build_prompt(self, code: str, language: str) -> str: """构建审查 prompt""" if language == 'python': lang_name = 'Python' elif language in ['javascript', 'typescript']: lang_name = 'JavaScript/TypeScript' else: lang_name = language prompt = f"""你是一位资深的 {lang_name} 代码审查专家。请审查以下代码,找出潜在的问题和缺陷。 请以 JSON 格式输出审查结果,必须包含以下字段: ```json {{ "issues": [ {{ "line": 行号, "column": 列号, "message": "问题描述", "type": "error/warning/info 之一", "severity": "Error/Warning/Info 之一", "symbol": "错误标识符如 unused-variable, syntax-error 等", "code_context": "问题代码的上下文(包含问题的那行或几行代码)", "defect_reason": "缺陷原因分析(30字以内简洁描述)" }} ] }} ``` 注意: 1. line 和 column 是问题所在的行号和列号(从 1 开始) 2. type: error=错误, warning=警告, info=信息 3. severity: Error=严重, Warning=一般, Info=提示 4. code_context: 包含问题代码的那一行或相邻的几行 5. defect_reason: 精简描述,30字以内,说明问题原因和风险 如果代码没有问题,返回空数组: {{"issues": []}} 重要:以下代码每行前已标注行号(格式为 "行号|"),请根据问题实际出现的代码行,严格使用该行前的行号填写 issues 中的 line 字段,不要猜测或使用错误行号。 以下是待审查的代码(行号已标注): ```{language} {code} ```""" return prompt def _code_with_line_numbers(self, code: str) -> str: """给代码每行前加上行号,便于模型返回准确行号""" lines = code.split('\n') width = len(str(len(lines))) return '\n'.join(f'{i:>{width}}| {line}' for i, line in enumerate(lines, 1)) def _correct_issue_line(self, issue: Dict[str, Any], code: str) -> None: """ 根据 message/symbol 在源码中搜索,尽量把 issue 的 line 校正到真实出现位置。 AI 返回的行号常不准确,通过匹配问题相关的标识符(如 'unused_module')修正行号。 """ line = issue.get('line') if not line or not code: return lines = code.split('\n') if line < 1 or line > len(lines): return # 从 message 中提取被引用的标识符(如 'unused_module' -> unused_module) message = (issue.get('message') or '') symbol = (issue.get('symbol') or '').strip() candidates = [] if symbol: candidates.append(symbol) for m in re.finditer(r"['\"]([a-zA-Z_][a-zA-Z0-9_]*)['\"]", message or ''): candidates.append(m.group(1)) # 若 message 里没有引号标识符,取首段英文/数字/下划线作为关键词 if not candidates: first_word = re.search(r'\b([a-zA-Z_][a-zA-Z0-9_]*)\b', message) if first_word: candidates.append(first_word.group(1)) for token in candidates: if not token: continue for i, code_line in enumerate(lines): if token in code_line: issue['line'] = i + 1 return def _call_ai(self, prompt: str) -> Optional[Dict[str, Any]]: """调用 AI 服务""" try: return self._call_api(prompt) except Exception as e: print("异常追踪信息:", e.__traceback__) logger.error(f'AI 调用失败: {str(e)}') return None def _extract_json_obj(self, content: Any) -> Optional[Dict[str, Any]]: """ 从模型输出中尽可能提取 JSON 对象(dict)。 兼容场景: - content 已经是 dict - content 是 JSON 字符串 - content 被 ```json ... ``` 或 ``` ... ``` 包裹 - content 前后夹杂说明文字,只要包含一个最外层 { ... } 就尝试解析 """ if content is None: logger.debug("_extract_json_obj: content is None") return None # 如果已经是 dict,直接返回 if isinstance(content, dict): logger.debug("_extract_json_obj: content is already dict") return content if not isinstance(content, str): content = str(content) text = content.strip() logger.debug(f"_extract_json_obj: 原始内容长度 = {len(text)}") logger.debug(f"_extract_json_obj: 原始内容前100字符: {text[:100]}") # 去掉代码块包裹(兼容 ```json / ``` json / ```JSON 等) lowered = text.lower() fence_start = lowered.find('```') if fence_start != -1: logger.debug(f"_extract_json_obj: 发现代码块 fence_start={fence_start}") # 找到第一段 fence after = text[fence_start + 3:] after_l = after.lower() # 如果 fence 后紧跟语言标识(json 或其他),跳过这一行直到换行 newline_idx = after.find('\n') if newline_idx != -1: lang_header = after_l[:newline_idx].strip() logger.debug(f"_extract_json_obj: 语言标识: {lang_header}") body = after[newline_idx + 1:] # 截取到下一个 fence 结束 end_idx = body.lower().find('```') if end_idx != -1: candidate = body[:end_idx].strip() else: # 没有结束 fence,直接用 body 作为候选(可能是截断的 JSON) candidate = body.strip() # 只有在确实像 json 的情况下才替换,避免误伤普通文本 if '{' in candidate and '}' in candidate: text = candidate logger.debug(f"_extract_json_obj: 提取代码块内容成功,长度={len(text)}") else: # 没有换行就按旧逻辑尽量截取 pass # 直接解析 try: obj = json.loads(text) logger.debug("_extract_json_obj: 直接解析成功") return obj if isinstance(obj, dict) else None except Exception as e: logger.debug(f"_extract_json_obj: 直接解析失败: {e}") # 兜底:截取最外层 { ... } 再解析 start = text.find('{') end = text.rfind('}') logger.debug(f"_extract_json_obj: 查找大括号 start={start}, end={end}") if start != -1 and end != -1 and end > start: candidate = text[start:end + 1].strip() logger.debug(f"_extract_json_obj: 候选内容长度={len(candidate)}, 前50字符: {candidate[:50]}") try: obj = json.loads(candidate) logger.debug("_extract_json_obj: 兜底解析成功") return obj if isinstance(obj, dict) else None except Exception as e: logger.debug(f"_extract_json_obj: 兜底解析失败: {e}") return None logger.debug("_extract_json_obj: 未能提取到有效的 JSON 对象") return None def _call_api(self, prompt: str) -> Optional[Dict[str, Any]]: """调用在线 API""" import requests headers = { 'Content-Type': 'application/json' } if self.api_key: headers['Authorization'] = f'Bearer {self.api_key}' # 根据 API URL 自动判断 provider if 'siliconflow' in self.api_url: url = f"{self.api_url}/chat/completions" payload = { "model": self.model, "messages": [{"role": "user", "content": prompt}], "max_tokens": 1024, "temperature": 0.7 } elif 'deepseek' in self.api_url: url = f"{self.api_url}/chat/completions" payload = { "model": self.model, "messages": [{"role": "user", "content": prompt}], "max_tokens": 1024, "temperature": 0.7 } elif 'dashscope' in self.api_url: # 阿里云 dashscope 专用端点 url = f"{self.api_url}/chat/completions" payload = { "model": self.model, "messages": [{"role": "user", "content": prompt}], "max_tokens": 1024, "temperature": 0.7, "stream": False # 显式关闭流式 } else: url = f"{self.api_url}/chat/completions" payload = { "model": self.model, "messages": [{"role": "user", "content": prompt}], "max_tokens": 1024, "temperature": 0.7 } logger.info(f"调用 API: {url}, model={self.model}") try: response = requests.post(url, json=payload, headers=headers, timeout=120) if response.status_code == 200: result = response.json() content = result['choices'][0]['message']['content'] logger.info(f"API 返回内容长度: {len(content) if content else 0}") parsed = self._extract_json_obj(content) return parsed logger.warning(f'API 返回错误: {response.status_code}, {response.text[:200]}') return None except Exception as e: logger.warning(f'API 调用失败: {e}') return None