"""Provide service-layer logic for bge m3 embedder.""" import numpy as np from typing import List, Dict, Optional, Union from dataclasses import dataclass, field from loguru import logger import torch import os # Keep service responsibilities explicit so downstream behavior stays predictable. # Keep service responsibilities explicit so downstream behavior stays predictable. if 'HF_ENDPOINT' not in os.environ: os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' # Keep service responsibilities explicit so downstream behavior stays predictable. LOCAL_MODEL_PATHS = [ os.path.expanduser("~/.cache/modelscope/Xorbits/bge-m3"), # Keep service responsibilities explicit so downstream behavior stays predictable. os.path.expanduser("~/.cache/huggingface/hub/models--BAAI--bge-m3/snapshots/main"), # Keep service responsibilities explicit so downstream behavior stays predictable. ] @dataclass class EmbeddingResult: """Represent the Embedding Result type.""" dense_embeddings: np.ndarray # Keep service responsibilities explicit so downstream behavior stays predictable. sparse_embeddings: List[Dict[int, float]] # Keep service responsibilities explicit so downstream behavior stays predictable. texts: List[str] dim: int = 1024 class BGEM3Embedder: """Represent the B G E M3 Embedder type.""" def __init__( self, model_name: str = "BAAI/bge-m3", use_fp16: bool = True, device: Optional[str] = None, batch_size: int = 12, max_length: int = 8192, local_model_path: Optional[str] = None ): """Initialize the B G E M3 Embedder instance.""" self.use_fp16 = use_fp16 self.batch_size = batch_size self.max_length = max_length # Keep service responsibilities explicit so downstream behavior stays predictable. if local_model_path and os.path.exists(local_model_path): self.model_path = local_model_path self.model_name = "local" logger.info(f"使用本地模型路径: {local_model_path}") else: # Keep service responsibilities explicit so downstream behavior stays predictable. found_local = False for path in LOCAL_MODEL_PATHS: if os.path.exists(path) and os.path.exists(os.path.join(path, "config.json")): self.model_path = path self.model_name = "local" logger.info(f"使用本地模型路径: {path}") found_local = True break if not found_local: self.model_path = model_name self.model_name = model_name logger.info(f"使用远程模型: {model_name}") # Keep service responsibilities explicit so downstream behavior stays predictable. if device is None: self.device = "cuda" if torch.cuda.is_available() else "cpu" else: self.device = device logger.info(f"初始化BGE-M3模型, 设备: {self.device}") self.model = None self._load_model() def _load_model(self): """Handle load model for this module for the B G E M3 Embedder instance.""" try: from FlagEmbedding import BGEM3FlagModel self.model = BGEM3FlagModel( self.model_path, use_fp16=self.use_fp16, device=self.device ) logger.success(f"BGE-M3模型加载成功") except ImportError: logger.warning("FlagEmbedding库未安装,请运行: pip install FlagEmbedding") raise except Exception as e: logger.error(f"模型加载失败: {e}") raise def embed( self, texts: List[str], return_dense: bool = True, return_sparse: bool = True, return_colbert_vecs: bool = False ) -> EmbeddingResult: """Handle embed for the B G E M3 Embedder instance.""" if not texts: logger.warning("输入文本列表为空") return EmbeddingResult( dense_embeddings=np.array([]), sparse_embeddings=[], texts=[], dim=0 ) logger.info(f"开始嵌入{len(texts)}个文本块") try: # Keep service responsibilities explicit so downstream behavior stays predictable. embeddings = self.model.encode( texts, batch_size=self.batch_size, max_length=self.max_length, return_dense=return_dense, return_sparse=return_sparse, return_colbert_vecs=return_colbert_vecs ) # Keep service responsibilities explicit so downstream behavior stays predictable. dense_embeddings = embeddings.get('dense_vecs', np.array([])) sparse_embeddings = embeddings.get('lexical_weights', []) # Keep service responsibilities explicit so downstream behavior stays predictable. dim = dense_embeddings.shape[1] if len(dense_embeddings) > 0 else 1024 logger.success(f"嵌入完成,向量维度: {dim}") return EmbeddingResult( dense_embeddings=dense_embeddings, sparse_embeddings=sparse_embeddings, texts=texts, dim=dim ) except Exception as e: logger.error(f"嵌入失败: {e}") raise def embed_single(self, text: str) -> Dict[str, Union[np.ndarray, Dict]]: """Embed single for the B G E M3 Embedder instance.""" result = self.embed([text]) return { 'dense': result.dense_embeddings[0], 'sparse': result.sparse_embeddings[0] if result.sparse_embeddings else {}, 'dim': result.dim } def embed_dense(self, texts: List[str]) -> np.ndarray: """Embed dense for the B G E M3 Embedder instance.""" result = self.embed(texts, return_sparse=False, return_colbert_vecs=False) return result.dense_embeddings def embed_sparse(self, texts: List[str]) -> List[Dict[int, float]]: """Embed sparse for the B G E M3 Embedder instance.""" result = self.embed(texts, return_dense=False, return_colbert_vecs=False) return result.sparse_embeddings def embed_query(self, query: str) -> Dict: """Embed query for the B G E M3 Embedder instance.""" return self.embed_single(query) def compute_similarity( self, query_embedding: np.ndarray, doc_embeddings: np.ndarray, metric: str = "cosine" ) -> np.ndarray: """Handle compute similarity for the B G E M3 Embedder instance.""" if metric == "cosine": # Keep service responsibilities explicit so downstream behavior stays predictable. query_norm = np.linalg.norm(query_embedding) doc_norms = np.linalg.norm(doc_embeddings, axis=1) similarities = np.dot(doc_embeddings, query_embedding) / (doc_norms * query_norm) elif metric == "dot": # Keep service responsibilities explicit so downstream behavior stays predictable. similarities = np.dot(doc_embeddings, query_embedding) else: raise ValueError(f"不支持的相似度度量: {metric}") return similarities def sparse_similarity( self, query_sparse: Dict[int, float], doc_sparse: Dict[int, float] ) -> float: """Handle sparse similarity for the B G E M3 Embedder instance.""" # Keep service responsibilities explicit so downstream behavior stays predictable. common_keys = set(query_sparse.keys()) & set(doc_sparse.keys()) score = sum(query_sparse[k] * doc_sparse[k] for k in common_keys) return score def embed_texts( texts: List[str], model_name: str = "BAAI/bge-m3", **kwargs ) -> EmbeddingResult: """Embed texts.""" embedder = BGEM3Embedder(model_name=model_name, **kwargs) return embedder.embed(texts) def embed_single_text( text: str, model_name: str = "BAAI/bge-m3", **kwargs ) -> Dict: """Embed single text.""" embedder = BGEM3Embedder(model_name=model_name, **kwargs) return embedder.embed_single(text)