feat: add metric_weights and doc_weights to Scenario schema and dataclass

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
2026-06-18 16:50:33 +08:00
parent e0b064587f
commit 8617eaa5aa
4 changed files with 64 additions and 0 deletions

View File

@@ -62,6 +62,8 @@ def load_scenario(path: str | Path) -> Scenario:
), ),
source_path=scenario_path, source_path=scenario_path,
optimization_advisor=model.optimization_advisor, optimization_advisor=model.optimization_advisor,
metric_weights=dict(model.metric_weights),
doc_weights=dict(model.doc_weights),
) )
# Run cross-field checks after all relative paths have been resolved. # Run cross-field checks after all relative paths have been resolved.
validate_scenario(scenario) validate_scenario(scenario)

View File

@@ -55,6 +55,8 @@ class ScenarioModel(BaseModel):
output_dir: str output_dir: str
runtime: RuntimeConfigModel = Field(default_factory=RuntimeConfigModel) runtime: RuntimeConfigModel = Field(default_factory=RuntimeConfigModel)
optimization_advisor: bool = False optimization_advisor: bool = False
metric_weights: dict[str, float] = Field(default_factory=dict)
doc_weights: dict[str, float] = Field(default_factory=dict)
@field_validator("metrics") @field_validator("metrics")
@classmethod @classmethod

View File

@@ -77,6 +77,8 @@ class Scenario:
app_adapter: AppAdapterConfig | None = None app_adapter: AppAdapterConfig | None = None
source_path: Path | None = None source_path: Path | None = None
optimization_advisor: bool = False optimization_advisor: bool = False
metric_weights: dict[str, float] = field(default_factory=dict)
doc_weights: dict[str, float] = field(default_factory=dict)
def snapshot(self) -> dict[str, Any]: def snapshot(self) -> dict[str, Any]:
"""Serialize the scenario into a reporting-friendly dictionary snapshot.""" """Serialize the scenario into a reporting-friendly dictionary snapshot."""

View File

@@ -80,6 +80,64 @@ class ScenarioAndDatasetTests(unittest.TestCase):
self.assertTrue(scenario.dataset.path.name.endswith(".csv")) self.assertTrue(scenario.dataset.path.name.endswith(".csv"))
self.assertTrue(scenario.output_dir.name == "sample-offline-baseline") self.assertTrue(scenario.output_dir.name == "sample-offline-baseline")
def test_load_scenario_metric_and_doc_weights(self) -> None:
"""load_scenario passes metric_weights and doc_weights into Scenario."""
import os
import tempfile
import yaml
from rag_eval.config.loader import load_scenario
payload = {
"scenario_name": "w-test",
"mode": "offline",
"dataset": "nonexistent.csv",
"judge_model": "m",
"embedding_model": "e",
"metrics": ["faithfulness"],
"output_dir": "out",
"metric_weights": {"faithfulness": 0.7},
"doc_weights": {"doc.pdf": 2.0},
}
with tempfile.NamedTemporaryFile(suffix=".yaml", mode="w", encoding="utf-8", delete=False) as f:
yaml.dump(payload, f, allow_unicode=True)
tmp_path = f.name
try:
scenario = load_scenario(tmp_path)
assert scenario.metric_weights == {"faithfulness": 0.7}
assert scenario.doc_weights == {"doc.pdf": 2.0}
finally:
os.unlink(tmp_path)
def test_load_scenario_defaults_to_empty_weights(self) -> None:
"""load_scenario defaults metric_weights and doc_weights to empty dicts."""
import os
import tempfile
import yaml
from rag_eval.config.loader import load_scenario
payload = {
"scenario_name": "no-w",
"mode": "offline",
"dataset": "nonexistent.csv",
"judge_model": "m",
"embedding_model": "e",
"metrics": ["faithfulness"],
"output_dir": "out",
}
with tempfile.NamedTemporaryFile(suffix=".yaml", mode="w", encoding="utf-8", delete=False) as f:
yaml.dump(payload, f, allow_unicode=True)
tmp_path = f.name
try:
scenario = load_scenario(tmp_path)
assert scenario.metric_weights == {}
assert scenario.doc_weights == {}
finally:
os.unlink(tmp_path)
def test_scenario_snapshot_serializes_path_static_kwargs(self) -> None: def test_scenario_snapshot_serializes_path_static_kwargs(self) -> None:
scenario = load_scenario("scenarios/online/sample-pdf-question-bank-online.yaml") scenario = load_scenario("scenarios/online/sample-pdf-question-bank-online.yaml")
snapshot = scenario.snapshot() snapshot = scenario.snapshot()