Files
2026-04-23 09:58:47 +08:00

88 lines
2.5 KiB
Python

import os
import logging
from contextlib import asynccontextmanager
from typing import Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
MODEL_NAME = os.getenv("MODEL_NAME", "BAAI/bge-m3")
MODEL_CACHE = os.getenv("HF_HOME", "/app/models")
DEVICE = os.getenv("DEVICE", "cpu")
MAX_BATCH = int(os.getenv("MAX_BATCH_SIZE", "16"))
# 设置 HuggingFace 镜像
if os.getenv("HF_ENDPOINT"):
os.environ["HF_ENDPOINT"] = os.getenv("HF_ENDPOINT")
model = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global model
logger.info(f"加载模型 {MODEL_NAME},设备:{DEVICE}")
try:
from FlagEmbedding import BGEM3FlagModel
model = BGEM3FlagModel(
MODEL_NAME,
use_fp16=(DEVICE != "cpu"),
cache_dir=MODEL_CACHE,
)
logger.info("BGE-M3 模型加载完成")
except Exception as e:
logger.error(f"模型加载失败:{e}")
raise
yield
logger.info("服务关闭")
app = FastAPI(title="BGE-M3 嵌入服务", lifespan=lifespan)
class EmbedRequest(BaseModel):
texts: list[str] = Field(..., min_length=1, max_length=100)
batch_size: int = Field(default=12, ge=1, le=MAX_BATCH)
return_dense: bool = True
return_sparse: bool = True
class EmbedResponse(BaseModel):
dense: Optional[list[list[float]]] = None
sparse: Optional[list[dict]] = None
model: str
count: int
@app.post("/embed", response_model=EmbedResponse)
def embed(req: EmbedRequest) -> EmbedResponse:
if model is None:
raise HTTPException(status_code=503, detail="模型未就绪")
if len(req.texts) > 100:
raise HTTPException(status_code=400, detail="单次最多 100 条文本")
try:
output = model.encode(
req.texts,
batch_size=req.batch_size,
return_dense=req.return_dense,
return_sparse=req.return_sparse,
)
return EmbedResponse(
dense=output["dense_vecs"].tolist() if req.return_dense else None,
sparse=[dict(w) for w in output["lexical_weights"]] if req.return_sparse else None,
model=MODEL_NAME,
count=len(req.texts),
)
except Exception as e:
logger.error(f"嵌入生成失败:{e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
def health():
return {"status": "ok", "model": MODEL_NAME, "device": DEVICE, "ready": model is not None}