Files
siemens_ragas/webapp/services/scenario_scanner.py

97 lines
3.2 KiB
Python

"""Discover scenario YAML files that can be launched from the console.
Scanning is intentionally tolerant: a malformed scenario file is reported with
an error string rather than aborting the whole listing, so the UI can show the
user which files are runnable and which need fixing.
"""
from __future__ import annotations
from pathlib import Path
import yaml
from webapp.models import ScenarioInfo
def _repo_root() -> Path:
"""Return the siemens_ragas repository root (parent of the webapp package)."""
return Path(__file__).resolve().parents[2]
def _scenarios_root() -> Path:
"""Return the conventional scenarios/ directory inside the repository."""
return _repo_root() / "scenarios"
def _summarize_scenario(path: Path) -> ScenarioInfo:
"""Read a scenario file into a compact info object, capturing parse errors."""
relative = path.relative_to(_repo_root()).as_posix()
try:
payload = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
except (OSError, yaml.YAMLError) as exc:
return ScenarioInfo(path=relative, error=f"无法解析: {exc}")
if not isinstance(payload, dict):
return ScenarioInfo(path=relative, error="场景文件格式不是 YAML 映射。")
metrics = payload.get("metrics")
metric_list = [str(item) for item in metrics] if isinstance(metrics, list) else []
raw_metric_weights = payload.get("metric_weights") or {}
raw_doc_weights = payload.get("doc_weights") or {}
metric_weights = {
str(k): float(v) for k, v in raw_metric_weights.items()
if isinstance(v, (int, float))
}
doc_weights = {
str(k): float(v) for k, v in raw_doc_weights.items()
if isinstance(v, (int, float))
}
return ScenarioInfo(
path=relative,
scenario_name=str(payload.get("scenario_name", "")),
mode=str(payload.get("mode", "")),
dataset=str(payload.get("dataset", "")),
judge_model=str(payload.get("judge_model", "")),
metrics=metric_list,
metric_weights=metric_weights,
doc_weights=doc_weights,
)
def list_scenarios() -> list[ScenarioInfo]:
"""Return every scenario YAML under scenarios/, sorted by path."""
root = _scenarios_root()
if not root.is_dir():
return []
scenarios: list[ScenarioInfo] = []
for path in sorted(root.rglob("*.yaml")):
scenarios.append(_summarize_scenario(path))
for path in sorted(root.rglob("*.yml")):
scenarios.append(_summarize_scenario(path))
return scenarios
def resolve_scenario_path(relative_or_absolute: str) -> Path | None:
"""Resolve a user-supplied scenario path safely within the repository.
Only paths that live inside the repository's scenarios/ directory are
accepted, which prevents the trigger endpoint from reading arbitrary files.
"""
root = _repo_root()
candidate = Path(relative_or_absolute)
resolved = candidate if candidate.is_absolute() else (root / candidate)
try:
resolved = resolved.resolve()
except OSError:
return None
scenarios_root = _scenarios_root().resolve()
if scenarios_root not in resolved.parents and resolved != scenarios_root:
return None
if not resolved.is_file():
return None
return resolved