fix somethings
This commit is contained in:
241
backend/app/infrastructure/perception/llm_pipeline.py
Normal file
241
backend/app/infrastructure/perception/llm_pipeline.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""LLM-driven pipeline for regulatory event enrichment."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.infrastructure.embedding.openai_compatible_embedding_provider import (
|
||||
OpenAICompatibleEmbeddingProvider,
|
||||
)
|
||||
from app.services.llm.llm_factory import get_llm_client
|
||||
|
||||
_EXTRACT_SYSTEM = (
|
||||
"You are a regulatory compliance expert specialising in automotive standards "
|
||||
"(GB, UN-ECE, ISO, EU). Extract structured information from regulation text. "
|
||||
"Return valid JSON only — no markdown fences, no extra keys."
|
||||
)
|
||||
|
||||
_ASSESS_SYSTEM = (
|
||||
"You are an automotive compliance analyst. Given a regulation and related document excerpts, "
|
||||
"identify which documents are affected and what actions are required. "
|
||||
"Return a JSON array only."
|
||||
)
|
||||
|
||||
_DIFF_SYSTEM = (
|
||||
"You are a regulatory change analyst. Given an old and new version of a regulation paragraph, "
|
||||
"classify the type of change and summarise it. "
|
||||
"Return JSON only: {\"change_type\": \"tightened|relaxed|added|removed\", \"summary\": \"...\"}"
|
||||
)
|
||||
|
||||
_SIMILARITY_THRESHOLD = 0.85
|
||||
|
||||
|
||||
def _cosine(a: list[float], b: list[float]) -> float:
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = math.sqrt(sum(x * x for x in a))
|
||||
norm_b = math.sqrt(sum(x * x for x in b))
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
return dot / (norm_a * norm_b)
|
||||
|
||||
|
||||
def _llm_json(client: Any, messages: list[dict]) -> Any:
|
||||
"""Call LLM and parse JSON response; return None on failure."""
|
||||
try:
|
||||
resp = client.chat(messages)
|
||||
text = (resp.content or "").strip()
|
||||
if text.startswith("```"):
|
||||
text = text.split("```")[1]
|
||||
if text.startswith("json"):
|
||||
text = text[4:]
|
||||
return json.loads(text)
|
||||
except Exception as exc:
|
||||
logger.warning("LLM JSON parse failed: {}", exc)
|
||||
return None
|
||||
|
||||
|
||||
class LlmPipeline:
|
||||
"""Three-step enrichment pipeline for crawled regulatory events."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._client = get_llm_client(
|
||||
provider=settings.llm_provider,
|
||||
model=settings.llm_model,
|
||||
)
|
||||
self._embedder = OpenAICompatibleEmbeddingProvider()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 1: Structure extraction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def extract_structure(self, event: dict) -> dict:
|
||||
"""Extract obligations, deadlines, scope, penalties, impact_level from event text."""
|
||||
prompt = f"""Extract structured compliance information from this regulation:
|
||||
|
||||
Standard: {event.get('standard_code', '')}
|
||||
Title: {event.get('title', '')}
|
||||
Source: {event.get('source_label', '')}
|
||||
Summary: {event.get('summary', '')}
|
||||
Tags: {', '.join(event.get('tags') or [])}
|
||||
|
||||
Return JSON with exactly these keys:
|
||||
{{
|
||||
"obligations": [{{"text": "...", "deontic": "must|shall|may|prohibited", "subject": "...", "object": "...", "condition": ""}}],
|
||||
"deadlines": [{{"date": "YYYY-MM-DD or null", "description": "..."}}],
|
||||
"scope": "one sentence describing who/what this applies to",
|
||||
"penalties": "one sentence on consequences of non-compliance, or null",
|
||||
"impact_level": "high|medium|low"
|
||||
}}"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": _EXTRACT_SYSTEM},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
result = _llm_json(self._client, messages)
|
||||
if not isinstance(result, dict):
|
||||
return {
|
||||
"obligations": [],
|
||||
"deadlines": [],
|
||||
"scope": "",
|
||||
"penalties": "",
|
||||
"impact_level": "medium",
|
||||
}
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 2: Impact assessment
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def assess_impact(self, event: dict, retrieval_service: Any) -> list[dict]:
|
||||
"""Use RAG to find affected documents and generate recommendations."""
|
||||
obligations = event.get("obligations") or []
|
||||
obligation_texts = " ".join(o.get("text", "") for o in obligations[:3])
|
||||
query = f"{event.get('standard_code', '')} {event.get('title', '')} {obligation_texts}"
|
||||
|
||||
try:
|
||||
chunks = retrieval_service.retrieve(query=query, top_k=5)
|
||||
except Exception as exc:
|
||||
logger.warning("RAG retrieval failed: {}", exc)
|
||||
return []
|
||||
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
seen: set[str] = set()
|
||||
doc_excerpts: list[dict] = []
|
||||
for chunk in chunks:
|
||||
if chunk.doc_id not in seen:
|
||||
seen.add(chunk.doc_id)
|
||||
doc_excerpts.append({
|
||||
"doc_id": chunk.doc_id,
|
||||
"doc_name": chunk.doc_title,
|
||||
"score": round(float(chunk.score if chunk.score is not None else 0), 4),
|
||||
"snippet": (chunk.text or "")[:300],
|
||||
"clause": getattr(chunk, "section_title", "") or "",
|
||||
})
|
||||
|
||||
context = "\n".join(
|
||||
f"[{d['doc_name']} {d['clause']}] score={d['score']}: {d['snippet']}"
|
||||
for d in doc_excerpts
|
||||
)
|
||||
prompt = f"""Regulation: {event.get('standard_code')} — {event.get('title')}
|
||||
Obligations: {obligation_texts or event.get('summary', '')}
|
||||
|
||||
Affected documents found in knowledge base:
|
||||
{context}
|
||||
|
||||
For each document, assess impact and recommend action. Return JSON array:
|
||||
[{{"doc_id":"...","doc_name":"...","score":0.0,"key_clauses":"...","recommendation":"one sentence action"}}]"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": _ASSESS_SYSTEM},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
result = _llm_json(self._client, messages)
|
||||
if isinstance(result, list):
|
||||
score_map = {d["doc_id"]: d["score"] for d in doc_excerpts}
|
||||
for item in result:
|
||||
if isinstance(item, dict) and item.get("doc_id") in score_map:
|
||||
item["score"] = score_map[item["doc_id"]]
|
||||
return result
|
||||
return doc_excerpts
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 3: Semantic diff
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def compute_diff(self, old_text: str, new_text: str) -> dict:
|
||||
"""Compare old and new regulation text; return changed sections and summary."""
|
||||
old_paras = [p.strip() for p in old_text.split("\n") if p.strip()]
|
||||
new_paras = [p.strip() for p in new_text.split("\n") if p.strip()]
|
||||
|
||||
if not old_paras or not new_paras:
|
||||
return {"changed_sections": [], "change_summary": "No comparable text."}
|
||||
|
||||
all_paras = old_paras + new_paras
|
||||
try:
|
||||
all_embeddings = self._embedder.embed_texts(all_paras)
|
||||
except Exception as exc:
|
||||
logger.warning("Embedding for diff failed: {}", exc)
|
||||
return {"changed_sections": [], "change_summary": "Diff unavailable (embedding error)."}
|
||||
|
||||
old_embeddings = all_embeddings[: len(old_paras)]
|
||||
new_embeddings = all_embeddings[len(old_paras):]
|
||||
|
||||
changed_sections: list[dict] = []
|
||||
max_len = max(len(old_paras), len(new_paras))
|
||||
|
||||
for i in range(max_len):
|
||||
if i >= len(old_paras):
|
||||
# New paragraph added
|
||||
changed_sections.append({
|
||||
"old_text": "",
|
||||
"new_text": new_paras[i][:300],
|
||||
"similarity": 0.0,
|
||||
"change_type": "added",
|
||||
"summary": "New paragraph added.",
|
||||
})
|
||||
continue
|
||||
if i >= len(new_paras):
|
||||
# Old paragraph removed
|
||||
changed_sections.append({
|
||||
"old_text": old_paras[i][:300],
|
||||
"new_text": "",
|
||||
"similarity": 0.0,
|
||||
"change_type": "removed",
|
||||
"summary": "Paragraph removed.",
|
||||
})
|
||||
continue
|
||||
# Both exist — compare via embeddings
|
||||
sim = _cosine(old_embeddings[i], new_embeddings[i])
|
||||
if sim < _SIMILARITY_THRESHOLD:
|
||||
messages = [
|
||||
{"role": "system", "content": _DIFF_SYSTEM},
|
||||
{"role": "user", "content": f"OLD: {old_paras[i][:500]}\nNEW: {new_paras[i][:500]}"},
|
||||
]
|
||||
classification = _llm_json(self._client, messages) or {}
|
||||
changed_sections.append({
|
||||
"old_text": old_paras[i][:300],
|
||||
"new_text": new_paras[i][:300],
|
||||
"similarity": round(sim, 3),
|
||||
"change_type": classification.get("change_type", "modified"),
|
||||
"summary": classification.get("summary", ""),
|
||||
})
|
||||
|
||||
if not changed_sections:
|
||||
change_summary = "No substantive changes detected between versions."
|
||||
else:
|
||||
types = [s["change_type"] for s in changed_sections]
|
||||
change_summary = (
|
||||
f"{len(changed_sections)} paragraph(s) changed: "
|
||||
+ ", ".join(f"{t}" for t in set(types))
|
||||
+ ". "
|
||||
+ (changed_sections[0].get("summary", "") if changed_sections else "")
|
||||
)
|
||||
|
||||
return {"changed_sections": changed_sections, "change_summary": change_summary}
|
||||
Reference in New Issue
Block a user