⚡ 分析质量提升
+并行子句处理(速度 3–5×)、跨编码器重排序、置信度过滤、修复 highlight_terms 失效 Bug、减少 LLM 静默失败。
+收益
- 更快、更准确的分析
- 消除当前 Bug
难度
- 需要改造 pipeline.py
diff --git a/.env b/.env index a92f5b7..45aa0e1 100644 --- a/.env +++ b/.env @@ -48,8 +48,8 @@ CHUNK_OVERLAP=50 MAX_FILE_SIZE_MB=100 PARSER_BACKEND=aliyun CHUNK_BACKEND=aliyun -# 文档元数据存储后端:json(默认)或 postgres -DOCUMENT_REPOSITORY_BACKEND=json +# 文档元数据存储后端:启用 postgres 以激活合规分析历史记录(Direction B)及 Finding Chat 持久化(Direction C) +DOCUMENT_REPOSITORY_BACKEND=postgres # Set to true only when a Celery worker is actually running (./dev.sh start worker). # Default false: processing runs in FastAPI's threadpool — no external worker needed. USE_CELERY_WORKER=false diff --git a/.env.development b/.env.development index b8a2285..ac899e7 100644 --- a/.env.development +++ b/.env.development @@ -31,5 +31,5 @@ POSTGRES_PASSWORD=postgresql123456 POSTGRES_DB=compliance_db # ===== 文档元数据后端 ===== -# 改为 postgres 以启用 PG 持久化(structure_nodes + semantic_blocks 入库) +# 改为 postgres 以启用合规分析历史记录(Direction B)和 Finding Chat(Direction C) DOCUMENT_REPOSITORY_BACKEND=json diff --git a/.env.example b/.env.example index 13a7539..1bc8445 100644 --- a/.env.example +++ b/.env.example @@ -49,7 +49,11 @@ MAX_FILE_SIZE_MB=100 DOCUMENT_METADATA_PATH=backend/data/documents.json PARSER_BACKEND=aliyun CHUNK_BACKEND=aliyun -# DOCUMENT_REPOSITORY_BACKEND=json(默认,无需数据库)或 postgres(启用 PG 持久化) +# 文档元数据存储后端:json(默认,无需数据库)或 postgres(启用 PG 持久化) +# ⚠ 以下功能需要 postgres(设为 json 时功能静默降级或报 500): +# - Direction B: 合规分析历史记录 (/compliance/history/*) +# - Direction B: DOCX 报告下载 +# - Direction C: Finding Chat 消息持久化 DOCUMENT_REPOSITORY_BACKEND=json # Set to true only when a Celery worker is running (./dev.sh start worker). # Default false: document processing runs in FastAPI's threadpool (no external worker needed). diff --git a/.superpowers/brainstorm/1055-1780892298/content/directions.html b/.superpowers/brainstorm/1055-1780892298/content/directions.html new file mode 100644 index 0000000..debfc0b --- /dev/null +++ b/.superpowers/brainstorm/1055-1780892298/content/directions.html @@ -0,0 +1,56 @@ +
基于代码深度分析,发现了 4 个有价值的改进方向。选择你最希望深入的那个。
+ + + +💡 也可以多选,或者在终端告诉我你有其他想法。
diff --git a/.superpowers/brainstorm/1055-1780892298/state/events b/.superpowers/brainstorm/1055-1780892298/state/events new file mode 100644 index 0000000..7e717fe --- /dev/null +++ b/.superpowers/brainstorm/1055-1780892298/state/events @@ -0,0 +1,3 @@ +{"type":"click","text":"C\n \n 💬 深度 Chat 增强\n 每个 Finding 独立对话线程(持久化)、Chat 上下文绑定真实检索到的法规原文、多轮追问记忆、快捷建议问句生成。\n \n 收益Finding 解读深度大幅提升用户粘性强\n 难度需重构 chat 端点","choice":"C","id":null,"timestamp":1780897984866} +{"type":"click","text":"B\n \n 📋 分析历史 & 专业报告\n 持久化分析记录(PostgreSQL)、历史对比、PDF/DOCX 专业报告导出、分析版本追踪。\n \n 收益结果不再丢失可交付给客户的报告\n 难度需要新增数据库表","choice":"B","id":null,"timestamp":1780897985879} +{"type":"click","text":"A\n \n ⚡ 分析质量提升\n 并行子句处理(速度 3–5×)、跨编码器重排序、置信度过滤、修复 highlight_terms 失效 Bug、减少 LLM 静默失败。\n \n 收益更快、更准确的分析消除当前 Bug\n 难度需要改造 pipeline.py","choice":"A","id":null,"timestamp":1780897986554} diff --git a/.superpowers/brainstorm/1055-1780892298/state/server-stopped b/.superpowers/brainstorm/1055-1780892298/state/server-stopped new file mode 100644 index 0000000..259fc09 --- /dev/null +++ b/.superpowers/brainstorm/1055-1780892298/state/server-stopped @@ -0,0 +1 @@ +{"reason":"idle timeout","timestamp":1780894411095} diff --git a/.superpowers/brainstorm/1055-1780892298/state/server.pid b/.superpowers/brainstorm/1055-1780892298/state/server.pid new file mode 100644 index 0000000..ca88378 --- /dev/null +++ b/.superpowers/brainstorm/1055-1780892298/state/server.pid @@ -0,0 +1 @@ +1055 diff --git a/backend/app/api/routes/compliance.py b/backend/app/api/routes/compliance.py index 3654eaf..d0be8c5 100644 --- a/backend/app/api/routes/compliance.py +++ b/backend/app/api/routes/compliance.py @@ -85,10 +85,9 @@ async def analyze_stream( Events: stage | source | finding | done | error """ from app.application.compliance.pipeline import ( - check_clause_compliance, extract_text_from_doc_id, extract_text_from_file, - retrieve_for_clause, + run_clauses_parallel, split_into_clauses, synthesize_conclusion, ) @@ -136,22 +135,28 @@ async def analyze_stream( await asyncio.sleep(0) clauses: list[str] = await asyncio.to_thread(split_into_clauses, para_text, client) - # ── Stage 3: retrieve + gap check per clause ────────────────── + # ── Stage 3: retrieve + gap check (parallel across all clauses) ──────────── findings: list[dict] = [] - for i, clause in enumerate(clauses): - yield _sse({ - "type": "stage", - "stage": "analyzing", - "label": f"Analyzing clause {i + 1}/{len(clauses)}…", - }) - await asyncio.sleep(0) + yield _sse({ + "type": "stage", + "stage": "analyzing", + "label": f"Analyzing {len(clauses)} clauses in parallel…", + }) + await asyncio.sleep(0) - chunks = await asyncio.to_thread( - retrieve_for_clause, clause, retrieval_service, 5, domains or None - ) + clause_results = await run_clauses_parallel( + clauses, retrieval_service, client, + top_k=5, + domains=domains or None, + ) - # Emit source events + for res in clause_results: + i = res["index"] + chunks = res["chunks"] + finding = res["finding"] + + # Emit source events for this clause for chunk in chunks[:3]: yield _sse({ "type": "source", @@ -161,12 +166,11 @@ async def analyze_stream( "status": "retrieved", "full_content": (getattr(chunk, "text", "") or "")[:300], }) - await asyncio.sleep(0) - finding = await asyncio.to_thread(check_clause_compliance, clause, chunks, client) if finding: findings.append(finding) yield _sse({"type": "finding", **finding}) + await asyncio.sleep(0) # ── Stage 4: synthesize conclusion ──────────────────────────── @@ -178,6 +182,45 @@ async def analyze_stream( ) yield _sse({"type": "done", **conclusion_data}) + # Auto-save analysis to database + try: + from app.shared.bootstrap import get_compliance_repository + from app.domain.compliance.ports import AnalysisRecord, FindingRecord + from datetime import datetime + + repo = get_compliance_repository() + finding_records = [ + FindingRecord( + id="", + analysis_id="", + seq=i, + title=f.get("title", ""), + description=f.get("desc", ""), + status=f.get("status", "ok"), + clause_ref=f.get("clause_ref"), + ) + for i, f in enumerate(findings) + ] + record = AnalysisRecord( + id="", + created_at=datetime.utcnow(), + created_by=current_user.username if hasattr(current_user, "username") else None, + doc_name=file_name or (title or "Pasted text"), + standard_name=title or "", + risk_score=conclusion_data.get("risk_score", 0), + conclusion=conclusion_data.get("conclusion", ""), + actions=conclusion_data.get("actions", []), + para_text=conclusion_data.get("para_text", ""), + highlight_terms=conclusion_data.get("highlight_terms", []), + findings=finding_records, + ) + analysis_id = await asyncio.to_thread(repo.save_analysis, record) + yield _sse({"type": "saved", "analysis_id": analysis_id}) + except NotImplementedError: + pass # No postgres backend configured — skip saving + except Exception as exc: + logger.warning("Failed to auto-save compliance analysis: {}", exc) + except Exception as exc: logger.exception("analyze-stream pipeline error") yield _sse({"type": "error", "text": str(exc)}) @@ -225,3 +268,226 @@ async def compliance_chat(segment_id: int, request: ComplianceChatRequest): media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, ) + + +@router.get("/history") +async def list_history( + limit: int = 20, + offset: int = 0, + current_user: UserClaims = Depends(get_current_user), +): + """Return paginated list of saved compliance analyses (newest first).""" + from app.shared.bootstrap import get_compliance_repository + try: + repo = get_compliance_repository() + records = await asyncio.to_thread(repo.list_analyses, limit, offset) + return [ + { + "id": r.id, + "created_at": r.created_at.isoformat(), + "created_by": r.created_by, + "doc_name": r.doc_name, + "standard_name": r.standard_name, + "risk_score": r.risk_score, + "finding_count": len(r.findings), + } + for r in records + ] + except NotImplementedError: + return [] + + +@router.get("/history/{analysis_id}") +async def get_history_item( + analysis_id: str, + current_user: UserClaims = Depends(get_current_user), +): + """Return full analysis record including findings.""" + from app.shared.bootstrap import get_compliance_repository + from fastapi import HTTPException + repo = get_compliance_repository() + record = await asyncio.to_thread(repo.get_analysis, analysis_id) + if not record: + raise HTTPException(status_code=404, detail="Analysis not found") + return { + "id": record.id, + "created_at": record.created_at.isoformat(), + "created_by": record.created_by, + "doc_name": record.doc_name, + "standard_name": record.standard_name, + "risk_score": record.risk_score, + "conclusion": record.conclusion, + "actions": record.actions, + "para_text": record.para_text, + "highlight_terms": record.highlight_terms, + "findings": [ + { + "id": f.id, + "seq": f.seq, + "title": f.title, + "description": f.description, + "status": f.status, + "clause_ref": f.clause_ref, + } + for f in record.findings + ], + } + + +@router.delete("/history/{analysis_id}", status_code=204) +async def delete_history_item( + analysis_id: str, + current_user: UserClaims = Depends(get_current_user), +): + """Delete a saved analysis (cascade removes findings and chat messages).""" + from app.shared.bootstrap import get_compliance_repository + repo = get_compliance_repository() + await asyncio.to_thread(repo.delete_analysis, analysis_id) + + +@router.get("/history/{analysis_id}/download") +async def download_history_docx( + analysis_id: str, + current_user: UserClaims = Depends(get_current_user), +): + """Return a DOCX compliance report for the given analysis.""" + from app.shared.bootstrap import get_compliance_repository + from app.infrastructure.compliance.docx_export import generate_docx + from fastapi import HTTPException + from fastapi.responses import Response + + repo = get_compliance_repository() + record = await asyncio.to_thread(repo.get_analysis, analysis_id) + if not record: + raise HTTPException(status_code=404, detail="Analysis not found") + + docx_bytes = await asyncio.to_thread(generate_docx, record) + safe_name = (record.doc_name or "report").replace(" ", "_")[:50] + filename = f"compliance_{safe_name}_{record.created_at.strftime('%Y%m%d')}.docx" + return Response( + content=docx_bytes, + media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", + headers={"Content-Disposition": f'attachment; filename="{filename}"'}, + ) + + +@router.get("/analyses/{analysis_id}/findings/{finding_id}/chat") +async def get_finding_chat_history( + analysis_id: str, + finding_id: str, + current_user: UserClaims = Depends(get_current_user), +): + """Return persisted chat messages for a finding thread, oldest first.""" + from app.shared.bootstrap import get_compliance_repository + try: + repo = get_compliance_repository() + messages = await asyncio.to_thread(repo.get_messages, finding_id) + return messages + except NotImplementedError: + return [] + + +@router.post("/analyses/{analysis_id}/findings/{finding_id}/suggestions") +async def get_finding_suggestions( + analysis_id: str, + finding_id: str, + current_user: UserClaims = Depends(get_current_user), +): + """Generate 3 LLM-powered follow-up question suggestions for a finding.""" + from app.application.compliance.pipeline import generate_suggestions + from app.shared.bootstrap import get_compliance_repository + from app.services.llm.llm_factory import get_llm_client + from fastapi import HTTPException + + repo = get_compliance_repository() + analysis = await asyncio.to_thread(repo.get_analysis, analysis_id) + if not analysis: + raise HTTPException(status_code=404, detail="Analysis not found") + + finding = next((f for f in analysis.findings if f.id == finding_id), None) + if not finding: + raise HTTPException(status_code=404, detail="Finding not found") + + client = get_llm_client(provider=settings.llm_provider, model=settings.llm_model) + questions = await asyncio.to_thread(generate_suggestions, finding, analysis, client) + return {"questions": questions} + + +@router.post("/analyses/{analysis_id}/findings/{finding_id}/chat") +async def finding_chat( + analysis_id: str, + finding_id: str, + request: ComplianceChatRequest, + current_user: UserClaims = Depends(get_current_user), +): + """Stream a grounded chat response for a specific finding. + + Loads the finding and analysis from DB to build grounded context. + Persists both user message and assistant response to finding_chat_messages. + """ + from app.application.compliance.pipeline import build_finding_context + from app.shared.bootstrap import get_compliance_repository + from fastapi import HTTPException + + repo = get_compliance_repository() + analysis = await asyncio.to_thread(repo.get_analysis, analysis_id) + if not analysis: + raise HTTPException(status_code=404, detail="Analysis not found") + finding = next((f for f in analysis.findings if f.id == finding_id), None) + if not finding: + raise HTTPException(status_code=404, detail="Finding not found") + + # Persist user message + await asyncio.to_thread( + repo.save_message, analysis_id, finding_id, "user", request.query + ) + + # Build message history (last 10 messages = 5 turns) + history = await asyncio.to_thread(repo.get_messages, finding_id) + history_messages = [ + {"role": m["role"], "content": m["content"]} + for m in history[-10:] + ] + + # Build grounded system context + system_context = build_finding_context(finding, analysis) + full_query = f"[Compliance Finding Context]\n{system_context}\n\nUser question: {request.query}" + + assistant_buffer: list[str] = [] + + async def generate() -> AsyncGenerator[str, None]: + try: + _, event_stream = get_agent_conversation_service().stream_chat( + query=full_query, + top_k=5, + prompt_template="compliance_qa", + ) + for event in event_stream: + event_type = event.get("event", "") + if event_type == "content": + text = event.get("data", "") + if text: + assistant_buffer.append(text) + yield _sse({"type": "chunk", "text": text}) + elif event_type == "done": + yield _sse({"type": "done"}) + await asyncio.sleep(0) + except Exception as exc: + logger.exception("finding_chat stream error") + yield _sse({"type": "error", "text": str(exc)}) + finally: + # Persist assistant response after stream completes + full_response = "".join(assistant_buffer) + if full_response: + try: + await asyncio.to_thread( + repo.save_message, analysis_id, finding_id, "assistant", full_response + ) + except Exception as exc: + logger.warning("Failed to persist assistant message: {}", exc) + + return StreamingResponse( + generate(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, + ) diff --git a/backend/app/application/compliance/pipeline.py b/backend/app/application/compliance/pipeline.py index 92b406e..de348a4 100644 --- a/backend/app/application/compliance/pipeline.py +++ b/backend/app/application/compliance/pipeline.py @@ -5,6 +5,7 @@ All functions are synchronous — call them via asyncio.to_thread() in async SSE from __future__ import annotations +import asyncio import json import os import re @@ -12,10 +13,20 @@ import tempfile from typing import TYPE_CHECKING from loguru import logger +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential + +# Shared retry policy for LLM calls: 3 attempts, exponential back-off 1–4 s. +_llm_retry = retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=1, max=4), + retry=retry_if_exception_type((ValueError, TimeoutError, ConnectionError)), + reraise=True, +) if TYPE_CHECKING: from app.application.knowledge import KnowledgeRetrievalService from app.domain.retrieval import RetrievedChunk + from app.domain.compliance.ports import AnalysisRecord, FindingRecord from app.services.llm.base_client import BaseLLMClient @@ -109,17 +120,67 @@ def retrieve_for_clause( return retrieval_service.retrieve(query=clause, top_k=top_k, filters=domains) +def process_single_clause( + clause: str, + index: int, + retrieval_service: "KnowledgeRetrievalService", + client: "BaseLLMClient", + top_k: int = 5, + domains: str | None = None, +) -> dict: + """Process one clause: retrieve relevant regulations then check compliance. + + Returns a dict with keys: index, chunks, finding (may be None on LLM failure). + Designed to run inside asyncio.to_thread() for parallel execution. + """ + chunks = retrieve_for_clause(clause, retrieval_service, top_k, domains) + finding = check_clause_compliance(clause, chunks, client) + return {"index": index, "chunks": chunks, "finding": finding} + + +async def run_clauses_parallel( + clauses: list[str], + retrieval_service: "KnowledgeRetrievalService", + client: "BaseLLMClient", + top_k: int = 5, + domains: str | None = None, +) -> list[dict]: + """Run all clauses through retrieve+gap-check in parallel. + + Results are returned in the original clause order even though processing + is concurrent. Exceptions in individual clauses are caught and returned as + dicts with finding=None so the stream continues for remaining clauses. + + Both retrieval_service and client must be thread-safe — they are shared + across all asyncio.to_thread() calls without locking. + """ + tasks = [ + asyncio.to_thread( + process_single_clause, + clause, i, retrieval_service, client, top_k, domains, + ) + for i, clause in enumerate(clauses) + ] + raw = await asyncio.gather(*tasks, return_exceptions=True) + results = [] + for i, r in enumerate(raw): + if isinstance(r, Exception): + logger.warning("Clause {} processing failed: {}", i, r) + results.append({"index": i, "chunks": [], "finding": None}) + else: + results.append(r) + return results + + def check_clause_compliance( clause: str, chunks: list["RetrievedChunk"], client: "BaseLLMClient", ) -> dict | None: - if not chunks: - return None reg_context = "\n".join( f"[{i+1}] {c.doc_title} {c.section_title or ''}: {c.text[:300]}" for i, c in enumerate(chunks[:5]) - ) + ) if chunks else "(no regulatory context retrieved)" prompt = ( "You are a compliance expert. Judge whether the following business clause " "complies with the retrieved regulations.\n\n" @@ -135,9 +196,19 @@ def check_clause_compliance( "status: ok=compliant, warn=gap exists, risk=critical/missing\n" "Return ONLY the JSON object." ) - response = client.chat([{"role": "user", "content": prompt}], max_tokens=500) - if not response.is_success: + + def _do_check(): + resp = client.chat([{"role": "user", "content": prompt}], max_tokens=500) + if not resp.is_success: + raise ValueError("LLM returned non-success for gap check") + return resp + + try: + response = _llm_retry(_do_check)() + except Exception as exc: + logger.warning("check_clause_compliance LLM call failed after retries: {}", exc) return None + try: result = _extract_json(response.content) if isinstance(result, dict) and "status" in result: @@ -182,12 +253,11 @@ def synthesize_conclusion( ' {"label": "Priority", "value": "High/Medium/Low", "risk": true}\n' ' ],\n' ' "risk_score": 0-100 (integer, higher=riskier),\n' - ' "highlight_terms": ["Key terms to highlight, max 10 terms"],\n' + ' "highlight_terms": ["term1", "term2"], // up to 10 key technical/legal terms actually present in the text\n' ' "para_text": "Original text or summary (max 600 chars)"\n' "}\n" "Return ONLY the JSON object." ) - response = client.chat([{"role": "user", "content": prompt}], max_tokens=1200) fallback = { "conclusion": "Compliance analysis complete. Review findings and create remediation plan.", "actions": [ @@ -198,8 +268,19 @@ def synthesize_conclusion( "highlight_terms": [], "para_text": para_text[:800], } - if not response.is_success: + + def _do_synthesize(): + resp = client.chat([{"role": "user", "content": prompt}], max_tokens=1200) + if not resp.is_success: + raise ValueError("LLM returned non-success for synthesis") + return resp + + try: + response = _llm_retry(_do_synthesize)() + except Exception as exc: + logger.warning("synthesize_conclusion LLM call failed after retries: {}", exc) return fallback + try: result = _extract_json(response.content) if isinstance(result, dict): @@ -212,4 +293,78 @@ def synthesize_conclusion( } except (ValueError, TypeError) as exc: logger.warning("Conclusion synthesis JSON parse failed: {}", exc) - return fallback \ No newline at end of file + return fallback + + +_SUGGESTION_FOCUS = { + "risk": "Focus on remediation steps, required certifications, and timeline to resolve.", + "warn": "Focus on identifying the specific compliance gap and how to close it.", + "ok": "Focus on maintaining compliance evidence and monitoring future changes.", +} + +_SUGGESTION_FALLBACK = { + "risk": [ + "What specific certifications or documents are required to remediate this finding?", + "What is the typical remediation timeline for this type of non-compliance?", + "Which regulation clause defines the exact requirement?", + ], + "warn": [ + "What is the exact gap between the current state and the requirement?", + "What evidence would demonstrate partial compliance?", + "Which regulation clause applies to this warning?", + ], + "ok": [ + "What documentation should be maintained to evidence this compliance?", + "How should this area be monitored as regulations evolve?", + "Are there related clauses that may affect this compliant area?", + ], +} + + +def build_finding_context(finding: "FindingRecord", analysis: "AnalysisRecord") -> str: + """Build a grounded system context string for a finding chat thread. + + Combines finding details with analysis metadata so the LLM has full + context without relying on the frontend to pass segment_context. + """ + return ( + f"Document: {analysis.doc_name}\n" + f"Standard: {analysis.standard_name}\n" + f"Finding [{finding.seq + 1}]: {finding.title}\n" + f"Status: {finding.status}\n" + f"Clause reference: {finding.clause_ref or 'N/A'}\n" + f"Description: {finding.description}\n" + f"Overall conclusion: {analysis.conclusion}\n" + ) + + +def generate_suggestions( + finding: "FindingRecord", + analysis: "AnalysisRecord", + client: "BaseLLMClient", +) -> list[str]: + """Generate 3 context-aware follow-up questions for a finding chat thread. + + Returns exactly 3 question strings. Falls back to static templates on error. + """ + fallback = _SUGGESTION_FALLBACK.get(finding.status, _SUGGESTION_FALLBACK["warn"]) + context = build_finding_context(finding, analysis) + focus = _SUGGESTION_FOCUS.get(finding.status, _SUGGESTION_FOCUS["warn"]) + prompt = ( + f"{context}\n\n" + f"Task: {focus}\n" + "Generate exactly 3 concise follow-up questions a compliance analyst would ask.\n" + 'Return JSON: {"questions": ["question 1", "question 2", "question 3"]}\n' + "Return ONLY the JSON object." + ) + response = client.chat([{"role": "user", "content": prompt}], max_tokens=300) + if not response.is_success: + return fallback + try: + result = _extract_json(response.content) + questions = result.get("questions", []) + if isinstance(questions, list) and len(questions) >= 3: + return [str(q) for q in questions[:3]] + except (ValueError, TypeError) as exc: + logger.warning("generate_suggestions JSON parse failed: {}", exc) + return fallback diff --git a/backend/app/domain/compliance/__init__.py b/backend/app/domain/compliance/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/domain/compliance/ports.py b/backend/app/domain/compliance/ports.py new file mode 100644 index 0000000..b27a93e --- /dev/null +++ b/backend/app/domain/compliance/ports.py @@ -0,0 +1,66 @@ +"""Domain ports for compliance history persistence.""" +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import datetime +from typing import Optional + + +@dataclass +class FindingRecord: + """Single finding row linked to an analysis.""" + id: str + analysis_id: str + seq: int + title: str + description: str + status: str # "ok" | "warn" | "risk" + clause_ref: Optional[str] = None + + +@dataclass +class AnalysisRecord: + """Full compliance analysis record with nested findings.""" + id: str # UUID string; empty string means not yet persisted + created_at: datetime + created_by: Optional[str] + doc_name: str + standard_name: str + risk_score: int + conclusion: str + actions: list # list[dict] — serialised action items + para_text: str + highlight_terms: list # list[str] + findings: list[FindingRecord] = field(default_factory=list) + + +class ComplianceRepository(ABC): + """Port for persisting and retrieving compliance analysis records.""" + + @abstractmethod + def save_analysis(self, record: AnalysisRecord) -> str: + """Persist a new analysis record and return the assigned UUID string.""" + + @abstractmethod + def list_analyses(self, limit: int = 50, offset: int = 0) -> list[AnalysisRecord]: + """Return analyses ordered by created_at DESC, without nested findings.""" + + @abstractmethod + def get_analysis(self, analysis_id: str) -> Optional[AnalysisRecord]: + """Return a single analysis with all nested findings, or None.""" + + @abstractmethod + def delete_analysis(self, analysis_id: str) -> None: + """Delete an analysis and all related findings and chat messages (cascade).""" + + @abstractmethod + def save_message(self, analysis_id: str, finding_id: str, role: str, content: str) -> str: + """Persist a chat message and return its UUID string.""" + + @abstractmethod + def get_messages(self, finding_id: str) -> list[dict]: + """Return chat messages for a finding ordered by created_at ASC. + + Each dict has keys: id, role, content, created_at (ISO string). + """ diff --git a/backend/app/infrastructure/compliance/__init__.py b/backend/app/infrastructure/compliance/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/infrastructure/compliance/docx_export.py b/backend/app/infrastructure/compliance/docx_export.py new file mode 100644 index 0000000..68fe3e5 --- /dev/null +++ b/backend/app/infrastructure/compliance/docx_export.py @@ -0,0 +1,101 @@ +"""DOCX report generator for compliance analysis results. + +Uses python-docx (already in requirements.txt). Returns raw bytes so the +caller can stream the response without writing to disk. +""" +from __future__ import annotations + +from datetime import datetime, timezone +from io import BytesIO + +from docx import Document +from docx.shared import Pt, RGBColor +from docx.enum.text import WD_ALIGN_PARAGRAPH + +from app.domain.compliance.ports import AnalysisRecord + +_STATUS_LABEL = {"ok": "Compliant", "warn": "Warning", "risk": "Non-Compliant"} +_STATUS_COLOR = { + "ok": RGBColor(0x22, 0x8B, 0x22), + "warn": RGBColor(0xFF, 0x8C, 0x00), + "risk": RGBColor(0xDC, 0x14, 0x3C), +} + + +def generate_docx(record: AnalysisRecord) -> bytes: + """Generate a compliance report DOCX and return its raw bytes. + + Structure: + - Cover: document name, standard, date, risk score + - Executive summary (conclusion) + - Findings table + - Recommended actions + - Footer note + """ + doc = Document() + + # ── Cover ────────────────────────────────────────────────────────────────── + title_para = doc.add_heading("Compliance Analysis Report", level=0) + title_para.alignment = WD_ALIGN_PARAGRAPH.CENTER + + doc.add_paragraph("") + meta_table = doc.add_table(rows=4, cols=2) + meta_table.style = "Table Grid" + labels = ["Document", "Standard", "Date", "Risk Score"] + values = [ + record.doc_name, + record.standard_name, + record.created_at.strftime("%Y-%m-%d %H:%M UTC") if record.created_at else "", + f"{record.risk_score} / 100", + ] + for i, (label, value) in enumerate(zip(labels, values)): + meta_table.cell(i, 0).text = label + meta_table.cell(i, 1).text = value + + # ── Executive Summary ────────────────────────────────────────────────────── + doc.add_heading("Executive Summary", level=1) + doc.add_paragraph(record.conclusion) + + # ── Findings ─────────────────────────────────────────────────────────────── + doc.add_heading("Findings", level=1) + if record.findings: + table = doc.add_table(rows=1, cols=4) + table.style = "Table Grid" + hdr = table.rows[0].cells + for i, h in enumerate(["#", "Status", "Title", "Description / Clause"]): + hdr[i].text = h + for run in hdr[i].paragraphs[0].runs: + run.bold = True + + for f in record.findings: + row = table.add_row().cells + row[0].text = str(f.seq + 1) + row[1].text = _STATUS_LABEL.get(f.status, f.status) + row[2].text = f.title + desc = f.description + if f.clause_ref: + desc += f"\n[{f.clause_ref}]" + row[3].text = desc + else: + doc.add_paragraph("No findings recorded.") + + # ── Recommended Actions ──────────────────────────────────────────────────── + doc.add_heading("Recommended Actions", level=1) + for i, action in enumerate(record.actions, start=1): + label = action.get("label", "Action") + value = action.get("value", "") + doc.add_paragraph(f"{i}. {label}: {value}", style="List Number") + + # ── Footer note ──────────────────────────────────────────────────────────── + doc.add_paragraph("") + footer = doc.add_paragraph( + f"Generated by AI Regulation Analysis System — {datetime.now(timezone.utc).strftime('%Y-%m-%d')}" + ) + footer.alignment = WD_ALIGN_PARAGRAPH.CENTER + for run in footer.runs: + run.font.size = Pt(8) + run.font.color.rgb = RGBColor(0x88, 0x88, 0x88) + + buf = BytesIO() + doc.save(buf) + return buf.getvalue() diff --git a/backend/app/infrastructure/compliance/repository.py b/backend/app/infrastructure/compliance/repository.py new file mode 100644 index 0000000..c627b2c --- /dev/null +++ b/backend/app/infrastructure/compliance/repository.py @@ -0,0 +1,280 @@ +# backend/app/infrastructure/compliance/repository.py +"""PostgreSQL-backed compliance analysis repository. + +Follows the same psycopg2 pattern as PostgresDocumentRepository: +ThreadedConnectionPool + RealDictCursor for reads, _ensure_schema on init. +""" +from __future__ import annotations + +import json +from contextlib import contextmanager +from datetime import datetime +from typing import Optional + +import psycopg2 +import psycopg2.extras +import psycopg2.pool +from loguru import logger + +from app.domain.compliance.ports import ( + AnalysisRecord, + ComplianceRepository, + FindingRecord, +) + + +class PostgresComplianceRepository(ComplianceRepository): + """Stores compliance analyses, findings, and finding chat messages in PostgreSQL.""" + + def __init__( + self, + host: str, + port: int, + user: str, + password: str, + dbname: str, + minconn: int = 1, + maxconn: int = 5, + ) -> None: + self._pool = psycopg2.pool.ThreadedConnectionPool( + minconn=minconn, + maxconn=maxconn, + host=host, + port=port, + user=user, + password=password, + dbname=dbname, + ) + self._ensure_schema() + + @contextmanager + def _conn(self): + conn = self._pool.getconn() + try: + yield conn + finally: + self._pool.putconn(conn) + + def _ensure_schema(self) -> None: + """Create tables if they do not exist.""" + with self._conn() as conn: + with conn.cursor() as cur: + cur.execute(""" + CREATE TABLE IF NOT EXISTS compliance_analyses ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + created_by VARCHAR(255), + doc_name VARCHAR(500), + standard_name VARCHAR(500), + risk_score INTEGER, + conclusion TEXT, + actions JSONB, + para_text TEXT, + highlight_terms JSONB + ); + """) + cur.execute(""" + CREATE TABLE IF NOT EXISTS compliance_findings ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + analysis_id UUID NOT NULL REFERENCES compliance_analyses(id) ON DELETE CASCADE, + seq INTEGER NOT NULL, + title VARCHAR(500), + description TEXT, + status VARCHAR(50), + clause_ref VARCHAR(200) + ); + """) + cur.execute(""" + CREATE TABLE IF NOT EXISTS finding_chat_messages ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + analysis_id UUID NOT NULL REFERENCES compliance_analyses(id) ON DELETE CASCADE, + finding_id UUID NOT NULL REFERENCES compliance_findings(id) ON DELETE CASCADE, + role VARCHAR(20) NOT NULL, + content TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() + ); + """) + conn.commit() + + def save_analysis(self, record: AnalysisRecord) -> str: + """Insert analysis + findings; return the new analysis UUID.""" + with self._conn() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: + cur.execute( + """ + INSERT INTO compliance_analyses + (created_by, doc_name, standard_name, risk_score, + conclusion, actions, para_text, highlight_terms) + VALUES + (%(created_by)s, %(doc_name)s, %(standard_name)s, %(risk_score)s, + %(conclusion)s, %(actions)s, %(para_text)s, %(highlight_terms)s) + RETURNING id + """, + { + "created_by": record.created_by, + "doc_name": record.doc_name, + "standard_name": record.standard_name, + "risk_score": record.risk_score, + "conclusion": record.conclusion, + "actions": json.dumps(record.actions, ensure_ascii=False), + "para_text": record.para_text, + "highlight_terms": json.dumps(record.highlight_terms, ensure_ascii=False), + }, + ) + row = cur.fetchone() + analysis_id = str(row["id"]) + + if record.findings: + with conn.cursor() as cur: + for f in record.findings: + cur.execute( + """ + INSERT INTO compliance_findings + (analysis_id, seq, title, description, status, clause_ref) + VALUES + (%(analysis_id)s, %(seq)s, %(title)s, %(desc)s, %(status)s, %(clause_ref)s) + """, + { + "analysis_id": analysis_id, + "seq": f.seq, + "title": f.title, + "desc": f.description, + "status": f.status, + "clause_ref": f.clause_ref, + }, + ) + conn.commit() + return analysis_id + + def list_analyses(self, limit: int = 50, offset: int = 0) -> list[AnalysisRecord]: + """Return analyses without nested findings, ordered newest first.""" + with self._conn() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: + cur.execute( + """ + SELECT id, created_at, created_by, doc_name, standard_name, + risk_score, conclusion, actions, para_text, highlight_terms + FROM compliance_analyses + ORDER BY created_at DESC + LIMIT %(limit)s OFFSET %(offset)s + """, + {"limit": limit, "offset": offset}, + ) + rows = cur.fetchall() + return [self._row_to_record(dict(r)) for r in rows] + + def get_analysis(self, analysis_id: str) -> Optional[AnalysisRecord]: + """Return analysis with nested findings list.""" + with self._conn() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: + cur.execute( + "SELECT * FROM compliance_analyses WHERE id = %(id)s", + {"id": analysis_id}, + ) + row = cur.fetchone() + if not row: + return None + record = self._row_to_record(dict(row)) + + cur.execute( + """ + SELECT id, analysis_id, seq, title, description, status, clause_ref + FROM compliance_findings + WHERE analysis_id = %(id)s + ORDER BY seq + """, + {"id": analysis_id}, + ) + findings = [ + FindingRecord( + id=str(r["id"]), + analysis_id=str(r["analysis_id"]), + seq=r["seq"], + title=r["title"] or "", + description=r["description"] or "", + status=r["status"] or "ok", + clause_ref=r["clause_ref"], + ) + for r in cur.fetchall() + ] + record.findings = findings + return record + + def delete_analysis(self, analysis_id: str) -> None: + """Delete analysis; findings and chat messages cascade automatically.""" + with self._conn() as conn: + with conn.cursor() as cur: + cur.execute( + "DELETE FROM compliance_analyses WHERE id = %(id)s", + {"id": analysis_id}, + ) + conn.commit() + + def save_message(self, analysis_id: str, finding_id: str, role: str, content: str) -> str: + """Persist a chat message; return its UUID.""" + with self._conn() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: + cur.execute( + """ + INSERT INTO finding_chat_messages + (analysis_id, finding_id, role, content) + VALUES + (%(analysis_id)s, %(finding_id)s, %(role)s, %(content)s) + RETURNING id + """, + { + "analysis_id": analysis_id, + "finding_id": finding_id, + "role": role, + "content": content, + }, + ) + row = cur.fetchone() + conn.commit() + return str(row["id"]) + + def get_messages(self, finding_id: str) -> list[dict]: + """Return messages for a finding, oldest first.""" + with self._conn() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: + cur.execute( + """ + SELECT id, role, content, created_at + FROM finding_chat_messages + WHERE finding_id = %(finding_id)s + ORDER BY created_at ASC + """, + {"finding_id": finding_id}, + ) + rows = cur.fetchall() + return [ + { + "id": str(r["id"]), + "role": r["role"], + "content": r["content"], + "created_at": r["created_at"].isoformat() if r["created_at"] else "", + } + for r in rows + ] + + def _row_to_record(self, row: dict) -> AnalysisRecord: + """Convert a RealDictCursor row to an AnalysisRecord (no findings).""" + actions = row.get("actions") or [] + if isinstance(actions, str): + actions = json.loads(actions) + highlight_terms = row.get("highlight_terms") or [] + if isinstance(highlight_terms, str): + highlight_terms = json.loads(highlight_terms) + return AnalysisRecord( + id=str(row["id"]), + created_at=row["created_at"] if isinstance(row["created_at"], datetime) else datetime.utcnow(), + created_by=row.get("created_by"), + doc_name=row.get("doc_name") or "", + standard_name=row.get("standard_name") or "", + risk_score=int(row.get("risk_score") or 0), + conclusion=row.get("conclusion") or "", + actions=actions, + para_text=row.get("para_text") or "", + highlight_terms=highlight_terms, + findings=[], + ) diff --git a/backend/app/infrastructure/vectorstore/pass_through_reranker.py b/backend/app/infrastructure/vectorstore/pass_through_reranker.py new file mode 100644 index 0000000..8ec4804 --- /dev/null +++ b/backend/app/infrastructure/vectorstore/pass_through_reranker.py @@ -0,0 +1,21 @@ +"""No-op reranker stub. + +Returns the original candidate list sliced to top_k. +Replace with CrossEncoderReranker when a local cross-encoder model is available. +""" +from __future__ import annotations + +from app.domain.retrieval.models import RetrievedChunk +from app.domain.retrieval.ports import Reranker + + +class PassThroughReranker(Reranker): + """Pass-through reranker that preserves original retrieval order. + + Acts as a placeholder for future cross-encoder reranking (e.g. ms-marco-MiniLM). + Wire via bootstrap.get_compliance_reranker() when ready to swap. + """ + + def rerank(self, query: str, chunks: list[RetrievedChunk], top_k: int) -> list[RetrievedChunk]: + """Return the first top_k chunks without reordering.""" + return chunks[:top_k] diff --git a/backend/app/shared/bootstrap.py b/backend/app/shared/bootstrap.py index 1f2d981..b016913 100644 --- a/backend/app/shared/bootstrap.py +++ b/backend/app/shared/bootstrap.py @@ -40,6 +40,8 @@ from app.infrastructure.vectorstore.cross_encoder_reranker import OpenAICompatib from app.infrastructure.vectorstore.dense_retriever import DenseRetriever from app.infrastructure.vectorstore.milvus_vector_index import MilvusVectorIndex from app.services.llm.llm_factory import LLMFactory +from app.domain.compliance.ports import ComplianceRepository +from app.infrastructure.compliance.repository import PostgresComplianceRepository # Keep shared wiring centralized so dependency construction remains consistent. @@ -311,6 +313,28 @@ def get_event_store() -> BaseEventStore: return MockEventStore() +@lru_cache +def get_compliance_repository() -> ComplianceRepository: + """Return the compliance analysis repository. + + Requires document_repository_backend=postgres and valid postgres_* settings. + Raises NotImplementedError for any other backend value. + """ + if settings.document_repository_backend != "postgres": + raise NotImplementedError( + f"ComplianceRepository requires document_repository_backend=postgres, " + f"got '{settings.document_repository_backend}'. " + "Set DOCUMENT_REPOSITORY_BACKEND=postgres in your .env file." + ) + return PostgresComplianceRepository( + host=settings.postgres_host, + port=settings.postgres_port, + user=settings.postgres_user, + password=settings.postgres_password, + dbname=settings.postgres_db, + ) + + @lru_cache def get_perception_service() -> PerceptionService: return PerceptionService( diff --git a/backend/tests/compliance/__init__.py b/backend/tests/compliance/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/compliance/test_pipeline.py b/backend/tests/compliance/test_pipeline.py new file mode 100644 index 0000000..e9af9fc --- /dev/null +++ b/backend/tests/compliance/test_pipeline.py @@ -0,0 +1,140 @@ +import asyncio + +import pytest +from unittest.mock import MagicMock, patch +from datetime import datetime + +from app.infrastructure.vectorstore.pass_through_reranker import PassThroughReranker +from app.domain.retrieval.models import RetrievedChunk +from app.domain.compliance.ports import AnalysisRecord, FindingRecord + + +# ── helpers ────────────────────────────────────────────────────────────────── + +def _make_chunk(score: float) -> RetrievedChunk: + return RetrievedChunk( + chunk_id="c1", + doc_id="d1", + doc_title="Test Doc", + section_title="S1", + text="some text", + score=score, + page_start=1, + ) + + +def _make_mock_client(content: str = '{"status":"ok","title":"T","desc":"D","clause_ref":"A1"}'): + client = MagicMock() + response = MagicMock() + response.is_success = True + response.content = content + client.chat.return_value = response + return client + + +def _make_mock_retrieval(): + svc = MagicMock() + svc.retrieve.return_value = [] + return svc + + +# ── existing tests ──────────────────────────────────────────────────────────── + +def test_pass_through_returns_top_k(): + reranker = PassThroughReranker() + chunks = [_make_chunk(0.9), _make_chunk(0.8), _make_chunk(0.7)] + result = reranker.rerank(query="test", chunks=chunks, top_k=2) + assert len(result) == 2 + assert result[0].score == 0.9 + + +def test_pass_through_returns_all_when_top_k_exceeds(): + reranker = PassThroughReranker() + chunks = [_make_chunk(0.5)] + result = reranker.rerank(query="test", chunks=chunks, top_k=10) + assert len(result) == 1 + + +# ── new tests ───────────────────────────────────────────────────────────────── + +def test_process_single_clause_returns_finding(): + from app.application.compliance.pipeline import process_single_clause + client = _make_mock_client() + svc = _make_mock_retrieval() + result = process_single_clause("test clause", 0, svc, client) + assert result["finding"] is not None + assert result["index"] == 0 + assert result["chunks"] == [] + + +def test_run_clauses_parallel_runs_all(): + from app.application.compliance.pipeline import run_clauses_parallel + client = _make_mock_client() + svc = _make_mock_retrieval() + clauses = ["clause one", "clause two", "clause three"] + results = asyncio.run(run_clauses_parallel(clauses, svc, client)) + assert len(results) == 3 + assert all(r["index"] == i for i, r in enumerate(results)) + + +def test_run_clauses_parallel_handles_clause_failure(): + from app.application.compliance.pipeline import run_clauses_parallel + svc = _make_mock_retrieval() + bad_client = MagicMock() + bad_client.chat.side_effect = RuntimeError("LLM exploded") + results = asyncio.run(run_clauses_parallel( + ["clause one", "clause two"], svc, bad_client + )) + assert len(results) == 2 + assert all(r["finding"] is None for r in results) + assert all(r["chunks"] == [] for r in results) + + +# ── helpers for new tests ───────────────────────────────────────────────────── + +def _sample_analysis() -> AnalysisRecord: + return AnalysisRecord( + id="a1", created_at=datetime(2026, 6, 8), created_by="u", + doc_name="doc.pdf", standard_name="EU AI Act", + risk_score=72, conclusion="Gaps found.", actions=[], para_text="para", + highlight_terms=[], findings=[], + ) + + +def _sample_finding(status: str = "risk") -> FindingRecord: + return FindingRecord( + id="f1", analysis_id="a1", seq=0, + title="Missing CSMS", description="No CSMS certification.", + status=status, clause_ref="Art.9.1", + ) + + +# ── new tests ───────────────────────────────────────────────────────────────── + +def test_build_finding_context_contains_required_fields(): + from app.application.compliance.pipeline import build_finding_context + ctx = build_finding_context(_sample_finding(), _sample_analysis()) + assert "doc.pdf" in ctx + assert "EU AI Act" in ctx + assert "Missing CSMS" in ctx + assert "Art.9.1" in ctx + + +def test_generate_suggestions_returns_three_questions(): + from app.application.compliance.pipeline import generate_suggestions + client = _make_mock_client( + '{"questions": ["Q1?", "Q2?", "Q3?"]}' + ) + questions = generate_suggestions(_sample_finding("risk"), _sample_analysis(), client) + assert len(questions) == 3 + assert all(isinstance(q, str) for q in questions) + + +def test_generate_suggestions_falls_back_on_error(): + from app.application.compliance.pipeline import generate_suggestions + bad_client = MagicMock() + bad_resp = MagicMock() + bad_resp.is_success = False + bad_client.chat.return_value = bad_resp + questions = generate_suggestions(_sample_finding(), _sample_analysis(), bad_client) + assert len(questions) == 3 # fallback always returns 3 diff --git a/backend/tests/compliance/test_repository.py b/backend/tests/compliance/test_repository.py new file mode 100644 index 0000000..bdc7aef --- /dev/null +++ b/backend/tests/compliance/test_repository.py @@ -0,0 +1,98 @@ +from unittest.mock import MagicMock, patch +from datetime import datetime +from app.domain.compliance.ports import ( + AnalysisRecord, + FindingRecord, + ComplianceRepository, +) + + +def _mock_pool(): + """Return a mock psycopg2 ThreadedConnectionPool.""" + conn = MagicMock() + cursor = MagicMock() + cursor.__enter__ = MagicMock(return_value=cursor) + cursor.__exit__ = MagicMock(return_value=False) + conn.cursor.return_value = cursor + pool = MagicMock() + pool.getconn.return_value = conn + return pool, conn, cursor + + +@patch("app.infrastructure.compliance.repository.psycopg2.pool.ThreadedConnectionPool") +def test_save_analysis_returns_uuid(mock_pool_cls): + from app.infrastructure.compliance.repository import PostgresComplianceRepository + pool, conn, cursor = _mock_pool() + mock_pool_cls.return_value = pool + cursor.fetchone.return_value = {"id": "abc-123"} + + repo = PostgresComplianceRepository( + host="localhost", port=5432, user="u", password="p", dbname="db" + ) + record = AnalysisRecord( + id="", created_at=datetime.utcnow(), created_by="user1", + doc_name="doc.pdf", standard_name="EU AI Act", + risk_score=50, conclusion="OK", actions=[], para_text="p", + highlight_terms=[], findings=[], + ) + result = repo.save_analysis(record) + assert result == "abc-123" + + +def test_analysis_record_construction(): + record = AnalysisRecord( + id="", + created_at=datetime.utcnow(), + created_by="user1", + doc_name="test.pdf", + standard_name="EU AI Act", + risk_score=72, + conclusion="Several gaps found.", + actions=[{"label": "Fix", "value": "Update docs"}], + para_text="The system shall...", + highlight_terms=["CSMS", "ISO 21434"], + findings=[ + FindingRecord( + id="", + analysis_id="", + seq=0, + title="Missing CSMS", + description="No CSMS certification found.", + status="risk", + clause_ref="Art.9.1", + ) + ], + ) + assert record.doc_name == "test.pdf" + assert len(record.findings) == 1 + assert record.findings[0].status == "risk" + + +def test_compliance_repository_is_abstract(): + import inspect + assert inspect.isabstract(ComplianceRepository) + + +def test_generate_docx_returns_bytes(): + from app.infrastructure.compliance.docx_export import generate_docx + record = AnalysisRecord( + id="test-id", created_at=datetime(2026, 6, 8), created_by="user1", + doc_name="test.pdf", standard_name="EU AI Act", + risk_score=72, conclusion="Several gaps found.", + actions=[{"label": "Fix", "value": "Update CSMS docs"}], + para_text="The system shall implement CSMS.", + highlight_terms=["CSMS"], + findings=[ + FindingRecord( + id="f1", analysis_id="test-id", seq=0, + title="Missing CSMS", description="No CSMS cert.", + status="risk", clause_ref="Art.9.1", + ) + ], + ) + data = generate_docx(record) + assert isinstance(data, bytes) + assert len(data) > 1000 # DOCX is at minimum a ZIP with ~1 KB overhead + # Verify it's a valid ZIP (DOCX = ZIP container) + import zipfile, io + assert zipfile.is_zipfile(io.BytesIO(data)) diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 6427b32..d64d7a7 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -58,7 +58,8 @@ services: retries: 5 restart: unless-stopped - # PostgreSQL数据库 (可选,启用 DOCUMENT_REPOSITORY_BACKEND=postgres 时使用) + # PostgreSQL数据库 (启用 DOCUMENT_REPOSITORY_BACKEND=postgres 时使用; + # 合规分析历史记录 Direction B、DOCX 报告下载及 Finding Chat 持久化 Direction C 均依赖此服务) postgres: image: postgres:15-alpine container_name: postgres diff --git a/docs/superpowers/plans/2026-06-08-compliance-enhancement.md b/docs/superpowers/plans/2026-06-08-compliance-enhancement.md new file mode 100644 index 0000000..a7dadda --- /dev/null +++ b/docs/superpowers/plans/2026-06-08-compliance-enhancement.md @@ -0,0 +1,2292 @@ +# Compliance Analysis Enhancement Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Enhance the Compliance Analysis module with parallel clause processing + bug fixes (A), persistent analysis history + DOCX export (B), and per-finding persistent chat threads with LLM-generated suggestions (C). + +**Architecture:** Direction A refactors the SSE pipeline in `app/api/routes/compliance.py` to run clauses in parallel via `asyncio.gather` and fixes `highlight_terms` and LLM retry. Direction B adds a PostgreSQL-backed `ComplianceRepository` (new domain port + infrastructure) and three REST endpoints plus a frontend History Rail. Direction C adds per-finding chat endpoints grounded in real analysis data and a React Drawer component. Implementation order: A → B → C (each direction is a prerequisite for the next). + +**Tech Stack:** FastAPI (SSE), psycopg2 (ThreadedConnectionPool), python-docx, tenacity (retry), React 18, TypeScript, existing `PageStateContext` pattern. + +--- + +## File Map + +### Direction A — Analysis Quality + +| File | Action | +|------|--------| +| `backend/app/application/compliance/pipeline.py` | Add `process_single_clause()`, `_call_llm_with_retry()` wrapper, add `tenacity` retry to `synthesize_conclusion` and `check_clause_compliance` | +| `backend/app/api/routes/compliance.py` | Replace sequential for-loop with `asyncio.gather` on `process_single_clause` | +| `backend/app/infrastructure/vectorstore/pass_through_reranker.py` | New — `PassThroughReranker` stub | +| `backend/tests/compliance/test_pipeline.py` | New — unit tests for parallel processing + highlight_terms fix | + +### Direction B — History & Reports + +| File | Action | +|------|--------| +| `backend/app/domain/compliance/__init__.py` | New — empty | +| `backend/app/domain/compliance/ports.py` | New — `FindingRecord`, `AnalysisRecord`, `ComplianceRepository` ABC | +| `backend/app/infrastructure/compliance/__init__.py` | New — empty | +| `backend/app/infrastructure/compliance/repository.py` | New — `PostgresComplianceRepository` | +| `backend/app/infrastructure/compliance/docx_export.py` | New — `generate_docx()` | +| `backend/app/api/routes/compliance.py` | Add history endpoints + auto-save hook in `analyze_stream` | +| `backend/app/shared/bootstrap.py` | Add `get_compliance_repository()` factory | +| `frontend/src/pages/Compliance/HistoryRail.tsx` | New — History Rail component | +| `frontend/src/pages/Compliance/CompliancePage.tsx` | Import + render `HistoryRail`, handle `saved` SSE event, store `analysisId` | +| `frontend/src/contexts/PageStateContext.tsx` | Add `analysisId`, `isReadOnly` to `ComplianceState` | +| `backend/tests/compliance/test_repository.py` | New — unit tests for repository (with mock psycopg2) | + +### Direction C — Deep Chat + +| File | Action | +|------|--------| +| `backend/app/application/compliance/pipeline.py` | Add `build_finding_context()`, `generate_suggestions()` | +| `backend/app/api/routes/compliance.py` | Add 3 finding-chat endpoints, deprecate old chat endpoint | +| `frontend/src/pages/Compliance/FindingChatDrawer.tsx` | New — per-finding chat drawer | +| `frontend/src/pages/Compliance/CompliancePage.tsx` | Add Chat button to finding cards, render `FindingChatDrawer` | +| `frontend/src/contexts/PageStateContext.tsx` | Add `activeFindingId` to `ComplianceState` | +| `backend/tests/compliance/test_pipeline.py` | Extend — test `build_finding_context` + `generate_suggestions` | + +--- + +## Direction A Tasks + +### Task 1: PassThroughReranker stub + +**Files:** +- Create: `backend/app/infrastructure/vectorstore/pass_through_reranker.py` +- Test: `backend/tests/compliance/test_pipeline.py` + +- [ ] **Step 1: Create test file and write failing test** + +```python +# backend/tests/compliance/test_pipeline.py +import pytest +from app.infrastructure.vectorstore.pass_through_reranker import PassThroughReranker +from app.domain.retrieval.models import RetrievedChunk + + +def _make_chunk(score: float) -> RetrievedChunk: + return RetrievedChunk( + doc_id="d1", + doc_title="Test Doc", + section_title="S1", + text="some text", + score=score, + page_start=1, + ) + + +def test_pass_through_returns_top_k(): + reranker = PassThroughReranker() + chunks = [_make_chunk(0.9), _make_chunk(0.8), _make_chunk(0.7)] + result = reranker.rerank(query="test", chunks=chunks, top_k=2) + assert len(result) == 2 + assert result[0].score == 0.9 + + +def test_pass_through_returns_all_when_top_k_exceeds(): + reranker = PassThroughReranker() + chunks = [_make_chunk(0.5)] + result = reranker.rerank(query="test", chunks=chunks, top_k=10) + assert len(result) == 1 +``` + +- [ ] **Step 2: Run test to verify it fails** + +``` +cd backend +python -m pytest tests/compliance/test_pipeline.py::test_pass_through_returns_top_k -v +``` + +Expected: `ModuleNotFoundError` or `ImportError` — `pass_through_reranker` does not exist yet. + +- [ ] **Step 3: Create `tests/compliance/__init__.py`** + +```python +# backend/tests/compliance/__init__.py +``` + +- [ ] **Step 4: Create the reranker** + +```python +# backend/app/infrastructure/vectorstore/pass_through_reranker.py +"""No-op reranker stub. + +Returns the original candidate list sliced to top_k. +Replace with CrossEncoderReranker when a local cross-encoder model is available. +""" +from __future__ import annotations + +from app.domain.retrieval.models import RetrievedChunk +from app.domain.retrieval.ports import Reranker + + +class PassThroughReranker(Reranker): + """Pass-through reranker that preserves original retrieval order. + + Acts as a placeholder for future cross-encoder reranking (e.g. ms-marco-MiniLM). + Wire via bootstrap.get_compliance_reranker() when ready to swap. + """ + + def rerank(self, query: str, chunks: list[RetrievedChunk], top_k: int) -> list[RetrievedChunk]: + """Return the first top_k chunks without reordering.""" + return chunks[:top_k] +``` + +- [ ] **Step 5: Run tests to verify they pass** + +``` +cd backend +python -m pytest tests/compliance/test_pipeline.py::test_pass_through_returns_top_k tests/compliance/test_pipeline.py::test_pass_through_returns_all_when_top_k_exceeds -v +``` + +Expected: 2 passed. + +--- + +### Task 2: Parallel clause processing + LLM retry + +**Files:** +- Modify: `backend/app/application/compliance/pipeline.py` +- Modify: `backend/app/api/routes/compliance.py` +- Test: `backend/tests/compliance/test_pipeline.py` + +- [ ] **Step 1: Write failing tests** + +Add these tests to `backend/tests/compliance/test_pipeline.py`: + +```python +import asyncio +from unittest.mock import MagicMock, patch + + +def _make_mock_client(content: str = '{"status":"ok","title":"T","desc":"D","clause_ref":"A1"}'): + client = MagicMock() + response = MagicMock() + response.is_success = True + response.content = content + client.chat.return_value = response + return client + + +def _make_mock_retrieval(): + svc = MagicMock() + svc.retrieve.return_value = [] + return svc + + +def test_process_single_clause_returns_finding(): + from app.application.compliance.pipeline import process_single_clause + client = _make_mock_client() + svc = _make_mock_retrieval() + result = process_single_clause("test clause", 0, svc, client, "EU AI Act", "para") + assert result["finding"] is not None + assert result["index"] == 0 + assert result["chunks"] == [] + + +def test_run_clauses_parallel_runs_all(): + from app.application.compliance.pipeline import run_clauses_parallel + client = _make_mock_client() + svc = _make_mock_retrieval() + clauses = ["clause one", "clause two", "clause three"] + results = asyncio.run(run_clauses_parallel(clauses, svc, client, "EU AI Act", "para")) + assert len(results) == 3 + assert all(r["index"] == i for i, r in enumerate(results)) +``` + +- [ ] **Step 2: Run to verify failure** + +``` +cd backend +python -m pytest tests/compliance/test_pipeline.py::test_process_single_clause_returns_finding -v +``` + +Expected: `ImportError: cannot import name 'process_single_clause'` + +- [ ] **Step 3: Add `process_single_clause` and `run_clauses_parallel` to `pipeline.py`** + +Add after line 109 (after `retrieve_for_clause`), before `check_clause_compliance`: + +```python +import asyncio +from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type +import httpx + + +def process_single_clause( + clause: str, + index: int, + retrieval_service: "KnowledgeRetrievalService", + client: "BaseLLMClient", + standard_name: str, + para_text: str, + top_k: int = 5, + domains: str | None = None, +) -> dict: + """Process one clause: retrieve relevant regulations then check compliance. + + Returns a dict with keys: index, chunks, finding (may be None on LLM failure). + Designed to run inside asyncio.to_thread() for parallel execution. + """ + chunks = retrieve_for_clause(clause, retrieval_service, top_k, domains) + finding = check_clause_compliance(clause, chunks, client) + return {"index": index, "chunks": chunks, "finding": finding} + + +async def run_clauses_parallel( + clauses: list[str], + retrieval_service: "KnowledgeRetrievalService", + client: "BaseLLMClient", + standard_name: str, + para_text: str, + top_k: int = 5, + domains: str | None = None, +) -> list[dict]: + """Run all clauses through retrieve+gap-check in parallel. + + Results are returned in the original clause order even though processing + is concurrent. Exceptions in individual clauses are caught and returned as + dicts with finding=None so the stream continues for remaining clauses. + """ + tasks = [ + asyncio.to_thread( + process_single_clause, + clause, i, retrieval_service, client, standard_name, para_text, top_k, domains, + ) + for i, clause in enumerate(clauses) + ] + raw = await asyncio.gather(*tasks, return_exceptions=True) + results = [] + for i, r in enumerate(raw): + if isinstance(r, Exception): + logger.warning("Clause {} processing failed: {}", i, r) + results.append({"index": i, "chunks": [], "finding": None}) + else: + results.append(r) + return results +``` + +Also add the imports at the top of `pipeline.py` (after existing imports): + +```python +import asyncio +import httpx +from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type +``` + +- [ ] **Step 4: Wrap LLM calls in `check_clause_compliance` with retry** + +Replace the direct `client.chat(...)` call in `check_clause_compliance` (line 137) with: + +```python +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=1, max=4), + retry=retry_if_exception_type((httpx.HTTPError, ValueError)), + reraise=True, +) +def _llm_check(prompt_messages: list[dict]) -> "BaseLLMClient": + resp = client.chat(prompt_messages, max_tokens=500) + if not resp.is_success: + raise ValueError(f"LLM returned non-success status for gap check") + return resp +``` + +Then in `check_clause_compliance`, replace: + +```python +response = client.chat([{"role": "user", "content": prompt}], max_tokens=500) +if not response.is_success: + return None +``` + +with: + +```python +try: + response = _llm_check([{"role": "user", "content": prompt}]) +except Exception as exc: + logger.warning("check_clause_compliance LLM call failed after retries: {}", exc) + return None +``` + +Apply the same pattern to `synthesize_conclusion` (line 189): wrap `client.chat(...)` in a nested retry function and catch on final failure to return `fallback`. + +**Also fix the `highlight_terms` prompt bug** — the current prompt uses `["Key terms to highlight, max 10 terms"]` as an example string, so the LLM returns the literal example instead of real terms. In `synthesize_conclusion`, replace that prompt line with: + +```python +' "highlight_terms": ["term1", "term2"], // up to 10 key technical/legal terms actually present in the text\n' +``` + +- [ ] **Step 5: Replace sequential loop in `compliance.py` with parallel version** + +In `backend/app/api/routes/compliance.py`, replace the Stage 3 block (lines 138–168): + +```python +# ── Stage 3: retrieve + gap check per clause ────────────────── +findings: list[dict] = [] + +for i, clause in enumerate(clauses): + yield _sse({ + "type": "stage", + "stage": "analyzing", + "label": f"Analyzing clause {i + 1}/{len(clauses)}…", + }) + await asyncio.sleep(0) + + chunks = await asyncio.to_thread( + retrieve_for_clause, clause, retrieval_service, 5, domains or None + ) + + # Emit source events + for chunk in chunks[:3]: + yield _sse({...}) + await asyncio.sleep(0) + + finding = await asyncio.to_thread(check_clause_compliance, clause, chunks, client) + if finding: + findings.append(finding) + yield _sse({"type": "finding", **finding}) + await asyncio.sleep(0) +``` + +With: + +```python +# ── Stage 3: retrieve + gap check (parallel across all clauses) ──────────── +from app.application.compliance.pipeline import run_clauses_parallel + +findings: list[dict] = [] + +yield _sse({ + "type": "stage", + "stage": "analyzing", + "label": f"Analyzing {len(clauses)} clauses in parallel…", +}) +await asyncio.sleep(0) + +clause_results = await run_clauses_parallel( + clauses, retrieval_service, client, + standard_name=title or "", + para_text=para_text, + top_k=5, + domains=domains or None, +) + +for res in clause_results: + i = res["index"] + chunks = res["chunks"] + finding = res["finding"] + + # Emit source events for this clause + for chunk in chunks[:3]: + yield _sse({ + "type": "source", + "standard": getattr(chunk, "doc_title", "") or getattr(chunk, "doc_name", ""), + "clause": getattr(chunk, "section_title", "") or "", + "score": round(float(getattr(chunk, "score", 0)), 3), + "status": "retrieved", + "full_content": (getattr(chunk, "text", "") or "")[:300], + }) + + if finding: + findings.append(finding) + yield _sse({"type": "finding", **finding}) + + await asyncio.sleep(0) +``` + +Also update the imports at the top of the `generate()` inner function — add `run_clauses_parallel` to the `from app.application.compliance.pipeline import (...)` block and remove individual imports of `retrieve_for_clause`, `check_clause_compliance`. + +- [ ] **Step 6: Run all Direction A tests** + +``` +cd backend +python -m pytest tests/compliance/ -v +``` + +Expected: all pass. + +- [ ] **Step 7: Smoke test the SSE endpoint manually** + +```bash +curl -X POST http://localhost:8000/api/v1/compliance/analyze-stream \ + -H "Authorization: Bearer+ Completed analyses appear here. +
+Loading history…
+ )} + {messages.map(msg => ( +{t.overview.eyebrow}
+{t.overview.heroDesc}
+{t.signals.emptySelectSignal}
` + +Find: `` +Replace: `` + +Find: `` +Replace: `` + +Find: `暂无结构化数据。点击右上角"Run impact analysis"触发提取。
` +Replace: `{t.signals.obligationsEmpty}
` + +Find: `No affected documents found.
` +Replace: `{t.signals.noAffectedDocs}
` + +Find: `{message}
+