Files
siemens_ragas/webapp/api/llm_profiles.py
wangwei fb420656ec fix: use /embeddings endpoint for embedding models in connectivity test
text-embedding-* and other embedding models must call /embeddings not
/chat/completions. Added _is_embedding_model() heuristic that checks model
name keywords to route to the correct endpoint automatically.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-06-23 14:53:32 +08:00

223 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""CRUD routes for LLM profiles plus the scenario-patching apply endpoint."""
from __future__ import annotations
import logging
import time
from fastapi import APIRouter, HTTPException
from openai import OpenAI
from webapp.models import (
CreateProfileRequest,
LLMProfile,
ProfileApplyRequest,
ProfileApplyResponse,
ProfileProbeRequest,
ProfileTestResponse,
)
from webapp.services.profile_manager import profile_manager
from webapp.services.yaml_patcher import apply_profiles_to_scenario
router = APIRouter(prefix="/api/llm-profiles", tags=["llm-profiles"])
logger = logging.getLogger("webapp.api.llm_profiles")
# 常见 embedding 模型名称关键词,用于自动判断走 /embeddings 端点
_EMBEDDING_MODEL_KEYWORDS = (
"embedding", "embed", "text-search", "text-similarity",
"code-search", "ada-002",
)
def _is_embedding_model(model: str) -> bool:
"""Heuristic: return True if the model name looks like an embedding model."""
return any(kw in model.lower() for kw in _EMBEDDING_MODEL_KEYWORDS)
def _do_connectivity_test(
model: str,
base_url: str,
api_key: str,
timeout_seconds: int,
) -> ProfileTestResponse:
"""Send a minimal request and return the connectivity test result.
- Embedding models → POST /embeddings with a short text
- Chat models → POST /chat/completions, tries max_completion_tokens first
(required by newer models like gpt-5.x), falls back to max_tokens.
"""
client = OpenAI(
api_key=api_key,
base_url=base_url.rstrip("/"),
timeout=float(timeout_seconds),
)
t0 = time.monotonic()
if _is_embedding_model(model):
# Embedding 模型走 /embeddings 端点
try:
client.embeddings.create(model=model, input="test")
latency_ms = int((time.monotonic() - t0) * 1000)
return ProfileTestResponse(ok=True, message="连接成功embedding", latency_ms=latency_ms)
except Exception as exc: # noqa: BLE001
latency_ms = int((time.monotonic() - t0) * 1000)
return ProfileTestResponse(ok=False, message=str(exc), latency_ms=latency_ms)
# Chat 模型:先用 max_completion_tokens失败时 fallback 到 max_tokens
for kwargs in [{"max_completion_tokens": 1}, {"max_tokens": 1}]:
try:
client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": "hi"}],
**kwargs,
)
latency_ms = int((time.monotonic() - t0) * 1000)
return ProfileTestResponse(ok=True, message="连接成功", latency_ms=latency_ms)
except Exception as exc: # noqa: BLE001
err_str = str(exc)
# 仅当错误明确提示参数名称问题时才重试
if "max_tokens" in err_str and "max_completion_tokens" in err_str and kwargs.get("max_completion_tokens"):
continue
latency_ms = int((time.monotonic() - t0) * 1000)
return ProfileTestResponse(ok=False, message=err_str, latency_ms=latency_ms)
latency_ms = int((time.monotonic() - t0) * 1000)
return ProfileTestResponse(ok=False, message="连接测试失败", latency_ms=latency_ms)
@router.post("/probe", response_model=ProfileTestResponse, tags=["llm-profiles"])
def probe_connectivity(request: ProfileProbeRequest) -> ProfileTestResponse:
"""Test LLM connectivity with inline credentials (no saved profile required)."""
logger.info("[probe] model=%s base_url=%s", request.model, request.base_url)
result = _do_connectivity_test(
model=request.model,
base_url=request.base_url,
api_key=request.api_key,
timeout_seconds=request.timeout_seconds,
)
logger.info("[probe] ok=%s latency=%sms msg=%s", result.ok, result.latency_ms, result.message)
return result
@router.get("", response_model=dict)
def list_profiles() -> dict:
"""Return all saved LLM profiles."""
profiles = profile_manager.list_all()
logger.info("[list_profiles] count=%d", len(profiles))
return {"profiles": [p.model_dump() for p in profiles]}
@router.post("", status_code=201, response_model=LLMProfile)
def create_profile(request: CreateProfileRequest) -> LLMProfile:
"""Create a new LLM profile."""
logger.info("[create_profile] name=%r model=%s base_url=%s", request.name, request.model, request.base_url)
profile = profile_manager.create(
name=request.name,
model=request.model,
base_url=request.base_url,
api_key=request.api_key,
timeout_seconds=request.timeout_seconds,
)
logger.info("[create_profile] created id=%s", profile.profile_id)
return profile
@router.put("/{profile_id}", response_model=LLMProfile)
def update_profile(profile_id: str, request: CreateProfileRequest) -> LLMProfile:
"""Update an existing LLM profile by id."""
logger.info("[update_profile] id=%s name=%r model=%s", profile_id, request.name, request.model)
updated = profile_manager.update(
profile_id=profile_id,
name=request.name,
model=request.model,
base_url=request.base_url,
api_key=request.api_key,
timeout_seconds=request.timeout_seconds,
)
if updated is None:
logger.warning("[update_profile] not found id=%s", profile_id)
raise HTTPException(status_code=404, detail=f"Profile not found: {profile_id}")
logger.info("[update_profile] updated id=%s", profile_id)
return updated
@router.delete("/{profile_id}", response_model=dict)
def delete_profile(profile_id: str) -> dict:
"""Delete an LLM profile by id."""
logger.info("[delete_profile] id=%s", profile_id)
deleted = profile_manager.delete(profile_id)
if not deleted:
logger.warning("[delete_profile] not found id=%s", profile_id)
raise HTTPException(status_code=404, detail=f"Profile not found: {profile_id}")
logger.info("[delete_profile] deleted id=%s", profile_id)
return {"deleted": True}
@router.post("/{profile_id}/test", response_model=ProfileTestResponse)
def test_profile(profile_id: str) -> ProfileTestResponse:
"""Test LLM connectivity for a saved profile."""
profile = profile_manager.get(profile_id)
if profile is None:
logger.warning("[test_profile] not found id=%s", profile_id)
raise HTTPException(status_code=404, detail=f"Profile not found: {profile_id}")
logger.info("[test_profile] id=%s model=%s base_url=%s", profile_id, profile.model, profile.base_url)
result = _do_connectivity_test(
model=profile.model,
base_url=profile.base_url,
api_key=profile.api_key,
timeout_seconds=profile.timeout_seconds,
)
logger.info("[test_profile] ok=%s latency=%sms", result.ok, result.latency_ms)
return result
@router.post("/apply", response_model=ProfileApplyResponse)
def apply_profiles(request: ProfileApplyRequest) -> ProfileApplyResponse:
"""Patch selected LLM profiles into the target scenario YAML file."""
logger.info(
"[apply_profiles] scenario=%s judge=%s answer=%s dataset=%s metric_weights=%s doc_weights=%s",
request.scenario_path,
request.judge_profile_id,
request.answer_profile_id,
request.dataset_profile_id,
bool(request.metric_weights),
bool(request.doc_weights),
)
role_profiles: dict[str, LLMProfile | None] = {
"judge": profile_manager.get(request.judge_profile_id) if request.judge_profile_id else None,
"answer": profile_manager.get(request.answer_profile_id) if request.answer_profile_id else None,
"dataset": profile_manager.get(request.dataset_profile_id) if request.dataset_profile_id else None,
}
missing = [
role
for role, pid in [
("judge", request.judge_profile_id),
("answer", request.answer_profile_id),
("dataset", request.dataset_profile_id),
]
if pid and role_profiles[role] is None
]
if missing:
logger.warning("[apply_profiles] missing profiles for roles: %s", missing)
raise HTTPException(
status_code=400,
detail=f"Profile(s) not found for roles: {', '.join(missing)}",
)
patched = apply_profiles_to_scenario(
scenario_path=request.scenario_path,
judge_profile=role_profiles["judge"],
answer_profile=role_profiles["answer"],
dataset_profile=role_profiles["dataset"],
metric_weights=request.metric_weights,
doc_weights=request.doc_weights,
)
logger.info("[apply_profiles] patched fields: %s", patched)
return ProfileApplyResponse(
scenario_path=request.scenario_path,
patched_fields=patched,
)