Source code for ccat_data_transfer.setup_celery_app

from celery import Celery
from celery.signals import after_setup_logger
from celery.signals import after_setup_task_logger
from kombu import Exchange, Queue
from celery.signals import worker_init
from ccat_ops_db import models
from typing import List

from contextlib import contextmanager
from sqlalchemy.orm import sessionmaker
from builtins import NotImplementedError  # Add explicit import for clarity

from .config.config import ccat_data_transfer_settings
from .utils import service_shutdown, get_redis_connection
from .task_state_manager import TaskStateManager
from .exceptions import ServiceExit
from .database import DatabaseConnection
from .health_check import HealthCheck
from .logging_utils import get_structured_logger, setup_celery_logging
from .queue_discovery import QueueDiscoveryService
import logging
import time
import threading
from .notification_service import NotificationClient
import json


# Setup logging before creating the Celery app
setup_celery_logging()

# Configure Redis connection with SSL and password
redis_connection = "rediss://:{}@{}:{}/{}".format(
    ccat_data_transfer_settings.REDIS_PASSWORD,
    ccat_data_transfer_settings.REDIS_HOST,
    ccat_data_transfer_settings.REDIS_PORT,
    ccat_data_transfer_settings.REDIS_DB,
)

# Initialize Celery app with SSL broker options
app = Celery(
    "ccat",
    broker=redis_connection,
    backend=redis_connection,
    broker_connection_retry_on_startup=True,
    broker_use_ssl={
        "ssl_cert_reqs": "required",
        "ssl_ca_certs": ccat_data_transfer_settings.REDIS_CA_CERT,
        "ssl_certfile": ccat_data_transfer_settings.REDIS_CERTFILE,
        "ssl_keyfile": ccat_data_transfer_settings.REDIS_KEYFILE,
    },
    redis_backend_use_ssl={
        "ssl_cert_reqs": "required",
        "ssl_ca_certs": ccat_data_transfer_settings.REDIS_CA_CERT,
        "ssl_certfile": ccat_data_transfer_settings.REDIS_CERTFILE,
        "ssl_keyfile": ccat_data_transfer_settings.REDIS_KEYFILE,
    },
)

# Configure Celery logging
app.conf.update(
    worker_hijack_root_logger=False,  # Prevent Celery from hijacking root logger
    worker_redirect_stdouts=False,  # Prevent Celery from redirecting stdout/stderr
    worker_redirect_stdouts_level="INFO",  # Log level for stdout/stderr
    # Redis broker settings
    broker_connection_timeout=30,
    broker_connection_retry=True,
    broker_connection_max_retries=10,
    broker_pool_limit=10,
    # Redis backend settings
    redis_socket_timeout=10,
    redis_socket_connect_timeout=10,
    redis_retry_on_timeout=True,
    redis_max_connections=10,
    # Broker transport options - enable EVALSHA for better performance
    broker_transport_options={
        "visibility_timeout": 3600,  # 1 hour
        "fanout_prefix": True,
        "fanout_patterns": True,
        "retry_policy": {
            "timeout": 5.0,
            "max_retries": 3,
        },
        "socket_connect_timeout": 10,
        "socket_timeout": 10,
        "use_evalsha": True,  # Enable EVALSHA for better performance
    },
    # Task settings
    task_acks_late=True,
    task_reject_on_worker_lost=True,
    task_track_started=True,
    # Worker settings
    worker_prefetch_multiplier=10,  # Process one task at a time
    worker_max_tasks_per_child=100,  # Restart worker after 100 tasks
    worker_max_memory_per_child=512000,  # Restart if memory exceeds 512MB
)

# Disable propagation for the main Celery logger
logging.getLogger("celery").propagate = False


