"""Provide service-layer logic for milvus client.""" from pymilvus import ( connections, Collection, FieldSchema, CollectionSchema, DataType, utility ) from typing import List, Dict, Optional, Any from dataclasses import dataclass, field from loguru import logger import time import numpy as np from ..embedding.text_chunker import TextChunk from ..embedding.bge_m3_embedder import EmbeddingResult from app.config.settings import settings # Keep service responsibilities explicit so downstream behavior stays predictable. @dataclass class SearchResult: """Represent the Search Result type.""" id: int content: str score: float metadata: Dict[str, Any] = field(default_factory=dict) @dataclass class MilvusDocument: """Represent the Milvus Document type.""" doc_id: str chunk_id: str content: str dense_vector: List[float] sparse_vector: Dict[int, float] doc_name: str section_title: str clause_number: str page_number: int regulation_type: str version: str create_time: int class MilvusClient: """Represent the Milvus Client type.""" COLLECTION_NAME = "regulations" SCHEMA_FIELDS = [ FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), FieldSchema(name="doc_id", dtype=DataType.VARCHAR, max_length=64), FieldSchema(name="chunk_id", dtype=DataType.VARCHAR, max_length=128), FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=8192), FieldSchema(name="dense_vector", dtype=DataType.FLOAT_VECTOR, dim=1024), FieldSchema(name="sparse_vector", dtype=DataType.SPARSE_FLOAT_VECTOR), FieldSchema(name="doc_name", dtype=DataType.VARCHAR, max_length=256), FieldSchema(name="section_title", dtype=DataType.VARCHAR, max_length=512), FieldSchema(name="clause_number", dtype=DataType.VARCHAR, max_length=64), FieldSchema(name="page_number", dtype=DataType.INT64), FieldSchema(name="regulation_type", dtype=DataType.VARCHAR, max_length=32), FieldSchema(name="version", dtype=DataType.VARCHAR, max_length=32), FieldSchema(name="create_time", dtype=DataType.INT64), ] def __init__( self, host: str = None, port: int = None, collection_name: str = None, db_name: str = None ): """Initialize the Milvus Client instance.""" self.host = host or settings.milvus_host self.port = port or settings.milvus_port self.collection_name = collection_name or settings.milvus_collection self.db_name = db_name or settings.milvus_db_name self.collection: Optional[Collection] = None self.connected = False logger.info(f"Milvus客户端配置: {self.host}:{self.port}, Collection: {self.collection_name}") def connect(self) -> bool: """Handle connect for the Milvus Client instance.""" try: connections.connect( alias="default", host=self.host, port=self.port, db_name=self.db_name ) self.connected = True logger.success(f"Milvus连接成功: {self.host}:{self.port}") return True except Exception as e: logger.error(f"Milvus连接失败: {e}") self.connected = False return False def disconnect(self): """Handle disconnect for the Milvus Client instance.""" try: connections.disconnect("default") self.connected = False logger.info("Milvus连接已断开") except Exception as e: logger.warning(f"断开连接时出错: {e}") def create_collection(self, recreate: bool = False) -> bool: """Create collection for the Milvus Client instance.""" if not self.connected: logger.warning("未连接到Milvus,请先调用connect()") return False try: if utility.has_collection(self.collection_name): if recreate: logger.info(f"删除已存在的Collection: {self.collection_name}") utility.drop_collection(self.collection_name) else: logger.info(f"Collection已存在: {self.collection_name}") self.collection = Collection(self.collection_name) return True schema = CollectionSchema( fields=self.SCHEMA_FIELDS, description="法规文档向量存储", enable_dynamic_field=True ) self.collection = Collection( name=self.collection_name, schema=schema ) self._create_indexes() logger.success(f"Collection创建成功: {self.collection_name}") return True except Exception as e: logger.error(f"Collection创建失败: {e}") return False def _create_indexes(self): """Handle create indexes for this module for the Milvus Client instance.""" if not self.collection: return try: dense_index_params = { "metric_type": "COSINE", "index_type": "IVF_FLAT", "params": {"nlist": 128} } self.collection.create_index( field_name="dense_vector", index_params=dense_index_params ) sparse_index_params = { "metric_type": "IP", "index_type": "SPARSE_INVERTED_INDEX", "params": {"drop_ratio_build": 0.2} } self.collection.create_index( field_name="sparse_vector", index_params=sparse_index_params ) logger.success("向量索引创建成功") except Exception as e: logger.warning(f"创建索引时出错: {e}") def load_collection(self): """Load collection for the Milvus Client instance.""" if self.collection: self.collection.load() logger.info(f"Collection已加载: {self.collection_name}") def release_collection(self): """Handle release collection for the Milvus Client instance.""" if self.collection: self.collection.release() logger.info(f"Collection已释放: {self.collection_name}") def insert_chunks( self, chunks: List[TextChunk], embeddings: EmbeddingResult ) -> List[int]: """Handle insert chunks for the Milvus Client instance.""" if not self.collection: logger.warning("Collection未初始化") return [] if len(chunks) != len(embeddings.texts): logger.warning(f"Chunks数量与嵌入数量不匹配") return [] logger.info(f"准备插入{len(chunks)}个文档分块") try: data = [] current_time = int(time.time()) for chunk, dense_emb, sparse_emb in zip( chunks, embeddings.dense_embeddings, embeddings.sparse_embeddings ): row = { "doc_id": chunk.metadata.doc_id, "chunk_id": chunk.metadata.chunk_id, "content": chunk.content, "dense_vector": dense_emb.tolist(), "sparse_vector": sparse_emb, "doc_name": chunk.metadata.doc_name, "section_title": chunk.metadata.section_title, "clause_number": chunk.metadata.clause_number, "page_number": chunk.metadata.page_number, "regulation_type": chunk.metadata.regulation_type, "version": chunk.metadata.version, "create_time": current_time } data.append(row) result = self.collection.insert(data) self.collection.flush() logger.success(f"插入完成,共{len(result.primary_keys)}条记录") return result.primary_keys except Exception as e: logger.error(f"插入数据失败: {e}") return [] def hybrid_search( self, query_dense: List[float], query_sparse: Dict[int, float], top_k: int = 10, filters: Optional[str] = None ) -> List[SearchResult]: """Handle hybrid search for the Milvus Client instance.""" if not self.collection: logger.warning("Collection未初始化") return [] try: self.collection.load() # Keep service responsibilities explicit so downstream behavior stays predictable. dense_results = self.dense_search(query_dense, top_k, filters) # Keep service responsibilities explicit so downstream behavior stays predictable. if query_sparse: sparse_results = self.sparse_search(query_sparse, top_k, filters) merged = self._merge_results(dense_results, sparse_results, top_k) logger.success(f"混合检索完成,返回{len(merged)}条结果") return merged return dense_results except Exception as e: logger.error(f"混合检索失败: {e}") return [] def _merge_results( self, dense_results: List[SearchResult], sparse_results: List[SearchResult], top_k: int, dense_weight: float = 0.6 ) -> List[SearchResult]: """Handle merge results for this module for the Milvus Client instance.""" sparse_weight = 1 - dense_weight merged_dict = {} for r in dense_results: merged_dict[r.id] = { "result": r, "dense_score": r.score * dense_weight, "sparse_score": 0 } for r in sparse_results: if r.id in merged_dict: merged_dict[r.id]["sparse_score"] = r.score * sparse_weight else: merged_dict[r.id] = { "result": r, "dense_score": 0, "sparse_score": r.score * sparse_weight } final_results = [] for id_, data in merged_dict.items(): result = data["result"] final_score = data["dense_score"] + data["sparse_score"] final_results.append(SearchResult( id=result.id, content=result.content, score=final_score, metadata=result.metadata )) final_results.sort(key=lambda x: x.score, reverse=True) return final_results[:top_k] def dense_search( self, query_dense: List[float], top_k: int = 10, filters: Optional[str] = None ) -> List[SearchResult]: """Handle dense search for the Milvus Client instance.""" if not self.collection: return [] try: self.collection.load() search_params = { "metric_type": "COSINE", "params": {"nprobe": 16} } results = self.collection.search( data=[query_dense], anns_field="dense_vector", param=search_params, limit=top_k, filter=filters, output_fields=[ "doc_id", "chunk_id", "content", "doc_name", "section_title", "clause_number", "page_number", "regulation_type", "version" ] ) search_results = [] for hits in results: for hit in hits: result = SearchResult( id=hit.id, content=hit.entity.get("content", ""), score=hit.score, metadata={ "doc_id": hit.entity.get("doc_id", ""), "chunk_id": hit.entity.get("chunk_id", ""), "doc_name": hit.entity.get("doc_name", ""), "section_title": hit.entity.get("section_title", ""), "clause_number": hit.entity.get("clause_number", ""), "page_number": hit.entity.get("page_number", 0), "regulation_type": hit.entity.get("regulation_type", ""), "version": hit.entity.get("version", ""), } ) search_results.append(result) return search_results except Exception as e: logger.error(f"Dense检索失败: {e}") return [] def sparse_search( self, query_sparse: Dict[int, float], top_k: int = 10, filters: Optional[str] = None ) -> List[SearchResult]: """Handle sparse search for the Milvus Client instance.""" if not self.collection: return [] try: self.collection.load() search_params = { "metric_type": "IP", "params": {"drop_ratio_search": 0.2} } results = self.collection.search( data=[query_sparse], anns_field="sparse_vector", param=search_params, limit=top_k, filter=filters, output_fields=[ "doc_id", "chunk_id", "content", "doc_name", "section_title", "clause_number", "page_number", "regulation_type", "version" ] ) search_results = [] for hits in results: for hit in hits: result = SearchResult( id=hit.id, content=hit.entity.get("content", ""), score=hit.score, metadata={ "doc_id": hit.entity.get("doc_id", ""), "chunk_id": hit.entity.get("chunk_id", ""), "doc_name": hit.entity.get("doc_name", ""), "section_title": hit.entity.get("section_title", ""), "clause_number": hit.entity.get("clause_number", ""), "page_number": hit.entity.get("page_number", 0), "regulation_type": hit.entity.get("regulation_type", ""), "version": hit.entity.get("version", ""), } ) search_results.append(result) return search_results except Exception as e: logger.error(f"Sparse检索失败: {e}") return [] def delete_by_doc_id(self, doc_id: str) -> int: """Delete by doc id for the Milvus Client instance.""" if not self.collection: return 0 try: expr = f'doc_id=="{doc_id}"' result = self.collection.delete(expr) logger.info(f"删除记录: doc_id={doc_id}, 数量={len(result.primary_keys)}") return len(result.primary_keys) except Exception as e: logger.error(f"删除失败: {e}") return 0 def get_collection_stats(self) -> Dict[str, Any]: """Return collection stats for the Milvus Client instance.""" if not self.collection: return {} try: stats = { "name": self.collection_name, "num_entities": self.collection.num_entities, "description": self.collection.description, } return stats except Exception as e: logger.warning(f"获取统计信息失败: {e}") return {} def create_milvus_client() -> MilvusClient: """Create milvus client.""" client = MilvusClient() client.connect() client.create_collection(recreate=False) return client def insert_documents( client: MilvusClient, chunks: List[TextChunk], embeddings: EmbeddingResult ) -> List[int]: """Handle insert documents.""" return client.insert_chunks(chunks, embeddings) def search_regulations( client: MilvusClient, query_dense: List[float], query_sparse: Dict[int, float], top_k: int = 10 ) -> List[SearchResult]: """Search regulations.""" return client.hybrid_search(query_dense, query_sparse, top_k)