Source code for ccat_data_transfer.task_state_manager

from datetime import datetime


[docs] class TaskStateManager: """Manager for tracking and recovering task states across all operation types."""
[docs] def __init__(self, redis_client): self.redis = redis_client
[docs] def register_task( self, task_id, operation_type, operation_id, additional_info=None, max_retries=3 ): """ Register a task in Redis with its metadata. Args: task_id (str): Celery task ID operation_type (str): Type of operation (transfer, archive, package, delete, verify) operation_id (int): Database ID of the operation additional_info (dict, optional): Additional context about the operation max_retries (int, optional): Maximum retry count for this task """ key = f"task:{task_id}" # Base data for all task types data = { "operation_type": operation_type, "operation_id": str(operation_id), "status": "RUNNING", "start_time": datetime.now().isoformat(), "heartbeat": datetime.now().isoformat(), "retry_count": "0", "max_retries": str(max_retries), } # Add additional info if provided if additional_info: for k, v in additional_info.items(): data[k] = str(v) # Store in Redis with TTL self.redis.hmset(key, data) self.redis.expire(key, 86400 * 2) # 48 hour TTL # Maintain indices for each operation type and ID self.redis.sadd(f"running_tasks:{operation_type}", task_id) self.redis.sadd(f"tasks_for_operation:{operation_type}:{operation_id}", task_id)
[docs] def update_heartbeat(self, task_id): """Update task heartbeat to indicate it's still running.""" key = f"task:{task_id}" if self.redis.exists(key): self.redis.hset(key, "heartbeat", datetime.now().isoformat())
[docs] def complete_task(self, task_id): """Mark task as completed and remove from tracking.""" key = f"task:{task_id}" # Get operation info before deleting if not self.redis.exists(key): return operation_type = self.redis.hget(key, "operation_type") operation_id = self.redis.hget(key, "operation_id") if operation_type and operation_id: # Remove from indices self.redis.srem(f"running_tasks:{operation_type}", task_id) self.redis.srem( f"tasks_for_operation:{operation_type}:{operation_id}", task_id ) # Mark as completed self.redis.hset(key, "status", "COMPLETED") self.redis.hset(key, "end_time", datetime.now().isoformat()) # Keep completed task info for a while before deleting self.redis.expire(key, 86400) # 24 hour TTL for completed tasks
[docs] def fail_task(self, task_id, error_message, is_retryable=True): """ Mark task as failed. Returns: tuple: (can_retry, operation_type, operation_id) """ key = f"task:{task_id}" if not self.redis.exists(key): return False, None, None # Get operation info operation_type = self.redis.hget(key, "operation_type") operation_id = self.redis.hget(key, "operation_id") # Update status self.redis.hset(key, "status", "FAILED") self.redis.hset(key, "error", error_message) self.redis.hset(key, "end_time", datetime.now().isoformat()) # Remove from running tasks if operation_type: self.redis.srem(f"running_tasks:{operation_type}", task_id) if is_retryable: # Increment retry count retry_count = int(self.redis.hincrby(key, "retry_count", 1)) max_retries = int(self.redis.hget(key, "max_retries") or 3) if retry_count <= max_retries: # Ready to be retried return True, operation_type, operation_id else: # Max retries exceeded self.redis.hset(key, "status", "FAILED_PERMANENT") return False, operation_type, operation_id else: # Not retryable self.redis.hset(key, "status", "FAILED_PERMANENT") return False, operation_type, operation_id
[docs] def get_stalled_tasks(self, heartbeat_timeout=300): """ Find tasks that haven't updated their heartbeat recently. Returns: list: List of dicts with task information """ now = datetime.now() stalled_tasks = [] # Get all running tasks all_task_keys = self.redis.keys("task:*") for task_key in all_task_keys: task_data = self.redis.hgetall(task_key) if task_data.get("status") != "RUNNING": continue # Check heartbeat try: last_heartbeat = datetime.fromisoformat(task_data.get("heartbeat")) task_id = task_key.replace("task:", "") if (now - last_heartbeat).total_seconds() > heartbeat_timeout: stalled_tasks.append( { "task_id": task_id, "operation_type": task_data.get("operation_type"), "operation_id": task_data.get("operation_id"), "last_heartbeat": last_heartbeat, "stalled_for": (now - last_heartbeat).total_seconds(), } ) except (ValueError, TypeError): # Invalid heartbeat format, consider stalled stalled_tasks.append( { "task_id": task_key.replace("task:", ""), "operation_type": task_data.get("operation_type"), "operation_id": task_data.get("operation_id"), "last_heartbeat": None, "stalled_for": None, } ) return stalled_tasks