Files
siemens_ragas/webapp/services/yaml_patcher.py

85 lines
2.7 KiB
Python
Raw Normal View History

"""Patch LLM profile settings into scenario YAML files in-place.
Only the fields that correspond to a provided (non-None) profile are touched.
All other fields and structure are preserved as much as PyYAML allows
(comments are lost on round-trip, which is an accepted trade-off).
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import yaml
from webapp.models import LLMProfile
def _repo_root() -> Path:
return Path(__file__).resolve().parents[2]
def _resolve_scenario_path(path_str: str) -> Path:
"""Resolve a scenario path; absolute paths are used as-is."""
candidate = Path(path_str)
if candidate.is_absolute():
return candidate
return (_repo_root() / candidate).resolve()
def apply_profiles_to_scenario(
scenario_path: str,
judge_profile: LLMProfile | None,
answer_profile: LLMProfile | None,
dataset_profile: LLMProfile | None,
metric_weights: dict[str, float] | None = None,
doc_weights: dict[str, float] | None = None,
_resolve_absolute: bool = False,
) -> list[str]:
"""Patch the YAML file at *scenario_path* with the supplied profiles and weights.
Returns a list of dotted field names that were actually patched.
Setting *_resolve_absolute=True* skips repo-root resolution (used in tests).
"""
if _resolve_absolute:
resolved = Path(scenario_path)
else:
resolved = _resolve_scenario_path(scenario_path)
if not resolved.exists():
raise FileNotFoundError(f"Scenario file not found: {resolved}")
data: dict[str, Any] = yaml.safe_load(resolved.read_text(encoding="utf-8")) or {}
patched: list[str] = []
if judge_profile is not None:
data["judge_model"] = judge_profile.model
patched.append("judge_model")
if answer_profile is not None:
adapter = data.get("app_adapter")
if isinstance(adapter, dict):
static_kwargs = adapter.setdefault("static_kwargs", {})
static_kwargs["model"] = answer_profile.model
patched.append("app_adapter.static_kwargs.model")
if dataset_profile is not None:
generation = data.get("generation")
if isinstance(generation, dict):
generation["model"] = dataset_profile.model
patched.append("generation.model")
if metric_weights is not None:
data["metric_weights"] = dict(metric_weights)
patched.append("metric_weights")
if doc_weights is not None:
data["doc_weights"] = dict(doc_weights)
patched.append("doc_weights")
resolved.write_text(
yaml.dump(data, allow_unicode=True, default_flow_style=False, sort_keys=False),
encoding="utf-8",
)
return patched