365 lines
12 KiB
Python
365 lines
12 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
AI 代码审查器
|
||
使用大模型进行智能代码审查
|
||
"""
|
||
import os
|
||
import json
|
||
import logging
|
||
from typing import Dict, Any, List, Optional
|
||
|
||
from scanner.base import BaseScanner
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class AIReviewer(BaseScanner):
|
||
"""AI 代码审查器"""
|
||
|
||
def __init__(self, config: Dict[str, Any]):
|
||
"""
|
||
初始化 AI 审查器
|
||
|
||
Args:
|
||
config: AI 配置
|
||
"""
|
||
# 先初始化基类
|
||
super().__init__(config.get('scanner', {}))
|
||
|
||
self.config = config
|
||
self.enabled = config.get('enabled', True)
|
||
self.provider = config.get('provider', 'ollama')
|
||
self.model = config.get('model', 'llama3')
|
||
self.api_url = config.get('api_url', 'http://localhost:11434')
|
||
self.api_key = config.get('api_key', '')
|
||
self.max_lines = config.get('max_lines', 200)
|
||
|
||
if not self.enabled:
|
||
logger.info('AI 审查器已禁用')
|
||
return
|
||
|
||
logger.info(f'AI 审查器初始化: {self.provider}/{self.model}')
|
||
|
||
def scan(self, repo_url: str, commit_id: Optional[str], branch: str, changed_files: Optional[List[str]] = None) -> Dict[str, Any]:
|
||
"""
|
||
执行代码扫描(实现抽象方法)
|
||
|
||
Args:
|
||
repo_url: 仓库 URL
|
||
commit_id: 提交 ID
|
||
branch: 分支名
|
||
changed_files: 可选的变更文件列表(来自 PR)
|
||
|
||
Returns:
|
||
审查结果
|
||
"""
|
||
# 调用实际的审查逻辑
|
||
return self._do_review(repo_url=repo_url, commit_id=commit_id, branch=branch, changed_files=changed_files)
|
||
|
||
def _do_review(self, clone_dir: str = None, repo_url: str = None,
|
||
commit_id: str = None, branch: str = None,
|
||
language: str = 'python',
|
||
changed_files: Optional[List[str]] = None) -> Dict[str, Any]:
|
||
"""
|
||
执行 AI 代码审查
|
||
|
||
Args:
|
||
clone_dir: 仓库目录(如果已克隆则直接传入)
|
||
repo_url: 仓库 URL(如果未克隆则需要传入)
|
||
commit_id: 提交 ID
|
||
branch: 分支名
|
||
language: 编程语言
|
||
changed_files: 可选的变更文件列表(来自 PR)
|
||
|
||
Returns:
|
||
审查结果
|
||
"""
|
||
if not self.enabled:
|
||
return {
|
||
'enabled': False,
|
||
'tool': 'AI Code Reviewer',
|
||
'reviews': [],
|
||
'summary': 'AI 审查已禁用'
|
||
}
|
||
|
||
try:
|
||
# 如果没有传入 clone_dir,需要克隆
|
||
if not clone_dir and repo_url:
|
||
clone_dir = self.clone_repo(repo_url, commit_id, branch)
|
||
|
||
if not clone_dir or not os.path.exists(clone_dir):
|
||
return {
|
||
'enabled': True,
|
||
'tool': 'AI Code Reviewer',
|
||
'reviews': [],
|
||
'summary': '无法获取代码目录'
|
||
}
|
||
|
||
# 获取要审查的代码文件
|
||
files = self._get_code_files(clone_dir, language, changed_files)
|
||
|
||
if not files:
|
||
return {
|
||
'enabled': True,
|
||
'tool': 'AI Code Reviewer',
|
||
'reviews': [],
|
||
'summary': '未找到可审查的代码文件'
|
||
}
|
||
|
||
# 对每个文件进行 AI 审查
|
||
all_reviews = []
|
||
for file_path in files[:5]: # 限制最多审查 5 个文件
|
||
review = self._review_file(file_path, language, clone_dir)
|
||
if review:
|
||
all_reviews.append(review)
|
||
|
||
# 生成总结
|
||
summary = self._generate_summary(all_reviews)
|
||
|
||
return {
|
||
'enabled': True,
|
||
'tool': 'AI Code Reviewer',
|
||
'reviews': all_reviews,
|
||
'summary': summary,
|
||
'files_reviewed': len(all_reviews),
|
||
'clone_dir': clone_dir # 返回 clone_dir 用于后续清理
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f'AI 审查失败: {str(e)}')
|
||
return {
|
||
'enabled': True,
|
||
'tool': 'AI Code Reviewer',
|
||
'error': str(e),
|
||
'reviews': [],
|
||
'summary': f'AI 审查出错: {str(e)}'
|
||
}
|
||
|
||
def _get_code_files(self, clone_dir: str, language: str, changed_files: Optional[List[str]] = None) -> List[str]:
|
||
"""获取代码文件列表"""
|
||
import glob
|
||
|
||
extensions = {
|
||
'python': ['.py'],
|
||
'javascript': ['.js', '.jsx'],
|
||
'typescript': ['.ts', '.tsx']
|
||
}
|
||
|
||
exts = extensions.get(language, ['.py'])
|
||
|
||
# 如果提供了变更文件列表,只返回这些文件
|
||
if changed_files:
|
||
files = []
|
||
for changed_file in changed_files:
|
||
if any(changed_file.endswith(ext) for ext in exts):
|
||
full_path = os.path.join(clone_dir, changed_file)
|
||
if os.path.exists(full_path):
|
||
files.append(full_path)
|
||
return files[:10]
|
||
|
||
# 否则扫描整个仓库
|
||
files = []
|
||
|
||
for ext in exts:
|
||
pattern = os.path.join(clone_dir, '**', f'*{ext}')
|
||
files.extend(glob.glob(pattern, recursive=True))
|
||
|
||
# 过滤掉测试文件和虚拟环境
|
||
files = [f for f in files if not any(x in f for x in [
|
||
'test_', '_test.', 'venv', 'node_modules', '__pycache__'
|
||
])]
|
||
|
||
return files[:10] # 最多 10 个文件
|
||
|
||
def _review_file(self, file_path: str, language: str, clone_dir: str = None) -> Optional[Dict[str, Any]]:
|
||
"""审查单个文件"""
|
||
try:
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
code = f.read()
|
||
|
||
# 限制代码行数
|
||
lines = code.split('\n')
|
||
if len(lines) > self.max_lines:
|
||
code = '\n'.join(lines[:self.max_lines])
|
||
truncated = True
|
||
else:
|
||
truncated = False
|
||
|
||
# 构建 prompt
|
||
prompt = self._build_prompt(code, language)
|
||
|
||
# 调用 AI
|
||
response = self._call_ai(prompt)
|
||
|
||
if not response:
|
||
return None
|
||
|
||
# 解析响应
|
||
rel_path = os.path.relpath(file_path, clone_dir) if (clone_dir and file_path) else file_path
|
||
return {
|
||
'file': rel_path,
|
||
'path': file_path,
|
||
'truncated': truncated,
|
||
'review': response
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.warning(f'审查文件失败 {file_path}: {str(e)}')
|
||
return None
|
||
|
||
def _build_prompt(self, code: str, language: str) -> str:
|
||
"""构建审查 prompt"""
|
||
if language == 'python':
|
||
lang_name = 'Python'
|
||
elif language in ['javascript', 'typescript']:
|
||
lang_name = 'JavaScript/TypeScript'
|
||
else:
|
||
lang_name = language
|
||
|
||
prompt = f"""你是一位资深的 {lang_name} 代码审查专家。请审查以下代码,并给出:
|
||
|
||
1. **代码优点** - 写得好地方
|
||
2. **问题建议** - 需要改进的地方
|
||
3. **优化建议** - 如何让代码更好
|
||
|
||
请用中文回复,保持简洁,每个文件审查不超过 3 点建议。
|
||
|
||
以下是代码:
|
||
```{language}
|
||
{code}
|
||
```
|
||
|
||
请以 JSON 格式输出:
|
||
```json
|
||
{{
|
||
"优点": ["..."],
|
||
"问题": ["..."],
|
||
"优化": ["..."]
|
||
}}
|
||
```"""
|
||
return prompt
|
||
|
||
def _call_ai(self, prompt: str) -> Optional[Dict[str, Any]]:
|
||
"""调用 AI 服务"""
|
||
try:
|
||
if self.provider == 'ollama':
|
||
return self._call_ollama(prompt)
|
||
elif self.provider == 'api':
|
||
return self._call_api(prompt)
|
||
else:
|
||
logger.warning(f'未知的 AI provider: {self.provider}')
|
||
return None
|
||
except Exception as e:
|
||
print("异常追踪信息:", e.__traceback__)
|
||
logger.error(f'AI 调用失败: {str(e)}')
|
||
return None
|
||
|
||
def _call_ollama(self, prompt: str) -> Optional[Dict[str, Any]]:
|
||
"""调用 Ollama 本地模型"""
|
||
import requests
|
||
|
||
url = f"{self.api_url}/api/generate"
|
||
payload = {
|
||
"model": self.model,
|
||
"prompt": prompt,
|
||
"stream": False,
|
||
"format": "json"
|
||
}
|
||
|
||
response = requests.post(url, json=payload, timeout=120)
|
||
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
content = result.get('response', '')
|
||
|
||
# 尝试解析 JSON
|
||
try:
|
||
# 提取 JSON 部分
|
||
if '```json' in content:
|
||
content = content.split('```json')[1].split('```')[0]
|
||
elif '```' in content:
|
||
content = content.split('```')[1].split('```')[0]
|
||
|
||
return json.loads(content.strip())
|
||
except json.JSONDecodeError:
|
||
# 如果不是 JSON,直接返回文本
|
||
return {'raw_review': content}
|
||
|
||
logger.warning(f'Ollama 返回错误: {response.status_code}')
|
||
return None
|
||
|
||
def _call_api(self, prompt: str) -> Optional[Dict[str, Any]]:
|
||
"""调用在线 API"""
|
||
import requests
|
||
|
||
headers = {
|
||
'Content-Type': 'application/json'
|
||
}
|
||
|
||
if self.api_key:
|
||
headers['Authorization'] = f'Bearer {self.api_key}'
|
||
|
||
# 根据 API URL 自动判断 provider
|
||
if 'siliconflow' in self.api_url:
|
||
url = f"{self.api_url}/chat/completions"
|
||
payload = {
|
||
"model": self.model,
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
"max_tokens": 1024,
|
||
"temperature": 0.7
|
||
}
|
||
elif 'deepseek' in self.api_url:
|
||
url = f"{self.api_url}/chat/completions"
|
||
payload = {
|
||
"model": self.model,
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
"max_tokens": 1024,
|
||
"temperature": 0.7
|
||
}
|
||
else:
|
||
url = f"{self.api_url}/chat/completions"
|
||
payload = {
|
||
"model": self.model,
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
"max_tokens": 1024,
|
||
"temperature": 0.7
|
||
}
|
||
|
||
response = requests.post(url, json=payload, headers=headers, timeout=120)
|
||
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
content = result['choices'][0]['message']['content']
|
||
|
||
try:
|
||
if '```json' in content:
|
||
content = content.split('```json')[1].split('```')[0]
|
||
elif '```' in content:
|
||
content = content.split('```')[1].split('```')[0]
|
||
|
||
return json.loads(content.strip())
|
||
except json.JSONDecodeError:
|
||
return {'raw_review': content}
|
||
|
||
logger.warning(f'API 返回错误: {response.status_code}')
|
||
return None
|
||
|
||
def _generate_summary(self, reviews: List[Dict[str, Any]]) -> str:
|
||
"""生成审查总结"""
|
||
if not reviews:
|
||
return '未找到需要审查的代码'
|
||
|
||
total_issues = sum(
|
||
len(r.get('review', {}).get('问题', [])) +
|
||
len(r.get('review', {}).get('优化', []))
|
||
for r in reviews
|
||
)
|
||
|
||
files_count = len(reviews)
|
||
|
||
if total_issues == 0:
|
||
return f'✅ AI 审查通过!审查了 {files_count} 个文件,未发现问题'
|
||
|
||
return f'🤖 AI 审查了 {files_count} 个文件,发现 {total_issues} 个改进建议'
|