"""Rule-based diagnostic engine for RAG evaluation metric scores.""" from __future__ import annotations import math from dataclasses import dataclass, field from typing import Any @dataclass class MetricRule: """Threshold configuration and diagnostic text for one metric.""" warning_threshold: float critical_threshold: float higher_is_better: bool # False for noise_sensitivity root_causes: list[str] suggested_actions: list[str] METRIC_RULES: dict[str, MetricRule] = { "faithfulness": MetricRule( warning_threshold=0.7, critical_threshold=0.5, higher_is_better=True, root_causes=[ "生成回答包含检索片段中不支持的陈述(幻觉)", "生成阶段未严格遵循 grounding 约束", "校验阶段未开启或未生效", ], suggested_actions=[ "强化生成 prompt 的 grounding 约束('只依据参考资料作答')", "开启校验阶段(validation: by_scenario)", "检查低分样本中模型是否引用了片段外的知识", ], ), "answer_relevancy": MetricRule( warning_threshold=0.7, critical_threshold=0.5, higher_is_better=True, root_causes=[ "回答偏离问题主旨或包含大量冗余内容", "查询改写后问题语义漂移", "生成 prompt 格式约束不足", ], suggested_actions=[ "优化查询改写 prompt,确保改写后语义不偏移", "在生成 prompt 中加入'简洁准确、直接回答问题'的约束", "检查低分样本的回答是否存在格式冗余或话题偏移", ], ), "context_recall": MetricRule( warning_threshold=0.7, critical_threshold=0.5, higher_is_better=True, root_causes=[ "检索未能召回标准答案所涉及的关键信息", "单一查询未能覆盖问题的多个角度", "过召回数量不足,关键片段被截断", ], suggested_actions=[ "启用多查询扩展(use_multi_query)覆盖不同措辞", "对多跳问题启用问题分解(sub_questions)", "加大过召回宽度(recall_top_k)", "对颗粒度细的问题尝试 Step-back 双路检索", ], ), "context_precision": MetricRule( warning_threshold=0.6, critical_threshold=0.4, higher_is_better=True, root_causes=[ "检索引入过多与问题无关的片段", "重排未能将相关片段排在前列", "缺少相关性过滤,噪声片段进入上下文", ], suggested_actions=[ "启用或优化 listwise 重排,将相关片段排在前列", "启用上下文压缩(compression)过滤无关句子", "启用相关性过滤(relevance_filter)丢弃明确无关片段", "缩小 rerank_keep_k(如从 8 降到 5)", ], ), "noise_sensitivity": MetricRule( warning_threshold=0.3, # higher is worse; trigger when mean > threshold critical_threshold=0.5, higher_is_better=False, root_causes=[ "回答中包含检索到的噪声片段所引入的错误陈述", "相关性过滤未能拦截干扰性片段", "生成阶段对噪声片段未加区分地引用", ], suggested_actions=[ "启用相关性过滤(relevance_filter)拦截噪声", "优化重排,将不相关片段排到截断点之后", "在生成 prompt 中强调'来源冲突时并列陈述,不擅自下定论'", ], ), "factual_correctness": MetricRule( warning_threshold=0.6, critical_threshold=0.4, higher_is_better=True, root_causes=[ "回答的事实陈述与标准答案存在偏差", "检索未能命中标准答案所依据的关键片段", "生成阶段对多个来源综合时产生事实错误", ], suggested_actions=[ "重点检查低分样本,确认是检索遗漏还是生成错误", "提升 context_recall 以确保关键信息被检索到", "对事实型问题将 temperature 降至 0", ], ), "semantic_similarity": MetricRule( warning_threshold=0.7, critical_threshold=0.5, higher_is_better=True, root_causes=[ "回答语义与标准答案差距较大", "回答过于简短或过于冗长,语义偏移", "检索到的片段质量不足,导致生成内容偏离", ], suggested_actions=[ "检查低分样本的回答与标准答案的表述差异", "优化生成 prompt 使回答更贴近标准表述风格", "提升检索质量(context_recall / context_precision)", ], ), } @dataclass class Diagnosis: """Diagnostic result for one metric that triggered a threshold.""" metric: str mean_score: float threshold: float # the triggered threshold severity: str # "warning" | "critical" root_causes: list[str] = field(default_factory=list) suggested_actions: list[str] = field(default_factory=list) low_samples: list[dict[str, Any]] = field(default_factory=list) def _mean_ignoring_nan(values: list[float]) -> float | None: valid = [v for v in values if not math.isnan(v)] if not valid: return None return sum(valid) / len(valid) def _select_low_samples( rows: list[dict[str, Any]], metric: str, top_n: int, higher_is_better: bool, ) -> list[dict[str, Any]]: """Return the top_n worst-scoring rows for a metric, excluding NaN.""" valid = [r for r in rows if metric in r and not math.isnan(float(r[metric]))] sorted_rows = sorted(valid, key=lambda r: float(r[metric]), reverse=not higher_is_better) worst = sorted_rows[:top_n] keep_keys = {"sample_id", "question", "answer", "ground_truth", metric} return [{k: v for k, v in row.items() if k in keep_keys} for row in worst] def diagnose( score_rows: list[dict[str, Any]], metrics: list[str], top_low_samples: int = 3, ) -> list[Diagnosis]: """Analyse score_rows and return a Diagnosis for each metric below threshold. Args: score_rows: List of per-sample score dicts (from EvaluationResult.score_rows). metrics: Metric names to evaluate (from Scenario.metrics). top_low_samples: How many worst-scoring samples to attach per diagnosis. Returns: List of Diagnosis objects, one per triggered metric. Empty if all OK. """ diagnoses: list[Diagnosis] = [] for metric in metrics: rule = METRIC_RULES.get(metric) if rule is None: continue # unknown metric, skip values = [] for row in score_rows: raw = row.get(metric) if raw is None: continue try: v = float(raw) except (TypeError, ValueError): continue values.append(v) if not values: continue mean = _mean_ignoring_nan(values) if mean is None: continue # Determine severity (direction-aware) if rule.higher_is_better: if mean < rule.critical_threshold: severity = "critical" threshold = rule.critical_threshold elif mean < rule.warning_threshold: severity = "warning" threshold = rule.warning_threshold else: continue # above warning threshold → no diagnosis else: # lower is better (noise_sensitivity) if mean > rule.critical_threshold: severity = "critical" threshold = rule.critical_threshold elif mean > rule.warning_threshold: severity = "warning" threshold = rule.warning_threshold else: continue low_samples = _select_low_samples(score_rows, metric, top_low_samples, rule.higher_is_better) diagnoses.append(Diagnosis( metric=metric, mean_score=round(mean, 4), threshold=threshold, severity=severity, root_causes=list(rule.root_causes), suggested_actions=list(rule.suggested_actions), low_samples=low_samples, )) return diagnoses