feat: yaml_patcher and ProfileApplyRequest support metric_weights and doc_weights
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -137,3 +137,104 @@ def test_apply_no_profiles_returns_empty(tmp_path):
|
||||
_resolve_absolute=True,
|
||||
)
|
||||
assert patched == []
|
||||
|
||||
|
||||
def test_apply_metric_weights_patches_yaml(tmp_path):
|
||||
"""Applying metric_weights writes them into the YAML."""
|
||||
import yaml as yaml_lib
|
||||
import pytest
|
||||
scenario_file = tmp_path / "w-scenario.yaml"
|
||||
scenario_file.write_text(
|
||||
"scenario_name: test\nmode: offline\njudge_model: m\nembedding_model: e\n"
|
||||
"dataset: d.csv\nmetrics:\n- faithfulness\noutput_dir: out\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
from webapp.services.yaml_patcher import apply_profiles_to_scenario
|
||||
patched = apply_profiles_to_scenario(
|
||||
scenario_path=str(scenario_file),
|
||||
judge_profile=None, answer_profile=None, dataset_profile=None,
|
||||
metric_weights={"faithfulness": 0.7, "context_recall": 0.3},
|
||||
_resolve_absolute=True,
|
||||
)
|
||||
assert "metric_weights" in patched
|
||||
data = yaml_lib.safe_load(scenario_file.read_text())
|
||||
assert abs(data["metric_weights"]["faithfulness"] - 0.7) < 1e-9
|
||||
|
||||
|
||||
def test_apply_doc_weights_patches_yaml(tmp_path):
|
||||
"""Applying doc_weights writes them into the YAML."""
|
||||
import yaml as yaml_lib
|
||||
scenario_file = tmp_path / "dw-scenario.yaml"
|
||||
scenario_file.write_text(
|
||||
"scenario_name: test\nmode: offline\njudge_model: m\nembedding_model: e\n"
|
||||
"dataset: d.csv\nmetrics:\n- faithfulness\noutput_dir: out\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
from webapp.services.yaml_patcher import apply_profiles_to_scenario
|
||||
patched = apply_profiles_to_scenario(
|
||||
scenario_path=str(scenario_file),
|
||||
judge_profile=None, answer_profile=None, dataset_profile=None,
|
||||
doc_weights={"doc.pdf": 2.0},
|
||||
_resolve_absolute=True,
|
||||
)
|
||||
assert "doc_weights" in patched
|
||||
data = yaml_lib.safe_load(scenario_file.read_text())
|
||||
assert abs(data["doc_weights"]["doc.pdf"] - 2.0) < 1e-9
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Connectivity test endpoint tests
|
||||
# ---------------------------------------------------------------------------
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def test_probe_connectivity_success(client):
|
||||
"""POST /api/llm-profiles/probe returns ok=True on successful completion."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
with patch("webapp.api.llm_profiles.OpenAI") as MockOpenAI:
|
||||
MockOpenAI.return_value.chat.completions.create.return_value = mock_response
|
||||
resp = client.post("/api/llm-profiles/probe", json={
|
||||
"model": "test-model",
|
||||
"base_url": "http://x/v1",
|
||||
"api_key": "sk-test",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["ok"] is True
|
||||
assert data["latency_ms"] is not None
|
||||
|
||||
|
||||
def test_probe_connectivity_failure(client):
|
||||
"""POST /api/llm-profiles/probe returns ok=False when the LLM call raises."""
|
||||
with patch("webapp.api.llm_profiles.OpenAI") as MockOpenAI:
|
||||
MockOpenAI.return_value.chat.completions.create.side_effect = Exception("connection refused")
|
||||
resp = client.post("/api/llm-profiles/probe", json={
|
||||
"model": "test-model",
|
||||
"base_url": "http://x/v1",
|
||||
"api_key": "sk-test",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["ok"] is False
|
||||
assert "connection refused" in data["message"]
|
||||
|
||||
|
||||
def test_test_saved_profile_success(client):
|
||||
"""POST /api/llm-profiles/{id}/test returns ok=True for a saved profile."""
|
||||
body = {"name": "T", "model": "m1", "base_url": "http://x/v1", "api_key": "k"}
|
||||
pid = client.post("/api/llm-profiles", json=body).json()["profile_id"]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
with patch("webapp.api.llm_profiles.OpenAI") as MockOpenAI:
|
||||
MockOpenAI.return_value.chat.completions.create.return_value = mock_response
|
||||
resp = client.post(f"/api/llm-profiles/{pid}/test")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["ok"] is True
|
||||
|
||||
|
||||
def test_test_nonexistent_profile_returns_404(client):
|
||||
"""POST /api/llm-profiles/{id}/test returns 404 for unknown profile id."""
|
||||
resp = client.post("/api/llm-profiles/nonexistent/test")
|
||||
assert resp.status_code == 404
|
||||
|
||||
@@ -2,13 +2,18 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from openai import OpenAI
|
||||
|
||||
from webapp.models import (
|
||||
CreateProfileRequest,
|
||||
LLMProfile,
|
||||
ProfileApplyRequest,
|
||||
ProfileApplyResponse,
|
||||
ProfileProbeRequest,
|
||||
ProfileTestResponse,
|
||||
)
|
||||
from webapp.services.profile_manager import profile_manager
|
||||
from webapp.services.yaml_patcher import apply_profiles_to_scenario
|
||||
@@ -16,6 +21,43 @@ from webapp.services.yaml_patcher import apply_profiles_to_scenario
|
||||
router = APIRouter(prefix="/api/llm-profiles", tags=["llm-profiles"])
|
||||
|
||||
|
||||
def _do_connectivity_test(
|
||||
model: str,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
timeout_seconds: int,
|
||||
) -> ProfileTestResponse:
|
||||
"""Send a minimal chat completion request and return the test result."""
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url.rstrip("/"),
|
||||
timeout=float(timeout_seconds),
|
||||
)
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
max_tokens=1,
|
||||
)
|
||||
latency_ms = int((time.monotonic() - t0) * 1000)
|
||||
return ProfileTestResponse(ok=True, message="连接成功", latency_ms=latency_ms)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
latency_ms = int((time.monotonic() - t0) * 1000)
|
||||
return ProfileTestResponse(ok=False, message=str(exc), latency_ms=latency_ms)
|
||||
|
||||
|
||||
@router.post("/probe", response_model=ProfileTestResponse, tags=["llm-profiles"])
|
||||
def probe_connectivity(request: ProfileProbeRequest) -> ProfileTestResponse:
|
||||
"""Test LLM connectivity with inline credentials (no saved profile required)."""
|
||||
return _do_connectivity_test(
|
||||
model=request.model,
|
||||
base_url=request.base_url,
|
||||
api_key=request.api_key,
|
||||
timeout_seconds=request.timeout_seconds,
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=dict)
|
||||
def list_profiles() -> dict:
|
||||
"""Return all saved LLM profiles."""
|
||||
@@ -59,6 +101,20 @@ def delete_profile(profile_id: str) -> dict:
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
@router.post("/{profile_id}/test", response_model=ProfileTestResponse)
|
||||
def test_profile(profile_id: str) -> ProfileTestResponse:
|
||||
"""Test LLM connectivity for a saved profile."""
|
||||
profile = profile_manager.get(profile_id)
|
||||
if profile is None:
|
||||
raise HTTPException(status_code=404, detail=f"Profile not found: {profile_id}")
|
||||
return _do_connectivity_test(
|
||||
model=profile.model,
|
||||
base_url=profile.base_url,
|
||||
api_key=profile.api_key,
|
||||
timeout_seconds=profile.timeout_seconds,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/apply", response_model=ProfileApplyResponse)
|
||||
def apply_profiles(request: ProfileApplyRequest) -> ProfileApplyResponse:
|
||||
"""Patch selected LLM profiles into the target scenario YAML file."""
|
||||
@@ -89,6 +145,8 @@ def apply_profiles(request: ProfileApplyRequest) -> ProfileApplyResponse:
|
||||
judge_profile=role_profiles["judge"],
|
||||
answer_profile=role_profiles["answer"],
|
||||
dataset_profile=role_profiles["dataset"],
|
||||
metric_weights=request.metric_weights,
|
||||
doc_weights=request.doc_weights,
|
||||
)
|
||||
return ProfileApplyResponse(
|
||||
scenario_path=request.scenario_path,
|
||||
|
||||
180
webapp/models.py
180
webapp/models.py
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
def _utcnow_iso() -> str:
|
||||
@@ -150,6 +150,14 @@ class ProfileApplyRequest(BaseModel):
|
||||
judge_profile_id: str | None = None
|
||||
answer_profile_id: str | None = None
|
||||
dataset_profile_id: str | None = None
|
||||
metric_weights: dict[str, float] | None = Field(
|
||||
default=None,
|
||||
description="指标权重映射,如 {\"faithfulness\": 0.35}。为 null 时不修改 YAML。",
|
||||
)
|
||||
doc_weights: dict[str, float] | None = Field(
|
||||
default=None,
|
||||
description="文档权重映射,如 {\"doc.pdf\": 2.0}。为 null 时不修改 YAML。",
|
||||
)
|
||||
|
||||
|
||||
class ProfileApplyResponse(BaseModel):
|
||||
@@ -159,6 +167,23 @@ class ProfileApplyResponse(BaseModel):
|
||||
patched_fields: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ProfileProbeRequest(BaseModel):
|
||||
"""Inline credentials for testing LLM connectivity without saving a profile."""
|
||||
|
||||
model: str
|
||||
base_url: str
|
||||
api_key: str
|
||||
timeout_seconds: int = 30
|
||||
|
||||
|
||||
class ProfileTestResponse(BaseModel):
|
||||
"""Result of a LLM connectivity test."""
|
||||
|
||||
ok: bool
|
||||
message: str
|
||||
latency_ms: int | None = None
|
||||
|
||||
|
||||
def jsonable(value: Any) -> Any:
|
||||
"""Convert NaN/inf floats into None so the payload stays valid JSON."""
|
||||
import math
|
||||
@@ -172,3 +197,156 @@ def jsonable(value: Any) -> Any:
|
||||
if isinstance(value, list):
|
||||
return [jsonable(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Full pipeline (build + eval) job models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class PipelineJobRequest(BaseModel):
|
||||
"""Request body for launching an end-to-end build + evaluation pipeline job."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"examples": [
|
||||
{
|
||||
"summary": "西门子 CT 文档评估(完整参数)",
|
||||
"value": {
|
||||
"docs_path": "datasets/siemens-pdfs",
|
||||
"job_name": "siemens-ct-eval-2026",
|
||||
"generation_model": "qwen3.6-plus",
|
||||
"answer_model": "deepseek-v4-flash",
|
||||
"judge_model": "deepseek-v4-flash",
|
||||
"embedding_model": "text-embedding-v3",
|
||||
"max_questions_per_document": 10,
|
||||
"max_source_chunks_per_question": 3,
|
||||
"max_documents": None,
|
||||
"max_samples": None,
|
||||
"metrics": [
|
||||
"faithfulness",
|
||||
"answer_relevancy",
|
||||
"context_recall",
|
||||
"context_precision",
|
||||
],
|
||||
"optimization_advisor": False,
|
||||
"failure_mode": "skip",
|
||||
},
|
||||
},
|
||||
{
|
||||
"summary": "快速冒烟测试(仅 2 份文档、5 道题)",
|
||||
"value": {
|
||||
"docs_path": "datasets/siemens-pdfs",
|
||||
"job_name": "smoke-test",
|
||||
"generation_model": "qwen3.6-plus",
|
||||
"answer_model": "deepseek-v4-flash",
|
||||
"judge_model": "deepseek-v4-flash",
|
||||
"embedding_model": "text-embedding-v3",
|
||||
"max_questions_per_document": 5,
|
||||
"max_source_chunks_per_question": 3,
|
||||
"max_documents": 2,
|
||||
"max_samples": 10,
|
||||
"metrics": ["faithfulness", "answer_relevancy"],
|
||||
"optimization_advisor": False,
|
||||
"failure_mode": "skip",
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
docs_path: str = Field(
|
||||
description="PDF 文档所在文件夹的绝对路径或相对于仓库根目录的相对路径。"
|
||||
)
|
||||
job_name: str = Field(
|
||||
default="",
|
||||
description="任务显示名称;留空时系统自动生成唯一标识。",
|
||||
)
|
||||
generation_model: str = Field(
|
||||
default="qwen3.6-plus",
|
||||
description="用于从文档片段生成草稿题库的 LLM 模型名称。",
|
||||
)
|
||||
answer_model: str = Field(
|
||||
default="deepseek-v4-flash",
|
||||
description="在线评估时调用的答题 LLM 模型名称(siemens_pdf_qa adapter)。",
|
||||
)
|
||||
judge_model: str = Field(
|
||||
default="deepseek-v4-flash",
|
||||
description="RAGAS 指标评分时使用的 Judge LLM 模型名称。",
|
||||
)
|
||||
embedding_model: str = Field(
|
||||
default="text-embedding-v3",
|
||||
description="RAGAS context-recall / context-precision 使用的 Embedding 模型名称。",
|
||||
)
|
||||
max_questions_per_document: int = Field(
|
||||
default=10, gt=0,
|
||||
description="每份 PDF 文档最多生成的草稿题目数量。",
|
||||
)
|
||||
max_source_chunks_per_question: int = Field(
|
||||
default=3, gt=0,
|
||||
description="每道题目最多引用的文档片段(source chunk)数量。",
|
||||
)
|
||||
max_documents: int | None = Field(
|
||||
default=None, gt=0,
|
||||
description="限制处理的 PDF 文件数量上限(冒烟测试时使用)。",
|
||||
)
|
||||
max_samples: int | None = Field(
|
||||
default=None, gt=0,
|
||||
description="限制评估的题目数量上限(冒烟测试时使用)。",
|
||||
)
|
||||
metrics: list[str] = Field(
|
||||
default_factory=lambda: [
|
||||
"faithfulness",
|
||||
"answer_relevancy",
|
||||
"context_recall",
|
||||
"context_precision",
|
||||
],
|
||||
description=(
|
||||
"需要计算的 RAGAS 指标列表。"
|
||||
"可选值:faithfulness, answer_relevancy, context_recall, "
|
||||
"context_precision, noise_sensitivity, factual_correctness, semantic_similarity。"
|
||||
),
|
||||
)
|
||||
optimization_advisor: bool = Field(
|
||||
default=False,
|
||||
description="为 True 时启用 RAGAS 优化建议模块,生成 optimization_advice.md。",
|
||||
)
|
||||
failure_mode: str = Field(
|
||||
default="skip",
|
||||
description="PDF 解析失败时的处理策略:skip(跳过继续)或 fail(立即中止)。",
|
||||
)
|
||||
|
||||
|
||||
class PipelineResult(BaseModel):
|
||||
"""Artifact locations and statistics for a completed pipeline run."""
|
||||
|
||||
build_artifact_dir: str = Field(description="题库生成阶段的产物根目录路径。")
|
||||
dataset_csv: str = Field(description="生成的草稿题库 CSV 文件路径(评估输入)。")
|
||||
source_chunks_jsonl: str = Field(description="文档片段索引文件路径(在线评估 adapter 使用)。")
|
||||
total_questions: int = Field(description="成功生成的有效题目总数。")
|
||||
parse_failures: int = Field(description="文档解析失败的 PDF 数量。")
|
||||
eval_run_id: str = Field(description="RAGAS 评估运行 ID。")
|
||||
eval_output_dir: str = Field(description="RAGAS 评估产物根目录路径。")
|
||||
scores_csv: str = Field(description="每道题目逐项评分的 CSV 文件路径。")
|
||||
summary_md: str = Field(description="评估结果摘要 Markdown 文件路径。")
|
||||
|
||||
|
||||
class PipelineJobStatus(BaseModel):
|
||||
"""State of one end-to-end pipeline job."""
|
||||
|
||||
job_id: str = Field(description="任务唯一标识符。")
|
||||
job_name: str = Field(description="任务显示名称。")
|
||||
status: str = Field(description="任务状态:queued | running | completed | failed。")
|
||||
phase: str = Field(default="idle", description="当前执行阶段:idle | parsing_documents | generating_questions | evaluating | done。")
|
||||
logs: list[str] = Field(default_factory=list, description="实时日志行列表。")
|
||||
result: PipelineResult | None = Field(default=None, description="任务完成后填充的产物路径与统计信息。")
|
||||
error: str | None = Field(default=None, description="失败时的错误信息。")
|
||||
created_at: str = Field(default="", description="任务创建时间(ISO 8601 UTC)。")
|
||||
finished_at: str = Field(default="", description="任务结束时间(ISO 8601 UTC)。")
|
||||
|
||||
|
||||
class PipelineJobResponse(BaseModel):
|
||||
"""Immediate response returned after a pipeline job is queued."""
|
||||
|
||||
job_id: str = Field(description="任务唯一标识符,用于后续轮询状态。")
|
||||
job_name: str = Field(description="任务显示名称。")
|
||||
status: str = Field(default="queued", description="初始状态,通常为 queued。")
|
||||
|
||||
@@ -32,9 +32,11 @@ def apply_profiles_to_scenario(
|
||||
judge_profile: LLMProfile | None,
|
||||
answer_profile: LLMProfile | None,
|
||||
dataset_profile: LLMProfile | None,
|
||||
metric_weights: dict[str, float] | None = None,
|
||||
doc_weights: dict[str, float] | None = None,
|
||||
_resolve_absolute: bool = False,
|
||||
) -> list[str]:
|
||||
"""Patch the YAML file at *scenario_path* with the supplied profiles.
|
||||
"""Patch the YAML file at *scenario_path* with the supplied profiles and weights.
|
||||
|
||||
Returns a list of dotted field names that were actually patched.
|
||||
Setting *_resolve_absolute=True* skips repo-root resolution (used in tests).
|
||||
@@ -67,6 +69,14 @@ def apply_profiles_to_scenario(
|
||||
generation["model"] = dataset_profile.model
|
||||
patched.append("generation.model")
|
||||
|
||||
if metric_weights is not None:
|
||||
data["metric_weights"] = dict(metric_weights)
|
||||
patched.append("metric_weights")
|
||||
|
||||
if doc_weights is not None:
|
||||
data["doc_weights"] = dict(doc_weights)
|
||||
patched.append("doc_weights")
|
||||
|
||||
resolved.write_text(
|
||||
yaml.dump(data, allow_unicode=True, default_flow_style=False, sort_keys=False),
|
||||
encoding="utf-8",
|
||||
|
||||
Reference in New Issue
Block a user