初始化
This commit is contained in:
158
app/services/milvus.py
Normal file
158
app/services/milvus.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from pymilvus import (
|
||||
connections,
|
||||
Collection,
|
||||
FieldSchema,
|
||||
CollectionSchema,
|
||||
DataType,
|
||||
utility,
|
||||
)
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class MilvusService:
|
||||
def __init__(self):
|
||||
from app.core.config import settings
|
||||
self.host = settings.milvus_host
|
||||
self.port = settings.milvus_port
|
||||
self.regulations_collection_name = settings.regulations_collection
|
||||
self.compliance_collection_name = settings.compliance_collection
|
||||
self._connected = False
|
||||
|
||||
def connect(self):
|
||||
"""连接Milvus"""
|
||||
if not self._connected:
|
||||
connections.connect(
|
||||
alias="default",
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
)
|
||||
self._connected = True
|
||||
|
||||
def disconnect(self):
|
||||
"""断开连接"""
|
||||
if self._connected:
|
||||
connections.disconnect("default")
|
||||
self._connected = False
|
||||
|
||||
def create_regulations_collection(self):
|
||||
"""创建法规文档集合"""
|
||||
from app.core.config import settings
|
||||
self.connect()
|
||||
|
||||
if utility.has_collection(self.regulations_collection_name):
|
||||
return Collection(self.regulations_collection_name)
|
||||
|
||||
fields = [
|
||||
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
|
||||
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=settings.embedding_dim),
|
||||
FieldSchema(name="doc_name", dtype=DataType.VARCHAR, max_length=256),
|
||||
FieldSchema(name="clause_id", dtype=DataType.VARCHAR, max_length=64),
|
||||
FieldSchema(name="chapter", dtype=DataType.VARCHAR, max_length=128),
|
||||
FieldSchema(name="source_file", dtype=DataType.VARCHAR, max_length=256),
|
||||
FieldSchema(name="chunk_index", dtype=DataType.INT64),
|
||||
FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name="token_count", dtype=DataType.INT64),
|
||||
]
|
||||
|
||||
schema = CollectionSchema(
|
||||
fields=fields,
|
||||
description="法规文档向量集合",
|
||||
)
|
||||
|
||||
collection = Collection(
|
||||
name=self.regulations_collection_name,
|
||||
schema=schema,
|
||||
)
|
||||
|
||||
index_params = {
|
||||
"metric_type": "COSINE",
|
||||
"index_type": "IVF_FLAT",
|
||||
"params": {"nlist": 128},
|
||||
}
|
||||
collection.create_index(field_name="embedding", index_params=index_params)
|
||||
|
||||
return collection
|
||||
|
||||
def insert_chunks(
|
||||
self,
|
||||
embeddings: List[List[float]],
|
||||
metadata: List[dict],
|
||||
) -> List[int]:
|
||||
"""插入向量数据"""
|
||||
collection = Collection(self.regulations_collection_name)
|
||||
collection.load()
|
||||
|
||||
data = [
|
||||
embeddings,
|
||||
[m.get("doc_name", "") for m in metadata],
|
||||
[m.get("clause_id", "") for m in metadata],
|
||||
[m.get("chapter", "") for m in metadata],
|
||||
[m.get("source_file", "") for m in metadata],
|
||||
[m.get("chunk_index", 0) for m in metadata],
|
||||
[m.get("content", "") for m in metadata],
|
||||
[m.get("token_count", 0) for m in metadata],
|
||||
]
|
||||
|
||||
result = collection.insert(data)
|
||||
collection.flush()
|
||||
return result.primary_keys
|
||||
|
||||
def search(
|
||||
self,
|
||||
query_embedding: List[float],
|
||||
top_k: int = 10,
|
||||
) -> List[dict]:
|
||||
"""向量检索"""
|
||||
collection = Collection(self.regulations_collection_name)
|
||||
collection.load()
|
||||
|
||||
search_params = {"metric_type": "COSINE", "params": {"nprobe": 16}}
|
||||
|
||||
results = collection.search(
|
||||
data=[query_embedding],
|
||||
anns_field="embedding",
|
||||
param=search_params,
|
||||
limit=top_k,
|
||||
output_fields=["doc_name", "clause_id", "chapter", "content", "chunk_index"],
|
||||
)
|
||||
|
||||
hits = []
|
||||
for hit in results[0]:
|
||||
hits.append({
|
||||
"id": hit.id,
|
||||
"score": hit.score,
|
||||
"doc_name": hit.entity.get("doc_name"),
|
||||
"clause_id": hit.entity.get("clause_id"),
|
||||
"chapter": hit.entity.get("chapter"),
|
||||
"content": hit.entity.get("content"),
|
||||
"chunk_index": hit.entity.get("chunk_index"),
|
||||
})
|
||||
|
||||
return hits
|
||||
|
||||
def get_collection_stats(self) -> dict:
|
||||
"""获取集合统计"""
|
||||
self.connect()
|
||||
|
||||
if not utility.has_collection(self.regulations_collection_name):
|
||||
return {"exists": False}
|
||||
|
||||
collection = Collection(self.regulations_collection_name)
|
||||
collection.load()
|
||||
|
||||
return {
|
||||
"exists": True,
|
||||
"name": self.regulations_collection_name,
|
||||
"count": collection.num_entities,
|
||||
}
|
||||
|
||||
def health_check(self) -> bool:
|
||||
"""健康检查"""
|
||||
try:
|
||||
self.connect()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
milvus_service = MilvusService()
|
||||
Reference in New Issue
Block a user