Files

228 lines
8.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Provide service-layer logic for bge m3 embedder."""
import numpy as np
from typing import List, Dict, Optional, Union
from dataclasses import dataclass, field
from loguru import logger
import torch
import os
# Keep service responsibilities explicit so downstream behavior stays predictable.
# Keep service responsibilities explicit so downstream behavior stays predictable.
if 'HF_ENDPOINT' not in os.environ:
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
# Keep service responsibilities explicit so downstream behavior stays predictable.
LOCAL_MODEL_PATHS = [
os.path.expanduser("~/.cache/modelscope/Xorbits/bge-m3"), # Keep service responsibilities explicit so downstream behavior stays predictable.
os.path.expanduser("~/.cache/huggingface/hub/models--BAAI--bge-m3/snapshots/main"), # Keep service responsibilities explicit so downstream behavior stays predictable.
]
@dataclass
class EmbeddingResult:
"""Represent the Embedding Result type."""
dense_embeddings: np.ndarray # Keep service responsibilities explicit so downstream behavior stays predictable.
sparse_embeddings: List[Dict[int, float]] # Keep service responsibilities explicit so downstream behavior stays predictable.
texts: List[str]
dim: int = 1024
class BGEM3Embedder:
"""Represent the B G E M3 Embedder type."""
def __init__(
self,
model_name: str = "BAAI/bge-m3",
use_fp16: bool = True,
device: Optional[str] = None,
batch_size: int = 12,
max_length: int = 8192,
local_model_path: Optional[str] = None
):
"""Initialize the B G E M3 Embedder instance."""
self.use_fp16 = use_fp16
self.batch_size = batch_size
self.max_length = max_length
# Keep service responsibilities explicit so downstream behavior stays predictable.
if local_model_path and os.path.exists(local_model_path):
self.model_path = local_model_path
self.model_name = "local"
logger.info(f"使用本地模型路径: {local_model_path}")
else:
# Keep service responsibilities explicit so downstream behavior stays predictable.
found_local = False
for path in LOCAL_MODEL_PATHS:
if os.path.exists(path) and os.path.exists(os.path.join(path, "config.json")):
self.model_path = path
self.model_name = "local"
logger.info(f"使用本地模型路径: {path}")
found_local = True
break
if not found_local:
self.model_path = model_name
self.model_name = model_name
logger.info(f"使用远程模型: {model_name}")
# Keep service responsibilities explicit so downstream behavior stays predictable.
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
logger.info(f"初始化BGE-M3模型, 设备: {self.device}")
self.model = None
self._load_model()
def _load_model(self):
"""Handle load model for this module for the B G E M3 Embedder instance."""
try:
from FlagEmbedding import BGEM3FlagModel
self.model = BGEM3FlagModel(
self.model_path,
use_fp16=self.use_fp16,
device=self.device
)
logger.success(f"BGE-M3模型加载成功")
except ImportError:
logger.warning("FlagEmbedding库未安装请运行: pip install FlagEmbedding")
raise
except Exception as e:
logger.error(f"模型加载失败: {e}")
raise
def embed(
self,
texts: List[str],
return_dense: bool = True,
return_sparse: bool = True,
return_colbert_vecs: bool = False
) -> EmbeddingResult:
"""Handle embed for the B G E M3 Embedder instance."""
if not texts:
logger.warning("输入文本列表为空")
return EmbeddingResult(
dense_embeddings=np.array([]),
sparse_embeddings=[],
texts=[],
dim=0
)
logger.info(f"开始嵌入{len(texts)}个文本块")
try:
# Keep service responsibilities explicit so downstream behavior stays predictable.
embeddings = self.model.encode(
texts,
batch_size=self.batch_size,
max_length=self.max_length,
return_dense=return_dense,
return_sparse=return_sparse,
return_colbert_vecs=return_colbert_vecs
)
# Keep service responsibilities explicit so downstream behavior stays predictable.
dense_embeddings = embeddings.get('dense_vecs', np.array([]))
sparse_embeddings = embeddings.get('lexical_weights', [])
# Keep service responsibilities explicit so downstream behavior stays predictable.
dim = dense_embeddings.shape[1] if len(dense_embeddings) > 0 else 1024
logger.success(f"嵌入完成,向量维度: {dim}")
return EmbeddingResult(
dense_embeddings=dense_embeddings,
sparse_embeddings=sparse_embeddings,
texts=texts,
dim=dim
)
except Exception as e:
logger.error(f"嵌入失败: {e}")
raise
def embed_single(self, text: str) -> Dict[str, Union[np.ndarray, Dict]]:
"""Embed single for the B G E M3 Embedder instance."""
result = self.embed([text])
return {
'dense': result.dense_embeddings[0],
'sparse': result.sparse_embeddings[0] if result.sparse_embeddings else {},
'dim': result.dim
}
def embed_dense(self, texts: List[str]) -> np.ndarray:
"""Embed dense for the B G E M3 Embedder instance."""
result = self.embed(texts, return_sparse=False, return_colbert_vecs=False)
return result.dense_embeddings
def embed_sparse(self, texts: List[str]) -> List[Dict[int, float]]:
"""Embed sparse for the B G E M3 Embedder instance."""
result = self.embed(texts, return_dense=False, return_colbert_vecs=False)
return result.sparse_embeddings
def embed_query(self, query: str) -> Dict:
"""Embed query for the B G E M3 Embedder instance."""
return self.embed_single(query)
def compute_similarity(
self,
query_embedding: np.ndarray,
doc_embeddings: np.ndarray,
metric: str = "cosine"
) -> np.ndarray:
"""Handle compute similarity for the B G E M3 Embedder instance."""
if metric == "cosine":
# Keep service responsibilities explicit so downstream behavior stays predictable.
query_norm = np.linalg.norm(query_embedding)
doc_norms = np.linalg.norm(doc_embeddings, axis=1)
similarities = np.dot(doc_embeddings, query_embedding) / (doc_norms * query_norm)
elif metric == "dot":
# Keep service responsibilities explicit so downstream behavior stays predictable.
similarities = np.dot(doc_embeddings, query_embedding)
else:
raise ValueError(f"不支持的相似度度量: {metric}")
return similarities
def sparse_similarity(
self,
query_sparse: Dict[int, float],
doc_sparse: Dict[int, float]
) -> float:
"""Handle sparse similarity for the B G E M3 Embedder instance."""
# Keep service responsibilities explicit so downstream behavior stays predictable.
common_keys = set(query_sparse.keys()) & set(doc_sparse.keys())
score = sum(query_sparse[k] * doc_sparse[k] for k in common_keys)
return score
def embed_texts(
texts: List[str],
model_name: str = "BAAI/bge-m3",
**kwargs
) -> EmbeddingResult:
"""Embed texts."""
embedder = BGEM3Embedder(model_name=model_name, **kwargs)
return embedder.embed(texts)
def embed_single_text(
text: str,
model_name: str = "BAAI/bge-m3",
**kwargs
) -> Dict:
"""Embed single text."""
embedder = BGEM3Embedder(model_name=model_name, **kwargs)
return embedder.embed_single(text)