Files
siemens_ragas/rag_eval/datasets/normalizers.py

106 lines
3.2 KiB
Python
Raw Normal View History

"""Normalize raw dataset records into NormalizedSample and InvalidSample objects.
Handles both offline mode (records already contain answer + contexts) and online
mode (records only contain question + ground_truth; adapter fills the rest).
"""
from __future__ import annotations
import uuid
from typing import Any
from rag_eval.shared.models import InvalidSample, NormalizedSample
from rag_eval.shared.utils import parse_contexts
# Fields we always strip from the raw record before storing it in metadata.
_CORE_FIELDS = {
"sample_id",
"question",
"contexts",
"answer",
"ground_truth",
"scenario",
"language",
"retrieval_config",
}
def _get_str(record: dict[str, Any], key: str, default: str = "") -> str:
"""Return a string field from the record, coercing None/NaN to the default."""
value = record.get(key)
if value is None:
return default
text = str(value).strip()
return default if text.lower() == "nan" else text
def normalize_records(
records: list[dict[str, Any]],
mode: str = "offline",
max_samples: int | None = None,
) -> tuple[list[NormalizedSample], list[InvalidSample]]:
"""Convert raw dicts into NormalizedSample / InvalidSample collections.
In offline mode every record must already contain answer and contexts.
In online mode those fields may be absent; they will be filled by the adapter.
"""
if max_samples is not None:
records = records[:max_samples]
valid: list[NormalizedSample] = []
invalid: list[InvalidSample] = []
for raw in records:
sample_id = _get_str(raw, "sample_id") or uuid.uuid4().hex[:12]
question = _get_str(raw, "question")
if not question:
invalid.append(InvalidSample(
sample_id=sample_id,
error="missing required field: question",
raw=raw,
))
continue
ground_truth = _get_str(raw, "ground_truth")
contexts = parse_contexts(raw.get("contexts"))
answer = _get_str(raw, "answer")
if mode == "offline":
errors: list[str] = []
if not ground_truth:
errors.append("missing ground_truth")
if not answer:
errors.append("missing answer")
if not contexts:
errors.append("missing or empty contexts")
if errors:
invalid.append(InvalidSample(
sample_id=sample_id,
error="; ".join(errors),
raw=raw,
))
continue
# Collect any extra columns as opaque metadata for adapters and reporting.
metadata = {
key: value
for key, value in raw.items()
if key not in _CORE_FIELDS
}
valid.append(NormalizedSample(
sample_id=sample_id,
question=question,
contexts=contexts,
answer=answer,
ground_truth=ground_truth,
scenario=_get_str(raw, "scenario"),
language=_get_str(raw, "language"),
retrieval_config=_get_str(raw, "retrieval_config"),
metadata=metadata,
raw=raw,
))
return valid, invalid