136 lines
3.9 KiB
Python
136 lines
3.9 KiB
Python
# tests/test_milvus.py
|
|
"""Milvus集成测试"""
|
|
|
|
import pytest
|
|
from loguru import logger
|
|
import sys
|
|
import os
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
|
|
|
from src.services.storage.milvus_client import MilvusClient, SearchResult
|
|
from src.services.embedding.bge_m3_embedder import BGEM3Embedder
|
|
from src.config.settings import settings
|
|
|
|
|
|
class TestMilvusConnection:
|
|
"""Milvus连接测试"""
|
|
|
|
def test_connection(self):
|
|
"""测试Milvus连接"""
|
|
client = MilvusClient()
|
|
|
|
result = client.connect()
|
|
assert result == True
|
|
|
|
client.disconnect()
|
|
|
|
def test_create_collection(self):
|
|
"""测试创建Collection"""
|
|
client = MilvusClient()
|
|
client.connect()
|
|
|
|
result = client.create_collection(recreate=True)
|
|
assert result == True
|
|
|
|
# 检查Collection是否存在
|
|
stats = client.get_collection_stats()
|
|
assert stats["name"] == settings.milvus_collection
|
|
|
|
client.disconnect()
|
|
|
|
|
|
class TestMilvusOperations:
|
|
"""Milvus操作测试"""
|
|
|
|
@pytest.fixture
|
|
def client(self):
|
|
"""创建测试客户端"""
|
|
client = MilvusClient()
|
|
client.connect()
|
|
client.create_collection(recreate=True)
|
|
client.load_collection()
|
|
yield client
|
|
client.disconnect()
|
|
|
|
def test_insert_and_search(self, client):
|
|
"""测试插入和检索"""
|
|
from src.services.embedding.text_chunker import TextChunk, ChunkMetadata
|
|
|
|
# 创建测试数据
|
|
chunks = [
|
|
TextChunk(
|
|
content="第一条 为保障机动车安全技术性能,预防和减少机动车交通事故,保护人身安全,制定本标准。",
|
|
metadata=ChunkMetadata(
|
|
doc_id="test_doc",
|
|
doc_name="测试文档",
|
|
chunk_id="test_chunk_1",
|
|
clause_number="第一条",
|
|
regulation_type="车辆安全"
|
|
)
|
|
),
|
|
TextChunk(
|
|
content="第二条 本标准适用于在我国道路上行驶的所有机动车。",
|
|
metadata=ChunkMetadata(
|
|
doc_id="test_doc",
|
|
doc_name="测试文档",
|
|
chunk_id="test_chunk_2",
|
|
clause_number="第二条",
|
|
regulation_type="车辆安全"
|
|
)
|
|
)
|
|
]
|
|
|
|
# 生成嵌入
|
|
embedder = BGEM3Embedder()
|
|
embeddings = embedder.embed([c.content for c in chunks])
|
|
|
|
# 插入数据
|
|
inserted_ids = client.insert_chunks(chunks, embeddings)
|
|
assert len(inserted_ids) == 2
|
|
|
|
# 执行检索
|
|
query = "机动车安全标准"
|
|
query_embedding = embedder.embed_single(query)
|
|
|
|
results = client.hybrid_search(
|
|
query_dense=query_embedding['dense'].tolist(),
|
|
query_sparse=query_embedding['sparse'],
|
|
top_k=2
|
|
)
|
|
|
|
assert len(results) > 0
|
|
assert "机动车" in results[0].content or "安全" in results[0].content
|
|
|
|
|
|
class TestEmbedding:
|
|
"""嵌入模型测试"""
|
|
|
|
def test_embed_single_text(self):
|
|
"""测试单文本嵌入"""
|
|
embedder = BGEM3Embedder()
|
|
|
|
result = embedder.embed_single("这是一条测试文本")
|
|
|
|
assert 'dense' in result
|
|
assert 'sparse' in result
|
|
assert len(result['dense']) == 1024 # BGE-M3默认维度
|
|
|
|
def test_embed_batch(self):
|
|
"""测试批量嵌入"""
|
|
embedder = BGEM3Embedder()
|
|
|
|
texts = [
|
|
"第一条 本标准规定了机动车安全要求",
|
|
"第二条 机动车应符合以下技术条件",
|
|
"第三条 生产企业应建立质量管理体系"
|
|
]
|
|
|
|
result = embedder.embed(texts)
|
|
|
|
assert len(result.dense_embeddings) == 3
|
|
assert result.dense_embeddings.shape[1] == 1024
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"]) |