update
This commit is contained in:
7
backend/app/services/storage/__init__.py
Normal file
7
backend/app/services/storage/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# src/services/storage/__init__.py
|
||||
"""存储服务"""
|
||||
|
||||
from .milvus_client import MilvusClient
|
||||
from .minio_client import MinIOClient
|
||||
|
||||
__all__ = ["MilvusClient", "MinIOClient"]
|
||||
485
backend/app/services/storage/milvus_client.py
Normal file
485
backend/app/services/storage/milvus_client.py
Normal file
@@ -0,0 +1,485 @@
|
||||
# src/services/storage/milvus_client.py
|
||||
"""Milvus向量数据库客户端 - 存储与检索服务"""
|
||||
|
||||
from pymilvus import (
|
||||
connections,
|
||||
Collection,
|
||||
FieldSchema,
|
||||
CollectionSchema,
|
||||
DataType,
|
||||
utility
|
||||
)
|
||||
from typing import List, Dict, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from ..embedding.text_chunker import TextChunk
|
||||
from ..embedding.bge_m3_embedder import EmbeddingResult
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""检索结果"""
|
||||
id: int
|
||||
content: str
|
||||
score: float
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MilvusDocument:
|
||||
"""Milvus文档数据结构"""
|
||||
doc_id: str
|
||||
chunk_id: str
|
||||
content: str
|
||||
dense_vector: List[float]
|
||||
sparse_vector: Dict[int, float]
|
||||
doc_name: str
|
||||
section_title: str
|
||||
clause_number: str
|
||||
page_number: int
|
||||
regulation_type: str
|
||||
version: str
|
||||
create_time: int
|
||||
|
||||
|
||||
class MilvusClient:
|
||||
"""Milvus向量数据库客户端"""
|
||||
|
||||
COLLECTION_NAME = "regulations"
|
||||
|
||||
SCHEMA_FIELDS = [
|
||||
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
|
||||
FieldSchema(name="doc_id", dtype=DataType.VARCHAR, max_length=64),
|
||||
FieldSchema(name="chunk_id", dtype=DataType.VARCHAR, max_length=128),
|
||||
FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=8192),
|
||||
FieldSchema(name="dense_vector", dtype=DataType.FLOAT_VECTOR, dim=1024),
|
||||
FieldSchema(name="sparse_vector", dtype=DataType.SPARSE_FLOAT_VECTOR),
|
||||
FieldSchema(name="doc_name", dtype=DataType.VARCHAR, max_length=256),
|
||||
FieldSchema(name="section_title", dtype=DataType.VARCHAR, max_length=512),
|
||||
FieldSchema(name="clause_number", dtype=DataType.VARCHAR, max_length=64),
|
||||
FieldSchema(name="page_number", dtype=DataType.INT64),
|
||||
FieldSchema(name="regulation_type", dtype=DataType.VARCHAR, max_length=32),
|
||||
FieldSchema(name="version", dtype=DataType.VARCHAR, max_length=32),
|
||||
FieldSchema(name="create_time", dtype=DataType.INT64),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = None,
|
||||
port: int = None,
|
||||
collection_name: str = None,
|
||||
db_name: str = None
|
||||
):
|
||||
self.host = host or settings.milvus_host
|
||||
self.port = port or settings.milvus_port
|
||||
self.collection_name = collection_name or settings.milvus_collection
|
||||
self.db_name = db_name or settings.milvus_db_name
|
||||
|
||||
self.collection: Optional[Collection] = None
|
||||
self.connected = False
|
||||
|
||||
logger.info(f"Milvus客户端配置: {self.host}:{self.port}, Collection: {self.collection_name}")
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""连接到Milvus服务器"""
|
||||
try:
|
||||
connections.connect(
|
||||
alias="default",
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
db_name=self.db_name
|
||||
)
|
||||
self.connected = True
|
||||
logger.success(f"Milvus连接成功: {self.host}:{self.port}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Milvus连接失败: {e}")
|
||||
self.connected = False
|
||||
return False
|
||||
|
||||
def disconnect(self):
|
||||
"""断开连接"""
|
||||
try:
|
||||
connections.disconnect("default")
|
||||
self.connected = False
|
||||
logger.info("Milvus连接已断开")
|
||||
except Exception as e:
|
||||
logger.warning(f"断开连接时出错: {e}")
|
||||
|
||||
def create_collection(self, recreate: bool = False) -> bool:
|
||||
"""创建Collection"""
|
||||
if not self.connected:
|
||||
logger.warning("未连接到Milvus,请先调用connect()")
|
||||
return False
|
||||
|
||||
try:
|
||||
if utility.has_collection(self.collection_name):
|
||||
if recreate:
|
||||
logger.info(f"删除已存在的Collection: {self.collection_name}")
|
||||
utility.drop_collection(self.collection_name)
|
||||
else:
|
||||
logger.info(f"Collection已存在: {self.collection_name}")
|
||||
self.collection = Collection(self.collection_name)
|
||||
return True
|
||||
|
||||
schema = CollectionSchema(
|
||||
fields=self.SCHEMA_FIELDS,
|
||||
description="法规文档向量存储",
|
||||
enable_dynamic_field=True
|
||||
)
|
||||
|
||||
self.collection = Collection(
|
||||
name=self.collection_name,
|
||||
schema=schema
|
||||
)
|
||||
|
||||
self._create_indexes()
|
||||
|
||||
logger.success(f"Collection创建成功: {self.collection_name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Collection创建失败: {e}")
|
||||
return False
|
||||
|
||||
def _create_indexes(self):
|
||||
"""创建向量索引"""
|
||||
if not self.collection:
|
||||
return
|
||||
|
||||
try:
|
||||
dense_index_params = {
|
||||
"metric_type": "COSINE",
|
||||
"index_type": "IVF_FLAT",
|
||||
"params": {"nlist": 128}
|
||||
}
|
||||
self.collection.create_index(
|
||||
field_name="dense_vector",
|
||||
index_params=dense_index_params
|
||||
)
|
||||
|
||||
sparse_index_params = {
|
||||
"metric_type": "IP",
|
||||
"index_type": "SPARSE_INVERTED_INDEX",
|
||||
"params": {"drop_ratio_build": 0.2}
|
||||
}
|
||||
self.collection.create_index(
|
||||
field_name="sparse_vector",
|
||||
index_params=sparse_index_params
|
||||
)
|
||||
|
||||
logger.success("向量索引创建成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"创建索引时出错: {e}")
|
||||
|
||||
def load_collection(self):
|
||||
"""加载Collection到内存"""
|
||||
if self.collection:
|
||||
self.collection.load()
|
||||
logger.info(f"Collection已加载: {self.collection_name}")
|
||||
|
||||
def release_collection(self):
|
||||
"""释放Collection内存"""
|
||||
if self.collection:
|
||||
self.collection.release()
|
||||
logger.info(f"Collection已释放: {self.collection_name}")
|
||||
|
||||
def insert_chunks(
|
||||
self,
|
||||
chunks: List[TextChunk],
|
||||
embeddings: EmbeddingResult
|
||||
) -> List[int]:
|
||||
"""插入文档分块和嵌入向量"""
|
||||
if not self.collection:
|
||||
logger.warning("Collection未初始化")
|
||||
return []
|
||||
|
||||
if len(chunks) != len(embeddings.texts):
|
||||
logger.warning(f"Chunks数量与嵌入数量不匹配")
|
||||
return []
|
||||
|
||||
logger.info(f"准备插入{len(chunks)}个文档分块")
|
||||
|
||||
try:
|
||||
data = []
|
||||
current_time = int(time.time())
|
||||
|
||||
for chunk, dense_emb, sparse_emb in zip(
|
||||
chunks,
|
||||
embeddings.dense_embeddings,
|
||||
embeddings.sparse_embeddings
|
||||
):
|
||||
row = {
|
||||
"doc_id": chunk.metadata.doc_id,
|
||||
"chunk_id": chunk.metadata.chunk_id,
|
||||
"content": chunk.content,
|
||||
"dense_vector": dense_emb.tolist(),
|
||||
"sparse_vector": sparse_emb,
|
||||
"doc_name": chunk.metadata.doc_name,
|
||||
"section_title": chunk.metadata.section_title,
|
||||
"clause_number": chunk.metadata.clause_number,
|
||||
"page_number": chunk.metadata.page_number,
|
||||
"regulation_type": chunk.metadata.regulation_type,
|
||||
"version": chunk.metadata.version,
|
||||
"create_time": current_time
|
||||
}
|
||||
data.append(row)
|
||||
|
||||
result = self.collection.insert(data)
|
||||
self.collection.flush()
|
||||
|
||||
logger.success(f"插入完成,共{len(result.primary_keys)}条记录")
|
||||
return result.primary_keys
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"插入数据失败: {e}")
|
||||
return []
|
||||
|
||||
def hybrid_search(
|
||||
self,
|
||||
query_dense: List[float],
|
||||
query_sparse: Dict[int, float],
|
||||
top_k: int = 10,
|
||||
filters: Optional[str] = None
|
||||
) -> List[SearchResult]:
|
||||
"""混合检索:Dense + Sparse"""
|
||||
if not self.collection:
|
||||
logger.warning("Collection未初始化")
|
||||
return []
|
||||
|
||||
try:
|
||||
self.collection.load()
|
||||
|
||||
# 使用简单的Dense检索(兼容所有版本)
|
||||
dense_results = self.dense_search(query_dense, top_k, filters)
|
||||
|
||||
# 可选:合并Sparse结果
|
||||
if query_sparse:
|
||||
sparse_results = self.sparse_search(query_sparse, top_k, filters)
|
||||
merged = self._merge_results(dense_results, sparse_results, top_k)
|
||||
logger.success(f"混合检索完成,返回{len(merged)}条结果")
|
||||
return merged
|
||||
|
||||
return dense_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"混合检索失败: {e}")
|
||||
return []
|
||||
|
||||
def _merge_results(
|
||||
self,
|
||||
dense_results: List[SearchResult],
|
||||
sparse_results: List[SearchResult],
|
||||
top_k: int,
|
||||
dense_weight: float = 0.6
|
||||
) -> List[SearchResult]:
|
||||
"""手动融合Dense和Sparse结果"""
|
||||
sparse_weight = 1 - dense_weight
|
||||
merged_dict = {}
|
||||
|
||||
for r in dense_results:
|
||||
merged_dict[r.id] = {
|
||||
"result": r,
|
||||
"dense_score": r.score * dense_weight,
|
||||
"sparse_score": 0
|
||||
}
|
||||
|
||||
for r in sparse_results:
|
||||
if r.id in merged_dict:
|
||||
merged_dict[r.id]["sparse_score"] = r.score * sparse_weight
|
||||
else:
|
||||
merged_dict[r.id] = {
|
||||
"result": r,
|
||||
"dense_score": 0,
|
||||
"sparse_score": r.score * sparse_weight
|
||||
}
|
||||
|
||||
final_results = []
|
||||
for id_, data in merged_dict.items():
|
||||
result = data["result"]
|
||||
final_score = data["dense_score"] + data["sparse_score"]
|
||||
final_results.append(SearchResult(
|
||||
id=result.id,
|
||||
content=result.content,
|
||||
score=final_score,
|
||||
metadata=result.metadata
|
||||
))
|
||||
|
||||
final_results.sort(key=lambda x: x.score, reverse=True)
|
||||
return final_results[:top_k]
|
||||
|
||||
def dense_search(
|
||||
self,
|
||||
query_dense: List[float],
|
||||
top_k: int = 10,
|
||||
filters: Optional[str] = None
|
||||
) -> List[SearchResult]:
|
||||
"""纯Dense向量检索"""
|
||||
if not self.collection:
|
||||
return []
|
||||
|
||||
try:
|
||||
self.collection.load()
|
||||
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
"params": {"nprobe": 16}
|
||||
}
|
||||
|
||||
results = self.collection.search(
|
||||
data=[query_dense],
|
||||
anns_field="dense_vector",
|
||||
param=search_params,
|
||||
limit=top_k,
|
||||
filter=filters,
|
||||
output_fields=[
|
||||
"doc_id", "chunk_id", "content",
|
||||
"doc_name", "section_title", "clause_number",
|
||||
"page_number", "regulation_type", "version"
|
||||
]
|
||||
)
|
||||
|
||||
search_results = []
|
||||
for hits in results:
|
||||
for hit in hits:
|
||||
result = SearchResult(
|
||||
id=hit.id,
|
||||
content=hit.entity.get("content", ""),
|
||||
score=hit.score,
|
||||
metadata={
|
||||
"doc_id": hit.entity.get("doc_id", ""),
|
||||
"chunk_id": hit.entity.get("chunk_id", ""),
|
||||
"doc_name": hit.entity.get("doc_name", ""),
|
||||
"section_title": hit.entity.get("section_title", ""),
|
||||
"clause_number": hit.entity.get("clause_number", ""),
|
||||
"page_number": hit.entity.get("page_number", 0),
|
||||
"regulation_type": hit.entity.get("regulation_type", ""),
|
||||
"version": hit.entity.get("version", ""),
|
||||
}
|
||||
)
|
||||
search_results.append(result)
|
||||
|
||||
return search_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Dense检索失败: {e}")
|
||||
return []
|
||||
|
||||
def sparse_search(
|
||||
self,
|
||||
query_sparse: Dict[int, float],
|
||||
top_k: int = 10,
|
||||
filters: Optional[str] = None
|
||||
) -> List[SearchResult]:
|
||||
"""纯Sparse向量检索"""
|
||||
if not self.collection:
|
||||
return []
|
||||
|
||||
try:
|
||||
self.collection.load()
|
||||
|
||||
search_params = {
|
||||
"metric_type": "IP",
|
||||
"params": {"drop_ratio_search": 0.2}
|
||||
}
|
||||
|
||||
results = self.collection.search(
|
||||
data=[query_sparse],
|
||||
anns_field="sparse_vector",
|
||||
param=search_params,
|
||||
limit=top_k,
|
||||
filter=filters,
|
||||
output_fields=[
|
||||
"doc_id", "chunk_id", "content",
|
||||
"doc_name", "section_title", "clause_number",
|
||||
"page_number", "regulation_type", "version"
|
||||
]
|
||||
)
|
||||
|
||||
search_results = []
|
||||
for hits in results:
|
||||
for hit in hits:
|
||||
result = SearchResult(
|
||||
id=hit.id,
|
||||
content=hit.entity.get("content", ""),
|
||||
score=hit.score,
|
||||
metadata={
|
||||
"doc_id": hit.entity.get("doc_id", ""),
|
||||
"chunk_id": hit.entity.get("chunk_id", ""),
|
||||
"doc_name": hit.entity.get("doc_name", ""),
|
||||
"section_title": hit.entity.get("section_title", ""),
|
||||
"clause_number": hit.entity.get("clause_number", ""),
|
||||
"page_number": hit.entity.get("page_number", 0),
|
||||
"regulation_type": hit.entity.get("regulation_type", ""),
|
||||
"version": hit.entity.get("version", ""),
|
||||
}
|
||||
)
|
||||
search_results.append(result)
|
||||
|
||||
return search_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Sparse检索失败: {e}")
|
||||
return []
|
||||
|
||||
def delete_by_doc_id(self, doc_id: str) -> int:
|
||||
"""根据doc_id删除记录"""
|
||||
if not self.collection:
|
||||
return 0
|
||||
|
||||
try:
|
||||
expr = f'doc_id=="{doc_id}"'
|
||||
result = self.collection.delete(expr)
|
||||
logger.info(f"删除记录: doc_id={doc_id}, 数量={len(result.primary_keys)}")
|
||||
return len(result.primary_keys)
|
||||
except Exception as e:
|
||||
logger.error(f"删除失败: {e}")
|
||||
return 0
|
||||
|
||||
def get_collection_stats(self) -> Dict[str, Any]:
|
||||
"""获取Collection统计信息"""
|
||||
if not self.collection:
|
||||
return {}
|
||||
|
||||
try:
|
||||
stats = {
|
||||
"name": self.collection_name,
|
||||
"num_entities": self.collection.num_entities,
|
||||
"description": self.collection.description,
|
||||
}
|
||||
return stats
|
||||
except Exception as e:
|
||||
logger.warning(f"获取统计信息失败: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def create_milvus_client() -> MilvusClient:
|
||||
"""便捷函数:创建Milvus客户端"""
|
||||
client = MilvusClient()
|
||||
client.connect()
|
||||
client.create_collection(recreate=False)
|
||||
return client
|
||||
|
||||
|
||||
def insert_documents(
|
||||
client: MilvusClient,
|
||||
chunks: List[TextChunk],
|
||||
embeddings: EmbeddingResult
|
||||
) -> List[int]:
|
||||
"""便捷函数:插入文档"""
|
||||
return client.insert_chunks(chunks, embeddings)
|
||||
|
||||
|
||||
def search_regulations(
|
||||
client: MilvusClient,
|
||||
query_dense: List[float],
|
||||
query_sparse: Dict[int, float],
|
||||
top_k: int = 10
|
||||
) -> List[SearchResult]:
|
||||
"""便捷函数:检索法规"""
|
||||
return client.hybrid_search(query_dense, query_sparse, top_k)
|
||||
352
backend/app/services/storage/minio_client.py
Normal file
352
backend/app/services/storage/minio_client.py
Normal file
@@ -0,0 +1,352 @@
|
||||
# src/services/storage/minio_client.py
|
||||
"""MinIO对象存储客户端 - 文档文件存储"""
|
||||
|
||||
from minio import Minio
|
||||
from minio.error import S3Error
|
||||
from typing import Optional, Dict, Any
|
||||
from loguru import logger
|
||||
from io import BytesIO
|
||||
import os
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
class MinIOClient:
|
||||
"""MinIO对象存储客户端"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str = None,
|
||||
access_key: str = None,
|
||||
secret_key: str = None,
|
||||
bucket: str = None,
|
||||
secure: bool = None
|
||||
):
|
||||
"""
|
||||
初始化MinIO客户端
|
||||
|
||||
Args:
|
||||
endpoint: MinIO服务地址
|
||||
access_key: 访问密钥
|
||||
secret_key: 秘密密钥
|
||||
bucket: 存储桶名称
|
||||
secure: 是否使用HTTPS
|
||||
"""
|
||||
self.endpoint = endpoint or settings.minio_endpoint
|
||||
self.access_key = access_key or settings.minio_access_key
|
||||
self.secret_key = secret_key or settings.minio_secret_key
|
||||
self.bucket = bucket or settings.minio_bucket
|
||||
self.secure = secure or settings.minio_secure
|
||||
|
||||
self.client: Optional[Minio] = None
|
||||
self.connected = False
|
||||
|
||||
logger.info(f"MinIO客户端配置: {self.endpoint}, bucket={self.bucket}")
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""连接MinIO服务"""
|
||||
try:
|
||||
self.client = Minio(
|
||||
self.endpoint,
|
||||
access_key=self.access_key,
|
||||
secret_key=self.secret_key,
|
||||
secure=self.secure
|
||||
)
|
||||
self.connected = True
|
||||
logger.success(f"MinIO连接成功: {self.endpoint}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"MinIO连接失败: {e}")
|
||||
self.connected = False
|
||||
return False
|
||||
|
||||
def ensure_bucket(self) -> bool:
|
||||
"""确保存储桶存在"""
|
||||
if not self.connected:
|
||||
logger.warning("未连接MinIO,请先调用connect()")
|
||||
return False
|
||||
|
||||
try:
|
||||
if not self.client.bucket_exists(self.bucket):
|
||||
self.client.make_bucket(self.bucket)
|
||||
logger.success(f"创建存储桶: {self.bucket}")
|
||||
else:
|
||||
logger.info(f"存储桶已存在: {self.bucket}")
|
||||
return True
|
||||
except S3Error as e:
|
||||
logger.error(f"存储桶操作失败: {e}")
|
||||
return False
|
||||
|
||||
def upload_file(
|
||||
self,
|
||||
file_path: str,
|
||||
object_name: str,
|
||||
metadata: Dict[str, Any] = None
|
||||
) -> bool:
|
||||
"""
|
||||
上传本地文件到MinIO
|
||||
|
||||
Args:
|
||||
file_path: 本地文件路径
|
||||
object_name: MinIO对象名称
|
||||
metadata: 元数据
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
if not self.connected:
|
||||
self.connect()
|
||||
self.ensure_bucket()
|
||||
|
||||
try:
|
||||
file_size = os.stat(file_path).st_size
|
||||
content_type = self._get_content_type(file_path)
|
||||
|
||||
with open(file_path, 'rb') as f:
|
||||
self.client.put_object(
|
||||
self.bucket,
|
||||
object_name,
|
||||
f,
|
||||
file_size,
|
||||
content_type=content_type,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
logger.success(f"文件上传成功: {object_name}, 大小={file_size}")
|
||||
return True
|
||||
|
||||
except S3Error as e:
|
||||
logger.error(f"文件上传失败: {e}")
|
||||
return False
|
||||
|
||||
def upload_bytes(
|
||||
self,
|
||||
data: bytes,
|
||||
object_name: str,
|
||||
content_type: str = "application/octet-stream",
|
||||
metadata: Dict[str, Any] = None
|
||||
) -> bool:
|
||||
"""
|
||||
上传字节数据到MinIO
|
||||
|
||||
Args:
|
||||
data: 文件字节数据
|
||||
object_name: MinIO对象名称
|
||||
content_type: 内容类型
|
||||
metadata: 元数据(注意:MinIO仅支持US-ASCII字符)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
if not self.connected:
|
||||
self.connect()
|
||||
self.ensure_bucket()
|
||||
|
||||
try:
|
||||
data_stream = BytesIO(data)
|
||||
|
||||
# 处理metadata:仅保留ASCII安全字符
|
||||
safe_metadata = None
|
||||
if metadata:
|
||||
safe_metadata = {}
|
||||
for key, value in metadata.items():
|
||||
if isinstance(value, str):
|
||||
# 只保留ASCII字符或转换为安全格式
|
||||
try:
|
||||
value.encode('ascii')
|
||||
safe_metadata[key] = value
|
||||
except UnicodeEncodeError:
|
||||
# 中文字符跳过或用占位符
|
||||
safe_metadata[key] = ""
|
||||
else:
|
||||
safe_metadata[key] = str(value)
|
||||
|
||||
self.client.put_object(
|
||||
self.bucket,
|
||||
object_name,
|
||||
data_stream,
|
||||
len(data),
|
||||
content_type=content_type,
|
||||
metadata=safe_metadata
|
||||
)
|
||||
|
||||
logger.success(f"数据上传成功: {object_name}, 大小={len(data)}")
|
||||
return True
|
||||
|
||||
except S3Error as e:
|
||||
logger.error(f"数据上传失败: {e}")
|
||||
return False
|
||||
|
||||
def download_file(
|
||||
self,
|
||||
object_name: str,
|
||||
file_path: str
|
||||
) -> bool:
|
||||
"""
|
||||
从MinIO下载文件到本地
|
||||
|
||||
Args:
|
||||
object_name: MinIO对象名称
|
||||
file_path: 本地保存路径
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
if not self.connected:
|
||||
self.connect()
|
||||
|
||||
try:
|
||||
self.client.fget_object(
|
||||
self.bucket,
|
||||
object_name,
|
||||
file_path
|
||||
)
|
||||
logger.success(f"文件下载成功: {object_name} -> {file_path}")
|
||||
return True
|
||||
|
||||
except S3Error as e:
|
||||
logger.error(f"文件下载失败: {e}")
|
||||
return False
|
||||
|
||||
def get_object_url(
|
||||
self,
|
||||
object_name: str,
|
||||
expires: int = 3600
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
获取对象下载URL(临时URL)
|
||||
|
||||
Args:
|
||||
object_name: MinIO对象名称
|
||||
expires: URL有效期(秒)
|
||||
|
||||
Returns:
|
||||
str: 下载URL
|
||||
"""
|
||||
if not self.connected:
|
||||
self.connect()
|
||||
|
||||
try:
|
||||
url = self.client.presigned_get_object(
|
||||
self.bucket,
|
||||
object_name,
|
||||
expires=expires
|
||||
)
|
||||
return url
|
||||
|
||||
except S3Error as e:
|
||||
logger.error(f"获取URL失败: {e}")
|
||||
return None
|
||||
|
||||
def get_object_data(self, object_name: str) -> Optional[bytes]:
|
||||
"""
|
||||
获取对象数据(字节)
|
||||
|
||||
Args:
|
||||
object_name: MinIO对象名称
|
||||
|
||||
Returns:
|
||||
bytes: 文件数据
|
||||
"""
|
||||
if not self.connected:
|
||||
self.connect()
|
||||
|
||||
try:
|
||||
response = self.client.get_object(self.bucket, object_name)
|
||||
data = response.read()
|
||||
response.close()
|
||||
response.release_conn()
|
||||
return data
|
||||
|
||||
except S3Error as e:
|
||||
logger.error(f"获取对象数据失败: {e}")
|
||||
return None
|
||||
|
||||
def delete_object(self, object_name: str) -> bool:
|
||||
"""
|
||||
删除对象
|
||||
|
||||
Args:
|
||||
object_name: MinIO对象名称
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
if not self.connected:
|
||||
self.connect()
|
||||
|
||||
try:
|
||||
self.client.remove_object(self.bucket, object_name)
|
||||
logger.info(f"对象删除成功: {object_name}")
|
||||
return True
|
||||
|
||||
except S3Error as e:
|
||||
logger.error(f"对象删除失败: {e}")
|
||||
return False
|
||||
|
||||
def list_objects(self, prefix: str = "") -> list:
|
||||
"""
|
||||
列出存储桶中的对象
|
||||
|
||||
Args:
|
||||
prefix: 对象名称前缀
|
||||
|
||||
Returns:
|
||||
list: 对象列表
|
||||
"""
|
||||
if not self.connected:
|
||||
self.connect()
|
||||
|
||||
try:
|
||||
objects = self.client.list_objects(self.bucket, prefix=prefix)
|
||||
return [obj.object_name for obj in objects]
|
||||
|
||||
except S3Error as e:
|
||||
logger.error(f"列出对象失败: {e}")
|
||||
return []
|
||||
|
||||
def object_exists(self, object_name: str) -> bool:
|
||||
"""
|
||||
检查对象是否存在
|
||||
|
||||
Args:
|
||||
object_name: MinIO对象名称
|
||||
|
||||
Returns:
|
||||
bool: 是否存在
|
||||
"""
|
||||
if not self.connected:
|
||||
self.connect()
|
||||
|
||||
try:
|
||||
self.client.stat_object(self.bucket, object_name)
|
||||
return True
|
||||
|
||||
except S3Error:
|
||||
return False
|
||||
|
||||
def _get_content_type(self, file_path: str) -> str:
|
||||
"""根据文件扩展名获取Content-Type"""
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
content_types = {
|
||||
'.pdf': 'application/pdf',
|
||||
'.doc': 'application/msword',
|
||||
'.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'.txt': 'text/plain',
|
||||
'.json': 'application/json',
|
||||
'.xml': 'application/xml',
|
||||
}
|
||||
return content_types.get(ext, 'application/octet-stream')
|
||||
|
||||
def close(self):
|
||||
"""关闭连接(MinIO客户端无需显式关闭)"""
|
||||
self.connected = False
|
||||
logger.info("MinIO客户端已关闭")
|
||||
|
||||
|
||||
def create_minio_client() -> MinIOClient:
|
||||
"""便捷函数:创建MinIO客户端"""
|
||||
client = MinIOClient()
|
||||
client.connect()
|
||||
client.ensure_bucket()
|
||||
return client
|
||||
Reference in New Issue
Block a user