Files
AIRegulation-DocAnalysis/backend/tests/compliance/test_pipeline.py

141 lines
5.3 KiB
Python
Raw Normal View History

2026-06-10 11:10:36 +08:00
import asyncio
import pytest
from unittest.mock import MagicMock, patch
from datetime import datetime
from app.infrastructure.vectorstore.pass_through_reranker import PassThroughReranker
from app.domain.retrieval.models import RetrievedChunk
from app.domain.compliance.ports import AnalysisRecord, FindingRecord
# ── helpers ──────────────────────────────────────────────────────────────────
def _make_chunk(score: float) -> RetrievedChunk:
return RetrievedChunk(
chunk_id="c1",
doc_id="d1",
doc_title="Test Doc",
section_title="S1",
text="some text",
score=score,
page_start=1,
)
def _make_mock_client(content: str = '{"status":"ok","title":"T","desc":"D","clause_ref":"A1"}'):
client = MagicMock()
response = MagicMock()
response.is_success = True
response.content = content
client.chat.return_value = response
return client
def _make_mock_retrieval():
svc = MagicMock()
svc.retrieve.return_value = []
return svc
# ── existing tests ────────────────────────────────────────────────────────────
def test_pass_through_returns_top_k():
reranker = PassThroughReranker()
chunks = [_make_chunk(0.9), _make_chunk(0.8), _make_chunk(0.7)]
result = reranker.rerank(query="test", chunks=chunks, top_k=2)
assert len(result) == 2
assert result[0].score == 0.9
def test_pass_through_returns_all_when_top_k_exceeds():
reranker = PassThroughReranker()
chunks = [_make_chunk(0.5)]
result = reranker.rerank(query="test", chunks=chunks, top_k=10)
assert len(result) == 1
# ── new tests ─────────────────────────────────────────────────────────────────
def test_process_single_clause_returns_finding():
from app.application.compliance.pipeline import process_single_clause
client = _make_mock_client()
svc = _make_mock_retrieval()
result = process_single_clause("test clause", 0, svc, client)
assert result["finding"] is not None
assert result["index"] == 0
assert result["chunks"] == []
def test_run_clauses_parallel_runs_all():
from app.application.compliance.pipeline import run_clauses_parallel
client = _make_mock_client()
svc = _make_mock_retrieval()
clauses = ["clause one", "clause two", "clause three"]
results = asyncio.run(run_clauses_parallel(clauses, svc, client))
assert len(results) == 3
assert all(r["index"] == i for i, r in enumerate(results))
def test_run_clauses_parallel_handles_clause_failure():
from app.application.compliance.pipeline import run_clauses_parallel
svc = _make_mock_retrieval()
bad_client = MagicMock()
bad_client.chat.side_effect = RuntimeError("LLM exploded")
results = asyncio.run(run_clauses_parallel(
["clause one", "clause two"], svc, bad_client
))
assert len(results) == 2
assert all(r["finding"] is None for r in results)
assert all(r["chunks"] == [] for r in results)
# ── helpers for new tests ─────────────────────────────────────────────────────
def _sample_analysis() -> AnalysisRecord:
return AnalysisRecord(
id="a1", created_at=datetime(2026, 6, 8), created_by="u",
doc_name="doc.pdf", standard_name="EU AI Act",
risk_score=72, conclusion="Gaps found.", actions=[], para_text="para",
highlight_terms=[], findings=[],
)
def _sample_finding(status: str = "risk") -> FindingRecord:
return FindingRecord(
id="f1", analysis_id="a1", seq=0,
title="Missing CSMS", description="No CSMS certification.",
status=status, clause_ref="Art.9.1",
)
# ── new tests ─────────────────────────────────────────────────────────────────
def test_build_finding_context_contains_required_fields():
from app.application.compliance.pipeline import build_finding_context
ctx = build_finding_context(_sample_finding(), _sample_analysis())
assert "doc.pdf" in ctx
assert "EU AI Act" in ctx
assert "Missing CSMS" in ctx
assert "Art.9.1" in ctx
def test_generate_suggestions_returns_three_questions():
from app.application.compliance.pipeline import generate_suggestions
client = _make_mock_client(
'{"questions": ["Q1?", "Q2?", "Q3?"]}'
)
questions = generate_suggestions(_sample_finding("risk"), _sample_analysis(), client)
assert len(questions) == 3
assert all(isinstance(q, str) for q in questions)
def test_generate_suggestions_falls_back_on_error():
from app.application.compliance.pipeline import generate_suggestions
bad_client = MagicMock()
bad_resp = MagicMock()
bad_resp.is_success = False
bad_client.chat.return_value = bad_resp
questions = generate_suggestions(_sample_finding(), _sample_analysis(), bad_client)
assert len(questions) == 3 # fallback always returns 3