242 lines
9.3 KiB
Python
242 lines
9.3 KiB
Python
|
|
"""LLM-driven pipeline for regulatory event enrichment."""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import json
|
||
|
|
import math
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
from loguru import logger
|
||
|
|
|
||
|
|
from app.config.settings import settings
|
||
|
|
from app.infrastructure.embedding.openai_compatible_embedding_provider import (
|
||
|
|
OpenAICompatibleEmbeddingProvider,
|
||
|
|
)
|
||
|
|
from app.services.llm.llm_factory import get_llm_client
|
||
|
|
|
||
|
|
_EXTRACT_SYSTEM = (
|
||
|
|
"You are a regulatory compliance expert specialising in automotive standards "
|
||
|
|
"(GB, UN-ECE, ISO, EU). Extract structured information from regulation text. "
|
||
|
|
"Return valid JSON only — no markdown fences, no extra keys."
|
||
|
|
)
|
||
|
|
|
||
|
|
_ASSESS_SYSTEM = (
|
||
|
|
"You are an automotive compliance analyst. Given a regulation and related document excerpts, "
|
||
|
|
"identify which documents are affected and what actions are required. "
|
||
|
|
"Return a JSON array only."
|
||
|
|
)
|
||
|
|
|
||
|
|
_DIFF_SYSTEM = (
|
||
|
|
"You are a regulatory change analyst. Given an old and new version of a regulation paragraph, "
|
||
|
|
"classify the type of change and summarise it. "
|
||
|
|
"Return JSON only: {\"change_type\": \"tightened|relaxed|added|removed\", \"summary\": \"...\"}"
|
||
|
|
)
|
||
|
|
|
||
|
|
_SIMILARITY_THRESHOLD = 0.85
|
||
|
|
|
||
|
|
|
||
|
|
def _cosine(a: list[float], b: list[float]) -> float:
|
||
|
|
dot = sum(x * y for x, y in zip(a, b))
|
||
|
|
norm_a = math.sqrt(sum(x * x for x in a))
|
||
|
|
norm_b = math.sqrt(sum(x * x for x in b))
|
||
|
|
if norm_a == 0 or norm_b == 0:
|
||
|
|
return 0.0
|
||
|
|
return dot / (norm_a * norm_b)
|
||
|
|
|
||
|
|
|
||
|
|
def _llm_json(client: Any, messages: list[dict]) -> Any:
|
||
|
|
"""Call LLM and parse JSON response; return None on failure."""
|
||
|
|
try:
|
||
|
|
resp = client.chat(messages)
|
||
|
|
text = (resp.content or "").strip()
|
||
|
|
if text.startswith("```"):
|
||
|
|
text = text.split("```")[1]
|
||
|
|
if text.startswith("json"):
|
||
|
|
text = text[4:]
|
||
|
|
return json.loads(text)
|
||
|
|
except Exception as exc:
|
||
|
|
logger.warning("LLM JSON parse failed: {}", exc)
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
class LlmPipeline:
|
||
|
|
"""Three-step enrichment pipeline for crawled regulatory events."""
|
||
|
|
|
||
|
|
def __init__(self) -> None:
|
||
|
|
self._client = get_llm_client(
|
||
|
|
provider=settings.llm_provider,
|
||
|
|
model=settings.llm_model,
|
||
|
|
)
|
||
|
|
self._embedder = OpenAICompatibleEmbeddingProvider()
|
||
|
|
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
# Step 1: Structure extraction
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
|
||
|
|
def extract_structure(self, event: dict) -> dict:
|
||
|
|
"""Extract obligations, deadlines, scope, penalties, impact_level from event text."""
|
||
|
|
prompt = f"""Extract structured compliance information from this regulation:
|
||
|
|
|
||
|
|
Standard: {event.get('standard_code', '')}
|
||
|
|
Title: {event.get('title', '')}
|
||
|
|
Source: {event.get('source_label', '')}
|
||
|
|
Summary: {event.get('summary', '')}
|
||
|
|
Tags: {', '.join(event.get('tags') or [])}
|
||
|
|
|
||
|
|
Return JSON with exactly these keys:
|
||
|
|
{{
|
||
|
|
"obligations": [{{"text": "...", "deontic": "must|shall|may|prohibited", "subject": "...", "object": "...", "condition": ""}}],
|
||
|
|
"deadlines": [{{"date": "YYYY-MM-DD or null", "description": "..."}}],
|
||
|
|
"scope": "one sentence describing who/what this applies to",
|
||
|
|
"penalties": "one sentence on consequences of non-compliance, or null",
|
||
|
|
"impact_level": "high|medium|low"
|
||
|
|
}}"""
|
||
|
|
|
||
|
|
messages = [
|
||
|
|
{"role": "system", "content": _EXTRACT_SYSTEM},
|
||
|
|
{"role": "user", "content": prompt},
|
||
|
|
]
|
||
|
|
result = _llm_json(self._client, messages)
|
||
|
|
if not isinstance(result, dict):
|
||
|
|
return {
|
||
|
|
"obligations": [],
|
||
|
|
"deadlines": [],
|
||
|
|
"scope": "",
|
||
|
|
"penalties": "",
|
||
|
|
"impact_level": "medium",
|
||
|
|
}
|
||
|
|
return result
|
||
|
|
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
# Step 2: Impact assessment
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
|
||
|
|
def assess_impact(self, event: dict, retrieval_service: Any) -> list[dict]:
|
||
|
|
"""Use RAG to find affected documents and generate recommendations."""
|
||
|
|
obligations = event.get("obligations") or []
|
||
|
|
obligation_texts = " ".join(o.get("text", "") for o in obligations[:3])
|
||
|
|
query = f"{event.get('standard_code', '')} {event.get('title', '')} {obligation_texts}"
|
||
|
|
|
||
|
|
try:
|
||
|
|
chunks = retrieval_service.retrieve(query=query, top_k=5)
|
||
|
|
except Exception as exc:
|
||
|
|
logger.warning("RAG retrieval failed: {}", exc)
|
||
|
|
return []
|
||
|
|
|
||
|
|
if not chunks:
|
||
|
|
return []
|
||
|
|
|
||
|
|
seen: set[str] = set()
|
||
|
|
doc_excerpts: list[dict] = []
|
||
|
|
for chunk in chunks:
|
||
|
|
if chunk.doc_id not in seen:
|
||
|
|
seen.add(chunk.doc_id)
|
||
|
|
doc_excerpts.append({
|
||
|
|
"doc_id": chunk.doc_id,
|
||
|
|
"doc_name": chunk.doc_title,
|
||
|
|
"score": round(float(chunk.score if chunk.score is not None else 0), 4),
|
||
|
|
"snippet": (chunk.text or "")[:300],
|
||
|
|
"clause": getattr(chunk, "section_title", "") or "",
|
||
|
|
})
|
||
|
|
|
||
|
|
context = "\n".join(
|
||
|
|
f"[{d['doc_name']} {d['clause']}] score={d['score']}: {d['snippet']}"
|
||
|
|
for d in doc_excerpts
|
||
|
|
)
|
||
|
|
prompt = f"""Regulation: {event.get('standard_code')} — {event.get('title')}
|
||
|
|
Obligations: {obligation_texts or event.get('summary', '')}
|
||
|
|
|
||
|
|
Affected documents found in knowledge base:
|
||
|
|
{context}
|
||
|
|
|
||
|
|
For each document, assess impact and recommend action. Return JSON array:
|
||
|
|
[{{"doc_id":"...","doc_name":"...","score":0.0,"key_clauses":"...","recommendation":"one sentence action"}}]"""
|
||
|
|
|
||
|
|
messages = [
|
||
|
|
{"role": "system", "content": _ASSESS_SYSTEM},
|
||
|
|
{"role": "user", "content": prompt},
|
||
|
|
]
|
||
|
|
result = _llm_json(self._client, messages)
|
||
|
|
if isinstance(result, list):
|
||
|
|
score_map = {d["doc_id"]: d["score"] for d in doc_excerpts}
|
||
|
|
for item in result:
|
||
|
|
if isinstance(item, dict) and item.get("doc_id") in score_map:
|
||
|
|
item["score"] = score_map[item["doc_id"]]
|
||
|
|
return result
|
||
|
|
return doc_excerpts
|
||
|
|
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
# Step 3: Semantic diff
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
|
||
|
|
def compute_diff(self, old_text: str, new_text: str) -> dict:
|
||
|
|
"""Compare old and new regulation text; return changed sections and summary."""
|
||
|
|
old_paras = [p.strip() for p in old_text.split("\n") if p.strip()]
|
||
|
|
new_paras = [p.strip() for p in new_text.split("\n") if p.strip()]
|
||
|
|
|
||
|
|
if not old_paras or not new_paras:
|
||
|
|
return {"changed_sections": [], "change_summary": "No comparable text."}
|
||
|
|
|
||
|
|
all_paras = old_paras + new_paras
|
||
|
|
try:
|
||
|
|
all_embeddings = self._embedder.embed_texts(all_paras)
|
||
|
|
except Exception as exc:
|
||
|
|
logger.warning("Embedding for diff failed: {}", exc)
|
||
|
|
return {"changed_sections": [], "change_summary": "Diff unavailable (embedding error)."}
|
||
|
|
|
||
|
|
old_embeddings = all_embeddings[: len(old_paras)]
|
||
|
|
new_embeddings = all_embeddings[len(old_paras):]
|
||
|
|
|
||
|
|
changed_sections: list[dict] = []
|
||
|
|
max_len = max(len(old_paras), len(new_paras))
|
||
|
|
|
||
|
|
for i in range(max_len):
|
||
|
|
if i >= len(old_paras):
|
||
|
|
# New paragraph added
|
||
|
|
changed_sections.append({
|
||
|
|
"old_text": "",
|
||
|
|
"new_text": new_paras[i][:300],
|
||
|
|
"similarity": 0.0,
|
||
|
|
"change_type": "added",
|
||
|
|
"summary": "New paragraph added.",
|
||
|
|
})
|
||
|
|
continue
|
||
|
|
if i >= len(new_paras):
|
||
|
|
# Old paragraph removed
|
||
|
|
changed_sections.append({
|
||
|
|
"old_text": old_paras[i][:300],
|
||
|
|
"new_text": "",
|
||
|
|
"similarity": 0.0,
|
||
|
|
"change_type": "removed",
|
||
|
|
"summary": "Paragraph removed.",
|
||
|
|
})
|
||
|
|
continue
|
||
|
|
# Both exist — compare via embeddings
|
||
|
|
sim = _cosine(old_embeddings[i], new_embeddings[i])
|
||
|
|
if sim < _SIMILARITY_THRESHOLD:
|
||
|
|
messages = [
|
||
|
|
{"role": "system", "content": _DIFF_SYSTEM},
|
||
|
|
{"role": "user", "content": f"OLD: {old_paras[i][:500]}\nNEW: {new_paras[i][:500]}"},
|
||
|
|
]
|
||
|
|
classification = _llm_json(self._client, messages) or {}
|
||
|
|
changed_sections.append({
|
||
|
|
"old_text": old_paras[i][:300],
|
||
|
|
"new_text": new_paras[i][:300],
|
||
|
|
"similarity": round(sim, 3),
|
||
|
|
"change_type": classification.get("change_type", "modified"),
|
||
|
|
"summary": classification.get("summary", ""),
|
||
|
|
})
|
||
|
|
|
||
|
|
if not changed_sections:
|
||
|
|
change_summary = "No substantive changes detected between versions."
|
||
|
|
else:
|
||
|
|
types = [s["change_type"] for s in changed_sections]
|
||
|
|
change_summary = (
|
||
|
|
f"{len(changed_sections)} paragraph(s) changed: "
|
||
|
|
+ ", ".join(f"{t}" for t in set(types))
|
||
|
|
+ ". "
|
||
|
|
+ (changed_sections[0].get("summary", "") if changed_sections else "")
|
||
|
|
)
|
||
|
|
|
||
|
|
return {"changed_sections": changed_sections, "change_summary": change_summary}
|