Files
AIRegulation-DocAnalysis/backend/app/services/storage/milvus_client.py

488 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Provide service-layer logic for milvus client."""
from pymilvus import (
connections,
Collection,
FieldSchema,
CollectionSchema,
DataType,
utility
)
from typing import List, Dict, Optional, Any
from dataclasses import dataclass, field
from loguru import logger
import time
import numpy as np
from ..embedding.text_chunker import TextChunk
from ..embedding.bge_m3_embedder import EmbeddingResult
from app.config.settings import settings
# Keep service responsibilities explicit so downstream behavior stays predictable.
@dataclass
class SearchResult:
"""Represent the Search Result type."""
id: int
content: str
score: float
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class MilvusDocument:
"""Represent the Milvus Document type."""
doc_id: str
chunk_id: str
content: str
dense_vector: List[float]
sparse_vector: Dict[int, float]
doc_name: str
section_title: str
clause_number: str
page_number: int
regulation_type: str
version: str
create_time: int
class MilvusClient:
"""Represent the Milvus Client type."""
COLLECTION_NAME = "regulations"
SCHEMA_FIELDS = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="doc_id", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="chunk_id", dtype=DataType.VARCHAR, max_length=128),
FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=8192),
FieldSchema(name="dense_vector", dtype=DataType.FLOAT_VECTOR, dim=1024),
FieldSchema(name="sparse_vector", dtype=DataType.SPARSE_FLOAT_VECTOR),
FieldSchema(name="doc_name", dtype=DataType.VARCHAR, max_length=256),
FieldSchema(name="section_title", dtype=DataType.VARCHAR, max_length=512),
FieldSchema(name="clause_number", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="page_number", dtype=DataType.INT64),
FieldSchema(name="regulation_type", dtype=DataType.VARCHAR, max_length=32),
FieldSchema(name="version", dtype=DataType.VARCHAR, max_length=32),
FieldSchema(name="create_time", dtype=DataType.INT64),
]
def __init__(
self,
host: str = None,
port: int = None,
collection_name: str = None,
db_name: str = None
):
"""Initialize the Milvus Client instance."""
self.host = host or settings.milvus_host
self.port = port or settings.milvus_port
self.collection_name = collection_name or settings.milvus_collection
self.db_name = db_name or settings.milvus_db_name
self.collection: Optional[Collection] = None
self.connected = False
logger.info(f"Milvus客户端配置: {self.host}:{self.port}, Collection: {self.collection_name}")
def connect(self) -> bool:
"""Handle connect for the Milvus Client instance."""
try:
connections.connect(
alias="default",
host=self.host,
port=self.port,
db_name=self.db_name
)
self.connected = True
logger.success(f"Milvus连接成功: {self.host}:{self.port}")
return True
except Exception as e:
logger.error(f"Milvus连接失败: {e}")
self.connected = False
return False
def disconnect(self):
"""Handle disconnect for the Milvus Client instance."""
try:
connections.disconnect("default")
self.connected = False
logger.info("Milvus连接已断开")
except Exception as e:
logger.warning(f"断开连接时出错: {e}")
def create_collection(self, recreate: bool = False) -> bool:
"""Create collection for the Milvus Client instance."""
if not self.connected:
logger.warning("未连接到Milvus请先调用connect()")
return False
try:
if utility.has_collection(self.collection_name):
if recreate:
logger.info(f"删除已存在的Collection: {self.collection_name}")
utility.drop_collection(self.collection_name)
else:
logger.info(f"Collection已存在: {self.collection_name}")
self.collection = Collection(self.collection_name)
return True
schema = CollectionSchema(
fields=self.SCHEMA_FIELDS,
description="法规文档向量存储",
enable_dynamic_field=True
)
self.collection = Collection(
name=self.collection_name,
schema=schema
)
self._create_indexes()
logger.success(f"Collection创建成功: {self.collection_name}")
return True
except Exception as e:
logger.error(f"Collection创建失败: {e}")
return False
def _create_indexes(self):
"""Handle create indexes for this module for the Milvus Client instance."""
if not self.collection:
return
try:
dense_index_params = {
"metric_type": "COSINE",
"index_type": "IVF_FLAT",
"params": {"nlist": 128}
}
self.collection.create_index(
field_name="dense_vector",
index_params=dense_index_params
)
sparse_index_params = {
"metric_type": "IP",
"index_type": "SPARSE_INVERTED_INDEX",
"params": {"drop_ratio_build": 0.2}
}
self.collection.create_index(
field_name="sparse_vector",
index_params=sparse_index_params
)
logger.success("向量索引创建成功")
except Exception as e:
logger.warning(f"创建索引时出错: {e}")
def load_collection(self):
"""Load collection for the Milvus Client instance."""
if self.collection:
self.collection.load()
logger.info(f"Collection已加载: {self.collection_name}")
def release_collection(self):
"""Handle release collection for the Milvus Client instance."""
if self.collection:
self.collection.release()
logger.info(f"Collection已释放: {self.collection_name}")
def insert_chunks(
self,
chunks: List[TextChunk],
embeddings: EmbeddingResult
) -> List[int]:
"""Handle insert chunks for the Milvus Client instance."""
if not self.collection:
logger.warning("Collection未初始化")
return []
if len(chunks) != len(embeddings.texts):
logger.warning(f"Chunks数量与嵌入数量不匹配")
return []
logger.info(f"准备插入{len(chunks)}个文档分块")
try:
data = []
current_time = int(time.time())
for chunk, dense_emb, sparse_emb in zip(
chunks,
embeddings.dense_embeddings,
embeddings.sparse_embeddings
):
row = {
"doc_id": chunk.metadata.doc_id,
"chunk_id": chunk.metadata.chunk_id,
"content": chunk.content,
"dense_vector": dense_emb.tolist(),
"sparse_vector": sparse_emb,
"doc_name": chunk.metadata.doc_name,
"section_title": chunk.metadata.section_title,
"clause_number": chunk.metadata.clause_number,
"page_number": chunk.metadata.page_number,
"regulation_type": chunk.metadata.regulation_type,
"version": chunk.metadata.version,
"create_time": current_time
}
data.append(row)
result = self.collection.insert(data)
self.collection.flush()
logger.success(f"插入完成,共{len(result.primary_keys)}条记录")
return result.primary_keys
except Exception as e:
logger.error(f"插入数据失败: {e}")
return []
def hybrid_search(
self,
query_dense: List[float],
query_sparse: Dict[int, float],
top_k: int = 10,
filters: Optional[str] = None
) -> List[SearchResult]:
"""Handle hybrid search for the Milvus Client instance."""
if not self.collection:
logger.warning("Collection未初始化")
return []
try:
self.collection.load()
# Keep service responsibilities explicit so downstream behavior stays predictable.
dense_results = self.dense_search(query_dense, top_k, filters)
# Keep service responsibilities explicit so downstream behavior stays predictable.
if query_sparse:
sparse_results = self.sparse_search(query_sparse, top_k, filters)
merged = self._merge_results(dense_results, sparse_results, top_k)
logger.success(f"混合检索完成,返回{len(merged)}条结果")
return merged
return dense_results
except Exception as e:
logger.error(f"混合检索失败: {e}")
return []
def _merge_results(
self,
dense_results: List[SearchResult],
sparse_results: List[SearchResult],
top_k: int,
dense_weight: float = 0.6
) -> List[SearchResult]:
"""Handle merge results for this module for the Milvus Client instance."""
sparse_weight = 1 - dense_weight
merged_dict = {}
for r in dense_results:
merged_dict[r.id] = {
"result": r,
"dense_score": r.score * dense_weight,
"sparse_score": 0
}
for r in sparse_results:
if r.id in merged_dict:
merged_dict[r.id]["sparse_score"] = r.score * sparse_weight
else:
merged_dict[r.id] = {
"result": r,
"dense_score": 0,
"sparse_score": r.score * sparse_weight
}
final_results = []
for id_, data in merged_dict.items():
result = data["result"]
final_score = data["dense_score"] + data["sparse_score"]
final_results.append(SearchResult(
id=result.id,
content=result.content,
score=final_score,
metadata=result.metadata
))
final_results.sort(key=lambda x: x.score, reverse=True)
return final_results[:top_k]
def dense_search(
self,
query_dense: List[float],
top_k: int = 10,
filters: Optional[str] = None
) -> List[SearchResult]:
"""Handle dense search for the Milvus Client instance."""
if not self.collection:
return []
try:
self.collection.load()
search_params = {
"metric_type": "COSINE",
"params": {"nprobe": 16}
}
results = self.collection.search(
data=[query_dense],
anns_field="dense_vector",
param=search_params,
limit=top_k,
filter=filters,
output_fields=[
"doc_id", "chunk_id", "content",
"doc_name", "section_title", "clause_number",
"page_number", "regulation_type", "version"
]
)
search_results = []
for hits in results:
for hit in hits:
result = SearchResult(
id=hit.id,
content=hit.entity.get("content", ""),
score=hit.score,
metadata={
"doc_id": hit.entity.get("doc_id", ""),
"chunk_id": hit.entity.get("chunk_id", ""),
"doc_name": hit.entity.get("doc_name", ""),
"section_title": hit.entity.get("section_title", ""),
"clause_number": hit.entity.get("clause_number", ""),
"page_number": hit.entity.get("page_number", 0),
"regulation_type": hit.entity.get("regulation_type", ""),
"version": hit.entity.get("version", ""),
}
)
search_results.append(result)
return search_results
except Exception as e:
logger.error(f"Dense检索失败: {e}")
return []
def sparse_search(
self,
query_sparse: Dict[int, float],
top_k: int = 10,
filters: Optional[str] = None
) -> List[SearchResult]:
"""Handle sparse search for the Milvus Client instance."""
if not self.collection:
return []
try:
self.collection.load()
search_params = {
"metric_type": "IP",
"params": {"drop_ratio_search": 0.2}
}
results = self.collection.search(
data=[query_sparse],
anns_field="sparse_vector",
param=search_params,
limit=top_k,
filter=filters,
output_fields=[
"doc_id", "chunk_id", "content",
"doc_name", "section_title", "clause_number",
"page_number", "regulation_type", "version"
]
)
search_results = []
for hits in results:
for hit in hits:
result = SearchResult(
id=hit.id,
content=hit.entity.get("content", ""),
score=hit.score,
metadata={
"doc_id": hit.entity.get("doc_id", ""),
"chunk_id": hit.entity.get("chunk_id", ""),
"doc_name": hit.entity.get("doc_name", ""),
"section_title": hit.entity.get("section_title", ""),
"clause_number": hit.entity.get("clause_number", ""),
"page_number": hit.entity.get("page_number", 0),
"regulation_type": hit.entity.get("regulation_type", ""),
"version": hit.entity.get("version", ""),
}
)
search_results.append(result)
return search_results
except Exception as e:
logger.error(f"Sparse检索失败: {e}")
return []
def delete_by_doc_id(self, doc_id: str) -> int:
"""Delete by doc id for the Milvus Client instance."""
if not self.collection:
return 0
try:
expr = f'doc_id=="{doc_id}"'
result = self.collection.delete(expr)
logger.info(f"删除记录: doc_id={doc_id}, 数量={len(result.primary_keys)}")
return len(result.primary_keys)
except Exception as e:
logger.error(f"删除失败: {e}")
return 0
def get_collection_stats(self) -> Dict[str, Any]:
"""Return collection stats for the Milvus Client instance."""
if not self.collection:
return {}
try:
stats = {
"name": self.collection_name,
"num_entities": self.collection.num_entities,
"description": self.collection.description,
}
return stats
except Exception as e:
logger.warning(f"获取统计信息失败: {e}")
return {}
def create_milvus_client() -> MilvusClient:
"""Create milvus client."""
client = MilvusClient()
client.connect()
client.create_collection(recreate=False)
return client
def insert_documents(
client: MilvusClient,
chunks: List[TextChunk],
embeddings: EmbeddingResult
) -> List[int]:
"""Handle insert documents."""
return client.insert_chunks(chunks, embeddings)
def search_regulations(
client: MilvusClient,
query_dense: List[float],
query_sparse: Dict[int, float],
top_k: int = 10
) -> List[SearchResult]:
"""Search regulations."""
return client.hybrid_search(query_dense, query_sparse, top_k)