175 lines
5.4 KiB
Python
175 lines
5.4 KiB
Python
from typing import TypedDict, List
|
||
from langgraph.graph import StateGraph, END
|
||
|
||
|
||
class ComplianceState(TypedDict):
|
||
document_path: str
|
||
raw_text: str
|
||
segments: List[dict]
|
||
matched_regulations: List[dict]
|
||
risk_dashboard: dict
|
||
priority_actions: List[dict]
|
||
|
||
|
||
def parse_document(state: ComplianceState) -> dict:
|
||
"""解析文档"""
|
||
from app.services import get_document_service
|
||
doc_service = get_document_service(
|
||
"/airegulation/demo-mao/backend/data/raw",
|
||
"/airegulation/demo-mao/backend/data/parsed",
|
||
)
|
||
text = doc_service.parse_document(state["document_path"])
|
||
return {"raw_text": text}
|
||
|
||
|
||
def segment_document(state: ComplianceState) -> dict:
|
||
"""AI语义分段"""
|
||
from app.services import llm_service
|
||
prompt = f"""请分析以下设计方案文档,按照设计意图将其分成若干语义段落。
|
||
|
||
文档内容:
|
||
{state['raw_text'][:3000]}
|
||
|
||
请输出JSON格式的分段结果,每个段落包含:
|
||
- intent: 段落意图/主题
|
||
- startPos: 在原文中的起始位置(大致)
|
||
- endPos: 在原文中的结束位置(大致)
|
||
- keywords: 关键词列表
|
||
|
||
输出格式:
|
||
[{{"intent": "...", "startPos": 0, "endPos": 100, "keywords": [...]}}]"""
|
||
|
||
# 简化处理:返回基本分段
|
||
segments = [
|
||
{
|
||
"id": 1,
|
||
"intent": "整体设计概述",
|
||
"content": state["raw_text"][:500],
|
||
"keywords": ["设计", "方案"],
|
||
}
|
||
]
|
||
|
||
return {"segments": segments}
|
||
|
||
|
||
def match_regulations(state: ComplianceState) -> dict:
|
||
"""法规匹配"""
|
||
from app.services import embedding_service, milvus_service
|
||
matched = []
|
||
|
||
for segment in state["segments"]:
|
||
keyword_text = " ".join(segment.get("keywords", []))
|
||
embedding = embedding_service.embed_single(keyword_text)
|
||
|
||
docs = milvus_service.search(embedding, top_k=5)
|
||
|
||
segment_regs = []
|
||
for doc in docs:
|
||
category = "high" if doc["score"] > 0.85 else ("medium" if doc["score"] > 0.6 else "low")
|
||
segment_regs.append({
|
||
"id": doc["id"],
|
||
"name": doc["doc_name"],
|
||
"clause": doc.get("clause_id"),
|
||
"score": doc["score"],
|
||
"match_keyword": keyword_text,
|
||
"category": category,
|
||
"full_content": doc["content"],
|
||
})
|
||
|
||
segment["regulations"] = segment_regs
|
||
matched.append(segment)
|
||
|
||
return {"matched_regulations": matched}
|
||
|
||
|
||
def calculate_risk(state: ComplianceState) -> dict:
|
||
"""计算风险等级"""
|
||
segments = state["matched_regulations"]
|
||
|
||
high_count = 0
|
||
medium_count = 0
|
||
low_count = 0
|
||
need_fix = 0
|
||
total_score = 0
|
||
|
||
for segment in segments:
|
||
regs = segment.get("regulations", [])
|
||
high_regs = [r for r in regs if r["category"] == "high"]
|
||
|
||
if high_regs:
|
||
avg_score = sum(r["score"] for r in high_regs) / len(high_regs)
|
||
if avg_score < 0.9:
|
||
segment["risk_level"] = "high"
|
||
high_count += 1
|
||
need_fix += 1
|
||
elif avg_score < 0.92:
|
||
segment["risk_level"] = "medium"
|
||
medium_count += 1
|
||
else:
|
||
segment["risk_level"] = "low"
|
||
low_count += 1
|
||
else:
|
||
segment["risk_level"] = "low"
|
||
low_count += 1
|
||
|
||
total_score += avg_score if high_regs else 100
|
||
|
||
avg_score = total_score / len(segments) if segments else 100
|
||
|
||
status = "pass" if avg_score >= 90 else ("warning" if avg_score >= 70 else "fail")
|
||
status_label = "合规" if status == "pass" else ("需要修改" if status == "warning" else "高风险")
|
||
|
||
dashboard = {
|
||
"score": avg_score,
|
||
"high_risk_count": high_count,
|
||
"medium_risk_count": medium_count,
|
||
"low_risk_count": low_count,
|
||
"need_fix_segments": need_fix,
|
||
"status": status,
|
||
"status_label": status_label,
|
||
}
|
||
|
||
return {"risk_dashboard": dashboard, "segments": segments}
|
||
|
||
|
||
def generate_suggestions(state: ComplianceState) -> dict:
|
||
"""生成优先建议"""
|
||
actions = []
|
||
|
||
for segment in state["segments"]:
|
||
for reg in segment.get("regulations", []):
|
||
if reg["category"] == "high" and reg["score"] < 0.9:
|
||
actions.append({
|
||
"regulation": reg["name"],
|
||
"issue": reg["match_keyword"],
|
||
"suggestion": f"建议对照{reg['name']}第{reg.get('clause', '')}条进行修改",
|
||
"severity": "high",
|
||
})
|
||
|
||
return {"priority_actions": actions}
|
||
|
||
|
||
# 构建工作流图
|
||
compliance_graph = StateGraph(ComplianceState)
|
||
|
||
compliance_graph.add_node("parse", parse_document)
|
||
compliance_graph.add_node("segment", segment_document)
|
||
compliance_graph.add_node("match", match_regulations)
|
||
compliance_graph.add_node("score", calculate_risk)
|
||
compliance_graph.add_node("suggest", generate_suggestions)
|
||
|
||
compliance_graph.set_entry_point("parse")
|
||
compliance_graph.add_edge("parse", "segment")
|
||
compliance_graph.add_edge("segment", "match")
|
||
compliance_graph.add_edge("match", "score")
|
||
compliance_graph.add_edge("score", "suggest")
|
||
compliance_graph.add_edge("suggest", END)
|
||
|
||
compliance_workflow = compliance_graph.compile()
|
||
|
||
|
||
async def run_compliance_workflow(document_path: str) -> ComplianceState:
|
||
"""运行合规分析工作流"""
|
||
initial_state: ComplianceState = {"document_path": document_path}
|
||
result = compliance_workflow.invoke(initial_state)
|
||
return result |