244 lines
10 KiB
Python
244 lines
10 KiB
Python
import time
|
|
from typing import List, Any, Optional, Dict
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
import json
|
|
import datetime
|
|
import traceback
|
|
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from abc import ABC, abstractmethod
|
|
from sqlalchemy import and_
|
|
from sqlalchemy.orm import sessionmaker
|
|
from database import IndexJobStatus, IndexJob
|
|
|
|
from utils import custom_serializer
|
|
|
|
|
|
@dataclass
|
|
class Task:
|
|
"""Task object"""
|
|
id: str
|
|
payload: Any
|
|
priority: int = 0
|
|
status: IndexJobStatus = IndexJobStatus.PENDING
|
|
created_at: float = field(default_factory=time.time)
|
|
started_at: Optional[float] = None
|
|
completed_at: Optional[float] = None
|
|
error: Optional[Exception] = None
|
|
result: Any = None
|
|
|
|
def __lt__(self, other):
|
|
"""Used for priority queue sorting"""
|
|
return self.priority > other.priority
|
|
|
|
|
|
@dataclass
|
|
class ProcessingStats:
|
|
"""Processing statistics information"""
|
|
total_tasks: int = 0
|
|
completed_tasks: int = 0
|
|
failed_tasks: int = 0
|
|
cancelled_tasks: int = 0
|
|
average_processing_time: float = 0.0
|
|
throughput: float = 0.0 # Number of tasks processed per second
|
|
start_time: datetime.datetime = datetime.datetime.now()
|
|
|
|
@property
|
|
def success_rate(self) -> float:
|
|
"""Success rate"""
|
|
if self.total_tasks == 0:
|
|
return 0.0
|
|
return self.completed_tasks / self.total_tasks
|
|
|
|
@property
|
|
def pending_tasks(self) -> int:
|
|
"""Number of pending tasks"""
|
|
return self.total_tasks - self.completed_tasks - self.failed_tasks - self.cancelled_tasks
|
|
|
|
@property
|
|
def elapsed_time(self) -> float:
|
|
"""Elapsed time"""
|
|
time_diff = datetime.datetime.now() - self.start_time
|
|
return time_diff.total_seconds()
|
|
|
|
@property
|
|
def eta(self) -> float:
|
|
"""Estimated remaining time"""
|
|
if self.completed_tasks == 0:
|
|
return 0.0
|
|
rate = self.completed_tasks / self.elapsed_time
|
|
if rate == 0:
|
|
return 0.0
|
|
return self.pending_tasks / rate
|
|
|
|
class TaskProcessorInterface(ABC):
|
|
@abstractmethod
|
|
def process(self, task: Task) -> Any:
|
|
pass
|
|
|
|
class TaskProcessor:
|
|
"""Task processor"""
|
|
|
|
def __init__(self,
|
|
task_processor: TaskProcessorInterface,
|
|
max_workers: int = 4,
|
|
logger: Optional[logging.Logger] = None,
|
|
database_engine: Optional[Any] = None,
|
|
data_config:Optional[dict[str,Any]] = None):
|
|
|
|
if data_config is None:
|
|
raise ValueError("data_config must be provided")
|
|
|
|
self.task_processor = task_processor
|
|
self.max_workers = max_workers
|
|
self.logger = logger or logging.getLogger(__name__)
|
|
self.database_engine = database_engine
|
|
|
|
# Simple statistics
|
|
self.total_tasks = 0
|
|
self.completed_tasks = 0
|
|
self.failed_tasks = 0
|
|
self.start_time:datetime.datetime|None = None
|
|
|
|
# Processing report collection
|
|
self.processing_reports: List[Dict[str, Any]] = []
|
|
|
|
# Control variable
|
|
self.should_stop = False
|
|
|
|
self.data_config = data_config
|
|
self.datasource_name: str = data_config.get("datasource_name", "default")
|
|
|
|
def process_tasks(self, tasks: List[Any]) -> None:
|
|
"""Process task list - simple and effective"""
|
|
self.total_tasks = len(tasks)
|
|
self.completed_tasks = 0
|
|
self.failed_tasks = 0
|
|
self.start_time = datetime.datetime.now()
|
|
self.processing_reports = []
|
|
|
|
self.logger.info(f"Starting to process {self.total_tasks} tasks")
|
|
|
|
# Use thread pool to process tasks
|
|
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
# Submit all tasks
|
|
future_to_task = {executor.submit(self._process_single_task, task): task
|
|
for task in tasks}
|
|
|
|
# Wait for tasks to complete
|
|
for future in as_completed(future_to_task):
|
|
if self.should_stop:
|
|
break
|
|
|
|
task = future_to_task[future]
|
|
try:
|
|
result = future.result()
|
|
self.completed_tasks += 1
|
|
|
|
# Record successful processing report
|
|
report:dict[str,Any] = { 'task_id': getattr(task, 'id', 'unknown'), 'status': 'success', 'message': getattr(result, 'message', 'Processing completed'), 'chunks_count': getattr(result, 'chunks_count', 0), 'processing_time': getattr(result, 'processing_time', 0) }
|
|
|
|
self.processing_reports.append(report)
|
|
|
|
# Output progress every 1 task
|
|
self._log_progress()
|
|
|
|
except Exception:
|
|
self.failed_tasks += 1
|
|
self.logger.error(f"Task processing failed: {traceback.format_exc()}")
|
|
# Record failed processing report
|
|
report = { 'task_id': getattr(task, 'id', 'unknown'), 'status': 'failed', 'error': traceback.format_exc(), 'processing_time': 0 }
|
|
self.processing_reports.append(report)
|
|
# Output final statistics
|
|
self.finalize_job_status_and_log()
|
|
|
|
def _process_single_task(self, task: Any) -> Any:
|
|
"""Process a single task"""
|
|
return self.task_processor.process(task)
|
|
|
|
def get_processing_reports(self) -> List[Dict[str, Any]]:
|
|
"""Get processing reports"""
|
|
return self.processing_reports
|
|
|
|
def _log_progress(self) -> None:
|
|
"""Output progress information (estimate remaining time based on average time per processed document)"""
|
|
if self.start_time is None:
|
|
return
|
|
elapsed = (datetime.datetime.now() - self.start_time).total_seconds() if self.start_time else 0
|
|
total_processed = self.completed_tasks + self.failed_tasks
|
|
remaining = self.total_tasks - total_processed
|
|
# Total processing time for processed tasks
|
|
total_processing_time = sum(r.get('processing_time', 0) for r in self.processing_reports)
|
|
avg_processing_time = (total_processing_time / total_processed) if total_processed > 0 else 0
|
|
eta = avg_processing_time * remaining
|
|
if total_processed > 0:
|
|
rate = total_processed / elapsed if elapsed > 0 else 0
|
|
self.logger.info(
|
|
f"Progress: {total_processed}/{self.total_tasks} "
|
|
f"({100.0 * total_processed / self.total_tasks:.1f}%) "
|
|
f"Success: {self.completed_tasks} Failed: {self.failed_tasks} "
|
|
f"Rate: {rate:.2f} tasks/second "
|
|
f"Average time: {avg_processing_time:.2f} seconds/task "
|
|
f"Estimated remaining: {eta / 60:.1f} minutes"
|
|
)
|
|
|
|
def finalize_job_status_and_log(self) -> None:
|
|
"""Statistics, write IndexJob status, and output all log details."""
|
|
elapsed = (datetime.datetime.now() - self.start_time).total_seconds() if self.start_time else 0
|
|
success_count = self.completed_tasks
|
|
fail_count = self.failed_tasks
|
|
total_count = self.total_tasks
|
|
success_rate = (success_count / total_count * 100) if total_count > 0 else 0.0
|
|
status = IndexJobStatus.FAILED.value
|
|
if total_count == success_count:
|
|
status = IndexJobStatus.SUCCESS.value
|
|
elif success_count > 0 and fail_count > 0:
|
|
status = IndexJobStatus.PARTIAL_SUCCESS.value
|
|
|
|
report:dict[str,Any] = {
|
|
"status": status,
|
|
"success_rate": f"{success_rate:.4f}%",
|
|
"total_tasks": total_count,
|
|
"completed": success_count,
|
|
"failed": fail_count,
|
|
"start_time": self.start_time,
|
|
"end_time": datetime.datetime.now(datetime.timezone.utc),
|
|
"processing_time": f"{elapsed:.4f} sec",
|
|
"total_elapsed": f"{elapsed / 3600:.4f} hours ",
|
|
"average_speed": f"{total_count / elapsed:.5f} tasks/sec" if elapsed > 0 else "average speed: 0 tasks/sec"
|
|
}
|
|
# Database write section
|
|
if self.database_engine:
|
|
try:
|
|
Session = sessionmaker(bind=self.database_engine)
|
|
session = Session()
|
|
try:
|
|
current_job = session.query(IndexJob).filter(and_(IndexJob.status == "processing",IndexJob.datasource_name==self.datasource_name)).order_by(IndexJob.id.desc()).first()
|
|
if current_job:
|
|
setattr(current_job, 'finished_time', report["end_time"])
|
|
setattr(current_job, 'success_object_count', success_count - fail_count)
|
|
setattr(current_job, 'failed_object_count', fail_count)
|
|
setattr(current_job, 'detailed_message', json.dumps(report, default=custom_serializer, ensure_ascii=False))
|
|
session.commit()
|
|
self.logger.info(f"IndexJob status updated: {current_job.status}, Success: {current_job.success_object_count}, Failed: {current_job.failed_object_count}")
|
|
else:
|
|
self.logger.warning("No IndexJob record with processing status found")
|
|
finally:
|
|
session.close()
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to update IndexJob status: {e}")
|
|
# Output merged report content
|
|
self.logger.info(f"Final report: {json.dumps(report, default=custom_serializer, ensure_ascii=False)}")
|
|
if self.processing_reports:
|
|
success_reports = [r for r in self.processing_reports if r['status'] == 'success']
|
|
failed_reports = [r for r in self.processing_reports if r['status'] == 'failed']
|
|
if success_reports:
|
|
total_chunks = sum(r.get('chunks_count', 0) for r in success_reports)
|
|
avg_processing_time = sum(r.get('processing_time', 0) for r in success_reports) / len(success_reports)
|
|
self.logger.info(f"Success reports: {len(success_reports)} tasks, total {total_chunks} chunks, average processing time {avg_processing_time:.2f} sec")
|
|
if failed_reports:
|
|
self.logger.error(f"Failed reports: {len(failed_reports)} tasks")
|
|
for r in failed_reports[:5]:
|
|
self.logger.error(f" - {r['task_id']}: {r['error']}")
|