186 lines
5.6 KiB
Python
186 lines
5.6 KiB
Python
# tests/test_embedding.py
|
|
"""嵌入和分块测试"""
|
|
|
|
import pytest
|
|
from loguru import logger
|
|
import sys
|
|
import os
|
|
|
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__))
|
|
sys.path.insert(0, os.path.join(PROJECT_ROOT, "backend"))
|
|
|
|
from app.services.embedding.text_chunker import RegulationChunker, TextChunk, ChunkMetadata
|
|
from app.services.embedding.bge_m3_embedder import BGEM3Embedder, EmbeddingResult
|
|
|
|
|
|
class TestRegulationChunker:
|
|
"""法规分块器测试"""
|
|
|
|
@pytest.fixture
|
|
def chunker(self):
|
|
"""创建分块器实例"""
|
|
return RegulationChunker(chunk_size=512)
|
|
|
|
@pytest.fixture
|
|
def sample_regulation(self):
|
|
"""示例法规文档"""
|
|
return """
|
|
# GB 7258-2017 机动车运行安全技术条件
|
|
|
|
第一章 范围
|
|
|
|
第一条 本标准规定了机动车运行安全技术条件。
|
|
|
|
第二条 本标准适用于在我国道路上行驶的所有机动车。
|
|
|
|
第二章 术语和定义
|
|
|
|
第三条 下列术语和定义适用于本标准。
|
|
|
|
(一)机动车:以动力装置驱动或者牵引,上道路行驶的供人员乘用或者用于运送物品以及进行工程专项作业的轮式车辆。
|
|
|
|
(二)整车:完整的机动车,包括所有必要的部件和系统。
|
|
|
|
第三章 技术要求
|
|
|
|
第四条 机动车应满足以下基本要求:
|
|
|
|
1. 车辆应具有唯一的产品标识;
|
|
2. 车辆结构应安全可靠;
|
|
3. 车辆应配备必要的安全装置。
|
|
"""
|
|
|
|
def test_chunk_document(self, chunker, sample_regulation):
|
|
"""测试文档分块"""
|
|
chunks = chunker.chunk_document(
|
|
sample_regulation,
|
|
doc_id="gb7258",
|
|
doc_name="GB 7258-2017",
|
|
regulation_type="车辆安全"
|
|
)
|
|
|
|
# 应该有多个分块
|
|
assert len(chunks) > 3
|
|
|
|
# 每个分块应该有内容
|
|
for chunk in chunks:
|
|
assert len(chunk.content) > 0
|
|
assert chunk.metadata.doc_id == "gb7258"
|
|
|
|
def test_section_detection(self, chunker, sample_regulation):
|
|
"""测试章节检测"""
|
|
chunks = chunker.chunk_document(
|
|
sample_regulation,
|
|
doc_id="test",
|
|
doc_name="测试"
|
|
)
|
|
|
|
# 应该检测到章节
|
|
section_numbers = [c.metadata.section_number for c in chunks]
|
|
assert any(s for s in section_numbers) # 至少有一个章节编号
|
|
|
|
def test_clause_detection(self, chunker, sample_regulation):
|
|
"""测试条款检测"""
|
|
chunks = chunker.chunk_document(
|
|
sample_regulation,
|
|
doc_id="test",
|
|
doc_name="测试"
|
|
)
|
|
|
|
# 应该检测到条款
|
|
clause_numbers = [c.metadata.clause_number for c in chunks]
|
|
assert any(c for c in clause_numbers) # 至少有一个条款编号
|
|
|
|
def test_long_clause_split(self, chunker):
|
|
"""测试长条款分割"""
|
|
long_clause = """
|
|
第一条 本条款内容很长,需要进行分割处理。
|
|
|
|
本条款包含以下多项内容:
|
|
1. 第一项内容,这是一个非常长的子项,包含了大量的文字描述,需要进行适当的处理。
|
|
2. 第二项内容,这也是一个较长的子项,包含了相关的技术要求和规范说明。
|
|
3. 第三项内容,继续描述相关要求和注意事项,确保文档的完整性和规范性。
|
|
4. 第四项内容,补充说明其他相关事项,保证内容的全面性。
|
|
"""
|
|
|
|
chunks = chunker.chunk_document(
|
|
long_clause,
|
|
doc_id="test",
|
|
doc_name="测试"
|
|
)
|
|
|
|
# 长条款应该被分割成多个chunk
|
|
assert len(chunks) >= 1
|
|
|
|
|
|
class TestBGEM3Embedder:
|
|
"""BGE-M3嵌入模型测试"""
|
|
|
|
@pytest.fixture
|
|
def embedder(self):
|
|
"""创建嵌入模型实例"""
|
|
try:
|
|
return BGEM3Embedder()
|
|
except Exception as e:
|
|
pytest.skip(f"嵌入模型加载失败: {e}")
|
|
|
|
def test_embed_single(self, embedder):
|
|
"""测试单文本嵌入"""
|
|
text = "这是一条测试文本"
|
|
result = embedder.embed_single(text)
|
|
|
|
# 应该包含dense和sparse向量
|
|
assert 'dense' in result
|
|
assert 'sparse' in result
|
|
|
|
# dense向量维度应该是1024
|
|
assert len(result['dense']) == 1024
|
|
|
|
def test_embed_batch(self, embedder):
|
|
"""测试批量嵌入"""
|
|
texts = [
|
|
"第一条 本标准规定了机动车安全要求",
|
|
"第二条 机动车应符合技术条件",
|
|
"第三条 生产企业应建立管理体系"
|
|
]
|
|
|
|
result = embedder.embed(texts)
|
|
|
|
# 应该返回正确数量的向量
|
|
assert len(result.dense_embeddings) == 3
|
|
|
|
# 维度应该是1024
|
|
assert result.dense_embeddings.shape[1] == 1024
|
|
|
|
def test_embed_empty_list(self, embedder):
|
|
"""测试空列表嵌入"""
|
|
result = embedder.embed([])
|
|
|
|
# 应该返回空结果
|
|
assert len(result.dense_embeddings) == 0
|
|
|
|
def test_similarity(self, embedder):
|
|
"""测试相似度计算"""
|
|
import numpy as np
|
|
|
|
texts = [
|
|
"机动车安全标准要求",
|
|
"汽车安全技术规范",
|
|
"食品安全管理规定" # 不相关文本
|
|
]
|
|
|
|
result = embedder.embed(texts)
|
|
|
|
# 计算第一个文本与其他文本的相似度
|
|
query = result.dense_embeddings[0]
|
|
docs = result.dense_embeddings[1:]
|
|
|
|
similarities = embedder.compute_similarity(query, docs)
|
|
|
|
# 相关文档的相似度应该更高
|
|
assert similarities[0] > similarities[1] # 车辆安全 > 食品安全
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|