From 8617eaa5aae0a71bfcb8436a60fa4ef2b2eb4736 Mon Sep 17 00:00:00 2001 From: wangwei Date: Thu, 18 Jun 2026 16:50:33 +0800 Subject: [PATCH] feat: add metric_weights and doc_weights to Scenario schema and dataclass Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- rag_eval/config/loader.py | 2 ++ rag_eval/config/schema.py | 2 ++ rag_eval/shared/models.py | 2 ++ tests/test_offline_eval.py | 58 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 64 insertions(+) diff --git a/rag_eval/config/loader.py b/rag_eval/config/loader.py index f4ffd4b..b68f8bc 100644 --- a/rag_eval/config/loader.py +++ b/rag_eval/config/loader.py @@ -62,6 +62,8 @@ def load_scenario(path: str | Path) -> Scenario: ), source_path=scenario_path, optimization_advisor=model.optimization_advisor, + metric_weights=dict(model.metric_weights), + doc_weights=dict(model.doc_weights), ) # Run cross-field checks after all relative paths have been resolved. validate_scenario(scenario) diff --git a/rag_eval/config/schema.py b/rag_eval/config/schema.py index f36e8ac..3fac72a 100644 --- a/rag_eval/config/schema.py +++ b/rag_eval/config/schema.py @@ -55,6 +55,8 @@ class ScenarioModel(BaseModel): output_dir: str runtime: RuntimeConfigModel = Field(default_factory=RuntimeConfigModel) optimization_advisor: bool = False + metric_weights: dict[str, float] = Field(default_factory=dict) + doc_weights: dict[str, float] = Field(default_factory=dict) @field_validator("metrics") @classmethod diff --git a/rag_eval/shared/models.py b/rag_eval/shared/models.py index 9284788..98c6fa3 100644 --- a/rag_eval/shared/models.py +++ b/rag_eval/shared/models.py @@ -77,6 +77,8 @@ class Scenario: app_adapter: AppAdapterConfig | None = None source_path: Path | None = None optimization_advisor: bool = False + metric_weights: dict[str, float] = field(default_factory=dict) + doc_weights: dict[str, float] = field(default_factory=dict) def snapshot(self) -> dict[str, Any]: """Serialize the scenario into a reporting-friendly dictionary snapshot.""" diff --git a/tests/test_offline_eval.py b/tests/test_offline_eval.py index 29018ad..384b665 100644 --- a/tests/test_offline_eval.py +++ b/tests/test_offline_eval.py @@ -80,6 +80,64 @@ class ScenarioAndDatasetTests(unittest.TestCase): self.assertTrue(scenario.dataset.path.name.endswith(".csv")) self.assertTrue(scenario.output_dir.name == "sample-offline-baseline") + def test_load_scenario_metric_and_doc_weights(self) -> None: + """load_scenario passes metric_weights and doc_weights into Scenario.""" + import os + import tempfile + + import yaml + + from rag_eval.config.loader import load_scenario + + payload = { + "scenario_name": "w-test", + "mode": "offline", + "dataset": "nonexistent.csv", + "judge_model": "m", + "embedding_model": "e", + "metrics": ["faithfulness"], + "output_dir": "out", + "metric_weights": {"faithfulness": 0.7}, + "doc_weights": {"doc.pdf": 2.0}, + } + with tempfile.NamedTemporaryFile(suffix=".yaml", mode="w", encoding="utf-8", delete=False) as f: + yaml.dump(payload, f, allow_unicode=True) + tmp_path = f.name + try: + scenario = load_scenario(tmp_path) + assert scenario.metric_weights == {"faithfulness": 0.7} + assert scenario.doc_weights == {"doc.pdf": 2.0} + finally: + os.unlink(tmp_path) + + def test_load_scenario_defaults_to_empty_weights(self) -> None: + """load_scenario defaults metric_weights and doc_weights to empty dicts.""" + import os + import tempfile + + import yaml + + from rag_eval.config.loader import load_scenario + + payload = { + "scenario_name": "no-w", + "mode": "offline", + "dataset": "nonexistent.csv", + "judge_model": "m", + "embedding_model": "e", + "metrics": ["faithfulness"], + "output_dir": "out", + } + with tempfile.NamedTemporaryFile(suffix=".yaml", mode="w", encoding="utf-8", delete=False) as f: + yaml.dump(payload, f, allow_unicode=True) + tmp_path = f.name + try: + scenario = load_scenario(tmp_path) + assert scenario.metric_weights == {} + assert scenario.doc_weights == {} + finally: + os.unlink(tmp_path) + def test_scenario_snapshot_serializes_path_static_kwargs(self) -> None: scenario = load_scenario("scenarios/online/sample-pdf-question-bank-online.yaml") snapshot = scenario.snapshot()