141 lines
5.3 KiB
Python
141 lines
5.3 KiB
Python
|
|
import asyncio
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from unittest.mock import MagicMock, patch
|
||
|
|
from datetime import datetime
|
||
|
|
|
||
|
|
from app.infrastructure.vectorstore.pass_through_reranker import PassThroughReranker
|
||
|
|
from app.domain.retrieval.models import RetrievedChunk
|
||
|
|
from app.domain.compliance.ports import AnalysisRecord, FindingRecord
|
||
|
|
|
||
|
|
|
||
|
|
# ── helpers ──────────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
def _make_chunk(score: float) -> RetrievedChunk:
|
||
|
|
return RetrievedChunk(
|
||
|
|
chunk_id="c1",
|
||
|
|
doc_id="d1",
|
||
|
|
doc_title="Test Doc",
|
||
|
|
section_title="S1",
|
||
|
|
text="some text",
|
||
|
|
score=score,
|
||
|
|
page_start=1,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _make_mock_client(content: str = '{"status":"ok","title":"T","desc":"D","clause_ref":"A1"}'):
|
||
|
|
client = MagicMock()
|
||
|
|
response = MagicMock()
|
||
|
|
response.is_success = True
|
||
|
|
response.content = content
|
||
|
|
client.chat.return_value = response
|
||
|
|
return client
|
||
|
|
|
||
|
|
|
||
|
|
def _make_mock_retrieval():
|
||
|
|
svc = MagicMock()
|
||
|
|
svc.retrieve.return_value = []
|
||
|
|
return svc
|
||
|
|
|
||
|
|
|
||
|
|
# ── existing tests ────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
def test_pass_through_returns_top_k():
|
||
|
|
reranker = PassThroughReranker()
|
||
|
|
chunks = [_make_chunk(0.9), _make_chunk(0.8), _make_chunk(0.7)]
|
||
|
|
result = reranker.rerank(query="test", chunks=chunks, top_k=2)
|
||
|
|
assert len(result) == 2
|
||
|
|
assert result[0].score == 0.9
|
||
|
|
|
||
|
|
|
||
|
|
def test_pass_through_returns_all_when_top_k_exceeds():
|
||
|
|
reranker = PassThroughReranker()
|
||
|
|
chunks = [_make_chunk(0.5)]
|
||
|
|
result = reranker.rerank(query="test", chunks=chunks, top_k=10)
|
||
|
|
assert len(result) == 1
|
||
|
|
|
||
|
|
|
||
|
|
# ── new tests ─────────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
def test_process_single_clause_returns_finding():
|
||
|
|
from app.application.compliance.pipeline import process_single_clause
|
||
|
|
client = _make_mock_client()
|
||
|
|
svc = _make_mock_retrieval()
|
||
|
|
result = process_single_clause("test clause", 0, svc, client)
|
||
|
|
assert result["finding"] is not None
|
||
|
|
assert result["index"] == 0
|
||
|
|
assert result["chunks"] == []
|
||
|
|
|
||
|
|
|
||
|
|
def test_run_clauses_parallel_runs_all():
|
||
|
|
from app.application.compliance.pipeline import run_clauses_parallel
|
||
|
|
client = _make_mock_client()
|
||
|
|
svc = _make_mock_retrieval()
|
||
|
|
clauses = ["clause one", "clause two", "clause three"]
|
||
|
|
results = asyncio.run(run_clauses_parallel(clauses, svc, client))
|
||
|
|
assert len(results) == 3
|
||
|
|
assert all(r["index"] == i for i, r in enumerate(results))
|
||
|
|
|
||
|
|
|
||
|
|
def test_run_clauses_parallel_handles_clause_failure():
|
||
|
|
from app.application.compliance.pipeline import run_clauses_parallel
|
||
|
|
svc = _make_mock_retrieval()
|
||
|
|
bad_client = MagicMock()
|
||
|
|
bad_client.chat.side_effect = RuntimeError("LLM exploded")
|
||
|
|
results = asyncio.run(run_clauses_parallel(
|
||
|
|
["clause one", "clause two"], svc, bad_client
|
||
|
|
))
|
||
|
|
assert len(results) == 2
|
||
|
|
assert all(r["finding"] is None for r in results)
|
||
|
|
assert all(r["chunks"] == [] for r in results)
|
||
|
|
|
||
|
|
|
||
|
|
# ── helpers for new tests ─────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
def _sample_analysis() -> AnalysisRecord:
|
||
|
|
return AnalysisRecord(
|
||
|
|
id="a1", created_at=datetime(2026, 6, 8), created_by="u",
|
||
|
|
doc_name="doc.pdf", standard_name="EU AI Act",
|
||
|
|
risk_score=72, conclusion="Gaps found.", actions=[], para_text="para",
|
||
|
|
highlight_terms=[], findings=[],
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _sample_finding(status: str = "risk") -> FindingRecord:
|
||
|
|
return FindingRecord(
|
||
|
|
id="f1", analysis_id="a1", seq=0,
|
||
|
|
title="Missing CSMS", description="No CSMS certification.",
|
||
|
|
status=status, clause_ref="Art.9.1",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# ── new tests ─────────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
def test_build_finding_context_contains_required_fields():
|
||
|
|
from app.application.compliance.pipeline import build_finding_context
|
||
|
|
ctx = build_finding_context(_sample_finding(), _sample_analysis())
|
||
|
|
assert "doc.pdf" in ctx
|
||
|
|
assert "EU AI Act" in ctx
|
||
|
|
assert "Missing CSMS" in ctx
|
||
|
|
assert "Art.9.1" in ctx
|
||
|
|
|
||
|
|
|
||
|
|
def test_generate_suggestions_returns_three_questions():
|
||
|
|
from app.application.compliance.pipeline import generate_suggestions
|
||
|
|
client = _make_mock_client(
|
||
|
|
'{"questions": ["Q1?", "Q2?", "Q3?"]}'
|
||
|
|
)
|
||
|
|
questions = generate_suggestions(_sample_finding("risk"), _sample_analysis(), client)
|
||
|
|
assert len(questions) == 3
|
||
|
|
assert all(isinstance(q, str) for q in questions)
|
||
|
|
|
||
|
|
|
||
|
|
def test_generate_suggestions_falls_back_on_error():
|
||
|
|
from app.application.compliance.pipeline import generate_suggestions
|
||
|
|
bad_client = MagicMock()
|
||
|
|
bad_resp = MagicMock()
|
||
|
|
bad_resp.is_success = False
|
||
|
|
bad_client.chat.return_value = bad_resp
|
||
|
|
questions = generate_suggestions(_sample_finding(), _sample_analysis(), bad_client)
|
||
|
|
assert len(questions) == 3 # fallback always returns 3
|