from celery import Celery
from celery.signals import after_setup_logger
from celery.signals import after_setup_task_logger
from kombu import Exchange, Queue
from celery.signals import worker_init
from ccat_ops_db import models
from typing import List
from contextlib import contextmanager
from sqlalchemy.orm import sessionmaker
from builtins import NotImplementedError # Add explicit import for clarity
from .config.config import ccat_data_transfer_settings
from .utils import service_shutdown, get_redis_connection
from .task_state_manager import TaskStateManager
from .exceptions import ServiceExit
from .database import DatabaseConnection
from .health_check import HealthCheck
from .logging_utils import get_structured_logger, setup_celery_logging
from .queue_discovery import QueueDiscoveryService
import logging
import time
import threading
from .notification_service import NotificationClient
import json
# Setup logging before creating the Celery app
setup_celery_logging()
# Configure Redis connection with SSL and password
redis_connection = "rediss://:{}@{}:{}/{}".format(
ccat_data_transfer_settings.REDIS_PASSWORD,
ccat_data_transfer_settings.REDIS_HOST,
ccat_data_transfer_settings.REDIS_PORT,
ccat_data_transfer_settings.REDIS_DB,
)
# Initialize Celery app with SSL broker options
app = Celery(
"ccat",
broker=redis_connection,
backend=redis_connection,
broker_connection_retry_on_startup=True,
broker_use_ssl={
"ssl_cert_reqs": "required",
"ssl_ca_certs": ccat_data_transfer_settings.REDIS_CA_CERT,
"ssl_certfile": ccat_data_transfer_settings.REDIS_CERTFILE,
"ssl_keyfile": ccat_data_transfer_settings.REDIS_KEYFILE,
},
redis_backend_use_ssl={
"ssl_cert_reqs": "required",
"ssl_ca_certs": ccat_data_transfer_settings.REDIS_CA_CERT,
"ssl_certfile": ccat_data_transfer_settings.REDIS_CERTFILE,
"ssl_keyfile": ccat_data_transfer_settings.REDIS_KEYFILE,
},
)
# Configure Celery logging
app.conf.update(
worker_hijack_root_logger=False, # Prevent Celery from hijacking root logger
worker_redirect_stdouts=False, # Prevent Celery from redirecting stdout/stderr
worker_redirect_stdouts_level="INFO", # Log level for stdout/stderr
# Redis broker settings
broker_connection_timeout=30,
broker_connection_retry=True,
broker_connection_max_retries=10,
broker_pool_limit=10,
# Redis backend settings
redis_socket_timeout=10,
redis_socket_connect_timeout=10,
redis_retry_on_timeout=True,
redis_max_connections=10,
# Broker transport options - enable EVALSHA for better performance
broker_transport_options={
"visibility_timeout": 3600, # 1 hour
"fanout_prefix": True,
"fanout_patterns": True,
"retry_policy": {
"timeout": 5.0,
"max_retries": 3,
},
"socket_connect_timeout": 10,
"socket_timeout": 10,
"use_evalsha": True, # Enable EVALSHA for better performance
},
# Task settings
task_acks_late=True,
task_reject_on_worker_lost=True,
task_track_started=True,
# Worker settings
worker_prefetch_multiplier=10, # Process one task at a time
worker_max_tasks_per_child=100, # Restart worker after 100 tasks
worker_max_memory_per_child=512000, # Restart if memory exceeds 512MB
)
# Disable propagation for the main Celery logger
logging.getLogger("celery").propagate = False
# Move database initialization to worker_init signal
[docs]
@worker_init.connect
def init_worker(**kwargs):
"""Initialize database connection and BBCP settings when the worker starts"""
from .database import DatabaseConnection
from .bbcp_settings import BBCPSettings
from .logging_utils import get_structured_logger
logger = get_structured_logger(__name__)
# Initialize database
db = DatabaseConnection()
session, engine = db.get_connection()
SQLAlchemyTask.init_session_factory(engine)
# Configure dynamic queues
configure_dynamic_queues(session)
# Initialize BBCP settings
logger.info("Initializing BBCP settings")
settings = BBCPSettings()
logger.info(f"BBCP settings: {settings.get_all_settings()}")
# Define exchanges
data_transfer_exchange = Exchange("data_transfer", type="direct")
# Define static queues (these will be merged with dynamic queues)
STATIC_QUEUES = (
Queue("fyst", data_transfer_exchange, routing_key="fyst"),
Queue("cologne", data_transfer_exchange, routing_key="cologne"),
Queue("us", data_transfer_exchange, routing_key="us"),
Queue(
"fyst-transfer",
data_transfer_exchange,
routing_key="fyst-transfer",
queue_arguments={"x-queue-type": "classic"},
),
Queue(
"cologne-transfer",
data_transfer_exchange,
routing_key="cologne-transfer",
queue_arguments={"x-queue-type": "classic"},
),
Queue(
"us-transfer",
data_transfer_exchange,
routing_key="us-transfer",
queue_arguments={"x-queue-type": "classic"},
),
)
app.conf.imports = (
"ccat_data_transfer.archive_manager",
"ccat_data_transfer.data_integrity_manager",
"ccat_data_transfer.transfer_manager",
"ccat_data_transfer.deletion_manager",
"ccat_data_transfer.disk_monitor",
"ccat_data_transfer.staging_manager",
"ccat_data_transfer.raw_data_package_manager",
"ccat_data_transfer.data_transfer_package_manager",
)
# Define static routes (these will be merged with dynamic routes)
STATIC_ROUTES = {
"ccat:data_transfer:transfer:fyst": {
"queue": "fyst-transfer",
"routing_key": "fyst-transfer",
},
"ccat:data_transfer:create:transfer_package": {
"queue": "fyst",
"routing_key": "fyst",
},
"ccat:data_transfer:transfer:cologne": {
"queue": "cologne-transfer",
"routing_key": "cologne-transfer",
},
"ccat:data_transfer:transfer:us": {
"queue": "us-transfer",
"routing_key": "us-transfer",
},
"ccat:data_transfer:unpack_data_transfer_package:cologne": {
"queue": "cologne",
"routing_key": "cologne",
},
"ccat:data_transfer:unpack_data_transfer_package:us": {
"queue": "us",
"routing_key": "us",
},
"ccat:data_transfer:archive:cologne": {
"queue": "cologne",
"routing_key": "cologne",
},
"ccat:data_transfer:archive:us": {
"queue": "us",
"routing_key": "us",
},
"ccat:data_transfer:delete:cologne": {
"queue": "cologne",
"routing_key": "cologne",
},
"ccat:data_transfer:delete:us": {
"queue": "us",
"routing_key": "us",
},
"ccat:data_transfer:delete:fyst": {
"queue": "fyst",
"routing_key": "fyst",
},
"ccat:data_transfer:monitor_disk_usage:fyst": {
"queue": "fyst",
"routing_key": "fyst",
},
"ccat:data_transfer:staging:ramses": {
"queue": "ramses-staging",
"routing_key": "ramses-staging",
},
"ccat:data_transfer:staging:cheops": {
"queue": "cheops-staging",
"routing_key": "cheops-staging",
},
}
[docs]
def get_worker_queues(
session: sessionmaker, location_identifier: str, operation_type: str = None
) -> List[str]:
"""Get queue names for a specific location and optional operation type."""
logger = get_structured_logger(__name__)
# Resolve location from identifier
location = (
session.query(models.DataLocation)
.filter(models.DataLocation.name == location_identifier)
.first()
)
if not location:
logger.warning(f"Location '{location_identifier}' not found")
return []
discovery_service = QueueDiscoveryService(session)
if operation_type:
# Process specific operation type
return [discovery_service.get_queue_name_for_location(location, operation_type)]
else:
# Process all operations for this location
return discovery_service.get_queues_for_location(location)
# Connect to both logger setup signals
after_setup_logger.connect(configure_logger)
after_setup_task_logger.connect(configure_logger)
# First, let's define the base SQLAlchemyTask as you currently have it
[docs]
class SQLAlchemyTask(app.Task):
_session_factory = None
[docs]
@classmethod
def init_session_factory(cls, engine):
cls._session_factory = sessionmaker(bind=engine)
[docs]
@contextmanager
def session_scope(self):
if self._session_factory is None:
raise RuntimeError("Session factory not initialized")
session = self._session_factory()
try:
yield session
session.commit()
except:
session.rollback()
raise
finally:
session.close()
[docs]
def after_return(self, status, retval, task_id, args, kwargs, einfo):
pass # No need to remove session, as it's handled by the context manager
[docs]
def on_failure(self, exc, task_id, args, kwargs, einfo):
pass # No need to rollback, as it's handled by the context manager
# Now let's create the enhanced version with error tracking and recovery
[docs]
def make_celery_task(test_session_factory=None):
"""
Create a base task class with unified error handling, state tracking, and SQLAlchemy support.
Args:
test_session_factory: Optional session factory for testing
Returns:
A base Celery task class with enhanced error handling and SQLAlchemy integration
"""
# Initialize services
redis_client = get_redis_connection()
task_state_manager = TaskStateManager(redis_client)
notification_client = NotificationClient(redis_client=redis_client)
logger = get_structured_logger(__name__)
class CCATEnhancedSQLAlchemyTask(SQLAlchemyTask):
"""Enhanced SQLAlchemy task with resilient error handling and state tracking."""
abstract = True
# Operation type and retry settings
operation_type = None # Must be set by subclasses
max_retries = 3
@classmethod
def init_session_factory(cls, session_factory=None):
"""Initialize the SQLAlchemy session factory."""
if session_factory:
cls._session_factory = session_factory
else:
db = DatabaseConnection()
session, engine = db.get_connection()
cls._session_factory = sessionmaker(bind=engine)
def get_operation_id(self, args):
"""Extract operation ID from task arguments."""
if args and len(args) > 0:
return args[0]
return None
def get_operation_info(self, args, kwargs):
"""Get additional operation info for task tracking."""
return {}
def should_retry(self, exc, operation_id, retry_count):
"""
Determine if task should be retried based on exception and retry count.
Args:
exc: The exception that was raised
operation_id: ID of the operation
retry_count: Current retry count
Returns:
bool: True if task should be retried, False otherwise
"""
logger = get_structured_logger(__name__)
# List of exceptions that should never be retried
non_retryable_exceptions = (
FileNotFoundError,
PermissionError,
NotADirectoryError,
IsADirectoryError,
)
# Check if exception is in the non-retryable list
if isinstance(exc, non_retryable_exceptions):
logger.info(
"Non-retryable exception encountered",
operation_id=operation_id,
exception_type=type(exc).__name__,
error=str(exc),
)
return False
# Check if exception has explicit retry information
if hasattr(exc, "is_retryable"):
is_retryable = exc.is_retryable
max_allowed_retries = getattr(exc, "max_retries", self.max_retries)
else:
is_retryable = True # Default to retryable
max_allowed_retries = self.max_retries
# Check retry count against max allowed retries
if retry_count >= max_allowed_retries:
logger.info(
"Max retries exceeded",
operation_id=operation_id,
retry_count=retry_count,
max_retries=max_allowed_retries,
)
return False
return is_retryable
def on_failure(self, exc, task_id, args, kwargs, einfo):
"""Handle task failure with uniform approach for all operation types."""
operation_id = self.get_operation_id(args)
if not operation_id:
logger.error("No operation ID found in task arguments")
return
try:
with self.session_scope() as session:
# Get current retry count
retry_count = self.get_retry_count(session, operation_id)
# Determine if we should retry based on retry count and exception
should_retry = self.should_retry(exc, operation_id, retry_count)
# Get operation details for notification
operation_details = self._get_operation_details(
session, operation_id
)
operation_info = self.get_operation_info(args, kwargs)
# Prepare notification subject with development mode indicator
is_dev = ccat_data_transfer_settings.DEVELOPMENT_MODE
subject_prefix = "[DEV]" if is_dev else ""
# Build base notification body
notification_body = (
f"Task Details:\n"
f"- Task Name: {self.name}\n"
f"- Task ID: {task_id}\n"
f"- Operation Type: {self.operation_type or 'unknown'}\n"
f"- Operation ID: {operation_id}\n"
f"- Retry Count: {retry_count}\n"
f"- Max Retries: {self.max_retries}\n"
f"- Will Retry: {should_retry}\n\n"
f"Error Information:\n"
f"- Error Type: {type(exc).__name__}\n"
f"- Error Message: {str(exc)}\n\n"
f"Operation Details:\n{self._format_operation_details(operation_details)}\n\n"
f"Additional Context:\n{json.dumps(operation_info, indent=2)}\n"
)
# Add development mode extras
if is_dev:
notification_body += (
f"\nTraceback:\n{''.join(einfo.traceback)}\n\n"
f"Task Arguments:\n"
f"- Args: {args}\n"
f"- Kwargs: {kwargs}\n"
)
if should_retry and hasattr(self, "reset_state_on_failure"):
# Reset state for retry
self.reset_state_on_failure(session, operation_id, exc)
logger.info(
"Task scheduled for retry",
operation_id=operation_id,
retry_count=retry_count + 1,
)
# Only send retry notification in development mode
if is_dev:
subject = f"{subject_prefix} Task Retry in {self.name}"
notification_client.send_notification(
subject=subject,
body=notification_body,
level="DEBUG",
)
elif hasattr(self, "mark_permanent_failure"):
# Mark as permanently failed
self.mark_permanent_failure(session, operation_id, exc)
# Always send notification for permanent failures
subject = (
f"{subject_prefix} Permanent Task Failure in {self.name}"
)
notification_client.send_notification(
subject=subject,
body=notification_body,
level="ERROR",
)
except Exception as e:
logger.error(
"Error in failure handling",
task_id=task_id,
operation_id=operation_id,
error=str(e),
)
# Call parent's on_failure (which is a no-op in SQLAlchemyTask)
super().on_failure(exc, task_id, args, kwargs, einfo)
def get_retry_count(self, session, operation_id):
"""
Get current retry count for this operation.
Should be implemented by subclasses to access the appropriate database field.
Args:
session: SQLAlchemy session
operation_id: ID of the operation
Returns:
int: Current retry count, defaults to 0
"""
return NotImplementedError(
"get_retry_count must be implemented by subclasses"
)
def on_success(self, retval, task_id, args, kwargs):
"""Handle successful task completion."""
task_state_manager.complete_task(task_id)
def after_return(self, status, retval, task_id, args, kwargs, einfo):
"""Cleanup after task execution."""
# If task succeeded, cleanup is handled by on_success
# If task failed, cleanup is handled by on_failure
super().after_return(status, retval, task_id, args, kwargs, einfo)
def __call__(self, *args, **kwargs):
"""Run the task with state tracking, heartbeat, and SQLAlchemy session."""
# Initialize session factory if not done yet
if not self._session_factory:
self.init_session_factory(test_session_factory)
task_id = self.request.id
operation_id = self.get_operation_id(args)
operation_type = self.operation_type or self.name.split(":")[-1]
# Get additional operation context
operation_info = self.get_operation_info(args, kwargs)
# Register task in Redis
task_state_manager.register_task(
task_id=task_id,
operation_type=operation_type,
operation_id=operation_id,
additional_info=operation_info,
max_retries=self.max_retries,
)
# Set up periodic heartbeat with proper cleanup
stop_heartbeat = threading.Event()
heartbeat_failed = threading.Event()
heartbeat_thread = None
def heartbeat_worker():
heartbeat_logger = get_structured_logger(__name__ + ".heartbeat")
consecutive_failures = 0
max_consecutive_failures = 3
while not stop_heartbeat.is_set():
try:
if stop_heartbeat.is_set():
break
task_state_manager.update_heartbeat(task_id)
consecutive_failures = 0 # Reset on success
except Exception as e:
consecutive_failures += 1
heartbeat_logger.error(
"Heartbeat update failed",
error=str(e),
consecutive_failures=consecutive_failures,
)
if consecutive_failures >= max_consecutive_failures:
heartbeat_logger.error(
"Too many consecutive heartbeat failures, marking task as failed",
task_id=task_id,
)
heartbeat_failed.set()
break
# Use a shorter sleep interval and check stop flag more frequently
for _ in range(6): # 6 * 10 seconds = 60 seconds total
if stop_heartbeat.is_set():
break
time.sleep(10)
try:
# Start heartbeat thread
heartbeat_thread = threading.Thread(target=heartbeat_worker)
heartbeat_thread.daemon = True
heartbeat_thread.start()
# Execute the task
result = super().__call__(*args, **kwargs)
# Check if heartbeat failed during task execution
if heartbeat_failed.is_set():
raise RuntimeError("Task heartbeat failed during execution")
return result
except Exception as e:
# Log the error and re-raise
logger.error(
"Task execution failed",
task_id=task_id,
error=str(e),
operation_type=operation_type,
operation_id=operation_id,
)
raise
finally:
# Ensure proper cleanup of heartbeat thread
if heartbeat_thread and heartbeat_thread.is_alive():
stop_heartbeat.set()
try:
# Give the thread a reasonable time to clean up
heartbeat_thread.join(
timeout=10.0
) # Increased timeout to 10 seconds
if heartbeat_thread.is_alive():
logger.warning(
"Heartbeat thread did not stop gracefully - forcing cleanup",
task_id=task_id,
)
# Force cleanup of task state since heartbeat thread is stuck
try:
task_state_manager.complete_task(task_id)
except Exception as cleanup_error:
logger.error(
"Failed to force cleanup task state",
task_id=task_id,
error=str(cleanup_error),
)
except Exception as e:
logger.error(
"Error stopping heartbeat thread",
task_id=task_id,
error=str(e),
)
# Clean up task state
try:
task_state_manager.complete_task(task_id)
except Exception as e:
logger.error(
"Failed to clean up task state",
task_id=task_id,
error=str(e),
)
# Default implementations to be overridden by subclasses
def reset_state_on_failure(self, session, operation_id, exc):
"""self.is_retryable = is_retryable
Reset operation state for retry. To be implemented by subclasses.
This default implementation logs a warning and raises an error to ensure
subclasses properly implement their own retry logic.
Args:
session: SQLAlchemy session
operation_id: ID of the operation
exc: The exception that caused the failure
Raises:
NotImplementedError: Always raised to ensure subclasses implement their own logic
"""
logger.warning(
"reset_state_on_failure not implemented for task",
task_name=self.name,
operation_id=operation_id,
error=str(exc),
)
raise NotImplementedError(
f"Task {self.name} must implement reset_state_on_failure to handle retries properly"
)
def mark_permanent_failure(self, session, operation_id, exc):
"""
Mark operation as permanently failed. To be implemented by subclasses.
This default implementation logs a warning and raises an error to ensure
subclasses properly implement their own failure handling logic.
Args:
session: SQLAlchemy session
operation_id: ID of the operation
exc: The exception that caused the failure
Raises:
NotImplementedError: Always raised to ensure subclasses implement their own logic
"""
logger.warning(
"mark_permanent_failure not implemented for task",
task_name=self.name,
operation_id=operation_id,
error=str(exc),
)
raise NotImplementedError(
f"Task {self.name} must implement mark_permanent_failure to handle permanent failures properly"
)
def _get_operation_details(self, session, operation_id):
"""Get detailed information about the operation from the database."""
try:
if self.operation_type == "package":
return session.query(models.DataTransferPackage).get(operation_id)
elif self.operation_type == "transfer":
return session.query(models.DataTransfer).get(operation_id)
elif self.operation_type == "unpack":
return session.query(models.DataTransfer).get(operation_id)
elif self.operation_type == "archive":
return session.query(models.LongTermArchiveTransfer).get(
operation_id
)
elif self.operation_type == "delete":
return session.query(models.PhysicalCopy).get(operation_id)
return None
except Exception as e:
logger.error(
"Failed to get operation details",
operation_type=self.operation_type,
operation_id=operation_id,
error=str(e),
)
return None
def _format_operation_details(self, operation):
"""Format operation details for notification message."""
if not operation:
return "No operation details available"
details = []
for key, value in operation.__dict__.items():
if not key.startswith("_"):
details.append(f"- {key}: {value}")
return "\n".join(details)
return CCATEnhancedSQLAlchemyTask
[docs]
def start_celery_worker(queue=None, concurrency=None):
"""
Starts a Celery worker for a specific queue with controlled concurrency.
Parameters
----------
queue : str, optional
The name of the queue to process. If None, the worker will consume from all queues.
concurrency : int, optional
Number of worker processes/threads. If None, defaults to the number of CPU cores.
Returns
-------
None
Raises
------
ServiceExit
When a termination signal is received.
Notes
-----
- Uses subprocess to spawn a new Celery worker process.
- Configures SIGTERM signal handler for graceful shutdown.
- Redirects standard output and error to /dev/null.
"""
import os
import signal
import subprocess
# Set up signal handler for graceful shutdown
signal.signal(signal.SIGTERM, service_shutdown)
with open(os.devnull, "w") as _:
health_check = HealthCheck(
service_type="data_transfer",
service_name=f"celery_worker_{queue}",
)
health_check.start()
logging.info(f"Starting celery worker for queue {queue}")
try:
env = os.environ
# Modified command to ensure strict queue binding
command = [
"celery",
"-A",
"ccat_data_transfer.setup_celery_app.app",
"worker",
"--loglevel=INFO",
"-Q",
queue,
"--purge", # Clear any existing messages on startup
"-n",
f"worker.{queue}@%h", # Give unique name to worker
"--without-mingle", # Prevent workers from importing tasks from other workers
"--without-gossip", # Disable worker-worker communication
]
if concurrency is not None:
command.extend(["-c", str(concurrency)])
# Start the Celery worker as a subprocess
proc = subprocess.Popen(command, env=env)
# Wait for the worker process to complete
proc.wait()
except ServiceExit:
# Send SIGTERM for warm shutdown to the celery worker
proc.terminate()
proc.wait()
finally:
health_check.stop()
# Add to existing Celery configuration
app.conf.beat_schedule = {
"monitor-disk-usage-cologne": {
"task": "ccat:data_transfer:monitor_disk_usage:cologne",
"schedule": 60.0, # run every minute
},
"monitor-disk-usage-us": {
"task": "ccat:data_transfer:monitor_disk_usage:us",
"schedule": 60.0,
},
"monitor-disk-usage-fyst": {
"task": "ccat:data_transfer:monitor_disk_usage:fyst",
"schedule": 60.0,
},
}
[docs]
def start_celery_beat():
"""
Starts the Celery beat scheduler.
"""
import os
import signal
import subprocess
# Set up signal handler for graceful shutdown
signal.signal(signal.SIGTERM, service_shutdown)
try:
env = os.environ
command = [
"celery",
"-A",
"ccat_data_transfer.setup_celery_app.app",
"beat",
"--loglevel=INFO",
]
# Start the Celery beat scheduler as a subprocess
proc = subprocess.Popen(command, env=env)
# Wait for the process to complete
proc.wait()
except ServiceExit:
# Send SIGTERM for warm shutdown
proc.terminate()
proc.wait()