import asyncio import pytest from unittest.mock import MagicMock, patch from datetime import datetime from app.infrastructure.vectorstore.pass_through_reranker import PassThroughReranker from app.domain.retrieval.models import RetrievedChunk from app.domain.compliance.ports import AnalysisRecord, FindingRecord # ── helpers ────────────────────────────────────────────────────────────────── def _make_chunk(score: float) -> RetrievedChunk: return RetrievedChunk( chunk_id="c1", doc_id="d1", doc_title="Test Doc", section_title="S1", text="some text", score=score, page_start=1, ) def _make_mock_client(content: str = '{"status":"ok","title":"T","desc":"D","clause_ref":"A1"}'): client = MagicMock() response = MagicMock() response.is_success = True response.content = content client.chat.return_value = response return client def _make_mock_retrieval(): svc = MagicMock() svc.retrieve.return_value = [] return svc # ── existing tests ──────────────────────────────────────────────────────────── def test_pass_through_returns_top_k(): reranker = PassThroughReranker() chunks = [_make_chunk(0.9), _make_chunk(0.8), _make_chunk(0.7)] result = reranker.rerank(query="test", chunks=chunks, top_k=2) assert len(result) == 2 assert result[0].score == 0.9 def test_pass_through_returns_all_when_top_k_exceeds(): reranker = PassThroughReranker() chunks = [_make_chunk(0.5)] result = reranker.rerank(query="test", chunks=chunks, top_k=10) assert len(result) == 1 # ── new tests ───────────────────────────────────────────────────────────────── def test_process_single_clause_returns_finding(): from app.application.compliance.pipeline import process_single_clause client = _make_mock_client() svc = _make_mock_retrieval() result = process_single_clause("test clause", 0, svc, client) assert result["finding"] is not None assert result["index"] == 0 assert result["chunks"] == [] def test_run_clauses_parallel_runs_all(): from app.application.compliance.pipeline import run_clauses_parallel client = _make_mock_client() svc = _make_mock_retrieval() clauses = ["clause one", "clause two", "clause three"] results = asyncio.run(run_clauses_parallel(clauses, svc, client)) assert len(results) == 3 assert all(r["index"] == i for i, r in enumerate(results)) def test_run_clauses_parallel_handles_clause_failure(): from app.application.compliance.pipeline import run_clauses_parallel svc = _make_mock_retrieval() bad_client = MagicMock() bad_client.chat.side_effect = RuntimeError("LLM exploded") results = asyncio.run(run_clauses_parallel( ["clause one", "clause two"], svc, bad_client )) assert len(results) == 2 assert all(r["finding"] is None for r in results) assert all(r["chunks"] == [] for r in results) # ── helpers for new tests ───────────────────────────────────────────────────── def _sample_analysis() -> AnalysisRecord: return AnalysisRecord( id="a1", created_at=datetime(2026, 6, 8), created_by="u", doc_name="doc.pdf", standard_name="EU AI Act", risk_score=72, conclusion="Gaps found.", actions=[], para_text="para", highlight_terms=[], findings=[], ) def _sample_finding(status: str = "risk") -> FindingRecord: return FindingRecord( id="f1", analysis_id="a1", seq=0, title="Missing CSMS", description="No CSMS certification.", status=status, clause_ref="Art.9.1", ) # ── new tests ───────────────────────────────────────────────────────────────── def test_build_finding_context_contains_required_fields(): from app.application.compliance.pipeline import build_finding_context ctx = build_finding_context(_sample_finding(), _sample_analysis()) assert "doc.pdf" in ctx assert "EU AI Act" in ctx assert "Missing CSMS" in ctx assert "Art.9.1" in ctx def test_generate_suggestions_returns_three_questions(): from app.application.compliance.pipeline import generate_suggestions client = _make_mock_client( '{"questions": ["Q1?", "Q2?", "Q3?"]}' ) questions = generate_suggestions(_sample_finding("risk"), _sample_analysis(), client) assert len(questions) == 3 assert all(isinstance(q, str) for q in questions) def test_generate_suggestions_falls_back_on_error(): from app.application.compliance.pipeline import generate_suggestions bad_client = MagicMock() bad_resp = MagicMock() bad_resp.is_success = False bad_client.chat.return_value = bad_resp questions = generate_suggestions(_sample_finding(), _sample_analysis(), bad_client) assert len(questions) == 3 # fallback always returns 3