feat: add metric/doc weight computation module (weights.py)
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
152
rag_eval/metrics/weights.py
Normal file
152
rag_eval/metrics/weights.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
"""Utility functions for weighted metric aggregation.
|
||||||
|
|
||||||
|
All functions are pure (no side effects, no I/O) and operate on plain dicts/lists.
|
||||||
|
Weights do not need to be pre-normalised — normalisation is done internally.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_weight(weights: dict[str, float], key: str, default: float = 1.0) -> float:
|
||||||
|
"""Return the weight for *key*, or *default* when absent."""
|
||||||
|
return float(weights.get(key, default))
|
||||||
|
|
||||||
|
|
||||||
|
def compute_weighted_score(
|
||||||
|
scores: dict[str, float | None],
|
||||||
|
metric_weights: dict[str, float],
|
||||||
|
) -> float | None:
|
||||||
|
"""Return the weighted mean of valid (non-NaN, non-None) metric scores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scores: mapping of metric_name -> raw score (may be NaN or None).
|
||||||
|
metric_weights: optional per-metric weights; absent keys default to 1.0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Weighted mean as a float, or None when no valid score exists.
|
||||||
|
"""
|
||||||
|
total_weight = 0.0
|
||||||
|
total_score = 0.0
|
||||||
|
for metric, score in scores.items():
|
||||||
|
if score is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
value = float(score)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
if math.isnan(value) or math.isinf(value):
|
||||||
|
continue
|
||||||
|
weight = resolve_weight(metric_weights, metric, default=1.0)
|
||||||
|
total_weight += weight
|
||||||
|
total_score += weight * value
|
||||||
|
if total_weight == 0.0:
|
||||||
|
return None
|
||||||
|
return total_score / total_weight
|
||||||
|
|
||||||
|
|
||||||
|
def weighted_metric_means(
|
||||||
|
score_rows: list[dict],
|
||||||
|
metrics: list[str],
|
||||||
|
doc_weights: dict[str, float],
|
||||||
|
) -> dict[str, float | None]:
|
||||||
|
"""Compute per-metric weighted means across all score rows.
|
||||||
|
|
||||||
|
Each row's contribution is scaled by the doc_weight for its ``doc_name``.
|
||||||
|
Rows with NaN/None for a given metric are excluded from that metric's mean.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
score_rows: list of score record dicts (from scores.csv).
|
||||||
|
metrics: ordered list of metric names to aggregate.
|
||||||
|
doc_weights: mapping doc_name -> weight multiplier; absent keys default to 1.0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping metric_name -> weighted mean (or None if no valid data).
|
||||||
|
"""
|
||||||
|
totals: dict[str, float] = {metric: 0.0 for metric in metrics}
|
||||||
|
weights_sum: dict[str, float] = {metric: 0.0 for metric in metrics}
|
||||||
|
|
||||||
|
for row in score_rows:
|
||||||
|
doc_name = str(row.get("doc_name", "") or "")
|
||||||
|
sample_weight = resolve_weight(doc_weights, doc_name, default=1.0)
|
||||||
|
for metric in metrics:
|
||||||
|
raw_value = row.get(metric)
|
||||||
|
if raw_value is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
value = float(raw_value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
if math.isnan(value) or math.isinf(value):
|
||||||
|
continue
|
||||||
|
totals[metric] += sample_weight * value
|
||||||
|
weights_sum[metric] += sample_weight
|
||||||
|
|
||||||
|
return {
|
||||||
|
metric: (totals[metric] / weights_sum[metric] if weights_sum[metric] > 0 else None)
|
||||||
|
for metric in metrics
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def compute_overall_weighted_score_mean(
|
||||||
|
score_rows: list[dict],
|
||||||
|
metric_weights: dict[str, float],
|
||||||
|
doc_weights: dict[str, float],
|
||||||
|
) -> float | None:
|
||||||
|
"""Compute the overall weighted-score mean across all samples.
|
||||||
|
|
||||||
|
For each sample:
|
||||||
|
1. Compute per-sample weighted_score via compute_weighted_score.
|
||||||
|
2. Scale by the doc weight for that sample's doc_name.
|
||||||
|
Then return the weighted mean of all per-sample weighted_scores.
|
||||||
|
"""
|
||||||
|
total_weight = 0.0
|
||||||
|
total_score = 0.0
|
||||||
|
for row in score_rows:
|
||||||
|
metric_scores: dict[str, float | None] = {}
|
||||||
|
for key, value in row.items():
|
||||||
|
if key in _META_COLUMNS:
|
||||||
|
continue
|
||||||
|
metric_scores[key] = value # type: ignore[assignment]
|
||||||
|
|
||||||
|
weighted_score = compute_weighted_score(metric_scores, metric_weights)
|
||||||
|
if weighted_score is None:
|
||||||
|
continue
|
||||||
|
doc_name = str(row.get("doc_name", "") or "")
|
||||||
|
sample_weight = resolve_weight(doc_weights, doc_name, default=1.0)
|
||||||
|
total_weight += sample_weight
|
||||||
|
total_score += sample_weight * weighted_score
|
||||||
|
|
||||||
|
return total_score / total_weight if total_weight > 0 else None
|
||||||
|
|
||||||
|
|
||||||
|
# Columns in scores.csv that are sample metadata, not metric scores.
|
||||||
|
_META_COLUMNS = frozenset(
|
||||||
|
{
|
||||||
|
"sample_id",
|
||||||
|
"question",
|
||||||
|
"contexts",
|
||||||
|
"answer",
|
||||||
|
"ground_truth",
|
||||||
|
"scenario",
|
||||||
|
"language",
|
||||||
|
"retrieval_config",
|
||||||
|
"error",
|
||||||
|
"judge_model",
|
||||||
|
"embedding_model",
|
||||||
|
"run_id",
|
||||||
|
"difficulty",
|
||||||
|
"question_type",
|
||||||
|
"doc_id",
|
||||||
|
"doc_name",
|
||||||
|
"section_path",
|
||||||
|
"page_start",
|
||||||
|
"page_end",
|
||||||
|
"source_chunk_ids",
|
||||||
|
"review_status",
|
||||||
|
"review_notes",
|
||||||
|
"weighted_score",
|
||||||
|
"sample_weight",
|
||||||
|
}
|
||||||
|
)
|
||||||
124
tests/test_weights.py
Normal file
124
tests/test_weights.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
"""Unit tests for rag_eval/metrics/weights.py"""
|
||||||
|
import math
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from rag_eval.metrics.weights import (
|
||||||
|
compute_overall_weighted_score_mean,
|
||||||
|
compute_weighted_score,
|
||||||
|
resolve_weight,
|
||||||
|
weighted_metric_means,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestResolveWeight:
|
||||||
|
def test_returns_value_when_key_present(self):
|
||||||
|
assert resolve_weight({"faith": 0.5}, "faith") == 0.5
|
||||||
|
|
||||||
|
def test_returns_default_when_key_missing(self):
|
||||||
|
assert resolve_weight({}, "faith") == 1.0
|
||||||
|
|
||||||
|
def test_returns_custom_default_when_key_missing(self):
|
||||||
|
assert resolve_weight({}, "faith", default=2.0) == 2.0
|
||||||
|
|
||||||
|
def test_empty_dict_returns_default(self):
|
||||||
|
assert resolve_weight({}, "anything") == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestComputeWeightedScore:
|
||||||
|
def test_equal_weights_is_simple_mean(self):
|
||||||
|
scores = {"faithfulness": 0.8, "context_recall": 0.6}
|
||||||
|
result = compute_weighted_score(scores, {})
|
||||||
|
assert result == pytest.approx(0.7, rel=1e-4)
|
||||||
|
|
||||||
|
def test_explicit_weights(self):
|
||||||
|
scores = {"faithfulness": 1.0, "context_recall": 0.0}
|
||||||
|
weights = {"faithfulness": 3.0, "context_recall": 1.0}
|
||||||
|
result = compute_weighted_score(scores, weights)
|
||||||
|
assert result == pytest.approx(0.75, rel=1e-4)
|
||||||
|
|
||||||
|
def test_nan_values_excluded(self):
|
||||||
|
scores = {"faithfulness": float("nan"), "context_recall": 0.8}
|
||||||
|
result = compute_weighted_score(scores, {})
|
||||||
|
assert result == pytest.approx(0.8, rel=1e-4)
|
||||||
|
|
||||||
|
def test_none_values_excluded(self):
|
||||||
|
scores = {"faithfulness": None, "context_recall": 0.6}
|
||||||
|
result = compute_weighted_score(scores, {})
|
||||||
|
assert result == pytest.approx(0.6, rel=1e-4)
|
||||||
|
|
||||||
|
def test_all_nan_returns_none(self):
|
||||||
|
scores = {"faithfulness": float("nan"), "context_recall": float("nan")}
|
||||||
|
assert compute_weighted_score(scores, {}) is None
|
||||||
|
|
||||||
|
def test_empty_scores_returns_none(self):
|
||||||
|
assert compute_weighted_score({}, {}) is None
|
||||||
|
|
||||||
|
def test_missing_metric_in_weights_uses_default_1(self):
|
||||||
|
scores = {"faithfulness": 0.8, "context_recall": 0.4}
|
||||||
|
weights = {"faithfulness": 2.0}
|
||||||
|
result = compute_weighted_score(scores, weights)
|
||||||
|
assert result == pytest.approx(2.0 / 3, rel=1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
class TestWeightedMetricMeans:
|
||||||
|
def _rows(self):
|
||||||
|
return [
|
||||||
|
{"doc_name": "a.pdf", "faithfulness": 1.0, "context_recall": 0.5},
|
||||||
|
{"doc_name": "b.pdf", "faithfulness": 0.6, "context_recall": 0.8},
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_equal_weights_gives_arithmetic_mean(self):
|
||||||
|
rows = self._rows()
|
||||||
|
result = weighted_metric_means(rows, ["faithfulness", "context_recall"], {})
|
||||||
|
assert result["faithfulness"] == pytest.approx(0.8, rel=1e-4)
|
||||||
|
assert result["context_recall"] == pytest.approx(0.65, rel=1e-4)
|
||||||
|
|
||||||
|
def test_doc_weight_amplifies_contribution(self):
|
||||||
|
rows = self._rows()
|
||||||
|
doc_weights = {"a.pdf": 3.0, "b.pdf": 1.0}
|
||||||
|
result = weighted_metric_means(rows, ["faithfulness"], doc_weights)
|
||||||
|
assert result["faithfulness"] == pytest.approx(0.9, rel=1e-4)
|
||||||
|
|
||||||
|
def test_nan_rows_skipped_per_metric(self):
|
||||||
|
rows = [
|
||||||
|
{"doc_name": "a.pdf", "faithfulness": float("nan"), "context_recall": 0.5},
|
||||||
|
{"doc_name": "b.pdf", "faithfulness": 0.8, "context_recall": 0.9},
|
||||||
|
]
|
||||||
|
result = weighted_metric_means(rows, ["faithfulness", "context_recall"], {})
|
||||||
|
assert result["faithfulness"] == pytest.approx(0.8, rel=1e-4)
|
||||||
|
assert result["context_recall"] == pytest.approx(0.7, rel=1e-4)
|
||||||
|
|
||||||
|
def test_missing_metric_column_returns_none(self):
|
||||||
|
rows = [{"doc_name": "a.pdf", "faithfulness": 0.8}]
|
||||||
|
result = weighted_metric_means(rows, ["faithfulness", "unknown_metric"], {})
|
||||||
|
assert result["faithfulness"] == pytest.approx(0.8, rel=1e-4)
|
||||||
|
assert result["unknown_metric"] is None
|
||||||
|
|
||||||
|
def test_empty_rows_returns_none_for_all(self):
|
||||||
|
result = weighted_metric_means([], ["faithfulness"], {})
|
||||||
|
assert result["faithfulness"] is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestComputeOverallWeightedScoreMean:
|
||||||
|
def test_basic_weighted_mean_of_weighted_scores(self):
|
||||||
|
rows = [
|
||||||
|
{"doc_name": "a.pdf", "faithfulness": 1.0, "context_recall": 0.0},
|
||||||
|
{"doc_name": "b.pdf", "faithfulness": 0.5, "context_recall": 0.5},
|
||||||
|
]
|
||||||
|
metric_weights = {"faithfulness": 1.0, "context_recall": 1.0}
|
||||||
|
result = compute_overall_weighted_score_mean(rows, metric_weights, {})
|
||||||
|
assert result == pytest.approx(0.5, rel=1e-4)
|
||||||
|
|
||||||
|
def test_doc_weight_amplifies_sample(self):
|
||||||
|
rows = [
|
||||||
|
{"doc_name": "important.pdf", "faithfulness": 1.0},
|
||||||
|
{"doc_name": "other.pdf", "faithfulness": 0.0},
|
||||||
|
]
|
||||||
|
doc_weights = {"important.pdf": 9.0, "other.pdf": 1.0}
|
||||||
|
result = compute_overall_weighted_score_mean(rows, {}, doc_weights)
|
||||||
|
assert result == pytest.approx(0.9, rel=1e-4)
|
||||||
|
|
||||||
|
def test_all_nan_returns_none(self):
|
||||||
|
rows = [{"doc_name": "a.pdf", "faithfulness": float("nan")}]
|
||||||
|
assert compute_overall_weighted_score_mean(rows, {}, {}) is None
|
||||||
Reference in New Issue
Block a user