import datetime
import logging
import os
import json
import subprocess
from typing import List, Tuple, Dict, Any, Optional
import time
from celery.utils.log import get_task_logger
from sqlalchemy.orm import Session
from ccat_ops_db import models
from .database import DatabaseConnection
from .setup_celery_app import app, make_celery_task
from .utils import (
create_local_folder,
create_remote_folder,
make_bbcp_command,
parse_bbcp_output,
)
from .config.config import ccat_data_transfer_settings
from .decorators import track_metrics
from .metrics import HousekeepingMetrics
from .logging_utils import BBCPLogHandler
from .exceptions import (
BBCPError,
NetworkError,
SegmentationFaultError,
DestinationFileExistsError,
)
from .logging_utils import get_structured_logger
from .bbcp_settings import BBCPSettings
from .notification_service import NotificationClient
from .buffer_manager import buffer_manager
from .utils import get_redis_connection
from .queue_discovery import route_task_by_location
from .operation_types import OperationType
# Use only the task logger
logger = get_structured_logger(__name__)
# logger = get_task_logger(__name__)
ops_logger = get_task_logger("ccat_ops_db")
redis_ = get_redis_connection()
[docs]
class DataTransferTask(make_celery_task()):
"""Base class for data transfer tasks."""
[docs]
def __init__(self):
super().__init__()
self.operation_type = "transfer"
self.notification_client = NotificationClient()
[docs]
def get_retry_count(self, session, data_transfer_id):
"""Get current retry count for this data transfer operation."""
data_transfer = session.query(models.DataTransfer).get(data_transfer_id)
if data_transfer and hasattr(data_transfer, "retry_count"):
return data_transfer.retry_count
raise ValueError("Data transfer not found or retry count not available")
[docs]
def reset_state_on_failure(self, session, data_transfer_id, exc):
"""Reset data transfer state for retry."""
data_transfer = session.query(models.DataTransfer).get(data_transfer_id)
if data_transfer:
data_transfer.status = models.Status.PENDING
for (
raw_data_package
) in data_transfer.data_transfer_package.raw_data_packages:
raw_data_package.state = models.PackageState.TRANSFERRING
data_transfer.failure_error_message = None
data_transfer.retry_count += 1
logger.info(
"Reset transfer for retry",
data_transfer_id=data_transfer_id,
retry_count=data_transfer.retry_count,
)
[docs]
def mark_permanent_failure(self, session, data_transfer_id, exc):
"""Mark data transfer as permanently failed."""
data_transfer = session.query(models.DataTransfer).get(data_transfer_id)
if data_transfer:
data_transfer.status = models.Status.FAILED
for (
raw_data_package
) in data_transfer.data_transfer_package.raw_data_packages:
raw_data_package.state = models.PackageState.FAILED
data_transfer.failure_error_message = str(exc)
logger.info(
"Marked transfer as permanently failed",
data_transfer_id=data_transfer_id,
)
[docs]
def get_operation_info(self, args, kwargs):
"""Get additional context for data transfer tasks."""
if not args or len(args) == 0:
return {}
with self.session_scope() as session:
try:
data_transfer = session.query(models.DataTransfer).get(args[0])
if data_transfer:
return {
"source_location": data_transfer.origin_location.name,
"destination_location": data_transfer.destination_location.name,
"package_id": str(data_transfer.data_transfer_package_id),
}
except Exception as e:
logger.error(f"Error getting transfer info: {e}")
return {}
[docs]
def on_failure(self, exc, task_id, args, kwargs, einfo):
"""Handle task failure with recovery for specific error cases."""
operation_id = self.get_operation_id(args)
if not operation_id:
logger.error("No operation ID found in task arguments")
return
try:
with self.session_scope() as session:
# Handle specific error cases
if isinstance(exc, DestinationFileExistsError):
# Remove the existing destination file
try:
_remove_destination_file(
exc.destination_path,
session.query(models.DataTransfer).get(operation_id),
)
# Reset state for retry
self.reset_state_on_failure(session, operation_id, exc)
logger.info(
"Removed existing destination file and scheduled for retry",
operation_id=operation_id,
)
return
except Exception as e:
logger.error(
"Failed to remove destination file",
operation_id=operation_id,
error=str(e),
)
except Exception as e:
logger.error(
"Error in failure handling",
task_id=task_id,
operation_id=operation_id,
error=str(e),
)
# Call parent's on_failure to handle the standard failure logic
super().on_failure(exc, task_id, args, kwargs, einfo)
@app.task(
base=DataTransferTask,
name="ccat:data_transfer:transfer",
bind=True,
)
def transfer_files_bbcp(
self, data_transfer_id: int, session: Optional[Session] = None
) -> None:
"""Transfer files using BBCP with dynamic queue routing."""
if session is None:
with self.session_scope() as session:
return _transfer_files_bbcp_internal(session, data_transfer_id)
return _transfer_files_bbcp_internal(session, data_transfer_id)
@track_metrics(
operation_type="data_transfer",
additional_tags={
"transfer_method": "bbcp",
},
)
def _transfer_files_bbcp_internal(session: Session, data_transfer_id: int) -> None:
"""
Internal function to transfer files using BBCP.
"""
logger = get_structured_logger(__name__)
start_time = datetime.datetime.now()
data_transfer = _get_data_transfer(session, data_transfer_id)
_log_transfer_start(data_transfer)
data_transfer.start_time = start_time
session.commit()
# Create destination directory if needed (for disk locations)
if isinstance(data_transfer.destination_location, models.DiskDataLocation):
destination_path = os.path.join(
data_transfer.destination_location.path,
data_transfer.data_transfer_package.relative_path,
)
destination_dir = os.path.dirname(destination_path)
logger.info(
"creating_destination_folder",
folder=destination_dir,
)
if data_transfer.destination_location.host == "localhost":
create_local_folder(destination_dir)
else:
create_remote_folder(
data_transfer.destination_location.user,
data_transfer.destination_location.host,
destination_dir,
)
source_url, destination_url = _construct_transfer_urls(data_transfer)
result, transfer_metrics = _execute_bbcp_command(
session, data_transfer, source_url, destination_url
)
end_time = datetime.datetime.now()
data_transfer.end_time = end_time
session.commit()
bbcp_settings = BBCPSettings()
current_settings = bbcp_settings.get_all_settings()
# Send metrics to InfluxDB
metrics = HousekeepingMetrics()
try:
metrics.send_transfer_metrics(
operation="bbcp_transfer",
source_path=source_url,
destination_path=destination_url,
file_size=transfer_metrics["bytes_transferred"],
duration=transfer_metrics["duration"],
success=(result.returncode == 0),
error_message=result.stderr if result.returncode != 0 else None,
additional_fields={
"peak_transfer_rate_mbps": transfer_metrics["peak_transfer_rate_mbps"],
"average_transfer_rate_mbps": transfer_metrics[
"average_transfer_rate_mbps"
],
"number_of_streams": transfer_metrics["number_of_streams"],
"network_errors": transfer_metrics["network_errors"],
"retry_count": data_transfer.retry_count,
},
additional_tags={
"source_location": data_transfer.origin_location.name,
"destination_location": data_transfer.destination_location.name,
"transfer_id": str(data_transfer_id),
"transfer_method": "bbcp",
**{
k.lower(): v for k, v in current_settings.items()
}, # Add BBCP settings as tags
},
)
except Exception as e:
logger.error("metrics_send_failed", error=e, transfer_id=data_transfer_id)
finally:
metrics.close()
# Update status
_update_data_transfer_status(
session, data_transfer, result, start_time, end_time, transfer_metrics
)
session.commit()
redis_.publish(
"transfer:overview",
json.dumps({"type": "transfer_completed", "data": data_transfer.id}),
)
def _get_data_transfer(session: Session, data_transfer_id: int) -> models.DataTransfer:
"""
Retrieve the data transfer object from the database.
Parameters
----------
session : sqlalchemy.orm.Session
The database session.
data_transfer_id : int
The ID of the data transfer to retrieve.
Returns
-------
models.DataTransfer
The retrieved DataTransfer object.
"""
return session.get(models.DataTransfer, data_transfer_id)
def _log_transfer_start(data_transfer: models.DataTransfer) -> None:
"""Log the start of the data transfer."""
logger.debug(
"transfer_start",
transfer_id=data_transfer.id,
file=data_transfer.data_transfer_package.file_name,
source=data_transfer.origin_location.name,
destination=data_transfer.destination_location.name,
)
def _construct_transfer_urls(data_transfer: models.DataTransfer) -> Tuple[str, str]:
"""
Construct the source and destination URLs for the BBCP command.
Handles different storage types (disk, S3, tape) polymorphically.
Parameters
----------
data_transfer : models.DataTransfer
The DataTransfer object containing the location information.
Returns
-------
Tuple[str, str]
The source and destination URLs.
"""
# Get source and destination paths based on storage type
source_path = _get_location_path(
data_transfer.origin_location, data_transfer.data_transfer_package
)
destination_path = _get_location_path(
data_transfer.destination_location, data_transfer.data_transfer_package
)
# Construct URLs based on storage type
source_url = _construct_url_for_location(data_transfer.origin_location, source_path)
destination_url = _construct_url_for_location(
data_transfer.destination_location, destination_path
)
return source_url, destination_url
def _get_location_path(
data_location: models.DataLocation,
data_transfer_package: models.DataTransferPackage,
) -> str:
"""
Get the full path for a data transfer package at a specific location.
Parameters
----------
data_location : models.DataLocation
The data location.
data_transfer_package : models.DataTransferPackage
The data transfer package.
Returns
-------
str
The full path to the package at this location.
"""
if isinstance(data_location, models.DiskDataLocation):
return os.path.join(data_location.path, data_transfer_package.relative_path)
elif isinstance(data_location, models.S3DataLocation):
return f"{data_location.prefix}{data_transfer_package.relative_path}"
elif isinstance(data_location, models.TapeDataLocation):
return os.path.join(
data_location.mount_path, data_transfer_package.relative_path
)
else:
raise ValueError(f"Unsupported storage type: {data_location.storage_type}")
def _construct_url_for_location(data_location: models.DataLocation, path: str) -> str:
"""
Construct a URL for a specific location type.
Parameters
----------
data_location : models.DataLocation
The data location.
path : str
The file path.
Returns
-------
str
The constructed URL.
"""
if isinstance(data_location, models.DiskDataLocation):
# Handle localhost connections directly without SSH
if data_location.host == "localhost":
return path
else:
return f"{data_location.user}@{data_location.host}:{path}"
elif isinstance(data_location, models.S3DataLocation):
# For S3, construct S3 URL
if data_location.endpoint_url:
# Custom S3-compatible endpoint
return f"s3://{data_location.bucket_name}/{path}"
else:
# AWS S3
_ = data_location.region or "us-east-1"
return f"s3://{data_location.bucket_name}/{path}"
elif isinstance(data_location, models.TapeDataLocation):
# For tape, use local path (tape is mounted locally)
return path
else:
raise ValueError(f"Unsupported storage type: {data_location.storage_type}")
def _remove_destination_file(
destination_url: str, data_transfer: models.DataTransfer
) -> None:
"""Remove destination file if it exists."""
logger = get_structured_logger(__name__)
try:
if isinstance(data_transfer.destination_location, models.DiskDataLocation):
if data_transfer.destination_location.host == "localhost":
if os.path.exists(destination_url):
os.remove(destination_url)
logger.info(
"removed_existing_destination_file",
path=destination_url,
transfer_id=data_transfer.id,
)
else:
# For remote hosts, use SSH to remove the file
ssh_command = [
"ssh",
f"{data_transfer.destination_location.user}@{data_transfer.destination_location.host}",
f"rm -f {destination_url}",
]
result = subprocess.run(
ssh_command, capture_output=True, text=True, check=False
)
if result.returncode == 0:
logger.info(
"removed_existing_destination_file",
path=destination_url,
transfer_id=data_transfer.id,
)
else:
logger.warning(
"failed_to_remove_destination_file",
path=destination_url,
error=result.stderr,
transfer_id=data_transfer.id,
)
elif isinstance(data_transfer.destination_location, models.S3DataLocation):
# For S3, use AWS CLI or boto3 to remove file
# This would need to be implemented based on your S3 access method
logger.info(
"S3 file removal not yet implemented",
path=destination_url,
transfer_id=data_transfer.id,
)
else:
logger.warning(
"File removal not implemented for storage type",
storage_type=data_transfer.destination_location.storage_type,
transfer_id=data_transfer.id,
)
except Exception as e:
logger.error(
"error_removing_destination_file",
path=destination_url,
error=str(e),
transfer_id=data_transfer.id,
)
raise
def _execute_bbcp_command(
session: Session,
data_transfer: models.DataTransfer,
source_url: str,
destination_url: str,
):
logger = get_structured_logger(__name__)
# Add development mode check for local transfers
use_cp = (
ccat_data_transfer_settings.DEVELOPMENT_MODE
and isinstance(data_transfer.destination_location, models.DiskDataLocation)
and data_transfer.destination_location.host == "localhost"
and "@" not in source_url # Ensure source is also local
)
if isinstance(data_transfer.destination_location, models.DiskDataLocation):
if data_transfer.destination_location.host == "localhost":
use_cp = True
else:
use_cp = False
if use_cp:
return _execute_cp_command(session, data_transfer, source_url, destination_url)
bbcp_command = make_bbcp_command(source_url, destination_url)
logger.info(f"Executing BBCP command: {bbcp_command}")
log_handler = BBCPLogHandler()
start_time = time.time()
# check if the remote destination file already exists (for disk locations)
if (
isinstance(data_transfer.destination_location, models.DiskDataLocation)
and data_transfer.destination_location.host != "localhost"
):
ssh_command = [
"ssh",
f"{data_transfer.destination_location.user}@{data_transfer.destination_location.host}",
f"ls {destination_url}",
]
result = subprocess.run(
ssh_command, capture_output=True, text=True, check=False
)
if result.returncode == 0:
logger.info(
"destination_file_already_exists",
path=destination_url,
transfer_id=data_transfer.id,
)
# remove the file from the destination
ssh_command = [
"ssh",
f"{data_transfer.destination_location.user}@{data_transfer.destination_location.host}",
f"rm -f {destination_url}",
]
subprocess.run(ssh_command, capture_output=True, text=True, check=False)
# Separate try block just for subprocess execution
try:
result = subprocess.run(
bbcp_command,
capture_output=True,
text=True,
check=False,
)
except subprocess.SubprocessError as e:
raise BBCPError(
message="Failed to execute BBCP command",
returncode=-1,
stderr=str(e),
transfer_id=data_transfer.id,
)
if result.returncode != 0:
if (
"Connection refused" in result.stderr
or "Connection timed out" in result.stderr
):
raise NetworkError(
result.stderr,
host=(
data_transfer.destination_location.host
if isinstance(
data_transfer.destination_location, models.DiskDataLocation
)
else "unknown"
),
transfer_id=data_transfer.id,
)
elif result.returncode == -11:
raise SegmentationFaultError(
message="Segmentation Fault",
returncode=result.returncode,
stderr=result.stderr,
transfer_id=data_transfer.id,
)
elif "already exists" in result.stderr:
# Handle destination file exists error
raise DestinationFileExistsError(
message=result.stderr,
returncode=result.returncode,
stderr=result.stderr,
transfer_id=data_transfer.id,
destination_path=destination_url,
)
else:
raise BBCPError(
message=result.stderr,
returncode=result.returncode,
stderr=result.stderr,
transfer_id=data_transfer.id,
)
end_time = time.time()
duration = end_time - start_time
logger.debug(
"bbcp_complete",
return_code=result.returncode,
transfer_id=data_transfer.id,
duration=duration,
)
# Store command output
log_handler.store_bbcp_output(
session=session,
data_transfer=data_transfer,
stdout=result.stdout,
stderr=result.stderr,
success=(result.returncode == 0),
)
# Parse metrics
metrics = parse_bbcp_output(result.stdout, result.stderr, duration)
return result, metrics
def _execute_cp_command(
session: Session,
data_transfer: models.DataTransfer,
source_path: str,
destination_path: str,
):
"""Execute cp command for local transfers in development mode."""
logger = get_structured_logger(__name__)
log_handler = BBCPLogHandler()
start_time = time.time()
try:
# Get file size before transfer
file_size = os.path.getsize(source_path)
cp_command = ["cp", source_path, destination_path]
logger.info(
" ".join(cp_command),
command=cp_command,
file_size=file_size,
)
# Execute cp command
result = subprocess.run(
cp_command,
capture_output=True,
text=True,
check=False,
)
end_time = time.time()
duration = end_time - start_time
if result.returncode != 0:
raise BBCPError(
message="Failed to execute cp command",
returncode=result.returncode,
stderr=result.stderr,
transfer_id=data_transfer.id,
)
# Calculate metrics similar to bbcp
transfer_rate = (file_size / duration) / (1024 * 1024) # MB/s
metrics = {
"bytes_transferred": file_size,
"duration": duration,
"peak_transfer_rate_mbps": transfer_rate,
"average_transfer_rate_mbps": transfer_rate,
"number_of_streams": 1,
"network_errors": 0,
}
# Store command output
log_handler.store_bbcp_output(
session=session,
data_transfer=data_transfer,
stdout=f"Copied {file_size} bytes in {duration:.2f} seconds ({transfer_rate:.2f} MB/s)",
stderr=result.stderr,
success=(result.returncode == 0),
)
logger.debug(
"cp_complete",
return_code=result.returncode,
transfer_id=data_transfer.id,
duration=duration,
)
return result, metrics
except subprocess.SubprocessError as e:
raise BBCPError(
message="Failed to execute cp command",
returncode=-1,
stderr=str(e),
transfer_id=data_transfer.id,
)
def _update_data_transfer_status(
session: Session,
data_transfer: models.DataTransfer,
result: subprocess.CompletedProcess,
start_time: datetime.datetime,
end_time: datetime.datetime,
transfer_metrics: Dict[str, Any],
) -> None:
"""Update the status of the data transfer in the database."""
logger = get_structured_logger(__name__)
if result.returncode == 0:
logger.info(
"transfer_complete",
transfer_id=data_transfer.id,
file=data_transfer.data_transfer_package.file_name,
)
data_transfer.start_time = start_time
data_transfer.end_time = end_time
data_transfer.status = models.Status.COMPLETED
# Create physical copy record at destination
destination_path = _get_location_path(
data_transfer.destination_location, data_transfer.data_transfer_package
)
physical_copy = models.DataTransferPackagePhysicalCopy(
data_transfer_package=data_transfer.data_transfer_package,
data_location=data_transfer.destination_location,
status=models.PhysicalCopyStatus.PRESENT,
checksum=data_transfer.data_transfer_package.checksum,
)
session.add(physical_copy)
logger.info("Storing Information on copy")
else:
logger.error(
"transfer_error",
transfer_id=data_transfer.id,
file=data_transfer.data_transfer_package.file_name,
return_code=result.returncode,
error=result.stderr.strip(),
)
# Instead of directly updating status, raise an exception to trigger retry handling
if (
"Connection refused" in result.stderr
or "Connection timed out" in result.stderr
):
raise NetworkError(
result.stderr,
host=(
data_transfer.destination_location.host
if isinstance(
data_transfer.destination_location, models.DiskDataLocation
)
else "unknown"
),
transfer_id=data_transfer.id,
)
elif result.returncode == -11:
raise SegmentationFaultError(
message="Segmentation Fault",
returncode=result.returncode,
stderr=result.stderr,
transfer_id=data_transfer.id,
)
elif "already exists" in result.stderr:
raise DestinationFileExistsError(
message=result.stderr,
returncode=result.returncode,
stderr=result.stderr,
transfer_id=data_transfer.id,
destination_path=destination_path,
)
else:
raise BBCPError(
message=result.stderr,
returncode=result.returncode,
stderr=result.stderr,
transfer_id=data_transfer.id,
)
session.commit()
redis_.publish(
"transfer:overview",
json.dumps({"type": "transfer_completed", "data": data_transfer.id}),
)
def _get_pending_transfers(session: Session) -> List[models.DataTransfer]:
"""
Retrieve all pending data transfers from the database.
Parameters
----------
session : sqlalchemy.orm.Session
The database session.
Returns
-------
List[models.DataTransfer]
A list of pending DataTransfer objects.
"""
pending_transfers = (
session.query(models.DataTransfer).filter_by(status=models.Status.PENDING).all()
)
return pending_transfers
def _filter_supported_transfers(
transfers: List[models.DataTransfer],
) -> Tuple[List[models.DataTransfer], List[str]]:
"""
Filter transfers based on supported transfer methods.
Parameters
----------
transfers : List[models.DataTransfer]
A list of DataTransfer objects to filter.
Returns
-------
Tuple[List[models.DataTransfer], List[str]]
A tuple containing a list of supported transfers and a list of unsupported transfer methods.
"""
supported = []
unsupported_methods = set()
for transfer in transfers:
if (
transfer.data_transfer_method
in ccat_data_transfer_settings.SUPPORTED_DATA_TRANSFER_METHODS
):
supported.append(transfer)
else:
unsupported_methods.add(transfer.data_transfer_method)
return supported, list(unsupported_methods)
def _process_transfer(transfer: models.DataTransfer, session: Session) -> None:
"""Process a single data transfer."""
logger = get_structured_logger(__name__)
# Check if we can create new data based on buffer state
if not buffer_manager.can_create_data():
logger.warning(
"Buffer in emergency state, postponing transfer",
transfer_id=transfer.id,
file=transfer.data_transfer_package.file_name,
)
transfer.status = models.Status.PENDING
session.commit()
redis_.publish(
"transfer:overview",
json.dumps({"type": "transfer_pending", "data": transfer.id}),
)
return
if transfer.data_transfer_method == "bbcp":
# Use dynamic queue routing based on origin location
queue_name = route_task_by_location(
OperationType.DATA_TRANSFER, transfer.origin_location
)
task_args = {
"args": (transfer.id,),
"queue": queue_name,
}
# Get max parallel transfers based on buffer state
max_transfers = buffer_manager.get_max_parallel_transfers()
# Apply rate limiting based on buffer state
if max_transfers < ccat_data_transfer_settings.DATA_TRANSFER_WORKERS:
logger.info(
"Reducing parallel transfers due to buffer state",
max_transfers=max_transfers,
normal_workers=ccat_data_transfer_settings.DATA_TRANSFER_WORKERS,
)
# Use a rate-limited queue
task_args["queue"] = f"{queue_name}-limited"
# Apply the task using the unified transfer function
transfer_files_bbcp.apply_async(**task_args)
transfer.status = models.Status.SCHEDULED
session.commit()
redis_.publish(
"transfer:overview",
json.dumps({"type": "transfer_scheduled", "data": transfer.id}),
)
logger.debug(
"transfer_scheduled",
transfer_id=transfer.id,
file=transfer.data_transfer_package.file_name,
source=transfer.origin_location.name,
queue=queue_name,
)
[docs]
def transfer_transfer_packages(verbose: bool = False, session: Session = None) -> None:
"""
Find not yet transferred data transfer packages and schedule their transfer.
Parameters
----------
verbose : bool, optional
If True, sets the logging level to DEBUG. Default is False.
Returns
-------
None
Notes
-----
- Updates the logging level if verbose is True.
- Retrieves pending data transfers from the database.
- Schedules Celery tasks for file transfers.
- Updates data transfer statuses in the database.
- Logs information about the transfer process.
- Handles database errors and unexpected exceptions.
"""
if verbose or ccat_data_transfer_settings.VERBOSE:
logger.setLevel(logging.DEBUG)
logger.debug("verbose_mode_enabled")
db = DatabaseConnection()
should_close_session = False
if session is None:
db = DatabaseConnection()
session, _ = db.get_connection()
should_close_session = True
try:
pending_transfers = _get_pending_transfers(session)
supported_transfers, unsupported_methods = _filter_supported_transfers(
pending_transfers
)
if unsupported_methods:
logger.error("unsupported_methods", methods=",".join(unsupported_methods))
if len(pending_transfers) > 0:
logger.debug(
"pending_transfers_found",
count=len(pending_transfers),
first_id=pending_transfers[0].id,
)
else:
logger.debug("no_pending_transfers")
for transfer in supported_transfers:
_process_transfer(transfer, session)
except Exception as e:
logger.error("service_loop_error", error=str(e))
finally:
if should_close_session:
session.close()