# Move database initialization to worker_init signal
[docs] @worker_init.connect def init_worker(**kwargs): """Initialize database connection and BBCP settings when the worker starts""" from .database import DatabaseConnection from .bbcp_settings import BBCPSettings from .logging_utils import get_structured_logger logger = get_structured_logger(__name__) # Initialize database db = DatabaseConnection() session, engine = db.get_connection() SQLAlchemyTask.init_session_factory(engine) # Configure dynamic queues configure_dynamic_queues(session) # Initialize BBCP settings logger.info("Initializing BBCP settings") settings = BBCPSettings() logger.info(f"BBCP settings: {settings.get_all_settings()}")
# Define exchanges data_transfer_exchange = Exchange("data_transfer", type="direct") # Define static queues (these will be merged with dynamic queues) STATIC_QUEUES = ( Queue("fyst", data_transfer_exchange, routing_key="fyst"), Queue("cologne", data_transfer_exchange, routing_key="cologne"), Queue("us", data_transfer_exchange, routing_key="us"), Queue( "fyst-transfer", data_transfer_exchange, routing_key="fyst-transfer", queue_arguments={"x-queue-type": "classic"}, ), Queue( "cologne-transfer", data_transfer_exchange, routing_key="cologne-transfer", queue_arguments={"x-queue-type": "classic"}, ), Queue( "us-transfer", data_transfer_exchange, routing_key="us-transfer", queue_arguments={"x-queue-type": "classic"}, ), ) app.conf.imports = ( "ccat_data_transfer.archive_manager", "ccat_data_transfer.data_integrity_manager", "ccat_data_transfer.transfer_manager", "ccat_data_transfer.deletion_manager", "ccat_data_transfer.disk_monitor", "ccat_data_transfer.staging_manager", "ccat_data_transfer.raw_data_package_manager", "ccat_data_transfer.data_transfer_package_manager", ) # Define static routes (these will be merged with dynamic routes) STATIC_ROUTES = { "ccat:data_transfer:transfer:fyst": { "queue": "fyst-transfer", "routing_key": "fyst-transfer", }, "ccat:data_transfer:create:transfer_package": { "queue": "fyst", "routing_key": "fyst", }, "ccat:data_transfer:transfer:cologne": { "queue": "cologne-transfer", "routing_key": "cologne-transfer", }, "ccat:data_transfer:transfer:us": { "queue": "us-transfer", "routing_key": "us-transfer", }, "ccat:data_transfer:unpack_data_transfer_package:cologne": { "queue": "cologne", "routing_key": "cologne", }, "ccat:data_transfer:unpack_data_transfer_package:us": { "queue": "us", "routing_key": "us", }, "ccat:data_transfer:archive:cologne": { "queue": "cologne", "routing_key": "cologne", }, "ccat:data_transfer:archive:us": { "queue": "us", "routing_key": "us", }, "ccat:data_transfer:delete:cologne": { "queue": "cologne", "routing_key": "cologne", }, "ccat:data_transfer:delete:us": { "queue": "us", "routing_key": "us", }, "ccat:data_transfer:delete:fyst": { "queue": "fyst", "routing_key": "fyst", }, "ccat:data_transfer:monitor_disk_usage:fyst": { "queue": "fyst", "routing_key": "fyst", }, "ccat:data_transfer:staging:ramses": { "queue": "ramses-staging", "routing_key": "ramses-staging", }, "ccat:data_transfer:staging:cheops": { "queue": "cheops-staging", "routing_key": "cheops-staging", }, }
[docs] def configure_dynamic_queues(session: sessionmaker) -> None: """Configure dynamic queues from database and merge with static configuration.""" logger = get_structured_logger(__name__) logger.info("Configuring dynamic Celery queues from database") # Discover queues from database discovery_service = QueueDiscoveryService(session) dynamic_queue_names = discovery_service.discover_all_queues() # Create dynamic queue objects dynamic_queues = [] for queue_name in dynamic_queue_names: queue = Queue( queue_name, data_transfer_exchange, routing_key=queue_name, queue_arguments={"x-queue-type": "classic"}, ) dynamic_queues.append(queue) # Merge static and dynamic queues all_queues = list(STATIC_QUEUES) + dynamic_queues app.conf.task_queues = tuple(all_queues) # Create dynamic routes dynamic_routes = {} for queue_name in dynamic_queue_names: dynamic_routes[f"ccat:data_transfer:{queue_name}"] = { "queue": queue_name, "routing_key": queue_name, } # Merge static and dynamic routes all_routes = {**STATIC_ROUTES, **dynamic_routes} app.conf.task_routes = all_routes logger.info(f"Configured {len(all_queues)} queues ({len(dynamic_queues)} dynamic)")
[docs] def get_worker_queues( session: sessionmaker, location_identifier: str, operation_type: str = None ) -> List[str]: """Get queue names for a specific location and optional operation type.""" logger = get_structured_logger(__name__) # Resolve location from identifier location = ( session.query(models.DataLocation) .filter(models.DataLocation.name == location_identifier) .first() ) if not location: logger.warning(f"Location '{location_identifier}' not found") return [] discovery_service = QueueDiscoveryService(session) if operation_type: # Process specific operation type return [discovery_service.get_queue_name_for_location(location, operation_type)] else: # Process all operations for this location return discovery_service.get_queues_for_location(location)
[docs] def configure_logger(logger, *args, **kwargs): """Configure logger to prevent duplicate logging""" # Remove all existing handlers logger.handlers = [] # Prevent propagation to parent loggers logger.propagate = False # Create a single handler with our desired format handler = logging.StreamHandler() formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S", # Removed milliseconds for cleaner output ) handler.setFormatter(formatter) logger.addHandler(handler)
# Connect to both logger setup signals after_setup_logger.connect(configure_logger) after_setup_task_logger.connect(configure_logger) # First, let's define the base SQLAlchemyTask as you currently have it
[docs] class SQLAlchemyTask(app.Task): _session_factory = None
[docs] @classmethod def init_session_factory(cls, engine): cls._session_factory = sessionmaker(bind=engine)
[docs] @contextmanager def session_scope(self): if self._session_factory is None: raise RuntimeError("Session factory not initialized") session = self._session_factory() try: yield session session.commit() except: session.rollback() raise finally: session.close()
[docs] def after_return(self, status, retval, task_id, args, kwargs, einfo): pass # No need to remove session, as it's handled by the context manager
[docs] def on_failure(self, exc, task_id, args, kwargs, einfo): pass # No need to rollback, as it's handled by the context manager
# Now let's create the enhanced version with error tracking and recovery
[docs] def make_celery_task(test_session_factory=None): """ Create a base task class with unified error handling, state tracking, and SQLAlchemy support. Args: test_session_factory: Optional session factory for testing Returns: A base Celery task class with enhanced error handling and SQLAlchemy integration """ # Initialize services redis_client = get_redis_connection() task_state_manager = TaskStateManager(redis_client) notification_client = NotificationClient(redis_client=redis_client) logger = get_structured_logger(__name__) class CCATEnhancedSQLAlchemyTask(SQLAlchemyTask): """Enhanced SQLAlchemy task with resilient error handling and state tracking.""" abstract = True # Operation type and retry settings operation_type = None # Must be set by subclasses max_retries = 3 @classmethod def init_session_factory(cls, session_factory=None): """Initialize the SQLAlchemy session factory.""" if session_factory: cls._session_factory = session_factory else: db = DatabaseConnection() session, engine = db.get_connection() cls._session_factory = sessionmaker(bind=engine) def get_operation_id(self, args): """Extract operation ID from task arguments.""" if args and len(args) > 0: return args[0] return None def get_operation_info(self, args, kwargs): """Get additional operation info for task tracking.""" return {} def should_retry(self, exc, operation_id, retry_count): """ Determine if task should be retried based on exception and retry count. Args: exc: The exception that was raised operation_id: ID of the operation retry_count: Current retry count Returns: bool: True if task should be retried, False otherwise """ logger = get_structured_logger(__name__) # List of exceptions that should never be retried non_retryable_exceptions = ( FileNotFoundError, PermissionError, NotADirectoryError, IsADirectoryError, ) # Check if exception is in the non-retryable list if isinstance(exc, non_retryable_exceptions): logger.info( "Non-retryable exception encountered", operation_id=operation_id, exception_type=type(exc).__name__, error=str(exc), ) return False # Check if exception has explicit retry information if hasattr(exc, "is_retryable"): is_retryable = exc.is_retryable max_allowed_retries = getattr(exc, "max_retries", self.max_retries) else: is_retryable = True # Default to retryable max_allowed_retries = self.max_retries # Check retry count against max allowed retries if retry_count >= max_allowed_retries: logger.info( "Max retries exceeded", operation_id=operation_id, retry_count=retry_count, max_retries=max_allowed_retries, ) return False return is_retryable def on_failure(self, exc, task_id, args, kwargs, einfo): """Handle task failure with uniform approach for all operation types.""" operation_id = self.get_operation_id(args) if not operation_id: logger.error("No operation ID found in task arguments") return try: with self.session_scope() as session: # Get current retry count retry_count = self.get_retry_count(session, operation_id) # Determine if we should retry based on retry count and exception should_retry = self.should_retry(exc, operation_id, retry_count) # Get operation details for notification operation_details = self._get_operation_details( session, operation_id ) operation_info = self.get_operation_info(args, kwargs) # Prepare notification subject with development mode indicator is_dev = ccat_data_transfer_settings.DEVELOPMENT_MODE subject_prefix = "[DEV]" if is_dev else "" # Build base notification body notification_body = ( f"Task Details:\n" f"- Task Name: {self.name}\n" f"- Task ID: {task_id}\n" f"- Operation Type: {self.operation_type or 'unknown'}\n" f"- Operation ID: {operation_id}\n" f"- Retry Count: {retry_count}\n" f"- Max Retries: {self.max_retries}\n" f"- Will Retry: {should_retry}\n\n" f"Error Information:\n" f"- Error Type: {type(exc).__name__}\n" f"- Error Message: {str(exc)}\n\n" f"Operation Details:\n{self._format_operation_details(operation_details)}\n\n" f"Additional Context:\n{json.dumps(operation_info, indent=2)}\n" ) # Add development mode extras if is_dev: notification_body += ( f"\nTraceback:\n{''.join(einfo.traceback)}\n\n" f"Task Arguments:\n" f"- Args: {args}\n" f"- Kwargs: {kwargs}\n" ) if should_retry and hasattr(self, "reset_state_on_failure"): # Reset state for retry self.reset_state_on_failure(session, operation_id, exc) logger.info( "Task scheduled for retry", operation_id=operation_id, retry_count=retry_count + 1, ) # Only send retry notification in development mode if is_dev: subject = f"{subject_prefix} Task Retry in {self.name}" notification_client.send_notification( subject=subject, body=notification_body, level="DEBUG", ) elif hasattr(self, "mark_permanent_failure"): # Mark as permanently failed self.mark_permanent_failure(session, operation_id, exc) # Always send notification for permanent failures subject = ( f"{subject_prefix} Permanent Task Failure in {self.name}" ) notification_client.send_notification( subject=subject, body=notification_body, level="ERROR", ) except Exception as e: logger.error( "Error in failure handling", task_id=task_id, operation_id=operation_id, error=str(e), ) # Call parent's on_failure (which is a no-op in SQLAlchemyTask) super().on_failure(exc, task_id, args, kwargs, einfo) def get_retry_count(self, session, operation_id): """ Get current retry count for this operation. Should be implemented by subclasses to access the appropriate database field. Args: session: SQLAlchemy session operation_id: ID of the operation Returns: int: Current retry count, defaults to 0 """ return NotImplementedError( "get_retry_count must be implemented by subclasses" ) def on_success(self, retval, task_id, args, kwargs): """Handle successful task completion.""" task_state_manager.complete_task(task_id) def after_return(self, status, retval, task_id, args, kwargs, einfo): """Cleanup after task execution.""" # If task succeeded, cleanup is handled by on_success # If task failed, cleanup is handled by on_failure super().after_return(status, retval, task_id, args, kwargs, einfo) def __call__(self, *args, **kwargs): """Run the task with state tracking, heartbeat, and SQLAlchemy session.""" # Initialize session factory if not done yet if not self._session_factory: self.init_session_factory(test_session_factory) task_id = self.request.id operation_id = self.get_operation_id(args) operation_type = self.operation_type or self.name.split(":")[-1] # Get additional operation context operation_info = self.get_operation_info(args, kwargs) # Register task in Redis task_state_manager.register_task( task_id=task_id, operation_type=operation_type, operation_id=operation_id, additional_info=operation_info, max_retries=self.max_retries, ) # Set up periodic heartbeat with proper cleanup stop_heartbeat = threading.Event() heartbeat_failed = threading.Event() heartbeat_thread = None def heartbeat_worker(): heartbeat_logger = get_structured_logger(__name__ + ".heartbeat") consecutive_failures = 0 max_consecutive_failures = 3 while not stop_heartbeat.is_set(): try: if stop_heartbeat.is_set(): break task_state_manager.update_heartbeat(task_id) consecutive_failures = 0 # Reset on success except Exception as e: consecutive_failures += 1 heartbeat_logger.error( "Heartbeat update failed", error=str(e), consecutive_failures=consecutive_failures, ) if consecutive_failures >= max_consecutive_failures: heartbeat_logger.error( "Too many consecutive heartbeat failures, marking task as failed", task_id=task_id, ) heartbeat_failed.set() break # Use a shorter sleep interval and check stop flag more frequently for _ in range(6): # 6 * 10 seconds = 60 seconds total if stop_heartbeat.is_set(): break time.sleep(10) try: # Start heartbeat thread heartbeat_thread = threading.Thread(target=heartbeat_worker) heartbeat_thread.daemon = True heartbeat_thread.start() # Execute the task result = super().__call__(*args, **kwargs) # Check if heartbeat failed during task execution if heartbeat_failed.is_set(): raise RuntimeError("Task heartbeat failed during execution") return result except Exception as e: # Log the error and re-raise logger.error( "Task execution failed", task_id=task_id, error=str(e), operation_type=operation_type, operation_id=operation_id, ) raise finally: # Ensure proper cleanup of heartbeat thread if heartbeat_thread and heartbeat_thread.is_alive(): stop_heartbeat.set() try: # Give the thread a reasonable time to clean up heartbeat_thread.join( timeout=10.0 ) # Increased timeout to 10 seconds if heartbeat_thread.is_alive(): logger.warning( "Heartbeat thread did not stop gracefully - forcing cleanup", task_id=task_id, ) # Force cleanup of task state since heartbeat thread is stuck try: task_state_manager.complete_task(task_id) except Exception as cleanup_error: logger.error( "Failed to force cleanup task state", task_id=task_id, error=str(cleanup_error), ) except Exception as e: logger.error( "Error stopping heartbeat thread", task_id=task_id, error=str(e), ) # Clean up task state try: task_state_manager.complete_task(task_id) except Exception as e: logger.error( "Failed to clean up task state", task_id=task_id, error=str(e), ) # Default implementations to be overridden by subclasses def reset_state_on_failure(self, session, operation_id, exc): """self.is_retryable = is_retryable Reset operation state for retry. To be implemented by subclasses. This default implementation logs a warning and raises an error to ensure subclasses properly implement their own retry logic. Args: session: SQLAlchemy session operation_id: ID of the operation exc: The exception that caused the failure Raises: NotImplementedError: Always raised to ensure subclasses implement their own logic """ logger.warning( "reset_state_on_failure not implemented for task", task_name=self.name, operation_id=operation_id, error=str(exc), ) raise NotImplementedError( f"Task {self.name} must implement reset_state_on_failure to handle retries properly" ) def mark_permanent_failure(self, session, operation_id, exc): """ Mark operation as permanently failed. To be implemented by subclasses. This default implementation logs a warning and raises an error to ensure subclasses properly implement their own failure handling logic. Args: session: SQLAlchemy session operation_id: ID of the operation exc: The exception that caused the failure Raises: NotImplementedError: Always raised to ensure subclasses implement their own logic """ logger.warning( "mark_permanent_failure not implemented for task", task_name=self.name, operation_id=operation_id, error=str(exc), ) raise NotImplementedError( f"Task {self.name} must implement mark_permanent_failure to handle permanent failures properly" ) def _get_operation_details(self, session, operation_id): """Get detailed information about the operation from the database.""" try: if self.operation_type == "package": return session.query(models.DataTransferPackage).get(operation_id) elif self.operation_type == "transfer": return session.query(models.DataTransfer).get(operation_id) elif self.operation_type == "unpack": return session.query(models.DataTransfer).get(operation_id) elif self.operation_type == "archive": return session.query(models.LongTermArchiveTransfer).get( operation_id ) elif self.operation_type == "delete": return session.query(models.PhysicalCopy).get(operation_id) return None except Exception as e: logger.error( "Failed to get operation details", operation_type=self.operation_type, operation_id=operation_id, error=str(e), ) return None def _format_operation_details(self, operation): """Format operation details for notification message.""" if not operation: return "No operation details available" details = [] for key, value in operation.__dict__.items(): if not key.startswith("_"): details.append(f"- {key}: {value}") return "\n".join(details) return CCATEnhancedSQLAlchemyTask
[docs] def start_celery_worker(queue=None, concurrency=None): """ Starts a Celery worker for a specific queue with controlled concurrency. Parameters ---------- queue : str, optional The name of the queue to process. If None, the worker will consume from all queues. concurrency : int, optional Number of worker processes/threads. If None, defaults to the number of CPU cores. Returns ------- None Raises ------ ServiceExit When a termination signal is received. Notes ----- - Uses subprocess to spawn a new Celery worker process. - Configures SIGTERM signal handler for graceful shutdown. - Redirects standard output and error to /dev/null. """ import os import signal import subprocess # Set up signal handler for graceful shutdown signal.signal(signal.SIGTERM, service_shutdown) with open(os.devnull, "w") as _: health_check = HealthCheck( service_type="data_transfer", service_name=f"celery_worker_{queue}", ) health_check.start() logging.info(f"Starting celery worker for queue {queue}") try: env = os.environ # Modified command to ensure strict queue binding command = [ "celery", "-A", "ccat_data_transfer.setup_celery_app.app", "worker", "--loglevel=INFO", "-Q", queue, "--purge", # Clear any existing messages on startup "-n", f"worker.{queue}@%h", # Give unique name to worker "--without-mingle", # Prevent workers from importing tasks from other workers "--without-gossip", # Disable worker-worker communication ] if concurrency is not None: command.extend(["-c", str(concurrency)]) # Start the Celery worker as a subprocess proc = subprocess.Popen(command, env=env) # Wait for the worker process to complete proc.wait() except ServiceExit: # Send SIGTERM for warm shutdown to the celery worker proc.terminate() proc.wait() finally: health_check.stop()
# Add to existing Celery configuration app.conf.beat_schedule = { "monitor-disk-usage-cologne": { "task": "ccat:data_transfer:monitor_disk_usage:cologne", "schedule": 60.0, # run every minute }, "monitor-disk-usage-us": { "task": "ccat:data_transfer:monitor_disk_usage:us", "schedule": 60.0, }, "monitor-disk-usage-fyst": { "task": "ccat:data_transfer:monitor_disk_usage:fyst", "schedule": 60.0, }, }
[docs] def start_celery_beat(): """ Starts the Celery beat scheduler. """ import os import signal import subprocess # Set up signal handler for graceful shutdown signal.signal(signal.SIGTERM, service_shutdown) try: env = os.environ command = [ "celery", "-A", "ccat_data_transfer.setup_celery_app.app", "beat", "--loglevel=INFO", ] # Start the Celery beat scheduler as a subprocess proc = subprocess.Popen(command, env=env) # Wait for the process to complete proc.wait() except ServiceExit: # Send SIGTERM for warm shutdown proc.terminate() proc.wait()