Fix SSE route dependency and align architecture docs
This commit is contained in:
@@ -1,185 +1,197 @@
|
||||
# tests/test_embedding.py
|
||||
"""嵌入和分块测试"""
|
||||
"""新架构下的文档编排与 embedding 边界测试。"""
|
||||
|
||||
import pytest
|
||||
from loguru import logger
|
||||
import sys
|
||||
import os
|
||||
from __future__ import annotations
|
||||
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__))
|
||||
sys.path.insert(0, os.path.join(PROJECT_ROOT, "backend"))
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.services.embedding.text_chunker import RegulationChunker, TextChunk, ChunkMetadata
|
||||
from app.services.embedding.bge_m3_embedder import BGEM3Embedder, EmbeddingResult
|
||||
from app.application.documents.services import DocumentCommandService
|
||||
from app.domain.documents import Chunk, Document, DocumentStatus, ParsedDocument
|
||||
from app.shared import bootstrap
|
||||
|
||||
|
||||
class TestRegulationChunker:
|
||||
"""法规分块器测试"""
|
||||
class FakeRepository:
|
||||
def __init__(self) -> None:
|
||||
self.documents: dict[str, Document] = {}
|
||||
|
||||
@pytest.fixture
|
||||
def chunker(self):
|
||||
"""创建分块器实例"""
|
||||
return RegulationChunker(chunk_size=512)
|
||||
def create(self, document: Document) -> Document:
|
||||
self.documents[document.doc_id] = document
|
||||
return document
|
||||
|
||||
@pytest.fixture
|
||||
def sample_regulation(self):
|
||||
"""示例法规文档"""
|
||||
return """
|
||||
# GB 7258-2017 机动车运行安全技术条件
|
||||
def update(self, document: Document) -> Document:
|
||||
self.documents[document.doc_id] = document
|
||||
return document
|
||||
|
||||
第一章 范围
|
||||
def get(self, doc_id: str) -> Document | None:
|
||||
return self.documents.get(doc_id)
|
||||
|
||||
第一条 本标准规定了机动车运行安全技术条件。
|
||||
def list(self, limit: int | None = None) -> list[Document]:
|
||||
values = list(self.documents.values())
|
||||
return values[:limit] if limit is not None else values
|
||||
|
||||
第二条 本标准适用于在我国道路上行驶的所有机动车。
|
||||
def update_status(
|
||||
self,
|
||||
doc_id: str,
|
||||
status: DocumentStatus,
|
||||
*,
|
||||
error_message: str = "",
|
||||
chunk_count: int | None = None,
|
||||
summary: str | None = None,
|
||||
summary_latency_ms: int | None = None,
|
||||
parser_name: str | None = None,
|
||||
index_name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> Document | None:
|
||||
document = self.documents.get(doc_id)
|
||||
if not document:
|
||||
return None
|
||||
document.status = status
|
||||
document.error_message = error_message
|
||||
if chunk_count is not None:
|
||||
document.chunk_count = chunk_count
|
||||
if summary is not None:
|
||||
document.summary = summary
|
||||
if summary_latency_ms is not None:
|
||||
document.summary_latency_ms = summary_latency_ms
|
||||
if parser_name is not None:
|
||||
document.parser_name = parser_name
|
||||
if index_name is not None:
|
||||
document.index_name = index_name
|
||||
if metadata:
|
||||
document.metadata.update(metadata)
|
||||
return document
|
||||
|
||||
第二章 术语和定义
|
||||
|
||||
第三条 下列术语和定义适用于本标准。
|
||||
class FakeBinaryStore:
|
||||
def __init__(self) -> None:
|
||||
self.saved: dict[str, bytes] = {}
|
||||
|
||||
(一)机动车:以动力装置驱动或者牵引,上道路行驶的供人员乘用或者用于运送物品以及进行工程专项作业的轮式车辆。
|
||||
def save(self, *, object_name: str, data: bytes, content_type: str, metadata: dict[str, str] | None = None) -> None:
|
||||
self.saved[object_name] = data
|
||||
|
||||
(二)整车:完整的机动车,包括所有必要的部件和系统。
|
||||
def read(self, object_name: str) -> bytes:
|
||||
return self.saved[object_name]
|
||||
|
||||
第三章 技术要求
|
||||
def delete(self, object_name: str) -> None:
|
||||
self.saved.pop(object_name, None)
|
||||
|
||||
第四条 机动车应满足以下基本要求:
|
||||
|
||||
1. 车辆应具有唯一的产品标识;
|
||||
2. 车辆结构应安全可靠;
|
||||
3. 车辆应配备必要的安全装置。
|
||||
"""
|
||||
|
||||
def test_chunk_document(self, chunker, sample_regulation):
|
||||
"""测试文档分块"""
|
||||
chunks = chunker.chunk_document(
|
||||
sample_regulation,
|
||||
doc_id="gb7258",
|
||||
doc_name="GB 7258-2017",
|
||||
regulation_type="车辆安全"
|
||||
class FakeParser:
|
||||
def parse(self, *, file_path: str, doc_id: str, doc_name: str) -> ParsedDocument:
|
||||
return ParsedDocument(
|
||||
doc_id=doc_id,
|
||||
doc_name=doc_name,
|
||||
structure_nodes=[{"title": "第一章"}],
|
||||
semantic_blocks=[{"semantic_id": "semantic-1", "text": "法规正文", "section_title": "第一章"}],
|
||||
vector_chunks=[
|
||||
{
|
||||
"chunk_id": f"{doc_id}-chunk-1",
|
||||
"semantic_id": "semantic-1",
|
||||
"chunk_type": "section_text",
|
||||
"section_title": "第一章",
|
||||
"section_path": ["第一章"],
|
||||
"page_start": 1,
|
||||
"text": "法规正文",
|
||||
"embedding_text": "标准:测试\n章节:第一章\n\n法规正文",
|
||||
}
|
||||
],
|
||||
parser_name="fake_parser",
|
||||
)
|
||||
|
||||
# 应该有多个分块
|
||||
assert len(chunks) > 3
|
||||
|
||||
# 每个分块应该有内容
|
||||
for chunk in chunks:
|
||||
assert len(chunk.content) > 0
|
||||
assert chunk.metadata.doc_id == "gb7258"
|
||||
|
||||
def test_section_detection(self, chunker, sample_regulation):
|
||||
"""测试章节检测"""
|
||||
chunks = chunker.chunk_document(
|
||||
sample_regulation,
|
||||
doc_id="test",
|
||||
doc_name="测试"
|
||||
)
|
||||
|
||||
# 应该检测到章节
|
||||
section_numbers = [c.metadata.section_number for c in chunks]
|
||||
assert any(s for s in section_numbers) # 至少有一个章节编号
|
||||
|
||||
def test_clause_detection(self, chunker, sample_regulation):
|
||||
"""测试条款检测"""
|
||||
chunks = chunker.chunk_document(
|
||||
sample_regulation,
|
||||
doc_id="test",
|
||||
doc_name="测试"
|
||||
)
|
||||
|
||||
# 应该检测到条款
|
||||
clause_numbers = [c.metadata.clause_number for c in chunks]
|
||||
assert any(c for c in clause_numbers) # 至少有一个条款编号
|
||||
|
||||
def test_long_clause_split(self, chunker):
|
||||
"""测试长条款分割"""
|
||||
long_clause = """
|
||||
第一条 本条款内容很长,需要进行分割处理。
|
||||
|
||||
本条款包含以下多项内容:
|
||||
1. 第一项内容,这是一个非常长的子项,包含了大量的文字描述,需要进行适当的处理。
|
||||
2. 第二项内容,这也是一个较长的子项,包含了相关的技术要求和规范说明。
|
||||
3. 第三项内容,继续描述相关要求和注意事项,确保文档的完整性和规范性。
|
||||
4. 第四项内容,补充说明其他相关事项,保证内容的全面性。
|
||||
"""
|
||||
|
||||
chunks = chunker.chunk_document(
|
||||
long_clause,
|
||||
doc_id="test",
|
||||
doc_name="测试"
|
||||
)
|
||||
|
||||
# 长条款应该被分割成多个chunk
|
||||
assert len(chunks) >= 1
|
||||
|
||||
|
||||
class TestBGEM3Embedder:
|
||||
"""BGE-M3嵌入模型测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def embedder(self):
|
||||
"""创建嵌入模型实例"""
|
||||
try:
|
||||
return BGEM3Embedder()
|
||||
except Exception as e:
|
||||
pytest.skip(f"嵌入模型加载失败: {e}")
|
||||
|
||||
def test_embed_single(self, embedder):
|
||||
"""测试单文本嵌入"""
|
||||
text = "这是一条测试文本"
|
||||
result = embedder.embed_single(text)
|
||||
|
||||
# 应该包含dense和sparse向量
|
||||
assert 'dense' in result
|
||||
assert 'sparse' in result
|
||||
|
||||
# dense向量维度应该是1024
|
||||
assert len(result['dense']) == 1024
|
||||
|
||||
def test_embed_batch(self, embedder):
|
||||
"""测试批量嵌入"""
|
||||
texts = [
|
||||
"第一条 本标准规定了机动车安全要求",
|
||||
"第二条 机动车应符合技术条件",
|
||||
"第三条 生产企业应建立管理体系"
|
||||
class FakeChunkBuilder:
|
||||
def build(self, *, parsed_document: ParsedDocument, regulation_type: str, version: str) -> list[Chunk]:
|
||||
return [
|
||||
Chunk(
|
||||
chunk_id=f"{parsed_document.doc_id}-chunk-1",
|
||||
doc_id=parsed_document.doc_id,
|
||||
doc_name=parsed_document.doc_name,
|
||||
content="法规正文",
|
||||
embedding_text="标准:测试\n章节:第一章\n\n法规正文",
|
||||
section_title="第一章",
|
||||
section_path=["第一章"],
|
||||
page_number=1,
|
||||
regulation_type=regulation_type,
|
||||
version=version,
|
||||
semantic_id="semantic-1",
|
||||
block_type="section_text",
|
||||
metadata={"source": "aliyun_vector_chunk"},
|
||||
)
|
||||
]
|
||||
|
||||
result = embedder.embed(texts)
|
||||
|
||||
# 应该返回正确数量的向量
|
||||
assert len(result.dense_embeddings) == 3
|
||||
class FakeEmbeddingProvider:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[list[str]] = []
|
||||
|
||||
# 维度应该是1024
|
||||
assert result.dense_embeddings.shape[1] == 1024
|
||||
def embed_texts(self, texts: list[str]) -> list[list[float]]:
|
||||
self.calls.append(texts)
|
||||
return [[0.1] * 1536 for _ in texts]
|
||||
|
||||
def test_embed_empty_list(self, embedder):
|
||||
"""测试空列表嵌入"""
|
||||
result = embedder.embed([])
|
||||
|
||||
# 应该返回空结果
|
||||
assert len(result.dense_embeddings) == 0
|
||||
|
||||
def test_similarity(self, embedder):
|
||||
"""测试相似度计算"""
|
||||
import numpy as np
|
||||
|
||||
texts = [
|
||||
"机动车安全标准要求",
|
||||
"汽车安全技术规范",
|
||||
"食品安全管理规定" # 不相关文本
|
||||
]
|
||||
|
||||
result = embedder.embed(texts)
|
||||
|
||||
# 计算第一个文本与其他文本的相似度
|
||||
query = result.dense_embeddings[0]
|
||||
docs = result.dense_embeddings[1:]
|
||||
|
||||
similarities = embedder.compute_similarity(query, docs)
|
||||
|
||||
# 相关文档的相似度应该更高
|
||||
assert similarities[0] > similarities[1] # 车辆安全 > 食品安全
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
return [0.2] * 1536
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
class FakeVectorIndex:
|
||||
def __init__(self) -> None:
|
||||
self.upserts: list[tuple[list[Chunk], list[list[float]]]] = []
|
||||
|
||||
def upsert(self, chunks: list[Chunk], vectors: list[list[float]]) -> int:
|
||||
self.upserts.append((chunks, vectors))
|
||||
return len(chunks)
|
||||
|
||||
def delete_by_document(self, doc_id: str) -> int:
|
||||
return 0
|
||||
|
||||
def search(self, query_vector: list[float], top_k: int, filters: str | None = None):
|
||||
return []
|
||||
|
||||
def health(self) -> dict:
|
||||
return {"collection_name": "regulations_dense_1536"}
|
||||
|
||||
|
||||
def test_document_command_service_uses_1536_dense_embedding_and_updates_status():
|
||||
repository = FakeRepository()
|
||||
binary_store = FakeBinaryStore()
|
||||
embedding_provider = FakeEmbeddingProvider()
|
||||
vector_index = FakeVectorIndex()
|
||||
service = DocumentCommandService(
|
||||
document_repository=repository,
|
||||
binary_store=binary_store,
|
||||
parser=FakeParser(),
|
||||
chunk_builder=FakeChunkBuilder(),
|
||||
embedding_provider=embedding_provider,
|
||||
vector_index=vector_index,
|
||||
)
|
||||
|
||||
result = service.upload_and_process(
|
||||
doc_id="doc12345",
|
||||
file_name="test.pdf",
|
||||
content=b"dummy pdf bytes",
|
||||
content_type="application/pdf",
|
||||
doc_name="测试法规",
|
||||
regulation_type="车辆安全",
|
||||
version="2026",
|
||||
generate_summary=False,
|
||||
)
|
||||
|
||||
assert result.status == "indexed"
|
||||
assert result.num_chunks == 1
|
||||
assert embedding_provider.calls == [["标准:测试\n章节:第一章\n\n法规正文"]]
|
||||
assert len(vector_index.upserts) == 1
|
||||
stored = repository.get("doc12345")
|
||||
assert stored is not None
|
||||
assert stored.status == DocumentStatus.INDEXED
|
||||
assert stored.chunk_count == 1
|
||||
assert stored.parser_name == "fake_parser"
|
||||
assert stored.index_name == "regulations_dense_1536"
|
||||
|
||||
|
||||
def test_bootstrap_defaults_to_local_parser_and_chunk_builder():
|
||||
bootstrap.get_parser.cache_clear()
|
||||
bootstrap.get_chunk_builder.cache_clear()
|
||||
|
||||
parser = bootstrap.get_parser()
|
||||
chunk_builder = bootstrap.get_chunk_builder()
|
||||
|
||||
assert parser.__class__.__name__ == "LocalDocumentParser"
|
||||
assert chunk_builder.__class__.__name__ == "LocalRegulationChunkBuilder"
|
||||
|
||||
Reference in New Issue
Block a user