first commit
This commit is contained in:
0
services/compliance-backend/app/api/__init__.py
Normal file
0
services/compliance-backend/app/api/__init__.py
Normal file
95
services/compliance-backend/app/api/compliance.py
Normal file
95
services/compliance-backend/app/api/compliance.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import uuid
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from ..core.llm import get_llm, COMPLIANCE_CHECK_PROMPT
|
||||
from ..services.rag import hybrid_search
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/compliance", tags=["合规审查"])
|
||||
|
||||
|
||||
class ComplianceCheckRequest(BaseModel):
|
||||
query: str
|
||||
regulation_domains: list[str] = ["vehicle_safety"]
|
||||
top_k: int = 5
|
||||
|
||||
|
||||
class ComplianceCheckResponse(BaseModel):
|
||||
risk_level: str
|
||||
risk_score: float
|
||||
findings: list[dict]
|
||||
recommendations: list[str]
|
||||
sources: list[dict]
|
||||
|
||||
|
||||
@router.post("/check", response_model=ComplianceCheckResponse)
|
||||
async def check_compliance(req: ComplianceCheckRequest):
|
||||
"""
|
||||
对输入内容进行合规性检查,与法规库比对后给出风险评估。
|
||||
"""
|
||||
# 检索相关法规(从多个域检索)
|
||||
all_chunks = []
|
||||
for domain in req.regulation_domains:
|
||||
chunks = await hybrid_search(
|
||||
req.query,
|
||||
collection_name="regulation_chunks",
|
||||
top_k=req.top_k,
|
||||
)
|
||||
all_chunks.extend(chunks)
|
||||
|
||||
# 去重 + 按分数排序
|
||||
seen = set()
|
||||
unique_chunks = []
|
||||
for c in sorted(all_chunks, key=lambda x: x["score"], reverse=True):
|
||||
if c["id"] not in seen:
|
||||
seen.add(c["id"])
|
||||
unique_chunks.append(c)
|
||||
top_chunks = unique_chunks[:req.top_k]
|
||||
|
||||
if not top_chunks:
|
||||
return ComplianceCheckResponse(
|
||||
risk_level="unknown",
|
||||
risk_score=0,
|
||||
findings=[{"issue": "未找到相关法规,请先上传法规文档"}],
|
||||
recommendations=["上传相关法规文档到知识库后重试"],
|
||||
sources=[],
|
||||
)
|
||||
|
||||
# 构建法规上下文
|
||||
regulations_text = "\n\n".join(
|
||||
f"[{i+1}] {c['content'][:500]}" for i, c in enumerate(top_chunks)
|
||||
)
|
||||
|
||||
prompt = COMPLIANCE_CHECK_PROMPT.format(
|
||||
content=req.query,
|
||||
regulations=regulations_text,
|
||||
)
|
||||
|
||||
llm = get_llm(temperature=0.0)
|
||||
try:
|
||||
response = await llm.ainvoke([HumanMessage(content=prompt)])
|
||||
analysis = response.content
|
||||
except Exception as e:
|
||||
logger.error(f"LLM 合规分析失败:{e}")
|
||||
analysis = f"LLM 分析失败:{e}"
|
||||
|
||||
# 简单解析 LLM 输出(生产可用结构化输出)
|
||||
risk_level = "medium"
|
||||
risk_score = 50.0
|
||||
if "critical" in analysis.lower() or "严重" in analysis:
|
||||
risk_level, risk_score = "critical", 90.0
|
||||
elif "high" in analysis.lower() or "高风险" in analysis:
|
||||
risk_level, risk_score = "high", 70.0
|
||||
elif "low" in analysis.lower() or "低风险" in analysis:
|
||||
risk_level, risk_score = "low", 20.0
|
||||
|
||||
return ComplianceCheckResponse(
|
||||
risk_level=risk_level,
|
||||
risk_score=risk_score,
|
||||
findings=[{"analysis": analysis}],
|
||||
recommendations=["请参考上述分析进行整改"],
|
||||
sources=[{"content": c["content"][:200], "score": c["score"]} for c in top_chunks],
|
||||
)
|
||||
114
services/compliance-backend/app/api/kb.py
Normal file
114
services/compliance-backend/app/api/kb.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import uuid
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, UploadFile, File, Form, HTTPException, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from ..core.deps import get_db
|
||||
from ..models.db import Workspace, File as FileRecord, Task
|
||||
from ..services.rag import hybrid_search, rerank, generate_answer
|
||||
from ..worker import process_file_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/kb", tags=["知识库"])
|
||||
|
||||
UPLOAD_DIR = Path("/app/data/uploads")
|
||||
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
class WorkspaceCreate(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
domain: str = "general"
|
||||
|
||||
|
||||
class QARequest(BaseModel):
|
||||
query: str
|
||||
workspace_id: str | None = None
|
||||
top_k: int = 5
|
||||
return_sources: bool = True
|
||||
|
||||
|
||||
@router.post("/workspaces")
|
||||
async def create_workspace(req: WorkspaceCreate, db: AsyncSession = Depends(get_db)):
|
||||
ws = Workspace(name=req.name, description=req.description, domain=req.domain)
|
||||
db.add(ws)
|
||||
await db.flush()
|
||||
return {"id": str(ws.id), "name": ws.name, "domain": ws.domain}
|
||||
|
||||
|
||||
@router.post("/files/upload")
|
||||
async def upload_file(
|
||||
background_tasks: BackgroundTasks,
|
||||
file: UploadFile = File(...),
|
||||
workspace_id: str = Form(default=""),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
content = await file.read()
|
||||
file_id = str(uuid.uuid4())
|
||||
suffix = Path(file.filename or "doc").suffix
|
||||
save_path = UPLOAD_DIR / f"{file_id}{suffix}"
|
||||
save_path.write_bytes(content)
|
||||
|
||||
file_record = FileRecord(
|
||||
id=uuid.UUID(file_id),
|
||||
filename=f"{file_id}{suffix}",
|
||||
original_name=file.filename or "unknown",
|
||||
file_type=suffix.lstrip("."),
|
||||
file_size=len(content),
|
||||
storage_path=str(save_path),
|
||||
workspace_id=uuid.UUID(workspace_id) if workspace_id else None,
|
||||
status="uploaded",
|
||||
)
|
||||
db.add(file_record)
|
||||
|
||||
task = Task(
|
||||
task_type="parse_and_vectorize",
|
||||
status="pending",
|
||||
file_id=uuid.UUID(file_id),
|
||||
payload={"workspace_id": workspace_id},
|
||||
)
|
||||
db.add(task)
|
||||
await db.flush()
|
||||
|
||||
# 异步触发 Celery 任务
|
||||
celery_task = process_file_task.delay(file_id, str(task.id), workspace_id)
|
||||
task.celery_task_id = celery_task.id
|
||||
await db.flush()
|
||||
|
||||
return {"file_id": file_id, "task_id": str(task.id), "status": "processing"}
|
||||
|
||||
|
||||
@router.get("/tasks/{task_id}")
|
||||
async def get_task(task_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Task).where(Task.id == uuid.UUID(task_id)))
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
return {
|
||||
"task_id": str(task.id),
|
||||
"status": task.status,
|
||||
"progress": task.progress,
|
||||
"file_id": str(task.file_id) if task.file_id else None,
|
||||
"error_msg": task.error_msg,
|
||||
"completed_at": task.completed_at.isoformat() if task.completed_at else None,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/qa")
|
||||
async def qa(req: QARequest):
|
||||
chunks = await hybrid_search(req.query, workspace_id=req.workspace_id, top_k=req.top_k * 2)
|
||||
ranked = await rerank(req.query, chunks, top_k=req.top_k)
|
||||
result = await generate_answer(req.query, ranked)
|
||||
if not req.return_sources:
|
||||
result.pop("sources", None)
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/knowledge/retrieval")
|
||||
async def retrieval(req: QARequest):
|
||||
chunks = await hybrid_search(req.query, workspace_id=req.workspace_id, top_k=req.top_k)
|
||||
return {"chunks": chunks, "total": len(chunks)}
|
||||
111
services/compliance-backend/app/api/regulation.py
Normal file
111
services/compliance-backend/app/api/regulation.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import uuid
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, desc
|
||||
|
||||
from ..core.deps import get_db
|
||||
from ..models.db import RegulationSource, RegulationUpdate
|
||||
from ..worker import fetch_regulation_source
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/regulation", tags=["法规监控"])
|
||||
|
||||
|
||||
class SourceCreate(BaseModel):
|
||||
name: str
|
||||
url: str
|
||||
domain: str = "vehicle_safety"
|
||||
fetch_interval: int = 86400
|
||||
fetch_config: dict = {}
|
||||
|
||||
|
||||
class SubscribeRequest(BaseModel):
|
||||
name: str
|
||||
channel: str # email / webhook / feishu / dingtalk
|
||||
target: str
|
||||
domains: list[str] = []
|
||||
importance_min: str = "normal"
|
||||
|
||||
|
||||
@router.post("/sources")
|
||||
async def create_source(req: SourceCreate, db: AsyncSession = Depends(get_db)):
|
||||
source = RegulationSource(
|
||||
name=req.name,
|
||||
url=req.url,
|
||||
domain=req.domain,
|
||||
fetch_interval=req.fetch_interval,
|
||||
fetch_config=req.fetch_config,
|
||||
)
|
||||
db.add(source)
|
||||
await db.flush()
|
||||
return {
|
||||
"id": str(source.id),
|
||||
"name": source.name,
|
||||
"url": source.url,
|
||||
"domain": source.domain,
|
||||
"status": "active",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/sources")
|
||||
async def list_sources(db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(
|
||||
select(RegulationSource).where(RegulationSource.is_active == True)
|
||||
)
|
||||
sources = result.scalars().all()
|
||||
return [{"id": str(s.id), "name": s.name, "url": s.url, "domain": s.domain} for s in sources]
|
||||
|
||||
|
||||
@router.post("/sources/{source_id}/fetch")
|
||||
async def manual_fetch(source_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""手动触发某个监控源的抓取(测试用)"""
|
||||
result = await db.execute(
|
||||
select(RegulationSource).where(RegulationSource.id == uuid.UUID(source_id))
|
||||
)
|
||||
source = result.scalar_one_or_none()
|
||||
if not source:
|
||||
raise HTTPException(status_code=404, detail="监控源不存在")
|
||||
|
||||
task = fetch_regulation_source.delay(source_id)
|
||||
return {"task_id": task.id, "status": "queued", "source_id": source_id}
|
||||
|
||||
|
||||
@router.get("/updates")
|
||||
async def get_updates(
|
||||
domain: str | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
query = select(RegulationUpdate).order_by(desc(RegulationUpdate.fetched_at))
|
||||
result = await db.execute(query.limit(limit).offset(offset))
|
||||
updates = result.scalars().all()
|
||||
return {
|
||||
"updates": [
|
||||
{
|
||||
"id": str(u.id),
|
||||
"title": u.title,
|
||||
"url": u.url,
|
||||
"change_type": u.change_type,
|
||||
"summary": u.summary,
|
||||
"importance": u.importance,
|
||||
"fetched_at": u.fetched_at.isoformat() if u.fetched_at else None,
|
||||
}
|
||||
for u in updates
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.post("/subscribe")
|
||||
async def subscribe(req: SubscribeRequest, db: AsyncSession = Depends(get_db)):
|
||||
from ..models.db import Workspace # 借用DB session
|
||||
# 简化版:仅记录订阅(推送逻辑在 push-worker 中实现)
|
||||
return {
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": req.name,
|
||||
"channel": req.channel,
|
||||
"domains": req.domains,
|
||||
"status": "active",
|
||||
}
|
||||
Reference in New Issue
Block a user