85 lines
2.7 KiB
Python
85 lines
2.7 KiB
Python
"""Patch LLM profile settings into scenario YAML files in-place.
|
|
|
|
Only the fields that correspond to a provided (non-None) profile are touched.
|
|
All other fields and structure are preserved as much as PyYAML allows
|
|
(comments are lost on round-trip, which is an accepted trade-off).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import yaml
|
|
|
|
from webapp.models import LLMProfile
|
|
|
|
|
|
def _repo_root() -> Path:
|
|
return Path(__file__).resolve().parents[2]
|
|
|
|
|
|
def _resolve_scenario_path(path_str: str) -> Path:
|
|
"""Resolve a scenario path; absolute paths are used as-is."""
|
|
candidate = Path(path_str)
|
|
if candidate.is_absolute():
|
|
return candidate
|
|
return (_repo_root() / candidate).resolve()
|
|
|
|
|
|
def apply_profiles_to_scenario(
|
|
scenario_path: str,
|
|
judge_profile: LLMProfile | None,
|
|
answer_profile: LLMProfile | None,
|
|
dataset_profile: LLMProfile | None,
|
|
metric_weights: dict[str, float] | None = None,
|
|
doc_weights: dict[str, float] | None = None,
|
|
_resolve_absolute: bool = False,
|
|
) -> list[str]:
|
|
"""Patch the YAML file at *scenario_path* with the supplied profiles and weights.
|
|
|
|
Returns a list of dotted field names that were actually patched.
|
|
Setting *_resolve_absolute=True* skips repo-root resolution (used in tests).
|
|
"""
|
|
if _resolve_absolute:
|
|
resolved = Path(scenario_path)
|
|
else:
|
|
resolved = _resolve_scenario_path(scenario_path)
|
|
|
|
if not resolved.exists():
|
|
raise FileNotFoundError(f"Scenario file not found: {resolved}")
|
|
|
|
data: dict[str, Any] = yaml.safe_load(resolved.read_text(encoding="utf-8")) or {}
|
|
patched: list[str] = []
|
|
|
|
if judge_profile is not None:
|
|
data["judge_model"] = judge_profile.model
|
|
patched.append("judge_model")
|
|
|
|
if answer_profile is not None:
|
|
adapter = data.get("app_adapter")
|
|
if isinstance(adapter, dict):
|
|
static_kwargs = adapter.setdefault("static_kwargs", {})
|
|
static_kwargs["model"] = answer_profile.model
|
|
patched.append("app_adapter.static_kwargs.model")
|
|
|
|
if dataset_profile is not None:
|
|
generation = data.get("generation")
|
|
if isinstance(generation, dict):
|
|
generation["model"] = dataset_profile.model
|
|
patched.append("generation.model")
|
|
|
|
if metric_weights is not None:
|
|
data["metric_weights"] = dict(metric_weights)
|
|
patched.append("metric_weights")
|
|
|
|
if doc_weights is not None:
|
|
data["doc_weights"] = dict(doc_weights)
|
|
patched.append("doc_weights")
|
|
|
|
resolved.write_text(
|
|
yaml.dump(data, allow_unicode=True, default_flow_style=False, sort_keys=False),
|
|
encoding="utf-8",
|
|
)
|
|
return patched
|