feat: add ScoreRequest/ScoreResponse models and SCORE_API_TOKEN setting
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
134
webapp/models.py
134
webapp/models.py
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
def _utcnow_iso() -> str:
|
||||
@@ -370,3 +370,135 @@ class PipelineJobResponse(BaseModel):
|
||||
job_id: str = Field(description="任务唯一标识符,用于后续轮询状态。")
|
||||
job_name: str = Field(description="任务显示名称。")
|
||||
status: str = Field(default="queued", description="初始状态,通常为 queued。")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dify 实时评分 API 模型
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 需要 ground_truth 才能计算的指标集合
|
||||
_GT_DEPENDENT_METRICS: frozenset[str] = frozenset({
|
||||
"context_recall",
|
||||
"factual_correctness",
|
||||
"semantic_similarity",
|
||||
"noise_sensitivity",
|
||||
})
|
||||
|
||||
# 所有合法指标名称
|
||||
_VALID_METRICS: frozenset[str] = frozenset({
|
||||
"faithfulness",
|
||||
"answer_relevancy",
|
||||
"context_recall",
|
||||
"context_precision",
|
||||
"noise_sensitivity",
|
||||
"factual_correctness",
|
||||
"semantic_similarity",
|
||||
})
|
||||
|
||||
_DEFAULT_SCORE_METRICS: list[str] = [
|
||||
"faithfulness",
|
||||
"answer_relevancy",
|
||||
"context_recall",
|
||||
"context_precision",
|
||||
]
|
||||
|
||||
|
||||
class ScoreRequest(BaseModel):
|
||||
"""Request body for the real-time single-sample scoring endpoint."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"examples": [
|
||||
{
|
||||
"summary": "基础评分请求",
|
||||
"value": {
|
||||
"question": "双源CT的时间分辨率是多少?",
|
||||
"answer": "双源CT的单扇区时间分辨率为75ms。",
|
||||
"contexts": "双源CT采用两套管-探测器系统 |||| 单扇区采集旋转135度",
|
||||
"ground_truth": "双源CT单扇区时间分辨率为75ms,需旋转135度。",
|
||||
"context_separator": " |||| ",
|
||||
"metrics": [
|
||||
"faithfulness",
|
||||
"answer_relevancy",
|
||||
"context_recall",
|
||||
"context_precision",
|
||||
],
|
||||
"judge_model": "deepseek-v4-flash",
|
||||
"embedding_model": "text-embedding-v3",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
question: str = Field(description="问题文本。")
|
||||
answer: str = Field(description="待评分的回答。")
|
||||
contexts: str = Field(
|
||||
description="检索上下文字符串,多段之间用 context_separator 拼接。"
|
||||
)
|
||||
ground_truth: str | None = Field(
|
||||
default=None,
|
||||
description="标准参考答案(可选)。缺失时自动跳过需要它的指标。",
|
||||
)
|
||||
context_separator: str = Field(
|
||||
default=" |||| ",
|
||||
description="contexts 字段中段落分隔符,默认为四个竖线两侧各一空格。",
|
||||
)
|
||||
metrics: list[str] = Field(
|
||||
default_factory=lambda: list(_DEFAULT_SCORE_METRICS),
|
||||
description="需要计算的 RAGAS 指标列表。",
|
||||
)
|
||||
judge_model: str | None = Field(
|
||||
default=None,
|
||||
description="Judge LLM 模型名称;为 null 时使用 .env 中的 RAGAS_JUDGE_MODEL。",
|
||||
)
|
||||
embedding_model: str | None = Field(
|
||||
default=None,
|
||||
description="Embedding 模型名称;为 null 时使用 .env 中的 RAGAS_EMBEDDING_MODEL。",
|
||||
)
|
||||
|
||||
@field_validator("metrics")
|
||||
@classmethod
|
||||
def validate_metric_names(cls, value: list[str]) -> list[str]:
|
||||
"""Reject any metric name not in the supported registry."""
|
||||
invalid = [metric_name for metric_name in value if metric_name not in _VALID_METRICS]
|
||||
if invalid:
|
||||
raise ValueError(
|
||||
f"不支持的指标名称:{invalid}。"
|
||||
f"合法值:{sorted(_VALID_METRICS)}"
|
||||
)
|
||||
if not value:
|
||||
raise ValueError("metrics 不能为空列表。")
|
||||
return value
|
||||
|
||||
def contexts_as_list(self) -> list[str]:
|
||||
"""Split the contexts string into a list of non-empty fragments."""
|
||||
separator = self.context_separator or " |||| "
|
||||
return [part.strip() for part in self.contexts.split(separator) if part.strip()]
|
||||
|
||||
def effective_metrics(self) -> list[str]:
|
||||
"""Return metrics filtered to exclude GT-dependent ones when ground_truth is absent."""
|
||||
if self.ground_truth is not None:
|
||||
return list(self.metrics)
|
||||
return [metric_name for metric_name in self.metrics if metric_name not in _GT_DEPENDENT_METRICS]
|
||||
|
||||
|
||||
class ScoreResponse(BaseModel):
|
||||
"""Response payload for the real-time scoring endpoint."""
|
||||
|
||||
scores: dict[str, float | None] = Field(
|
||||
description="各指标得分(NaN 或计算失败时为 null)。"
|
||||
)
|
||||
weighted_score: float | None = Field(
|
||||
default=None,
|
||||
description="等权加权综合得分(仅对非 null 指标求均值)。",
|
||||
)
|
||||
latency_ms: int = Field(description="服务端打分耗时(毫秒)。")
|
||||
skipped_metrics: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="因缺少 ground_truth 而跳过的指标名称列表。",
|
||||
)
|
||||
error: str | None = Field(
|
||||
default=None,
|
||||
description="打分异常时的错误信息(HTTP 200 仍返回,scores 为空)。",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user