106 lines
3.2 KiB
Python
106 lines
3.2 KiB
Python
|
|
"""Normalize raw dataset records into NormalizedSample and InvalidSample objects.
|
||
|
|
|
||
|
|
Handles both offline mode (records already contain answer + contexts) and online
|
||
|
|
mode (records only contain question + ground_truth; adapter fills the rest).
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import uuid
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
from rag_eval.shared.models import InvalidSample, NormalizedSample
|
||
|
|
from rag_eval.shared.utils import parse_contexts
|
||
|
|
|
||
|
|
# Fields we always strip from the raw record before storing it in metadata.
|
||
|
|
_CORE_FIELDS = {
|
||
|
|
"sample_id",
|
||
|
|
"question",
|
||
|
|
"contexts",
|
||
|
|
"answer",
|
||
|
|
"ground_truth",
|
||
|
|
"scenario",
|
||
|
|
"language",
|
||
|
|
"retrieval_config",
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def _get_str(record: dict[str, Any], key: str, default: str = "") -> str:
|
||
|
|
"""Return a string field from the record, coercing None/NaN to the default."""
|
||
|
|
value = record.get(key)
|
||
|
|
if value is None:
|
||
|
|
return default
|
||
|
|
text = str(value).strip()
|
||
|
|
return default if text.lower() == "nan" else text
|
||
|
|
|
||
|
|
|
||
|
|
def normalize_records(
|
||
|
|
records: list[dict[str, Any]],
|
||
|
|
mode: str = "offline",
|
||
|
|
max_samples: int | None = None,
|
||
|
|
) -> tuple[list[NormalizedSample], list[InvalidSample]]:
|
||
|
|
"""Convert raw dicts into NormalizedSample / InvalidSample collections.
|
||
|
|
|
||
|
|
In offline mode every record must already contain answer and contexts.
|
||
|
|
In online mode those fields may be absent; they will be filled by the adapter.
|
||
|
|
"""
|
||
|
|
if max_samples is not None:
|
||
|
|
records = records[:max_samples]
|
||
|
|
|
||
|
|
valid: list[NormalizedSample] = []
|
||
|
|
invalid: list[InvalidSample] = []
|
||
|
|
|
||
|
|
for raw in records:
|
||
|
|
sample_id = _get_str(raw, "sample_id") or uuid.uuid4().hex[:12]
|
||
|
|
|
||
|
|
question = _get_str(raw, "question")
|
||
|
|
if not question:
|
||
|
|
invalid.append(InvalidSample(
|
||
|
|
sample_id=sample_id,
|
||
|
|
error="missing required field: question",
|
||
|
|
raw=raw,
|
||
|
|
))
|
||
|
|
continue
|
||
|
|
|
||
|
|
ground_truth = _get_str(raw, "ground_truth")
|
||
|
|
contexts = parse_contexts(raw.get("contexts"))
|
||
|
|
answer = _get_str(raw, "answer")
|
||
|
|
|
||
|
|
if mode == "offline":
|
||
|
|
errors: list[str] = []
|
||
|
|
if not ground_truth:
|
||
|
|
errors.append("missing ground_truth")
|
||
|
|
if not answer:
|
||
|
|
errors.append("missing answer")
|
||
|
|
if not contexts:
|
||
|
|
errors.append("missing or empty contexts")
|
||
|
|
if errors:
|
||
|
|
invalid.append(InvalidSample(
|
||
|
|
sample_id=sample_id,
|
||
|
|
error="; ".join(errors),
|
||
|
|
raw=raw,
|
||
|
|
))
|
||
|
|
continue
|
||
|
|
|
||
|
|
# Collect any extra columns as opaque metadata for adapters and reporting.
|
||
|
|
metadata = {
|
||
|
|
key: value
|
||
|
|
for key, value in raw.items()
|
||
|
|
if key not in _CORE_FIELDS
|
||
|
|
}
|
||
|
|
|
||
|
|
valid.append(NormalizedSample(
|
||
|
|
sample_id=sample_id,
|
||
|
|
question=question,
|
||
|
|
contexts=contexts,
|
||
|
|
answer=answer,
|
||
|
|
ground_truth=ground_truth,
|
||
|
|
scenario=_get_str(raw, "scenario"),
|
||
|
|
language=_get_str(raw, "language"),
|
||
|
|
retrieval_config=_get_str(raw, "retrieval_config"),
|
||
|
|
metadata=metadata,
|
||
|
|
raw=raw,
|
||
|
|
))
|
||
|
|
|
||
|
|
return valid, invalid
|