Files
AIRegulation-Deployment/services/compliance-backend/app/worker.py
2026-04-23 09:58:47 +08:00

213 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import uuid
import logging
from datetime import datetime, timezone
from celery import Celery
from celery.schedules import crontab
from .core.config import settings
logger = logging.getLogger(__name__)
# Celery 配置
celery_app = Celery(
"compliance",
broker=settings.redis_url,
backend=settings.redis_url,
)
celery_app.conf.update(
task_serializer="json",
accept_content=["json"],
result_serializer="json",
timezone="Asia/Shanghai",
task_routes={
"app.worker.process_file_task": {"queue": "parse"},
"app.worker.fetch_regulation_source": {"queue": "monitor"},
"app.worker.send_notifications": {"queue": "push"},
},
beat_schedule={
"daily-regulation-monitor": {
"task": "app.worker.run_all_monitors",
"schedule": crontab(hour=2, minute=0),
},
},
)
# ── 文件处理任务(解析 + 向量化)────────────────
@celery_app.task(name="app.worker.process_file_task", bind=True, max_retries=3)
def process_file_task(self, file_id: str, task_id: str, workspace_id: str):
"""解析文档并向量化存入 Milvus"""
import asyncio
asyncio.run(_process_file(file_id, task_id, workspace_id))
async def _process_file(file_id: str, task_id: str, workspace_id: str):
from pathlib import Path
from sqlalchemy import select
from .core.deps import AsyncSessionLocal, get_milvus_collection
from .models.db import File, Task
from .services.parse import parse_document, chunk_text
from .services.embed import embed_texts
async with AsyncSessionLocal() as db:
# 查找文件记录
result = await db.execute(select(File).where(File.id == uuid.UUID(file_id)))
file_record = result.scalar_one_or_none()
if not file_record:
logger.error(f"文件 {file_id} 不存在")
return
task_result = await db.execute(select(Task).where(Task.id == uuid.UUID(task_id)))
task = task_result.scalar_one_or_none()
try:
# 更新状态
file_record.status = "parsing"
if task:
task.status = "running"
task.progress = 10
await db.commit()
# Step 1解析文档
file_content = Path(file_record.storage_path).read_bytes()
parse_result = await parse_document(file_content, file_record.original_name)
markdown = parse_result.get("markdown", "")
if not markdown.strip():
raise ValueError("文档解析结果为空")
file_record.status = "parsed"
if task:
task.progress = 40
await db.commit()
# Step 2分块
chunks = chunk_text(markdown, chunk_size=512, overlap=64)
logger.info(f"文件 {file_id} 分割为 {len(chunks)}")
# Step 3向量化分批处理
batch_size = 16
col = get_milvus_collection("regulation_chunks")
for i in range(0, len(chunks), batch_size):
batch = chunks[i:i + batch_size]
texts = [c["content"] for c in batch]
embed_result = await embed_texts(texts, batch_size=batch_size)
dense_vecs = embed_result["dense"]
entities = [
[f"{file_id}_{c['idx']}" for c in batch],
[file_id] * len(batch),
[workspace_id] * len(batch),
[c["idx"] for c in batch],
[c["content"] for c in batch],
dense_vecs,
[{"filename": file_record.original_name, "page": c.get("page", 0)} for c in batch],
]
col.insert(entities)
if task:
task.progress = 40 + int(60 * (i + batch_size) / len(chunks))
await db.commit()
col.flush()
# 完成
file_record.status = "vectorized"
if task:
task.status = "completed"
task.progress = 100
task.completed_at = datetime.now(timezone.utc)
await db.commit()
logger.info(f"文件 {file_id} 处理完成")
except Exception as e:
logger.error(f"文件 {file_id} 处理失败:{e}")
file_record.status = "failed"
file_record.error_msg = str(e)
if task:
task.status = "failed"
task.error_msg = str(e)
await db.commit()
raise
# ── 法规监控任务 ────────────────────────────────
@celery_app.task(name="app.worker.run_all_monitors")
def run_all_monitors():
"""定时触发所有活跃监控源"""
import asyncio
asyncio.run(_run_all_monitors())
async def _run_all_monitors():
from sqlalchemy import select
from .core.deps import AsyncSessionLocal
from .models.db import RegulationSource
async with AsyncSessionLocal() as db:
result = await db.execute(
select(RegulationSource).where(RegulationSource.is_active == True)
)
sources = result.scalars().all()
for source in sources:
fetch_regulation_source.delay(str(source.id))
logger.info(f"触发监控源抓取:{source.name}")
@celery_app.task(name="app.worker.fetch_regulation_source", bind=True, max_retries=2)
def fetch_regulation_source(self, source_id: str):
import asyncio
asyncio.run(_fetch_source(source_id))
async def _fetch_source(source_id: str):
import hashlib
from sqlalchemy import select
from .core.deps import AsyncSessionLocal
from .models.db import RegulationSource, RegulationUpdate
from .services.monitor import check_source_for_updates
async with AsyncSessionLocal() as db:
result = await db.execute(
select(RegulationSource).where(RegulationSource.id == uuid.UUID(source_id))
)
source = result.scalar_one_or_none()
if not source:
return
source_dict = {
"id": str(source.id),
"name": source.name,
"url": source.url,
"last_hash": source.last_hash,
}
update_data = await check_source_for_updates(source_dict)
if update_data:
logger.info(f"检测到变更:{source.name}")
source.last_hash = update_data["new_hash"]
source.last_fetched_at = datetime.now(timezone.utc)
update = RegulationUpdate(
source_id=uuid.UUID(source_id),
change_type="updated",
raw_content=update_data["raw_content"][:50000],
importance="normal",
)
db.add(update)
await db.commit()
else:
source.last_fetched_at = datetime.now(timezone.utc)
await db.commit()
@celery_app.task(name="app.worker.send_notifications")
def send_notifications():
logger.info("推送通知任务执行(待实现)")
# 导出供 FastAPI 使用
worker = celery_app