Files

158 lines
4.9 KiB
Python
Raw Permalink Normal View History

2026-05-11 11:22:55 +08:00
from pymilvus import (
connections,
Collection,
FieldSchema,
CollectionSchema,
DataType,
utility,
)
from typing import List, Optional
class MilvusService:
def __init__(self):
from app.core.config import settings
self.host = settings.milvus_host
self.port = settings.milvus_port
self.regulations_collection_name = settings.regulations_collection
self.compliance_collection_name = settings.compliance_collection
self._connected = False
def connect(self):
"""连接Milvus"""
if not self._connected:
connections.connect(
alias="default",
host=self.host,
port=self.port,
)
self._connected = True
def disconnect(self):
"""断开连接"""
if self._connected:
connections.disconnect("default")
self._connected = False
def create_regulations_collection(self):
"""创建法规文档集合"""
from app.core.config import settings
self.connect()
if utility.has_collection(self.regulations_collection_name):
return Collection(self.regulations_collection_name)
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=settings.embedding_dim),
FieldSchema(name="doc_name", dtype=DataType.VARCHAR, max_length=256),
FieldSchema(name="clause_id", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="chapter", dtype=DataType.VARCHAR, max_length=128),
FieldSchema(name="source_file", dtype=DataType.VARCHAR, max_length=256),
FieldSchema(name="chunk_index", dtype=DataType.INT64),
FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="token_count", dtype=DataType.INT64),
]
schema = CollectionSchema(
fields=fields,
description="法规文档向量集合",
)
collection = Collection(
name=self.regulations_collection_name,
schema=schema,
)
index_params = {
"metric_type": "COSINE",
"index_type": "IVF_FLAT",
"params": {"nlist": 128},
}
collection.create_index(field_name="embedding", index_params=index_params)
return collection
def insert_chunks(
self,
embeddings: List[List[float]],
metadata: List[dict],
) -> List[int]:
"""插入向量数据"""
collection = Collection(self.regulations_collection_name)
collection.load()
data = [
embeddings,
[m.get("doc_name", "") for m in metadata],
[m.get("clause_id", "") for m in metadata],
[m.get("chapter", "") for m in metadata],
[m.get("source_file", "") for m in metadata],
[m.get("chunk_index", 0) for m in metadata],
[m.get("content", "") for m in metadata],
[m.get("token_count", 0) for m in metadata],
]
result = collection.insert(data)
collection.flush()
return result.primary_keys
def search(
self,
query_embedding: List[float],
top_k: int = 10,
) -> List[dict]:
"""向量检索"""
collection = Collection(self.regulations_collection_name)
collection.load()
search_params = {"metric_type": "COSINE", "params": {"nprobe": 16}}
results = collection.search(
data=[query_embedding],
anns_field="embedding",
param=search_params,
limit=top_k,
output_fields=["doc_name", "clause_id", "chapter", "content", "chunk_index"],
)
hits = []
for hit in results[0]:
hits.append({
"id": hit.id,
"score": hit.score,
"doc_name": hit.entity.get("doc_name"),
"clause_id": hit.entity.get("clause_id"),
"chapter": hit.entity.get("chapter"),
"content": hit.entity.get("content"),
"chunk_index": hit.entity.get("chunk_index"),
})
return hits
def get_collection_stats(self) -> dict:
"""获取集合统计"""
self.connect()
if not utility.has_collection(self.regulations_collection_name):
return {"exists": False}
collection = Collection(self.regulations_collection_name)
collection.load()
return {
"exists": True,
"name": self.regulations_collection_name,
"count": collection.num_entities,
}
def health_check(self) -> bool:
"""健康检查"""
try:
self.connect()
return True
except Exception:
return False
milvus_service = MilvusService()