diff --git a/tests/webapp/test_llm_profiles_api.py b/tests/webapp/test_llm_profiles_api.py index 6cc2c4a..31d5f7d 100644 --- a/tests/webapp/test_llm_profiles_api.py +++ b/tests/webapp/test_llm_profiles_api.py @@ -65,3 +65,75 @@ def test_update_nonexistent(client): def test_delete_nonexistent(client): resp = client.delete("/api/llm-profiles/nope") assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# YAML patcher tests +# --------------------------------------------------------------------------- +import yaml as yaml_lib +from webapp.services.yaml_patcher import apply_profiles_to_scenario +from webapp.models import LLMProfile + + +def test_apply_judge_profile(tmp_path): + """Applying a judge profile patches judge_model in the YAML.""" + scenario_file = tmp_path / "test-scenario.yaml" + scenario_file.write_text( + "scenario_name: test\nmode: offline\njudge_model: old-model\nembedding_model: emb\n" + "dataset: data.csv\nmetrics:\n- faithfulness\noutput_dir: outputs/test\n", + encoding="utf-8", + ) + judge_p = LLMProfile( + profile_id="x", name="J", model="new-model", + base_url="http://x/v1", api_key="k", created_at="t", updated_at="t", + ) + patched = apply_profiles_to_scenario( + scenario_path=str(scenario_file), + judge_profile=judge_p, + answer_profile=None, + dataset_profile=None, + _resolve_absolute=True, + ) + assert "judge_model" in patched + data = yaml_lib.safe_load(scenario_file.read_text()) + assert data["judge_model"] == "new-model" + + +def test_apply_answer_profile(tmp_path): + """Applying an answer profile patches app_adapter.static_kwargs.model.""" + scenario_file = tmp_path / "online.yaml" + scenario_file.write_text( + "scenario_name: online\nmode: online\njudge_model: j\nembedding_model: emb\n" + "dataset: d.csv\nmetrics:\n- faithfulness\noutput_dir: out\n" + "app_adapter:\n type: python\n callable: apps.foo:run\n" + " static_kwargs:\n model: old\n source_chunks_path: chunks.jsonl\n", + encoding="utf-8", + ) + answer_p = LLMProfile( + profile_id="y", name="A", model="new-answer-model", + base_url="http://x/v1", api_key="k", created_at="t", updated_at="t", + ) + patched = apply_profiles_to_scenario( + scenario_path=str(scenario_file), + judge_profile=None, + answer_profile=answer_p, + dataset_profile=None, + _resolve_absolute=True, + ) + assert "app_adapter.static_kwargs.model" in patched + data = yaml_lib.safe_load(scenario_file.read_text()) + assert data["app_adapter"]["static_kwargs"]["model"] == "new-answer-model" + + +def test_apply_no_profiles_returns_empty(tmp_path): + """When no profiles are given, no fields are patched.""" + scenario_file = tmp_path / "noop.yaml" + scenario_file.write_text("scenario_name: noop\njudge_model: m\n", encoding="utf-8") + patched = apply_profiles_to_scenario( + scenario_path=str(scenario_file), + judge_profile=None, + answer_profile=None, + dataset_profile=None, + _resolve_absolute=True, + ) + assert patched == [] diff --git a/webapp/services/yaml_patcher.py b/webapp/services/yaml_patcher.py new file mode 100644 index 0000000..2dbae20 --- /dev/null +++ b/webapp/services/yaml_patcher.py @@ -0,0 +1,74 @@ +"""Patch LLM profile settings into scenario YAML files in-place. + +Only the fields that correspond to a provided (non-None) profile are touched. +All other fields and structure are preserved as much as PyYAML allows +(comments are lost on round-trip, which is an accepted trade-off). +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import yaml + +from webapp.models import LLMProfile + + +def _repo_root() -> Path: + return Path(__file__).resolve().parents[2] + + +def _resolve_scenario_path(path_str: str) -> Path: + """Resolve a scenario path; absolute paths are used as-is.""" + candidate = Path(path_str) + if candidate.is_absolute(): + return candidate + return (_repo_root() / candidate).resolve() + + +def apply_profiles_to_scenario( + scenario_path: str, + judge_profile: LLMProfile | None, + answer_profile: LLMProfile | None, + dataset_profile: LLMProfile | None, + _resolve_absolute: bool = False, +) -> list[str]: + """Patch the YAML file at *scenario_path* with the supplied profiles. + + Returns a list of dotted field names that were actually patched. + Setting *_resolve_absolute=True* skips repo-root resolution (used in tests). + """ + if _resolve_absolute: + resolved = Path(scenario_path) + else: + resolved = _resolve_scenario_path(scenario_path) + + if not resolved.exists(): + raise FileNotFoundError(f"Scenario file not found: {resolved}") + + data: dict[str, Any] = yaml.safe_load(resolved.read_text(encoding="utf-8")) or {} + patched: list[str] = [] + + if judge_profile is not None: + data["judge_model"] = judge_profile.model + patched.append("judge_model") + + if answer_profile is not None: + adapter = data.get("app_adapter") + if isinstance(adapter, dict): + static_kwargs = adapter.setdefault("static_kwargs", {}) + static_kwargs["model"] = answer_profile.model + patched.append("app_adapter.static_kwargs.model") + + if dataset_profile is not None: + generation = data.get("generation") + if isinstance(generation, dict): + generation["model"] = dataset_profile.model + patched.append("generation.model") + + resolved.write_text( + yaml.dump(data, allow_unicode=True, default_flow_style=False, sort_keys=False), + encoding="utf-8", + ) + return patched