Files
AIRegulation-DocAnalysis/tests/test_embedding.py
2026-05-26 12:34:12 +08:00

628 lines
23 KiB
Python

"""Document orchestration and embedding boundary tests for the migrated backend."""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from app.application.documents.services import DocumentCommandService
from app.domain.documents import Chunk, Document, DocumentArtifact, DocumentProcessingRun, DocumentStatus, DocumentStatusEvent, ParsedDocument
from app.infrastructure.storage.json_document_processing_store import JsonDocumentProcessingStore
from app.infrastructure.storage.json_document_repository import JsonDocumentRepository
from app.shared import bootstrap
class FakeRepository:
"""Store document rows in memory for application service tests."""
def __init__(self) -> None:
self.documents: dict[str, Document] = {}
def create(self, document: Document) -> Document:
self.documents[document.doc_id] = document
return document
def update(self, document: Document) -> Document:
self.documents[document.doc_id] = document
return document
def get(self, doc_id: str) -> Document | None:
return self.documents.get(doc_id)
def list(self, limit: int | None = None) -> list[Document]:
"""Return stored documents in insertion order."""
values = list(self.documents.values())
return values[:limit] if limit is not None else values
def delete(self, doc_id: str) -> bool:
"""Delete one document from the in-memory repository."""
return self.documents.pop(doc_id, None) is not None
def update_status(
self,
doc_id: str,
status: DocumentStatus,
*,
error_message: str = "",
chunk_count: int | None = None,
summary: str | None = None,
summary_latency_ms: int | None = None,
parser_name: str | None = None,
index_name: str | None = None,
metadata: dict | None = None,
) -> Document | None:
document = self.documents.get(doc_id)
if not document:
return None
document.status = status
document.error_message = error_message
if chunk_count is not None:
document.chunk_count = chunk_count
if summary is not None:
document.summary = summary
if summary_latency_ms is not None:
document.summary_latency_ms = summary_latency_ms
if parser_name is not None:
document.parser_name = parser_name
if index_name is not None:
document.index_name = index_name
if metadata:
document.metadata.update(metadata)
return document
class FakeBinaryStore:
"""Store binary payloads in memory for upload and retry tests."""
def __init__(self) -> None:
self.saved: dict[str, bytes] = {}
def save(self, *, object_name: str, data: bytes, content_type: str, metadata: dict[str, str] | None = None) -> None:
self.saved[object_name] = data
def read(self, object_name: str) -> bytes:
return self.saved[object_name]
def delete(self, object_name: str) -> None:
self.saved.pop(object_name, None)
class FakeParser:
"""Return a stable parsed document for deterministic service tests."""
def parse(self, *, file_path: str, doc_id: str, doc_name: str) -> ParsedDocument:
return ParsedDocument(
doc_id=doc_id,
doc_name=doc_name,
raw_layouts=[{"uniqueId": "layout-1", "type": "text"}],
structure_nodes=[{"title": "第一章"}],
semantic_blocks=[{"semantic_id": "semantic-1", "text": "法规正文", "section_title": "第一章"}],
vector_chunks=[
{
"chunk_id": f"{doc_id}-chunk-1",
"semantic_id": "semantic-1",
"chunk_type": "section_text",
"section_title": "第一章",
"section_path": ["第一章"],
"page_start": 1,
"text": "法规正文",
"embedding_text": "标准:测试\n章节:第一章\n\n法规正文",
}
],
parser_name="fake_parser",
metadata={"task_id": "task-123", "artifact_prefix": "artifacts", "layout_count": 1},
)
class FakeChunkBuilder:
"""Build one deterministic chunk from the fake parsed document."""
def build(self, *, parsed_document: ParsedDocument, regulation_type: str, version: str) -> list[Chunk]:
return [
Chunk(
chunk_id=f"{parsed_document.doc_id}-chunk-1",
doc_id=parsed_document.doc_id,
doc_name=parsed_document.doc_name,
content="法规正文",
embedding_text="标准:测试\n章节:第一章\n\n法规正文",
section_title="第一章",
section_path=["第一章"],
page_number=1,
regulation_type=regulation_type,
version=version,
semantic_id="semantic-1",
block_type="section_text",
metadata={"source": "aliyun_vector_chunk"},
)
]
class FakeEmbeddingProvider:
"""Capture embedding calls and return fixed-length vectors."""
def __init__(self) -> None:
self.calls: list[list[str]] = []
def embed_texts(self, texts: list[str]) -> list[list[float]]:
self.calls.append(texts)
return [[0.1] * 1024 for _ in texts]
def embed_query(self, text: str) -> list[float]:
return [0.2] * 1024
class FakeVectorIndex:
"""Capture vector upserts for service assertions."""
def __init__(self) -> None:
self.upserts: list[tuple[list[Chunk], list[list[float]]]] = []
def upsert(self, chunks: list[Chunk], vectors: list[list[float]]) -> int:
self.upserts.append((chunks, vectors))
return len(chunks)
def delete_by_document(self, doc_id: str) -> int:
return 0
def search(self, query_vector: list[float], top_k: int, filters: str | None = None):
return []
def health(self) -> dict:
return {"collection_name": "regulations_dense_1024_v1"}
@dataclass
class FakeProcessingStore:
"""Record processing history method calls for orchestration assertions."""
runs: list[DocumentProcessingRun] = None
status_events: list[DocumentStatusEvent] = None
artifact_batches: list[list[DocumentArtifact]] = None
deleted_doc_ids: list[str] = None
stored_run_ids: list[str] = None
parsed_calls: list[dict] = None
indexed_calls: list[dict] = None
failed_calls: list[dict] = None
def __post_init__(self) -> None:
"""Initialize mutable call collections for each fake instance."""
self.runs = []
self.status_events = []
self.artifact_batches = []
self.deleted_doc_ids = []
self.stored_run_ids = []
self.parsed_calls = []
self.indexed_calls = []
self.failed_calls = []
def create_run(self, run: DocumentProcessingRun) -> DocumentProcessingRun:
"""Store the created run and return it unchanged."""
self.runs.append(run)
return run
def mark_run_stored(self, run_id: str, *, stored_at=None, metadata: dict | None = None) -> DocumentProcessingRun | None:
"""Record that one run reached the stored stage."""
self.stored_run_ids.append(run_id)
return next((run for run in self.runs if run.run_id == run_id), None)
def mark_run_parsed(
self,
run_id: str,
*,
parser_backend: str,
layout_count: int,
structure_node_count: int,
semantic_block_count: int,
vector_chunk_count: int,
parsed_at=None,
metadata: dict | None = None,
) -> DocumentProcessingRun | None:
"""Record parse metrics for one run."""
self.parsed_calls.append(
{
"run_id": run_id,
"parser_backend": parser_backend,
"layout_count": layout_count,
"structure_node_count": structure_node_count,
"semantic_block_count": semantic_block_count,
"vector_chunk_count": vector_chunk_count,
"metadata": metadata or {},
}
)
return next((run for run in self.runs if run.run_id == run_id), None)
def mark_run_indexed(
self,
run_id: str,
*,
chunk_count: int,
index_name: str,
indexed_at=None,
finished_at=None,
metadata: dict | None = None,
) -> DocumentProcessingRun | None:
"""Record index completion for one run."""
self.indexed_calls.append(
{
"run_id": run_id,
"chunk_count": chunk_count,
"index_name": index_name,
"metadata": metadata or {},
}
)
return next((run for run in self.runs if run.run_id == run_id), None)
def mark_run_failed(
self,
run_id: str,
*,
failure_stage: str,
error_message: str,
finished_at=None,
metadata: dict | None = None,
) -> DocumentProcessingRun | None:
"""Record terminal failure details for one run."""
self.failed_calls.append(
{
"run_id": run_id,
"failure_stage": failure_stage,
"error_message": error_message,
"metadata": metadata or {},
}
)
return next((run for run in self.runs if run.run_id == run_id), None)
def append_status_event(self, event: DocumentStatusEvent) -> DocumentStatusEvent:
"""Store one status event."""
self.status_events.append(event)
return event
def replace_artifacts_for_run(self, run_id: str, artifacts: list[DocumentArtifact]) -> list[DocumentArtifact]:
"""Store one artifact replacement batch."""
self.artifact_batches.append(artifacts)
return artifacts
def delete_by_document(self, doc_id: str) -> None:
"""Record an explicit document-history delete request."""
self.deleted_doc_ids.append(doc_id)
def list_runs_by_document(self, doc_id: str) -> list[DocumentProcessingRun]:
"""Return runs for completeness of the fake port."""
return [run for run in self.runs if run.doc_id == doc_id]
def get_run(self, run_id: str) -> DocumentProcessingRun | None:
"""Return one run for completeness of the fake port."""
return next((run for run in self.runs if run.run_id == run_id), None)
def list_status_events_by_document(self, doc_id: str) -> list[DocumentStatusEvent]:
"""Return status events for completeness of the fake port."""
return [event for event in self.status_events if event.doc_id == doc_id]
def list_status_events_by_run(self, run_id: str) -> list[DocumentStatusEvent]:
"""Return status events for completeness of the fake port."""
return [event for event in self.status_events if event.run_id == run_id]
def list_artifacts_by_document(self, doc_id: str) -> list[DocumentArtifact]:
"""Return artifact references for completeness of the fake port."""
return [artifact for batch in self.artifact_batches for artifact in batch if artifact.doc_id == doc_id]
def list_artifacts_by_run(self, run_id: str) -> list[DocumentArtifact]:
"""Return artifact references for completeness of the fake port."""
return [artifact for batch in self.artifact_batches for artifact in batch if artifact.run_id == run_id]
class FailingParser:
"""Raise a deterministic parser failure for failure-stage assertions."""
def parse(self, *, file_path: str, doc_id: str, doc_name: str) -> ParsedDocument:
raise RuntimeError("parser exploded")
def test_document_command_service_uses_1024_dense_embedding_and_updates_status():
repository = FakeRepository()
binary_store = FakeBinaryStore()
embedding_provider = FakeEmbeddingProvider()
vector_index = FakeVectorIndex()
processing_store = FakeProcessingStore()
service = DocumentCommandService(
document_repository=repository,
binary_store=binary_store,
parser=FakeParser(),
chunk_builder=FakeChunkBuilder(),
embedding_provider=embedding_provider,
vector_index=vector_index,
document_processing_store=processing_store,
)
result = service.upload_and_process(
doc_id="doc12345",
file_name="test.pdf",
content=b"dummy pdf bytes",
content_type="application/pdf",
doc_name="测试法规",
regulation_type="车辆安全",
version="2026",
generate_summary=False,
)
assert result.status == "indexed"
assert result.num_chunks == 1
assert embedding_provider.calls == [["标准:测试\n章节:第一章\n\n法规正文"]]
assert len(vector_index.upserts) == 1
stored = repository.get("doc12345")
assert stored is not None
assert stored.status == DocumentStatus.INDEXED
assert stored.chunk_count == 1
assert stored.parser_name == "fake_parser"
assert stored.index_name == "regulations_dense_1024_v1"
assert stored.metadata["parse_task_id"] == "task-123"
assert stored.metadata["artifact_keys"]["vector_chunks"].endswith("/vector_chunks.json")
assert len(processing_store.runs) == 1
assert processing_store.runs[0].trigger_type == "upload"
assert processing_store.stored_run_ids == [processing_store.runs[0].run_id]
assert processing_store.parsed_calls[0]["vector_chunk_count"] == 1
assert processing_store.indexed_calls[0]["index_name"] == "regulations_dense_1024_v1"
assert [event.to_status for event in processing_store.status_events] == ["pending", "stored", "parsed", "indexed"]
assert {artifact.artifact_type for artifact in processing_store.artifact_batches[0]} == {
"layouts",
"structure_nodes",
"semantic_blocks",
"vector_chunks",
}
def test_document_command_service_retry_marks_processing_run_as_retry():
repository = FakeRepository()
binary_store = FakeBinaryStore()
embedding_provider = FakeEmbeddingProvider()
vector_index = FakeVectorIndex()
processing_store = FakeProcessingStore()
repository.create(
Document(
doc_id="doc-retry",
doc_name="Retry Doc",
file_name="retry.pdf",
object_name="doc-retry/retry.pdf",
content_type="application/pdf",
size_bytes=4,
regulation_type="车辆安全",
version="2026",
metadata={"generate_summary": False},
)
)
binary_store.save(
object_name="doc-retry/retry.pdf",
data=b"data",
content_type="application/pdf",
metadata={"doc_id": "doc-retry"},
)
service = DocumentCommandService(
document_repository=repository,
binary_store=binary_store,
parser=FakeParser(),
chunk_builder=FakeChunkBuilder(),
embedding_provider=embedding_provider,
vector_index=vector_index,
document_processing_store=processing_store,
)
result = service.retry("doc-retry")
assert result.status == "indexed"
assert processing_store.runs[0].trigger_type == "retry"
def test_document_command_service_records_failed_processing_stage():
repository = FakeRepository()
binary_store = FakeBinaryStore()
embedding_provider = FakeEmbeddingProvider()
vector_index = FakeVectorIndex()
processing_store = FakeProcessingStore()
service = DocumentCommandService(
document_repository=repository,
binary_store=binary_store,
parser=FailingParser(),
chunk_builder=FakeChunkBuilder(),
embedding_provider=embedding_provider,
vector_index=vector_index,
document_processing_store=processing_store,
)
result = service.upload_and_process(
doc_id="doc-fail",
file_name="test.pdf",
content=b"dummy pdf bytes",
content_type="application/pdf",
doc_name="测试法规",
regulation_type="车辆安全",
version="2026",
generate_summary=False,
)
assert result.status == "failed"
assert processing_store.failed_calls[0]["failure_stage"] == "parse"
assert processing_store.status_events[-1].to_status == "failed"
assert repository.get("doc-fail").metadata["failure_stage"] == "parse"
def test_document_command_service_delete_cleans_processing_history_when_present():
repository = FakeRepository()
binary_store = FakeBinaryStore()
vector_index = FakeVectorIndex()
processing_store = FakeProcessingStore()
repository.create(
Document(
doc_id="doc-delete",
doc_name="Delete Doc",
file_name="delete.pdf",
object_name="doc-delete/delete.pdf",
content_type="application/pdf",
size_bytes=4,
)
)
service = DocumentCommandService(
document_repository=repository,
binary_store=binary_store,
parser=FakeParser(),
chunk_builder=FakeChunkBuilder(),
embedding_provider=FakeEmbeddingProvider(),
vector_index=vector_index,
document_processing_store=processing_store,
)
deleted = service.delete("doc-delete")
assert deleted is True
assert processing_store.deleted_doc_ids == ["doc-delete"]
def test_document_command_service_persists_processing_history_with_json_store(tmp_path: Path):
repository = JsonDocumentRepository(str(tmp_path / "documents.json"))
processing_store = JsonDocumentProcessingStore(str(tmp_path / "document_processing.json"))
binary_store = FakeBinaryStore()
embedding_provider = FakeEmbeddingProvider()
vector_index = FakeVectorIndex()
service = DocumentCommandService(
document_repository=repository,
binary_store=binary_store,
parser=FakeParser(),
chunk_builder=FakeChunkBuilder(),
embedding_provider=embedding_provider,
vector_index=vector_index,
document_processing_store=processing_store,
)
result = service.upload_and_process(
doc_id="doc-json-flow",
file_name="test.pdf",
content=b"dummy pdf bytes",
content_type="application/pdf",
doc_name="测试法规",
regulation_type="车辆安全",
version="2026",
generate_summary=False,
)
stored = repository.get("doc-json-flow")
runs = processing_store.list_runs_by_document("doc-json-flow")
events = processing_store.list_status_events_by_document("doc-json-flow")
artifacts = processing_store.list_artifacts_by_document("doc-json-flow")
assert result.status == "indexed"
assert stored is not None and stored.status == DocumentStatus.INDEXED
assert len(runs) == 1
assert runs[0].trigger_type == "upload"
assert runs[0].run_status == "succeeded"
assert [event.to_status for event in events] == ["pending", "stored", "parsed", "indexed"]
assert {artifact.artifact_type for artifact in artifacts} == {
"layouts",
"structure_nodes",
"semantic_blocks",
"vector_chunks",
}
def test_document_command_service_retry_creates_second_json_processing_run(tmp_path: Path):
repository = JsonDocumentRepository(str(tmp_path / "documents.json"))
processing_store = JsonDocumentProcessingStore(str(tmp_path / "document_processing.json"))
binary_store = FakeBinaryStore()
repository.create(
Document(
doc_id="doc-json-retry",
doc_name="Retry Doc",
file_name="retry.pdf",
object_name="doc-json-retry/retry.pdf",
content_type="application/pdf",
size_bytes=4,
regulation_type="车辆安全",
version="2026",
metadata={"generate_summary": False},
)
)
binary_store.save(
object_name="doc-json-retry/retry.pdf",
data=b"data",
content_type="application/pdf",
metadata={"doc_id": "doc-json-retry"},
)
service = DocumentCommandService(
document_repository=repository,
binary_store=binary_store,
parser=FakeParser(),
chunk_builder=FakeChunkBuilder(),
embedding_provider=FakeEmbeddingProvider(),
vector_index=FakeVectorIndex(),
document_processing_store=processing_store,
)
first = service.retry("doc-json-retry")
second = service.retry("doc-json-retry")
runs = processing_store.list_runs_by_document("doc-json-retry")
assert first.status == "indexed"
assert second.status == "indexed"
assert len(runs) == 2
assert {run.trigger_type for run in runs} == {"retry"}
def test_document_command_service_delete_removes_json_processing_history(tmp_path: Path):
repository = JsonDocumentRepository(str(tmp_path / "documents.json"))
processing_store = JsonDocumentProcessingStore(str(tmp_path / "document_processing.json"))
binary_store = FakeBinaryStore()
service = DocumentCommandService(
document_repository=repository,
binary_store=binary_store,
parser=FakeParser(),
chunk_builder=FakeChunkBuilder(),
embedding_provider=FakeEmbeddingProvider(),
vector_index=FakeVectorIndex(),
document_processing_store=processing_store,
)
service.upload_and_process(
doc_id="doc-json-delete",
file_name="delete.pdf",
content=b"delete me",
content_type="application/pdf",
doc_name="Delete Doc",
regulation_type="车辆安全",
version="2026",
generate_summary=False,
)
deleted = service.delete("doc-json-delete")
assert deleted is True
assert processing_store.list_runs_by_document("doc-json-delete") == []
assert processing_store.list_status_events_by_document("doc-json-delete") == []
assert processing_store.list_artifacts_by_document("doc-json-delete") == []
def test_bootstrap_returns_json_processing_store_for_json_backend(tmp_path: Path):
original_backend = bootstrap.settings.document_repository_backend
original_path = bootstrap.settings.document_processing_metadata_path
bootstrap.get_document_processing_store.cache_clear()
try:
bootstrap.settings.document_repository_backend = "json"
bootstrap.settings.document_processing_metadata_path = str(tmp_path / "document_processing.json")
store = bootstrap.get_document_processing_store()
assert store.__class__.__name__ == "JsonDocumentProcessingStore"
finally:
bootstrap.settings.document_repository_backend = original_backend
bootstrap.settings.document_processing_metadata_path = original_path
bootstrap.get_document_processing_store.cache_clear()
def test_bootstrap_defaults_to_aliyun_parser_and_chunk_builder():
bootstrap.get_parser.cache_clear()
bootstrap.get_chunk_builder.cache_clear()
parser = bootstrap.get_parser()
chunk_builder = bootstrap.get_chunk_builder()
assert parser.__class__.__name__ == "AliyunDocumentParser"
assert chunk_builder.__class__.__name__ == "AliyunVectorChunkBuilder"