feat: add metric_weights and doc_weights to Scenario schema and dataclass
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -62,6 +62,8 @@ def load_scenario(path: str | Path) -> Scenario:
|
||||
),
|
||||
source_path=scenario_path,
|
||||
optimization_advisor=model.optimization_advisor,
|
||||
metric_weights=dict(model.metric_weights),
|
||||
doc_weights=dict(model.doc_weights),
|
||||
)
|
||||
# Run cross-field checks after all relative paths have been resolved.
|
||||
validate_scenario(scenario)
|
||||
|
||||
@@ -55,6 +55,8 @@ class ScenarioModel(BaseModel):
|
||||
output_dir: str
|
||||
runtime: RuntimeConfigModel = Field(default_factory=RuntimeConfigModel)
|
||||
optimization_advisor: bool = False
|
||||
metric_weights: dict[str, float] = Field(default_factory=dict)
|
||||
doc_weights: dict[str, float] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("metrics")
|
||||
@classmethod
|
||||
|
||||
@@ -77,6 +77,8 @@ class Scenario:
|
||||
app_adapter: AppAdapterConfig | None = None
|
||||
source_path: Path | None = None
|
||||
optimization_advisor: bool = False
|
||||
metric_weights: dict[str, float] = field(default_factory=dict)
|
||||
doc_weights: dict[str, float] = field(default_factory=dict)
|
||||
|
||||
def snapshot(self) -> dict[str, Any]:
|
||||
"""Serialize the scenario into a reporting-friendly dictionary snapshot."""
|
||||
|
||||
@@ -80,6 +80,64 @@ class ScenarioAndDatasetTests(unittest.TestCase):
|
||||
self.assertTrue(scenario.dataset.path.name.endswith(".csv"))
|
||||
self.assertTrue(scenario.output_dir.name == "sample-offline-baseline")
|
||||
|
||||
def test_load_scenario_metric_and_doc_weights(self) -> None:
|
||||
"""load_scenario passes metric_weights and doc_weights into Scenario."""
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import yaml
|
||||
|
||||
from rag_eval.config.loader import load_scenario
|
||||
|
||||
payload = {
|
||||
"scenario_name": "w-test",
|
||||
"mode": "offline",
|
||||
"dataset": "nonexistent.csv",
|
||||
"judge_model": "m",
|
||||
"embedding_model": "e",
|
||||
"metrics": ["faithfulness"],
|
||||
"output_dir": "out",
|
||||
"metric_weights": {"faithfulness": 0.7},
|
||||
"doc_weights": {"doc.pdf": 2.0},
|
||||
}
|
||||
with tempfile.NamedTemporaryFile(suffix=".yaml", mode="w", encoding="utf-8", delete=False) as f:
|
||||
yaml.dump(payload, f, allow_unicode=True)
|
||||
tmp_path = f.name
|
||||
try:
|
||||
scenario = load_scenario(tmp_path)
|
||||
assert scenario.metric_weights == {"faithfulness": 0.7}
|
||||
assert scenario.doc_weights == {"doc.pdf": 2.0}
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
def test_load_scenario_defaults_to_empty_weights(self) -> None:
|
||||
"""load_scenario defaults metric_weights and doc_weights to empty dicts."""
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import yaml
|
||||
|
||||
from rag_eval.config.loader import load_scenario
|
||||
|
||||
payload = {
|
||||
"scenario_name": "no-w",
|
||||
"mode": "offline",
|
||||
"dataset": "nonexistent.csv",
|
||||
"judge_model": "m",
|
||||
"embedding_model": "e",
|
||||
"metrics": ["faithfulness"],
|
||||
"output_dir": "out",
|
||||
}
|
||||
with tempfile.NamedTemporaryFile(suffix=".yaml", mode="w", encoding="utf-8", delete=False) as f:
|
||||
yaml.dump(payload, f, allow_unicode=True)
|
||||
tmp_path = f.name
|
||||
try:
|
||||
scenario = load_scenario(tmp_path)
|
||||
assert scenario.metric_weights == {}
|
||||
assert scenario.doc_weights == {}
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
def test_scenario_snapshot_serializes_path_static_kwargs(self) -> None:
|
||||
scenario = load_scenario("scenarios/online/sample-pdf-question-bank-online.yaml")
|
||||
snapshot = scenario.snapshot()
|
||||
|
||||
Reference in New Issue
Block a user