#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 数据库模型 存储 PR 扫描结果和管理状态 """ import sqlite3 import json import os from datetime import datetime from typing import List, Dict, Any, Optional DB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data', 'pr_scans.db') def get_db_connection(): """获取数据库连接""" os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) conn = sqlite3.connect(DB_PATH) conn.row_factory = sqlite3.Row return conn def init_db(): """初始化数据库表""" conn = get_db_connection() cursor = conn.cursor() # PR 扫描结果表 cursor.execute(''' CREATE TABLE IF NOT EXISTS pr_scans ( id INTEGER PRIMARY KEY AUTOINCREMENT, pr_number INTEGER NOT NULL, repo_name TEXT NOT NULL, pr_title TEXT, pr_url TEXT, source_branch TEXT, target_branch TEXT, author TEXT, state TEXT DEFAULT 'pending', scan_status TEXT DEFAULT 'pending', scan_result TEXT, issues_count INTEGER DEFAULT 0, security_issues INTEGER DEFAULT 0, ai_review TEXT, report_path TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, merged_at TIMESTAMP, merged_by TEXT, UNIQUE(repo_name, pr_number) ) ''') # 扫描记录详情表 cursor.execute(''' CREATE TABLE IF NOT EXISTS scan_details ( id INTEGER PRIMARY KEY AUTOINCREMENT, pr_scan_id INTEGER NOT NULL, scan_type TEXT NOT NULL, scan_data TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (pr_scan_id) REFERENCES pr_scans(id) ) ''') conn.commit() conn.close() class PRScanDB: """PR 扫描结果数据库操作类""" @staticmethod def save_pr_scan(pr_info: Dict[str, Any], scan_results: Dict[str, Any], report_path: str = None) -> int: """ 保存 PR 扫描结果 Args: pr_info: PR 信息 scan_results: 扫描结果 report_path: 报告文件路径 Returns: 扫描记录 ID """ conn = get_db_connection() cursor = conn.cursor() # 统计问题数量 issues_count = 0 security_issues = 0 for scan_type, result in scan_results.items(): if isinstance(result, dict): if 'issues' in result: issues_count += len(result.get('issues', [])) if 'vulnerabilities' in result: security_issues += len(result.get('vulnerabilities', [])) # 检查是否已存在 cursor.execute( 'SELECT id FROM pr_scans WHERE repo_name = ? AND pr_number = ?', (pr_info.get('repo_name'), pr_info.get('pr_number')) ) existing = cursor.fetchone() if existing: # 更新现有记录 cursor.execute(''' UPDATE pr_scans SET pr_title = ?, source_branch = ?, target_branch = ?, author = ?, scan_status = ?, scan_result = ?, issues_count = ?, security_issues = ?, ai_review = ?, report_path = ?, updated_at = CURRENT_TIMESTAMP WHERE repo_name = ? AND pr_number = ? ''', ( pr_info.get('pr_title'), pr_info.get('source_branch'), pr_info.get('target_branch'), pr_info.get('author'), 'completed', json.dumps(scan_results, ensure_ascii=False), issues_count, security_issues, json.dumps(scan_results.get('ai', {}), ensure_ascii=False), report_path, pr_info.get('repo_name'), pr_info.get('pr_number') )) scan_id = existing['id'] else: # 插入新记录 cursor.execute(''' INSERT INTO pr_scans ( pr_number, repo_name, pr_title, pr_url, source_branch, target_branch, author, state, scan_status, scan_result, issues_count, security_issues, ai_review, report_path ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ''', ( pr_info.get('pr_number'), pr_info.get('repo_name'), pr_info.get('pr_title'), pr_info.get('pr_url'), pr_info.get('source_branch'), pr_info.get('target_branch'), pr_info.get('author'), 'open', 'completed', json.dumps(scan_results, ensure_ascii=False), issues_count, security_issues, json.dumps(scan_results.get('ai', {}), ensure_ascii=False), report_path )) scan_id = cursor.lastrowid conn.commit() conn.close() return scan_id @staticmethod def get_all_prs(status: str = None, state: str = None) -> List[Dict[str, Any]]: """ 获取所有 PR 扫描记录 Args: status: 扫描状态 (pending/completed) state: PR 状态 (open/merged/closed) Returns: PR 列表 """ conn = get_db_connection() cursor = conn.cursor() query = 'SELECT * FROM pr_scans WHERE 1=1' params = [] if status: query += ' AND scan_status = ?' params.append(status) if state: query += ' AND state = ?' params.append(state) query += ' ORDER BY updated_at DESC' cursor.execute(query, params) rows = cursor.fetchall() conn.close() return [dict(row) for row in rows] @staticmethod def get_pr_by_id(scan_id: int) -> Optional[Dict[str, Any]]: """根据 ID 获取 PR 扫描记录""" conn = get_db_connection() cursor = conn.cursor() cursor.execute('SELECT * FROM pr_scans WHERE id = ?', (scan_id,)) row = cursor.fetchone() conn.close() return dict(row) if row else None @staticmethod def get_pr_by_number(repo_name: str, pr_number: int) -> Optional[Dict[str, Any]]: """根据仓库名和 PR 号获取扫描记录""" conn = get_db_connection() cursor = conn.cursor() cursor.execute( 'SELECT * FROM pr_scans WHERE repo_name = ? AND pr_number = ?', (repo_name, pr_number) ) row = cursor.fetchone() conn.close() return dict(row) if row else None @staticmethod def update_pr_state(scan_id: int, state: str, merged_by: str = None): """更新 PR 状态""" conn = get_db_connection() cursor = conn.cursor() if state == 'merged': cursor.execute(''' UPDATE pr_scans SET state = ?, merged_at = CURRENT_TIMESTAMP, merged_by = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ? ''', (state, merged_by, scan_id)) else: cursor.execute(''' UPDATE pr_scans SET state = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ? ''', (state, scan_id)) conn.commit() conn.close() @staticmethod def delete_pr(scan_id: int): """删除 PR 扫描记录""" conn = get_db_connection() cursor = conn.cursor() cursor.execute('DELETE FROM scan_details WHERE pr_scan_id = ?', (scan_id,)) cursor.execute('DELETE FROM pr_scans WHERE id = ?', (scan_id,)) conn.commit() conn.close() # 初始化数据库 init_db()