first commit

This commit is contained in:
2026-04-23 09:58:47 +08:00
commit 448e078d99
49 changed files with 5188 additions and 0 deletions

View File

@@ -0,0 +1,95 @@
import uuid
import logging
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from langchain.schema import HumanMessage, SystemMessage
from ..core.llm import get_llm, COMPLIANCE_CHECK_PROMPT
from ..services.rag import hybrid_search
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/compliance", tags=["合规审查"])
class ComplianceCheckRequest(BaseModel):
query: str
regulation_domains: list[str] = ["vehicle_safety"]
top_k: int = 5
class ComplianceCheckResponse(BaseModel):
risk_level: str
risk_score: float
findings: list[dict]
recommendations: list[str]
sources: list[dict]
@router.post("/check", response_model=ComplianceCheckResponse)
async def check_compliance(req: ComplianceCheckRequest):
"""
对输入内容进行合规性检查,与法规库比对后给出风险评估。
"""
# 检索相关法规(从多个域检索)
all_chunks = []
for domain in req.regulation_domains:
chunks = await hybrid_search(
req.query,
collection_name="regulation_chunks",
top_k=req.top_k,
)
all_chunks.extend(chunks)
# 去重 + 按分数排序
seen = set()
unique_chunks = []
for c in sorted(all_chunks, key=lambda x: x["score"], reverse=True):
if c["id"] not in seen:
seen.add(c["id"])
unique_chunks.append(c)
top_chunks = unique_chunks[:req.top_k]
if not top_chunks:
return ComplianceCheckResponse(
risk_level="unknown",
risk_score=0,
findings=[{"issue": "未找到相关法规,请先上传法规文档"}],
recommendations=["上传相关法规文档到知识库后重试"],
sources=[],
)
# 构建法规上下文
regulations_text = "\n\n".join(
f"[{i+1}] {c['content'][:500]}" for i, c in enumerate(top_chunks)
)
prompt = COMPLIANCE_CHECK_PROMPT.format(
content=req.query,
regulations=regulations_text,
)
llm = get_llm(temperature=0.0)
try:
response = await llm.ainvoke([HumanMessage(content=prompt)])
analysis = response.content
except Exception as e:
logger.error(f"LLM 合规分析失败:{e}")
analysis = f"LLM 分析失败:{e}"
# 简单解析 LLM 输出(生产可用结构化输出)
risk_level = "medium"
risk_score = 50.0
if "critical" in analysis.lower() or "严重" in analysis:
risk_level, risk_score = "critical", 90.0
elif "high" in analysis.lower() or "高风险" in analysis:
risk_level, risk_score = "high", 70.0
elif "low" in analysis.lower() or "低风险" in analysis:
risk_level, risk_score = "low", 20.0
return ComplianceCheckResponse(
risk_level=risk_level,
risk_score=risk_score,
findings=[{"analysis": analysis}],
recommendations=["请参考上述分析进行整改"],
sources=[{"content": c["content"][:200], "score": c["score"]} for c in top_chunks],
)

View File

