"""LLM-driven pipeline for regulatory event enrichment.""" from __future__ import annotations import json import math from typing import Any from loguru import logger from app.config.settings import settings from app.infrastructure.embedding.openai_compatible_embedding_provider import ( OpenAICompatibleEmbeddingProvider, ) from app.services.llm.llm_factory import get_llm_client _EXTRACT_SYSTEM = ( "You are a regulatory compliance expert specialising in automotive standards " "(GB, UN-ECE, ISO, EU). Extract structured information from regulation text. " "Return valid JSON only — no markdown fences, no extra keys." ) _ASSESS_SYSTEM = ( "You are an automotive compliance analyst. Given a regulation and related document excerpts, " "identify which documents are affected and what actions are required. " "Return a JSON array only." ) _DIFF_SYSTEM = ( "You are a regulatory change analyst. Given an old and new version of a regulation paragraph, " "classify the type of change and summarise it. " "Return JSON only: {\"change_type\": \"tightened|relaxed|added|removed\", \"summary\": \"...\"}" ) _SIMILARITY_THRESHOLD = 0.85 def _cosine(a: list[float], b: list[float]) -> float: dot = sum(x * y for x, y in zip(a, b)) norm_a = math.sqrt(sum(x * x for x in a)) norm_b = math.sqrt(sum(x * x for x in b)) if norm_a == 0 or norm_b == 0: return 0.0 return dot / (norm_a * norm_b) def _llm_json(client: Any, messages: list[dict]) -> Any: """Call LLM and parse JSON response; return None on failure.""" try: resp = client.chat(messages) text = (resp.content or "").strip() if text.startswith("```"): text = text.split("```")[1] if text.startswith("json"): text = text[4:] return json.loads(text) except Exception as exc: logger.warning("LLM JSON parse failed: {}", exc) return None class LlmPipeline: """Three-step enrichment pipeline for crawled regulatory events.""" def __init__(self) -> None: self._client = get_llm_client( provider=settings.llm_provider, model=settings.llm_model, ) self._embedder = OpenAICompatibleEmbeddingProvider() # ------------------------------------------------------------------ # Step 1: Structure extraction # ------------------------------------------------------------------ def extract_structure(self, event: dict) -> dict: """Extract obligations, deadlines, scope, penalties, impact_level from event text.""" prompt = f"""Extract structured compliance information from this regulation: Standard: {event.get('standard_code', '')} Title: {event.get('title', '')} Source: {event.get('source_label', '')} Summary: {event.get('summary', '')} Tags: {', '.join(event.get('tags') or [])} Return JSON with exactly these keys: {{ "obligations": [{{"text": "...", "deontic": "must|shall|may|prohibited", "subject": "...", "object": "...", "condition": ""}}], "deadlines": [{{"date": "YYYY-MM-DD or null", "description": "..."}}], "scope": "one sentence describing who/what this applies to", "penalties": "one sentence on consequences of non-compliance, or null", "impact_level": "high|medium|low" }}""" messages = [ {"role": "system", "content": _EXTRACT_SYSTEM}, {"role": "user", "content": prompt}, ] result = _llm_json(self._client, messages) if not isinstance(result, dict): return { "obligations": [], "deadlines": [], "scope": "", "penalties": "", "impact_level": "medium", } return result # ------------------------------------------------------------------ # Step 2: Impact assessment # ------------------------------------------------------------------ def assess_impact(self, event: dict, retrieval_service: Any) -> list[dict]: """Use RAG to find affected documents and generate recommendations.""" obligations = event.get("obligations") or [] obligation_texts = " ".join(o.get("text", "") for o in obligations[:3]) query = f"{event.get('standard_code', '')} {event.get('title', '')} {obligation_texts}" try: chunks = retrieval_service.retrieve(query=query, top_k=5) except Exception as exc: logger.warning("RAG retrieval failed: {}", exc) return [] if not chunks: return [] seen: set[str] = set() doc_excerpts: list[dict] = [] for chunk in chunks: if chunk.doc_id not in seen: seen.add(chunk.doc_id) doc_excerpts.append({ "doc_id": chunk.doc_id, "doc_name": chunk.doc_title, "score": round(float(chunk.score if chunk.score is not None else 0), 4), "snippet": (chunk.text or "")[:300], "clause": getattr(chunk, "section_title", "") or "", }) context = "\n".join( f"[{d['doc_name']} {d['clause']}] score={d['score']}: {d['snippet']}" for d in doc_excerpts ) prompt = f"""Regulation: {event.get('standard_code')} — {event.get('title')} Obligations: {obligation_texts or event.get('summary', '')} Affected documents found in knowledge base: {context} For each document, assess impact and recommend action. Return JSON array: [{{"doc_id":"...","doc_name":"...","score":0.0,"key_clauses":"...","recommendation":"one sentence action"}}]""" messages = [ {"role": "system", "content": _ASSESS_SYSTEM}, {"role": "user", "content": prompt}, ] result = _llm_json(self._client, messages) if isinstance(result, list): score_map = {d["doc_id"]: d["score"] for d in doc_excerpts} for item in result: if isinstance(item, dict) and item.get("doc_id") in score_map: item["score"] = score_map[item["doc_id"]] return result return doc_excerpts # ------------------------------------------------------------------ # Step 3: Semantic diff # ------------------------------------------------------------------ def compute_diff(self, old_text: str, new_text: str) -> dict: """Compare old and new regulation text; return changed sections and summary.""" old_paras = [p.strip() for p in old_text.split("\n") if p.strip()] new_paras = [p.strip() for p in new_text.split("\n") if p.strip()] if not old_paras or not new_paras: return {"changed_sections": [], "change_summary": "No comparable text."} all_paras = old_paras + new_paras try: all_embeddings = self._embedder.embed_texts(all_paras) except Exception as exc: logger.warning("Embedding for diff failed: {}", exc) return {"changed_sections": [], "change_summary": "Diff unavailable (embedding error)."} old_embeddings = all_embeddings[: len(old_paras)] new_embeddings = all_embeddings[len(old_paras):] changed_sections: list[dict] = [] max_len = max(len(old_paras), len(new_paras)) for i in range(max_len): if i >= len(old_paras): # New paragraph added changed_sections.append({ "old_text": "", "new_text": new_paras[i][:300], "similarity": 0.0, "change_type": "added", "summary": "New paragraph added.", }) continue if i >= len(new_paras): # Old paragraph removed changed_sections.append({ "old_text": old_paras[i][:300], "new_text": "", "similarity": 0.0, "change_type": "removed", "summary": "Paragraph removed.", }) continue # Both exist — compare via embeddings sim = _cosine(old_embeddings[i], new_embeddings[i]) if sim < _SIMILARITY_THRESHOLD: messages = [ {"role": "system", "content": _DIFF_SYSTEM}, {"role": "user", "content": f"OLD: {old_paras[i][:500]}\nNEW: {new_paras[i][:500]}"}, ] classification = _llm_json(self._client, messages) or {} changed_sections.append({ "old_text": old_paras[i][:300], "new_text": new_paras[i][:300], "similarity": round(sim, 3), "change_type": classification.get("change_type", "modified"), "summary": classification.get("summary", ""), }) if not changed_sections: change_summary = "No substantive changes detected between versions." else: types = [s["change_type"] for s in changed_sections] change_summary = ( f"{len(changed_sections)} paragraph(s) changed: " + ", ".join(f"{t}" for t in set(types)) + ". " + (changed_sections[0].get("summary", "") if changed_sections else "") ) return {"changed_sections": changed_sections, "change_summary": change_summary}