Files
AIRegulation-Demo-Test-Backend/app/workflows/compliance_workflow.py
2026-05-11 11:22:55 +08:00

175 lines
5.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import TypedDict, List
from langgraph.graph import StateGraph, END
class ComplianceState(TypedDict):
document_path: str
raw_text: str
segments: List[dict]
matched_regulations: List[dict]
risk_dashboard: dict
priority_actions: List[dict]
def parse_document(state: ComplianceState) -> dict:
"""解析文档"""
from app.services import get_document_service
doc_service = get_document_service(
"/airegulation/demo-mao/backend/data/raw",
"/airegulation/demo-mao/backend/data/parsed",
)
text = doc_service.parse_document(state["document_path"])
return {"raw_text": text}
def segment_document(state: ComplianceState) -> dict:
"""AI语义分段"""
from app.services import llm_service
prompt = f"""请分析以下设计方案文档,按照设计意图将其分成若干语义段落。
文档内容:
{state['raw_text'][:3000]}
请输出JSON格式的分段结果每个段落包含
- intent: 段落意图/主题
- startPos: 在原文中的起始位置(大致)
- endPos: 在原文中的结束位置(大致)
- keywords: 关键词列表
输出格式:
[{{"intent": "...", "startPos": 0, "endPos": 100, "keywords": [...]}}]"""
# 简化处理:返回基本分段
segments = [
{
"id": 1,
"intent": "整体设计概述",
"content": state["raw_text"][:500],
"keywords": ["设计", "方案"],
}
]
return {"segments": segments}
def match_regulations(state: ComplianceState) -> dict:
"""法规匹配"""
from app.services import embedding_service, milvus_service
matched = []
for segment in state["segments"]:
keyword_text = " ".join(segment.get("keywords", []))
embedding = embedding_service.embed_single(keyword_text)
docs = milvus_service.search(embedding, top_k=5)
segment_regs = []
for doc in docs:
category = "high" if doc["score"] > 0.85 else ("medium" if doc["score"] > 0.6 else "low")
segment_regs.append({
"id": doc["id"],
"name": doc["doc_name"],
"clause": doc.get("clause_id"),
"score": doc["score"],
"match_keyword": keyword_text,
"category": category,
"full_content": doc["content"],
})
segment["regulations"] = segment_regs
matched.append(segment)
return {"matched_regulations": matched}
def calculate_risk(state: ComplianceState) -> dict:
"""计算风险等级"""
segments = state["matched_regulations"]
high_count = 0
medium_count = 0
low_count = 0
need_fix = 0
total_score = 0
for segment in segments:
regs = segment.get("regulations", [])
high_regs = [r for r in regs if r["category"] == "high"]
if high_regs:
avg_score = sum(r["score"] for r in high_regs) / len(high_regs)
if avg_score < 0.9:
segment["risk_level"] = "high"
high_count += 1
need_fix += 1
elif avg_score < 0.92:
segment["risk_level"] = "medium"
medium_count += 1
else:
segment["risk_level"] = "low"
low_count += 1
else:
segment["risk_level"] = "low"
low_count += 1
total_score += avg_score if high_regs else 100
avg_score = total_score / len(segments) if segments else 100
status = "pass" if avg_score >= 90 else ("warning" if avg_score >= 70 else "fail")
status_label = "合规" if status == "pass" else ("需要修改" if status == "warning" else "高风险")
dashboard = {
"score": avg_score,
"high_risk_count": high_count,
"medium_risk_count": medium_count,
"low_risk_count": low_count,
"need_fix_segments": need_fix,
"status": status,
"status_label": status_label,
}
return {"risk_dashboard": dashboard, "segments": segments}
def generate_suggestions(state: ComplianceState) -> dict:
"""生成优先建议"""
actions = []
for segment in state["segments"]:
for reg in segment.get("regulations", []):
if reg["category"] == "high" and reg["score"] < 0.9:
actions.append({
"regulation": reg["name"],
"issue": reg["match_keyword"],
"suggestion": f"建议对照{reg['name']}{reg.get('clause', '')}条进行修改",
"severity": "high",
})
return {"priority_actions": actions}
# 构建工作流图
compliance_graph = StateGraph(ComplianceState)
compliance_graph.add_node("parse", parse_document)
compliance_graph.add_node("segment", segment_document)
compliance_graph.add_node("match", match_regulations)
compliance_graph.add_node("score", calculate_risk)
compliance_graph.add_node("suggest", generate_suggestions)
compliance_graph.set_entry_point("parse")
compliance_graph.add_edge("parse", "segment")
compliance_graph.add_edge("segment", "match")
compliance_graph.add_edge("match", "score")
compliance_graph.add_edge("score", "suggest")
compliance_graph.add_edge("suggest", END)
compliance_workflow = compliance_graph.compile()
async def run_compliance_workflow(document_path: str) -> ComplianceState:
"""运行合规分析工作流"""
initial_state: ComplianceState = {"document_path": document_path}
result = compliance_workflow.invoke(initial_state)
return result