first commit

This commit is contained in:
2026-04-23 09:58:47 +08:00
commit 448e078d99
49 changed files with 5188 additions and 0 deletions

View File

@@ -0,0 +1,212 @@
import uuid
import logging
from datetime import datetime, timezone
from celery import Celery
from celery.schedules import crontab
from .core.config import settings
logger = logging.getLogger(__name__)
# Celery 配置
celery_app = Celery(
"compliance",
broker=settings.redis_url,
backend=settings.redis_url,
)
celery_app.conf.update(
task_serializer="json",
accept_content=["json"],
result_serializer="json",
timezone="Asia/Shanghai",
task_routes={
"app.worker.process_file_task": {"queue": "parse"},
"app.worker.fetch_regulation_source": {"queue": "monitor"},
"app.worker.send_notifications": {"queue": "push"},
},
beat_schedule={
"daily-regulation-monitor": {
"task": "app.worker.run_all_monitors",
"schedule": crontab(hour=2, minute=0),
},
},
)
# ── 文件处理任务(解析 + 向量化)────────────────
@celery_app.task(name="app.worker.process_file_task", bind=True, max_retries=3)
def process_file_task(self, file_id: str, task_id: str, workspace_id: str):
"""解析文档并向量化存入 Milvus"""
import asyncio
asyncio.run(_process_file(file_id, task_id, workspace_id))
async def _process_file(file_id: str, task_id: str, workspace_id: str):
from pathlib import Path
from sqlalchemy import select
from .core.deps import AsyncSessionLocal, get_milvus_collection
from .models.db import File, Task
from .services.parse import parse_document, chunk_text
from .services.embed import embed_texts
async with AsyncSessionLocal() as db:
# 查找文件记录
result = await db.execute(select(File).where(File.id == uuid.UUID(file_id)))
file_record = result.scalar_one_or_none()
if not file_record:
logger.error(f"文件 {file_id} 不存在")
return
task_result = await db.execute(select(Task).where(Task.id == uuid.UUID(task_id)))
task = task_result.scalar_one_or_none()
try:
# 更新状态
file_record.status = "parsing"
if task:
task.status = "running"
task.progress = 10
await db.commit()
# Step 1解析文档
file_content = Path(file_record.storage_path).read_bytes()
parse_result = await parse_document(file_content, file_record.original_name)
markdown = parse_result.get("markdown", "")
if not markdown.strip():
raise ValueError("文档解析结果为空")
file_record.status = "parsed"
if task:
task.progress = 40
await db.commit()
# Step 2分块
chunks = chunk_text(markdown, chunk_size=512, overlap=64)
logger.info(f"文件 {file_id} 分割为 {len(chunks)}")
# Step 3向量化分批处理
batch_size = 16
col = get_milvus_collection("regulation_chunks")
for i in range(0, len(chunks), batch_size):
batch = chunks[i:i + batch_size]
texts = [c["content"] for c in batch]
embed_result = await embed_texts(texts, batch_size=batch_size)
dense_vecs = embed_result["dense"]
entities = [
[f"{file_id}_{c['idx']}" for c in batch],
[file_id] * len(batch),
[workspace_id] * len(batch),
[c["idx"] for c in batch],
[c["content"] for c in batch],
dense_vecs,
[{"filename": file_record.original_name, "page": c.get("page", 0)} for c in batch],
]
col.insert(entities)
if task:
task.progress = 40 + int(60 * (i + batch_size) / len(chunks))
await db.commit()
col.flush()
# 完成
file_record.status = "vectorized"
if task:
task.status = "completed"
task.progress = 100
task.completed_at = datetime.now(timezone.utc)
await db.commit()
logger.info(f"文件 {file_id} 处理完成")
except Exception as e:
logger.error(f"文件 {file_id} 处理失败:{e}")
file_record.status = "failed"
file_record.error_msg = str(e)
if task:
task.status = "failed"
task.error_msg = str(e)
await db.commit()
raise
# ── 法规监控任务 ────────────────────────────────
@celery_app.task(name="app.worker.run_all_monitors")
def run_all_monitors():
"""定时触发所有活跃监控源"""
import asyncio
asyncio.run(_run_all_monitors())
async def _run_all_monitors():
from sqlalchemy import select
from .core.deps import AsyncSessionLocal
from .models.db import RegulationSource
async with AsyncSessionLocal() as db:
result = await db.execute(
select(RegulationSource).where(RegulationSource.is_active == True)
)
sources = result.scalars().all()
for source in sources:
fetch_regulation_source.delay(str(source.id))
logger.info(f"触发监控源抓取:{source.name}")
@celery_app.task(name="app.worker.fetch_regulation_source", bind=True, max_retries=2)
def fetch_regulation_source(self, source_id: str):
import asyncio
asyncio.run(_fetch_source(source_id))
async def _fetch_source(source_id: str):
import hashlib
from sqlalchemy import select
from .core.deps import AsyncSessionLocal
from .models.db import RegulationSource, RegulationUpdate
from .services.monitor import check_source_for_updates
async with AsyncSessionLocal() as db:
result = await db.execute(
select(RegulationSource).where(RegulationSource.id == uuid.UUID(source_id))
)
source = result.scalar_one_or_none()
if not source:
return
source_dict = {
"id": str(source.id),
"name": source.name,
"url": source.url,
"last_hash": source.last_hash,
}
update_data = await check_source_for_updates(source_dict)
if update_data:
logger.info(f"检测到变更:{source.name}")
source.last_hash = update_data["new_hash"]
source.last_fetched_at = datetime.now(timezone.utc)
update = RegulationUpdate(
source_id=uuid.UUID(source_id),
change_type="updated",
raw_content=update_data["raw_content"][:50000],
importance="normal",
)
db.add(update)
await db.commit()
else:
source.last_fetched_at = datetime.now(timezone.utc)
await db.commit()
@celery_app.task(name="app.worker.send_notifications")
def send_notifications():
logger.info("推送通知任务执行(待实现)")
# 导出供 FastAPI 使用
worker = celery_app