Files
appium_ui_test/visual_comparator.py
2025-10-31 17:53:12 +08:00

459 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
视觉比对器模块
用于比较截图与设计图的视觉差异支持PNG、JPG和SVG格式
"""
import cv2
import numpy as np
from pathlib import Path
from skimage.metrics import structural_similarity as ssim
from PIL import Image
from config import Config
# 条件导入cairosvg如果失败则使用备用方案
CAIROSVG_AVAILABLE = False
try:
import cairosvg
CAIROSVG_AVAILABLE = True
print("✅ CairoSVG库可用支持完整SVG处理")
except (ImportError, OSError) as e:
print(f"⚠️ CairoSVG库不可用: {e}")
print("💡 将使用备用的SVG处理方案")
# 条件导入selenium用于SVG渲染
SELENIUM_AVAILABLE = False
try:
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.common.by import By
from webdriver_manager.chrome import ChromeDriverManager
import base64
SELENIUM_AVAILABLE = True
print("✅ Selenium库可用支持浏览器SVG渲染")
except ImportError as e:
print(f"⚠️ Selenium库不可用: {e}")
print("💡 将跳过浏览器SVG渲染方案")
class VisualComparator:
"""视觉比对器"""
def __init__(self):
self.comparison_results = []
self.config = Config.VISUAL_COMPARISON
def compare_images(self, screenshot_path, design_path, output_path=None):
"""比较截图和设计图"""
try:
# 读取图像
screenshot = self._load_image(screenshot_path)
design = self._load_image(design_path)
if screenshot is None or design is None:
print(f"❌ 无法加载图像: {screenshot_path}{design_path}")
return None
# 智能调整图像尺寸,保持宽高比
screenshot_processed, design_processed = self._smart_resize_images(screenshot, design)
# 转换为灰度图
gray1 = cv2.cvtColor(screenshot_processed, cv2.COLOR_BGR2GRAY)
gray2 = cv2.cvtColor(design_processed, cv2.COLOR_BGR2GRAY)
# 计算结构相似性
similarity_score, diff_image = ssim(gray1, gray2, full=True)
diff_image = (diff_image * 255).astype(np.uint8)
# 查找差异区域
diff_regions = self._find_diff_regions(diff_image)
# 生成比对结果图
if output_path:
self._generate_comparison_image(
screenshot_processed, design_processed, diff_image,
diff_regions, similarity_score, output_path
)
result = {
'similarity_score': similarity_score,
'diff_regions': diff_regions,
'total_diff_area': sum(r['area'] for r in diff_regions),
'diff_percentage': (1 - similarity_score) * 100,
'meets_threshold': similarity_score >= self.config['similarity_threshold'],
'original_screenshot_size': screenshot.shape[:2],
'original_design_size': design.shape[:2],
'processed_size': screenshot_processed.shape[:2]
}
self.comparison_results.append(result)
return result
except Exception as e:
print(f"❌ 图像比对失败: {e}")
return None
def _load_image(self, image_path):
"""加载图像支持PNG、JPG和SVG格式"""
image_path = Path(image_path)
if not image_path.exists():
print(f"❌ 文件不存在: {image_path}")
return None
# 检查文件格式
if image_path.suffix.lower() == '.svg':
return self._load_svg_image(image_path)
else:
# 加载常规图像格式
img = cv2.imread(str(image_path))
if img is None:
print(f"❌ 无法读取图像文件: {image_path}")
return img
def _load_svg_image(self, svg_path):
"""加载SVG图像并转换为OpenCV格式"""
import io
if CAIROSVG_AVAILABLE:
# 优先使用cairosvg
try:
# 将SVG转换为PNG字节数据
png_data = cairosvg.svg2png(url=str(svg_path))
# 使用PIL加载PNG数据
pil_image = Image.open(io.BytesIO(png_data))
# 转换为OpenCV格式
opencv_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
return opencv_image
except Exception as e:
print(f"❌ CairoSVG处理失败: {e}")
# 备用方案1: 尝试使用wand (ImageMagick)
try:
from wand.image import Image as WandImage
with WandImage(filename=str(svg_path)) as img:
img.format = 'png'
blob = img.make_blob()
pil_image = Image.open(io.BytesIO(blob))
opencv_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
print("✅ 使用Wand处理SVG成功")
return opencv_image
except ImportError:
print("⚠️ Wand库不可用")
except Exception as e:
print(f"❌ Wand处理失败: {e}")
# 备用方案2: 尝试使用reportlab和svglib
try:
from reportlab.graphics import renderPM
from svglib.svglib import renderSVG
drawing = renderSVG(str(svg_path))
pil_image = renderPM.drawToPIL(drawing)
opencv_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
print("✅ 使用ReportLab处理SVG成功")
return opencv_image
except ImportError:
print("⚠️ ReportLab/svglib库不可用")
except Exception as e:
print(f"❌ ReportLab处理失败: {e}")
# 备用方案3: 尝试使用Selenium浏览器渲染
if SELENIUM_AVAILABLE:
try:
result = self._render_svg_with_selenium(svg_path)
if result is not None:
print("✅ 使用Selenium浏览器渲染SVG成功")
return result
except Exception as e:
print(f"❌ Selenium渲染失败: {e}")
# 备用方案4: 简单的文本提示图像
print(f"⚠️ 无法处理SVG文件: {svg_path}")
print("💡 建议安装以下任一库来支持SVG:")
print(" - pip install cairosvg (推荐但需要Cairo库)")
print(" - pip install Wand (需要ImageMagick)")
print(" - pip install reportlab svglib")
# 创建一个提示图像
placeholder_img = np.ones((400, 600, 3), dtype=np.uint8) * 240
cv2.putText(placeholder_img, "SVG Not Supported", (150, 180),
cv2.FONT_HERSHEY_SIMPLEX, 1, (100, 100, 100), 2)
cv2.putText(placeholder_img, "Install SVG library", (150, 220),
cv2.FONT_HERSHEY_SIMPLEX, 1, (100, 100, 100), 2)
return placeholder_img
def _render_svg_with_selenium(self, svg_path):
"""使用Selenium浏览器渲染SVG"""
try:
# 配置Chrome选项
chrome_options = Options()
chrome_options.add_argument('--headless') # 无头模式
chrome_options.add_argument('--no-sandbox')
chrome_options.add_argument('--disable-dev-shm-usage')
chrome_options.add_argument('--disable-gpu')
chrome_options.add_argument('--disable-web-security')
chrome_options.add_argument('--allow-running-insecure-content')
chrome_options.add_argument('--disable-extensions')
chrome_options.add_argument('--window-size=1200,800')
# 创建WebDriver
service = Service(ChromeDriverManager().install())
driver = webdriver.Chrome(service=service, options=chrome_options)
try:
# 读取SVG文件内容
with open(svg_path, 'r', encoding='utf-8') as f:
svg_content = f.read()
# 创建HTML页面包含SVG
html_content = f"""<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<style>
body {{ margin: 0; padding: 20px; background: white; }}
svg {{ max-width: 100%; height: auto; display: block; }}
</style>
</head>
<body>
{svg_content}
</body>
</html>"""
# 创建临时HTML文件
import tempfile
import os
with tempfile.NamedTemporaryFile(mode='w', suffix='.html', delete=False, encoding='utf-8') as f:
f.write(html_content)
temp_html_path = f.name
try:
# 使用file:// URL加载本地文件
file_url = f"file:///{temp_html_path.replace(os.sep, '/')}"
driver.get(file_url)
# 等待页面加载
import time
time.sleep(3)
# 获取SVG元素
svg_element = driver.find_element(By.TAG_NAME, "svg")
# 截图SVG元素
screenshot = svg_element.screenshot_as_png
# 转换为OpenCV格式
import io
pil_image = Image.open(io.BytesIO(screenshot))
opencv_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
return opencv_image
finally:
# 清理临时文件
try:
os.unlink(temp_html_path)
except:
pass
finally:
driver.quit()
except Exception as e:
print(f"Selenium渲染SVG时出错: {e}")
return None
def _smart_resize_images(self, img1, img2):
"""智能调整图像尺寸,保持宽高比"""
h1, w1 = img1.shape[:2]
h2, w2 = img2.shape[:2]
# 计算宽高比
ratio1 = w1 / h1
ratio2 = w2 / h2
# 如果宽高比相近差异小于10%),则按较小的尺寸调整
if abs(ratio1 - ratio2) / max(ratio1, ratio2) < 0.1:
# 选择较小的尺寸作为目标尺寸
target_w = min(w1, w2)
target_h = min(h1, h2)
img1_resized = cv2.resize(img1, (target_w, target_h))
img2_resized = cv2.resize(img2, (target_w, target_h))
return img1_resized, img2_resized
# 如果宽高比差异较大,则保持原始比例,按最大公共区域对齐
else:
# 计算最大公共尺寸,保持各自的宽高比
if ratio1 > ratio2: # img1更宽
# 以img2的高度为基准
new_h = min(h1, h2)
new_w1 = int(new_h * ratio1)
new_w2 = int(new_h * ratio2)
else: # img2更宽或相等
# 以img1的高度为基准
new_h = min(h1, h2)
new_w1 = int(new_h * ratio1)
new_w2 = int(new_h * ratio2)
# 调整图像尺寸
img1_resized = cv2.resize(img1, (new_w1, new_h))
img2_resized = cv2.resize(img2, (new_w2, new_h))
# 如果宽度不同,则创建相同宽度的画布,居中放置
max_w = max(new_w1, new_w2)
if new_w1 != max_w:
canvas1 = np.ones((new_h, max_w, 3), dtype=np.uint8) * 255
offset = (max_w - new_w1) // 2
canvas1[:, offset:offset+new_w1] = img1_resized
img1_resized = canvas1
if new_w2 != max_w:
canvas2 = np.ones((new_h, max_w, 3), dtype=np.uint8) * 255
offset = (max_w - new_w2) // 2
canvas2[:, offset:offset+new_w2] = img2_resized
img2_resized = canvas2
return img1_resized, img2_resized
def _resize_to_match(self, img1, img2):
"""调整图像尺寸匹配(保留原方法作为备用)"""
h2, w2 = img2.shape[:2]
return cv2.resize(img1, (w2, h2))
def _find_diff_regions(self, diff_image):
"""查找差异区域"""
threshold = self.config['diff_threshold']
min_area = self.config['min_diff_area']
# 二值化差异图像
_, binary = cv2.threshold(diff_image, threshold, 255, cv2.THRESH_BINARY)
# 查找轮廓
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
regions = []
for contour in contours:
area = cv2.contourArea(contour)
if area > min_area: # 过滤小差异
x, y, w, h = cv2.boundingRect(contour)
regions.append({
'x': x, 'y': y, 'width': w, 'height': h,
'area': area,
'center_x': x + w//2, 'center_y': y + h//2
})
return regions
def _generate_comparison_image(self, img1, img2, diff_img, regions, score, output_path):
"""生成比对结果图像"""
# 创建拼接图像
h, w = img1.shape[:2]
result_img = np.zeros((h, w*3, 3), dtype=np.uint8)
# 放置原图像
result_img[:, :w] = img1
result_img[:, w:2*w] = img2
result_img[:, 2*w:] = cv2.cvtColor(diff_img, cv2.COLOR_GRAY2BGR)
# 标记差异区域
for region in regions:
x, y, rw, rh = region['x'], region['y'], region['width'], region['height']
# 在截图上标记
cv2.rectangle(result_img, (x, y), (x+rw, y+rh), (0, 0, 255), 2)
# 在设计图上标记
cv2.rectangle(result_img, (w+x, y), (w+x+rw, y+rh), (0, 0, 255), 2)
# 添加文本信息
font = cv2.FONT_HERSHEY_SIMPLEX
# 相似度信息
similarity_text = f"Similarity: {score:.2%}"
cv2.putText(result_img, similarity_text, (10, 30), font, 1, (255, 255, 255), 2)
# 差异统计
diff_count = len(regions)
diff_text = f"Differences: {diff_count}"
cv2.putText(result_img, diff_text, (10, 70), font, 0.8, (255, 255, 255), 2)
# 标签
cv2.putText(result_img, "Screenshot", (10, h-10), font, 0.7, (255, 255, 255), 2)
cv2.putText(result_img, "Design", (w+10, h-10), font, 0.7, (255, 255, 255), 2)
cv2.putText(result_img, "Difference", (2*w+10, h-10), font, 0.7, (255, 255, 255), 2)
# 保存结果
cv2.imwrite(str(output_path), result_img)
print(f"📊 比对结果已保存: {output_path}")
def get_comparison_summary(self):
"""获取比对结果摘要"""
if not self.comparison_results:
return None
latest_result = self.comparison_results[-1]
summary = {
'total_comparisons': len(self.comparison_results),
'latest_similarity': latest_result['similarity_score'],
'latest_diff_count': len(latest_result['diff_regions']),
'meets_threshold': latest_result['meets_threshold'],
'average_similarity': sum(r['similarity_score'] for r in self.comparison_results) / len(self.comparison_results)
}
return summary
def find_design_files(self, design_dir):
"""查找设计目录中的所有支持格式文件"""
design_dir = Path(design_dir)
supported_formats = self.config['supported_formats']
design_files = []
for format_ext in supported_formats:
design_files.extend(design_dir.glob(f"*{format_ext}"))
# 优先返回SVG文件
svg_files = [f for f in design_files if f.suffix.lower() == '.svg']
other_files = [f for f in design_files if f.suffix.lower() != '.svg']
return svg_files + other_files
def batch_compare(self, screenshot_path, design_dir, output_dir=None):
"""批量比对截图与设计目录中的所有文件"""
design_files = self.find_design_files(design_dir)
if not design_files:
print(f"❌ 在 {design_dir} 中未找到支持的设计文件")
return []
results = []
for design_file in design_files:
print(f"🔍 比对设计文件: {design_file.name}")
output_path = None
if output_dir:
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True)
output_path = output_dir / f"comparison_{design_file.stem}.png"
result = self.compare_images(screenshot_path, design_file, output_path)
if result:
result['design_file'] = str(design_file)
results.append(result)
return results