Files
siemens_ragas/webapp/models.py

373 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Pydantic response models for the evaluation console HTTP API."""
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
def _utcnow_iso() -> str:
return datetime.now(timezone.utc).isoformat()
class RunSummary(BaseModel):
"""Compact description of a single evaluation run for list views."""
run_id: str
scenario_name: str
mode: str = ""
judge_model: str = ""
embedding_model: str = ""
started_at: str = ""
finished_at: str = ""
dataset: str = ""
total_samples: int = 0
valid_samples: int = 0
invalid_samples: int = 0
metrics: list[str] = Field(default_factory=list)
metric_means: dict[str, float | None] = Field(default_factory=dict)
output_path: str = ""
class GroupStat(BaseModel):
"""Mean metric values for one slice of samples grouped by a metadata field."""
key: str
count: int
means: dict[str, float | None] = Field(default_factory=dict)
class DistributionBin(BaseModel):
"""One histogram bucket of sample counts for a single metric."""
label: str
lower: float
upper: float
count: int
class SampleScore(BaseModel):
"""Per-sample row used for the lowest-score review table."""
sample_id: str
question: str = ""
contexts: list[str] = Field(default_factory=list)
answer: str = ""
ground_truth: str = ""
language: str = ""
difficulty: str = ""
question_type: str = ""
metrics: dict[str, float | None] = Field(default_factory=dict)
mean_score: float | None = None
error: str = ""
class ReportData(BaseModel):
"""Aggregated report payload rendered by the report detail page."""
metrics: list[str] = Field(default_factory=list)
metric_means: dict[str, float | None] = Field(default_factory=dict)
distributions: dict[str, list[DistributionBin]] = Field(default_factory=dict)
groupings: dict[str, list[GroupStat]] = Field(default_factory=dict)
lowest_samples: list[SampleScore] = Field(default_factory=list)
summary_markdown: str = ""
advice_markdown: str = "" # optimization_advice.md content (empty if not generated)
weighted_score_mean: float | None = Field(
default=None,
description="加权综合得分均值metric_weights × doc_weights 共同作用)。",
)
metric_weights: dict[str, float] = Field(
default_factory=dict,
description="该次运行使用的指标权重配置(来自 scenario.snapshot.yaml",
)
doc_weights: dict[str, float] = Field(
default_factory=dict,
description="该次运行使用的文档权重配置(来自 scenario.snapshot.yaml",
)
class RunDetail(BaseModel):
"""Full payload for a single run: summary metadata plus the report."""
summary: RunSummary
report: ReportData
class ScenarioInfo(BaseModel):
"""One discoverable scenario YAML file that can be evaluated from the UI."""
path: str
scenario_name: str = ""
mode: str = ""
dataset: str = ""
judge_model: str = ""
metrics: list[str] = Field(default_factory=list)
error: str = ""
metric_weights: dict[str, float] = Field(
default_factory=dict,
description="从场景 YAML 读取的指标权重配置,供前端权重面板预填。",
)
doc_weights: dict[str, float] = Field(
default_factory=dict,
description="从场景 YAML 读取的文档权重配置,供前端权重面板预填。",
)
class TaskStatus(BaseModel):
"""State of a background evaluation task tracked by the task manager."""
task_id: str
scenario_path: str
status: str
logs: list[str] = Field(default_factory=list)
run_id: str | None = None
error: str | None = None
created_at: str = ""
finished_at: str = ""
class TriggerEvaluationRequest(BaseModel):
"""Request body for launching an evaluation run from the UI."""
scenario_path: str
class TriggerEvaluationResponse(BaseModel):
"""Response returned immediately after queuing an evaluation task."""
task_id: str
class LLMProfile(BaseModel):
"""A named LLM connection configuration that can be reused across tasks."""
profile_id: str
name: str
model: str
base_url: str
api_key: str
timeout_seconds: int = 30
created_at: str = Field(default_factory=_utcnow_iso)
updated_at: str = Field(default_factory=_utcnow_iso)
class CreateProfileRequest(BaseModel):
"""Request body for creating or updating an LLM profile."""
name: str
model: str
base_url: str
api_key: str
timeout_seconds: int = 30
class ProfileApplyRequest(BaseModel):
"""Request body to patch LLM profile selections into a scenario YAML."""
scenario_path: str
judge_profile_id: str | None = None
answer_profile_id: str | None = None
dataset_profile_id: str | None = None
metric_weights: dict[str, float] | None = Field(
default=None,
description="指标权重映射,如 {\"faithfulness\": 0.35}。为 null 时不修改 YAML。",
)
doc_weights: dict[str, float] | None = Field(
default=None,
description="文档权重映射,如 {\"doc.pdf\": 2.0}。为 null 时不修改 YAML。",
)
class ProfileApplyResponse(BaseModel):
"""Response after patching a scenario YAML with profile settings."""
scenario_path: str
patched_fields: list[str] = Field(default_factory=list)
class ProfileProbeRequest(BaseModel):
"""Inline credentials for testing LLM connectivity without saving a profile."""
model: str
base_url: str
api_key: str
timeout_seconds: int = 30
class ProfileTestResponse(BaseModel):
"""Result of a LLM connectivity test."""
ok: bool
message: str
latency_ms: int | None = None
def jsonable(value: Any) -> Any:
"""Convert NaN/inf floats into None so the payload stays valid JSON."""
import math
if isinstance(value, float):
if math.isnan(value) or math.isinf(value):
return None
return value
if isinstance(value, dict):
return {key: jsonable(item) for key, item in value.items()}
if isinstance(value, list):
return [jsonable(item) for item in value]
return value
# ---------------------------------------------------------------------------
# Full pipeline (build + eval) job models
# ---------------------------------------------------------------------------
class PipelineJobRequest(BaseModel):
"""Request body for launching an end-to-end build + evaluation pipeline job."""
model_config = ConfigDict(
json_schema_extra={
"examples": [
{
"summary": "西门子 CT 文档评估(完整参数)",
"value": {
"docs_path": "datasets/siemens-pdfs",
"job_name": "siemens-ct-eval-2026",
"generation_model": "qwen3.6-plus",
"answer_model": "deepseek-v4-flash",
"judge_model": "deepseek-v4-flash",
"embedding_model": "text-embedding-v3",
"max_questions_per_document": 10,
"max_source_chunks_per_question": 3,
"max_documents": None,
"max_samples": None,
"metrics": [
"faithfulness",
"answer_relevancy",
"context_recall",
"context_precision",
],
"optimization_advisor": False,
"failure_mode": "skip",
},
},
{
"summary": "快速冒烟测试(仅 2 份文档、5 道题)",
"value": {
"docs_path": "datasets/siemens-pdfs",
"job_name": "smoke-test",
"generation_model": "qwen3.6-plus",
"answer_model": "deepseek-v4-flash",
"judge_model": "deepseek-v4-flash",
"embedding_model": "text-embedding-v3",
"max_questions_per_document": 5,
"max_source_chunks_per_question": 3,
"max_documents": 2,
"max_samples": 10,
"metrics": ["faithfulness", "answer_relevancy"],
"optimization_advisor": False,
"failure_mode": "skip",
},
},
]
}
)
docs_path: str = Field(
description="PDF 文档所在文件夹的绝对路径或相对于仓库根目录的相对路径。"
)
job_name: str = Field(
default="",
description="任务显示名称;留空时系统自动生成唯一标识。",
)
generation_model: str = Field(
default="qwen3.6-plus",
description="用于从文档片段生成草稿题库的 LLM 模型名称。",
)
answer_model: str = Field(
default="deepseek-v4-flash",
description="在线评估时调用的答题 LLM 模型名称siemens_pdf_qa adapter",
)
judge_model: str = Field(
default="deepseek-v4-flash",
description="RAGAS 指标评分时使用的 Judge LLM 模型名称。",
)
embedding_model: str = Field(
default="text-embedding-v3",
description="RAGAS context-recall / context-precision 使用的 Embedding 模型名称。",
)
max_questions_per_document: int = Field(
default=10, gt=0,
description="每份 PDF 文档最多生成的草稿题目数量。",
)
max_source_chunks_per_question: int = Field(
default=3, gt=0,
description="每道题目最多引用的文档片段source chunk数量。",
)
max_documents: int | None = Field(
default=None, gt=0,
description="限制处理的 PDF 文件数量上限(冒烟测试时使用)。",
)
max_samples: int | None = Field(
default=None, gt=0,
description="限制评估的题目数量上限(冒烟测试时使用)。",
)
metrics: list[str] = Field(
default_factory=lambda: [
"faithfulness",
"answer_relevancy",
"context_recall",
"context_precision",
],
description=(
"需要计算的 RAGAS 指标列表。"
"可选值faithfulness, answer_relevancy, context_recall, "
"context_precision, noise_sensitivity, factual_correctness, semantic_similarity。"
),
)
optimization_advisor: bool = Field(
default=False,
description="为 True 时启用 RAGAS 优化建议模块,生成 optimization_advice.md。",
)
failure_mode: str = Field(
default="skip",
description="PDF 解析失败时的处理策略skip跳过继续或 fail立即中止",
)
class PipelineResult(BaseModel):
"""Artifact locations and statistics for a completed pipeline run."""
build_artifact_dir: str = Field(description="题库生成阶段的产物根目录路径。")
dataset_csv: str = Field(description="生成的草稿题库 CSV 文件路径(评估输入)。")
source_chunks_jsonl: str = Field(description="文档片段索引文件路径(在线评估 adapter 使用)。")
total_questions: int = Field(description="成功生成的有效题目总数。")
parse_failures: int = Field(description="文档解析失败的 PDF 数量。")
eval_run_id: str = Field(description="RAGAS 评估运行 ID。")
eval_output_dir: str = Field(description="RAGAS 评估产物根目录路径。")
scores_csv: str = Field(description="每道题目逐项评分的 CSV 文件路径。")
summary_md: str = Field(description="评估结果摘要 Markdown 文件路径。")
class PipelineJobStatus(BaseModel):
"""State of one end-to-end pipeline job."""
job_id: str = Field(description="任务唯一标识符。")
job_name: str = Field(description="任务显示名称。")
status: str = Field(description="任务状态queued | running | completed | failed。")
phase: str = Field(default="idle", description="当前执行阶段idle | parsing_documents | generating_questions | evaluating | done。")
logs: list[str] = Field(default_factory=list, description="实时日志行列表。")
result: PipelineResult | None = Field(default=None, description="任务完成后填充的产物路径与统计信息。")
error: str | None = Field(default=None, description="失败时的错误信息。")
created_at: str = Field(default="", description="任务创建时间ISO 8601 UTC")
finished_at: str = Field(default="", description="任务结束时间ISO 8601 UTC")
class PipelineJobResponse(BaseModel):
"""Immediate response returned after a pipeline job is queued."""
job_id: str = Field(description="任务唯一标识符,用于后续轮询状态。")
job_name: str = Field(description="任务显示名称。")
status: str = Field(default="queued", description="初始状态,通常为 queued。")