96 lines
3.0 KiB
Python
96 lines
3.0 KiB
Python
|
|
import uuid
|
|||
|
|
import logging
|
|||
|
|
from fastapi import APIRouter, HTTPException
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
from langchain.schema import HumanMessage, SystemMessage
|
|||
|
|
|
|||
|
|
from ..core.llm import get_llm, COMPLIANCE_CHECK_PROMPT
|
|||
|
|
from ..services.rag import hybrid_search
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
router = APIRouter(prefix="/api/compliance", tags=["合规审查"])
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ComplianceCheckRequest(BaseModel):
|
|||
|
|
query: str
|
|||
|
|
regulation_domains: list[str] = ["vehicle_safety"]
|
|||
|
|
top_k: int = 5
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ComplianceCheckResponse(BaseModel):
|
|||
|
|
risk_level: str
|
|||
|
|
risk_score: float
|
|||
|
|
findings: list[dict]
|
|||
|
|
recommendations: list[str]
|
|||
|
|
sources: list[dict]
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/check", response_model=ComplianceCheckResponse)
|
|||
|
|
async def check_compliance(req: ComplianceCheckRequest):
|
|||
|
|
"""
|
|||
|
|
对输入内容进行合规性检查,与法规库比对后给出风险评估。
|
|||
|
|
"""
|
|||
|
|
# 检索相关法规(从多个域检索)
|
|||
|
|
all_chunks = []
|
|||
|
|
for domain in req.regulation_domains:
|
|||
|
|
chunks = await hybrid_search(
|
|||
|
|
req.query,
|
|||
|
|
collection_name="regulation_chunks",
|
|||
|
|
top_k=req.top_k,
|
|||
|
|
)
|
|||
|
|
all_chunks.extend(chunks)
|
|||
|
|
|
|||
|
|
# 去重 + 按分数排序
|
|||
|
|
seen = set()
|
|||
|
|
unique_chunks = []
|
|||
|
|
for c in sorted(all_chunks, key=lambda x: x["score"], reverse=True):
|
|||
|
|
if c["id"] not in seen:
|
|||
|
|
seen.add(c["id"])
|
|||
|
|
unique_chunks.append(c)
|
|||
|
|
top_chunks = unique_chunks[:req.top_k]
|
|||
|
|
|
|||
|
|
if not top_chunks:
|
|||
|
|
return ComplianceCheckResponse(
|
|||
|
|
risk_level="unknown",
|
|||
|
|
risk_score=0,
|
|||
|
|
findings=[{"issue": "未找到相关法规,请先上传法规文档"}],
|
|||
|
|
recommendations=["上传相关法规文档到知识库后重试"],
|
|||
|
|
sources=[],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 构建法规上下文
|
|||
|
|
regulations_text = "\n\n".join(
|
|||
|
|
f"[{i+1}] {c['content'][:500]}" for i, c in enumerate(top_chunks)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
prompt = COMPLIANCE_CHECK_PROMPT.format(
|
|||
|
|
content=req.query,
|
|||
|
|
regulations=regulations_text,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
llm = get_llm(temperature=0.0)
|
|||
|
|
try:
|
|||
|
|
response = await llm.ainvoke([HumanMessage(content=prompt)])
|
|||
|
|
analysis = response.content
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"LLM 合规分析失败:{e}")
|
|||
|
|
analysis = f"LLM 分析失败:{e}"
|
|||
|
|
|
|||
|
|
# 简单解析 LLM 输出(生产可用结构化输出)
|
|||
|
|
risk_level = "medium"
|
|||
|
|
risk_score = 50.0
|
|||
|
|
if "critical" in analysis.lower() or "严重" in analysis:
|
|||
|
|
risk_level, risk_score = "critical", 90.0
|
|||
|
|
elif "high" in analysis.lower() or "高风险" in analysis:
|
|||
|
|
risk_level, risk_score = "high", 70.0
|
|||
|
|
elif "low" in analysis.lower() or "低风险" in analysis:
|
|||
|
|
risk_level, risk_score = "low", 20.0
|
|||
|
|
|
|||
|
|
return ComplianceCheckResponse(
|
|||
|
|
risk_level=risk_level,
|
|||
|
|
risk_score=risk_score,
|
|||
|
|
findings=[{"analysis": analysis}],
|
|||
|
|
recommendations=["请参考上述分析进行整改"],
|
|||
|
|
sources=[{"content": c["content"][:200], "score": c["score"]} for c in top_chunks],
|
|||
|
|
)
|