2026-05-18 16:32:42 +08:00
""" Provide service-layer logic for bge m3 embedder. """
2026-05-14 15:07:34 +08:00
import numpy as np
from typing import List , Dict , Optional , Union
from dataclasses import dataclass , field
from loguru import logger
import torch
import os
2026-05-18 16:32:42 +08:00
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
2026-05-18 16:32:42 +08:00
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
if ' HF_ENDPOINT ' not in os . environ :
os . environ [ ' HF_ENDPOINT ' ] = ' https://hf-mirror.com '
2026-05-18 16:32:42 +08:00
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
LOCAL_MODEL_PATHS = [
2026-05-18 16:32:42 +08:00
os . path . expanduser ( " ~/.cache/modelscope/Xorbits/bge-m3 " ) , # Keep service responsibilities explicit so downstream behavior stays predictable.
os . path . expanduser ( " ~/.cache/huggingface/hub/models--BAAI--bge-m3/snapshots/main " ) , # Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
]
@dataclass
class EmbeddingResult :
2026-05-18 16:32:42 +08:00
""" Represent the Embedding Result type. """
dense_embeddings : np . ndarray # Keep service responsibilities explicit so downstream behavior stays predictable.
sparse_embeddings : List [ Dict [ int , float ] ] # Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
texts : List [ str ]
dim : int = 1024
class BGEM3Embedder :
2026-05-18 16:32:42 +08:00
""" Represent the B G E M3 Embedder type. """
2026-05-14 15:07:34 +08:00
def __init__ (
self ,
model_name : str = " BAAI/bge-m3 " ,
use_fp16 : bool = True ,
device : Optional [ str ] = None ,
batch_size : int = 12 ,
max_length : int = 8192 ,
local_model_path : Optional [ str ] = None
) :
2026-05-18 16:32:42 +08:00
""" Initialize the B G E M3 Embedder instance. """
2026-05-14 15:07:34 +08:00
self . use_fp16 = use_fp16
self . batch_size = batch_size
self . max_length = max_length
2026-05-18 16:32:42 +08:00
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
if local_model_path and os . path . exists ( local_model_path ) :
self . model_path = local_model_path
self . model_name = " local "
logger . info ( f " 使用本地模型路径: { local_model_path } " )
else :
2026-05-18 16:32:42 +08:00
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
found_local = False
for path in LOCAL_MODEL_PATHS :
if os . path . exists ( path ) and os . path . exists ( os . path . join ( path , " config.json " ) ) :
self . model_path = path
self . model_name = " local "
logger . info ( f " 使用本地模型路径: { path } " )
found_local = True
break
if not found_local :
self . model_path = model_name
self . model_name = model_name
logger . info ( f " 使用远程模型: { model_name } " )
2026-05-18 16:32:42 +08:00
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
if device is None :
self . device = " cuda " if torch . cuda . is_available ( ) else " cpu "
else :
self . device = device
logger . info ( f " 初始化BGE-M3模型, 设备: { self . device } " )
self . model = None
self . _load_model ( )
def _load_model ( self ) :
2026-05-18 16:32:42 +08:00
""" Handle load model for this module for the B G E M3 Embedder instance. """
2026-05-14 15:07:34 +08:00
try :
from FlagEmbedding import BGEM3FlagModel
self . model = BGEM3FlagModel (
self . model_path ,
use_fp16 = self . use_fp16 ,
device = self . device
)
logger . success ( f " BGE-M3模型加载成功 " )
except ImportError :
logger . warning ( " FlagEmbedding库未安装, 请运行: pip install FlagEmbedding " )
raise
except Exception as e :
logger . error ( f " 模型加载失败: { e } " )
raise
def embed (
self ,
texts : List [ str ] ,
return_dense : bool = True ,
return_sparse : bool = True ,
return_colbert_vecs : bool = False
) - > EmbeddingResult :
2026-05-18 16:32:42 +08:00
""" Handle embed for the B G E M3 Embedder instance. """
2026-05-14 15:07:34 +08:00
if not texts :
logger . warning ( " 输入文本列表为空 " )
return EmbeddingResult (
dense_embeddings = np . array ( [ ] ) ,
sparse_embeddings = [ ] ,
texts = [ ] ,
dim = 0
)
logger . info ( f " 开始嵌入 { len ( texts ) } 个文本块 " )
try :
2026-05-18 16:32:42 +08:00
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
embeddings = self . model . encode (
texts ,
batch_size = self . batch_size ,
max_length = self . max_length ,
return_dense = return_dense ,
return_sparse = return_sparse ,
return_colbert_vecs = return_colbert_vecs
)
2026-05-18 16:32:42 +08:00
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
dense_embeddings = embeddings . get ( ' dense_vecs ' , np . array ( [ ] ) )
sparse_embeddings = embeddings . get ( ' lexical_weights ' , [ ] )
2026-05-18 16:32:42 +08:00
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
dim = dense_embeddings . shape [ 1 ] if len ( dense_embeddings ) > 0 else 1024
logger . success ( f " 嵌入完成,向量维度: { dim } " )
return EmbeddingResult (
dense_embeddings = dense_embeddings ,
sparse_embeddings = sparse_embeddings ,
texts = texts ,
dim = dim
)
except Exception as e :
logger . error ( f " 嵌入失败: { e } " )
raise
def embed_single ( self , text : str ) - > Dict [ str , Union [ np . ndarray , Dict ] ] :
2026-05-18 16:32:42 +08:00
""" Embed single for the B G E M3 Embedder instance. """
2026-05-14 15:07:34 +08:00
result = self . embed ( [ text ] )
return {
' dense ' : result . dense_embeddings [ 0 ] ,
' sparse ' : result . sparse_embeddings [ 0 ] if result . sparse_embeddings else { } ,
' dim ' : result . dim
}
def embed_dense ( self , texts : List [ str ] ) - > np . ndarray :
2026-05-18 16:32:42 +08:00
""" Embed dense for the B G E M3 Embedder instance. """
2026-05-14 15:07:34 +08:00
result = self . embed ( texts , return_sparse = False , return_colbert_vecs = False )
return result . dense_embeddings
def embed_sparse ( self , texts : List [ str ] ) - > List [ Dict [ int , float ] ] :
2026-05-18 16:32:42 +08:00
""" Embed sparse for the B G E M3 Embedder instance. """
2026-05-14 15:07:34 +08:00
result = self . embed ( texts , return_dense = False , return_colbert_vecs = False )
return result . sparse_embeddings
def embed_query ( self , query : str ) - > Dict :
2026-05-18 16:32:42 +08:00
""" Embed query for the B G E M3 Embedder instance. """
2026-05-14 15:07:34 +08:00
return self . embed_single ( query )
def compute_similarity (
self ,
query_embedding : np . ndarray ,
doc_embeddings : np . ndarray ,
metric : str = " cosine "
) - > np . ndarray :
2026-05-18 16:32:42 +08:00
""" Handle compute similarity for the B G E M3 Embedder instance. """
2026-05-14 15:07:34 +08:00
if metric == " cosine " :
2026-05-18 16:32:42 +08:00
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
query_norm = np . linalg . norm ( query_embedding )
doc_norms = np . linalg . norm ( doc_embeddings , axis = 1 )
similarities = np . dot ( doc_embeddings , query_embedding ) / ( doc_norms * query_norm )
elif metric == " dot " :
2026-05-18 16:32:42 +08:00
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
similarities = np . dot ( doc_embeddings , query_embedding )
else :
raise ValueError ( f " 不支持的相似度度量: { metric } " )
return similarities
def sparse_similarity (
self ,
query_sparse : Dict [ int , float ] ,
doc_sparse : Dict [ int , float ]
) - > float :
2026-05-18 16:32:42 +08:00
""" Handle sparse similarity for the B G E M3 Embedder instance. """
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
common_keys = set ( query_sparse . keys ( ) ) & set ( doc_sparse . keys ( ) )
score = sum ( query_sparse [ k ] * doc_sparse [ k ] for k in common_keys )
return score
def embed_texts (
texts : List [ str ] ,
model_name : str = " BAAI/bge-m3 " ,
* * kwargs
) - > EmbeddingResult :
2026-05-18 16:32:42 +08:00
""" Embed texts. """
2026-05-14 15:07:34 +08:00
embedder = BGEM3Embedder ( model_name = model_name , * * kwargs )
return embedder . embed ( texts )
def embed_single_text (
text : str ,
model_name : str = " BAAI/bge-m3 " ,
* * kwargs
) - > Dict :
2026-05-18 16:32:42 +08:00
""" Embed single text. """
2026-05-14 15:07:34 +08:00
embedder = BGEM3Embedder ( model_name = model_name , * * kwargs )
2026-05-14 18:09:15 +08:00
return embedder . embed_single ( text )