"""Execution pipeline for scoring normalized samples with RAGAS metrics.""" from __future__ import annotations import asyncio import logging import math import time from dataclasses import dataclass from typing import Any from rag_eval.shared.models import MetricScore, NormalizedSample logger = logging.getLogger("rag_eval.metrics.pipeline") @dataclass(slots=True) class MetricPipeline: """Score one or many normalized samples against a configured metric set.""" metrics: dict[str, Any] metric_timeout_seconds: float | None = None async def score_sample(self, sample: NormalizedSample) -> MetricScore: """Score a single sample and capture metric-level failures without aborting.""" results = {name: math.nan for name in self.metrics} errors: list[str] = [] sid = sample.sample_id[:12] ans_len = len(sample.answer or "") ctx_count = len(sample.contexts or []) logger.debug( "[score] sample=%s ans_len=%d ctx_count=%d question=%r", sid, ans_len, ctx_count, (sample.question or "")[:80], ) for name, metric in self.metrics.items(): t0 = time.monotonic() try: result = await self._run_metric(name, metric, sample) score_val = float(result.value) results[name] = score_val elapsed = time.monotonic() - t0 logger.info( "[metric OK ] sample=%-12s %-20s score=%.4f elapsed=%.1fs", sid, name, score_val, elapsed, ) except asyncio.TimeoutError: elapsed = time.monotonic() - t0 msg = f"timeout after {self.metric_timeout_seconds}s" errors.append(f"{name}: {msg}") logger.warning( "[metric TMO] sample=%-12s %-20s TIMEOUT after %.1fs", sid, name, elapsed, ) except Exception as exc: elapsed = time.monotonic() - t0 exc_type = type(exc).__name__ errors.append(f"{name}: {exc}") logger.warning( "[metric ERR] sample=%-12s %-20s %s: %s (elapsed=%.1fs)", sid, name, exc_type, exc, elapsed, ) return MetricScore(metrics=results, error=" | ".join(errors)) async def _run_metric(self, name: str, metric: Any, sample: NormalizedSample) -> Any: """Dispatch one metric call with the argument shape expected by that metric.""" timeout = None if self.metric_timeout_seconds is not None: timeout = max(1.0, float(self.metric_timeout_seconds)) if name == "faithfulness": coroutine = metric.ascore( user_input=sample.question, response=sample.answer, retrieved_contexts=sample.contexts, ) elif name == "answer_relevancy": coroutine = metric.ascore( user_input=sample.question, response=sample.answer, ) elif name == "context_recall": coroutine = metric.ascore( user_input=sample.question, retrieved_contexts=sample.contexts, reference=sample.ground_truth, ) elif name == "context_precision": coroutine = metric.ascore( user_input=sample.question, reference=sample.ground_truth, retrieved_contexts=sample.contexts, ) elif name == "noise_sensitivity": coroutine = metric.ascore( user_input=sample.question, response=sample.answer, reference=sample.ground_truth, retrieved_contexts=sample.contexts, ) elif name == "factual_correctness": coroutine = metric.ascore( response=sample.answer, reference=sample.ground_truth, ) elif name == "semantic_similarity": coroutine = metric.ascore( reference=sample.ground_truth, response=sample.answer, ) else: raise ValueError(f"Unsupported metric: {name}") if timeout is None: return await coroutine return await asyncio.wait_for(coroutine, timeout=timeout) async def score_samples( self, samples: list[NormalizedSample], max_concurrency: int, ) -> list[MetricScore]: """Score all samples while respecting the configured concurrency limit.""" total = len(samples) logger.info("[pipeline] scoring %d samples concurrency=%d timeout=%ss", total, max_concurrency, self.metric_timeout_seconds) semaphore = asyncio.Semaphore(max(1, max_concurrency)) completed = 0 async def guarded(idx: int, sample: NormalizedSample) -> MetricScore: """Throttle a single sample-scoring coroutine with the shared semaphore.""" nonlocal completed async with semaphore: result = await self.score_sample(sample) completed += 1 nan_metrics = [k for k, v in result.metrics.items() if math.isnan(v)] status = f"NaN={nan_metrics}" if nan_metrics else "all OK" logger.info("[pipeline] progress %d/%d sample=%-12s %s", completed, total, sample.sample_id[:12], status) return result return await asyncio.gather(*(guarded(i, s) for i, s in enumerate(samples)))