first commit
This commit is contained in:
24
services/compliance-backend/Dockerfile
Normal file
24
services/compliance-backend/Dockerfile
Normal file
@@ -0,0 +1,24 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 使用 uv 加速依赖安装
|
||||
RUN pip install uv --no-cache-dir
|
||||
|
||||
COPY pyproject.toml .
|
||||
RUN uv pip install --system --no-cache -r pyproject.toml \
|
||||
--index-url https://pypi.tuna.tsinghua.edu.cn/simple \
|
||||
--trusted-host pypi.tuna.tsinghua.edu.cn
|
||||
|
||||
COPY app/ ./app/
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=30s --retries=5 \
|
||||
CMD curl -f http://localhost:8000/health || exit 1
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "2"]
|
||||
0
services/compliance-backend/app/__init__.py
Normal file
0
services/compliance-backend/app/__init__.py
Normal file
0
services/compliance-backend/app/api/__init__.py
Normal file
0
services/compliance-backend/app/api/__init__.py
Normal file
95
services/compliance-backend/app/api/compliance.py
Normal file
95
services/compliance-backend/app/api/compliance.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import uuid
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from ..core.llm import get_llm, COMPLIANCE_CHECK_PROMPT
|
||||
from ..services.rag import hybrid_search
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/compliance", tags=["合规审查"])
|
||||
|
||||
|
||||
class ComplianceCheckRequest(BaseModel):
|
||||
query: str
|
||||
regulation_domains: list[str] = ["vehicle_safety"]
|
||||
top_k: int = 5
|
||||
|
||||
|
||||
class ComplianceCheckResponse(BaseModel):
|
||||
risk_level: str
|
||||
risk_score: float
|
||||
findings: list[dict]
|
||||
recommendations: list[str]
|
||||
sources: list[dict]
|
||||
|
||||
|
||||
@router.post("/check", response_model=ComplianceCheckResponse)
|
||||
async def check_compliance(req: ComplianceCheckRequest):
|
||||
"""
|
||||
对输入内容进行合规性检查,与法规库比对后给出风险评估。
|
||||
"""
|
||||
# 检索相关法规(从多个域检索)
|
||||
all_chunks = []
|
||||
for domain in req.regulation_domains:
|
||||
chunks = await hybrid_search(
|
||||
req.query,
|
||||
collection_name="regulation_chunks",
|
||||
top_k=req.top_k,
|
||||
)
|
||||
all_chunks.extend(chunks)
|
||||
|
||||
# 去重 + 按分数排序
|
||||
seen = set()
|
||||
unique_chunks = []
|
||||
for c in sorted(all_chunks, key=lambda x: x["score"], reverse=True):
|
||||
if c["id"] not in seen:
|
||||
seen.add(c["id"])
|
||||
unique_chunks.append(c)
|
||||
top_chunks = unique_chunks[:req.top_k]
|
||||
|
||||
if not top_chunks:
|
||||
return ComplianceCheckResponse(
|
||||
risk_level="unknown",
|
||||
risk_score=0,
|
||||
findings=[{"issue": "未找到相关法规,请先上传法规文档"}],
|
||||
recommendations=["上传相关法规文档到知识库后重试"],
|
||||
sources=[],
|
||||
)
|
||||
|
||||
# 构建法规上下文
|
||||
regulations_text = "\n\n".join(
|
||||
f"[{i+1}] {c['content'][:500]}" for i, c in enumerate(top_chunks)
|
||||
)
|
||||
|
||||
prompt = COMPLIANCE_CHECK_PROMPT.format(
|
||||
content=req.query,
|
||||
regulations=regulations_text,
|
||||
)
|
||||
|
||||
llm = get_llm(temperature=0.0)
|
||||
try:
|
||||
response = await llm.ainvoke([HumanMessage(content=prompt)])
|
||||
analysis = response.content
|
||||
except Exception as e:
|
||||
logger.error(f"LLM 合规分析失败:{e}")
|
||||
analysis = f"LLM 分析失败:{e}"
|
||||
|
||||
# 简单解析 LLM 输出(生产可用结构化输出)
|
||||
risk_level = "medium"
|
||||
risk_score = 50.0
|
||||
if "critical" in analysis.lower() or "严重" in analysis:
|
||||
risk_level, risk_score = "critical", 90.0
|
||||
elif "high" in analysis.lower() or "高风险" in analysis:
|
||||
risk_level, risk_score = "high", 70.0
|
||||
elif "low" in analysis.lower() or "低风险" in analysis:
|
||||
risk_level, risk_score = "low", 20.0
|
||||
|
||||
return ComplianceCheckResponse(
|
||||
risk_level=risk_level,
|
||||
risk_score=risk_score,
|
||||
findings=[{"analysis": analysis}],
|
||||
recommendations=["请参考上述分析进行整改"],
|
||||
sources=[{"content": c["content"][:200], "score": c["score"]} for c in top_chunks],
|
||||
)
|
||||
114
services/compliance-backend/app/api/kb.py
Normal file
114
services/compliance-backend/app/api/kb.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import uuid
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, UploadFile, File, Form, HTTPException, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from ..core.deps import get_db
|
||||
from ..models.db import Workspace, File as FileRecord, Task
|
||||
from ..services.rag import hybrid_search, rerank, generate_answer
|
||||
from ..worker import process_file_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/kb", tags=["知识库"])
|
||||
|
||||
UPLOAD_DIR = Path("/app/data/uploads")
|
||||
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
class WorkspaceCreate(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
domain: str = "general"
|
||||
|
||||
|
||||
class QARequest(BaseModel):
|
||||
query: str
|
||||
workspace_id: str | None = None
|
||||
top_k: int = 5
|
||||
return_sources: bool = True
|
||||
|
||||
|
||||
@router.post("/workspaces")
|
||||
async def create_workspace(req: WorkspaceCreate, db: AsyncSession = Depends(get_db)):
|
||||
ws = Workspace(name=req.name, description=req.description, domain=req.domain)
|
||||
db.add(ws)
|
||||
await db.flush()
|
||||
return {"id": str(ws.id), "name": ws.name, "domain": ws.domain}
|
||||
|
||||
|
||||
@router.post("/files/upload")
|
||||
async def upload_file(
|
||||
background_tasks: BackgroundTasks,
|
||||
file: UploadFile = File(...),
|
||||
workspace_id: str = Form(default=""),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
content = await file.read()
|
||||
file_id = str(uuid.uuid4())
|
||||
suffix = Path(file.filename or "doc").suffix
|
||||
save_path = UPLOAD_DIR / f"{file_id}{suffix}"
|
||||
save_path.write_bytes(content)
|
||||
|
||||
file_record = FileRecord(
|
||||
id=uuid.UUID(file_id),
|
||||
filename=f"{file_id}{suffix}",
|
||||
original_name=file.filename or "unknown",
|
||||
file_type=suffix.lstrip("."),
|
||||
file_size=len(content),
|
||||
storage_path=str(save_path),
|
||||
workspace_id=uuid.UUID(workspace_id) if workspace_id else None,
|
||||
status="uploaded",
|
||||
)
|
||||
db.add(file_record)
|
||||
|
||||
task = Task(
|
||||
task_type="parse_and_vectorize",
|
||||
status="pending",
|
||||
file_id=uuid.UUID(file_id),
|
||||
payload={"workspace_id": workspace_id},
|
||||
)
|
||||
db.add(task)
|
||||
await db.flush()
|
||||
|
||||
# 异步触发 Celery 任务
|
||||
celery_task = process_file_task.delay(file_id, str(task.id), workspace_id)
|
||||
task.celery_task_id = celery_task.id
|
||||
await db.flush()
|
||||
|
||||
return {"file_id": file_id, "task_id": str(task.id), "status": "processing"}
|
||||
|
||||
|
||||
@router.get("/tasks/{task_id}")
|
||||
async def get_task(task_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Task).where(Task.id == uuid.UUID(task_id)))
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
return {
|
||||
"task_id": str(task.id),
|
||||
"status": task.status,
|
||||
"progress": task.progress,
|
||||
"file_id": str(task.file_id) if task.file_id else None,
|
||||
"error_msg": task.error_msg,
|
||||
"completed_at": task.completed_at.isoformat() if task.completed_at else None,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/qa")
|
||||
async def qa(req: QARequest):
|
||||
chunks = await hybrid_search(req.query, workspace_id=req.workspace_id, top_k=req.top_k * 2)
|
||||
ranked = await rerank(req.query, chunks, top_k=req.top_k)
|
||||
result = await generate_answer(req.query, ranked)
|
||||
if not req.return_sources:
|
||||
result.pop("sources", None)
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/knowledge/retrieval")
|
||||
async def retrieval(req: QARequest):
|
||||
chunks = await hybrid_search(req.query, workspace_id=req.workspace_id, top_k=req.top_k)
|
||||
return {"chunks": chunks, "total": len(chunks)}
|
||||
111
services/compliance-backend/app/api/regulation.py
Normal file
111
services/compliance-backend/app/api/regulation.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import uuid
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, desc
|
||||
|
||||
from ..core.deps import get_db
|
||||
from ..models.db import RegulationSource, RegulationUpdate
|
||||
from ..worker import fetch_regulation_source
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/regulation", tags=["法规监控"])
|
||||
|
||||
|
||||
class SourceCreate(BaseModel):
|
||||
name: str
|
||||
url: str
|
||||
domain: str = "vehicle_safety"
|
||||
fetch_interval: int = 86400
|
||||
fetch_config: dict = {}
|
||||
|
||||
|
||||
class SubscribeRequest(BaseModel):
|
||||
name: str
|
||||
channel: str # email / webhook / feishu / dingtalk
|
||||
target: str
|
||||
domains: list[str] = []
|
||||
importance_min: str = "normal"
|
||||
|
||||
|
||||
@router.post("/sources")
|
||||
async def create_source(req: SourceCreate, db: AsyncSession = Depends(get_db)):
|
||||
source = RegulationSource(
|
||||
name=req.name,
|
||||
url=req.url,
|
||||
domain=req.domain,
|
||||
fetch_interval=req.fetch_interval,
|
||||
fetch_config=req.fetch_config,
|
||||
)
|
||||
db.add(source)
|
||||
await db.flush()
|
||||
return {
|
||||
"id": str(source.id),
|
||||
"name": source.name,
|
||||
"url": source.url,
|
||||
"domain": source.domain,
|
||||
"status": "active",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/sources")
|
||||
async def list_sources(db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(
|
||||
select(RegulationSource).where(RegulationSource.is_active == True)
|
||||
)
|
||||
sources = result.scalars().all()
|
||||
return [{"id": str(s.id), "name": s.name, "url": s.url, "domain": s.domain} for s in sources]
|
||||
|
||||
|
||||
@router.post("/sources/{source_id}/fetch")
|
||||
async def manual_fetch(source_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""手动触发某个监控源的抓取(测试用)"""
|
||||
result = await db.execute(
|
||||
select(RegulationSource).where(RegulationSource.id == uuid.UUID(source_id))
|
||||
)
|
||||
source = result.scalar_one_or_none()
|
||||
if not source:
|
||||
raise HTTPException(status_code=404, detail="监控源不存在")
|
||||
|
||||
task = fetch_regulation_source.delay(source_id)
|
||||
return {"task_id": task.id, "status": "queued", "source_id": source_id}
|
||||
|
||||
|
||||
@router.get("/updates")
|
||||
async def get_updates(
|
||||
domain: str | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
query = select(RegulationUpdate).order_by(desc(RegulationUpdate.fetched_at))
|
||||
result = await db.execute(query.limit(limit).offset(offset))
|
||||
updates = result.scalars().all()
|
||||
return {
|
||||
"updates": [
|
||||
{
|
||||
"id": str(u.id),
|
||||
"title": u.title,
|
||||
"url": u.url,
|
||||
"change_type": u.change_type,
|
||||
"summary": u.summary,
|
||||
"importance": u.importance,
|
||||
"fetched_at": u.fetched_at.isoformat() if u.fetched_at else None,
|
||||
}
|
||||
for u in updates
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.post("/subscribe")
|
||||
async def subscribe(req: SubscribeRequest, db: AsyncSession = Depends(get_db)):
|
||||
from ..models.db import Workspace # 借用DB session
|
||||
# 简化版:仅记录订阅(推送逻辑在 push-worker 中实现)
|
||||
return {
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": req.name,
|
||||
"channel": req.channel,
|
||||
"domains": req.domains,
|
||||
"status": "active",
|
||||
}
|
||||
0
services/compliance-backend/app/core/__init__.py
Normal file
0
services/compliance-backend/app/core/__init__.py
Normal file
37
services/compliance-backend/app/core/config.py
Normal file
37
services/compliance-backend/app/core/config.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
|
||||
|
||||
# 应用
|
||||
app_env: str = "development"
|
||||
log_level: str = "INFO"
|
||||
api_secret_key: str = "change_this_key"
|
||||
|
||||
# 数据库
|
||||
database_url: str = "postgresql+asyncpg://compliance:compliance123@postgres:5432/compliance_db"
|
||||
redis_url: str = "redis://:redis123@redis:6379/0"
|
||||
|
||||
# Milvus
|
||||
milvus_host: str = "milvus"
|
||||
milvus_port: int = 19530
|
||||
|
||||
# Neo4j
|
||||
neo4j_uri: str = "bolt://neo4j:7687"
|
||||
neo4j_user: str = "neo4j"
|
||||
neo4j_password: str = "neo4j123"
|
||||
|
||||
# AI 服务
|
||||
embedding_service_url: str = "http://embedding-service:8010"
|
||||
mcp_server_url: str = "http://mcp-server:8011"
|
||||
|
||||
# LLM
|
||||
llm_provider: str = "deepseek" # deepseek / qwen
|
||||
deepseek_api_key: str = ""
|
||||
deepseek_model: str = "deepseek-chat"
|
||||
dashscope_api_key: str = ""
|
||||
qwen_model: str = "qwen-plus"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
54
services/compliance-backend/app/core/deps.py
Normal file
54
services/compliance-backend/app/core/deps.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from functools import lru_cache
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
from neo4j import AsyncGraphDatabase
|
||||
from pymilvus import connections, Collection
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
|
||||
from .config import settings
|
||||
|
||||
# ── PostgreSQL ──────────────────────────────────
|
||||
engine = create_async_engine(settings.database_url, pool_size=10, max_overflow=20)
|
||||
AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
# ── Milvus ──────────────────────────────────────
|
||||
def get_milvus_collection(name: str) -> Collection:
|
||||
connections.connect(host=settings.milvus_host, port=settings.milvus_port)
|
||||
return Collection(name)
|
||||
|
||||
|
||||
# ── Neo4j ───────────────────────────────────────
|
||||
_neo4j_driver = None
|
||||
|
||||
|
||||
def get_neo4j():
|
||||
global _neo4j_driver
|
||||
if _neo4j_driver is None:
|
||||
_neo4j_driver = AsyncGraphDatabase.driver(
|
||||
settings.neo4j_uri,
|
||||
auth=(settings.neo4j_user, settings.neo4j_password),
|
||||
)
|
||||
return _neo4j_driver
|
||||
|
||||
|
||||
# ── HTTP 客户端(复用连接池)────────────────────
|
||||
_http_client = None
|
||||
|
||||
|
||||
def get_http_client() -> httpx.AsyncClient:
|
||||
global _http_client
|
||||
if _http_client is None:
|
||||
_http_client = httpx.AsyncClient(timeout=120.0)
|
||||
return _http_client
|
||||
56
services/compliance-backend/app/core/llm.py
Normal file
56
services/compliance-backend/app/core/llm.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from langchain_openai import ChatOpenAI
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
from .config import settings
|
||||
|
||||
|
||||
def get_llm(temperature: float = 0.1) -> ChatOpenAI:
|
||||
"""获取 LLM 客户端(DeepSeek 或 Qwen,均兼容 OpenAI API)"""
|
||||
if settings.llm_provider == "deepseek":
|
||||
return ChatOpenAI(
|
||||
model=settings.deepseek_model,
|
||||
api_key=settings.deepseek_api_key,
|
||||
base_url="https://api.deepseek.com/v1",
|
||||
temperature=temperature,
|
||||
max_retries=3,
|
||||
timeout=120,
|
||||
)
|
||||
elif settings.llm_provider == "qwen":
|
||||
return ChatOpenAI(
|
||||
model=settings.qwen_model,
|
||||
api_key=settings.dashscope_api_key,
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
temperature=temperature,
|
||||
max_retries=3,
|
||||
timeout=120,
|
||||
)
|
||||
raise ValueError(f"不支持的 LLM 提供商:{settings.llm_provider}")
|
||||
|
||||
|
||||
RAG_SYSTEM_PROMPT = """你是一位专业的汽车行业合规专家,具备深厚的法规知识(GB标准、UN-ECE、ISO 45001、IATF 16949等)。
|
||||
|
||||
回答规则:
|
||||
1. 仅基于提供的参考文献回答,不添加不在文献中的信息
|
||||
2. 每个关键陈述必须标注来源(格式:[来源:文件名,第X页])
|
||||
3. 如果参考文献不足以回答问题,明确说明
|
||||
4. 使用专业但清晰的语言,适合工程师和法务人员阅读
|
||||
5. 对于数值要求(如绝缘电阻值、时间限制等),精确引用原文"""
|
||||
|
||||
|
||||
COMPLIANCE_CHECK_PROMPT = """你是一位专业的汽车合规审查专家。
|
||||
|
||||
请对以下内容进行合规性评估:
|
||||
|
||||
【待审查内容】
|
||||
{content}
|
||||
|
||||
【相关法规要求】
|
||||
{regulations}
|
||||
|
||||
请按以下格式输出:
|
||||
1. 整体风险等级:[low/medium/high/critical]
|
||||
2. 风险分数:[0-100]
|
||||
3. 发现的合规问题(逐条列出):
|
||||
- 问题描述
|
||||
- 违反的具体法规条款
|
||||
- 严重程度
|
||||
4. 整改建议(具体可操作)"""
|
||||
84
services/compliance-backend/app/main.py
Normal file
84
services/compliance-backend/app/main.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import structlog
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
|
||||
from .api import kb, compliance, regulation
|
||||
from .core.config import settings
|
||||
|
||||
# 结构化日志配置
|
||||
structlog.configure(
|
||||
wrapper_class=structlog.make_filtering_bound_logger(
|
||||
getattr(logging, settings.log_level.upper(), logging.INFO)
|
||||
)
|
||||
)
|
||||
logger = structlog.get_logger()
|
||||
|
||||
app = FastAPI(
|
||||
title="AI合规智能中枢 API",
|
||||
description="面向车企与工厂的全链路合规智能平台",
|
||||
version="0.1.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
)
|
||||
|
||||
# CORS(开发环境)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"] if settings.app_env == "development" else [],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Prometheus 指标
|
||||
Instrumentator().instrument(app).expose(app)
|
||||
|
||||
# 注册路由
|
||||
app.include_router(kb.router)
|
||||
app.include_router(compliance.router)
|
||||
app.include_router(regulation.router)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def log_requests(request: Request, call_next):
|
||||
start = time.time()
|
||||
response = await call_next(request)
|
||||
duration_ms = int((time.time() - start) * 1000)
|
||||
logger.info(
|
||||
"request",
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
status=response.status_code,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""健康检查(含依赖服务检测)"""
|
||||
import httpx
|
||||
from .core.config import settings
|
||||
|
||||
checks = {"status": "ok", "services": {}}
|
||||
|
||||
# 检查嵌入服务
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
r = await client.get(f"{settings.embedding_service_url}/health")
|
||||
checks["services"]["embedding"] = "ok" if r.status_code == 200 else "degraded"
|
||||
except Exception:
|
||||
checks["services"]["embedding"] = "unavailable"
|
||||
|
||||
# 检查 MCP Server
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
r = await client.get(f"{settings.mcp_server_url}/health")
|
||||
checks["services"]["mcp"] = "ok" if r.status_code == 200 else "degraded"
|
||||
except Exception:
|
||||
checks["services"]["mcp"] = "unavailable"
|
||||
|
||||
return checks
|
||||
0
services/compliance-backend/app/models/__init__.py
Normal file
0
services/compliance-backend/app/models/__init__.py
Normal file
113
services/compliance-backend/app/models/db.py
Normal file
113
services/compliance-backend/app/models/db.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, String, Integer, BigInteger, Boolean, Text, ARRAY, Numeric
|
||||
from sqlalchemy import DateTime, ForeignKey, func
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB, INET
|
||||
from sqlalchemy.orm import DeclarativeBase, relationship
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class Workspace(Base):
|
||||
__tablename__ = "workspaces"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String(255), nullable=False)
|
||||
description = Column(Text)
|
||||
domain = Column(String(100))
|
||||
created_by = Column(String(255))
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
|
||||
files = relationship("File", back_populates="workspace")
|
||||
|
||||
|
||||
class File(Base):
|
||||
__tablename__ = "files"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="CASCADE"))
|
||||
filename = Column(String(500), nullable=False)
|
||||
original_name = Column(String(500), nullable=False)
|
||||
file_type = Column(String(50))
|
||||
file_size = Column(BigInteger)
|
||||
storage_path = Column(Text)
|
||||
parsed_path = Column(Text)
|
||||
status = Column(String(50), default="uploaded")
|
||||
error_msg = Column(Text)
|
||||
metadata = Column(JSONB, default={})
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
|
||||
workspace = relationship("Workspace", back_populates="files")
|
||||
tasks = relationship("Task", back_populates="file")
|
||||
|
||||
|
||||
class Task(Base):
|
||||
__tablename__ = "tasks"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
task_type = Column(String(100), nullable=False)
|
||||
status = Column(String(50), default="pending")
|
||||
payload = Column(JSONB, default={})
|
||||
result = Column(JSONB)
|
||||
error_msg = Column(Text)
|
||||
progress = Column(Integer, default=0)
|
||||
file_id = Column(UUID(as_uuid=True), ForeignKey("files.id"))
|
||||
celery_task_id = Column(String(255))
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
completed_at = Column(DateTime(timezone=True))
|
||||
|
||||
file = relationship("File", back_populates="tasks")
|
||||
|
||||
|
||||
class ComplianceReport(Base):
|
||||
__tablename__ = "compliance_reports"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
file_id = Column(UUID(as_uuid=True), ForeignKey("files.id"))
|
||||
regulation_domains = Column(ARRAY(Text))
|
||||
overall_risk_level = Column(String(20))
|
||||
risk_score = Column(Numeric(5, 2))
|
||||
findings = Column(JSONB, default=[])
|
||||
recommendations = Column(JSONB, default=[])
|
||||
report_markdown = Column(Text)
|
||||
llm_model = Column(String(100))
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
|
||||
class RegulationSource(Base):
|
||||
__tablename__ = "regulation_sources"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String(255), nullable=False)
|
||||
url = Column(Text, nullable=False)
|
||||
source_type = Column(String(50), default="webpage")
|
||||
domain = Column(String(100))
|
||||
fetch_interval = Column(Integer, default=86400)
|
||||
is_active = Column(Boolean, default=True)
|
||||
last_fetched_at = Column(DateTime(timezone=True))
|
||||
last_hash = Column(String(64))
|
||||
fetch_config = Column(JSONB, default={})
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
|
||||
class RegulationUpdate(Base):
|
||||
__tablename__ = "regulation_updates"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
source_id = Column(UUID(as_uuid=True), ForeignKey("regulation_sources.id"))
|
||||
title = Column(String(500))
|
||||
url = Column(Text)
|
||||
change_type = Column(String(50))
|
||||
summary = Column(Text)
|
||||
raw_content = Column(Text)
|
||||
diff_content = Column(Text)
|
||||
is_notified = Column(Boolean, default=False)
|
||||
importance = Column(String(20), default="normal")
|
||||
fetched_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
published_at = Column(DateTime(timezone=True))
|
||||
21
services/compliance-backend/app/services/embed.py
Normal file
21
services/compliance-backend/app/services/embed.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import httpx
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
from ..core.config import settings
|
||||
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10))
|
||||
async def embed_texts(texts: list[str], batch_size: int = 12) -> dict:
|
||||
"""调用嵌入服务,返回 dense 和 sparse 向量"""
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
resp = await client.post(
|
||||
f"{settings.embedding_service_url}/embed",
|
||||
json={"texts": texts, "batch_size": batch_size},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def embed_single(text: str) -> list[float]:
|
||||
"""嵌入单条文本,返回 dense 向量"""
|
||||
result = await embed_texts([text], batch_size=1)
|
||||
return result["dense"][0]
|
||||
65
services/compliance-backend/app/services/graph.py
Normal file
65
services/compliance-backend/app/services/graph.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import logging
|
||||
from ..core.deps import get_neo4j
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def create_regulation_node(regulation: dict) -> str:
|
||||
"""在 Neo4j 中创建法规节点"""
|
||||
driver = get_neo4j()
|
||||
async with driver.session() as session:
|
||||
result = await session.run(
|
||||
"""
|
||||
MERGE (r:Regulation {id: $id})
|
||||
SET r.title = $title,
|
||||
r.domain = $domain,
|
||||
r.version = $version,
|
||||
r.code = $code
|
||||
RETURN r.id as id
|
||||
""",
|
||||
id=regulation.get("id"),
|
||||
title=regulation.get("title", ""),
|
||||
domain=regulation.get("domain", ""),
|
||||
version=regulation.get("version", ""),
|
||||
code=regulation.get("code", ""),
|
||||
)
|
||||
record = await result.single()
|
||||
return record["id"] if record else None
|
||||
|
||||
|
||||
async def create_clause_node(clause: dict, regulation_id: str) -> str:
|
||||
"""创建条款节点并关联到法规"""
|
||||
driver = get_neo4j()
|
||||
async with driver.session() as session:
|
||||
result = await session.run(
|
||||
"""
|
||||
MATCH (r:Regulation {id: $reg_id})
|
||||
MERGE (c:Clause {id: $id})
|
||||
SET c.number = $number,
|
||||
c.content = $content
|
||||
MERGE (r)-[:CONTAINS]->(c)
|
||||
RETURN c.id as id
|
||||
""",
|
||||
reg_id=regulation_id,
|
||||
id=clause.get("id"),
|
||||
number=clause.get("number", ""),
|
||||
content=clause.get("content", "")[:2000],
|
||||
)
|
||||
record = await result.single()
|
||||
return record["id"] if record else None
|
||||
|
||||
|
||||
async def search_related_regulations(domain: str, limit: int = 10) -> list[dict]:
|
||||
"""查询指定域下的所有法规"""
|
||||
driver = get_neo4j()
|
||||
async with driver.session() as session:
|
||||
result = await session.run(
|
||||
"""
|
||||
MATCH (r:Regulation {domain: $domain})
|
||||
RETURN r.id as id, r.title as title, r.code as code, r.version as version
|
||||
LIMIT $limit
|
||||
""",
|
||||
domain=domain,
|
||||
limit=limit,
|
||||
)
|
||||
return [dict(record) async for record in result]
|
||||
59
services/compliance-backend/app/services/monitor.py
Normal file
59
services/compliance-backend/app/services/monitor.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import httpx
|
||||
from bs4 import BeautifulSoup
|
||||
from datetime import datetime, timezone
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def fetch_url(url: str, timeout: int = 30) -> str | None:
|
||||
"""抓取 URL 内容"""
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=timeout,
|
||||
headers={"User-Agent": "Mozilla/5.0 (compliance-monitor/1.0)"},
|
||||
follow_redirects=True,
|
||||
) as client:
|
||||
resp = await client.get(url)
|
||||
resp.raise_for_status()
|
||||
return resp.text
|
||||
except Exception as e:
|
||||
logger.warning(f"抓取 {url} 失败:{e}")
|
||||
return None
|
||||
|
||||
|
||||
def extract_text(html: str) -> str:
|
||||
"""提取 HTML 中的主要文本内容"""
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
for tag in soup(["script", "style", "nav", "footer", "header"]):
|
||||
tag.decompose()
|
||||
return soup.get_text(separator="\n", strip=True)
|
||||
|
||||
|
||||
def compute_hash(content: str) -> str:
|
||||
return hashlib.md5(content.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
async def check_source_for_updates(source: dict) -> dict | None:
|
||||
"""
|
||||
检查监控源是否有更新。
|
||||
返回 None 表示无变化,返回 dict 表示有新内容。
|
||||
"""
|
||||
html = await fetch_url(source["url"])
|
||||
if not html:
|
||||
return None
|
||||
|
||||
text = extract_text(html)
|
||||
new_hash = compute_hash(text)
|
||||
|
||||
if source.get("last_hash") == new_hash:
|
||||
logger.info(f"监控源 {source['name']} 无变化")
|
||||
return None
|
||||
|
||||
return {
|
||||
"source_id": source["id"],
|
||||
"raw_content": text[:50000], # 最多保存 50KB
|
||||
"new_hash": new_hash,
|
||||
"fetched_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
43
services/compliance-backend/app/services/parse.py
Normal file
43
services/compliance-backend/app/services/parse.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import httpx
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
from ..core.config import settings
|
||||
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=2, max=30))
|
||||
async def parse_document(file_content: bytes, filename: str) -> dict:
|
||||
"""调用 mcp-server 解析文档,返回 Markdown"""
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
resp = await client.post(
|
||||
f"{settings.mcp_server_url}/parse-document",
|
||||
files={"file": (filename, file_content, "application/octet-stream")},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def chunk_text(text: str, chunk_size: int = 512, overlap: int = 64) -> list[dict]:
|
||||
"""将文本按 token 数分块(简单版,按字符数估算)"""
|
||||
chars_per_chunk = chunk_size * 2 # 中文约2字符/token
|
||||
chars_overlap = overlap * 2
|
||||
chunks = []
|
||||
start = 0
|
||||
idx = 0
|
||||
|
||||
while start < len(text):
|
||||
end = min(start + chars_per_chunk, len(text))
|
||||
# 尝试在段落边界截断
|
||||
if end < len(text):
|
||||
for sep in ["\n\n", "\n", "。", ".", " "]:
|
||||
pos = text.rfind(sep, start, end)
|
||||
if pos > start + chars_per_chunk // 2:
|
||||
end = pos + len(sep)
|
||||
break
|
||||
|
||||
chunk_text = text[start:end].strip()
|
||||
if chunk_text:
|
||||
chunks.append({"idx": idx, "content": chunk_text, "start": start, "end": end})
|
||||
idx += 1
|
||||
|
||||
start = max(start + 1, end - chars_overlap)
|
||||
|
||||
return chunks
|
||||
92
services/compliance-backend/app/services/rag.py
Normal file
92
services/compliance-backend/app/services/rag.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import logging
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
from pymilvus import connections, Collection
|
||||
|
||||
from .embed import embed_single, embed_texts
|
||||
from ..core.llm import get_llm, RAG_SYSTEM_PROMPT
|
||||
from ..core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_collection(name: str) -> Collection:
|
||||
connections.connect(host=settings.milvus_host, port=settings.milvus_port)
|
||||
return Collection(name)
|
||||
|
||||
|
||||
async def hybrid_search(
|
||||
query: str,
|
||||
collection_name: str = "regulation_chunks",
|
||||
top_k: int = 10,
|
||||
workspace_id: str | None = None,
|
||||
) -> list[dict]:
|
||||
"""混合检索:BGE-M3 向量检索(调研版简化,省去 BM25 融合)"""
|
||||
query_vec = await embed_single(query)
|
||||
|
||||
col = _get_collection(collection_name)
|
||||
|
||||
expr = f'workspace_id == "{workspace_id}"' if workspace_id else None
|
||||
results = col.search(
|
||||
data=[query_vec],
|
||||
anns_field="dense_vec",
|
||||
param={"metric_type": "COSINE", "params": {"ef": 100}},
|
||||
limit=top_k,
|
||||
expr=expr,
|
||||
output_fields=["content", "metadata", "file_id", "chunk_idx"],
|
||||
)
|
||||
|
||||
chunks = []
|
||||
for hits in results:
|
||||
for hit in hits:
|
||||
chunks.append({
|
||||
"id": hit.id,
|
||||
"content": hit.entity.get("content", ""),
|
||||
"score": float(hit.score),
|
||||
"file_id": hit.entity.get("file_id", ""),
|
||||
"chunk_idx": hit.entity.get("chunk_idx", 0),
|
||||
"metadata": hit.entity.get("metadata", {}),
|
||||
})
|
||||
return chunks
|
||||
|
||||
|
||||
async def rerank(query: str, chunks: list[dict], top_k: int = 5) -> list[dict]:
|
||||
"""简化版精排(调研版按 score 直接排序,生产可换 Cross-Encoder)"""
|
||||
return sorted(chunks, key=lambda x: x["score"], reverse=True)[:top_k]
|
||||
|
||||
|
||||
async def generate_answer(query: str, chunks: list[dict]) -> dict:
|
||||
"""基于检索结果,调用 LLM 生成引文锚定的答案"""
|
||||
if not chunks:
|
||||
return {"answer": "未找到相关法规内容,请上传相关法规文档后重试。", "sources": []}
|
||||
|
||||
# 构建 RAG 上下文
|
||||
context_parts = []
|
||||
for i, chunk in enumerate(chunks, 1):
|
||||
meta = chunk.get("metadata", {})
|
||||
source_info = f"[来源 {i}:{meta.get('filename', '未知文件')},第 {meta.get('page', '?')} 页]"
|
||||
context_parts.append(f"{source_info}\n{chunk['content']}")
|
||||
|
||||
context = "\n\n---\n\n".join(context_parts)
|
||||
user_prompt = f"参考文献:\n\n{context}\n\n问题:{query}\n\n请基于以上参考文献回答,并标注来源。"
|
||||
|
||||
llm = get_llm(temperature=0.1)
|
||||
messages = [SystemMessage(content=RAG_SYSTEM_PROMPT), HumanMessage(content=user_prompt)]
|
||||
|
||||
try:
|
||||
response = await llm.ainvoke(messages)
|
||||
answer = response.content
|
||||
except Exception as e:
|
||||
logger.error(f"LLM 生成失败:{e}")
|
||||
answer = f"LLM 生成失败:{e}。检索到的相关内容:{chunks[0]['content'][:200]}..."
|
||||
|
||||
sources = [
|
||||
{
|
||||
"content": c["content"][:300],
|
||||
"file_id": c.get("file_id", ""),
|
||||
"chunk_idx": c.get("chunk_idx", 0),
|
||||
"score": c.get("score", 0),
|
||||
"metadata": c.get("metadata", {}),
|
||||
}
|
||||
for c in chunks
|
||||
]
|
||||
return {"answer": answer, "sources": sources}
|
||||
212
services/compliance-backend/app/worker.py
Normal file
212
services/compliance-backend/app/worker.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import uuid
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from celery import Celery
|
||||
from celery.schedules import crontab
|
||||
|
||||
from .core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Celery 配置
|
||||
celery_app = Celery(
|
||||
"compliance",
|
||||
broker=settings.redis_url,
|
||||
backend=settings.redis_url,
|
||||
)
|
||||
celery_app.conf.update(
|
||||
task_serializer="json",
|
||||
accept_content=["json"],
|
||||
result_serializer="json",
|
||||
timezone="Asia/Shanghai",
|
||||
task_routes={
|
||||
"app.worker.process_file_task": {"queue": "parse"},
|
||||
"app.worker.fetch_regulation_source": {"queue": "monitor"},
|
||||
"app.worker.send_notifications": {"queue": "push"},
|
||||
},
|
||||
beat_schedule={
|
||||
"daily-regulation-monitor": {
|
||||
"task": "app.worker.run_all_monitors",
|
||||
"schedule": crontab(hour=2, minute=0),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# ── 文件处理任务(解析 + 向量化)────────────────
|
||||
|
||||
@celery_app.task(name="app.worker.process_file_task", bind=True, max_retries=3)
|
||||
def process_file_task(self, file_id: str, task_id: str, workspace_id: str):
|
||||
"""解析文档并向量化存入 Milvus"""
|
||||
import asyncio
|
||||
asyncio.run(_process_file(file_id, task_id, workspace_id))
|
||||
|
||||
|
||||
async def _process_file(file_id: str, task_id: str, workspace_id: str):
|
||||
from pathlib import Path
|
||||
from sqlalchemy import select
|
||||
from .core.deps import AsyncSessionLocal, get_milvus_collection
|
||||
from .models.db import File, Task
|
||||
from .services.parse import parse_document, chunk_text
|
||||
from .services.embed import embed_texts
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
# 查找文件记录
|
||||
result = await db.execute(select(File).where(File.id == uuid.UUID(file_id)))
|
||||
file_record = result.scalar_one_or_none()
|
||||
if not file_record:
|
||||
logger.error(f"文件 {file_id} 不存在")
|
||||
return
|
||||
|
||||
task_result = await db.execute(select(Task).where(Task.id == uuid.UUID(task_id)))
|
||||
task = task_result.scalar_one_or_none()
|
||||
|
||||
try:
|
||||
# 更新状态
|
||||
file_record.status = "parsing"
|
||||
if task:
|
||||
task.status = "running"
|
||||
task.progress = 10
|
||||
await db.commit()
|
||||
|
||||
# Step 1:解析文档
|
||||
file_content = Path(file_record.storage_path).read_bytes()
|
||||
parse_result = await parse_document(file_content, file_record.original_name)
|
||||
markdown = parse_result.get("markdown", "")
|
||||
|
||||
if not markdown.strip():
|
||||
raise ValueError("文档解析结果为空")
|
||||
|
||||
file_record.status = "parsed"
|
||||
if task:
|
||||
task.progress = 40
|
||||
await db.commit()
|
||||
|
||||
# Step 2:分块
|
||||
chunks = chunk_text(markdown, chunk_size=512, overlap=64)
|
||||
logger.info(f"文件 {file_id} 分割为 {len(chunks)} 块")
|
||||
|
||||
# Step 3:向量化(分批处理)
|
||||
batch_size = 16
|
||||
col = get_milvus_collection("regulation_chunks")
|
||||
|
||||
for i in range(0, len(chunks), batch_size):
|
||||
batch = chunks[i:i + batch_size]
|
||||
texts = [c["content"] for c in batch]
|
||||
embed_result = await embed_texts(texts, batch_size=batch_size)
|
||||
dense_vecs = embed_result["dense"]
|
||||
|
||||
entities = [
|
||||
[f"{file_id}_{c['idx']}" for c in batch],
|
||||
[file_id] * len(batch),
|
||||
[workspace_id] * len(batch),
|
||||
[c["idx"] for c in batch],
|
||||
[c["content"] for c in batch],
|
||||
dense_vecs,
|
||||
[{"filename": file_record.original_name, "page": c.get("page", 0)} for c in batch],
|
||||
]
|
||||
col.insert(entities)
|
||||
|
||||
if task:
|
||||
task.progress = 40 + int(60 * (i + batch_size) / len(chunks))
|
||||
await db.commit()
|
||||
|
||||
col.flush()
|
||||
|
||||
# 完成
|
||||
file_record.status = "vectorized"
|
||||
if task:
|
||||
task.status = "completed"
|
||||
task.progress = 100
|
||||
task.completed_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
logger.info(f"文件 {file_id} 处理完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"文件 {file_id} 处理失败:{e}")
|
||||
file_record.status = "failed"
|
||||
file_record.error_msg = str(e)
|
||||
if task:
|
||||
task.status = "failed"
|
||||
task.error_msg = str(e)
|
||||
await db.commit()
|
||||
raise
|
||||
|
||||
|
||||
# ── 法规监控任务 ────────────────────────────────
|
||||
|
||||
@celery_app.task(name="app.worker.run_all_monitors")
|
||||
def run_all_monitors():
|
||||
"""定时触发所有活跃监控源"""
|
||||
import asyncio
|
||||
asyncio.run(_run_all_monitors())
|
||||
|
||||
|
||||
async def _run_all_monitors():
|
||||
from sqlalchemy import select
|
||||
from .core.deps import AsyncSessionLocal
|
||||
from .models.db import RegulationSource
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
result = await db.execute(
|
||||
select(RegulationSource).where(RegulationSource.is_active == True)
|
||||
)
|
||||
sources = result.scalars().all()
|
||||
for source in sources:
|
||||
fetch_regulation_source.delay(str(source.id))
|
||||
logger.info(f"触发监控源抓取:{source.name}")
|
||||
|
||||
|
||||
@celery_app.task(name="app.worker.fetch_regulation_source", bind=True, max_retries=2)
|
||||
def fetch_regulation_source(self, source_id: str):
|
||||
import asyncio
|
||||
asyncio.run(_fetch_source(source_id))
|
||||
|
||||
|
||||
async def _fetch_source(source_id: str):
|
||||
import hashlib
|
||||
from sqlalchemy import select
|
||||
from .core.deps import AsyncSessionLocal
|
||||
from .models.db import RegulationSource, RegulationUpdate
|
||||
from .services.monitor import check_source_for_updates
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
result = await db.execute(
|
||||
select(RegulationSource).where(RegulationSource.id == uuid.UUID(source_id))
|
||||
)
|
||||
source = result.scalar_one_or_none()
|
||||
if not source:
|
||||
return
|
||||
|
||||
source_dict = {
|
||||
"id": str(source.id),
|
||||
"name": source.name,
|
||||
"url": source.url,
|
||||
"last_hash": source.last_hash,
|
||||
}
|
||||
update_data = await check_source_for_updates(source_dict)
|
||||
|
||||
if update_data:
|
||||
logger.info(f"检测到变更:{source.name}")
|
||||
source.last_hash = update_data["new_hash"]
|
||||
source.last_fetched_at = datetime.now(timezone.utc)
|
||||
|
||||
update = RegulationUpdate(
|
||||
source_id=uuid.UUID(source_id),
|
||||
change_type="updated",
|
||||
raw_content=update_data["raw_content"][:50000],
|
||||
importance="normal",
|
||||
)
|
||||
db.add(update)
|
||||
await db.commit()
|
||||
else:
|
||||
source.last_fetched_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@celery_app.task(name="app.worker.send_notifications")
|
||||
def send_notifications():
|
||||
logger.info("推送通知任务执行(待实现)")
|
||||
|
||||
|
||||
# 导出供 FastAPI 使用
|
||||
worker = celery_app
|
||||
29
services/compliance-backend/pyproject.toml
Normal file
29
services/compliance-backend/pyproject.toml
Normal file
@@ -0,0 +1,29 @@
|
||||
[project]
|
||||
name = "compliance-backend"
|
||||
version = "0.1.0"
|
||||
description = "AI合规智能中枢 — 业务后端"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"fastapi>=0.115",
|
||||
"uvicorn[standard]>=0.30",
|
||||
"pydantic>=2.7",
|
||||
"pydantic-settings>=2.4",
|
||||
"sqlalchemy[asyncio]>=2.0",
|
||||
"asyncpg>=0.29",
|
||||
"redis[asyncio]>=5.0",
|
||||
"celery[redis]>=5.4",
|
||||
"pymilvus>=2.4",
|
||||
"neo4j>=5.20",
|
||||
"langchain>=0.3",
|
||||
"langchain-openai>=0.2",
|
||||
"langchain-community>=0.3",
|
||||
"llama-index-core>=0.11",
|
||||
"httpx>=0.27",
|
||||
"python-multipart>=0.0.9",
|
||||
"python-jose[cryptography]>=3.3",
|
||||
"structlog>=24.0",
|
||||
"prometheus-fastapi-instrumentator>=7.0",
|
||||
"tenacity>=8.5",
|
||||
"beautifulsoup4>=4.12",
|
||||
"requests>=2.32",
|
||||
]
|
||||
Reference in New Issue
Block a user