Files
siemens_ragas/webapp/api/llm_profiles.py

246 lines
9.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""CRUD routes for LLM profiles plus the scenario-patching apply endpoint."""
from __future__ import annotations
import logging
import time
from fastapi import APIRouter, HTTPException
from openai import OpenAI
from webapp.models import (
CreateProfileRequest,
LLMProfile,
ProfileApplyRequest,
ProfileApplyResponse,
ProfileProbeRequest,
ProfileTestResponse,
)
from webapp.services.profile_manager import profile_manager
from webapp.services.yaml_patcher import apply_profiles_to_scenario
router = APIRouter(prefix="/api/llm-profiles", tags=["llm-profiles"])
logger = logging.getLogger("webapp.api.llm_profiles")
# 常见 embedding 模型名称关键词,用于自动判断走 /embeddings 端点
_EMBEDDING_MODEL_KEYWORDS = (
"embedding", "embed", "text-search", "text-similarity",
"code-search", "ada-002",
)
def _is_embedding_model(model: str) -> bool:
"""Heuristic: return True if the model name looks like an embedding model."""
return any(kw in model.lower() for kw in _EMBEDDING_MODEL_KEYWORDS)
def _do_connectivity_test(
model: str,
base_url: str,
api_key: str,
timeout_seconds: int,
) -> ProfileTestResponse:
"""Send a minimal request and return the connectivity test result.
- Embedding models → POST /embeddings with a short text
- Chat models → POST /chat/completions, tries max_completion_tokens first
(required by newer models like gpt-5.x), falls back to max_tokens.
"""
client = OpenAI(
api_key=api_key,
base_url=base_url.rstrip("/"),
timeout=float(timeout_seconds),
)
t0 = time.monotonic()
if _is_embedding_model(model):
# Embedding 模型走 /embeddings 端点
try:
client.embeddings.create(model=model, input="test")
latency_ms = int((time.monotonic() - t0) * 1000)
return ProfileTestResponse(ok=True, message="连接成功embedding", latency_ms=latency_ms)
except Exception as exc: # noqa: BLE001
latency_ms = int((time.monotonic() - t0) * 1000)
return ProfileTestResponse(ok=False, message=str(exc), latency_ms=latency_ms)
# Chat 模型:先不限制 token最兼容超时/鉴权错误直接返回
# 避免 max_tokens=1 对部分模型gpt-5.x触发 min-output 限制
try:
client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": "hi"}],
max_tokens=8, # 足够小节省费用,同时满足各模型最小输出要求
)
latency_ms = int((time.monotonic() - t0) * 1000)
return ProfileTestResponse(ok=True, message="连接成功", latency_ms=latency_ms)
except Exception as exc: # noqa: BLE001
err_str = str(exc)
# 如果 max_tokens 不被支持,改用 max_completion_tokens 再试一次
if "max_tokens" in err_str and "max_completion_tokens" in err_str:
try:
client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": "hi"}],
max_completion_tokens=8,
)
latency_ms = int((time.monotonic() - t0) * 1000)
return ProfileTestResponse(ok=True, message="连接成功", latency_ms=latency_ms)
except Exception as exc2: # noqa: BLE001
latency_ms = int((time.monotonic() - t0) * 1000)
return ProfileTestResponse(ok=False, message=str(exc2), latency_ms=latency_ms)
latency_ms = int((time.monotonic() - t0) * 1000)
return ProfileTestResponse(ok=False, message=err_str, latency_ms=latency_ms)
latency_ms = int((time.monotonic() - t0) * 1000)
return ProfileTestResponse(ok=False, message="连接测试失败", latency_ms=latency_ms)
@router.post("/probe", response_model=ProfileTestResponse, tags=["llm-profiles"])
def probe_connectivity(request: ProfileProbeRequest) -> ProfileTestResponse:
"""Test LLM connectivity with inline credentials (no saved profile required)."""
logger.info("[probe] model=%s base_url=%s", request.model, request.base_url)
result = _do_connectivity_test(
model=request.model,
base_url=request.base_url,
api_key=request.api_key,
timeout_seconds=request.timeout_seconds,
)
logger.info("[probe] ok=%s latency=%sms msg=%s", result.ok, result.latency_ms, result.message)
return result
@router.get("", response_model=dict)
def list_profiles() -> dict:
"""Return all saved LLM profiles."""
profiles = profile_manager.list_all()
logger.info("[list_profiles] count=%d", len(profiles))
return {"profiles": [p.model_dump() for p in profiles]}
@router.post("", status_code=201, response_model=LLMProfile)
def create_profile(request: CreateProfileRequest) -> LLMProfile:
"""Create a new LLM profile."""
logger.info("[create_profile] name=%r model=%s base_url=%s", request.name, request.model, request.base_url)
profile = profile_manager.create(
name=request.name,
model=request.model,
base_url=request.base_url,
api_key=request.api_key,
timeout_seconds=request.timeout_seconds,
)
logger.info("[create_profile] created id=%s", profile.profile_id)
return profile
@router.put("/{profile_id}", response_model=LLMProfile)
def update_profile(profile_id: str, request: CreateProfileRequest) -> LLMProfile:
"""Update an existing LLM profile by id."""
logger.info("[update_profile] id=%s name=%r model=%s", profile_id, request.name, request.model)
updated = profile_manager.update(
profile_id=profile_id,
name=request.name,
model=request.model,
base_url=request.base_url,
api_key=request.api_key,
timeout_seconds=request.timeout_seconds,
)
if updated is None:
logger.warning("[update_profile] not found id=%s", profile_id)
raise HTTPException(status_code=404, detail=f"Profile not found: {profile_id}")
# Invalidate scorer cache so next request picks up the new profile settings.
try:
from webapp.services.inline_scorer import inline_scorer
inline_scorer.invalidate_cache()
logger.info("[update_profile] scorer cache invalidated id=%s", profile_id)
except Exception: # noqa: BLE001
pass
logger.info("[update_profile] updated id=%s", profile_id)
return updated
@router.delete("/{profile_id}", response_model=dict)
def delete_profile(profile_id: str) -> dict:
"""Delete an LLM profile by id."""
logger.info("[delete_profile] id=%s", profile_id)
deleted = profile_manager.delete(profile_id)
if not deleted:
logger.warning("[delete_profile] not found id=%s", profile_id)
raise HTTPException(status_code=404, detail=f"Profile not found: {profile_id}")
# Invalidate scorer cache in case the deleted profile was in use.
try:
from webapp.services.inline_scorer import inline_scorer
inline_scorer.invalidate_cache()
except Exception: # noqa: BLE001
pass
logger.info("[delete_profile] deleted id=%s", profile_id)
return {"deleted": True}
@router.post("/{profile_id}/test", response_model=ProfileTestResponse)
def test_profile(profile_id: str) -> ProfileTestResponse:
"""Test LLM connectivity for a saved profile."""
profile = profile_manager.get(profile_id)
if profile is None:
logger.warning("[test_profile] not found id=%s", profile_id)
raise HTTPException(status_code=404, detail=f"Profile not found: {profile_id}")
logger.info("[test_profile] id=%s model=%s base_url=%s", profile_id, profile.model, profile.base_url)
result = _do_connectivity_test(
model=profile.model,
base_url=profile.base_url,
api_key=profile.api_key,
timeout_seconds=profile.timeout_seconds,
)
logger.info("[test_profile] ok=%s latency=%sms", result.ok, result.latency_ms)
return result
@router.post("/apply", response_model=ProfileApplyResponse)
def apply_profiles(request: ProfileApplyRequest) -> ProfileApplyResponse:
"""Patch selected LLM profiles into the target scenario YAML file."""
logger.info(
"[apply_profiles] scenario=%s judge=%s answer=%s dataset=%s metric_weights=%s doc_weights=%s",
request.scenario_path,
request.judge_profile_id,
request.answer_profile_id,
request.dataset_profile_id,
bool(request.metric_weights),
bool(request.doc_weights),
)
role_profiles: dict[str, LLMProfile | None] = {
"judge": profile_manager.get(request.judge_profile_id) if request.judge_profile_id else None,
"answer": profile_manager.get(request.answer_profile_id) if request.answer_profile_id else None,
"dataset": profile_manager.get(request.dataset_profile_id) if request.dataset_profile_id else None,
}
missing = [
role
for role, pid in [
("judge", request.judge_profile_id),
("answer", request.answer_profile_id),
("dataset", request.dataset_profile_id),
]
if pid and role_profiles[role] is None
]
if missing:
logger.warning("[apply_profiles] missing profiles for roles: %s", missing)
raise HTTPException(
status_code=400,
detail=f"Profile(s) not found for roles: {', '.join(missing)}",
)
patched = apply_profiles_to_scenario(
scenario_path=request.scenario_path,
judge_profile=role_profiles["judge"],
answer_profile=role_profiles["answer"],
dataset_profile=role_profiles["dataset"],
metric_weights=request.metric_weights,
doc_weights=request.doc_weights,
)
logger.info("[apply_profiles] patched fields: %s", patched)
return ProfileApplyResponse(
scenario_path=request.scenario_path,
patched_fields=patched,
)