Files
siemens_ragas/webapp/api/llm_profiles.py

246 lines
9.7 KiB
Python
Raw Permalink Normal View History

"""CRUD routes for LLM profiles plus the scenario-patching apply endpoint."""
from __future__ import annotations
import logging
import time
from fastapi import APIRouter, HTTPException
from openai import OpenAI
from webapp.models import (
CreateProfileRequest,
LLMProfile,
ProfileApplyRequest,
ProfileApplyResponse,
ProfileProbeRequest,
ProfileTestResponse,
)
from webapp.services.profile_manager import profile_manager
from webapp.services.yaml_patcher import apply_profiles_to_scenario
router = APIRouter(prefix="/api/llm-profiles", tags=["llm-profiles"])
logger = logging.getLogger("webapp.api.llm_profiles")
# 常见 embedding 模型名称关键词,用于自动判断走 /embeddings 端点
_EMBEDDING_MODEL_KEYWORDS = (
"embedding", "embed", "text-search", "text-similarity",
"code-search", "ada-002",
)
def _is_embedding_model(model: str) -> bool:
"""Heuristic: return True if the model name looks like an embedding model."""
return any(kw in model.lower() for kw in _EMBEDDING_MODEL_KEYWORDS)
def _do_connectivity_test(
model: str,
base_url: str,
api_key: str,
timeout_seconds: int,
) -> ProfileTestResponse:
"""Send a minimal request and return the connectivity test result.
- Embedding models POST /embeddings with a short text
- Chat models POST /chat/completions, tries max_completion_tokens first
(required by newer models like gpt-5.x), falls back to max_tokens.
"""
client = OpenAI(
api_key=api_key,
base_url=base_url.rstrip("/"),
timeout=float(timeout_seconds),
)
t0 = time.monotonic()
if _is_embedding_model(model):
# Embedding 模型走 /embeddings 端点
try:
client.embeddings.create(model=model, input="test")
latency_ms = int((time.monotonic() - t0) * 1000)
return ProfileTestResponse(ok=True, message="连接成功embedding", latency_ms=latency_ms)
except Exception as exc: # noqa: BLE001
latency_ms = int((time.monotonic() - t0) * 1000)
return ProfileTestResponse(ok=False, message=str(exc), latency_ms=latency_ms)
# Chat 模型:先不限制 token最兼容超时/鉴权错误直接返回
# 避免 max_tokens=1 对部分模型gpt-5.x触发 min-output 限制
try:
client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": "hi"}],
max_tokens=8, # 足够小节省费用,同时满足各模型最小输出要求
)
latency_ms = int((time.monotonic() - t0) * 1000)
return ProfileTestResponse(ok=True, message="连接成功", latency_ms=latency_ms)
except Exception as exc: # noqa: BLE001
err_str = str(exc)
# 如果 max_tokens 不被支持,改用 max_completion_tokens 再试一次
if "max_tokens" in err_str and "max_completion_tokens" in err_str:
try:
client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": "hi"}],
max_completion_tokens=8,
)
latency_ms = int((time.monotonic() - t0) * 1000)
return ProfileTestResponse(ok=True, message="连接成功", latency_ms=latency_ms)
except Exception as exc2: # noqa: BLE001
latency_ms = int((time.monotonic() - t0) * 1000)
return ProfileTestResponse(ok=False, message=str(exc2), latency_ms=latency_ms)
latency_ms = int((time.monotonic() - t0) * 1000)
return ProfileTestResponse(ok=False, message=err_str, latency_ms=latency_ms)
latency_ms = int((time.monotonic() - t0) * 1000)
return ProfileTestResponse(ok=False, message="连接测试失败", latency_ms=latency_ms)
@router.post("/probe", response_model=ProfileTestResponse, tags=["llm-profiles"])
def probe_connectivity(request: ProfileProbeRequest) -> ProfileTestResponse:
"""Test LLM connectivity with inline credentials (no saved profile required)."""
logger.info("[probe] model=%s base_url=%s", request.model, request.base_url)
result = _do_connectivity_test(
model=request.model,
base_url=request.base_url,
api_key=request.api_key,
timeout_seconds=request.timeout_seconds,
)
logger.info("[probe] ok=%s latency=%sms msg=%s", result.ok, result.latency_ms, result.message)
return result
@router.get("", response_model=dict)
def list_profiles() -> dict:
"""Return all saved LLM profiles."""
profiles = profile_manager.list_all()
logger.info("[list_profiles] count=%d", len(profiles))
return {"profiles": [p.model_dump() for p in profiles]}
@router.post("", status_code=201, response_model=LLMProfile)
def create_profile(request: CreateProfileRequest) -> LLMProfile:
"""Create a new LLM profile."""
logger.info("[create_profile] name=%r model=%s base_url=%s", request.name, request.model, request.base_url)
profile = profile_manager.create(
name=request.name,
model=request.model,
base_url=request.base_url,
api_key=request.api_key,
timeout_seconds=request.timeout_seconds,
)
logger.info("[create_profile] created id=%s", profile.profile_id)
return profile
@router.put("/{profile_id}", response_model=LLMProfile)
def update_profile(profile_id: str, request: CreateProfileRequest) -> LLMProfile:
"""Update an existing LLM profile by id."""
logger.info("[update_profile] id=%s name=%r model=%s", profile_id, request.name, request.model)
updated = profile_manager.update(
profile_id=profile_id,
name=request.name,
model=request.model,
base_url=request.base_url,
api_key=request.api_key,
timeout_seconds=request.timeout_seconds,
)
if updated is None:
logger.warning("[update_profile] not found id=%s", profile_id)
raise HTTPException(status_code=404, detail=f"Profile not found: {profile_id}")
# Invalidate scorer cache so next request picks up the new profile settings.
try:
from webapp.services.inline_scorer import inline_scorer
inline_scorer.invalidate_cache()
logger.info("[update_profile] scorer cache invalidated id=%s", profile_id)
except Exception: # noqa: BLE001
pass
logger.info("[update_profile] updated id=%s", profile_id)
return updated
@router.delete("/{profile_id}", response_model=dict)
def delete_profile(profile_id: str) -> dict:
"""Delete an LLM profile by id."""
logger.info("[delete_profile] id=%s", profile_id)
deleted = profile_manager.delete(profile_id)
if not deleted:
logger.warning("[delete_profile] not found id=%s", profile_id)
raise HTTPException(status_code=404, detail=f"Profile not found: {profile_id}")
# Invalidate scorer cache in case the deleted profile was in use.
try:
from webapp.services.inline_scorer import inline_scorer
inline_scorer.invalidate_cache()
except Exception: # noqa: BLE001
pass
logger.info("[delete_profile] deleted id=%s", profile_id)
return {"deleted": True}
@router.post("/{profile_id}/test", response_model=ProfileTestResponse)
def test_profile(profile_id: str) -> ProfileTestResponse:
"""Test LLM connectivity for a saved profile."""
profile = profile_manager.get(profile_id)
if profile is None:
logger.warning("[test_profile] not found id=%s", profile_id)
raise HTTPException(status_code=404, detail=f"Profile not found: {profile_id}")
logger.info("[test_profile] id=%s model=%s base_url=%s", profile_id, profile.model, profile.base_url)
result = _do_connectivity_test(
model=profile.model,
base_url=profile.base_url,
api_key=profile.api_key,
timeout_seconds=profile.timeout_seconds,
)
logger.info("[test_profile] ok=%s latency=%sms", result.ok, result.latency_ms)
return result
@router.post("/apply", response_model=ProfileApplyResponse)
def apply_profiles(request: ProfileApplyRequest) -> ProfileApplyResponse:
"""Patch selected LLM profiles into the target scenario YAML file."""
logger.info(
"[apply_profiles] scenario=%s judge=%s answer=%s dataset=%s metric_weights=%s doc_weights=%s",
request.scenario_path,
request.judge_profile_id,
request.answer_profile_id,
request.dataset_profile_id,
bool(request.metric_weights),
bool(request.doc_weights),
)
role_profiles: dict[str, LLMProfile | None] = {
"judge": profile_manager.get(request.judge_profile_id) if request.judge_profile_id else None,
"answer": profile_manager.get(request.answer_profile_id) if request.answer_profile_id else None,
"dataset": profile_manager.get(request.dataset_profile_id) if request.dataset_profile_id else None,
}
missing = [
role
for role, pid in [
("judge", request.judge_profile_id),
("answer", request.answer_profile_id),
("dataset", request.dataset_profile_id),
]
if pid and role_profiles[role] is None
]
if missing:
logger.warning("[apply_profiles] missing profiles for roles: %s", missing)
raise HTTPException(
status_code=400,
detail=f"Profile(s) not found for roles: {', '.join(missing)}",
)
patched = apply_profiles_to_scenario(
scenario_path=request.scenario_path,
judge_profile=role_profiles["judge"],
answer_profile=role_profiles["answer"],
dataset_profile=role_profiles["dataset"],
metric_weights=request.metric_weights,
doc_weights=request.doc_weights,
)
logger.info("[apply_profiles] patched fields: %s", patched)
return ProfileApplyResponse(
scenario_path=request.scenario_path,
patched_fields=patched,
)