from datetime import datetime
[docs]
class TaskStateManager:
"""Manager for tracking and recovering task states across all operation types."""
[docs]
def __init__(self, redis_client):
self.redis = redis_client
[docs]
def register_task(
self, task_id, operation_type, operation_id, additional_info=None, max_retries=3
):
"""
Register a task in Redis with its metadata.
Args:
task_id (str): Celery task ID
operation_type (str): Type of operation (transfer, archive, package, delete, verify)
operation_id (int): Database ID of the operation
additional_info (dict, optional): Additional context about the operation
max_retries (int, optional): Maximum retry count for this task
"""
key = f"task:{task_id}"
# Base data for all task types
data = {
"operation_type": operation_type,
"operation_id": str(operation_id),
"status": "RUNNING",
"start_time": datetime.now().isoformat(),
"heartbeat": datetime.now().isoformat(),
"retry_count": "0",
"max_retries": str(max_retries),
}
# Add additional info if provided
if additional_info:
for k, v in additional_info.items():
data[k] = str(v)
# Store in Redis with TTL
self.redis.hmset(key, data)
self.redis.expire(key, 86400 * 2) # 48 hour TTL
# Maintain indices for each operation type and ID
self.redis.sadd(f"running_tasks:{operation_type}", task_id)
self.redis.sadd(f"tasks_for_operation:{operation_type}:{operation_id}", task_id)
[docs]
def update_heartbeat(self, task_id):
"""Update task heartbeat to indicate it's still running."""
key = f"task:{task_id}"
if self.redis.exists(key):
self.redis.hset(key, "heartbeat", datetime.now().isoformat())
[docs]
def complete_task(self, task_id):
"""Mark task as completed and remove from tracking."""
key = f"task:{task_id}"
# Get operation info before deleting
if not self.redis.exists(key):
return
operation_type = self.redis.hget(key, "operation_type")
operation_id = self.redis.hget(key, "operation_id")
if operation_type and operation_id:
# Remove from indices
self.redis.srem(f"running_tasks:{operation_type}", task_id)
self.redis.srem(
f"tasks_for_operation:{operation_type}:{operation_id}", task_id
)
# Mark as completed
self.redis.hset(key, "status", "COMPLETED")
self.redis.hset(key, "end_time", datetime.now().isoformat())
# Keep completed task info for a while before deleting
self.redis.expire(key, 86400) # 24 hour TTL for completed tasks
[docs]
def fail_task(self, task_id, error_message, is_retryable=True):
"""
Mark task as failed.
Returns:
tuple: (can_retry, operation_type, operation_id)
"""
key = f"task:{task_id}"
if not self.redis.exists(key):
return False, None, None
# Get operation info
operation_type = self.redis.hget(key, "operation_type")
operation_id = self.redis.hget(key, "operation_id")
# Update status
self.redis.hset(key, "status", "FAILED")
self.redis.hset(key, "error", error_message)
self.redis.hset(key, "end_time", datetime.now().isoformat())
# Remove from running tasks
if operation_type:
self.redis.srem(f"running_tasks:{operation_type}", task_id)
if is_retryable:
# Increment retry count
retry_count = int(self.redis.hincrby(key, "retry_count", 1))
max_retries = int(self.redis.hget(key, "max_retries") or 3)
if retry_count <= max_retries:
# Ready to be retried
return True, operation_type, operation_id
else:
# Max retries exceeded
self.redis.hset(key, "status", "FAILED_PERMANENT")
return False, operation_type, operation_id
else:
# Not retryable
self.redis.hset(key, "status", "FAILED_PERMANENT")
return False, operation_type, operation_id
[docs]
def get_stalled_tasks(self, heartbeat_timeout=300):
"""
Find tasks that haven't updated their heartbeat recently.
Returns:
list: List of dicts with task information
"""
now = datetime.now()
stalled_tasks = []
# Get all running tasks
all_task_keys = self.redis.keys("task:*")
for task_key in all_task_keys:
task_data = self.redis.hgetall(task_key)
if task_data.get("status") != "RUNNING":
continue
# Check heartbeat
try:
last_heartbeat = datetime.fromisoformat(task_data.get("heartbeat"))
task_id = task_key.replace("task:", "")
if (now - last_heartbeat).total_seconds() > heartbeat_timeout:
stalled_tasks.append(
{
"task_id": task_id,
"operation_type": task_data.get("operation_type"),
"operation_id": task_data.get("operation_id"),
"last_heartbeat": last_heartbeat,
"stalled_for": (now - last_heartbeat).total_seconds(),
}
)
except (ValueError, TypeError):
# Invalid heartbeat format, consider stalled
stalled_tasks.append(
{
"task_id": task_key.replace("task:", ""),
"operation_type": task_data.get("operation_type"),
"operation_id": task_data.get("operation_id"),
"last_heartbeat": None,
"stalled_for": None,
}
)
return stalled_tasks