import time from typing import List, Any, Optional, Dict import logging from dataclasses import dataclass, field import json import datetime import traceback from concurrent.futures import ThreadPoolExecutor, as_completed from abc import ABC, abstractmethod from sqlalchemy import and_ from sqlalchemy.orm import sessionmaker from database import IndexJobStatus, IndexJob from utils import custom_serializer @dataclass class Task: """Task object""" id: str payload: Any priority: int = 0 status: IndexJobStatus = IndexJobStatus.PENDING created_at: float = field(default_factory=time.time) started_at: Optional[float] = None completed_at: Optional[float] = None error: Optional[Exception] = None result: Any = None def __lt__(self, other): """Used for priority queue sorting""" return self.priority > other.priority @dataclass class ProcessingStats: """Processing statistics information""" total_tasks: int = 0 completed_tasks: int = 0 failed_tasks: int = 0 cancelled_tasks: int = 0 average_processing_time: float = 0.0 throughput: float = 0.0 # Number of tasks processed per second start_time: datetime.datetime = datetime.datetime.now() @property def success_rate(self) -> float: """Success rate""" if self.total_tasks == 0: return 0.0 return self.completed_tasks / self.total_tasks @property def pending_tasks(self) -> int: """Number of pending tasks""" return self.total_tasks - self.completed_tasks - self.failed_tasks - self.cancelled_tasks @property def elapsed_time(self) -> float: """Elapsed time""" time_diff = datetime.datetime.now() - self.start_time return time_diff.total_seconds() @property def eta(self) -> float: """Estimated remaining time""" if self.completed_tasks == 0: return 0.0 rate = self.completed_tasks / self.elapsed_time if rate == 0: return 0.0 return self.pending_tasks / rate class TaskProcessorInterface(ABC): @abstractmethod def process(self, task: Task) -> Any: pass class TaskProcessor: """Task processor""" def __init__(self, task_processor: TaskProcessorInterface, max_workers: int = 4, logger: Optional[logging.Logger] = None, database_engine: Optional[Any] = None, data_config:Optional[dict[str,Any]] = None): if data_config is None: raise ValueError("data_config must be provided") self.task_processor = task_processor self.max_workers = max_workers self.logger = logger or logging.getLogger(__name__) self.database_engine = database_engine # Simple statistics self.total_tasks = 0 self.completed_tasks = 0 self.failed_tasks = 0 self.start_time:datetime.datetime|None = None # Processing report collection self.processing_reports: List[Dict[str, Any]] = [] # Control variable self.should_stop = False self.data_config = data_config self.datasource_name: str = data_config.get("datasource_name", "default") def process_tasks(self, tasks: List[Any]) -> None: """Process task list - simple and effective""" self.total_tasks = len(tasks) self.completed_tasks = 0 self.failed_tasks = 0 self.start_time = datetime.datetime.now() self.processing_reports = [] self.logger.info(f"Starting to process {self.total_tasks} tasks") # Use thread pool to process tasks with ThreadPoolExecutor(max_workers=self.max_workers) as executor: # Submit all tasks future_to_task = {executor.submit(self._process_single_task, task): task for task in tasks} # Wait for tasks to complete for future in as_completed(future_to_task): if self.should_stop: break task = future_to_task[future] try: result = future.result() self.completed_tasks += 1 # Record successful processing report report:dict[str,Any] = { 'task_id': getattr(task, 'id', 'unknown'), 'status': 'success', 'message': getattr(result, 'message', 'Processing completed'), 'chunks_count': getattr(result, 'chunks_count', 0), 'processing_time': getattr(result, 'processing_time', 0) } self.processing_reports.append(report) # Output progress every 1 task self._log_progress() except Exception: self.failed_tasks += 1 self.logger.error(f"Task processing failed: {traceback.format_exc()}") # Record failed processing report report = { 'task_id': getattr(task, 'id', 'unknown'), 'status': 'failed', 'error': traceback.format_exc(), 'processing_time': 0 } self.processing_reports.append(report) # Output final statistics self.finalize_job_status_and_log() def _process_single_task(self, task: Any) -> Any: """Process a single task""" return self.task_processor.process(task) def get_processing_reports(self) -> List[Dict[str, Any]]: """Get processing reports""" return self.processing_reports def _log_progress(self) -> None: """Output progress information (estimate remaining time based on average time per processed document)""" if self.start_time is None: return elapsed = (datetime.datetime.now() - self.start_time).total_seconds() if self.start_time else 0 total_processed = self.completed_tasks + self.failed_tasks remaining = self.total_tasks - total_processed # Total processing time for processed tasks total_processing_time = sum(r.get('processing_time', 0) for r in self.processing_reports) avg_processing_time = (total_processing_time / total_processed) if total_processed > 0 else 0 eta = avg_processing_time * remaining if total_processed > 0: rate = total_processed / elapsed if elapsed > 0 else 0 self.logger.info( f"Progress: {total_processed}/{self.total_tasks} " f"({100.0 * total_processed / self.total_tasks:.1f}%) " f"Success: {self.completed_tasks} Failed: {self.failed_tasks} " f"Rate: {rate:.2f} tasks/second " f"Average time: {avg_processing_time:.2f} seconds/task " f"Estimated remaining: {eta / 60:.1f} minutes" ) def finalize_job_status_and_log(self) -> None: """Statistics, write IndexJob status, and output all log details.""" elapsed = (datetime.datetime.now() - self.start_time).total_seconds() if self.start_time else 0 success_count = self.completed_tasks fail_count = self.failed_tasks total_count = self.total_tasks success_rate = (success_count / total_count * 100) if total_count > 0 else 0.0 status = IndexJobStatus.FAILED.value if total_count == success_count: status = IndexJobStatus.SUCCESS.value elif success_count > 0 and fail_count > 0: status = IndexJobStatus.PARTIAL_SUCCESS.value report:dict[str,Any] = { "status": status, "success_rate": f"{success_rate:.4f}%", "total_tasks": total_count, "completed": success_count, "failed": fail_count, "start_time": self.start_time, "end_time": datetime.datetime.now(datetime.timezone.utc), "processing_time": f"{elapsed:.4f} sec", "total_elapsed": f"{elapsed / 3600:.4f} hours ", "average_speed": f"{total_count / elapsed:.5f} tasks/sec" if elapsed > 0 else "average speed: 0 tasks/sec" } # Database write section if self.database_engine: try: Session = sessionmaker(bind=self.database_engine) session = Session() try: current_job = session.query(IndexJob).filter(and_(IndexJob.status == "processing",IndexJob.datasource_name==self.datasource_name)).order_by(IndexJob.id.desc()).first() if current_job: setattr(current_job, 'finished_time', report["end_time"]) setattr(current_job, 'success_object_count', success_count - fail_count) setattr(current_job, 'failed_object_count', fail_count) setattr(current_job, 'detailed_message', json.dumps(report, default=custom_serializer, ensure_ascii=False)) session.commit() self.logger.info(f"IndexJob status updated: {current_job.status}, Success: {current_job.success_object_count}, Failed: {current_job.failed_object_count}") else: self.logger.warning("No IndexJob record with processing status found") finally: session.close() except Exception as e: self.logger.error(f"Failed to update IndexJob status: {e}") # Output merged report content self.logger.info(f"Final report: {json.dumps(report, default=custom_serializer, ensure_ascii=False)}") if self.processing_reports: success_reports = [r for r in self.processing_reports if r['status'] == 'success'] failed_reports = [r for r in self.processing_reports if r['status'] == 'failed'] if success_reports: total_chunks = sum(r.get('chunks_count', 0) for r in success_reports) avg_processing_time = sum(r.get('processing_time', 0) for r in success_reports) / len(success_reports) self.logger.info(f"Success reports: {len(success_reports)} tasks, total {total_chunks} chunks, average processing time {avg_processing_time:.2f} sec") if failed_reports: self.logger.error(f"Failed reports: {len(failed_reports)} tasks") for r in failed_reports[:5]: self.logger.error(f" - {r['task_id']}: {r['error']}")