158 lines
4.9 KiB
Python
158 lines
4.9 KiB
Python
|
|
from pymilvus import (
|
||
|
|
connections,
|
||
|
|
Collection,
|
||
|
|
FieldSchema,
|
||
|
|
CollectionSchema,
|
||
|
|
DataType,
|
||
|
|
utility,
|
||
|
|
)
|
||
|
|
from typing import List, Optional
|
||
|
|
|
||
|
|
|
||
|
|
class MilvusService:
|
||
|
|
def __init__(self):
|
||
|
|
from app.core.config import settings
|
||
|
|
self.host = settings.milvus_host
|
||
|
|
self.port = settings.milvus_port
|
||
|
|
self.regulations_collection_name = settings.regulations_collection
|
||
|
|
self.compliance_collection_name = settings.compliance_collection
|
||
|
|
self._connected = False
|
||
|
|
|
||
|
|
def connect(self):
|
||
|
|
"""连接Milvus"""
|
||
|
|
if not self._connected:
|
||
|
|
connections.connect(
|
||
|
|
alias="default",
|
||
|
|
host=self.host,
|
||
|
|
port=self.port,
|
||
|
|
)
|
||
|
|
self._connected = True
|
||
|
|
|
||
|
|
def disconnect(self):
|
||
|
|
"""断开连接"""
|
||
|
|
if self._connected:
|
||
|
|
connections.disconnect("default")
|
||
|
|
self._connected = False
|
||
|
|
|
||
|
|
def create_regulations_collection(self):
|
||
|
|
"""创建法规文档集合"""
|
||
|
|
from app.core.config import settings
|
||
|
|
self.connect()
|
||
|
|
|
||
|
|
if utility.has_collection(self.regulations_collection_name):
|
||
|
|
return Collection(self.regulations_collection_name)
|
||
|
|
|
||
|
|
fields = [
|
||
|
|
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
|
||
|
|
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=settings.embedding_dim),
|
||
|
|
FieldSchema(name="doc_name", dtype=DataType.VARCHAR, max_length=256),
|
||
|
|
FieldSchema(name="clause_id", dtype=DataType.VARCHAR, max_length=64),
|
||
|
|
FieldSchema(name="chapter", dtype=DataType.VARCHAR, max_length=128),
|
||
|
|
FieldSchema(name="source_file", dtype=DataType.VARCHAR, max_length=256),
|
||
|
|
FieldSchema(name="chunk_index", dtype=DataType.INT64),
|
||
|
|
FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535),
|
||
|
|
FieldSchema(name="token_count", dtype=DataType.INT64),
|
||
|
|
]
|
||
|
|
|
||
|
|
schema = CollectionSchema(
|
||
|
|
fields=fields,
|
||
|
|
description="法规文档向量集合",
|
||
|
|
)
|
||
|
|
|
||
|
|
collection = Collection(
|
||
|
|
name=self.regulations_collection_name,
|
||
|
|
schema=schema,
|
||
|
|
)
|
||
|
|
|
||
|
|
index_params = {
|
||
|
|
"metric_type": "COSINE",
|
||
|
|
"index_type": "IVF_FLAT",
|
||
|
|
"params": {"nlist": 128},
|
||
|
|
}
|
||
|
|
collection.create_index(field_name="embedding", index_params=index_params)
|
||
|
|
|
||
|
|
return collection
|
||
|
|
|
||
|
|
def insert_chunks(
|
||
|
|
self,
|
||
|
|
embeddings: List[List[float]],
|
||
|
|
metadata: List[dict],
|
||
|
|
) -> List[int]:
|
||
|
|
"""插入向量数据"""
|
||
|
|
collection = Collection(self.regulations_collection_name)
|
||
|
|
collection.load()
|
||
|
|
|
||
|
|
data = [
|
||
|
|
embeddings,
|
||
|
|
[m.get("doc_name", "") for m in metadata],
|
||
|
|
[m.get("clause_id", "") for m in metadata],
|
||
|
|
[m.get("chapter", "") for m in metadata],
|
||
|
|
[m.get("source_file", "") for m in metadata],
|
||
|
|
[m.get("chunk_index", 0) for m in metadata],
|
||
|
|
[m.get("content", "") for m in metadata],
|
||
|
|
[m.get("token_count", 0) for m in metadata],
|
||
|
|
]
|
||
|
|
|
||
|
|
result = collection.insert(data)
|
||
|
|
collection.flush()
|
||
|
|
return result.primary_keys
|
||
|
|
|
||
|
|
def search(
|
||
|
|
self,
|
||
|
|
query_embedding: List[float],
|
||
|
|
top_k: int = 10,
|
||
|
|
) -> List[dict]:
|
||
|
|
"""向量检索"""
|
||
|
|
collection = Collection(self.regulations_collection_name)
|
||
|
|
collection.load()
|
||
|
|
|
||
|
|
search_params = {"metric_type": "COSINE", "params": {"nprobe": 16}}
|
||
|
|
|
||
|
|
results = collection.search(
|
||
|
|
data=[query_embedding],
|
||
|
|
anns_field="embedding",
|
||
|
|
param=search_params,
|
||
|
|
limit=top_k,
|
||
|
|
output_fields=["doc_name", "clause_id", "chapter", "content", "chunk_index"],
|
||
|
|
)
|
||
|
|
|
||
|
|
hits = []
|
||
|
|
for hit in results[0]:
|
||
|
|
hits.append({
|
||
|
|
"id": hit.id,
|
||
|
|
"score": hit.score,
|
||
|
|
"doc_name": hit.entity.get("doc_name"),
|
||
|
|
"clause_id": hit.entity.get("clause_id"),
|
||
|
|
"chapter": hit.entity.get("chapter"),
|
||
|
|
"content": hit.entity.get("content"),
|
||
|
|
"chunk_index": hit.entity.get("chunk_index"),
|
||
|
|
})
|
||
|
|
|
||
|
|
return hits
|
||
|
|
|
||
|
|
def get_collection_stats(self) -> dict:
|
||
|
|
"""获取集合统计"""
|
||
|
|
self.connect()
|
||
|
|
|
||
|
|
if not utility.has_collection(self.regulations_collection_name):
|
||
|
|
return {"exists": False}
|
||
|
|
|
||
|
|
collection = Collection(self.regulations_collection_name)
|
||
|
|
collection.load()
|
||
|
|
|
||
|
|
return {
|
||
|
|
"exists": True,
|
||
|
|
"name": self.regulations_collection_name,
|
||
|
|
"count": collection.num_entities,
|
||
|
|
}
|
||
|
|
|
||
|
|
def health_check(self) -> bool:
|
||
|
|
"""健康检查"""
|
||
|
|
try:
|
||
|
|
self.connect()
|
||
|
|
return True
|
||
|
|
except Exception:
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
milvus_service = MilvusService()
|