88 lines
2.5 KiB
Python
88 lines
2.5 KiB
Python
|
|
import os
|
||
|
|
import logging
|
||
|
|
from contextlib import asynccontextmanager
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
from fastapi import FastAPI, HTTPException
|
||
|
|
from pydantic import BaseModel, Field
|
||
|
|
|
||
|
|
logging.basicConfig(level=logging.INFO)
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
MODEL_NAME = os.getenv("MODEL_NAME", "BAAI/bge-m3")
|
||
|
|
MODEL_CACHE = os.getenv("HF_HOME", "/app/models")
|
||
|
|
DEVICE = os.getenv("DEVICE", "cpu")
|
||
|
|
MAX_BATCH = int(os.getenv("MAX_BATCH_SIZE", "16"))
|
||
|
|
|
||
|
|
# 设置 HuggingFace 镜像
|
||
|
|
if os.getenv("HF_ENDPOINT"):
|
||
|
|
os.environ["HF_ENDPOINT"] = os.getenv("HF_ENDPOINT")
|
||
|
|
|
||
|
|
model = None
|
||
|
|
|
||
|
|
|
||
|
|
@asynccontextmanager
|
||
|
|
async def lifespan(app: FastAPI):
|
||
|
|
global model
|
||
|
|
logger.info(f"加载模型 {MODEL_NAME},设备:{DEVICE}")
|
||
|
|
try:
|
||
|
|
from FlagEmbedding import BGEM3FlagModel
|
||
|
|
model = BGEM3FlagModel(
|
||
|
|
MODEL_NAME,
|
||
|
|
use_fp16=(DEVICE != "cpu"),
|
||
|
|
cache_dir=MODEL_CACHE,
|
||
|
|
)
|
||
|
|
logger.info("BGE-M3 模型加载完成")
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"模型加载失败:{e}")
|
||
|
|
raise
|
||
|
|
yield
|
||
|
|
logger.info("服务关闭")
|
||
|
|
|
||
|
|
|
||
|
|
app = FastAPI(title="BGE-M3 嵌入服务", lifespan=lifespan)
|
||
|
|
|
||
|
|
|
||
|
|
class EmbedRequest(BaseModel):
|
||
|
|
texts: list[str] = Field(..., min_length=1, max_length=100)
|
||
|
|
batch_size: int = Field(default=12, ge=1, le=MAX_BATCH)
|
||
|
|
return_dense: bool = True
|
||
|
|
return_sparse: bool = True
|
||
|
|
|
||
|
|
|
||
|
|
class EmbedResponse(BaseModel):
|
||
|
|
dense: Optional[list[list[float]]] = None
|
||
|
|
sparse: Optional[list[dict]] = None
|
||
|
|
model: str
|
||
|
|
count: int
|
||
|
|
|
||
|
|
|
||
|
|
@app.post("/embed", response_model=EmbedResponse)
|
||
|
|
def embed(req: EmbedRequest) -> EmbedResponse:
|
||
|
|
if model is None:
|
||
|
|
raise HTTPException(status_code=503, detail="模型未就绪")
|
||
|
|
if len(req.texts) > 100:
|
||
|
|
raise HTTPException(status_code=400, detail="单次最多 100 条文本")
|
||
|
|
|
||
|
|
try:
|
||
|
|
output = model.encode(
|
||
|
|
req.texts,
|
||
|
|
batch_size=req.batch_size,
|
||
|
|
return_dense=req.return_dense,
|
||
|
|
return_sparse=req.return_sparse,
|
||
|
|
)
|
||
|
|
return EmbedResponse(
|
||
|
|
dense=output["dense_vecs"].tolist() if req.return_dense else None,
|
||
|
|
sparse=[dict(w) for w in output["lexical_weights"]] if req.return_sparse else None,
|
||
|
|
model=MODEL_NAME,
|
||
|
|
count=len(req.texts),
|
||
|
|
)
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"嵌入生成失败:{e}")
|
||
|
|
raise HTTPException(status_code=500, detail=str(e))
|
||
|
|
|
||
|
|
|
||
|
|
@app.get("/health")
|
||
|
|
def health():
|
||
|
|
return {"status": "ok", "model": MODEL_NAME, "device": DEVICE, "ready": model is not None}
|