Files
catonline_ai/vw-document-ai-indexer/task_processor.py
2025-09-26 17:15:54 +08:00

244 lines
10 KiB
Python

import time
from typing import List, Any, Optional, Dict
import logging
from dataclasses import dataclass, field
import json
import datetime
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from abc import ABC, abstractmethod
from sqlalchemy import and_
from sqlalchemy.orm import sessionmaker
from database import IndexJobStatus, IndexJob
from utils import custom_serializer
@dataclass
class Task:
"""Task object"""
id: str
payload: Any
priority: int = 0
status: IndexJobStatus = IndexJobStatus.PENDING
created_at: float = field(default_factory=time.time)
started_at: Optional[float] = None
completed_at: Optional[float] = None
error: Optional[Exception] = None
result: Any = None
def __lt__(self, other):
"""Used for priority queue sorting"""
return self.priority > other.priority
@dataclass
class ProcessingStats:
"""Processing statistics information"""
total_tasks: int = 0
completed_tasks: int = 0
failed_tasks: int = 0
cancelled_tasks: int = 0
average_processing_time: float = 0.0
throughput: float = 0.0 # Number of tasks processed per second
start_time: datetime.datetime = datetime.datetime.now()
@property
def success_rate(self) -> float:
"""Success rate"""
if self.total_tasks == 0:
return 0.0
return self.completed_tasks / self.total_tasks
@property
def pending_tasks(self) -> int:
"""Number of pending tasks"""
return self.total_tasks - self.completed_tasks - self.failed_tasks - self.cancelled_tasks
@property
def elapsed_time(self) -> float:
"""Elapsed time"""
time_diff = datetime.datetime.now() - self.start_time
return time_diff.total_seconds()
@property
def eta(self) -> float:
"""Estimated remaining time"""
if self.completed_tasks == 0:
return 0.0
rate = self.completed_tasks / self.elapsed_time
if rate == 0:
return 0.0
return self.pending_tasks / rate
class TaskProcessorInterface(ABC):
@abstractmethod
def process(self, task: Task) -> Any:
pass
class TaskProcessor:
"""Task processor"""
def __init__(self,
task_processor: TaskProcessorInterface,
max_workers: int = 4,
logger: Optional[logging.Logger] = None,
database_engine: Optional[Any] = None,
data_config:Optional[dict[str,Any]] = None):
if data_config is None:
raise ValueError("data_config must be provided")
self.task_processor = task_processor
self.max_workers = max_workers
self.logger = logger or logging.getLogger(__name__)
self.database_engine = database_engine
# Simple statistics
self.total_tasks = 0
self.completed_tasks = 0
self.failed_tasks = 0
self.start_time:datetime.datetime|None = None
# Processing report collection
self.processing_reports: List[Dict[str, Any]] = []
# Control variable
self.should_stop = False
self.data_config = data_config
self.datasource_name: str = data_config.get("datasource_name", "default")
def process_tasks(self, tasks: List[Any]) -> None:
"""Process task list - simple and effective"""
self.total_tasks = len(tasks)
self.completed_tasks = 0
self.failed_tasks = 0
self.start_time = datetime.datetime.now()
self.processing_reports = []
self.logger.info(f"Starting to process {self.total_tasks} tasks")
# Use thread pool to process tasks
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# Submit all tasks
future_to_task = {executor.submit(self._process_single_task, task): task
for task in tasks}
# Wait for tasks to complete
for future in as_completed(future_to_task):
if self.should_stop:
break
task = future_to_task[future]
try:
result = future.result()
self.completed_tasks += 1
# Record successful processing report
report:dict[str,Any] = { 'task_id': getattr(task, 'id', 'unknown'), 'status': 'success', 'message': getattr(result, 'message', 'Processing completed'), 'chunks_count': getattr(result, 'chunks_count', 0), 'processing_time': getattr(result, 'processing_time', 0) }
self.processing_reports.append(report)
# Output progress every 1 task
self._log_progress()
except Exception:
self.failed_tasks += 1
self.logger.error(f"Task processing failed: {traceback.format_exc()}")
# Record failed processing report
report = { 'task_id': getattr(task, 'id', 'unknown'), 'status': 'failed', 'error': traceback.format_exc(), 'processing_time': 0 }
self.processing_reports.append(report)
# Output final statistics
self.finalize_job_status_and_log()
def _process_single_task(self, task: Any) -> Any:
"""Process a single task"""
return self.task_processor.process(task)
def get_processing_reports(self) -> List[Dict[str, Any]]:
"""Get processing reports"""
return self.processing_reports
def _log_progress(self) -> None:
"""Output progress information (estimate remaining time based on average time per processed document)"""
if self.start_time is None:
return
elapsed = (datetime.datetime.now() - self.start_time).total_seconds() if self.start_time else 0
total_processed = self.completed_tasks + self.failed_tasks
remaining = self.total_tasks - total_processed
# Total processing time for processed tasks
total_processing_time = sum(r.get('processing_time', 0) for r in self.processing_reports)
avg_processing_time = (total_processing_time / total_processed) if total_processed > 0 else 0
eta = avg_processing_time * remaining
if total_processed > 0:
rate = total_processed / elapsed if elapsed > 0 else 0
self.logger.info(
f"Progress: {total_processed}/{self.total_tasks} "
f"({100.0 * total_processed / self.total_tasks:.1f}%) "
f"Success: {self.completed_tasks} Failed: {self.failed_tasks} "
f"Rate: {rate:.2f} tasks/second "
f"Average time: {avg_processing_time:.2f} seconds/task "
f"Estimated remaining: {eta / 60:.1f} minutes"
)
def finalize_job_status_and_log(self) -> None:
"""Statistics, write IndexJob status, and output all log details."""
elapsed = (datetime.datetime.now() - self.start_time).total_seconds() if self.start_time else 0
success_count = self.completed_tasks
fail_count = self.failed_tasks
total_count = self.total_tasks
success_rate = (success_count / total_count * 100) if total_count > 0 else 0.0
status = IndexJobStatus.FAILED.value
if total_count == success_count:
status = IndexJobStatus.SUCCESS.value
elif success_count > 0 and fail_count > 0:
status = IndexJobStatus.PARTIAL_SUCCESS.value
report:dict[str,Any] = {
"status": status,
"success_rate": f"{success_rate:.4f}%",
"total_tasks": total_count,
"completed": success_count,
"failed": fail_count,
"start_time": self.start_time,
"end_time": datetime.datetime.now(datetime.timezone.utc),
"processing_time": f"{elapsed:.4f} sec",
"total_elapsed": f"{elapsed / 3600:.4f} hours ",
"average_speed": f"{total_count / elapsed:.5f} tasks/sec" if elapsed > 0 else "average speed: 0 tasks/sec"
}
# Database write section
if self.database_engine:
try:
Session = sessionmaker(bind=self.database_engine)
session = Session()
try:
current_job = session.query(IndexJob).filter(and_(IndexJob.status == "processing",IndexJob.datasource_name==self.datasource_name)).order_by(IndexJob.id.desc()).first()
if current_job:
setattr(current_job, 'finished_time', report["end_time"])
setattr(current_job, 'success_object_count', success_count - fail_count)
setattr(current_job, 'failed_object_count', fail_count)
setattr(current_job, 'detailed_message', json.dumps(report, default=custom_serializer, ensure_ascii=False))
session.commit()
self.logger.info(f"IndexJob status updated: {current_job.status}, Success: {current_job.success_object_count}, Failed: {current_job.failed_object_count}")
else:
self.logger.warning("No IndexJob record with processing status found")
finally:
session.close()
except Exception as e:
self.logger.error(f"Failed to update IndexJob status: {e}")
# Output merged report content
self.logger.info(f"Final report: {json.dumps(report, default=custom_serializer, ensure_ascii=False)}")
if self.processing_reports:
success_reports = [r for r in self.processing_reports if r['status'] == 'success']
failed_reports = [r for r in self.processing_reports if r['status'] == 'failed']
if success_reports:
total_chunks = sum(r.get('chunks_count', 0) for r in success_reports)
avg_processing_time = sum(r.get('processing_time', 0) for r in success_reports) / len(success_reports)
self.logger.info(f"Success reports: {len(success_reports)} tasks, total {total_chunks} chunks, average processing time {avg_processing_time:.2f} sec")
if failed_reports:
self.logger.error(f"Failed reports: {len(failed_reports)} tasks")
for r in failed_reports[:5]:
self.logger.error(f" - {r['task_id']}: {r['error']}")