update for 1. 优化 2.中英切换
This commit is contained in:
0
backend/tests/compliance/__init__.py
Normal file
0
backend/tests/compliance/__init__.py
Normal file
140
backend/tests/compliance/test_pipeline.py
Normal file
140
backend/tests/compliance/test_pipeline.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from datetime import datetime
|
||||
|
||||
from app.infrastructure.vectorstore.pass_through_reranker import PassThroughReranker
|
||||
from app.domain.retrieval.models import RetrievedChunk
|
||||
from app.domain.compliance.ports import AnalysisRecord, FindingRecord
|
||||
|
||||
|
||||
# ── helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def _make_chunk(score: float) -> RetrievedChunk:
|
||||
return RetrievedChunk(
|
||||
chunk_id="c1",
|
||||
doc_id="d1",
|
||||
doc_title="Test Doc",
|
||||
section_title="S1",
|
||||
text="some text",
|
||||
score=score,
|
||||
page_start=1,
|
||||
)
|
||||
|
||||
|
||||
def _make_mock_client(content: str = '{"status":"ok","title":"T","desc":"D","clause_ref":"A1"}'):
|
||||
client = MagicMock()
|
||||
response = MagicMock()
|
||||
response.is_success = True
|
||||
response.content = content
|
||||
client.chat.return_value = response
|
||||
return client
|
||||
|
||||
|
||||
def _make_mock_retrieval():
|
||||
svc = MagicMock()
|
||||
svc.retrieve.return_value = []
|
||||
return svc
|
||||
|
||||
|
||||
# ── existing tests ────────────────────────────────────────────────────────────
|
||||
|
||||
def test_pass_through_returns_top_k():
|
||||
reranker = PassThroughReranker()
|
||||
chunks = [_make_chunk(0.9), _make_chunk(0.8), _make_chunk(0.7)]
|
||||
result = reranker.rerank(query="test", chunks=chunks, top_k=2)
|
||||
assert len(result) == 2
|
||||
assert result[0].score == 0.9
|
||||
|
||||
|
||||
def test_pass_through_returns_all_when_top_k_exceeds():
|
||||
reranker = PassThroughReranker()
|
||||
chunks = [_make_chunk(0.5)]
|
||||
result = reranker.rerank(query="test", chunks=chunks, top_k=10)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
# ── new tests ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_process_single_clause_returns_finding():
|
||||
from app.application.compliance.pipeline import process_single_clause
|
||||
client = _make_mock_client()
|
||||
svc = _make_mock_retrieval()
|
||||
result = process_single_clause("test clause", 0, svc, client)
|
||||
assert result["finding"] is not None
|
||||
assert result["index"] == 0
|
||||
assert result["chunks"] == []
|
||||
|
||||
|
||||
def test_run_clauses_parallel_runs_all():
|
||||
from app.application.compliance.pipeline import run_clauses_parallel
|
||||
client = _make_mock_client()
|
||||
svc = _make_mock_retrieval()
|
||||
clauses = ["clause one", "clause two", "clause three"]
|
||||
results = asyncio.run(run_clauses_parallel(clauses, svc, client))
|
||||
assert len(results) == 3
|
||||
assert all(r["index"] == i for i, r in enumerate(results))
|
||||
|
||||
|
||||
def test_run_clauses_parallel_handles_clause_failure():
|
||||
from app.application.compliance.pipeline import run_clauses_parallel
|
||||
svc = _make_mock_retrieval()
|
||||
bad_client = MagicMock()
|
||||
bad_client.chat.side_effect = RuntimeError("LLM exploded")
|
||||
results = asyncio.run(run_clauses_parallel(
|
||||
["clause one", "clause two"], svc, bad_client
|
||||
))
|
||||
assert len(results) == 2
|
||||
assert all(r["finding"] is None for r in results)
|
||||
assert all(r["chunks"] == [] for r in results)
|
||||
|
||||
|
||||
# ── helpers for new tests ─────────────────────────────────────────────────────
|
||||
|
||||
def _sample_analysis() -> AnalysisRecord:
|
||||
return AnalysisRecord(
|
||||
id="a1", created_at=datetime(2026, 6, 8), created_by="u",
|
||||
doc_name="doc.pdf", standard_name="EU AI Act",
|
||||
risk_score=72, conclusion="Gaps found.", actions=[], para_text="para",
|
||||
highlight_terms=[], findings=[],
|
||||
)
|
||||
|
||||
|
||||
def _sample_finding(status: str = "risk") -> FindingRecord:
|
||||
return FindingRecord(
|
||||
id="f1", analysis_id="a1", seq=0,
|
||||
title="Missing CSMS", description="No CSMS certification.",
|
||||
status=status, clause_ref="Art.9.1",
|
||||
)
|
||||
|
||||
|
||||
# ── new tests ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_build_finding_context_contains_required_fields():
|
||||
from app.application.compliance.pipeline import build_finding_context
|
||||
ctx = build_finding_context(_sample_finding(), _sample_analysis())
|
||||
assert "doc.pdf" in ctx
|
||||
assert "EU AI Act" in ctx
|
||||
assert "Missing CSMS" in ctx
|
||||
assert "Art.9.1" in ctx
|
||||
|
||||
|
||||
def test_generate_suggestions_returns_three_questions():
|
||||
from app.application.compliance.pipeline import generate_suggestions
|
||||
client = _make_mock_client(
|
||||
'{"questions": ["Q1?", "Q2?", "Q3?"]}'
|
||||
)
|
||||
questions = generate_suggestions(_sample_finding("risk"), _sample_analysis(), client)
|
||||
assert len(questions) == 3
|
||||
assert all(isinstance(q, str) for q in questions)
|
||||
|
||||
|
||||
def test_generate_suggestions_falls_back_on_error():
|
||||
from app.application.compliance.pipeline import generate_suggestions
|
||||
bad_client = MagicMock()
|
||||
bad_resp = MagicMock()
|
||||
bad_resp.is_success = False
|
||||
bad_client.chat.return_value = bad_resp
|
||||
questions = generate_suggestions(_sample_finding(), _sample_analysis(), bad_client)
|
||||
assert len(questions) == 3 # fallback always returns 3
|
||||
98
backend/tests/compliance/test_repository.py
Normal file
98
backend/tests/compliance/test_repository.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
from datetime import datetime
|
||||
from app.domain.compliance.ports import (
|
||||
AnalysisRecord,
|
||||
FindingRecord,
|
||||
ComplianceRepository,
|
||||
)
|
||||
|
||||
|
||||
def _mock_pool():
|
||||
"""Return a mock psycopg2 ThreadedConnectionPool."""
|
||||
conn = MagicMock()
|
||||
cursor = MagicMock()
|
||||
cursor.__enter__ = MagicMock(return_value=cursor)
|
||||
cursor.__exit__ = MagicMock(return_value=False)
|
||||
conn.cursor.return_value = cursor
|
||||
pool = MagicMock()
|
||||
pool.getconn.return_value = conn
|
||||
return pool, conn, cursor
|
||||
|
||||
|
||||
@patch("app.infrastructure.compliance.repository.psycopg2.pool.ThreadedConnectionPool")
|
||||
def test_save_analysis_returns_uuid(mock_pool_cls):
|
||||
from app.infrastructure.compliance.repository import PostgresComplianceRepository
|
||||
pool, conn, cursor = _mock_pool()
|
||||
mock_pool_cls.return_value = pool
|
||||
cursor.fetchone.return_value = {"id": "abc-123"}
|
||||
|
||||
repo = PostgresComplianceRepository(
|
||||
host="localhost", port=5432, user="u", password="p", dbname="db"
|
||||
)
|
||||
record = AnalysisRecord(
|
||||
id="", created_at=datetime.utcnow(), created_by="user1",
|
||||
doc_name="doc.pdf", standard_name="EU AI Act",
|
||||
risk_score=50, conclusion="OK", actions=[], para_text="p",
|
||||
highlight_terms=[], findings=[],
|
||||
)
|
||||
result = repo.save_analysis(record)
|
||||
assert result == "abc-123"
|
||||
|
||||
|
||||
def test_analysis_record_construction():
|
||||
record = AnalysisRecord(
|
||||
id="",
|
||||
created_at=datetime.utcnow(),
|
||||
created_by="user1",
|
||||
doc_name="test.pdf",
|
||||
standard_name="EU AI Act",
|
||||
risk_score=72,
|
||||
conclusion="Several gaps found.",
|
||||
actions=[{"label": "Fix", "value": "Update docs"}],
|
||||
para_text="The system shall...",
|
||||
highlight_terms=["CSMS", "ISO 21434"],
|
||||
findings=[
|
||||
FindingRecord(
|
||||
id="",
|
||||
analysis_id="",
|
||||
seq=0,
|
||||
title="Missing CSMS",
|
||||
description="No CSMS certification found.",
|
||||
status="risk",
|
||||
clause_ref="Art.9.1",
|
||||
)
|
||||
],
|
||||
)
|
||||
assert record.doc_name == "test.pdf"
|
||||
assert len(record.findings) == 1
|
||||
assert record.findings[0].status == "risk"
|
||||
|
||||
|
||||
def test_compliance_repository_is_abstract():
|
||||
import inspect
|
||||
assert inspect.isabstract(ComplianceRepository)
|
||||
|
||||
|
||||
def test_generate_docx_returns_bytes():
|
||||
from app.infrastructure.compliance.docx_export import generate_docx
|
||||
record = AnalysisRecord(
|
||||
id="test-id", created_at=datetime(2026, 6, 8), created_by="user1",
|
||||
doc_name="test.pdf", standard_name="EU AI Act",
|
||||
risk_score=72, conclusion="Several gaps found.",
|
||||
actions=[{"label": "Fix", "value": "Update CSMS docs"}],
|
||||
para_text="The system shall implement CSMS.",
|
||||
highlight_terms=["CSMS"],
|
||||
findings=[
|
||||
FindingRecord(
|
||||
id="f1", analysis_id="test-id", seq=0,
|
||||
title="Missing CSMS", description="No CSMS cert.",
|
||||
status="risk", clause_ref="Art.9.1",
|
||||
)
|
||||
],
|
||||
)
|
||||
data = generate_docx(record)
|
||||
assert isinstance(data, bytes)
|
||||
assert len(data) > 1000 # DOCX is at minimum a ZIP with ~1 KB overhead
|
||||
# Verify it's a valid ZIP (DOCX = ZIP container)
|
||||
import zipfile, io
|
||||
assert zipfile.is_zipfile(io.BytesIO(data))
|
||||
Reference in New Issue
Block a user