first commit
This commit is contained in:
87
services/embedding/main.py
Normal file
87
services/embedding/main.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import os
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODEL_NAME = os.getenv("MODEL_NAME", "BAAI/bge-m3")
|
||||
MODEL_CACHE = os.getenv("HF_HOME", "/app/models")
|
||||
DEVICE = os.getenv("DEVICE", "cpu")
|
||||
MAX_BATCH = int(os.getenv("MAX_BATCH_SIZE", "16"))
|
||||
|
||||
# 设置 HuggingFace 镜像
|
||||
if os.getenv("HF_ENDPOINT"):
|
||||
os.environ["HF_ENDPOINT"] = os.getenv("HF_ENDPOINT")
|
||||
|
||||
model = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global model
|
||||
logger.info(f"加载模型 {MODEL_NAME},设备:{DEVICE}")
|
||||
try:
|
||||
from FlagEmbedding import BGEM3FlagModel
|
||||
model = BGEM3FlagModel(
|
||||
MODEL_NAME,
|
||||
use_fp16=(DEVICE != "cpu"),
|
||||
cache_dir=MODEL_CACHE,
|
||||
)
|
||||
logger.info("BGE-M3 模型加载完成")
|
||||
except Exception as e:
|
||||
logger.error(f"模型加载失败:{e}")
|
||||
raise
|
||||
yield
|
||||
logger.info("服务关闭")
|
||||
|
||||
|
||||
app = FastAPI(title="BGE-M3 嵌入服务", lifespan=lifespan)
|
||||
|
||||
|
||||
class EmbedRequest(BaseModel):
|
||||
texts: list[str] = Field(..., min_length=1, max_length=100)
|
||||
batch_size: int = Field(default=12, ge=1, le=MAX_BATCH)
|
||||
return_dense: bool = True
|
||||
return_sparse: bool = True
|
||||
|
||||
|
||||
class EmbedResponse(BaseModel):
|
||||
dense: Optional[list[list[float]]] = None
|
||||
sparse: Optional[list[dict]] = None
|
||||
model: str
|
||||
count: int
|
||||
|
||||
|
||||
@app.post("/embed", response_model=EmbedResponse)
|
||||
def embed(req: EmbedRequest) -> EmbedResponse:
|
||||
if model is None:
|
||||
raise HTTPException(status_code=503, detail="模型未就绪")
|
||||
if len(req.texts) > 100:
|
||||
raise HTTPException(status_code=400, detail="单次最多 100 条文本")
|
||||
|
||||
try:
|
||||
output = model.encode(
|
||||
req.texts,
|
||||
batch_size=req.batch_size,
|
||||
return_dense=req.return_dense,
|
||||
return_sparse=req.return_sparse,
|
||||
)
|
||||
return EmbedResponse(
|
||||
dense=output["dense_vecs"].tolist() if req.return_dense else None,
|
||||
sparse=[dict(w) for w in output["lexical_weights"]] if req.return_sparse else None,
|
||||
model=MODEL_NAME,
|
||||
count=len(req.texts),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"嵌入生成失败:{e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"status": "ok", "model": MODEL_NAME, "device": DEVICE, "ready": model is not None}
|
||||
Reference in New Issue
Block a user