first commit
This commit is contained in:
5
rag_eval/config/__init__.py
Normal file
5
rag_eval/config/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Scenario configuration loading utilities."""
|
||||
|
||||
from .loader import load_scenario
|
||||
|
||||
__all__ = ["load_scenario"]
|
||||
67
rag_eval/config/loader.py
Normal file
67
rag_eval/config/loader.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Scenario file loading and conversion into internal runtime models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from rag_eval.shared.models import AppAdapterConfig, DatasetConfig, RuntimeConfig, Scenario
|
||||
|
||||
from .schema import ScenarioModel
|
||||
from .validators import validate_scenario
|
||||
|
||||
|
||||
def _resolve_static_kwargs_paths(base_dir: Path, raw_kwargs: dict[str, object]) -> dict[str, object]:
|
||||
"""Resolve adapter static kwargs that look like relative file-system paths."""
|
||||
resolved: dict[str, object] = {}
|
||||
for key, value in raw_kwargs.items():
|
||||
if key.endswith("_path") and isinstance(value, str):
|
||||
candidate = Path(value)
|
||||
resolved[key] = candidate if candidate.is_absolute() else (base_dir / candidate).resolve()
|
||||
continue
|
||||
resolved[key] = value
|
||||
return resolved
|
||||
|
||||
|
||||
def load_scenario(path: str | Path) -> Scenario:
|
||||
"""Load, validate, and resolve a scenario file into the internal scenario model."""
|
||||
scenario_path = Path(path).resolve()
|
||||
payload = yaml.safe_load(scenario_path.read_text(encoding="utf-8")) or {}
|
||||
model = ScenarioModel.model_validate(payload)
|
||||
base_dir = scenario_path.parent
|
||||
|
||||
app_adapter = None
|
||||
if model.app_adapter is not None:
|
||||
# Convert the validated Pydantic model into the lightweight runtime dataclass.
|
||||
app_adapter = AppAdapterConfig(
|
||||
type=model.app_adapter.type,
|
||||
endpoint=model.app_adapter.endpoint,
|
||||
method=model.app_adapter.method,
|
||||
timeout_seconds=model.app_adapter.timeout_seconds,
|
||||
callable=model.app_adapter.callable,
|
||||
request_template=model.app_adapter.request_template,
|
||||
response_mapping=model.app_adapter.response_mapping,
|
||||
static_kwargs=_resolve_static_kwargs_paths(base_dir, model.app_adapter.static_kwargs),
|
||||
)
|
||||
|
||||
scenario = Scenario(
|
||||
scenario_name=model.scenario_name,
|
||||
mode=model.mode,
|
||||
app_adapter=app_adapter,
|
||||
dataset=DatasetConfig(path=model.resolve_path(base_dir, model.dataset)),
|
||||
judge_model=model.judge_model,
|
||||
embedding_model=model.embedding_model,
|
||||
metrics=model.metrics,
|
||||
output_dir=model.resolve_path(base_dir, model.output_dir),
|
||||
runtime=RuntimeConfig(
|
||||
batch_size=model.runtime.batch_size,
|
||||
app_concurrency=model.runtime.app_concurrency,
|
||||
metric_concurrency=model.runtime.metric_concurrency,
|
||||
max_samples=model.runtime.max_samples,
|
||||
),
|
||||
source_path=scenario_path,
|
||||
)
|
||||
# Run cross-field checks after all relative paths have been resolved.
|
||||
validate_scenario(scenario)
|
||||
return scenario
|
||||
78
rag_eval/config/schema.py
Normal file
78
rag_eval/config/schema.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Pydantic schemas used to validate raw scenario configuration files."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
|
||||
class RuntimeConfigModel(BaseModel):
|
||||
"""Schema for runtime concurrency and sampling settings."""
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
batch_size: int = 4
|
||||
app_concurrency: int | None = None
|
||||
metric_concurrency: int | None = None
|
||||
max_samples: int | None = None
|
||||
|
||||
|
||||
class AppAdapterConfigModel(BaseModel):
|
||||
"""Schema for adapter-specific configuration in online scenarios."""
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
type: Literal["http", "python"]
|
||||
endpoint: str | None = None
|
||||
method: str = "POST"
|
||||
timeout_seconds: int = 30
|
||||
callable: str | None = None
|
||||
request_template: dict[str, Any] = Field(default_factory=dict)
|
||||
response_mapping: dict[str, str] = Field(default_factory=dict)
|
||||
static_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_shape(self) -> "AppAdapterConfigModel":
|
||||
"""Enforce the fields required by each adapter type."""
|
||||
if self.type == "http" and not self.endpoint:
|
||||
raise ValueError("HTTP adapter requires endpoint.")
|
||||
if self.type == "python" and not self.callable:
|
||||
raise ValueError("Python adapter requires callable.")
|
||||
return self
|
||||
|
||||
|
||||
class ScenarioModel(BaseModel):
|
||||
"""Schema for a user-authored evaluation scenario file."""
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
scenario_name: str
|
||||
mode: Literal["offline", "online"]
|
||||
app_adapter: AppAdapterConfigModel | None = None
|
||||
dataset: str
|
||||
judge_model: str
|
||||
embedding_model: str
|
||||
metrics: list[str]
|
||||
output_dir: str
|
||||
runtime: RuntimeConfigModel = Field(default_factory=RuntimeConfigModel)
|
||||
|
||||
@field_validator("metrics")
|
||||
@classmethod
|
||||
def ensure_metrics_not_empty(cls, value: list[str]) -> list[str]:
|
||||
"""Reject scenarios that do not request any metrics."""
|
||||
if not value:
|
||||
raise ValueError("metrics must not be empty.")
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_mode_requirements(self) -> "ScenarioModel":
|
||||
"""Ensure online scenarios define the adapter they depend on."""
|
||||
if self.mode == "online" and self.app_adapter is None:
|
||||
raise ValueError("online mode requires app_adapter.")
|
||||
return self
|
||||
|
||||
def resolve_path(self, base_dir: Path, raw_path: str) -> Path:
|
||||
"""Resolve relative paths against the scenario file directory."""
|
||||
candidate = Path(raw_path)
|
||||
if candidate.is_absolute():
|
||||
return candidate
|
||||
return (base_dir / candidate).resolve()
|
||||
20
rag_eval/config/validators.py
Normal file
20
rag_eval/config/validators.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Cross-field validation helpers for resolved runtime scenarios."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from rag_eval.metrics.registry import SUPPORTED_METRICS
|
||||
from rag_eval.shared.models import Scenario
|
||||
|
||||
|
||||
def validate_scenario(scenario: Scenario) -> None:
|
||||
"""Validate metric selection and mode-specific runtime constraints."""
|
||||
unsupported = [name for name in scenario.metrics if name not in SUPPORTED_METRICS]
|
||||
if unsupported:
|
||||
supported = ", ".join(sorted(SUPPORTED_METRICS))
|
||||
raise ValueError(
|
||||
f"Unsupported metrics: {', '.join(unsupported)}. Supported metrics: {supported}"
|
||||
)
|
||||
if scenario.mode == "offline" and scenario.app_adapter is not None:
|
||||
raise ValueError("offline mode should not define app_adapter.")
|
||||
if scenario.runtime.batch_size < 1:
|
||||
raise ValueError("runtime.batch_size must be >= 1.")
|
||||
Reference in New Issue
Block a user