237 lines
8.5 KiB
Python
237 lines
8.5 KiB
Python
|
|
"""Rule-based diagnostic engine for RAG evaluation metric scores."""
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import math
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class MetricRule:
|
|||
|
|
"""Threshold configuration and diagnostic text for one metric."""
|
|||
|
|
warning_threshold: float
|
|||
|
|
critical_threshold: float
|
|||
|
|
higher_is_better: bool # False for noise_sensitivity
|
|||
|
|
root_causes: list[str]
|
|||
|
|
suggested_actions: list[str]
|
|||
|
|
|
|||
|
|
|
|||
|
|
METRIC_RULES: dict[str, MetricRule] = {
|
|||
|
|
"faithfulness": MetricRule(
|
|||
|
|
warning_threshold=0.7,
|
|||
|
|
critical_threshold=0.5,
|
|||
|
|
higher_is_better=True,
|
|||
|
|
root_causes=[
|
|||
|
|
"生成回答包含检索片段中不支持的陈述(幻觉)",
|
|||
|
|
"生成阶段未严格遵循 grounding 约束",
|
|||
|
|
"校验阶段未开启或未生效",
|
|||
|
|
],
|
|||
|
|
suggested_actions=[
|
|||
|
|
"强化生成 prompt 的 grounding 约束('只依据参考资料作答')",
|
|||
|
|
"开启校验阶段(validation: by_scenario)",
|
|||
|
|
"检查低分样本中模型是否引用了片段外的知识",
|
|||
|
|
],
|
|||
|
|
),
|
|||
|
|
"answer_relevancy": MetricRule(
|
|||
|
|
warning_threshold=0.7,
|
|||
|
|
critical_threshold=0.5,
|
|||
|
|
higher_is_better=True,
|
|||
|
|
root_causes=[
|
|||
|
|
"回答偏离问题主旨或包含大量冗余内容",
|
|||
|
|
"查询改写后问题语义漂移",
|
|||
|
|
"生成 prompt 格式约束不足",
|
|||
|
|
],
|
|||
|
|
suggested_actions=[
|
|||
|
|
"优化查询改写 prompt,确保改写后语义不偏移",
|
|||
|
|
"在生成 prompt 中加入'简洁准确、直接回答问题'的约束",
|
|||
|
|
"检查低分样本的回答是否存在格式冗余或话题偏移",
|
|||
|
|
],
|
|||
|
|
),
|
|||
|
|
"context_recall": MetricRule(
|
|||
|
|
warning_threshold=0.7,
|
|||
|
|
critical_threshold=0.5,
|
|||
|
|
higher_is_better=True,
|
|||
|
|
root_causes=[
|
|||
|
|
"检索未能召回标准答案所涉及的关键信息",
|
|||
|
|
"单一查询未能覆盖问题的多个角度",
|
|||
|
|
"过召回数量不足,关键片段被截断",
|
|||
|
|
],
|
|||
|
|
suggested_actions=[
|
|||
|
|
"启用多查询扩展(use_multi_query)覆盖不同措辞",
|
|||
|
|
"对多跳问题启用问题分解(sub_questions)",
|
|||
|
|
"加大过召回宽度(recall_top_k)",
|
|||
|
|
"对颗粒度细的问题尝试 Step-back 双路检索",
|
|||
|
|
],
|
|||
|
|
),
|
|||
|
|
"context_precision": MetricRule(
|
|||
|
|
warning_threshold=0.6,
|
|||
|
|
critical_threshold=0.4,
|
|||
|
|
higher_is_better=True,
|
|||
|
|
root_causes=[
|
|||
|
|
"检索引入过多与问题无关的片段",
|
|||
|
|
"重排未能将相关片段排在前列",
|
|||
|
|
"缺少相关性过滤,噪声片段进入上下文",
|
|||
|
|
],
|
|||
|
|
suggested_actions=[
|
|||
|
|
"启用或优化 listwise 重排,将相关片段排在前列",
|
|||
|
|
"启用上下文压缩(compression)过滤无关句子",
|
|||
|
|
"启用相关性过滤(relevance_filter)丢弃明确无关片段",
|
|||
|
|
"缩小 rerank_keep_k(如从 8 降到 5)",
|
|||
|
|
],
|
|||
|
|
),
|
|||
|
|
"noise_sensitivity": MetricRule(
|
|||
|
|
warning_threshold=0.3, # higher is worse; trigger when mean > threshold
|
|||
|
|
critical_threshold=0.5,
|
|||
|
|
higher_is_better=False,
|
|||
|
|
root_causes=[
|
|||
|
|
"回答中包含检索到的噪声片段所引入的错误陈述",
|
|||
|
|
"相关性过滤未能拦截干扰性片段",
|
|||
|
|
"生成阶段对噪声片段未加区分地引用",
|
|||
|
|
],
|
|||
|
|
suggested_actions=[
|
|||
|
|
"启用相关性过滤(relevance_filter)拦截噪声",
|
|||
|
|
"优化重排,将不相关片段排到截断点之后",
|
|||
|
|
"在生成 prompt 中强调'来源冲突时并列陈述,不擅自下定论'",
|
|||
|
|
],
|
|||
|
|
),
|
|||
|
|
"factual_correctness": MetricRule(
|
|||
|
|
warning_threshold=0.6,
|
|||
|
|
critical_threshold=0.4,
|
|||
|
|
higher_is_better=True,
|
|||
|
|
root_causes=[
|
|||
|
|
"回答的事实陈述与标准答案存在偏差",
|
|||
|
|
"检索未能命中标准答案所依据的关键片段",
|
|||
|
|
"生成阶段对多个来源综合时产生事实错误",
|
|||
|
|
],
|
|||
|
|
suggested_actions=[
|
|||
|
|
"重点检查低分样本,确认是检索遗漏还是生成错误",
|
|||
|
|
"提升 context_recall 以确保关键信息被检索到",
|
|||
|
|
"对事实型问题将 temperature 降至 0",
|
|||
|
|
],
|
|||
|
|
),
|
|||
|
|
"semantic_similarity": MetricRule(
|
|||
|
|
warning_threshold=0.7,
|
|||
|
|
critical_threshold=0.5,
|
|||
|
|
higher_is_better=True,
|
|||
|
|
root_causes=[
|
|||
|
|
"回答语义与标准答案差距较大",
|
|||
|
|
"回答过于简短或过于冗长,语义偏移",
|
|||
|
|
"检索到的片段质量不足,导致生成内容偏离",
|
|||
|
|
],
|
|||
|
|
suggested_actions=[
|
|||
|
|
"检查低分样本的回答与标准答案的表述差异",
|
|||
|
|
"优化生成 prompt 使回答更贴近标准表述风格",
|
|||
|
|
"提升检索质量(context_recall / context_precision)",
|
|||
|
|
],
|
|||
|
|
),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class Diagnosis:
|
|||
|
|
"""Diagnostic result for one metric that triggered a threshold."""
|
|||
|
|
metric: str
|
|||
|
|
mean_score: float
|
|||
|
|
threshold: float # the triggered threshold
|
|||
|
|
severity: str # "warning" | "critical"
|
|||
|
|
root_causes: list[str] = field(default_factory=list)
|
|||
|
|
suggested_actions: list[str] = field(default_factory=list)
|
|||
|
|
low_samples: list[dict[str, Any]] = field(default_factory=list)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _mean_ignoring_nan(values: list[float]) -> float | None:
|
|||
|
|
valid = [v for v in values if not math.isnan(v)]
|
|||
|
|
if not valid:
|
|||
|
|
return None
|
|||
|
|
return sum(valid) / len(valid)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _select_low_samples(
|
|||
|
|
rows: list[dict[str, Any]],
|
|||
|
|
metric: str,
|
|||
|
|
top_n: int,
|
|||
|
|
higher_is_better: bool,
|
|||
|
|
) -> list[dict[str, Any]]:
|
|||
|
|
"""Return the top_n worst-scoring rows for a metric, excluding NaN."""
|
|||
|
|
valid = [r for r in rows if metric in r and not math.isnan(float(r[metric]))]
|
|||
|
|
sorted_rows = sorted(valid, key=lambda r: float(r[metric]), reverse=not higher_is_better)
|
|||
|
|
worst = sorted_rows[:top_n]
|
|||
|
|
keep_keys = {"sample_id", "question", "answer", "ground_truth", metric}
|
|||
|
|
return [{k: v for k, v in row.items() if k in keep_keys} for row in worst]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def diagnose(
|
|||
|
|
score_rows: list[dict[str, Any]],
|
|||
|
|
metrics: list[str],
|
|||
|
|
top_low_samples: int = 3,
|
|||
|
|
) -> list[Diagnosis]:
|
|||
|
|
"""Analyse score_rows and return a Diagnosis for each metric below threshold.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
score_rows: List of per-sample score dicts (from EvaluationResult.score_rows).
|
|||
|
|
metrics: Metric names to evaluate (from Scenario.metrics).
|
|||
|
|
top_low_samples: How many worst-scoring samples to attach per diagnosis.
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
List of Diagnosis objects, one per triggered metric. Empty if all OK.
|
|||
|
|
"""
|
|||
|
|
diagnoses: list[Diagnosis] = []
|
|||
|
|
|
|||
|
|
for metric in metrics:
|
|||
|
|
rule = METRIC_RULES.get(metric)
|
|||
|
|
if rule is None:
|
|||
|
|
continue # unknown metric, skip
|
|||
|
|
|
|||
|
|
values = []
|
|||
|
|
for row in score_rows:
|
|||
|
|
raw = row.get(metric)
|
|||
|
|
if raw is None:
|
|||
|
|
continue
|
|||
|
|
try:
|
|||
|
|
v = float(raw)
|
|||
|
|
except (TypeError, ValueError):
|
|||
|
|
continue
|
|||
|
|
values.append(v)
|
|||
|
|
|
|||
|
|
if not values:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
mean = _mean_ignoring_nan(values)
|
|||
|
|
if mean is None:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# Determine severity (direction-aware)
|
|||
|
|
if rule.higher_is_better:
|
|||
|
|
if mean < rule.critical_threshold:
|
|||
|
|
severity = "critical"
|
|||
|
|
threshold = rule.critical_threshold
|
|||
|
|
elif mean < rule.warning_threshold:
|
|||
|
|
severity = "warning"
|
|||
|
|
threshold = rule.warning_threshold
|
|||
|
|
else:
|
|||
|
|
continue # above warning threshold → no diagnosis
|
|||
|
|
else:
|
|||
|
|
# lower is better (noise_sensitivity)
|
|||
|
|
if mean > rule.critical_threshold:
|
|||
|
|
severity = "critical"
|
|||
|
|
threshold = rule.critical_threshold
|
|||
|
|
elif mean > rule.warning_threshold:
|
|||
|
|
severity = "warning"
|
|||
|
|
threshold = rule.warning_threshold
|
|||
|
|
else:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
low_samples = _select_low_samples(score_rows, metric, top_low_samples, rule.higher_is_better)
|
|||
|
|
|
|||
|
|
diagnoses.append(Diagnosis(
|
|||
|
|
metric=metric,
|
|||
|
|
mean_score=round(mean, 4),
|
|||
|
|
threshold=threshold,
|
|||
|
|
severity=severity,
|
|||
|
|
root_causes=list(rule.root_causes),
|
|||
|
|
suggested_actions=list(rule.suggested_actions),
|
|||
|
|
low_samples=low_samples,
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
return diagnoses
|