2026-06-16 16:27:54 +08:00
|
|
|
import math
|
|
|
|
|
import unittest
|
|
|
|
|
from rag_eval.advisor.rules import Diagnosis, diagnose, METRIC_RULES
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestDiagnosis(unittest.TestCase):
|
|
|
|
|
def _make_rows(self, metric: str, scores: list[float]) -> list[dict]:
|
|
|
|
|
return [{metric: s, "question": f"q{i}", "answer": f"a{i}",
|
|
|
|
|
"ground_truth": f"gt{i}", "sample_id": f"s{i}"}
|
|
|
|
|
for i, s in enumerate(scores)]
|
|
|
|
|
|
|
|
|
|
def test_no_diagnosis_when_all_scores_above_threshold(self):
|
2026-06-25 11:35:49 +08:00
|
|
|
# Mean exactly 0.85 should NOT trigger any diagnosis (< 0.85 is the condition).
|
2026-06-16 16:27:54 +08:00
|
|
|
rows = self._make_rows("faithfulness", [0.8, 0.9, 0.85])
|
|
|
|
|
result = diagnose(rows, metrics=["faithfulness"])
|
|
|
|
|
self.assertEqual(result, [])
|
|
|
|
|
|
2026-06-25 11:35:49 +08:00
|
|
|
def test_no_diagnosis_when_mean_above_advisory_threshold(self):
|
|
|
|
|
rows = self._make_rows("answer_relevancy", [0.9, 0.92, 0.88])
|
|
|
|
|
result = diagnose(rows, metrics=["answer_relevancy"])
|
|
|
|
|
self.assertEqual(result, [])
|
|
|
|
|
|
|
|
|
|
def test_low_severity_when_mean_below_advisory_threshold(self):
|
|
|
|
|
# Score between warning_threshold (0.7) and advisory_threshold (0.85) → "low"
|
|
|
|
|
rows = self._make_rows("faithfulness", [0.78, 0.80, 0.82])
|
|
|
|
|
result = diagnose(rows, metrics=["faithfulness"])
|
|
|
|
|
self.assertEqual(len(result), 1)
|
|
|
|
|
self.assertEqual(result[0].severity, "low")
|
|
|
|
|
self.assertAlmostEqual(result[0].threshold, 0.85, places=2)
|
|
|
|
|
|
|
|
|
|
def test_low_severity_answer_relevancy_at_0_84(self):
|
|
|
|
|
rows = self._make_rows("answer_relevancy", [0.84, 0.84, 0.84])
|
|
|
|
|
result = diagnose(rows, metrics=["answer_relevancy"])
|
|
|
|
|
self.assertEqual(len(result), 1)
|
|
|
|
|
self.assertEqual(result[0].severity, "low")
|
|
|
|
|
|
|
|
|
|
def test_low_severity_has_root_causes_and_actions(self):
|
|
|
|
|
rows = self._make_rows("context_precision", [0.75, 0.76, 0.77])
|
|
|
|
|
result = diagnose(rows, metrics=["context_precision"])
|
|
|
|
|
self.assertEqual(len(result), 1)
|
|
|
|
|
self.assertEqual(result[0].severity, "low")
|
|
|
|
|
self.assertTrue(len(result[0].root_causes) > 0)
|
|
|
|
|
self.assertTrue(len(result[0].suggested_actions) > 0)
|
|
|
|
|
|
2026-06-16 16:27:54 +08:00
|
|
|
def test_warning_when_mean_below_warning_threshold(self):
|
|
|
|
|
rows = self._make_rows("faithfulness", [0.65, 0.62, 0.68])
|
|
|
|
|
result = diagnose(rows, metrics=["faithfulness"])
|
|
|
|
|
self.assertEqual(len(result), 1)
|
|
|
|
|
self.assertEqual(result[0].metric, "faithfulness")
|
|
|
|
|
self.assertEqual(result[0].severity, "warning")
|
|
|
|
|
self.assertAlmostEqual(result[0].mean_score, 0.65, places=2)
|
|
|
|
|
|
|
|
|
|
def test_critical_when_mean_below_critical_threshold(self):
|
|
|
|
|
rows = self._make_rows("faithfulness", [0.3, 0.4, 0.45])
|
|
|
|
|
result = diagnose(rows, metrics=["faithfulness"])
|
|
|
|
|
self.assertEqual(result[0].severity, "critical")
|
|
|
|
|
|
|
|
|
|
def test_low_samples_selected_are_bottom_three(self):
|
|
|
|
|
rows = self._make_rows("faithfulness", [0.1, 0.2, 0.3, 0.8, 0.9])
|
|
|
|
|
result = diagnose(rows, metrics=["faithfulness"])
|
|
|
|
|
self.assertEqual(len(result[0].low_samples), 3)
|
|
|
|
|
scores = [s["faithfulness"] for s in result[0].low_samples]
|
|
|
|
|
self.assertEqual(sorted(scores), [0.1, 0.2, 0.3])
|
|
|
|
|
|
|
|
|
|
def test_nan_scores_excluded_from_mean_and_low_samples(self):
|
|
|
|
|
rows = self._make_rows("faithfulness", [0.3, float("nan"), 0.4])
|
|
|
|
|
result = diagnose(rows, metrics=["faithfulness"])
|
|
|
|
|
self.assertEqual(len(result), 1)
|
|
|
|
|
for s in result[0].low_samples:
|
|
|
|
|
self.assertFalse(math.isnan(s["faithfulness"]))
|
|
|
|
|
|
|
|
|
|
def test_noise_sensitivity_direction_inverted(self):
|
|
|
|
|
# noise_sensitivity: higher is worse; threshold > 0.3 is warning
|
|
|
|
|
rows = self._make_rows("noise_sensitivity", [0.4, 0.45, 0.5])
|
|
|
|
|
result = diagnose(rows, metrics=["noise_sensitivity"])
|
|
|
|
|
self.assertEqual(len(result), 1)
|
|
|
|
|
self.assertEqual(result[0].metric, "noise_sensitivity")
|
|
|
|
|
|
|
|
|
|
def test_noise_sensitivity_no_diagnosis_when_low(self):
|
|
|
|
|
rows = self._make_rows("noise_sensitivity", [0.1, 0.15, 0.2])
|
|
|
|
|
result = diagnose(rows, metrics=["noise_sensitivity"])
|
|
|
|
|
self.assertEqual(result, [])
|
|
|
|
|
|
|
|
|
|
def test_skips_metric_not_in_rows(self):
|
|
|
|
|
rows = [{"faithfulness": 0.3, "question": "q", "answer": "a",
|
|
|
|
|
"ground_truth": "gt", "sample_id": "s1"}]
|
|
|
|
|
result = diagnose(rows, metrics=["faithfulness", "context_recall"])
|
|
|
|
|
metrics_found = [d.metric for d in result]
|
|
|
|
|
self.assertIn("faithfulness", metrics_found)
|
|
|
|
|
self.assertNotIn("context_recall", metrics_found)
|
|
|
|
|
|
|
|
|
|
def test_all_seven_metrics_have_rules(self):
|
|
|
|
|
expected = {"faithfulness", "answer_relevancy", "context_recall",
|
|
|
|
|
"context_precision", "noise_sensitivity",
|
|
|
|
|
"factual_correctness", "semantic_similarity"}
|
|
|
|
|
self.assertEqual(set(METRIC_RULES.keys()), expected)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
unittest.main()
|