update
This commit is contained in:
12
backend/app/workflows/__init__.py
Normal file
12
backend/app/workflows/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from .rag_workflow import RagState, rag_workflow, run_rag_workflow, stream_rag_workflow
|
||||
from .compliance_workflow import ComplianceState, compliance_workflow, run_compliance_workflow
|
||||
|
||||
__all__ = [
|
||||
"RagState",
|
||||
"rag_workflow",
|
||||
"run_rag_workflow",
|
||||
"stream_rag_workflow",
|
||||
"ComplianceState",
|
||||
"compliance_workflow",
|
||||
"run_compliance_workflow",
|
||||
]
|
||||
175
backend/app/workflows/compliance_workflow.py
Normal file
175
backend/app/workflows/compliance_workflow.py
Normal file
@@ -0,0 +1,175 @@
|
||||
from typing import TypedDict, List
|
||||
from langgraph.graph import StateGraph, END
|
||||
|
||||
|
||||
class ComplianceState(TypedDict):
|
||||
document_path: str
|
||||
raw_text: str
|
||||
segments: List[dict]
|
||||
matched_regulations: List[dict]
|
||||
risk_dashboard: dict
|
||||
priority_actions: List[dict]
|
||||
|
||||
|
||||
def parse_document(state: ComplianceState) -> dict:
|
||||
"""解析文档"""
|
||||
from app.services import get_document_service
|
||||
doc_service = get_document_service(
|
||||
"/airegulation/demo-mao/backend/data/raw",
|
||||
"/airegulation/demo-mao/backend/data/parsed",
|
||||
)
|
||||
text = doc_service.parse_document(state["document_path"])
|
||||
return {"raw_text": text}
|
||||
|
||||
|
||||
def segment_document(state: ComplianceState) -> dict:
|
||||
"""AI语义分段"""
|
||||
from app.services import llm_service
|
||||
prompt = f"""请分析以下设计方案文档,按照设计意图将其分成若干语义段落。
|
||||
|
||||
文档内容:
|
||||
{state['raw_text'][:3000]}
|
||||
|
||||
请输出JSON格式的分段结果,每个段落包含:
|
||||
- intent: 段落意图/主题
|
||||
- startPos: 在原文中的起始位置(大致)
|
||||
- endPos: 在原文中的结束位置(大致)
|
||||
- keywords: 关键词列表
|
||||
|
||||
输出格式:
|
||||
[{{"intent": "...", "startPos": 0, "endPos": 100, "keywords": [...]}}]"""
|
||||
|
||||
# 简化处理:返回基本分段
|
||||
segments = [
|
||||
{
|
||||
"id": 1,
|
||||
"intent": "整体设计概述",
|
||||
"content": state["raw_text"][:500],
|
||||
"keywords": ["设计", "方案"],
|
||||
}
|
||||
]
|
||||
|
||||
return {"segments": segments}
|
||||
|
||||
|
||||
def match_regulations(state: ComplianceState) -> dict:
|
||||
"""法规匹配"""
|
||||
from app.services import embedding_service, milvus_service
|
||||
matched = []
|
||||
|
||||
for segment in state["segments"]:
|
||||
keyword_text = " ".join(segment.get("keywords", []))
|
||||
embedding = embedding_service.embed_single(keyword_text)
|
||||
|
||||
docs = milvus_service.search(embedding, top_k=5)
|
||||
|
||||
segment_regs = []
|
||||
for doc in docs:
|
||||
category = "high" if doc["score"] > 0.85 else ("medium" if doc["score"] > 0.6 else "low")
|
||||
segment_regs.append({
|
||||
"id": doc["id"],
|
||||
"name": doc["doc_name"],
|
||||
"clause": doc.get("clause_id"),
|
||||
"score": doc["score"],
|
||||
"match_keyword": keyword_text,
|
||||
"category": category,
|
||||
"full_content": doc["content"],
|
||||
})
|
||||
|
||||
segment["regulations"] = segment_regs
|
||||
matched.append(segment)
|
||||
|
||||
return {"matched_regulations": matched}
|
||||
|
||||
|
||||
def calculate_risk(state: ComplianceState) -> dict:
|
||||
"""计算风险等级"""
|
||||
segments = state["matched_regulations"]
|
||||
|
||||
high_count = 0
|
||||
medium_count = 0
|
||||
low_count = 0
|
||||
need_fix = 0
|
||||
total_score = 0
|
||||
|
||||
for segment in segments:
|
||||
regs = segment.get("regulations", [])
|
||||
high_regs = [r for r in regs if r["category"] == "high"]
|
||||
|
||||
if high_regs:
|
||||
avg_score = sum(r["score"] for r in high_regs) / len(high_regs)
|
||||
if avg_score < 0.9:
|
||||
segment["risk_level"] = "high"
|
||||
high_count += 1
|
||||
need_fix += 1
|
||||
elif avg_score < 0.92:
|
||||
segment["risk_level"] = "medium"
|
||||
medium_count += 1
|
||||
else:
|
||||
segment["risk_level"] = "low"
|
||||
low_count += 1
|
||||
else:
|
||||
segment["risk_level"] = "low"
|
||||
low_count += 1
|
||||
|
||||
total_score += avg_score if high_regs else 100
|
||||
|
||||
avg_score = total_score / len(segments) if segments else 100
|
||||
|
||||
status = "pass" if avg_score >= 90 else ("warning" if avg_score >= 70 else "fail")
|
||||
status_label = "合规" if status == "pass" else ("需要修改" if status == "warning" else "高风险")
|
||||
|
||||
dashboard = {
|
||||
"score": avg_score,
|
||||
"high_risk_count": high_count,
|
||||
"medium_risk_count": medium_count,
|
||||
"low_risk_count": low_count,
|
||||
"need_fix_segments": need_fix,
|
||||
"status": status,
|
||||
"status_label": status_label,
|
||||
}
|
||||
|
||||
return {"risk_dashboard": dashboard, "segments": segments}
|
||||
|
||||
|
||||
def generate_suggestions(state: ComplianceState) -> dict:
|
||||
"""生成优先建议"""
|
||||
actions = []
|
||||
|
||||
for segment in state["segments"]:
|
||||
for reg in segment.get("regulations", []):
|
||||
if reg["category"] == "high" and reg["score"] < 0.9:
|
||||
actions.append({
|
||||
"regulation": reg["name"],
|
||||
"issue": reg["match_keyword"],
|
||||
"suggestion": f"建议对照{reg['name']}第{reg.get('clause', '')}条进行修改",
|
||||
"severity": "high",
|
||||
})
|
||||
|
||||
return {"priority_actions": actions}
|
||||
|
||||
|
||||
# 构建工作流图
|
||||
compliance_graph = StateGraph(ComplianceState)
|
||||
|
||||
compliance_graph.add_node("parse", parse_document)
|
||||
compliance_graph.add_node("segment", segment_document)
|
||||
compliance_graph.add_node("match", match_regulations)
|
||||
compliance_graph.add_node("score", calculate_risk)
|
||||
compliance_graph.add_node("suggest", generate_suggestions)
|
||||
|
||||
compliance_graph.set_entry_point("parse")
|
||||
compliance_graph.add_edge("parse", "segment")
|
||||
compliance_graph.add_edge("segment", "match")
|
||||
compliance_graph.add_edge("match", "score")
|
||||
compliance_graph.add_edge("score", "suggest")
|
||||
compliance_graph.add_edge("suggest", END)
|
||||
|
||||
compliance_workflow = compliance_graph.compile()
|
||||
|
||||
|
||||
async def run_compliance_workflow(document_path: str) -> ComplianceState:
|
||||
"""运行合规分析工作流"""
|
||||
initial_state: ComplianceState = {"document_path": document_path}
|
||||
result = compliance_workflow.invoke(initial_state)
|
||||
return result
|
||||
114
backend/app/workflows/rag_workflow.py
Normal file
114
backend/app/workflows/rag_workflow.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from typing import TypedDict, List
|
||||
from langgraph.graph import StateGraph, END
|
||||
|
||||
|
||||
class RagState(TypedDict):
|
||||
query: str
|
||||
query_embedding: List[float]
|
||||
retrieved_docs: List[dict]
|
||||
context: str
|
||||
answer: str
|
||||
sources: List[dict]
|
||||
|
||||
|
||||
def embed_query(state: RagState) -> dict:
|
||||
"""将查询转为向量"""
|
||||
from app.services import embedding_service
|
||||
embedding = embedding_service.embed_single(state["query"])
|
||||
return {"query_embedding": embedding}
|
||||
|
||||
|
||||
def retrieve_docs(state: RagState) -> dict:
|
||||
"""向量检索"""
|
||||
from app.services import milvus_service
|
||||
from app.core.config import settings
|
||||
docs = milvus_service.search(
|
||||
state["query_embedding"],
|
||||
top_k=settings.vector_top_k,
|
||||
)
|
||||
return {"retrieved_docs": docs[:settings.final_top_k]}
|
||||
|
||||
|
||||
def build_context(state: RagState) -> dict:
|
||||
"""构建上下文"""
|
||||
context_parts = []
|
||||
sources = []
|
||||
|
||||
for doc in state["retrieved_docs"]:
|
||||
context_parts.append(f"【{doc['doc_name']} - {doc.get('clause_id', '')}】\n{doc['content']}")
|
||||
sources.append({
|
||||
"name": doc["doc_name"],
|
||||
"clause": doc.get("clause_id"),
|
||||
})
|
||||
|
||||
context = "\n\n".join(context_parts)
|
||||
return {"context": context, "sources": sources}
|
||||
|
||||
|
||||
def generate_answer(state: RagState) -> dict:
|
||||
"""生成答案"""
|
||||
from app.services import llm_service
|
||||
prompt = f"""请根据以下法规内容回答用户问题,并在回答中标注引用的法规条款。
|
||||
|
||||
法规内容:
|
||||
{state['context']}
|
||||
|
||||
用户问题:{state['query']}
|
||||
|
||||
请给出准确、简洁的回答,并引用相关法规条款。"""
|
||||
|
||||
answer = ""
|
||||
for chunk in llm_service.generate_stream(prompt):
|
||||
answer += chunk
|
||||
|
||||
return {"answer": answer}
|
||||
|
||||
|
||||
# 构建工作流图
|
||||
rag_graph = StateGraph(RagState)
|
||||
|
||||
rag_graph.add_node("embed", embed_query)
|
||||
rag_graph.add_node("retrieve", retrieve_docs)
|
||||
rag_graph.add_node("build_context", build_context)
|
||||
rag_graph.add_node("generate", generate_answer)
|
||||
|
||||
rag_graph.set_entry_point("embed")
|
||||
rag_graph.add_edge("embed", "retrieve")
|
||||
rag_graph.add_edge("retrieve", "build_context")
|
||||
rag_graph.add_edge("build_context", "generate")
|
||||
rag_graph.add_edge("generate", END)
|
||||
|
||||
rag_workflow = rag_graph.compile()
|
||||
|
||||
|
||||
async def run_rag_workflow(query: str) -> RagState:
|
||||
"""运行RAG工作流"""
|
||||
initial_state: RagState = {"query": query}
|
||||
result = rag_workflow.invoke(initial_state)
|
||||
return result
|
||||
|
||||
|
||||
def stream_rag_workflow(query: str):
|
||||
"""流式运行RAG工作流"""
|
||||
from app.services import llm_service
|
||||
|
||||
# 先完成检索阶段
|
||||
state: RagState = {"query": query}
|
||||
state.update(embed_query(state))
|
||||
state.update(retrieve_docs(state))
|
||||
state.update(build_context(state))
|
||||
|
||||
# 流式生成阶段
|
||||
prompt = f"""请根据以下法规内容回答用户问题,并在回答中标注引用的法规条款。
|
||||
|
||||
法规内容:
|
||||
{state['context']}
|
||||
|
||||
用户问题:{state['query']}
|
||||
|
||||
请给出准确、简洁的回答,并引用相关法规条款。"""
|
||||
|
||||
for chunk in llm_service.generate_stream(prompt):
|
||||
yield {"type": "chunk", "text": chunk}
|
||||
|
||||
yield {"type": "done", "sources": state["sources"]}
|
||||
Reference in New Issue
Block a user