Files
AIRegulation-DocAnalysis/backend/app/workflows/compliance_workflow.py

180 lines
5.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Define workflow state for compliance workflow."""
from typing import TypedDict, List
from langgraph.graph import StateGraph, END
# Keep workflow state definitions compact so transitions stay easy to audit.
class ComplianceState(TypedDict):
"""Track workflow state for compliance state."""
document_path: str
raw_text: str
segments: List[dict]
matched_regulations: List[dict]
risk_dashboard: dict
priority_actions: List[dict]
def parse_document(state: ComplianceState) -> dict:
"""Parse document."""
from app.services import get_document_service
doc_service = get_document_service(
"/airegulation/demo-mao/backend/data/raw",
"/airegulation/demo-mao/backend/data/parsed",
)
text = doc_service.parse_document(state["document_path"])
return {"raw_text": text}
def segment_document(state: ComplianceState) -> dict:
"""Handle segment document."""
from app.services import llm_service
prompt = f"""请分析以下设计方案文档,按照设计意图将其分成若干语义段落。
文档内容:
{state['raw_text'][:3000]}
请输出JSON格式的分段结果每个段落包含
- intent: 段落意图/主题
- startPos: 在原文中的起始位置(大致)
- endPos: 在原文中的结束位置(大致)
- keywords: 关键词列表
输出格式:
[{{"intent": "...", "startPos": 0, "endPos": 100, "keywords": [...]}}]"""
# Keep workflow state definitions compact so transitions stay easy to audit.
segments = [
{
"id": 1,
"intent": "整体设计概述",
"content": state["raw_text"][:500],
"keywords": ["设计", "方案"],
}
]
return {"segments": segments}
def match_regulations(state: ComplianceState) -> dict:
"""Handle match regulations."""
from app.services import embedding_service, milvus_service
matched = []
for segment in state["segments"]:
keyword_text = " ".join(segment.get("keywords", []))
embedding = embedding_service.embed_single(keyword_text)
docs = milvus_service.search(embedding, top_k=5)
segment_regs = []
for doc in docs:
category = "high" if doc["score"] > 0.85 else ("medium" if doc["score"] > 0.6 else "low")
segment_regs.append({
"id": doc["id"],
"name": doc["doc_name"],
"clause": doc.get("clause_id"),
"score": doc["score"],
"match_keyword": keyword_text,
"category": category,
"full_content": doc["content"],
})
segment["regulations"] = segment_regs
matched.append(segment)
return {"matched_regulations": matched}
def calculate_risk(state: ComplianceState) -> dict:
"""Handle calculate risk."""
segments = state["matched_regulations"]
high_count = 0
medium_count = 0
low_count = 0
need_fix = 0
total_score = 0
for segment in segments:
regs = segment.get("regulations", [])
high_regs = [r for r in regs if r["category"] == "high"]
if high_regs:
avg_score = sum(r["score"] for r in high_regs) / len(high_regs)
if avg_score < 0.9:
segment["risk_level"] = "high"
high_count += 1
need_fix += 1
elif avg_score < 0.92:
segment["risk_level"] = "medium"
medium_count += 1
else:
segment["risk_level"] = "low"
low_count += 1
else:
segment["risk_level"] = "low"
low_count += 1
total_score += avg_score if high_regs else 100
avg_score = total_score / len(segments) if segments else 100
status = "pass" if avg_score >= 90 else ("warning" if avg_score >= 70 else "fail")
status_label = "合规" if status == "pass" else ("需要修改" if status == "warning" else "高风险")
dashboard = {
"score": avg_score,
"high_risk_count": high_count,
"medium_risk_count": medium_count,
"low_risk_count": low_count,
"need_fix_segments": need_fix,
"status": status,
"status_label": status_label,
}
return {"risk_dashboard": dashboard, "segments": segments}
def generate_suggestions(state: ComplianceState) -> dict:
"""Handle generate suggestions."""
actions = []
for segment in state["segments"]:
for reg in segment.get("regulations", []):
if reg["category"] == "high" and reg["score"] < 0.9:
actions.append({
"regulation": reg["name"],
"issue": reg["match_keyword"],
"suggestion": f"建议对照{reg['name']}{reg.get('clause', '')}条进行修改",
"severity": "high",
})
return {"priority_actions": actions}
# Keep workflow state definitions compact so transitions stay easy to audit.
compliance_graph = StateGraph(ComplianceState)
compliance_graph.add_node("parse", parse_document)
compliance_graph.add_node("segment", segment_document)
compliance_graph.add_node("match", match_regulations)
compliance_graph.add_node("score", calculate_risk)
compliance_graph.add_node("suggest", generate_suggestions)
compliance_graph.set_entry_point("parse")
compliance_graph.add_edge("parse", "segment")
compliance_graph.add_edge("segment", "match")
compliance_graph.add_edge("match", "score")
compliance_graph.add_edge("score", "suggest")
compliance_graph.add_edge("suggest", END)
compliance_workflow = compliance_graph.compile()
async def run_compliance_workflow(document_path: str) -> ComplianceState:
"""Handle run compliance workflow."""
initial_state: ComplianceState = {"document_path": document_path}
result = compliance_workflow.invoke(initial_state)
return result