Files
AIRegulation-Deployment/services/compliance-backend/app/worker.py

235 lines
8.8 KiB
Python
Raw Normal View History

2026-04-23 09:58:47 +08:00
import uuid
import logging
from datetime import datetime, timezone
from celery import Celery
from celery.schedules import crontab
from .core.config import settings
logger = logging.getLogger(__name__)
# Celery 配置
celery_app = Celery(
"compliance",
broker=settings.redis_url,
backend=settings.redis_url,
)
celery_app.conf.update(
task_serializer="json",
accept_content=["json"],
result_serializer="json",
timezone="Asia/Shanghai",
task_routes={
"app.worker.process_file_task": {"queue": "parse"},
"app.worker.fetch_regulation_source": {"queue": "monitor"},
"app.worker.send_notifications": {"queue": "push"},
},
beat_schedule={
"daily-regulation-monitor": {
"task": "app.worker.run_all_monitors",
"schedule": crontab(hour=2, minute=0),
},
},
)
# ── 文件处理任务(解析 + 向量化)────────────────
@celery_app.task(name="app.worker.process_file_task", bind=True, max_retries=3)
def process_file_task(self, file_id: str, task_id: str, workspace_id: str):
"""解析文档并向量化存入 Milvus"""
import asyncio
asyncio.run(_process_file(file_id, task_id, workspace_id))
async def _process_file(file_id: str, task_id: str, workspace_id: str):
from pathlib import Path
from sqlalchemy import select
from .core.deps import AsyncSessionLocal, get_milvus_collection
from .models.db import File, Task
2026-04-23 14:50:24 +08:00
from .services.parse import parse_document
2026-04-23 09:58:47 +08:00
from .services.embed import embed_texts
2026-04-23 14:50:24 +08:00
from .services.regulation_parser import extract_regulation_meta, legal_chunk
2026-04-23 09:58:47 +08:00
async with AsyncSessionLocal() as db:
# 查找文件记录
result = await db.execute(select(File).where(File.id == uuid.UUID(file_id)))
file_record = result.scalar_one_or_none()
if not file_record:
logger.error(f"文件 {file_id} 不存在")
return
task_result = await db.execute(select(Task).where(Task.id == uuid.UUID(task_id)))
task = task_result.scalar_one_or_none()
try:
# 更新状态
file_record.status = "parsing"
if task:
task.status = "running"
task.progress = 10
await db.commit()
2026-04-23 14:50:24 +08:00
# Step 1解析文档调用 mcp-server
2026-04-23 09:58:47 +08:00
file_content = Path(file_record.storage_path).read_bytes()
parse_result = await parse_document(file_content, file_record.original_name)
markdown = parse_result.get("markdown", "")
if not markdown.strip():
raise ValueError("文档解析结果为空")
2026-04-23 14:50:24 +08:00
# Step 2提取法规元数据发布机关/文号/施行日期/法规类型)
reg_meta = extract_regulation_meta(markdown)
2026-04-23 09:58:47 +08:00
file_record.status = "parsed"
2026-04-23 14:50:24 +08:00
file_record.metadata = {
"regulation_name": reg_meta.regulation_name,
"issuing_authority": reg_meta.issuing_authority,
"doc_number": reg_meta.doc_number,
"effective_date": reg_meta.effective_date,
"regulation_type": reg_meta.regulation_type,
"parser": parse_result.get("parser", ""),
"page_count": parse_result.get("page_count", 0),
}
2026-04-23 09:58:47 +08:00
if task:
task.progress = 40
await db.commit()
2026-04-23 14:50:24 +08:00
# Step 3法规专用分块按章/条边界分割,保留条款号)
chunks = legal_chunk(markdown, reg_meta, chunk_size=512, overlap=64)
logger.info(f"文件 {file_id} 分割为 {len(chunks)} 块,法规:{reg_meta.regulation_name!r}")
2026-04-23 09:58:47 +08:00
2026-04-23 14:50:24 +08:00
# Step 4向量化并写入 Milvus分批处理
2026-04-23 09:58:47 +08:00
batch_size = 16
col = get_milvus_collection("regulation_chunks")
for i in range(0, len(chunks), batch_size):
batch = chunks[i:i + batch_size]
texts = [c["content"] for c in batch]
embed_result = await embed_texts(texts, batch_size=batch_size)
2026-04-23 14:50:24 +08:00
dense_vecs = embed_result["dense"] # list[list[float]], 1024维
sparse_vecs = embed_result.get("sparse", [{}] * len(batch)) # list[dict[str,float]]
2026-04-23 09:58:47 +08:00
entities = [
2026-04-23 14:50:24 +08:00
[f"{file_id}_{c['idx']}" for c in batch], # pk
[file_id] * len(batch), # file_id
[workspace_id] * len(batch), # workspace_id
[c["idx"] for c in batch], # chunk_idx
[c["content"] for c in batch], # content
dense_vecs, # dense_vec
sparse_vecs, # sparse_vec
[c["clause_no"] for c in batch], # clause_no
[c["article_no"] for c in batch], # article_no
[c["regulation_name"] for c in batch], # regulation_name
[{ # metadata
"filename": file_record.original_name,
"page": c.get("page", 0),
"doc_number": reg_meta.doc_number,
} for c in batch],
2026-04-23 09:58:47 +08:00
]
col.insert(entities)
if task:
task.progress = 40 + int(60 * (i + batch_size) / len(chunks))
await db.commit()
col.flush()
# 完成
file_record.status = "vectorized"
if task:
task.status = "completed"
task.progress = 100
task.completed_at = datetime.now(timezone.utc)
await db.commit()
2026-04-23 14:50:24 +08:00
logger.info(f"文件 {file_id} 处理完成,共 {len(chunks)} 个向量块")
2026-04-23 09:58:47 +08:00
except Exception as e:
logger.error(f"文件 {file_id} 处理失败:{e}")
file_record.status = "failed"
file_record.error_msg = str(e)
if task:
task.status = "failed"
task.error_msg = str(e)
await db.commit()
raise
# ── 法规监控任务 ────────────────────────────────
@celery_app.task(name="app.worker.run_all_monitors")
def run_all_monitors():
"""定时触发所有活跃监控源"""
import asyncio
asyncio.run(_run_all_monitors())
async def _run_all_monitors():
from sqlalchemy import select
from .core.deps import AsyncSessionLocal
from .models.db import RegulationSource
async with AsyncSessionLocal() as db:
result = await db.execute(
select(RegulationSource).where(RegulationSource.is_active == True)
)
sources = result.scalars().all()
for source in sources:
fetch_regulation_source.delay(str(source.id))
logger.info(f"触发监控源抓取:{source.name}")
@celery_app.task(name="app.worker.fetch_regulation_source", bind=True, max_retries=2)
def fetch_regulation_source(self, source_id: str):
import asyncio
asyncio.run(_fetch_source(source_id))
async def _fetch_source(source_id: str):
import hashlib
from sqlalchemy import select
from .core.deps import AsyncSessionLocal
from .models.db import RegulationSource, RegulationUpdate
from .services.monitor import check_source_for_updates
async with AsyncSessionLocal() as db:
result = await db.execute(
select(RegulationSource).where(RegulationSource.id == uuid.UUID(source_id))
)
source = result.scalar_one_or_none()
if not source:
return
source_dict = {
"id": str(source.id),
"name": source.name,
"url": source.url,
"last_hash": source.last_hash,
}
update_data = await check_source_for_updates(source_dict)
if update_data:
logger.info(f"检测到变更:{source.name}")
source.last_hash = update_data["new_hash"]
source.last_fetched_at = datetime.now(timezone.utc)
update = RegulationUpdate(
source_id=uuid.UUID(source_id),
change_type="updated",
raw_content=update_data["raw_content"][:50000],
importance="normal",
)
db.add(update)
await db.commit()
else:
source.last_fetched_at = datetime.now(timezone.utc)
await db.commit()
@celery_app.task(name="app.worker.send_notifications")
def send_notifications():
logger.info("推送通知任务执行(待实现)")
# 导出供 FastAPI 使用
worker = celery_app