146 lines
5.6 KiB
Python
146 lines
5.6 KiB
Python
"""Execution pipeline for scoring normalized samples with RAGAS metrics."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import math
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
from rag_eval.shared.models import MetricScore, NormalizedSample
|
|
|
|
logger = logging.getLogger("rag_eval.metrics.pipeline")
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class MetricPipeline:
|
|
"""Score one or many normalized samples against a configured metric set."""
|
|
|
|
metrics: dict[str, Any]
|
|
metric_timeout_seconds: float | None = None
|
|
|
|
async def score_sample(self, sample: NormalizedSample) -> MetricScore:
|
|
"""Score a single sample and capture metric-level failures without aborting."""
|
|
results = {name: math.nan for name in self.metrics}
|
|
errors: list[str] = []
|
|
|
|
sid = sample.sample_id[:12]
|
|
ans_len = len(sample.answer or "")
|
|
ctx_count = len(sample.contexts or [])
|
|
logger.debug(
|
|
"[score] sample=%s ans_len=%d ctx_count=%d question=%r",
|
|
sid, ans_len, ctx_count,
|
|
(sample.question or "")[:80],
|
|
)
|
|
|
|
for name, metric in self.metrics.items():
|
|
t0 = time.monotonic()
|
|
try:
|
|
result = await self._run_metric(name, metric, sample)
|
|
score_val = float(result.value)
|
|
results[name] = score_val
|
|
elapsed = time.monotonic() - t0
|
|
logger.info(
|
|
"[metric OK ] sample=%-12s %-20s score=%.4f elapsed=%.1fs",
|
|
sid, name, score_val, elapsed,
|
|
)
|
|
except asyncio.TimeoutError:
|
|
elapsed = time.monotonic() - t0
|
|
msg = f"timeout after {self.metric_timeout_seconds}s"
|
|
errors.append(f"{name}: {msg}")
|
|
logger.warning(
|
|
"[metric TMO] sample=%-12s %-20s TIMEOUT after %.1fs",
|
|
sid, name, elapsed,
|
|
)
|
|
except Exception as exc:
|
|
elapsed = time.monotonic() - t0
|
|
exc_type = type(exc).__name__
|
|
errors.append(f"{name}: {exc}")
|
|
logger.warning(
|
|
"[metric ERR] sample=%-12s %-20s %s: %s (elapsed=%.1fs)",
|
|
sid, name, exc_type, exc, elapsed,
|
|
)
|
|
|
|
return MetricScore(metrics=results, error=" | ".join(errors))
|
|
|
|
async def _run_metric(self, name: str, metric: Any, sample: NormalizedSample) -> Any:
|
|
"""Dispatch one metric call with the argument shape expected by that metric."""
|
|
timeout = None
|
|
if self.metric_timeout_seconds is not None:
|
|
timeout = max(1.0, float(self.metric_timeout_seconds))
|
|
|
|
if name == "faithfulness":
|
|
coroutine = metric.ascore(
|
|
user_input=sample.question,
|
|
response=sample.answer,
|
|
retrieved_contexts=sample.contexts,
|
|
)
|
|
elif name == "answer_relevancy":
|
|
coroutine = metric.ascore(
|
|
user_input=sample.question,
|
|
response=sample.answer,
|
|
)
|
|
elif name == "context_recall":
|
|
coroutine = metric.ascore(
|
|
user_input=sample.question,
|
|
retrieved_contexts=sample.contexts,
|
|
reference=sample.ground_truth,
|
|
)
|
|
elif name == "context_precision":
|
|
coroutine = metric.ascore(
|
|
user_input=sample.question,
|
|
reference=sample.ground_truth,
|
|
retrieved_contexts=sample.contexts,
|
|
)
|
|
elif name == "noise_sensitivity":
|
|
coroutine = metric.ascore(
|
|
user_input=sample.question,
|
|
response=sample.answer,
|
|
reference=sample.ground_truth,
|
|
retrieved_contexts=sample.contexts,
|
|
)
|
|
elif name == "factual_correctness":
|
|
coroutine = metric.ascore(
|
|
response=sample.answer,
|
|
reference=sample.ground_truth,
|
|
)
|
|
elif name == "semantic_similarity":
|
|
coroutine = metric.ascore(
|
|
reference=sample.ground_truth,
|
|
response=sample.answer,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported metric: {name}")
|
|
|
|
if timeout is None:
|
|
return await coroutine
|
|
return await asyncio.wait_for(coroutine, timeout=timeout)
|
|
|
|
async def score_samples(
|
|
self,
|
|
samples: list[NormalizedSample],
|
|
max_concurrency: int,
|
|
) -> list[MetricScore]:
|
|
"""Score all samples while respecting the configured concurrency limit."""
|
|
total = len(samples)
|
|
logger.info("[pipeline] scoring %d samples concurrency=%d timeout=%ss",
|
|
total, max_concurrency, self.metric_timeout_seconds)
|
|
semaphore = asyncio.Semaphore(max(1, max_concurrency))
|
|
completed = 0
|
|
|
|
async def guarded(idx: int, sample: NormalizedSample) -> MetricScore:
|
|
"""Throttle a single sample-scoring coroutine with the shared semaphore."""
|
|
nonlocal completed
|
|
async with semaphore:
|
|
result = await self.score_sample(sample)
|
|
completed += 1
|
|
nan_metrics = [k for k, v in result.metrics.items() if math.isnan(v)]
|
|
status = f"NaN={nan_metrics}" if nan_metrics else "all OK"
|
|
logger.info("[pipeline] progress %d/%d sample=%-12s %s",
|
|
completed, total, sample.sample_id[:12], status)
|
|
return result
|
|
|
|
return await asyncio.gather(*(guarded(i, s) for i, s in enumerate(samples)))
|