diff --git a/webapp/api/llm_profiles.py b/webapp/api/llm_profiles.py index cd71be3..b2bca65 100644 --- a/webapp/api/llm_profiles.py +++ b/webapp/api/llm_profiles.py @@ -23,16 +23,29 @@ router = APIRouter(prefix="/api/llm-profiles", tags=["llm-profiles"]) logger = logging.getLogger("webapp.api.llm_profiles") +# 常见 embedding 模型名称关键词,用于自动判断走 /embeddings 端点 +_EMBEDDING_MODEL_KEYWORDS = ( + "embedding", "embed", "text-search", "text-similarity", + "code-search", "ada-002", +) + + +def _is_embedding_model(model: str) -> bool: + """Heuristic: return True if the model name looks like an embedding model.""" + return any(kw in model.lower() for kw in _EMBEDDING_MODEL_KEYWORDS) + + def _do_connectivity_test( model: str, base_url: str, api_key: str, timeout_seconds: int, ) -> ProfileTestResponse: - """Send a minimal chat completion request and return the test result. + """Send a minimal request and return the connectivity test result. - Tries max_completion_tokens first (required by newer OpenAI models like gpt-5.x), - then falls back to max_tokens for older models / compatible APIs. + - Embedding models → POST /embeddings with a short text + - Chat models → POST /chat/completions, tries max_completion_tokens first + (required by newer models like gpt-5.x), falls back to max_tokens. """ client = OpenAI( api_key=api_key, @@ -40,7 +53,18 @@ def _do_connectivity_test( timeout=float(timeout_seconds), ) t0 = time.monotonic() - # Try newer parameter first, fall back to legacy max_tokens on failure + + if _is_embedding_model(model): + # Embedding 模型走 /embeddings 端点 + try: + client.embeddings.create(model=model, input="test") + latency_ms = int((time.monotonic() - t0) * 1000) + return ProfileTestResponse(ok=True, message="连接成功(embedding)", latency_ms=latency_ms) + except Exception as exc: # noqa: BLE001 + latency_ms = int((time.monotonic() - t0) * 1000) + return ProfileTestResponse(ok=False, message=str(exc), latency_ms=latency_ms) + + # Chat 模型:先用 max_completion_tokens,失败时 fallback 到 max_tokens for kwargs in [{"max_completion_tokens": 1}, {"max_tokens": 1}]: try: client.chat.completions.create( @@ -52,11 +76,12 @@ def _do_connectivity_test( return ProfileTestResponse(ok=True, message="连接成功", latency_ms=latency_ms) except Exception as exc: # noqa: BLE001 err_str = str(exc) - # Only retry if the error is specifically about the token parameter name + # 仅当错误明确提示参数名称问题时才重试 if "max_tokens" in err_str and "max_completion_tokens" in err_str and kwargs.get("max_completion_tokens"): continue latency_ms = int((time.monotonic() - t0) * 1000) return ProfileTestResponse(ok=False, message=err_str, latency_ms=latency_ms) + latency_ms = int((time.monotonic() - t0) * 1000) return ProfileTestResponse(ok=False, message="连接测试失败", latency_ms=latency_ms)