Files
appium_ui_test/visual_comparator.py

459 lines
17 KiB
Python
Raw Normal View History

2025-10-31 17:53:12 +08:00
#!/usr/bin/env python3
"""
视觉比对器模块
用于比较截图与设计图的视觉差异支持PNGJPG和SVG格式
"""
import cv2
import numpy as np
from pathlib import Path
from skimage.metrics import structural_similarity as ssim
from PIL import Image
from config import Config
# 条件导入cairosvg如果失败则使用备用方案
CAIROSVG_AVAILABLE = False
try:
import cairosvg
CAIROSVG_AVAILABLE = True
print("✅ CairoSVG库可用支持完整SVG处理")
except (ImportError, OSError) as e:
print(f"⚠️ CairoSVG库不可用: {e}")
print("💡 将使用备用的SVG处理方案")
# 条件导入selenium用于SVG渲染
SELENIUM_AVAILABLE = False
try:
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.common.by import By
from webdriver_manager.chrome import ChromeDriverManager
import base64
SELENIUM_AVAILABLE = True
print("✅ Selenium库可用支持浏览器SVG渲染")
except ImportError as e:
print(f"⚠️ Selenium库不可用: {e}")
print("💡 将跳过浏览器SVG渲染方案")
class VisualComparator:
"""视觉比对器"""
def __init__(self):
self.comparison_results = []
self.config = Config.VISUAL_COMPARISON
def compare_images(self, screenshot_path, design_path, output_path=None):
"""比较截图和设计图"""
try:
# 读取图像
screenshot = self._load_image(screenshot_path)
design = self._load_image(design_path)
if screenshot is None or design is None:
print(f"❌ 无法加载图像: {screenshot_path}{design_path}")
return None
# 智能调整图像尺寸,保持宽高比
screenshot_processed, design_processed = self._smart_resize_images(screenshot, design)
# 转换为灰度图
gray1 = cv2.cvtColor(screenshot_processed, cv2.COLOR_BGR2GRAY)
gray2 = cv2.cvtColor(design_processed, cv2.COLOR_BGR2GRAY)
# 计算结构相似性
similarity_score, diff_image = ssim(gray1, gray2, full=True)
diff_image = (diff_image * 255).astype(np.uint8)
# 查找差异区域
diff_regions = self._find_diff_regions(diff_image)
# 生成比对结果图
if output_path:
self._generate_comparison_image(
screenshot_processed, design_processed, diff_image,
diff_regions, similarity_score, output_path
)
result = {
'similarity_score': similarity_score,
'diff_regions': diff_regions,
'total_diff_area': sum(r['area'] for r in diff_regions),
'diff_percentage': (1 - similarity_score) * 100,
'meets_threshold': similarity_score >= self.config['similarity_threshold'],
'original_screenshot_size': screenshot.shape[:2],
'original_design_size': design.shape[:2],
'processed_size': screenshot_processed.shape[:2]
}
self.comparison_results.append(result)
return result
except Exception as e:
print(f"❌ 图像比对失败: {e}")
return None
def _load_image(self, image_path):
"""加载图像支持PNG、JPG和SVG格式"""
image_path = Path(image_path)
if not image_path.exists():
print(f"❌ 文件不存在: {image_path}")
return None
# 检查文件格式
if image_path.suffix.lower() == '.svg':
return self._load_svg_image(image_path)
else:
# 加载常规图像格式
img = cv2.imread(str(image_path))
if img is None:
print(f"❌ 无法读取图像文件: {image_path}")
return img
def _load_svg_image(self, svg_path):
"""加载SVG图像并转换为OpenCV格式"""
import io
if CAIROSVG_AVAILABLE:
# 优先使用cairosvg
try:
# 将SVG转换为PNG字节数据
png_data = cairosvg.svg2png(url=str(svg_path))
# 使用PIL加载PNG数据
pil_image = Image.open(io.BytesIO(png_data))
# 转换为OpenCV格式
opencv_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
return opencv_image
except Exception as e:
print(f"❌ CairoSVG处理失败: {e}")
# 备用方案1: 尝试使用wand (ImageMagick)
try:
from wand.image import Image as WandImage
with WandImage(filename=str(svg_path)) as img:
img.format = 'png'
blob = img.make_blob()
pil_image = Image.open(io.BytesIO(blob))
opencv_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
print("✅ 使用Wand处理SVG成功")
return opencv_image
except ImportError:
print("⚠️ Wand库不可用")
except Exception as e:
print(f"❌ Wand处理失败: {e}")
# 备用方案2: 尝试使用reportlab和svglib
try:
from reportlab.graphics import renderPM
from svglib.svglib import renderSVG
drawing = renderSVG(str(svg_path))
pil_image = renderPM.drawToPIL(drawing)
opencv_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
print("✅ 使用ReportLab处理SVG成功")
return opencv_image
except ImportError:
print("⚠️ ReportLab/svglib库不可用")
except Exception as e:
print(f"❌ ReportLab处理失败: {e}")
# 备用方案3: 尝试使用Selenium浏览器渲染
if SELENIUM_AVAILABLE:
try:
result = self._render_svg_with_selenium(svg_path)
if result is not None:
print("✅ 使用Selenium浏览器渲染SVG成功")
return result
except Exception as e:
print(f"❌ Selenium渲染失败: {e}")
# 备用方案4: 简单的文本提示图像
print(f"⚠️ 无法处理SVG文件: {svg_path}")
print("💡 建议安装以下任一库来支持SVG:")
print(" - pip install cairosvg (推荐但需要Cairo库)")
print(" - pip install Wand (需要ImageMagick)")
print(" - pip install reportlab svglib")
# 创建一个提示图像
placeholder_img = np.ones((400, 600, 3), dtype=np.uint8) * 240
cv2.putText(placeholder_img, "SVG Not Supported", (150, 180),
cv2.FONT_HERSHEY_SIMPLEX, 1, (100, 100, 100), 2)
cv2.putText(placeholder_img, "Install SVG library", (150, 220),
cv2.FONT_HERSHEY_SIMPLEX, 1, (100, 100, 100), 2)
return placeholder_img
def _render_svg_with_selenium(self, svg_path):
"""使用Selenium浏览器渲染SVG"""
try:
# 配置Chrome选项
chrome_options = Options()
chrome_options.add_argument('--headless') # 无头模式
chrome_options.add_argument('--no-sandbox')
chrome_options.add_argument('--disable-dev-shm-usage')
chrome_options.add_argument('--disable-gpu')
chrome_options.add_argument('--disable-web-security')
chrome_options.add_argument('--allow-running-insecure-content')
chrome_options.add_argument('--disable-extensions')
chrome_options.add_argument('--window-size=1200,800')
# 创建WebDriver
service = Service(ChromeDriverManager().install())
driver = webdriver.Chrome(service=service, options=chrome_options)
try:
# 读取SVG文件内容
with open(svg_path, 'r', encoding='utf-8') as f:
svg_content = f.read()
# 创建HTML页面包含SVG
html_content = f"""<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<style>
body {{ margin: 0; padding: 20px; background: white; }}
svg {{ max-width: 100%; height: auto; display: block; }}
</style>
</head>
<body>
{svg_content}
</body>
</html>"""
# 创建临时HTML文件
import tempfile
import os
with tempfile.NamedTemporaryFile(mode='w', suffix='.html', delete=False, encoding='utf-8') as f:
f.write(html_content)
temp_html_path = f.name
try:
# 使用file:// URL加载本地文件
file_url = f"file:///{temp_html_path.replace(os.sep, '/')}"
driver.get(file_url)
# 等待页面加载
import time
time.sleep(3)
# 获取SVG元素
svg_element = driver.find_element(By.TAG_NAME, "svg")
# 截图SVG元素
screenshot = svg_element.screenshot_as_png
# 转换为OpenCV格式
import io
pil_image = Image.open(io.BytesIO(screenshot))
opencv_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
return opencv_image
finally:
# 清理临时文件
try:
os.unlink(temp_html_path)
except:
pass
finally:
driver.quit()
except Exception as e:
print(f"Selenium渲染SVG时出错: {e}")
return None
def _smart_resize_images(self, img1, img2):
"""智能调整图像尺寸,保持宽高比"""
h1, w1 = img1.shape[:2]
h2, w2 = img2.shape[:2]
# 计算宽高比
ratio1 = w1 / h1
ratio2 = w2 / h2
# 如果宽高比相近差异小于10%),则按较小的尺寸调整
if abs(ratio1 - ratio2) / max(ratio1, ratio2) < 0.1:
# 选择较小的尺寸作为目标尺寸
target_w = min(w1, w2)
target_h = min(h1, h2)
img1_resized = cv2.resize(img1, (target_w, target_h))
img2_resized = cv2.resize(img2, (target_w, target_h))
return img1_resized, img2_resized
# 如果宽高比差异较大,则保持原始比例,按最大公共区域对齐
else:
# 计算最大公共尺寸,保持各自的宽高比
if ratio1 > ratio2: # img1更宽
# 以img2的高度为基准
new_h = min(h1, h2)
new_w1 = int(new_h * ratio1)
new_w2 = int(new_h * ratio2)
else: # img2更宽或相等
# 以img1的高度为基准
new_h = min(h1, h2)
new_w1 = int(new_h * ratio1)
new_w2 = int(new_h * ratio2)
# 调整图像尺寸
img1_resized = cv2.resize(img1, (new_w1, new_h))
img2_resized = cv2.resize(img2, (new_w2, new_h))
# 如果宽度不同,则创建相同宽度的画布,居中放置
max_w = max(new_w1, new_w2)
if new_w1 != max_w:
canvas1 = np.ones((new_h, max_w, 3), dtype=np.uint8) * 255
offset = (max_w - new_w1) // 2
canvas1[:, offset:offset+new_w1] = img1_resized
img1_resized = canvas1
if new_w2 != max_w:
canvas2 = np.ones((new_h, max_w, 3), dtype=np.uint8) * 255
offset = (max_w - new_w2) // 2
canvas2[:, offset:offset+new_w2] = img2_resized
img2_resized = canvas2
return img1_resized, img2_resized
def _resize_to_match(self, img1, img2):
"""调整图像尺寸匹配(保留原方法作为备用)"""
h2, w2 = img2.shape[:2]
return cv2.resize(img1, (w2, h2))
def _find_diff_regions(self, diff_image):
"""查找差异区域"""
threshold = self.config['diff_threshold']
min_area = self.config['min_diff_area']
# 二值化差异图像
_, binary = cv2.threshold(diff_image, threshold, 255, cv2.THRESH_BINARY)
# 查找轮廓
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
regions = []
for contour in contours:
area = cv2.contourArea(contour)
if area > min_area: # 过滤小差异
x, y, w, h = cv2.boundingRect(contour)
regions.append({
'x': x, 'y': y, 'width': w, 'height': h,
'area': area,
'center_x': x + w//2, 'center_y': y + h//2
})
return regions
def _generate_comparison_image(self, img1, img2, diff_img, regions, score, output_path):
"""生成比对结果图像"""
# 创建拼接图像
h, w = img1.shape[:2]
result_img = np.zeros((h, w*3, 3), dtype=np.uint8)
# 放置原图像
result_img[:, :w] = img1
result_img[:, w:2*w] = img2
result_img[:, 2*w:] = cv2.cvtColor(diff_img, cv2.COLOR_GRAY2BGR)
# 标记差异区域
for region in regions:
x, y, rw, rh = region['x'], region['y'], region['width'], region['height']
# 在截图上标记
cv2.rectangle(result_img, (x, y), (x+rw, y+rh), (0, 0, 255), 2)
# 在设计图上标记
cv2.rectangle(result_img, (w+x, y), (w+x+rw, y+rh), (0, 0, 255), 2)
# 添加文本信息
font = cv2.FONT_HERSHEY_SIMPLEX
# 相似度信息
similarity_text = f"Similarity: {score:.2%}"
cv2.putText(result_img, similarity_text, (10, 30), font, 1, (255, 255, 255), 2)
# 差异统计
diff_count = len(regions)
diff_text = f"Differences: {diff_count}"
cv2.putText(result_img, diff_text, (10, 70), font, 0.8, (255, 255, 255), 2)
# 标签
cv2.putText(result_img, "Screenshot", (10, h-10), font, 0.7, (255, 255, 255), 2)
cv2.putText(result_img, "Design", (w+10, h-10), font, 0.7, (255, 255, 255), 2)
cv2.putText(result_img, "Difference", (2*w+10, h-10), font, 0.7, (255, 255, 255), 2)
# 保存结果
cv2.imwrite(str(output_path), result_img)
print(f"📊 比对结果已保存: {output_path}")
def get_comparison_summary(self):
"""获取比对结果摘要"""
if not self.comparison_results:
return None
latest_result = self.comparison_results[-1]
summary = {
'total_comparisons': len(self.comparison_results),
'latest_similarity': latest_result['similarity_score'],
'latest_diff_count': len(latest_result['diff_regions']),
'meets_threshold': latest_result['meets_threshold'],
'average_similarity': sum(r['similarity_score'] for r in self.comparison_results) / len(self.comparison_results)
}
return summary
def find_design_files(self, design_dir):
"""查找设计目录中的所有支持格式文件"""
design_dir = Path(design_dir)
supported_formats = self.config['supported_formats']
design_files = []
for format_ext in supported_formats:
design_files.extend(design_dir.glob(f"*{format_ext}"))
# 优先返回SVG文件
svg_files = [f for f in design_files if f.suffix.lower() == '.svg']
other_files = [f for f in design_files if f.suffix.lower() != '.svg']
return svg_files + other_files
def batch_compare(self, screenshot_path, design_dir, output_dir=None):
"""批量比对截图与设计目录中的所有文件"""
design_files = self.find_design_files(design_dir)
if not design_files:
print(f"❌ 在 {design_dir} 中未找到支持的设计文件")
return []
results = []
for design_file in design_files:
print(f"🔍 比对设计文件: {design_file.name}")
output_path = None
if output_dir:
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True)
output_path = output_dir / f"comparison_{design_file.stem}.png"
result = self.compare_images(screenshot_path, design_file, output_path)
if result:
result['design_file'] = str(design_file)
results.append(result)
return results