Files
AIRegulation-DocAnalysis/backend/app/services/storage/milvus_client.py

488 lines
16 KiB
Python
Raw Permalink Normal View History

"""Provide service-layer logic for milvus client."""
2026-05-14 15:07:34 +08:00
from pymilvus import (
connections,
Collection,
FieldSchema,
CollectionSchema,
DataType,
utility
)
from typing import List, Dict, Optional, Any
from dataclasses import dataclass, field
from loguru import logger
import time
import numpy as np
from ..embedding.text_chunker import TextChunk
from ..embedding.bge_m3_embedder import EmbeddingResult
from app.config.settings import settings
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
@dataclass
class SearchResult:
"""Represent the Search Result type."""
2026-05-14 15:07:34 +08:00
id: int
content: str
score: float
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class MilvusDocument:
"""Represent the Milvus Document type."""
2026-05-14 15:07:34 +08:00
doc_id: str
chunk_id: str
content: str
dense_vector: List[float]
sparse_vector: Dict[int, float]
doc_name: str
section_title: str
clause_number: str
page_number: int
regulation_type: str
version: str
create_time: int
class MilvusClient:
"""Represent the Milvus Client type."""
2026-05-14 15:07:34 +08:00
COLLECTION_NAME = "regulations"
SCHEMA_FIELDS = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="doc_id", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="chunk_id", dtype=DataType.VARCHAR, max_length=128),
FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=8192),
FieldSchema(name="dense_vector", dtype=DataType.FLOAT_VECTOR, dim=1024),
FieldSchema(name="sparse_vector", dtype=DataType.SPARSE_FLOAT_VECTOR),
FieldSchema(name="doc_name", dtype=DataType.VARCHAR, max_length=256),
FieldSchema(name="section_title", dtype=DataType.VARCHAR, max_length=512),
FieldSchema(name="clause_number", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="page_number", dtype=DataType.INT64),
FieldSchema(name="regulation_type", dtype=DataType.VARCHAR, max_length=32),
FieldSchema(name="version", dtype=DataType.VARCHAR, max_length=32),
FieldSchema(name="create_time", dtype=DataType.INT64),
]
def __init__(
self,
host: str = None,
port: int = None,
collection_name: str = None,
db_name: str = None
):
"""Initialize the Milvus Client instance."""
2026-05-14 15:07:34 +08:00
self.host = host or settings.milvus_host
self.port = port or settings.milvus_port
self.collection_name = collection_name or settings.milvus_collection
self.db_name = db_name or settings.milvus_db_name
self.collection: Optional[Collection] = None
self.connected = False
logger.info(f"Milvus客户端配置: {self.host}:{self.port}, Collection: {self.collection_name}")
def connect(self) -> bool:
"""Handle connect for the Milvus Client instance."""
2026-05-14 15:07:34 +08:00
try:
connections.connect(
alias="default",
host=self.host,
port=self.port,
db_name=self.db_name
)
self.connected = True
logger.success(f"Milvus连接成功: {self.host}:{self.port}")
return True
except Exception as e:
logger.error(f"Milvus连接失败: {e}")
self.connected = False
return False
def disconnect(self):
"""Handle disconnect for the Milvus Client instance."""
2026-05-14 15:07:34 +08:00
try:
connections.disconnect("default")
self.connected = False
logger.info("Milvus连接已断开")
except Exception as e:
logger.warning(f"断开连接时出错: {e}")
def create_collection(self, recreate: bool = False) -> bool:
"""Create collection for the Milvus Client instance."""
2026-05-14 15:07:34 +08:00
if not self.connected:
logger.warning("未连接到Milvus请先调用connect()")
return False
try:
if utility.has_collection(self.collection_name):
if recreate:
logger.info(f"删除已存在的Collection: {self.collection_name}")
utility.drop_collection(self.collection_name)
else:
logger.info(f"Collection已存在: {self.collection_name}")
self.collection = Collection(self.collection_name)
return True
schema = CollectionSchema(
fields=self.SCHEMA_FIELDS,
description="法规文档向量存储",
enable_dynamic_field=True
)
self.collection = Collection(
name=self.collection_name,
schema=schema
)
self._create_indexes()
logger.success(f"Collection创建成功: {self.collection_name}")
return True
except Exception as e:
logger.error(f"Collection创建失败: {e}")
return False
def _create_indexes(self):
"""Handle create indexes for this module for the Milvus Client instance."""
2026-05-14 15:07:34 +08:00
if not self.collection:
return
try:
dense_index_params = {
"metric_type": "COSINE",
"index_type": "IVF_FLAT",
"params": {"nlist": 128}
}
self.collection.create_index(
field_name="dense_vector",
index_params=dense_index_params
)
sparse_index_params = {
"metric_type": "IP",
"index_type": "SPARSE_INVERTED_INDEX",
"params": {"drop_ratio_build": 0.2}
}
self.collection.create_index(
field_name="sparse_vector",
index_params=sparse_index_params
)
logger.success("向量索引创建成功")
except Exception as e:
logger.warning(f"创建索引时出错: {e}")
def load_collection(self):
"""Load collection for the Milvus Client instance."""
2026-05-14 15:07:34 +08:00
if self.collection:
self.collection.load()
logger.info(f"Collection已加载: {self.collection_name}")
def release_collection(self):
"""Handle release collection for the Milvus Client instance."""
2026-05-14 15:07:34 +08:00
if self.collection:
self.collection.release()
logger.info(f"Collection已释放: {self.collection_name}")
def insert_chunks(
self,
chunks: List[TextChunk],
embeddings: EmbeddingResult
) -> List[int]:
"""Handle insert chunks for the Milvus Client instance."""
2026-05-14 15:07:34 +08:00
if not self.collection:
logger.warning("Collection未初始化")
return []
if len(chunks) != len(embeddings.texts):
logger.warning(f"Chunks数量与嵌入数量不匹配")
return []
logger.info(f"准备插入{len(chunks)}个文档分块")
try:
data = []
current_time = int(time.time())
for chunk, dense_emb, sparse_emb in zip(
chunks,
embeddings.dense_embeddings,
embeddings.sparse_embeddings
):
row = {
"doc_id": chunk.metadata.doc_id,
"chunk_id": chunk.metadata.chunk_id,
"content": chunk.content,
"dense_vector": dense_emb.tolist(),
"sparse_vector": sparse_emb,
"doc_name": chunk.metadata.doc_name,
"section_title": chunk.metadata.section_title,
"clause_number": chunk.metadata.clause_number,
"page_number": chunk.metadata.page_number,
"regulation_type": chunk.metadata.regulation_type,
"version": chunk.metadata.version,
"create_time": current_time
}
data.append(row)
result = self.collection.insert(data)
self.collection.flush()
logger.success(f"插入完成,共{len(result.primary_keys)}条记录")
return result.primary_keys
except Exception as e:
logger.error(f"插入数据失败: {e}")
return []
def hybrid_search(
self,
query_dense: List[float],
query_sparse: Dict[int, float],
top_k: int = 10,
filters: Optional[str] = None
) -> List[SearchResult]:
"""Handle hybrid search for the Milvus Client instance."""
2026-05-14 15:07:34 +08:00
if not self.collection:
logger.warning("Collection未初始化")
return []
try:
self.collection.load()
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
dense_results = self.dense_search(query_dense, top_k, filters)
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
if query_sparse:
sparse_results = self.sparse_search(query_sparse, top_k, filters)
merged = self._merge_results(dense_results, sparse_results, top_k)
logger.success(f"混合检索完成,返回{len(merged)}条结果")
return merged
return dense_results
except Exception as e:
logger.error(f"混合检索失败: {e}")
return []
def _merge_results(
self,
dense_results: List[SearchResult],
sparse_results: List[SearchResult],
top_k: int,
dense_weight: float = 0.6
) -> List[SearchResult]:
"""Handle merge results for this module for the Milvus Client instance."""
2026-05-14 15:07:34 +08:00
sparse_weight = 1 - dense_weight
merged_dict = {}
for r in dense_results:
merged_dict[r.id] = {
"result": r,
"dense_score": r.score * dense_weight,
"sparse_score": 0
}
for r in sparse_results:
if r.id in merged_dict:
merged_dict[r.id]["sparse_score"] = r.score * sparse_weight
else:
merged_dict[r.id] = {
"result": r,
"dense_score": 0,
"sparse_score": r.score * sparse_weight
}
final_results = []
for id_, data in merged_dict.items():
result = data["result"]
final_score = data["dense_score"] + data["sparse_score"]
final_results.append(SearchResult(
id=result.id,
content=result.content,
score=final_score,
metadata=result.metadata
))
final_results.sort(key=lambda x: x.score, reverse=True)
return final_results[:top_k]
def dense_search(
self,
query_dense: List[float],
top_k: int = 10,
filters: Optional[str] = None
) -> List[SearchResult]:
"""Handle dense search for the Milvus Client instance."""
2026-05-14 15:07:34 +08:00
if not self.collection:
return []
try:
self.collection.load()
search_params = {
"metric_type": "COSINE",
"params": {"nprobe": 16}
}
results = self.collection.search(
data=[query_dense],
anns_field="dense_vector",
param=search_params,
limit=top_k,
filter=filters,
output_fields=[
"doc_id", "chunk_id", "content",
"doc_name", "section_title", "clause_number",
"page_number", "regulation_type", "version"
]
)
search_results = []
for hits in results:
for hit in hits:
result = SearchResult(
id=hit.id,
content=hit.entity.get("content", ""),
score=hit.score,
metadata={
"doc_id": hit.entity.get("doc_id", ""),
"chunk_id": hit.entity.get("chunk_id", ""),
"doc_name": hit.entity.get("doc_name", ""),
"section_title": hit.entity.get("section_title", ""),
"clause_number": hit.entity.get("clause_number", ""),
"page_number": hit.entity.get("page_number", 0),
"regulation_type": hit.entity.get("regulation_type", ""),
"version": hit.entity.get("version", ""),
}
)
search_results.append(result)
return search_results
except Exception as e:
logger.error(f"Dense检索失败: {e}")
return []
def sparse_search(
self,
query_sparse: Dict[int, float],
top_k: int = 10,
filters: Optional[str] = None
) -> List[SearchResult]:
"""Handle sparse search for the Milvus Client instance."""
2026-05-14 15:07:34 +08:00
if not self.collection:
return []
try:
self.collection.load()
search_params = {
"metric_type": "IP",
"params": {"drop_ratio_search": 0.2}
}
results = self.collection.search(
data=[query_sparse],
anns_field="sparse_vector",
param=search_params,
limit=top_k,
filter=filters,
output_fields=[
"doc_id", "chunk_id", "content",
"doc_name", "section_title", "clause_number",
"page_number", "regulation_type", "version"
]
)
search_results = []
for hits in results:
for hit in hits:
result = SearchResult(
id=hit.id,
content=hit.entity.get("content", ""),
score=hit.score,
metadata={
"doc_id": hit.entity.get("doc_id", ""),
"chunk_id": hit.entity.get("chunk_id", ""),
"doc_name": hit.entity.get("doc_name", ""),
"section_title": hit.entity.get("section_title", ""),
"clause_number": hit.entity.get("clause_number", ""),
"page_number": hit.entity.get("page_number", 0),
"regulation_type": hit.entity.get("regulation_type", ""),
"version": hit.entity.get("version", ""),
}
)
search_results.append(result)
return search_results
except Exception as e:
logger.error(f"Sparse检索失败: {e}")
return []
def delete_by_doc_id(self, doc_id: str) -> int:
"""Delete by doc id for the Milvus Client instance."""
2026-05-14 15:07:34 +08:00
if not self.collection:
return 0
try:
expr = f'doc_id=="{doc_id}"'
result = self.collection.delete(expr)
logger.info(f"删除记录: doc_id={doc_id}, 数量={len(result.primary_keys)}")
return len(result.primary_keys)
except Exception as e:
logger.error(f"删除失败: {e}")
return 0
def get_collection_stats(self) -> Dict[str, Any]:
"""Return collection stats for the Milvus Client instance."""
2026-05-14 15:07:34 +08:00
if not self.collection:
return {}
try:
stats = {
"name": self.collection_name,
"num_entities": self.collection.num_entities,
"description": self.collection.description,
}
return stats
except Exception as e:
logger.warning(f"获取统计信息失败: {e}")
return {}
def create_milvus_client() -> MilvusClient:
"""Create milvus client."""
2026-05-14 15:07:34 +08:00
client = MilvusClient()
client.connect()
client.create_collection(recreate=False)
return client
def insert_documents(
client: MilvusClient,
chunks: List[TextChunk],
embeddings: EmbeddingResult
) -> List[int]:
"""Handle insert documents."""
2026-05-14 15:07:34 +08:00
return client.insert_chunks(chunks, embeddings)
def search_regulations(
client: MilvusClient,
query_dense: List[float],
query_sparse: Dict[int, float],
top_k: int = 10
) -> List[SearchResult]:
"""Search regulations."""
2026-05-14 15:07:34 +08:00
return client.hybrid_search(query_dense, query_sparse, top_k)