update for 1. 优化 2.中英切换
This commit is contained in:
140
backend/tests/compliance/test_pipeline.py
Normal file
140
backend/tests/compliance/test_pipeline.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from datetime import datetime
|
||||
|
||||
from app.infrastructure.vectorstore.pass_through_reranker import PassThroughReranker
|
||||
from app.domain.retrieval.models import RetrievedChunk
|
||||
from app.domain.compliance.ports import AnalysisRecord, FindingRecord
|
||||
|
||||
|
||||
# ── helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def _make_chunk(score: float) -> RetrievedChunk:
|
||||
return RetrievedChunk(
|
||||
chunk_id="c1",
|
||||
doc_id="d1",
|
||||
doc_title="Test Doc",
|
||||
section_title="S1",
|
||||
text="some text",
|
||||
score=score,
|
||||
page_start=1,
|
||||
)
|
||||
|
||||
|
||||
def _make_mock_client(content: str = '{"status":"ok","title":"T","desc":"D","clause_ref":"A1"}'):
|
||||
client = MagicMock()
|
||||
response = MagicMock()
|
||||
response.is_success = True
|
||||
response.content = content
|
||||
client.chat.return_value = response
|
||||
return client
|
||||
|
||||
|
||||
def _make_mock_retrieval():
|
||||
svc = MagicMock()
|
||||
svc.retrieve.return_value = []
|
||||
return svc
|
||||
|
||||
|
||||
# ── existing tests ────────────────────────────────────────────────────────────
|
||||
|
||||
def test_pass_through_returns_top_k():
|
||||
reranker = PassThroughReranker()
|
||||
chunks = [_make_chunk(0.9), _make_chunk(0.8), _make_chunk(0.7)]
|
||||
result = reranker.rerank(query="test", chunks=chunks, top_k=2)
|
||||
assert len(result) == 2
|
||||
assert result[0].score == 0.9
|
||||
|
||||
|
||||
def test_pass_through_returns_all_when_top_k_exceeds():
|
||||
reranker = PassThroughReranker()
|
||||
chunks = [_make_chunk(0.5)]
|
||||
result = reranker.rerank(query="test", chunks=chunks, top_k=10)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
# ── new tests ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_process_single_clause_returns_finding():
|
||||
from app.application.compliance.pipeline import process_single_clause
|
||||
client = _make_mock_client()
|
||||
svc = _make_mock_retrieval()
|
||||
result = process_single_clause("test clause", 0, svc, client)
|
||||
assert result["finding"] is not None
|
||||
assert result["index"] == 0
|
||||
assert result["chunks"] == []
|
||||
|
||||
|
||||
def test_run_clauses_parallel_runs_all():
|
||||
from app.application.compliance.pipeline import run_clauses_parallel
|
||||
client = _make_mock_client()
|
||||
svc = _make_mock_retrieval()
|
||||
clauses = ["clause one", "clause two", "clause three"]
|
||||
results = asyncio.run(run_clauses_parallel(clauses, svc, client))
|
||||
assert len(results) == 3
|
||||
assert all(r["index"] == i for i, r in enumerate(results))
|
||||
|
||||
|
||||
def test_run_clauses_parallel_handles_clause_failure():
|
||||
from app.application.compliance.pipeline import run_clauses_parallel
|
||||
svc = _make_mock_retrieval()
|
||||
bad_client = MagicMock()
|
||||
bad_client.chat.side_effect = RuntimeError("LLM exploded")
|
||||
results = asyncio.run(run_clauses_parallel(
|
||||
["clause one", "clause two"], svc, bad_client
|
||||
))
|
||||
assert len(results) == 2
|
||||
assert all(r["finding"] is None for r in results)
|
||||
assert all(r["chunks"] == [] for r in results)
|
||||
|
||||
|
||||
# ── helpers for new tests ─────────────────────────────────────────────────────
|
||||
|
||||
def _sample_analysis() -> AnalysisRecord:
|
||||
return AnalysisRecord(
|
||||
id="a1", created_at=datetime(2026, 6, 8), created_by="u",
|
||||
doc_name="doc.pdf", standard_name="EU AI Act",
|
||||
risk_score=72, conclusion="Gaps found.", actions=[], para_text="para",
|
||||
highlight_terms=[], findings=[],
|
||||
)
|
||||
|
||||
|
||||
def _sample_finding(status: str = "risk") -> FindingRecord:
|
||||
return FindingRecord(
|
||||
id="f1", analysis_id="a1", seq=0,
|
||||
title="Missing CSMS", description="No CSMS certification.",
|
||||
status=status, clause_ref="Art.9.1",
|
||||
)
|
||||
|
||||
|
||||
# ── new tests ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_build_finding_context_contains_required_fields():
|
||||
from app.application.compliance.pipeline import build_finding_context
|
||||
ctx = build_finding_context(_sample_finding(), _sample_analysis())
|
||||
assert "doc.pdf" in ctx
|
||||
assert "EU AI Act" in ctx
|
||||
assert "Missing CSMS" in ctx
|
||||
assert "Art.9.1" in ctx
|
||||
|
||||
|
||||
def test_generate_suggestions_returns_three_questions():
|
||||
from app.application.compliance.pipeline import generate_suggestions
|
||||
client = _make_mock_client(
|
||||
'{"questions": ["Q1?", "Q2?", "Q3?"]}'
|
||||
)
|
||||
questions = generate_suggestions(_sample_finding("risk"), _sample_analysis(), client)
|
||||
assert len(questions) == 3
|
||||
assert all(isinstance(q, str) for q in questions)
|
||||
|
||||
|
||||
def test_generate_suggestions_falls_back_on_error():
|
||||
from app.application.compliance.pipeline import generate_suggestions
|
||||
bad_client = MagicMock()
|
||||
bad_resp = MagicMock()
|
||||
bad_resp.is_success = False
|
||||
bad_client.chat.return_value = bad_resp
|
||||
questions = generate_suggestions(_sample_finding(), _sample_analysis(), bad_client)
|
||||
assert len(questions) == 3 # fallback always returns 3
|
||||
Reference in New Issue
Block a user