diff --git a/rag_eval/settings.py b/rag_eval/settings.py index 9cbf17f..ad1db52 100644 --- a/rag_eval/settings.py +++ b/rag_eval/settings.py @@ -52,6 +52,11 @@ class EvaluationSettings(BaseSettings): ) parser_failure_mode: str = Field(default="fail", alias="PARSER_FAILURE_MODE") dataset_generator_model: str | None = Field(default=None, alias="DATASET_GENERATOR_MODEL") + score_api_token: str | None = Field( + default=None, + alias="SCORE_API_TOKEN", + description="Bearer token for /api/score endpoint. Empty = no auth.", + ) @property def openai_client_kwargs(self) -> dict[str, str | float]: diff --git a/tests/webapp/test_score_api.py b/tests/webapp/test_score_api.py new file mode 100644 index 0000000..8f52d5c --- /dev/null +++ b/tests/webapp/test_score_api.py @@ -0,0 +1,128 @@ +"""Tests for POST /api/score endpoint.""" +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from webapp.models import ScoreRequest, ScoreResponse + + +class TestScoreRequest: + def test_minimal_valid_request(self): + """Only required fields — question, answer, contexts.""" + req = ScoreRequest( + question="What is CT?", + answer="CT is imaging.", + contexts="CT uses X-rays.", + ) + assert req.question == "What is CT?" + assert req.contexts == "CT uses X-rays." + assert req.ground_truth is None + assert req.context_separator == " |||| " + assert req.metrics == [ + "faithfulness", + "answer_relevancy", + "context_recall", + "context_precision", + ] + + def test_contexts_split_by_separator(self): + """contexts_as_list() splits on context_separator.""" + req = ScoreRequest( + question="q", + answer="a", + contexts="ctx1 |||| ctx2 |||| ctx3", + context_separator=" |||| ", + ) + assert req.contexts_as_list() == ["ctx1", "ctx2", "ctx3"] + + def test_contexts_split_custom_separator(self): + req = ScoreRequest( + question="q", + answer="a", + contexts="a---b---c", + context_separator="---", + ) + assert req.contexts_as_list() == ["a", "b", "c"] + + def test_contexts_split_single_item(self): + req = ScoreRequest(question="q", answer="a", contexts="only one") + assert req.contexts_as_list() == ["only one"] + + def test_missing_question_raises(self): + with pytest.raises(ValidationError): + ScoreRequest(answer="a", contexts="c") # type: ignore[call-arg] + + def test_missing_answer_raises(self): + with pytest.raises(ValidationError): + ScoreRequest(question="q", contexts="c") # type: ignore[call-arg] + + def test_missing_contexts_raises(self): + with pytest.raises(ValidationError): + ScoreRequest(question="q", answer="a") # type: ignore[call-arg] + + def test_custom_metrics_accepted(self): + req = ScoreRequest( + question="q", + answer="a", + contexts="c", + metrics=["faithfulness"], + ) + assert req.metrics == ["faithfulness"] + + def test_invalid_metric_name_raises(self): + with pytest.raises(ValidationError): + ScoreRequest( + question="q", + answer="a", + contexts="c", + metrics=["not_a_metric"], + ) + + def test_effective_metrics_drops_ground_truth_dependent_when_missing(self): + """Without ground_truth, GT-dependent metrics are excluded.""" + req = ScoreRequest( + question="q", + answer="a", + contexts="c", + metrics=[ + "faithfulness", + "context_recall", + "factual_correctness", + "semantic_similarity", + "noise_sensitivity", + ], + ) + effective = req.effective_metrics() + assert "faithfulness" in effective + assert "context_recall" not in effective + assert "factual_correctness" not in effective + assert "semantic_similarity" not in effective + assert "noise_sensitivity" not in effective + + def test_effective_metrics_keeps_all_when_ground_truth_present(self): + req = ScoreRequest( + question="q", + answer="a", + contexts="c", + ground_truth="gt", + metrics=["faithfulness", "context_recall", "factual_correctness"], + ) + effective = req.effective_metrics() + assert effective == [ + "faithfulness", + "context_recall", + "factual_correctness", + ] + + +class TestScoreResponse: + def test_score_response_structure(self): + resp = ScoreResponse( + scores={"faithfulness": 0.85, "answer_relevancy": None}, + weighted_score=0.85, + latency_ms=1200, + ) + assert resp.scores["faithfulness"] == 0.85 + assert resp.scores["answer_relevancy"] is None + assert resp.latency_ms == 1200 diff --git a/webapp/models.py b/webapp/models.py index 6eaba75..0d9484d 100644 --- a/webapp/models.py +++ b/webapp/models.py @@ -5,7 +5,7 @@ from __future__ import annotations from datetime import datetime, timezone from typing import Any -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, field_validator def _utcnow_iso() -> str: @@ -370,3 +370,135 @@ class PipelineJobResponse(BaseModel): job_id: str = Field(description="任务唯一标识符,用于后续轮询状态。") job_name: str = Field(description="任务显示名称。") status: str = Field(default="queued", description="初始状态,通常为 queued。") + + +# --------------------------------------------------------------------------- +# Dify 实时评分 API 模型 +# --------------------------------------------------------------------------- + +# 需要 ground_truth 才能计算的指标集合 +_GT_DEPENDENT_METRICS: frozenset[str] = frozenset({ + "context_recall", + "factual_correctness", + "semantic_similarity", + "noise_sensitivity", +}) + +# 所有合法指标名称 +_VALID_METRICS: frozenset[str] = frozenset({ + "faithfulness", + "answer_relevancy", + "context_recall", + "context_precision", + "noise_sensitivity", + "factual_correctness", + "semantic_similarity", +}) + +_DEFAULT_SCORE_METRICS: list[str] = [ + "faithfulness", + "answer_relevancy", + "context_recall", + "context_precision", +] + + +class ScoreRequest(BaseModel): + """Request body for the real-time single-sample scoring endpoint.""" + + model_config = ConfigDict( + json_schema_extra={ + "examples": [ + { + "summary": "基础评分请求", + "value": { + "question": "双源CT的时间分辨率是多少?", + "answer": "双源CT的单扇区时间分辨率为75ms。", + "contexts": "双源CT采用两套管-探测器系统 |||| 单扇区采集旋转135度", + "ground_truth": "双源CT单扇区时间分辨率为75ms,需旋转135度。", + "context_separator": " |||| ", + "metrics": [ + "faithfulness", + "answer_relevancy", + "context_recall", + "context_precision", + ], + "judge_model": "deepseek-v4-flash", + "embedding_model": "text-embedding-v3", + }, + } + ] + } + ) + + question: str = Field(description="问题文本。") + answer: str = Field(description="待评分的回答。") + contexts: str = Field( + description="检索上下文字符串,多段之间用 context_separator 拼接。" + ) + ground_truth: str | None = Field( + default=None, + description="标准参考答案(可选)。缺失时自动跳过需要它的指标。", + ) + context_separator: str = Field( + default=" |||| ", + description="contexts 字段中段落分隔符,默认为四个竖线两侧各一空格。", + ) + metrics: list[str] = Field( + default_factory=lambda: list(_DEFAULT_SCORE_METRICS), + description="需要计算的 RAGAS 指标列表。", + ) + judge_model: str | None = Field( + default=None, + description="Judge LLM 模型名称;为 null 时使用 .env 中的 RAGAS_JUDGE_MODEL。", + ) + embedding_model: str | None = Field( + default=None, + description="Embedding 模型名称;为 null 时使用 .env 中的 RAGAS_EMBEDDING_MODEL。", + ) + + @field_validator("metrics") + @classmethod + def validate_metric_names(cls, value: list[str]) -> list[str]: + """Reject any metric name not in the supported registry.""" + invalid = [metric_name for metric_name in value if metric_name not in _VALID_METRICS] + if invalid: + raise ValueError( + f"不支持的指标名称:{invalid}。" + f"合法值:{sorted(_VALID_METRICS)}" + ) + if not value: + raise ValueError("metrics 不能为空列表。") + return value + + def contexts_as_list(self) -> list[str]: + """Split the contexts string into a list of non-empty fragments.""" + separator = self.context_separator or " |||| " + return [part.strip() for part in self.contexts.split(separator) if part.strip()] + + def effective_metrics(self) -> list[str]: + """Return metrics filtered to exclude GT-dependent ones when ground_truth is absent.""" + if self.ground_truth is not None: + return list(self.metrics) + return [metric_name for metric_name in self.metrics if metric_name not in _GT_DEPENDENT_METRICS] + + +class ScoreResponse(BaseModel): + """Response payload for the real-time scoring endpoint.""" + + scores: dict[str, float | None] = Field( + description="各指标得分(NaN 或计算失败时为 null)。" + ) + weighted_score: float | None = Field( + default=None, + description="等权加权综合得分(仅对非 null 指标求均值)。", + ) + latency_ms: int = Field(description="服务端打分耗时(毫秒)。") + skipped_metrics: list[str] = Field( + default_factory=list, + description="因缺少 ground_truth 而跳过的指标名称列表。", + ) + error: str | None = Field( + default=None, + description="打分异常时的错误信息(HTTP 200 仍返回,scores 为空)。", + )