2026-04-23 09:58:47 +08:00
|
|
|
|
import uuid
|
|
|
|
|
|
import logging
|
|
|
|
|
|
from datetime import datetime, timezone
|
|
|
|
|
|
from celery import Celery
|
|
|
|
|
|
from celery.schedules import crontab
|
|
|
|
|
|
|
|
|
|
|
|
from .core.config import settings
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
# Celery 配置
|
|
|
|
|
|
celery_app = Celery(
|
|
|
|
|
|
"compliance",
|
|
|
|
|
|
broker=settings.redis_url,
|
|
|
|
|
|
backend=settings.redis_url,
|
|
|
|
|
|
)
|
|
|
|
|
|
celery_app.conf.update(
|
|
|
|
|
|
task_serializer="json",
|
|
|
|
|
|
accept_content=["json"],
|
|
|
|
|
|
result_serializer="json",
|
|
|
|
|
|
timezone="Asia/Shanghai",
|
|
|
|
|
|
task_routes={
|
|
|
|
|
|
"app.worker.process_file_task": {"queue": "parse"},
|
|
|
|
|
|
"app.worker.fetch_regulation_source": {"queue": "monitor"},
|
|
|
|
|
|
"app.worker.send_notifications": {"queue": "push"},
|
|
|
|
|
|
},
|
|
|
|
|
|
beat_schedule={
|
|
|
|
|
|
"daily-regulation-monitor": {
|
|
|
|
|
|
"task": "app.worker.run_all_monitors",
|
|
|
|
|
|
"schedule": crontab(hour=2, minute=0),
|
|
|
|
|
|
},
|
|
|
|
|
|
},
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# ── 文件处理任务(解析 + 向量化)────────────────
|
|
|
|
|
|
|
|
|
|
|
|
@celery_app.task(name="app.worker.process_file_task", bind=True, max_retries=3)
|
|
|
|
|
|
def process_file_task(self, file_id: str, task_id: str, workspace_id: str):
|
|
|
|
|
|
"""解析文档并向量化存入 Milvus"""
|
|
|
|
|
|
import asyncio
|
|
|
|
|
|
asyncio.run(_process_file(file_id, task_id, workspace_id))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _process_file(file_id: str, task_id: str, workspace_id: str):
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
|
from .core.deps import AsyncSessionLocal, get_milvus_collection
|
|
|
|
|
|
from .models.db import File, Task
|
2026-04-23 14:50:24 +08:00
|
|
|
|
from .services.parse import parse_document
|
2026-04-23 09:58:47 +08:00
|
|
|
|
from .services.embed import embed_texts
|
2026-04-23 14:50:24 +08:00
|
|
|
|
from .services.regulation_parser import extract_regulation_meta, legal_chunk
|
2026-04-23 09:58:47 +08:00
|
|
|
|
|
|
|
|
|
|
async with AsyncSessionLocal() as db:
|
|
|
|
|
|
# 查找文件记录
|
|
|
|
|
|
result = await db.execute(select(File).where(File.id == uuid.UUID(file_id)))
|
|
|
|
|
|
file_record = result.scalar_one_or_none()
|
|
|
|
|
|
if not file_record:
|
|
|
|
|
|
logger.error(f"文件 {file_id} 不存在")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
task_result = await db.execute(select(Task).where(Task.id == uuid.UUID(task_id)))
|
|
|
|
|
|
task = task_result.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 更新状态
|
|
|
|
|
|
file_record.status = "parsing"
|
|
|
|
|
|
if task:
|
|
|
|
|
|
task.status = "running"
|
|
|
|
|
|
task.progress = 10
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
|
2026-04-23 14:50:24 +08:00
|
|
|
|
# Step 1:解析文档(调用 mcp-server)
|
2026-04-23 09:58:47 +08:00
|
|
|
|
file_content = Path(file_record.storage_path).read_bytes()
|
|
|
|
|
|
parse_result = await parse_document(file_content, file_record.original_name)
|
|
|
|
|
|
markdown = parse_result.get("markdown", "")
|
|
|
|
|
|
|
|
|
|
|
|
if not markdown.strip():
|
|
|
|
|
|
raise ValueError("文档解析结果为空")
|
|
|
|
|
|
|
2026-04-23 14:50:24 +08:00
|
|
|
|
# Step 2:提取法规元数据(发布机关/文号/施行日期/法规类型)
|
|
|
|
|
|
reg_meta = extract_regulation_meta(markdown)
|
2026-04-23 09:58:47 +08:00
|
|
|
|
file_record.status = "parsed"
|
2026-04-23 14:50:24 +08:00
|
|
|
|
file_record.metadata = {
|
|
|
|
|
|
"regulation_name": reg_meta.regulation_name,
|
|
|
|
|
|
"issuing_authority": reg_meta.issuing_authority,
|
|
|
|
|
|
"doc_number": reg_meta.doc_number,
|
|
|
|
|
|
"effective_date": reg_meta.effective_date,
|
|
|
|
|
|
"regulation_type": reg_meta.regulation_type,
|
|
|
|
|
|
"parser": parse_result.get("parser", ""),
|
|
|
|
|
|
"page_count": parse_result.get("page_count", 0),
|
|
|
|
|
|
}
|
2026-04-23 09:58:47 +08:00
|
|
|
|
if task:
|
|
|
|
|
|
task.progress = 40
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
|
2026-04-23 14:50:24 +08:00
|
|
|
|
# Step 3:法规专用分块(按章/条边界分割,保留条款号)
|
|
|
|
|
|
chunks = legal_chunk(markdown, reg_meta, chunk_size=512, overlap=64)
|
|
|
|
|
|
logger.info(f"文件 {file_id} 分割为 {len(chunks)} 块,法规:{reg_meta.regulation_name!r}")
|
2026-04-23 09:58:47 +08:00
|
|
|
|
|
2026-04-23 14:50:24 +08:00
|
|
|
|
# Step 4:向量化并写入 Milvus(分批处理)
|
2026-04-23 09:58:47 +08:00
|
|
|
|
batch_size = 16
|
|
|
|
|
|
col = get_milvus_collection("regulation_chunks")
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(0, len(chunks), batch_size):
|
|
|
|
|
|
batch = chunks[i:i + batch_size]
|
|
|
|
|
|
texts = [c["content"] for c in batch]
|
|
|
|
|
|
embed_result = await embed_texts(texts, batch_size=batch_size)
|
2026-04-23 14:50:24 +08:00
|
|
|
|
|
|
|
|
|
|
dense_vecs = embed_result["dense"] # list[list[float]], 1024维
|
|
|
|
|
|
sparse_vecs = embed_result.get("sparse", [{}] * len(batch)) # list[dict[str,float]]
|
2026-04-23 09:58:47 +08:00
|
|
|
|
|
|
|
|
|
|
entities = [
|
2026-04-23 14:50:24 +08:00
|
|
|
|
[f"{file_id}_{c['idx']}" for c in batch], # pk
|
|
|
|
|
|
[file_id] * len(batch), # file_id
|
|
|
|
|
|
[workspace_id] * len(batch), # workspace_id
|
|
|
|
|
|
[c["idx"] for c in batch], # chunk_idx
|
|
|
|
|
|
[c["content"] for c in batch], # content
|
|
|
|
|
|
dense_vecs, # dense_vec
|
|
|
|
|
|
sparse_vecs, # sparse_vec
|
|
|
|
|
|
[c["clause_no"] for c in batch], # clause_no
|
|
|
|
|
|
[c["article_no"] for c in batch], # article_no
|
|
|
|
|
|
[c["regulation_name"] for c in batch], # regulation_name
|
|
|
|
|
|
[{ # metadata
|
|
|
|
|
|
"filename": file_record.original_name,
|
|
|
|
|
|
"page": c.get("page", 0),
|
|
|
|
|
|
"doc_number": reg_meta.doc_number,
|
|
|
|
|
|
} for c in batch],
|
2026-04-23 09:58:47 +08:00
|
|
|
|
]
|
|
|
|
|
|
col.insert(entities)
|
|
|
|
|
|
|
|
|
|
|
|
if task:
|
|
|
|
|
|
task.progress = 40 + int(60 * (i + batch_size) / len(chunks))
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
|
|
|
|
|
|
col.flush()
|
|
|
|
|
|
|
|
|
|
|
|
# 完成
|
|
|
|
|
|
file_record.status = "vectorized"
|
|
|
|
|
|
if task:
|
|
|
|
|
|
task.status = "completed"
|
|
|
|
|
|
task.progress = 100
|
|
|
|
|
|
task.completed_at = datetime.now(timezone.utc)
|
|
|
|
|
|
await db.commit()
|
2026-04-23 14:50:24 +08:00
|
|
|
|
logger.info(f"文件 {file_id} 处理完成,共 {len(chunks)} 个向量块")
|
2026-04-23 09:58:47 +08:00
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"文件 {file_id} 处理失败:{e}")
|
|
|
|
|
|
file_record.status = "failed"
|
|
|
|
|
|
file_record.error_msg = str(e)
|
|
|
|
|
|
if task:
|
|
|
|
|
|
task.status = "failed"
|
|
|
|
|
|
task.error_msg = str(e)
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ── 法规监控任务 ────────────────────────────────
|
|
|
|
|
|
|
|
|
|
|
|
@celery_app.task(name="app.worker.run_all_monitors")
|
|
|
|
|
|
def run_all_monitors():
|
|
|
|
|
|
"""定时触发所有活跃监控源"""
|
|
|
|
|
|
import asyncio
|
|
|
|
|
|
asyncio.run(_run_all_monitors())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _run_all_monitors():
|
|
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
|
from .core.deps import AsyncSessionLocal
|
|
|
|
|
|
from .models.db import RegulationSource
|
|
|
|
|
|
|
|
|
|
|
|
async with AsyncSessionLocal() as db:
|
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
|
select(RegulationSource).where(RegulationSource.is_active == True)
|
|
|
|
|
|
)
|
|
|
|
|
|
sources = result.scalars().all()
|
|
|
|
|
|
for source in sources:
|
|
|
|
|
|
fetch_regulation_source.delay(str(source.id))
|
|
|
|
|
|
logger.info(f"触发监控源抓取:{source.name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@celery_app.task(name="app.worker.fetch_regulation_source", bind=True, max_retries=2)
|
|
|
|
|
|
def fetch_regulation_source(self, source_id: str):
|
|
|
|
|
|
import asyncio
|
|
|
|
|
|
asyncio.run(_fetch_source(source_id))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _fetch_source(source_id: str):
|
|
|
|
|
|
import hashlib
|
|
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
|
from .core.deps import AsyncSessionLocal
|
|
|
|
|
|
from .models.db import RegulationSource, RegulationUpdate
|
|
|
|
|
|
from .services.monitor import check_source_for_updates
|
|
|
|
|
|
|
|
|
|
|
|
async with AsyncSessionLocal() as db:
|
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
|
select(RegulationSource).where(RegulationSource.id == uuid.UUID(source_id))
|
|
|
|
|
|
)
|
|
|
|
|
|
source = result.scalar_one_or_none()
|
|
|
|
|
|
if not source:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
source_dict = {
|
|
|
|
|
|
"id": str(source.id),
|
|
|
|
|
|
"name": source.name,
|
|
|
|
|
|
"url": source.url,
|
|
|
|
|
|
"last_hash": source.last_hash,
|
|
|
|
|
|
}
|
|
|
|
|
|
update_data = await check_source_for_updates(source_dict)
|
|
|
|
|
|
|
|
|
|
|
|
if update_data:
|
|
|
|
|
|
logger.info(f"检测到变更:{source.name}")
|
|
|
|
|
|
source.last_hash = update_data["new_hash"]
|
|
|
|
|
|
source.last_fetched_at = datetime.now(timezone.utc)
|
|
|
|
|
|
|
|
|
|
|
|
update = RegulationUpdate(
|
|
|
|
|
|
source_id=uuid.UUID(source_id),
|
|
|
|
|
|
change_type="updated",
|
|
|
|
|
|
raw_content=update_data["raw_content"][:50000],
|
|
|
|
|
|
importance="normal",
|
|
|
|
|
|
)
|
|
|
|
|
|
db.add(update)
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
else:
|
|
|
|
|
|
source.last_fetched_at = datetime.now(timezone.utc)
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@celery_app.task(name="app.worker.send_notifications")
|
|
|
|
|
|
def send_notifications():
|
|
|
|
|
|
logger.info("推送通知任务执行(待实现)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 导出供 FastAPI 使用
|
|
|
|
|
|
worker = celery_app
|