Files
siemens_ragas/rag_eval/metrics/pipeline.py

146 lines
5.6 KiB
Python
Raw Permalink Normal View History

2026-06-12 14:02:15 +08:00
"""Execution pipeline for scoring normalized samples with RAGAS metrics."""
from __future__ import annotations
import asyncio
import logging
2026-06-12 14:02:15 +08:00
import math
import time
2026-06-12 14:02:15 +08:00
from dataclasses import dataclass
from typing import Any
from rag_eval.shared.models import MetricScore, NormalizedSample
logger = logging.getLogger("rag_eval.metrics.pipeline")
2026-06-12 14:02:15 +08:00
@dataclass(slots=True)
class MetricPipeline:
"""Score one or many normalized samples against a configured metric set."""
metrics: dict[str, Any]
metric_timeout_seconds: float | None = None
async def score_sample(self, sample: NormalizedSample) -> MetricScore:
"""Score a single sample and capture metric-level failures without aborting."""
results = {name: math.nan for name in self.metrics}
errors: list[str] = []
sid = sample.sample_id[:12]
ans_len = len(sample.answer or "")
ctx_count = len(sample.contexts or [])
logger.debug(
"[score] sample=%s ans_len=%d ctx_count=%d question=%r",
sid, ans_len, ctx_count,
(sample.question or "")[:80],
)
2026-06-12 14:02:15 +08:00
for name, metric in self.metrics.items():
t0 = time.monotonic()
2026-06-12 14:02:15 +08:00
try:
result = await self._run_metric(name, metric, sample)
score_val = float(result.value)
results[name] = score_val
elapsed = time.monotonic() - t0
logger.info(
"[metric OK ] sample=%-12s %-20s score=%.4f elapsed=%.1fs",
sid, name, score_val, elapsed,
)
except asyncio.TimeoutError:
elapsed = time.monotonic() - t0
msg = f"timeout after {self.metric_timeout_seconds}s"
errors.append(f"{name}: {msg}")
logger.warning(
"[metric TMO] sample=%-12s %-20s TIMEOUT after %.1fs",
sid, name, elapsed,
)
2026-06-12 14:02:15 +08:00
except Exception as exc:
elapsed = time.monotonic() - t0
exc_type = type(exc).__name__
2026-06-12 14:02:15 +08:00
errors.append(f"{name}: {exc}")
logger.warning(
"[metric ERR] sample=%-12s %-20s %s: %s (elapsed=%.1fs)",
sid, name, exc_type, exc, elapsed,
)
2026-06-12 14:02:15 +08:00
return MetricScore(metrics=results, error=" | ".join(errors))
async def _run_metric(self, name: str, metric: Any, sample: NormalizedSample) -> Any:
"""Dispatch one metric call with the argument shape expected by that metric."""
timeout = None
if self.metric_timeout_seconds is not None:
timeout = max(1.0, float(self.metric_timeout_seconds))
if name == "faithfulness":
coroutine = metric.ascore(
user_input=sample.question,
response=sample.answer,
retrieved_contexts=sample.contexts,
)
elif name == "answer_relevancy":
coroutine = metric.ascore(
user_input=sample.question,
response=sample.answer,
)
elif name == "context_recall":
coroutine = metric.ascore(
user_input=sample.question,
retrieved_contexts=sample.contexts,
reference=sample.ground_truth,
)
elif name == "context_precision":
coroutine = metric.ascore(
user_input=sample.question,
reference=sample.ground_truth,
retrieved_contexts=sample.contexts,
)
2026-06-16 18:12:33 +08:00
elif name == "noise_sensitivity":
coroutine = metric.ascore(
user_input=sample.question,
response=sample.answer,
reference=sample.ground_truth,
retrieved_contexts=sample.contexts,
)
elif name == "factual_correctness":
coroutine = metric.ascore(
response=sample.answer,
reference=sample.ground_truth,
)
elif name == "semantic_similarity":
coroutine = metric.ascore(
reference=sample.ground_truth,
response=sample.answer,
)
2026-06-12 14:02:15 +08:00
else:
raise ValueError(f"Unsupported metric: {name}")
if timeout is None:
return await coroutine
return await asyncio.wait_for(coroutine, timeout=timeout)
async def score_samples(
self,
samples: list[NormalizedSample],
max_concurrency: int,
) -> list[MetricScore]:
"""Score all samples while respecting the configured concurrency limit."""
total = len(samples)
logger.info("[pipeline] scoring %d samples concurrency=%d timeout=%ss",
total, max_concurrency, self.metric_timeout_seconds)
2026-06-12 14:02:15 +08:00
semaphore = asyncio.Semaphore(max(1, max_concurrency))
completed = 0
2026-06-12 14:02:15 +08:00
async def guarded(idx: int, sample: NormalizedSample) -> MetricScore:
2026-06-12 14:02:15 +08:00
"""Throttle a single sample-scoring coroutine with the shared semaphore."""
nonlocal completed
2026-06-12 14:02:15 +08:00
async with semaphore:
result = await self.score_sample(sample)
completed += 1
nan_metrics = [k for k, v in result.metrics.items() if math.isnan(v)]
status = f"NaN={nan_metrics}" if nan_metrics else "all OK"
logger.info("[pipeline] progress %d/%d sample=%-12s %s",
completed, total, sample.sample_id[:12], status)
return result
2026-06-12 14:02:15 +08:00
return await asyncio.gather(*(guarded(i, s) for i, s in enumerate(samples)))