228 lines
8.2 KiB
Python
228 lines
8.2 KiB
Python
"""Provide service-layer logic for bge m3 embedder."""
|
||
|
||
import numpy as np
|
||
from typing import List, Dict, Optional, Union
|
||
from dataclasses import dataclass, field
|
||
from loguru import logger
|
||
import torch
|
||
import os
|
||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||
|
||
|
||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||
if 'HF_ENDPOINT' not in os.environ:
|
||
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
||
|
||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||
LOCAL_MODEL_PATHS = [
|
||
os.path.expanduser("~/.cache/modelscope/Xorbits/bge-m3"), # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||
os.path.expanduser("~/.cache/huggingface/hub/models--BAAI--bge-m3/snapshots/main"), # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||
]
|
||
|
||
|
||
@dataclass
|
||
class EmbeddingResult:
|
||
"""Represent the Embedding Result type."""
|
||
dense_embeddings: np.ndarray # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||
sparse_embeddings: List[Dict[int, float]] # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||
texts: List[str]
|
||
dim: int = 1024
|
||
|
||
|
||
class BGEM3Embedder:
|
||
"""Represent the B G E M3 Embedder type."""
|
||
|
||
def __init__(
|
||
self,
|
||
model_name: str = "BAAI/bge-m3",
|
||
use_fp16: bool = True,
|
||
device: Optional[str] = None,
|
||
batch_size: int = 12,
|
||
max_length: int = 8192,
|
||
local_model_path: Optional[str] = None
|
||
):
|
||
"""Initialize the B G E M3 Embedder instance."""
|
||
self.use_fp16 = use_fp16
|
||
self.batch_size = batch_size
|
||
self.max_length = max_length
|
||
|
||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||
if local_model_path and os.path.exists(local_model_path):
|
||
self.model_path = local_model_path
|
||
self.model_name = "local"
|
||
logger.info(f"使用本地模型路径: {local_model_path}")
|
||
else:
|
||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||
found_local = False
|
||
for path in LOCAL_MODEL_PATHS:
|
||
if os.path.exists(path) and os.path.exists(os.path.join(path, "config.json")):
|
||
self.model_path = path
|
||
self.model_name = "local"
|
||
logger.info(f"使用本地模型路径: {path}")
|
||
found_local = True
|
||
break
|
||
|
||
if not found_local:
|
||
self.model_path = model_name
|
||
self.model_name = model_name
|
||
logger.info(f"使用远程模型: {model_name}")
|
||
|
||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||
if device is None:
|
||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
else:
|
||
self.device = device
|
||
|
||
logger.info(f"初始化BGE-M3模型, 设备: {self.device}")
|
||
|
||
self.model = None
|
||
self._load_model()
|
||
|
||
def _load_model(self):
|
||
"""Handle load model for this module for the B G E M3 Embedder instance."""
|
||
try:
|
||
from FlagEmbedding import BGEM3FlagModel
|
||
|
||
self.model = BGEM3FlagModel(
|
||
self.model_path,
|
||
use_fp16=self.use_fp16,
|
||
device=self.device
|
||
)
|
||
|
||
logger.success(f"BGE-M3模型加载成功")
|
||
|
||
except ImportError:
|
||
logger.warning("FlagEmbedding库未安装,请运行: pip install FlagEmbedding")
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"模型加载失败: {e}")
|
||
raise
|
||
|
||
def embed(
|
||
self,
|
||
texts: List[str],
|
||
return_dense: bool = True,
|
||
return_sparse: bool = True,
|
||
return_colbert_vecs: bool = False
|
||
) -> EmbeddingResult:
|
||
"""Handle embed for the B G E M3 Embedder instance."""
|
||
if not texts:
|
||
logger.warning("输入文本列表为空")
|
||
return EmbeddingResult(
|
||
dense_embeddings=np.array([]),
|
||
sparse_embeddings=[],
|
||
texts=[],
|
||
dim=0
|
||
)
|
||
|
||
logger.info(f"开始嵌入{len(texts)}个文本块")
|
||
|
||
try:
|
||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||
embeddings = self.model.encode(
|
||
texts,
|
||
batch_size=self.batch_size,
|
||
max_length=self.max_length,
|
||
return_dense=return_dense,
|
||
return_sparse=return_sparse,
|
||
return_colbert_vecs=return_colbert_vecs
|
||
)
|
||
|
||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||
dense_embeddings = embeddings.get('dense_vecs', np.array([]))
|
||
sparse_embeddings = embeddings.get('lexical_weights', [])
|
||
|
||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||
dim = dense_embeddings.shape[1] if len(dense_embeddings) > 0 else 1024
|
||
|
||
logger.success(f"嵌入完成,向量维度: {dim}")
|
||
|
||
return EmbeddingResult(
|
||
dense_embeddings=dense_embeddings,
|
||
sparse_embeddings=sparse_embeddings,
|
||
texts=texts,
|
||
dim=dim
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"嵌入失败: {e}")
|
||
raise
|
||
|
||
def embed_single(self, text: str) -> Dict[str, Union[np.ndarray, Dict]]:
|
||
"""Embed single for the B G E M3 Embedder instance."""
|
||
result = self.embed([text])
|
||
return {
|
||
'dense': result.dense_embeddings[0],
|
||
'sparse': result.sparse_embeddings[0] if result.sparse_embeddings else {},
|
||
'dim': result.dim
|
||
}
|
||
|
||
def embed_dense(self, texts: List[str]) -> np.ndarray:
|
||
"""Embed dense for the B G E M3 Embedder instance."""
|
||
result = self.embed(texts, return_sparse=False, return_colbert_vecs=False)
|
||
return result.dense_embeddings
|
||
|
||
def embed_sparse(self, texts: List[str]) -> List[Dict[int, float]]:
|
||
"""Embed sparse for the B G E M3 Embedder instance."""
|
||
result = self.embed(texts, return_dense=False, return_colbert_vecs=False)
|
||
return result.sparse_embeddings
|
||
|
||
def embed_query(self, query: str) -> Dict:
|
||
"""Embed query for the B G E M3 Embedder instance."""
|
||
return self.embed_single(query)
|
||
|
||
def compute_similarity(
|
||
self,
|
||
query_embedding: np.ndarray,
|
||
doc_embeddings: np.ndarray,
|
||
metric: str = "cosine"
|
||
) -> np.ndarray:
|
||
"""Handle compute similarity for the B G E M3 Embedder instance."""
|
||
if metric == "cosine":
|
||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||
query_norm = np.linalg.norm(query_embedding)
|
||
doc_norms = np.linalg.norm(doc_embeddings, axis=1)
|
||
|
||
similarities = np.dot(doc_embeddings, query_embedding) / (doc_norms * query_norm)
|
||
|
||
elif metric == "dot":
|
||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||
similarities = np.dot(doc_embeddings, query_embedding)
|
||
|
||
else:
|
||
raise ValueError(f"不支持的相似度度量: {metric}")
|
||
|
||
return similarities
|
||
|
||
def sparse_similarity(
|
||
self,
|
||
query_sparse: Dict[int, float],
|
||
doc_sparse: Dict[int, float]
|
||
) -> float:
|
||
"""Handle sparse similarity for the B G E M3 Embedder instance."""
|
||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||
common_keys = set(query_sparse.keys()) & set(doc_sparse.keys())
|
||
|
||
score = sum(query_sparse[k] * doc_sparse[k] for k in common_keys)
|
||
return score
|
||
|
||
|
||
def embed_texts(
|
||
texts: List[str],
|
||
model_name: str = "BAAI/bge-m3",
|
||
**kwargs
|
||
) -> EmbeddingResult:
|
||
"""Embed texts."""
|
||
embedder = BGEM3Embedder(model_name=model_name, **kwargs)
|
||
return embedder.embed(texts)
|
||
|
||
|
||
def embed_single_text(
|
||
text: str,
|
||
model_name: str = "BAAI/bge-m3",
|
||
**kwargs
|
||
) -> Dict:
|
||
"""Embed single text."""
|
||
embedder = BGEM3Embedder(model_name=model_name, **kwargs)
|
||
return embedder.embed_single(text)
|