235 lines
8.8 KiB
Python
235 lines
8.8 KiB
Python
import uuid
|
||
import logging
|
||
from datetime import datetime, timezone
|
||
from celery import Celery
|
||
from celery.schedules import crontab
|
||
|
||
from .core.config import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Celery 配置
|
||
celery_app = Celery(
|
||
"compliance",
|
||
broker=settings.redis_url,
|
||
backend=settings.redis_url,
|
||
)
|
||
celery_app.conf.update(
|
||
task_serializer="json",
|
||
accept_content=["json"],
|
||
result_serializer="json",
|
||
timezone="Asia/Shanghai",
|
||
task_routes={
|
||
"app.worker.process_file_task": {"queue": "parse"},
|
||
"app.worker.fetch_regulation_source": {"queue": "monitor"},
|
||
"app.worker.send_notifications": {"queue": "push"},
|
||
},
|
||
beat_schedule={
|
||
"daily-regulation-monitor": {
|
||
"task": "app.worker.run_all_monitors",
|
||
"schedule": crontab(hour=2, minute=0),
|
||
},
|
||
},
|
||
)
|
||
|
||
# ── 文件处理任务(解析 + 向量化)────────────────
|
||
|
||
@celery_app.task(name="app.worker.process_file_task", bind=True, max_retries=3)
|
||
def process_file_task(self, file_id: str, task_id: str, workspace_id: str):
|
||
"""解析文档并向量化存入 Milvus"""
|
||
import asyncio
|
||
asyncio.run(_process_file(file_id, task_id, workspace_id))
|
||
|
||
|
||
async def _process_file(file_id: str, task_id: str, workspace_id: str):
|
||
from pathlib import Path
|
||
from sqlalchemy import select
|
||
from .core.deps import AsyncSessionLocal, get_milvus_collection
|
||
from .models.db import File, Task
|
||
from .services.parse import parse_document
|
||
from .services.embed import embed_texts
|
||
from .services.regulation_parser import extract_regulation_meta, legal_chunk
|
||
|
||
async with AsyncSessionLocal() as db:
|
||
# 查找文件记录
|
||
result = await db.execute(select(File).where(File.id == uuid.UUID(file_id)))
|
||
file_record = result.scalar_one_or_none()
|
||
if not file_record:
|
||
logger.error(f"文件 {file_id} 不存在")
|
||
return
|
||
|
||
task_result = await db.execute(select(Task).where(Task.id == uuid.UUID(task_id)))
|
||
task = task_result.scalar_one_or_none()
|
||
|
||
try:
|
||
# 更新状态
|
||
file_record.status = "parsing"
|
||
if task:
|
||
task.status = "running"
|
||
task.progress = 10
|
||
await db.commit()
|
||
|
||
# Step 1:解析文档(调用 mcp-server)
|
||
file_content = Path(file_record.storage_path).read_bytes()
|
||
parse_result = await parse_document(file_content, file_record.original_name)
|
||
markdown = parse_result.get("markdown", "")
|
||
|
||
if not markdown.strip():
|
||
raise ValueError("文档解析结果为空")
|
||
|
||
# Step 2:提取法规元数据(发布机关/文号/施行日期/法规类型)
|
||
reg_meta = extract_regulation_meta(markdown)
|
||
file_record.status = "parsed"
|
||
file_record.metadata = {
|
||
"regulation_name": reg_meta.regulation_name,
|
||
"issuing_authority": reg_meta.issuing_authority,
|
||
"doc_number": reg_meta.doc_number,
|
||
"effective_date": reg_meta.effective_date,
|
||
"regulation_type": reg_meta.regulation_type,
|
||
"parser": parse_result.get("parser", ""),
|
||
"page_count": parse_result.get("page_count", 0),
|
||
}
|
||
if task:
|
||
task.progress = 40
|
||
await db.commit()
|
||
|
||
# Step 3:法规专用分块(按章/条边界分割,保留条款号)
|
||
chunks = legal_chunk(markdown, reg_meta, chunk_size=512, overlap=64)
|
||
logger.info(f"文件 {file_id} 分割为 {len(chunks)} 块,法规:{reg_meta.regulation_name!r}")
|
||
|
||
# Step 4:向量化并写入 Milvus(分批处理)
|
||
batch_size = 16
|
||
col = get_milvus_collection("regulation_chunks")
|
||
|
||
for i in range(0, len(chunks), batch_size):
|
||
batch = chunks[i:i + batch_size]
|
||
texts = [c["content"] for c in batch]
|
||
embed_result = await embed_texts(texts, batch_size=batch_size)
|
||
|
||
dense_vecs = embed_result["dense"] # list[list[float]], 1024维
|
||
sparse_vecs = embed_result.get("sparse", [{}] * len(batch)) # list[dict[str,float]]
|
||
|
||
entities = [
|
||
[f"{file_id}_{c['idx']}" for c in batch], # pk
|
||
[file_id] * len(batch), # file_id
|
||
[workspace_id] * len(batch), # workspace_id
|
||
[c["idx"] for c in batch], # chunk_idx
|
||
[c["content"] for c in batch], # content
|
||
dense_vecs, # dense_vec
|
||
sparse_vecs, # sparse_vec
|
||
[c["clause_no"] for c in batch], # clause_no
|
||
[c["article_no"] for c in batch], # article_no
|
||
[c["regulation_name"] for c in batch], # regulation_name
|
||
[{ # metadata
|
||
"filename": file_record.original_name,
|
||
"page": c.get("page", 0),
|
||
"doc_number": reg_meta.doc_number,
|
||
} for c in batch],
|
||
]
|
||
col.insert(entities)
|
||
|
||
if task:
|
||
task.progress = 40 + int(60 * (i + batch_size) / len(chunks))
|
||
await db.commit()
|
||
|
||
col.flush()
|
||
|
||
# 完成
|
||
file_record.status = "vectorized"
|
||
if task:
|
||
task.status = "completed"
|
||
task.progress = 100
|
||
task.completed_at = datetime.now(timezone.utc)
|
||
await db.commit()
|
||
logger.info(f"文件 {file_id} 处理完成,共 {len(chunks)} 个向量块")
|
||
|
||
except Exception as e:
|
||
logger.error(f"文件 {file_id} 处理失败:{e}")
|
||
file_record.status = "failed"
|
||
file_record.error_msg = str(e)
|
||
if task:
|
||
task.status = "failed"
|
||
task.error_msg = str(e)
|
||
await db.commit()
|
||
raise
|
||
|
||
|
||
# ── 法规监控任务 ────────────────────────────────
|
||
|
||
@celery_app.task(name="app.worker.run_all_monitors")
|
||
def run_all_monitors():
|
||
"""定时触发所有活跃监控源"""
|
||
import asyncio
|
||
asyncio.run(_run_all_monitors())
|
||
|
||
|
||
async def _run_all_monitors():
|
||
from sqlalchemy import select
|
||
from .core.deps import AsyncSessionLocal
|
||
from .models.db import RegulationSource
|
||
|
||
async with AsyncSessionLocal() as db:
|
||
result = await db.execute(
|
||
select(RegulationSource).where(RegulationSource.is_active == True)
|
||
)
|
||
sources = result.scalars().all()
|
||
for source in sources:
|
||
fetch_regulation_source.delay(str(source.id))
|
||
logger.info(f"触发监控源抓取:{source.name}")
|
||
|
||
|
||
@celery_app.task(name="app.worker.fetch_regulation_source", bind=True, max_retries=2)
|
||
def fetch_regulation_source(self, source_id: str):
|
||
import asyncio
|
||
asyncio.run(_fetch_source(source_id))
|
||
|
||
|
||
async def _fetch_source(source_id: str):
|
||
import hashlib
|
||
from sqlalchemy import select
|
||
from .core.deps import AsyncSessionLocal
|
||
from .models.db import RegulationSource, RegulationUpdate
|
||
from .services.monitor import check_source_for_updates
|
||
|
||
async with AsyncSessionLocal() as db:
|
||
result = await db.execute(
|
||
select(RegulationSource).where(RegulationSource.id == uuid.UUID(source_id))
|
||
)
|
||
source = result.scalar_one_or_none()
|
||
if not source:
|
||
return
|
||
|
||
source_dict = {
|
||
"id": str(source.id),
|
||
"name": source.name,
|
||
"url": source.url,
|
||
"last_hash": source.last_hash,
|
||
}
|
||
update_data = await check_source_for_updates(source_dict)
|
||
|
||
if update_data:
|
||
logger.info(f"检测到变更:{source.name}")
|
||
source.last_hash = update_data["new_hash"]
|
||
source.last_fetched_at = datetime.now(timezone.utc)
|
||
|
||
update = RegulationUpdate(
|
||
source_id=uuid.UUID(source_id),
|
||
change_type="updated",
|
||
raw_content=update_data["raw_content"][:50000],
|
||
importance="normal",
|
||
)
|
||
db.add(update)
|
||
await db.commit()
|
||
else:
|
||
source.last_fetched_at = datetime.now(timezone.utc)
|
||
await db.commit()
|
||
|
||
|
||
@celery_app.task(name="app.worker.send_notifications")
|
||
def send_notifications():
|
||
logger.info("推送通知任务执行(待实现)")
|
||
|
||
|
||
# 导出供 FastAPI 使用
|
||
worker = celery_app
|