feat: add metric_weights and doc_weights to Scenario schema and dataclass
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -62,6 +62,8 @@ def load_scenario(path: str | Path) -> Scenario:
|
|||||||
),
|
),
|
||||||
source_path=scenario_path,
|
source_path=scenario_path,
|
||||||
optimization_advisor=model.optimization_advisor,
|
optimization_advisor=model.optimization_advisor,
|
||||||
|
metric_weights=dict(model.metric_weights),
|
||||||
|
doc_weights=dict(model.doc_weights),
|
||||||
)
|
)
|
||||||
# Run cross-field checks after all relative paths have been resolved.
|
# Run cross-field checks after all relative paths have been resolved.
|
||||||
validate_scenario(scenario)
|
validate_scenario(scenario)
|
||||||
|
|||||||
@@ -55,6 +55,8 @@ class ScenarioModel(BaseModel):
|
|||||||
output_dir: str
|
output_dir: str
|
||||||
runtime: RuntimeConfigModel = Field(default_factory=RuntimeConfigModel)
|
runtime: RuntimeConfigModel = Field(default_factory=RuntimeConfigModel)
|
||||||
optimization_advisor: bool = False
|
optimization_advisor: bool = False
|
||||||
|
metric_weights: dict[str, float] = Field(default_factory=dict)
|
||||||
|
doc_weights: dict[str, float] = Field(default_factory=dict)
|
||||||
|
|
||||||
@field_validator("metrics")
|
@field_validator("metrics")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -77,6 +77,8 @@ class Scenario:
|
|||||||
app_adapter: AppAdapterConfig | None = None
|
app_adapter: AppAdapterConfig | None = None
|
||||||
source_path: Path | None = None
|
source_path: Path | None = None
|
||||||
optimization_advisor: bool = False
|
optimization_advisor: bool = False
|
||||||
|
metric_weights: dict[str, float] = field(default_factory=dict)
|
||||||
|
doc_weights: dict[str, float] = field(default_factory=dict)
|
||||||
|
|
||||||
def snapshot(self) -> dict[str, Any]:
|
def snapshot(self) -> dict[str, Any]:
|
||||||
"""Serialize the scenario into a reporting-friendly dictionary snapshot."""
|
"""Serialize the scenario into a reporting-friendly dictionary snapshot."""
|
||||||
|
|||||||
@@ -80,6 +80,64 @@ class ScenarioAndDatasetTests(unittest.TestCase):
|
|||||||
self.assertTrue(scenario.dataset.path.name.endswith(".csv"))
|
self.assertTrue(scenario.dataset.path.name.endswith(".csv"))
|
||||||
self.assertTrue(scenario.output_dir.name == "sample-offline-baseline")
|
self.assertTrue(scenario.output_dir.name == "sample-offline-baseline")
|
||||||
|
|
||||||
|
def test_load_scenario_metric_and_doc_weights(self) -> None:
|
||||||
|
"""load_scenario passes metric_weights and doc_weights into Scenario."""
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from rag_eval.config.loader import load_scenario
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"scenario_name": "w-test",
|
||||||
|
"mode": "offline",
|
||||||
|
"dataset": "nonexistent.csv",
|
||||||
|
"judge_model": "m",
|
||||||
|
"embedding_model": "e",
|
||||||
|
"metrics": ["faithfulness"],
|
||||||
|
"output_dir": "out",
|
||||||
|
"metric_weights": {"faithfulness": 0.7},
|
||||||
|
"doc_weights": {"doc.pdf": 2.0},
|
||||||
|
}
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".yaml", mode="w", encoding="utf-8", delete=False) as f:
|
||||||
|
yaml.dump(payload, f, allow_unicode=True)
|
||||||
|
tmp_path = f.name
|
||||||
|
try:
|
||||||
|
scenario = load_scenario(tmp_path)
|
||||||
|
assert scenario.metric_weights == {"faithfulness": 0.7}
|
||||||
|
assert scenario.doc_weights == {"doc.pdf": 2.0}
|
||||||
|
finally:
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
|
||||||
|
def test_load_scenario_defaults_to_empty_weights(self) -> None:
|
||||||
|
"""load_scenario defaults metric_weights and doc_weights to empty dicts."""
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from rag_eval.config.loader import load_scenario
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"scenario_name": "no-w",
|
||||||
|
"mode": "offline",
|
||||||
|
"dataset": "nonexistent.csv",
|
||||||
|
"judge_model": "m",
|
||||||
|
"embedding_model": "e",
|
||||||
|
"metrics": ["faithfulness"],
|
||||||
|
"output_dir": "out",
|
||||||
|
}
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".yaml", mode="w", encoding="utf-8", delete=False) as f:
|
||||||
|
yaml.dump(payload, f, allow_unicode=True)
|
||||||
|
tmp_path = f.name
|
||||||
|
try:
|
||||||
|
scenario = load_scenario(tmp_path)
|
||||||
|
assert scenario.metric_weights == {}
|
||||||
|
assert scenario.doc_weights == {}
|
||||||
|
finally:
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
|
||||||
def test_scenario_snapshot_serializes_path_static_kwargs(self) -> None:
|
def test_scenario_snapshot_serializes_path_static_kwargs(self) -> None:
|
||||||
scenario = load_scenario("scenarios/online/sample-pdf-question-bank-online.yaml")
|
scenario = load_scenario("scenarios/online/sample-pdf-question-bank-online.yaml")
|
||||||
snapshot = scenario.snapshot()
|
snapshot = scenario.snapshot()
|
||||||
|
|||||||
Reference in New Issue
Block a user