from pymilvus import ( connections, Collection, FieldSchema, CollectionSchema, DataType, utility, ) from typing import List, Optional class MilvusService: def __init__(self): from app.core.config import settings self.host = settings.milvus_host self.port = settings.milvus_port self.regulations_collection_name = settings.regulations_collection self.compliance_collection_name = settings.compliance_collection self._connected = False def connect(self): """连接Milvus""" if not self._connected: connections.connect( alias="default", host=self.host, port=self.port, ) self._connected = True def disconnect(self): """断开连接""" if self._connected: connections.disconnect("default") self._connected = False def create_regulations_collection(self): """创建法规文档集合""" from app.core.config import settings self.connect() if utility.has_collection(self.regulations_collection_name): return Collection(self.regulations_collection_name) fields = [ FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=settings.embedding_dim), FieldSchema(name="doc_name", dtype=DataType.VARCHAR, max_length=256), FieldSchema(name="clause_id", dtype=DataType.VARCHAR, max_length=64), FieldSchema(name="chapter", dtype=DataType.VARCHAR, max_length=128), FieldSchema(name="source_file", dtype=DataType.VARCHAR, max_length=256), FieldSchema(name="chunk_index", dtype=DataType.INT64), FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535), FieldSchema(name="token_count", dtype=DataType.INT64), ] schema = CollectionSchema( fields=fields, description="法规文档向量集合", ) collection = Collection( name=self.regulations_collection_name, schema=schema, ) index_params = { "metric_type": "COSINE", "index_type": "IVF_FLAT", "params": {"nlist": 128}, } collection.create_index(field_name="embedding", index_params=index_params) return collection def insert_chunks( self, embeddings: List[List[float]], metadata: List[dict], ) -> List[int]: """插入向量数据""" collection = Collection(self.regulations_collection_name) collection.load() data = [ embeddings, [m.get("doc_name", "") for m in metadata], [m.get("clause_id", "") for m in metadata], [m.get("chapter", "") for m in metadata], [m.get("source_file", "") for m in metadata], [m.get("chunk_index", 0) for m in metadata], [m.get("content", "") for m in metadata], [m.get("token_count", 0) for m in metadata], ] result = collection.insert(data) collection.flush() return result.primary_keys def search( self, query_embedding: List[float], top_k: int = 10, ) -> List[dict]: """向量检索""" collection = Collection(self.regulations_collection_name) collection.load() search_params = {"metric_type": "COSINE", "params": {"nprobe": 16}} results = collection.search( data=[query_embedding], anns_field="embedding", param=search_params, limit=top_k, output_fields=["doc_name", "clause_id", "chapter", "content", "chunk_index"], ) hits = [] for hit in results[0]: hits.append({ "id": hit.id, "score": hit.score, "doc_name": hit.entity.get("doc_name"), "clause_id": hit.entity.get("clause_id"), "chapter": hit.entity.get("chapter"), "content": hit.entity.get("content"), "chunk_index": hit.entity.get("chunk_index"), }) return hits def get_collection_stats(self) -> dict: """获取集合统计""" self.connect() if not utility.has_collection(self.regulations_collection_name): return {"exists": False} collection = Collection(self.regulations_collection_name) collection.load() return { "exists": True, "name": self.regulations_collection_name, "count": collection.num_entities, } def health_check(self) -> bool: """健康检查""" try: self.connect() return True except Exception: return False milvus_service = MilvusService()