feat: yaml_patcher and ProfileApplyRequest support metric_weights and doc_weights
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -137,3 +137,104 @@ def test_apply_no_profiles_returns_empty(tmp_path):
|
|||||||
_resolve_absolute=True,
|
_resolve_absolute=True,
|
||||||
)
|
)
|
||||||
assert patched == []
|
assert patched == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_metric_weights_patches_yaml(tmp_path):
|
||||||
|
"""Applying metric_weights writes them into the YAML."""
|
||||||
|
import yaml as yaml_lib
|
||||||
|
import pytest
|
||||||
|
scenario_file = tmp_path / "w-scenario.yaml"
|
||||||
|
scenario_file.write_text(
|
||||||
|
"scenario_name: test\nmode: offline\njudge_model: m\nembedding_model: e\n"
|
||||||
|
"dataset: d.csv\nmetrics:\n- faithfulness\noutput_dir: out\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
from webapp.services.yaml_patcher import apply_profiles_to_scenario
|
||||||
|
patched = apply_profiles_to_scenario(
|
||||||
|
scenario_path=str(scenario_file),
|
||||||
|
judge_profile=None, answer_profile=None, dataset_profile=None,
|
||||||
|
metric_weights={"faithfulness": 0.7, "context_recall": 0.3},
|
||||||
|
_resolve_absolute=True,
|
||||||
|
)
|
||||||
|
assert "metric_weights" in patched
|
||||||
|
data = yaml_lib.safe_load(scenario_file.read_text())
|
||||||
|
assert abs(data["metric_weights"]["faithfulness"] - 0.7) < 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_doc_weights_patches_yaml(tmp_path):
|
||||||
|
"""Applying doc_weights writes them into the YAML."""
|
||||||
|
import yaml as yaml_lib
|
||||||
|
scenario_file = tmp_path / "dw-scenario.yaml"
|
||||||
|
scenario_file.write_text(
|
||||||
|
"scenario_name: test\nmode: offline\njudge_model: m\nembedding_model: e\n"
|
||||||
|
"dataset: d.csv\nmetrics:\n- faithfulness\noutput_dir: out\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
from webapp.services.yaml_patcher import apply_profiles_to_scenario
|
||||||
|
patched = apply_profiles_to_scenario(
|
||||||
|
scenario_path=str(scenario_file),
|
||||||
|
judge_profile=None, answer_profile=None, dataset_profile=None,
|
||||||
|
doc_weights={"doc.pdf": 2.0},
|
||||||
|
_resolve_absolute=True,
|
||||||
|
)
|
||||||
|
assert "doc_weights" in patched
|
||||||
|
data = yaml_lib.safe_load(scenario_file.read_text())
|
||||||
|
assert abs(data["doc_weights"]["doc.pdf"] - 2.0) < 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Connectivity test endpoint tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
def test_probe_connectivity_success(client):
|
||||||
|
"""POST /api/llm-profiles/probe returns ok=True on successful completion."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [MagicMock()]
|
||||||
|
with patch("webapp.api.llm_profiles.OpenAI") as MockOpenAI:
|
||||||
|
MockOpenAI.return_value.chat.completions.create.return_value = mock_response
|
||||||
|
resp = client.post("/api/llm-profiles/probe", json={
|
||||||
|
"model": "test-model",
|
||||||
|
"base_url": "http://x/v1",
|
||||||
|
"api_key": "sk-test",
|
||||||
|
})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["ok"] is True
|
||||||
|
assert data["latency_ms"] is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_probe_connectivity_failure(client):
|
||||||
|
"""POST /api/llm-profiles/probe returns ok=False when the LLM call raises."""
|
||||||
|
with patch("webapp.api.llm_profiles.OpenAI") as MockOpenAI:
|
||||||
|
MockOpenAI.return_value.chat.completions.create.side_effect = Exception("connection refused")
|
||||||
|
resp = client.post("/api/llm-profiles/probe", json={
|
||||||
|
"model": "test-model",
|
||||||
|
"base_url": "http://x/v1",
|
||||||
|
"api_key": "sk-test",
|
||||||
|
})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["ok"] is False
|
||||||
|
assert "connection refused" in data["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_test_saved_profile_success(client):
|
||||||
|
"""POST /api/llm-profiles/{id}/test returns ok=True for a saved profile."""
|
||||||
|
body = {"name": "T", "model": "m1", "base_url": "http://x/v1", "api_key": "k"}
|
||||||
|
pid = client.post("/api/llm-profiles", json=body).json()["profile_id"]
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [MagicMock()]
|
||||||
|
with patch("webapp.api.llm_profiles.OpenAI") as MockOpenAI:
|
||||||
|
MockOpenAI.return_value.chat.completions.create.return_value = mock_response
|
||||||
|
resp = client.post(f"/api/llm-profiles/{pid}/test")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["ok"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_test_nonexistent_profile_returns_404(client):
|
||||||
|
"""POST /api/llm-profiles/{id}/test returns 404 for unknown profile id."""
|
||||||
|
resp = client.post("/api/llm-profiles/nonexistent/test")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|||||||
@@ -2,13 +2,18 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
from webapp.models import (
|
from webapp.models import (
|
||||||
CreateProfileRequest,
|
CreateProfileRequest,
|
||||||
LLMProfile,
|
LLMProfile,
|
||||||
ProfileApplyRequest,
|
ProfileApplyRequest,
|
||||||
ProfileApplyResponse,
|
ProfileApplyResponse,
|
||||||
|
ProfileProbeRequest,
|
||||||
|
ProfileTestResponse,
|
||||||
)
|
)
|
||||||
from webapp.services.profile_manager import profile_manager
|
from webapp.services.profile_manager import profile_manager
|
||||||
from webapp.services.yaml_patcher import apply_profiles_to_scenario
|
from webapp.services.yaml_patcher import apply_profiles_to_scenario
|
||||||
@@ -16,6 +21,43 @@ from webapp.services.yaml_patcher import apply_profiles_to_scenario
|
|||||||
router = APIRouter(prefix="/api/llm-profiles", tags=["llm-profiles"])
|
router = APIRouter(prefix="/api/llm-profiles", tags=["llm-profiles"])
|
||||||
|
|
||||||
|
|
||||||
|
def _do_connectivity_test(
|
||||||
|
model: str,
|
||||||
|
base_url: str,
|
||||||
|
api_key: str,
|
||||||
|
timeout_seconds: int,
|
||||||
|
) -> ProfileTestResponse:
|
||||||
|
"""Send a minimal chat completion request and return the test result."""
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url.rstrip("/"),
|
||||||
|
timeout=float(timeout_seconds),
|
||||||
|
)
|
||||||
|
t0 = time.monotonic()
|
||||||
|
try:
|
||||||
|
client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
max_tokens=1,
|
||||||
|
)
|
||||||
|
latency_ms = int((time.monotonic() - t0) * 1000)
|
||||||
|
return ProfileTestResponse(ok=True, message="连接成功", latency_ms=latency_ms)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
latency_ms = int((time.monotonic() - t0) * 1000)
|
||||||
|
return ProfileTestResponse(ok=False, message=str(exc), latency_ms=latency_ms)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/probe", response_model=ProfileTestResponse, tags=["llm-profiles"])
|
||||||
|
def probe_connectivity(request: ProfileProbeRequest) -> ProfileTestResponse:
|
||||||
|
"""Test LLM connectivity with inline credentials (no saved profile required)."""
|
||||||
|
return _do_connectivity_test(
|
||||||
|
model=request.model,
|
||||||
|
base_url=request.base_url,
|
||||||
|
api_key=request.api_key,
|
||||||
|
timeout_seconds=request.timeout_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=dict)
|
@router.get("", response_model=dict)
|
||||||
def list_profiles() -> dict:
|
def list_profiles() -> dict:
|
||||||
"""Return all saved LLM profiles."""
|
"""Return all saved LLM profiles."""
|
||||||
@@ -59,6 +101,20 @@ def delete_profile(profile_id: str) -> dict:
|
|||||||
return {"deleted": True}
|
return {"deleted": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{profile_id}/test", response_model=ProfileTestResponse)
|
||||||
|
def test_profile(profile_id: str) -> ProfileTestResponse:
|
||||||
|
"""Test LLM connectivity for a saved profile."""
|
||||||
|
profile = profile_manager.get(profile_id)
|
||||||
|
if profile is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Profile not found: {profile_id}")
|
||||||
|
return _do_connectivity_test(
|
||||||
|
model=profile.model,
|
||||||
|
base_url=profile.base_url,
|
||||||
|
api_key=profile.api_key,
|
||||||
|
timeout_seconds=profile.timeout_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/apply", response_model=ProfileApplyResponse)
|
@router.post("/apply", response_model=ProfileApplyResponse)
|
||||||
def apply_profiles(request: ProfileApplyRequest) -> ProfileApplyResponse:
|
def apply_profiles(request: ProfileApplyRequest) -> ProfileApplyResponse:
|
||||||
"""Patch selected LLM profiles into the target scenario YAML file."""
|
"""Patch selected LLM profiles into the target scenario YAML file."""
|
||||||
@@ -89,6 +145,8 @@ def apply_profiles(request: ProfileApplyRequest) -> ProfileApplyResponse:
|
|||||||
judge_profile=role_profiles["judge"],
|
judge_profile=role_profiles["judge"],
|
||||||
answer_profile=role_profiles["answer"],
|
answer_profile=role_profiles["answer"],
|
||||||
dataset_profile=role_profiles["dataset"],
|
dataset_profile=role_profiles["dataset"],
|
||||||
|
metric_weights=request.metric_weights,
|
||||||
|
doc_weights=request.doc_weights,
|
||||||
)
|
)
|
||||||
return ProfileApplyResponse(
|
return ProfileApplyResponse(
|
||||||
scenario_path=request.scenario_path,
|
scenario_path=request.scenario_path,
|
||||||
|
|||||||
180
webapp/models.py
180
webapp/models.py
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
def _utcnow_iso() -> str:
|
def _utcnow_iso() -> str:
|
||||||
@@ -150,6 +150,14 @@ class ProfileApplyRequest(BaseModel):
|
|||||||
judge_profile_id: str | None = None
|
judge_profile_id: str | None = None
|
||||||
answer_profile_id: str | None = None
|
answer_profile_id: str | None = None
|
||||||
dataset_profile_id: str | None = None
|
dataset_profile_id: str | None = None
|
||||||
|
metric_weights: dict[str, float] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="指标权重映射,如 {\"faithfulness\": 0.35}。为 null 时不修改 YAML。",
|
||||||
|
)
|
||||||
|
doc_weights: dict[str, float] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="文档权重映射,如 {\"doc.pdf\": 2.0}。为 null 时不修改 YAML。",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ProfileApplyResponse(BaseModel):
|
class ProfileApplyResponse(BaseModel):
|
||||||
@@ -159,6 +167,23 @@ class ProfileApplyResponse(BaseModel):
|
|||||||
patched_fields: list[str] = Field(default_factory=list)
|
patched_fields: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class ProfileProbeRequest(BaseModel):
|
||||||
|
"""Inline credentials for testing LLM connectivity without saving a profile."""
|
||||||
|
|
||||||
|
model: str
|
||||||
|
base_url: str
|
||||||
|
api_key: str
|
||||||
|
timeout_seconds: int = 30
|
||||||
|
|
||||||
|
|
||||||
|
class ProfileTestResponse(BaseModel):
|
||||||
|
"""Result of a LLM connectivity test."""
|
||||||
|
|
||||||
|
ok: bool
|
||||||
|
message: str
|
||||||
|
latency_ms: int | None = None
|
||||||
|
|
||||||
|
|
||||||
def jsonable(value: Any) -> Any:
|
def jsonable(value: Any) -> Any:
|
||||||
"""Convert NaN/inf floats into None so the payload stays valid JSON."""
|
"""Convert NaN/inf floats into None so the payload stays valid JSON."""
|
||||||
import math
|
import math
|
||||||
@@ -172,3 +197,156 @@ def jsonable(value: Any) -> Any:
|
|||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
return [jsonable(item) for item in value]
|
return [jsonable(item) for item in value]
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Full pipeline (build + eval) job models
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class PipelineJobRequest(BaseModel):
|
||||||
|
"""Request body for launching an end-to-end build + evaluation pipeline job."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
json_schema_extra={
|
||||||
|
"examples": [
|
||||||
|
{
|
||||||
|
"summary": "西门子 CT 文档评估(完整参数)",
|
||||||
|
"value": {
|
||||||
|
"docs_path": "datasets/siemens-pdfs",
|
||||||
|
"job_name": "siemens-ct-eval-2026",
|
||||||
|
"generation_model": "qwen3.6-plus",
|
||||||
|
"answer_model": "deepseek-v4-flash",
|
||||||
|
"judge_model": "deepseek-v4-flash",
|
||||||
|
"embedding_model": "text-embedding-v3",
|
||||||
|
"max_questions_per_document": 10,
|
||||||
|
"max_source_chunks_per_question": 3,
|
||||||
|
"max_documents": None,
|
||||||
|
"max_samples": None,
|
||||||
|
"metrics": [
|
||||||
|
"faithfulness",
|
||||||
|
"answer_relevancy",
|
||||||
|
"context_recall",
|
||||||
|
"context_precision",
|
||||||
|
],
|
||||||
|
"optimization_advisor": False,
|
||||||
|
"failure_mode": "skip",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"summary": "快速冒烟测试(仅 2 份文档、5 道题)",
|
||||||
|
"value": {
|
||||||
|
"docs_path": "datasets/siemens-pdfs",
|
||||||
|
"job_name": "smoke-test",
|
||||||
|
"generation_model": "qwen3.6-plus",
|
||||||
|
"answer_model": "deepseek-v4-flash",
|
||||||
|
"judge_model": "deepseek-v4-flash",
|
||||||
|
"embedding_model": "text-embedding-v3",
|
||||||
|
"max_questions_per_document": 5,
|
||||||
|
"max_source_chunks_per_question": 3,
|
||||||
|
"max_documents": 2,
|
||||||
|
"max_samples": 10,
|
||||||
|
"metrics": ["faithfulness", "answer_relevancy"],
|
||||||
|
"optimization_advisor": False,
|
||||||
|
"failure_mode": "skip",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
docs_path: str = Field(
|
||||||
|
description="PDF 文档所在文件夹的绝对路径或相对于仓库根目录的相对路径。"
|
||||||
|
)
|
||||||
|
job_name: str = Field(
|
||||||
|
default="",
|
||||||
|
description="任务显示名称;留空时系统自动生成唯一标识。",
|
||||||
|
)
|
||||||
|
generation_model: str = Field(
|
||||||
|
default="qwen3.6-plus",
|
||||||
|
description="用于从文档片段生成草稿题库的 LLM 模型名称。",
|
||||||
|
)
|
||||||
|
answer_model: str = Field(
|
||||||
|
default="deepseek-v4-flash",
|
||||||
|
description="在线评估时调用的答题 LLM 模型名称(siemens_pdf_qa adapter)。",
|
||||||
|
)
|
||||||
|
judge_model: str = Field(
|
||||||
|
default="deepseek-v4-flash",
|
||||||
|
description="RAGAS 指标评分时使用的 Judge LLM 模型名称。",
|
||||||
|
)
|
||||||
|
embedding_model: str = Field(
|
||||||
|
default="text-embedding-v3",
|
||||||
|
description="RAGAS context-recall / context-precision 使用的 Embedding 模型名称。",
|
||||||
|
)
|
||||||
|
max_questions_per_document: int = Field(
|
||||||
|
default=10, gt=0,
|
||||||
|
description="每份 PDF 文档最多生成的草稿题目数量。",
|
||||||
|
)
|
||||||
|
max_source_chunks_per_question: int = Field(
|
||||||
|
default=3, gt=0,
|
||||||
|
description="每道题目最多引用的文档片段(source chunk)数量。",
|
||||||
|
)
|
||||||
|
max_documents: int | None = Field(
|
||||||
|
default=None, gt=0,
|
||||||
|
description="限制处理的 PDF 文件数量上限(冒烟测试时使用)。",
|
||||||
|
)
|
||||||
|
max_samples: int | None = Field(
|
||||||
|
default=None, gt=0,
|
||||||
|
description="限制评估的题目数量上限(冒烟测试时使用)。",
|
||||||
|
)
|
||||||
|
metrics: list[str] = Field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
"faithfulness",
|
||||||
|
"answer_relevancy",
|
||||||
|
"context_recall",
|
||||||
|
"context_precision",
|
||||||
|
],
|
||||||
|
description=(
|
||||||
|
"需要计算的 RAGAS 指标列表。"
|
||||||
|
"可选值:faithfulness, answer_relevancy, context_recall, "
|
||||||
|
"context_precision, noise_sensitivity, factual_correctness, semantic_similarity。"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
optimization_advisor: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="为 True 时启用 RAGAS 优化建议模块,生成 optimization_advice.md。",
|
||||||
|
)
|
||||||
|
failure_mode: str = Field(
|
||||||
|
default="skip",
|
||||||
|
description="PDF 解析失败时的处理策略:skip(跳过继续)或 fail(立即中止)。",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineResult(BaseModel):
|
||||||
|
"""Artifact locations and statistics for a completed pipeline run."""
|
||||||
|
|
||||||
|
build_artifact_dir: str = Field(description="题库生成阶段的产物根目录路径。")
|
||||||
|
dataset_csv: str = Field(description="生成的草稿题库 CSV 文件路径(评估输入)。")
|
||||||
|
source_chunks_jsonl: str = Field(description="文档片段索引文件路径(在线评估 adapter 使用)。")
|
||||||
|
total_questions: int = Field(description="成功生成的有效题目总数。")
|
||||||
|
parse_failures: int = Field(description="文档解析失败的 PDF 数量。")
|
||||||
|
eval_run_id: str = Field(description="RAGAS 评估运行 ID。")
|
||||||
|
eval_output_dir: str = Field(description="RAGAS 评估产物根目录路径。")
|
||||||
|
scores_csv: str = Field(description="每道题目逐项评分的 CSV 文件路径。")
|
||||||
|
summary_md: str = Field(description="评估结果摘要 Markdown 文件路径。")
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineJobStatus(BaseModel):
|
||||||
|
"""State of one end-to-end pipeline job."""
|
||||||
|
|
||||||
|
job_id: str = Field(description="任务唯一标识符。")
|
||||||
|
job_name: str = Field(description="任务显示名称。")
|
||||||
|
status: str = Field(description="任务状态:queued | running | completed | failed。")
|
||||||
|
phase: str = Field(default="idle", description="当前执行阶段:idle | parsing_documents | generating_questions | evaluating | done。")
|
||||||
|
logs: list[str] = Field(default_factory=list, description="实时日志行列表。")
|
||||||
|
result: PipelineResult | None = Field(default=None, description="任务完成后填充的产物路径与统计信息。")
|
||||||
|
error: str | None = Field(default=None, description="失败时的错误信息。")
|
||||||
|
created_at: str = Field(default="", description="任务创建时间(ISO 8601 UTC)。")
|
||||||
|
finished_at: str = Field(default="", description="任务结束时间(ISO 8601 UTC)。")
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineJobResponse(BaseModel):
|
||||||
|
"""Immediate response returned after a pipeline job is queued."""
|
||||||
|
|
||||||
|
job_id: str = Field(description="任务唯一标识符,用于后续轮询状态。")
|
||||||
|
job_name: str = Field(description="任务显示名称。")
|
||||||
|
status: str = Field(default="queued", description="初始状态,通常为 queued。")
|
||||||
|
|||||||
@@ -32,9 +32,11 @@ def apply_profiles_to_scenario(
|
|||||||
judge_profile: LLMProfile | None,
|
judge_profile: LLMProfile | None,
|
||||||
answer_profile: LLMProfile | None,
|
answer_profile: LLMProfile | None,
|
||||||
dataset_profile: LLMProfile | None,
|
dataset_profile: LLMProfile | None,
|
||||||
|
metric_weights: dict[str, float] | None = None,
|
||||||
|
doc_weights: dict[str, float] | None = None,
|
||||||
_resolve_absolute: bool = False,
|
_resolve_absolute: bool = False,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Patch the YAML file at *scenario_path* with the supplied profiles.
|
"""Patch the YAML file at *scenario_path* with the supplied profiles and weights.
|
||||||
|
|
||||||
Returns a list of dotted field names that were actually patched.
|
Returns a list of dotted field names that were actually patched.
|
||||||
Setting *_resolve_absolute=True* skips repo-root resolution (used in tests).
|
Setting *_resolve_absolute=True* skips repo-root resolution (used in tests).
|
||||||
@@ -67,6 +69,14 @@ def apply_profiles_to_scenario(
|
|||||||
generation["model"] = dataset_profile.model
|
generation["model"] = dataset_profile.model
|
||||||
patched.append("generation.model")
|
patched.append("generation.model")
|
||||||
|
|
||||||
|
if metric_weights is not None:
|
||||||
|
data["metric_weights"] = dict(metric_weights)
|
||||||
|
patched.append("metric_weights")
|
||||||
|
|
||||||
|
if doc_weights is not None:
|
||||||
|
data["doc_weights"] = dict(doc_weights)
|
||||||
|
patched.append("doc_weights")
|
||||||
|
|
||||||
resolved.write_text(
|
resolved.write_text(
|
||||||
yaml.dump(data, allow_unicode=True, default_flow_style=False, sort_keys=False),
|
yaml.dump(data, allow_unicode=True, default_flow_style=False, sort_keys=False),
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
|
|||||||
Reference in New Issue
Block a user