import uuid import logging from datetime import datetime, timezone from celery import Celery from celery.schedules import crontab from .core.config import settings logger = logging.getLogger(__name__) # Celery 配置 celery_app = Celery( "compliance", broker=settings.redis_url, backend=settings.redis_url, ) celery_app.conf.update( task_serializer="json", accept_content=["json"], result_serializer="json", timezone="Asia/Shanghai", task_routes={ "app.worker.process_file_task": {"queue": "parse"}, "app.worker.fetch_regulation_source": {"queue": "monitor"}, "app.worker.send_notifications": {"queue": "push"}, }, beat_schedule={ "daily-regulation-monitor": { "task": "app.worker.run_all_monitors", "schedule": crontab(hour=2, minute=0), }, }, ) # ── 文件处理任务(解析 + 向量化)──────────────── @celery_app.task(name="app.worker.process_file_task", bind=True, max_retries=3) def process_file_task(self, file_id: str, task_id: str, workspace_id: str): """解析文档并向量化存入 Milvus""" import asyncio asyncio.run(_process_file(file_id, task_id, workspace_id)) async def _process_file(file_id: str, task_id: str, workspace_id: str): from pathlib import Path from sqlalchemy import select from .core.deps import AsyncSessionLocal, get_milvus_collection from .models.db import File, Task from .services.parse import parse_document, chunk_text from .services.embed import embed_texts async with AsyncSessionLocal() as db: # 查找文件记录 result = await db.execute(select(File).where(File.id == uuid.UUID(file_id))) file_record = result.scalar_one_or_none() if not file_record: logger.error(f"文件 {file_id} 不存在") return task_result = await db.execute(select(Task).where(Task.id == uuid.UUID(task_id))) task = task_result.scalar_one_or_none() try: # 更新状态 file_record.status = "parsing" if task: task.status = "running" task.progress = 10 await db.commit() # Step 1:解析文档 file_content = Path(file_record.storage_path).read_bytes() parse_result = await parse_document(file_content, file_record.original_name) markdown = parse_result.get("markdown", "") if not markdown.strip(): raise ValueError("文档解析结果为空") file_record.status = "parsed" if task: task.progress = 40 await db.commit() # Step 2:分块 chunks = chunk_text(markdown, chunk_size=512, overlap=64) logger.info(f"文件 {file_id} 分割为 {len(chunks)} 块") # Step 3:向量化(分批处理) batch_size = 16 col = get_milvus_collection("regulation_chunks") for i in range(0, len(chunks), batch_size): batch = chunks[i:i + batch_size] texts = [c["content"] for c in batch] embed_result = await embed_texts(texts, batch_size=batch_size) dense_vecs = embed_result["dense"] entities = [ [f"{file_id}_{c['idx']}" for c in batch], [file_id] * len(batch), [workspace_id] * len(batch), [c["idx"] for c in batch], [c["content"] for c in batch], dense_vecs, [{"filename": file_record.original_name, "page": c.get("page", 0)} for c in batch], ] col.insert(entities) if task: task.progress = 40 + int(60 * (i + batch_size) / len(chunks)) await db.commit() col.flush() # 完成 file_record.status = "vectorized" if task: task.status = "completed" task.progress = 100 task.completed_at = datetime.now(timezone.utc) await db.commit() logger.info(f"文件 {file_id} 处理完成") except Exception as e: logger.error(f"文件 {file_id} 处理失败:{e}") file_record.status = "failed" file_record.error_msg = str(e) if task: task.status = "failed" task.error_msg = str(e) await db.commit() raise # ── 法规监控任务 ──────────────────────────────── @celery_app.task(name="app.worker.run_all_monitors") def run_all_monitors(): """定时触发所有活跃监控源""" import asyncio asyncio.run(_run_all_monitors()) async def _run_all_monitors(): from sqlalchemy import select from .core.deps import AsyncSessionLocal from .models.db import RegulationSource async with AsyncSessionLocal() as db: result = await db.execute( select(RegulationSource).where(RegulationSource.is_active == True) ) sources = result.scalars().all() for source in sources: fetch_regulation_source.delay(str(source.id)) logger.info(f"触发监控源抓取:{source.name}") @celery_app.task(name="app.worker.fetch_regulation_source", bind=True, max_retries=2) def fetch_regulation_source(self, source_id: str): import asyncio asyncio.run(_fetch_source(source_id)) async def _fetch_source(source_id: str): import hashlib from sqlalchemy import select from .core.deps import AsyncSessionLocal from .models.db import RegulationSource, RegulationUpdate from .services.monitor import check_source_for_updates async with AsyncSessionLocal() as db: result = await db.execute( select(RegulationSource).where(RegulationSource.id == uuid.UUID(source_id)) ) source = result.scalar_one_or_none() if not source: return source_dict = { "id": str(source.id), "name": source.name, "url": source.url, "last_hash": source.last_hash, } update_data = await check_source_for_updates(source_dict) if update_data: logger.info(f"检测到变更:{source.name}") source.last_hash = update_data["new_hash"] source.last_fetched_at = datetime.now(timezone.utc) update = RegulationUpdate( source_id=uuid.UUID(source_id), change_type="updated", raw_content=update_data["raw_content"][:50000], importance="normal", ) db.add(update) await db.commit() else: source.last_fetched_at = datetime.now(timezone.utc) await db.commit() @celery_app.task(name="app.worker.send_notifications") def send_notifications(): logger.info("推送通知任务执行(待实现)") # 导出供 FastAPI 使用 worker = celery_app