@@ -0,0 +1,114 @@
import uuid
import logging
from pathlib import Path
from fastapi import APIRouter, Depends, UploadFile, File, Form, HTTPException, BackgroundTasks
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from ..core.deps import get_db
from ..models.db import Workspace, File as FileRecord, Task
from ..services.rag import hybrid_search, rerank, generate_answer
from ..worker import process_file_task
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/kb", tags=["知识库"])
UPLOAD_DIR = Path("/app/data/uploads")
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
class WorkspaceCreate(BaseModel):
name: str
description: str = ""
domain: str = "general"
class QARequest(BaseModel):
query: str
workspace_id: str | None = None
top_k: int = 5
return_sources: bool = True
@router.post("/workspaces")
async def create_workspace(req: WorkspaceCreate, db: AsyncSession = Depends(get_db)):
ws = Workspace(name=req.name, description=req.description, domain=req.domain)
db.add(ws)
await db.flush()
return {"id": str(ws.id), "name": ws.name, "domain": ws.domain}
@router.post("/files/upload")
async def upload_file(
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
workspace_id: str = Form(default=""),
db: AsyncSession = Depends(get_db),
):
content = await file.read()
file_id = str(uuid.uuid4())
suffix = Path(file.filename or "doc").suffix
save_path = UPLOAD_DIR / f"{file_id}{suffix}"
save_path.write_bytes(content)
file_record = FileRecord(
id=uuid.UUID(file_id),
filename=f"{file_id}{suffix}",
original_name=file.filename or "unknown",
file_type=suffix.lstrip("."),
file_size=len(content),
storage_path=str(save_path),
workspace_id=uuid.UUID(workspace_id) if workspace_id else None,
status="uploaded",
)
db.add(file_record)
task = Task(
task_type="parse_and_vectorize",
status="pending",
file_id=uuid.UUID(file_id),
payload={"workspace_id": workspace_id},
)
db.add(task)
await db.flush()
# 异步触发 Celery 任务
celery_task = process_file_task.delay(file_id, str(task.id), workspace_id)
task.celery_task_id = celery_task.id
await db.flush()
return {"file_id": file_id, "task_id": str(task.id), "status": "processing"}
@router.get("/tasks/{task_id}")
async def get_task(task_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Task).where(Task.id == uuid.UUID(task_id)))
task = result.scalar_one_or_none()
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
return {
"task_id": str(task.id),
"status": task.status,
"progress": task.progress,
"file_id": str(task.file_id) if task.file_id else None,
"error_msg": task.error_msg,
"completed_at": task.completed_at.isoformat() if task.completed_at else None,
}
@router.post("/qa")
async def qa(req: QARequest):
chunks = await hybrid_search(req.query, workspace_id=req.workspace_id, top_k=req.top_k * 2)
ranked = await rerank(req.query, chunks, top_k=req.top_k)
result = await generate_answer(req.query, ranked)
if not req.return_sources:
result.pop("sources", None)
return result
@router.post("/knowledge/retrieval")
async def retrieval(req: QARequest):
chunks = await hybrid_search(req.query, workspace_id=req.workspace_id, top_k=req.top_k)
return {"chunks": chunks, "total": len(chunks)}

View File

@@ -0,0 +1,111 @@
import uuid
import logging
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, desc
from ..core.deps import get_db
from ..models.db import RegulationSource, RegulationUpdate
from ..worker import fetch_regulation_source
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/regulation", tags=["法规监控"])
class SourceCreate(BaseModel):
name: str
url: str
domain: str = "vehicle_safety"
fetch_interval: int = 86400
fetch_config: dict = {}
class SubscribeRequest(BaseModel):
name: str
channel: str # email / webhook / feishu / dingtalk
target: str
domains: list[str] = []
importance_min: str = "normal"
@router.post("/sources")
async def create_source(req: SourceCreate, db: AsyncSession = Depends(get_db)):
source = RegulationSource(
name=req.name,
url=req.url,
domain=req.domain,
fetch_interval=req.fetch_interval,
fetch_config=req.fetch_config,
)
db.add(source)
await db.flush()
return {
"id": str(source.id),
"name": source.name,
"url": source.url,
"domain": source.domain,
"status": "active",
}
@router.get("/sources")
async def list_sources(db: AsyncSession = Depends(get_db)):
result = await db.execute(
select(RegulationSource).where(RegulationSource.is_active == True)
)
sources = result.scalars().all()
return [{"id": str(s.id), "name": s.name, "url": s.url, "domain": s.domain} for s in sources]
@router.post("/sources/{source_id}/fetch")
async def manual_fetch(source_id: str, db: AsyncSession = Depends(get_db)):
"""手动触发某个监控源的抓取(测试用)"""
result = await db.execute(
select(RegulationSource).where(RegulationSource.id == uuid.UUID(source_id))
)
source = result.scalar_one_or_none()
if not source:
raise HTTPException(status_code=404, detail="监控源不存在")
task = fetch_regulation_source.delay(source_id)
return {"task_id": task.id, "status": "queued", "source_id": source_id}
@router.get("/updates")
async def get_updates(
domain: str | None = None,
limit: int = 20,
offset: int = 0,
db: AsyncSession = Depends(get_db),
):
query = select(RegulationUpdate).order_by(desc(RegulationUpdate.fetched_at))
result = await db.execute(query.limit(limit).offset(offset))
updates = result.scalars().all()
return {
"updates": [
{
"id": str(u.id),
"title": u.title,
"url": u.url,
"change_type": u.change_type,
"summary": u.summary,
"importance": u.importance,
"fetched_at": u.fetched_at.isoformat() if u.fetched_at else None,
}
for u in updates
]
}
@router.post("/subscribe")
async def subscribe(req: SubscribeRequest, db: AsyncSession = Depends(get_db)):
from ..models.db import Workspace # 借用DB session
# 简化版:仅记录订阅(推送逻辑在 push-worker 中实现)
return {
"id": str(uuid.uuid4()),
"name": req.name,
"channel": req.channel,
"domains": req.domains,
"status": "active",
}