Files
code_scan/scanner/ai_reviewer.py
Dang Zerong 027cf50759 add web
2026-03-12 14:42:23 +08:00

365 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
AI 代码审查器
使用大模型进行智能代码审查
"""
import os
import json
import logging
from typing import Dict, Any, List, Optional
from scanner.base import BaseScanner
logger = logging.getLogger(__name__)
class AIReviewer(BaseScanner):
"""AI 代码审查器"""
def __init__(self, config: Dict[str, Any]):
"""
初始化 AI 审查器
Args:
config: AI 配置
"""
# 先初始化基类
super().__init__(config.get('scanner', {}))
self.config = config
self.enabled = config.get('enabled', True)
self.provider = config.get('provider', 'ollama')
self.model = config.get('model', 'llama3')
self.api_url = config.get('api_url', 'http://localhost:11434')
self.api_key = config.get('api_key', '')
self.max_lines = config.get('max_lines', 200)
if not self.enabled:
logger.info('AI 审查器已禁用')
return
logger.info(f'AI 审查器初始化: {self.provider}/{self.model}')
def scan(self, repo_url: str, commit_id: Optional[str], branch: str, changed_files: Optional[List[str]] = None) -> Dict[str, Any]:
"""
执行代码扫描(实现抽象方法)
Args:
repo_url: 仓库 URL
commit_id: 提交 ID
branch: 分支名
changed_files: 可选的变更文件列表(来自 PR
Returns:
审查结果
"""
# 调用实际的审查逻辑
return self._do_review(repo_url=repo_url, commit_id=commit_id, branch=branch, changed_files=changed_files)
def _do_review(self, clone_dir: str = None, repo_url: str = None,
commit_id: str = None, branch: str = None,
language: str = 'python',
changed_files: Optional[List[str]] = None) -> Dict[str, Any]:
"""
执行 AI 代码审查
Args:
clone_dir: 仓库目录(如果已克隆则直接传入)
repo_url: 仓库 URL如果未克隆则需要传入
commit_id: 提交 ID
branch: 分支名
language: 编程语言
changed_files: 可选的变更文件列表(来自 PR
Returns:
审查结果
"""
if not self.enabled:
return {
'enabled': False,
'tool': 'AI Code Reviewer',
'reviews': [],
'summary': 'AI 审查已禁用'
}
try:
# 如果没有传入 clone_dir需要克隆
if not clone_dir and repo_url:
clone_dir = self.clone_repo(repo_url, commit_id, branch)
if not clone_dir or not os.path.exists(clone_dir):
return {
'enabled': True,
'tool': 'AI Code Reviewer',
'reviews': [],
'summary': '无法获取代码目录'
}
# 获取要审查的代码文件
files = self._get_code_files(clone_dir, language, changed_files)
if not files:
return {
'enabled': True,
'tool': 'AI Code Reviewer',
'reviews': [],
'summary': '未找到可审查的代码文件'
}
# 对每个文件进行 AI 审查
all_reviews = []
for file_path in files[:5]: # 限制最多审查 5 个文件
review = self._review_file(file_path, language, clone_dir)
if review:
all_reviews.append(review)
# 生成总结
summary = self._generate_summary(all_reviews)
return {
'enabled': True,
'tool': 'AI Code Reviewer',
'reviews': all_reviews,
'summary': summary,
'files_reviewed': len(all_reviews),
'clone_dir': clone_dir # 返回 clone_dir 用于后续清理
}
except Exception as e:
logger.error(f'AI 审查失败: {str(e)}')
return {
'enabled': True,
'tool': 'AI Code Reviewer',
'error': str(e),
'reviews': [],
'summary': f'AI 审查出错: {str(e)}'
}
def _get_code_files(self, clone_dir: str, language: str, changed_files: Optional[List[str]] = None) -> List[str]:
"""获取代码文件列表"""
import glob
extensions = {
'python': ['.py'],
'javascript': ['.js', '.jsx'],
'typescript': ['.ts', '.tsx']
}
exts = extensions.get(language, ['.py'])
# 如果提供了变更文件列表,只返回这些文件
if changed_files:
files = []
for changed_file in changed_files:
if any(changed_file.endswith(ext) for ext in exts):
full_path = os.path.join(clone_dir, changed_file)
if os.path.exists(full_path):
files.append(full_path)
return files[:10]
# 否则扫描整个仓库
files = []
for ext in exts:
pattern = os.path.join(clone_dir, '**', f'*{ext}')
files.extend(glob.glob(pattern, recursive=True))
# 过滤掉测试文件和虚拟环境
files = [f for f in files if not any(x in f for x in [
'test_', '_test.', 'venv', 'node_modules', '__pycache__'
])]
return files[:10] # 最多 10 个文件
def _review_file(self, file_path: str, language: str, clone_dir: str = None) -> Optional[Dict[str, Any]]:
"""审查单个文件"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
code = f.read()
# 限制代码行数
lines = code.split('\n')
if len(lines) > self.max_lines:
code = '\n'.join(lines[:self.max_lines])
truncated = True
else:
truncated = False
# 构建 prompt
prompt = self._build_prompt(code, language)
# 调用 AI
response = self._call_ai(prompt)
if not response:
return None
# 解析响应
rel_path = os.path.relpath(file_path, clone_dir) if (clone_dir and file_path) else file_path
return {
'file': rel_path,
'path': file_path,
'truncated': truncated,
'review': response
}
except Exception as e:
logger.warning(f'审查文件失败 {file_path}: {str(e)}')
return None
def _build_prompt(self, code: str, language: str) -> str:
"""构建审查 prompt"""
if language == 'python':
lang_name = 'Python'
elif language in ['javascript', 'typescript']:
lang_name = 'JavaScript/TypeScript'
else:
lang_name = language
prompt = f"""你是一位资深的 {lang_name} 代码审查专家。请审查以下代码,并给出:
1. **代码优点** - 写得好地方
2. **问题建议** - 需要改进的地方
3. **优化建议** - 如何让代码更好
请用中文回复,保持简洁,每个文件审查不超过 3 点建议。
以下是代码:
```{language}
{code}
```
请以 JSON 格式输出:
```json
{{
"优点": ["..."],
"问题": ["..."],
"优化": ["..."]
}}
```"""
return prompt
def _call_ai(self, prompt: str) -> Optional[Dict[str, Any]]:
"""调用 AI 服务"""
try:
if self.provider == 'ollama':
return self._call_ollama(prompt)
elif self.provider == 'api':
return self._call_api(prompt)
else:
logger.warning(f'未知的 AI provider: {self.provider}')
return None
except Exception as e:
print("异常追踪信息:", e.__traceback__)
logger.error(f'AI 调用失败: {str(e)}')
return None
def _call_ollama(self, prompt: str) -> Optional[Dict[str, Any]]:
"""调用 Ollama 本地模型"""
import requests
url = f"{self.api_url}/api/generate"
payload = {
"model": self.model,
"prompt": prompt,
"stream": False,
"format": "json"
}
response = requests.post(url, json=payload, timeout=120)
if response.status_code == 200:
result = response.json()
content = result.get('response', '')
# 尝试解析 JSON
try:
# 提取 JSON 部分
if '```json' in content:
content = content.split('```json')[1].split('```')[0]
elif '```' in content:
content = content.split('```')[1].split('```')[0]
return json.loads(content.strip())
except json.JSONDecodeError:
# 如果不是 JSON直接返回文本
return {'raw_review': content}
logger.warning(f'Ollama 返回错误: {response.status_code}')
return None
def _call_api(self, prompt: str) -> Optional[Dict[str, Any]]:
"""调用在线 API"""
import requests
headers = {
'Content-Type': 'application/json'
}
if self.api_key:
headers['Authorization'] = f'Bearer {self.api_key}'
# 根据 API URL 自动判断 provider
if 'siliconflow' in self.api_url:
url = f"{self.api_url}/chat/completions"
payload = {
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 1024,
"temperature": 0.7
}
elif 'deepseek' in self.api_url:
url = f"{self.api_url}/chat/completions"
payload = {
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 1024,
"temperature": 0.7
}
else:
url = f"{self.api_url}/chat/completions"
payload = {
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 1024,
"temperature": 0.7
}
response = requests.post(url, json=payload, headers=headers, timeout=120)
if response.status_code == 200:
result = response.json()
content = result['choices'][0]['message']['content']
try:
if '```json' in content:
content = content.split('```json')[1].split('```')[0]
elif '```' in content:
content = content.split('```')[1].split('```')[0]
return json.loads(content.strip())
except json.JSONDecodeError:
return {'raw_review': content}
logger.warning(f'API 返回错误: {response.status_code}')
return None
def _generate_summary(self, reviews: List[Dict[str, Any]]) -> str:
"""生成审查总结"""
if not reviews:
return '未找到需要审查的代码'
total_issues = sum(
len(r.get('review', {}).get('问题', [])) +
len(r.get('review', {}).get('优化', []))
for r in reviews
)
files_count = len(reviews)
if total_issues == 0:
return f'✅ AI 审查通过!审查了 {files_count} 个文件,未发现问题'
return f'🤖 AI 审查了 {files_count} 个文件,发现 {total_issues} 个改进建议'