97 lines
3.2 KiB
Python
97 lines
3.2 KiB
Python
"""Discover scenario YAML files that can be launched from the console.
|
|
|
|
Scanning is intentionally tolerant: a malformed scenario file is reported with
|
|
an error string rather than aborting the whole listing, so the UI can show the
|
|
user which files are runnable and which need fixing.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
|
|
import yaml
|
|
|
|
from webapp.models import ScenarioInfo
|
|
|
|
|
|
def _repo_root() -> Path:
|
|
"""Return the siemens_ragas repository root (parent of the webapp package)."""
|
|
return Path(__file__).resolve().parents[2]
|
|
|
|
|
|
def _scenarios_root() -> Path:
|
|
"""Return the conventional scenarios/ directory inside the repository."""
|
|
return _repo_root() / "scenarios"
|
|
|
|
|
|
def _summarize_scenario(path: Path) -> ScenarioInfo:
|
|
"""Read a scenario file into a compact info object, capturing parse errors."""
|
|
relative = path.relative_to(_repo_root()).as_posix()
|
|
try:
|
|
payload = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
|
|
except (OSError, yaml.YAMLError) as exc:
|
|
return ScenarioInfo(path=relative, error=f"无法解析: {exc}")
|
|
|
|
if not isinstance(payload, dict):
|
|
return ScenarioInfo(path=relative, error="场景文件格式不是 YAML 映射。")
|
|
|
|
metrics = payload.get("metrics")
|
|
metric_list = [str(item) for item in metrics] if isinstance(metrics, list) else []
|
|
raw_metric_weights = payload.get("metric_weights") or {}
|
|
raw_doc_weights = payload.get("doc_weights") or {}
|
|
metric_weights = {
|
|
str(k): float(v) for k, v in raw_metric_weights.items()
|
|
if isinstance(v, (int, float))
|
|
}
|
|
doc_weights = {
|
|
str(k): float(v) for k, v in raw_doc_weights.items()
|
|
if isinstance(v, (int, float))
|
|
}
|
|
|
|
return ScenarioInfo(
|
|
path=relative,
|
|
scenario_name=str(payload.get("scenario_name", "")),
|
|
mode=str(payload.get("mode", "")),
|
|
dataset=str(payload.get("dataset", "")),
|
|
judge_model=str(payload.get("judge_model", "")),
|
|
metrics=metric_list,
|
|
metric_weights=metric_weights,
|
|
doc_weights=doc_weights,
|
|
)
|
|
|
|
|
|
def list_scenarios() -> list[ScenarioInfo]:
|
|
"""Return every scenario YAML under scenarios/, sorted by path."""
|
|
root = _scenarios_root()
|
|
if not root.is_dir():
|
|
return []
|
|
|
|
scenarios: list[ScenarioInfo] = []
|
|
for path in sorted(root.rglob("*.yaml")):
|
|
scenarios.append(_summarize_scenario(path))
|
|
for path in sorted(root.rglob("*.yml")):
|
|
scenarios.append(_summarize_scenario(path))
|
|
return scenarios
|
|
|
|
|
|
def resolve_scenario_path(relative_or_absolute: str) -> Path | None:
|
|
"""Resolve a user-supplied scenario path safely within the repository.
|
|
|
|
Only paths that live inside the repository's scenarios/ directory are
|
|
accepted, which prevents the trigger endpoint from reading arbitrary files.
|
|
"""
|
|
root = _repo_root()
|
|
candidate = Path(relative_or_absolute)
|
|
resolved = candidate if candidate.is_absolute() else (root / candidate)
|
|
try:
|
|
resolved = resolved.resolve()
|
|
except OSError:
|
|
return None
|
|
|
|
scenarios_root = _scenarios_root().resolve()
|
|
if scenarios_root not in resolved.parents and resolved != scenarios_root:
|
|
return None
|
|
if not resolved.is_file():
|
|
return None
|
|
return resolved
|