From e0b064587f67d041afad62df2951d42f630dbcfb Mon Sep 17 00:00:00 2001 From: wangwei Date: Thu, 18 Jun 2026 16:47:47 +0800 Subject: [PATCH] feat: add metric/doc weight computation module (weights.py) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- rag_eval/metrics/weights.py | 152 ++++++++++++++++++++++++++++++++++++ tests/test_weights.py | 124 +++++++++++++++++++++++++++++ 2 files changed, 276 insertions(+) create mode 100644 rag_eval/metrics/weights.py create mode 100644 tests/test_weights.py diff --git a/rag_eval/metrics/weights.py b/rag_eval/metrics/weights.py new file mode 100644 index 0000000..d84de57 --- /dev/null +++ b/rag_eval/metrics/weights.py @@ -0,0 +1,152 @@ +"""Utility functions for weighted metric aggregation. + +All functions are pure (no side effects, no I/O) and operate on plain dicts/lists. +Weights do not need to be pre-normalised — normalisation is done internally. +""" + +from __future__ import annotations + +import math + + +def resolve_weight(weights: dict[str, float], key: str, default: float = 1.0) -> float: + """Return the weight for *key*, or *default* when absent.""" + return float(weights.get(key, default)) + + +def compute_weighted_score( + scores: dict[str, float | None], + metric_weights: dict[str, float], +) -> float | None: + """Return the weighted mean of valid (non-NaN, non-None) metric scores. + + Args: + scores: mapping of metric_name -> raw score (may be NaN or None). + metric_weights: optional per-metric weights; absent keys default to 1.0. + + Returns: + Weighted mean as a float, or None when no valid score exists. + """ + total_weight = 0.0 + total_score = 0.0 + for metric, score in scores.items(): + if score is None: + continue + try: + value = float(score) + except (TypeError, ValueError): + continue + if math.isnan(value) or math.isinf(value): + continue + weight = resolve_weight(metric_weights, metric, default=1.0) + total_weight += weight + total_score += weight * value + if total_weight == 0.0: + return None + return total_score / total_weight + + +def weighted_metric_means( + score_rows: list[dict], + metrics: list[str], + doc_weights: dict[str, float], +) -> dict[str, float | None]: + """Compute per-metric weighted means across all score rows. + + Each row's contribution is scaled by the doc_weight for its ``doc_name``. + Rows with NaN/None for a given metric are excluded from that metric's mean. + + Args: + score_rows: list of score record dicts (from scores.csv). + metrics: ordered list of metric names to aggregate. + doc_weights: mapping doc_name -> weight multiplier; absent keys default to 1.0. + + Returns: + Dict mapping metric_name -> weighted mean (or None if no valid data). + """ + totals: dict[str, float] = {metric: 0.0 for metric in metrics} + weights_sum: dict[str, float] = {metric: 0.0 for metric in metrics} + + for row in score_rows: + doc_name = str(row.get("doc_name", "") or "") + sample_weight = resolve_weight(doc_weights, doc_name, default=1.0) + for metric in metrics: + raw_value = row.get(metric) + if raw_value is None: + continue + try: + value = float(raw_value) + except (TypeError, ValueError): + continue + if math.isnan(value) or math.isinf(value): + continue + totals[metric] += sample_weight * value + weights_sum[metric] += sample_weight + + return { + metric: (totals[metric] / weights_sum[metric] if weights_sum[metric] > 0 else None) + for metric in metrics + } + + +def compute_overall_weighted_score_mean( + score_rows: list[dict], + metric_weights: dict[str, float], + doc_weights: dict[str, float], +) -> float | None: + """Compute the overall weighted-score mean across all samples. + + For each sample: + 1. Compute per-sample weighted_score via compute_weighted_score. + 2. Scale by the doc weight for that sample's doc_name. + Then return the weighted mean of all per-sample weighted_scores. + """ + total_weight = 0.0 + total_score = 0.0 + for row in score_rows: + metric_scores: dict[str, float | None] = {} + for key, value in row.items(): + if key in _META_COLUMNS: + continue + metric_scores[key] = value # type: ignore[assignment] + + weighted_score = compute_weighted_score(metric_scores, metric_weights) + if weighted_score is None: + continue + doc_name = str(row.get("doc_name", "") or "") + sample_weight = resolve_weight(doc_weights, doc_name, default=1.0) + total_weight += sample_weight + total_score += sample_weight * weighted_score + + return total_score / total_weight if total_weight > 0 else None + + +# Columns in scores.csv that are sample metadata, not metric scores. +_META_COLUMNS = frozenset( + { + "sample_id", + "question", + "contexts", + "answer", + "ground_truth", + "scenario", + "language", + "retrieval_config", + "error", + "judge_model", + "embedding_model", + "run_id", + "difficulty", + "question_type", + "doc_id", + "doc_name", + "section_path", + "page_start", + "page_end", + "source_chunk_ids", + "review_status", + "review_notes", + "weighted_score", + "sample_weight", + } +) diff --git a/tests/test_weights.py b/tests/test_weights.py new file mode 100644 index 0000000..0cfba92 --- /dev/null +++ b/tests/test_weights.py @@ -0,0 +1,124 @@ +"""Unit tests for rag_eval/metrics/weights.py""" +import math + +import pytest + +from rag_eval.metrics.weights import ( + compute_overall_weighted_score_mean, + compute_weighted_score, + resolve_weight, + weighted_metric_means, +) + + +class TestResolveWeight: + def test_returns_value_when_key_present(self): + assert resolve_weight({"faith": 0.5}, "faith") == 0.5 + + def test_returns_default_when_key_missing(self): + assert resolve_weight({}, "faith") == 1.0 + + def test_returns_custom_default_when_key_missing(self): + assert resolve_weight({}, "faith", default=2.0) == 2.0 + + def test_empty_dict_returns_default(self): + assert resolve_weight({}, "anything") == 1.0 + + +class TestComputeWeightedScore: + def test_equal_weights_is_simple_mean(self): + scores = {"faithfulness": 0.8, "context_recall": 0.6} + result = compute_weighted_score(scores, {}) + assert result == pytest.approx(0.7, rel=1e-4) + + def test_explicit_weights(self): + scores = {"faithfulness": 1.0, "context_recall": 0.0} + weights = {"faithfulness": 3.0, "context_recall": 1.0} + result = compute_weighted_score(scores, weights) + assert result == pytest.approx(0.75, rel=1e-4) + + def test_nan_values_excluded(self): + scores = {"faithfulness": float("nan"), "context_recall": 0.8} + result = compute_weighted_score(scores, {}) + assert result == pytest.approx(0.8, rel=1e-4) + + def test_none_values_excluded(self): + scores = {"faithfulness": None, "context_recall": 0.6} + result = compute_weighted_score(scores, {}) + assert result == pytest.approx(0.6, rel=1e-4) + + def test_all_nan_returns_none(self): + scores = {"faithfulness": float("nan"), "context_recall": float("nan")} + assert compute_weighted_score(scores, {}) is None + + def test_empty_scores_returns_none(self): + assert compute_weighted_score({}, {}) is None + + def test_missing_metric_in_weights_uses_default_1(self): + scores = {"faithfulness": 0.8, "context_recall": 0.4} + weights = {"faithfulness": 2.0} + result = compute_weighted_score(scores, weights) + assert result == pytest.approx(2.0 / 3, rel=1e-4) + + +class TestWeightedMetricMeans: + def _rows(self): + return [ + {"doc_name": "a.pdf", "faithfulness": 1.0, "context_recall": 0.5}, + {"doc_name": "b.pdf", "faithfulness": 0.6, "context_recall": 0.8}, + ] + + def test_equal_weights_gives_arithmetic_mean(self): + rows = self._rows() + result = weighted_metric_means(rows, ["faithfulness", "context_recall"], {}) + assert result["faithfulness"] == pytest.approx(0.8, rel=1e-4) + assert result["context_recall"] == pytest.approx(0.65, rel=1e-4) + + def test_doc_weight_amplifies_contribution(self): + rows = self._rows() + doc_weights = {"a.pdf": 3.0, "b.pdf": 1.0} + result = weighted_metric_means(rows, ["faithfulness"], doc_weights) + assert result["faithfulness"] == pytest.approx(0.9, rel=1e-4) + + def test_nan_rows_skipped_per_metric(self): + rows = [ + {"doc_name": "a.pdf", "faithfulness": float("nan"), "context_recall": 0.5}, + {"doc_name": "b.pdf", "faithfulness": 0.8, "context_recall": 0.9}, + ] + result = weighted_metric_means(rows, ["faithfulness", "context_recall"], {}) + assert result["faithfulness"] == pytest.approx(0.8, rel=1e-4) + assert result["context_recall"] == pytest.approx(0.7, rel=1e-4) + + def test_missing_metric_column_returns_none(self): + rows = [{"doc_name": "a.pdf", "faithfulness": 0.8}] + result = weighted_metric_means(rows, ["faithfulness", "unknown_metric"], {}) + assert result["faithfulness"] == pytest.approx(0.8, rel=1e-4) + assert result["unknown_metric"] is None + + def test_empty_rows_returns_none_for_all(self): + result = weighted_metric_means([], ["faithfulness"], {}) + assert result["faithfulness"] is None + + +class TestComputeOverallWeightedScoreMean: + def test_basic_weighted_mean_of_weighted_scores(self): + rows = [ + {"doc_name": "a.pdf", "faithfulness": 1.0, "context_recall": 0.0}, + {"doc_name": "b.pdf", "faithfulness": 0.5, "context_recall": 0.5}, + ] + metric_weights = {"faithfulness": 1.0, "context_recall": 1.0} + result = compute_overall_weighted_score_mean(rows, metric_weights, {}) + assert result == pytest.approx(0.5, rel=1e-4) + + def test_doc_weight_amplifies_sample(self): + rows = [ + {"doc_name": "important.pdf", "faithfulness": 1.0}, + {"doc_name": "other.pdf", "faithfulness": 0.0}, + ] + doc_weights = {"important.pdf": 9.0, "other.pdf": 1.0} + result = compute_overall_weighted_score_mean(rows, {}, doc_weights) + assert result == pytest.approx(0.9, rel=1e-4) + + def test_all_nan_returns_none(self): + rows = [{"doc_name": "a.pdf", "faithfulness": float("nan")}] + assert compute_overall_weighted_score_mean(rows, {}, {}) is None