Files

244 lines
8.9 KiB
Python
Raw Permalink Normal View History

"""Rule-based diagnostic engine for RAG evaluation metric scores."""
from __future__ import annotations
import math
from dataclasses import dataclass, field
from typing import Any
@dataclass
class MetricRule:
"""Threshold configuration and diagnostic text for one metric."""
warning_threshold: float
critical_threshold: float
higher_is_better: bool # False for noise_sensitivity
root_causes: list[str]
suggested_actions: list[str]
# Scores below this threshold trigger a "low" advisory (LLM suggestion requested).
# Only applies to higher_is_better metrics; noise_sensitivity uses existing thresholds.
advisory_threshold: float = 0.85
METRIC_RULES: dict[str, MetricRule] = {
"faithfulness": MetricRule(
warning_threshold=0.7,
critical_threshold=0.5,
higher_is_better=True,
root_causes=[
"生成回答包含检索片段中不支持的陈述(幻觉)",
"生成阶段未严格遵循 grounding 约束",
"校验阶段未开启或未生效",
],
suggested_actions=[
"强化生成 prompt 的 grounding 约束('只依据参考资料作答'",
"开启校验阶段validation: by_scenario",
"检查低分样本中模型是否引用了片段外的知识",
],
),
"answer_relevancy": MetricRule(
warning_threshold=0.7,
critical_threshold=0.5,
higher_is_better=True,
root_causes=[
"回答偏离问题主旨或包含大量冗余内容",
"查询改写后问题语义漂移",
"生成 prompt 格式约束不足",
],
suggested_actions=[
"优化查询改写 prompt确保改写后语义不偏移",
"在生成 prompt 中加入'简洁准确、直接回答问题'的约束",
"检查低分样本的回答是否存在格式冗余或话题偏移",
],
),
"context_recall": MetricRule(
warning_threshold=0.7,
critical_threshold=0.5,
higher_is_better=True,
root_causes=[
"检索未能召回标准答案所涉及的关键信息",
"单一查询未能覆盖问题的多个角度",
"过召回数量不足,关键片段被截断",
],
suggested_actions=[
"启用多查询扩展use_multi_query覆盖不同措辞",
"对多跳问题启用问题分解sub_questions",
"加大过召回宽度recall_top_k",
"对颗粒度细的问题尝试 Step-back 双路检索",
],
),
"context_precision": MetricRule(
warning_threshold=0.6,
critical_threshold=0.4,
higher_is_better=True,
root_causes=[
"检索引入过多与问题无关的片段",
"重排未能将相关片段排在前列",
"缺少相关性过滤,噪声片段进入上下文",
],
suggested_actions=[
"启用或优化 listwise 重排,将相关片段排在前列",
"启用上下文压缩compression过滤无关句子",
"启用相关性过滤relevance_filter丢弃明确无关片段",
"缩小 rerank_keep_k如从 8 降到 5",
],
),
"noise_sensitivity": MetricRule(
warning_threshold=0.3, # higher is worse; trigger when mean > threshold
critical_threshold=0.5,
higher_is_better=False,
root_causes=[
"回答中包含检索到的噪声片段所引入的错误陈述",
"相关性过滤未能拦截干扰性片段",
"生成阶段对噪声片段未加区分地引用",
],
suggested_actions=[
"启用相关性过滤relevance_filter拦截噪声",
"优化重排,将不相关片段排到截断点之后",
"在生成 prompt 中强调'来源冲突时并列陈述,不擅自下定论'",
],
),
"factual_correctness": MetricRule(
warning_threshold=0.6,
critical_threshold=0.4,
higher_is_better=True,
root_causes=[
"回答的事实陈述与标准答案存在偏差",
"检索未能命中标准答案所依据的关键片段",
"生成阶段对多个来源综合时产生事实错误",
],
suggested_actions=[
"重点检查低分样本,确认是检索遗漏还是生成错误",
"提升 context_recall 以确保关键信息被检索到",
"对事实型问题将 temperature 降至 0",
],
),
"semantic_similarity": MetricRule(
warning_threshold=0.7,
critical_threshold=0.5,
higher_is_better=True,
root_causes=[
"回答语义与标准答案差距较大",
"回答过于简短或过于冗长,语义偏移",
"检索到的片段质量不足,导致生成内容偏离",
],
suggested_actions=[
"检查低分样本的回答与标准答案的表述差异",
"优化生成 prompt 使回答更贴近标准表述风格",
"提升检索质量context_recall / context_precision",
],
),
}
@dataclass
class Diagnosis:
"""Diagnostic result for one metric that triggered a threshold."""
metric: str
mean_score: float
threshold: float # the triggered threshold
severity: str # "warning" | "critical"
root_causes: list[str] = field(default_factory=list)
suggested_actions: list[str] = field(default_factory=list)
low_samples: list[dict[str, Any]] = field(default_factory=list)
def _mean_ignoring_nan(values: list[float]) -> float | None:
valid = [v for v in values if not math.isnan(v)]
if not valid:
return None
return sum(valid) / len(valid)
def _select_low_samples(
rows: list[dict[str, Any]],
metric: str,
top_n: int,
higher_is_better: bool,
) -> list[dict[str, Any]]:
"""Return the top_n worst-scoring rows for a metric, excluding NaN."""
valid = [r for r in rows if metric in r and not math.isnan(float(r[metric]))]
sorted_rows = sorted(valid, key=lambda r: float(r[metric]), reverse=not higher_is_better)
worst = sorted_rows[:top_n]
keep_keys = {"sample_id", "question", "answer", "ground_truth", metric}
return [{k: v for k, v in row.items() if k in keep_keys} for row in worst]
def diagnose(
score_rows: list[dict[str, Any]],
metrics: list[str],
top_low_samples: int = 3,
) -> list[Diagnosis]:
"""Analyse score_rows and return a Diagnosis for each metric below threshold.
Args:
score_rows: List of per-sample score dicts (from EvaluationResult.score_rows).
metrics: Metric names to evaluate (from Scenario.metrics).
top_low_samples: How many worst-scoring samples to attach per diagnosis.
Returns:
List of Diagnosis objects, one per triggered metric. Empty if all OK.
"""
diagnoses: list[Diagnosis] = []
for metric in metrics:
rule = METRIC_RULES.get(metric)
if rule is None:
continue # unknown metric, skip
values = []
for row in score_rows:
raw = row.get(metric)
if raw is None:
continue
try:
v = float(raw)
except (TypeError, ValueError):
continue
values.append(v)
if not values:
continue
mean = _mean_ignoring_nan(values)
if mean is None:
continue
# Determine severity (direction-aware)
if rule.higher_is_better:
if mean < rule.critical_threshold:
severity = "critical"
threshold = rule.critical_threshold
elif mean < rule.warning_threshold:
severity = "warning"
threshold = rule.warning_threshold
elif mean < rule.advisory_threshold:
# Score is acceptable but below 0.85 — request LLM optimization advice.
severity = "low"
threshold = rule.advisory_threshold
else:
continue # >= advisory_threshold → no diagnosis needed
else:
# lower is better (noise_sensitivity): keep existing two-tier logic
if mean > rule.critical_threshold:
severity = "critical"
threshold = rule.critical_threshold
elif mean > rule.warning_threshold:
severity = "warning"
threshold = rule.warning_threshold
else:
continue
low_samples = _select_low_samples(score_rows, metric, top_low_samples, rule.higher_is_better)
diagnoses.append(Diagnosis(
metric=metric,
mean_score=round(mean, 4),
threshold=threshold,
severity=severity,
root_causes=list(rule.root_causes),
suggested_actions=list(rule.suggested_actions),
low_samples=low_samples,
))
return diagnoses