first commit
This commit is contained in:
212
services/compliance-backend/app/worker.py
Normal file
212
services/compliance-backend/app/worker.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import uuid
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from celery import Celery
|
||||
from celery.schedules import crontab
|
||||
|
||||
from .core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Celery 配置
|
||||
celery_app = Celery(
|
||||
"compliance",
|
||||
broker=settings.redis_url,
|
||||
backend=settings.redis_url,
|
||||
)
|
||||
celery_app.conf.update(
|
||||
task_serializer="json",
|
||||
accept_content=["json"],
|
||||
result_serializer="json",
|
||||
timezone="Asia/Shanghai",
|
||||
task_routes={
|
||||
"app.worker.process_file_task": {"queue": "parse"},
|
||||
"app.worker.fetch_regulation_source": {"queue": "monitor"},
|
||||
"app.worker.send_notifications": {"queue": "push"},
|
||||
},
|
||||
beat_schedule={
|
||||
"daily-regulation-monitor": {
|
||||
"task": "app.worker.run_all_monitors",
|
||||
"schedule": crontab(hour=2, minute=0),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# ── 文件处理任务(解析 + 向量化)────────────────
|
||||
|
||||
@celery_app.task(name="app.worker.process_file_task", bind=True, max_retries=3)
|
||||
def process_file_task(self, file_id: str, task_id: str, workspace_id: str):
|
||||
"""解析文档并向量化存入 Milvus"""
|
||||
import asyncio
|
||||
asyncio.run(_process_file(file_id, task_id, workspace_id))
|
||||
|
||||
|
||||
async def _process_file(file_id: str, task_id: str, workspace_id: str):
|
||||
from pathlib import Path
|
||||
from sqlalchemy import select
|
||||
from .core.deps import AsyncSessionLocal, get_milvus_collection
|
||||
from .models.db import File, Task
|
||||
from .services.parse import parse_document, chunk_text
|
||||
from .services.embed import embed_texts
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
# 查找文件记录
|
||||
result = await db.execute(select(File).where(File.id == uuid.UUID(file_id)))
|
||||
file_record = result.scalar_one_or_none()
|
||||
if not file_record:
|
||||
logger.error(f"文件 {file_id} 不存在")
|
||||
return
|
||||
|
||||
task_result = await db.execute(select(Task).where(Task.id == uuid.UUID(task_id)))
|
||||
task = task_result.scalar_one_or_none()
|
||||
|
||||
try:
|
||||
# 更新状态
|
||||
file_record.status = "parsing"
|
||||
if task:
|
||||
task.status = "running"
|
||||
task.progress = 10
|
||||
await db.commit()
|
||||
|
||||
# Step 1:解析文档
|
||||
file_content = Path(file_record.storage_path).read_bytes()
|
||||
parse_result = await parse_document(file_content, file_record.original_name)
|
||||
markdown = parse_result.get("markdown", "")
|
||||
|
||||
if not markdown.strip():
|
||||
raise ValueError("文档解析结果为空")
|
||||
|
||||
file_record.status = "parsed"
|
||||
if task:
|
||||
task.progress = 40
|
||||
await db.commit()
|
||||
|
||||
# Step 2:分块
|
||||
chunks = chunk_text(markdown, chunk_size=512, overlap=64)
|
||||
logger.info(f"文件 {file_id} 分割为 {len(chunks)} 块")
|
||||
|
||||
# Step 3:向量化(分批处理)
|
||||
batch_size = 16
|
||||
col = get_milvus_collection("regulation_chunks")
|
||||
|
||||
for i in range(0, len(chunks), batch_size):
|
||||
batch = chunks[i:i + batch_size]
|
||||
texts = [c["content"] for c in batch]
|
||||
embed_result = await embed_texts(texts, batch_size=batch_size)
|
||||
dense_vecs = embed_result["dense"]
|
||||
|
||||
entities = [
|
||||
[f"{file_id}_{c['idx']}" for c in batch],
|
||||
[file_id] * len(batch),
|
||||
[workspace_id] * len(batch),
|
||||
[c["idx"] for c in batch],
|
||||
[c["content"] for c in batch],
|
||||
dense_vecs,
|
||||
[{"filename": file_record.original_name, "page": c.get("page", 0)} for c in batch],
|
||||
]
|
||||
col.insert(entities)
|
||||
|
||||
if task:
|
||||
task.progress = 40 + int(60 * (i + batch_size) / len(chunks))
|
||||
await db.commit()
|
||||
|
||||
col.flush()
|
||||
|
||||
# 完成
|
||||
file_record.status = "vectorized"
|
||||
if task:
|
||||
task.status = "completed"
|
||||
task.progress = 100
|
||||
task.completed_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
logger.info(f"文件 {file_id} 处理完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"文件 {file_id} 处理失败:{e}")
|
||||
file_record.status = "failed"
|
||||
file_record.error_msg = str(e)
|
||||
if task:
|
||||
task.status = "failed"
|
||||
task.error_msg = str(e)
|
||||
await db.commit()
|
||||
raise
|
||||
|
||||
|
||||
# ── 法规监控任务 ────────────────────────────────
|
||||
|
||||
@celery_app.task(name="app.worker.run_all_monitors")
|
||||
def run_all_monitors():
|
||||
"""定时触发所有活跃监控源"""
|
||||
import asyncio
|
||||
asyncio.run(_run_all_monitors())
|
||||
|
||||
|
||||
async def _run_all_monitors():
|
||||
from sqlalchemy import select
|
||||
from .core.deps import AsyncSessionLocal
|
||||
from .models.db import RegulationSource
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
result = await db.execute(
|
||||
select(RegulationSource).where(RegulationSource.is_active == True)
|
||||
)
|
||||
sources = result.scalars().all()
|
||||
for source in sources:
|
||||
fetch_regulation_source.delay(str(source.id))
|
||||
logger.info(f"触发监控源抓取:{source.name}")
|
||||
|
||||
|
||||
@celery_app.task(name="app.worker.fetch_regulation_source", bind=True, max_retries=2)
|
||||
def fetch_regulation_source(self, source_id: str):
|
||||
import asyncio
|
||||
asyncio.run(_fetch_source(source_id))
|
||||
|
||||
|
||||
async def _fetch_source(source_id: str):
|
||||
import hashlib
|
||||
from sqlalchemy import select
|
||||
from .core.deps import AsyncSessionLocal
|
||||
from .models.db import RegulationSource, RegulationUpdate
|
||||
from .services.monitor import check_source_for_updates
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
result = await db.execute(
|
||||
select(RegulationSource).where(RegulationSource.id == uuid.UUID(source_id))
|
||||
)
|
||||
source = result.scalar_one_or_none()
|
||||
if not source:
|
||||
return
|
||||
|
||||
source_dict = {
|
||||
"id": str(source.id),
|
||||
"name": source.name,
|
||||
"url": source.url,
|
||||
"last_hash": source.last_hash,
|
||||
}
|
||||
update_data = await check_source_for_updates(source_dict)
|
||||
|
||||
if update_data:
|
||||
logger.info(f"检测到变更:{source.name}")
|
||||
source.last_hash = update_data["new_hash"]
|
||||
source.last_fetched_at = datetime.now(timezone.utc)
|
||||
|
||||
update = RegulationUpdate(
|
||||
source_id=uuid.UUID(source_id),
|
||||
change_type="updated",
|
||||
raw_content=update_data["raw_content"][:50000],
|
||||
importance="normal",
|
||||
)
|
||||
db.add(update)
|
||||
await db.commit()
|
||||
else:
|
||||
source.last_fetched_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@celery_app.task(name="app.worker.send_notifications")
|
||||
def send_notifications():
|
||||
logger.info("推送通知任务执行(待实现)")
|
||||
|
||||
|
||||
# 导出供 FastAPI 使用
|
||||
worker = celery_app
|
||||
Reference in New Issue
Block a user