import hashlib
import os
import subprocess
import time
import zipfile
from typing import List, Tuple, Optional, Dict, Any
import logging
import redis
import re
import tempfile
import boto3
from .config.config import ccat_data_transfer_settings
from .exceptions import ServiceExit, ArchiveCorruptionError
from .logging_utils import get_structured_logger
from ccat_ops_db import models
# from .bbcp_settings import BBCPSettings
logger = get_structured_logger(__name__)
# Singleton Redis client
_redis_client = None
# Singleton S3 client
_s3_client = None
[docs]
def get_redis_connection() -> redis.StrictRedis:
"""
Establish a connection to the Redis server.
This function implements a singleton pattern to reuse the same Redis connection.
Returns
-------
redis.StrictRedis
"""
global _redis_client
if _redis_client is None:
# Create the Redis client only once
_redis_client = redis.Redis(
host=ccat_data_transfer_settings.REDIS_HOST,
port=ccat_data_transfer_settings.REDIS_PORT,
db=0,
decode_responses=True,
ssl=True,
ssl_cert_reqs="required",
ssl_ca_certs=ccat_data_transfer_settings.REDIS_CA_CERT,
password=ccat_data_transfer_settings.REDIS_PASSWORD,
ssl_certfile=ccat_data_transfer_settings.REDIS_CERTFILE,
ssl_keyfile=ccat_data_transfer_settings.REDIS_KEYFILE,
# Add connection pooling settings
max_connections=10, # Limit the number of connections
socket_timeout=5, # Set a reasonable timeout
socket_connect_timeout=5,
retry_on_timeout=True,
)
logger.info("Created new Redis connection")
return _redis_client
[docs]
def get_s3_client(
location: Optional[models.S3DataLocation] = None, site_name: Optional[str] = None
) -> boto3.client:
"""
Establish a connection to the S3 server.
This function implements a singleton pattern to reuse the same S3 connection.
Parameters
----------
location : Optional[models.S3DataLocation]
S3DataLocation object to get specific configuration for.
If None, uses default configuration.
site_name : Optional[str]
Name of the site for credential loading.
Required if location is provided.
Returns
-------
boto3.client
"""
global _s3_client
# Get configuration for the specific location or use default
if location and site_name:
# Get location-specific credentials
access_key_id, secret_access_key = location.get_s3_credentials(site_name)
# Create a new client for this specific location
client = boto3.client(
"s3",
endpoint_url=location.endpoint_url,
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
region_name=location.region,
config=boto3.session.Config(signature_version="s3v4"),
)
logger.info(f"Created new S3 connection for location: {location.name}")
return client
# Use default configuration (singleton pattern)
if _s3_client is None:
# Create the S3 client only once
_s3_client = boto3.client(
"s3",
endpoint_url=ccat_data_transfer_settings.s3_endpoint_url,
aws_access_key_id=ccat_data_transfer_settings.s3_access_key_id,
aws_secret_access_key=ccat_data_transfer_settings.s3_secret_access_key,
region_name=ccat_data_transfer_settings.s3_region_name,
config=boto3.session.Config(signature_version="s3v4"),
)
logger.info("Created new S3 connection")
return _s3_client
[docs]
def service_shutdown(signum: int, frame) -> None:
"""
Handle service shutdown signal.
Parameters
----------
signum : int
The signal number.
frame : frame
Current stack frame.
Raises
------
ServiceExit
Raised to initiate the service exit process.
"""
logger.info(f"Caught signal {signum}")
raise ServiceExit
[docs]
def unique_id() -> str:
"""
Generate a unique ID based on the current timestamp.
Returns
-------
str
A hexadecimal string representing the unique ID.
"""
current_time = str(time.time())
hash_object = hashlib.sha1(current_time.encode())
return hash_object.hexdigest()
[docs]
def create_archive(
files: List, archive_name: str, base_path: str
) -> Tuple[str, List[str]]:
"""
Create a tar archive optimized for high-speed transfer using system tar command.
Parameters
----------
files : List
A list of RawDataPackage objects, each with a 'relative_path' attribute.
archive_name : str
The name (including path) of the tar archive to be created.
base_path : str
The base path to prepend to the relative paths.
Returns
-------
Tuple[str, List[str]]
A tuple containing the archive name and a list of file names included in the archive.
"""
# Add timing instrumentation
start_time = time.time()
# Ensure the archive ends with .tar
if not archive_name.endswith(".tar"):
archive_name = archive_name.rsplit(".", 1)[0] + ".tar"
# Log the archive name being used - helps debug any name mismatch issues
logger.info(f"Creating archive with name: {os.path.basename(archive_name)}")
file_names = []
existing_files = []
# First, collect all valid files
for file in files:
full_path = os.path.join(base_path, file.relative_path)
if os.path.exists(full_path):
existing_files.append(
(file.relative_path, full_path)
) # Store relative path first
file_names.append(file.relative_path)
else:
logger.error(
f"File {full_path} does not exist and will not be added to the archive."
)
if not existing_files:
logger.error("No valid files to add to archive")
return archive_name, []
# Create a temporary file listing relative paths for tar
with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
for rel_path, _ in existing_files:
temp_file.write(f"{rel_path}\n")
temp_file_path = temp_file.name
try:
# Use system tar command for much better performance
tar_cmd = [
"tar",
"-cf", # Create uncompressed tar file
archive_name, # Output file
"-C",
base_path, # Change to base directory
"--files-from",
temp_file_path, # Read file list from temp file
]
logger.info(f"Running tar command: {' '.join(tar_cmd)}")
# Add detailed timing for subprocess
cmd_start = time.time()
result = subprocess.run(tar_cmd, capture_output=True, text=True, check=False)
cmd_end = time.time()
logger.info(f"Tar command execution time: {cmd_end - cmd_start:.2f} seconds")
if result.returncode != 0:
logger.error(f"Error creating tar archive: {result.stderr}")
return archive_name, []
logger.info(f"Successfully created tar archive with {len(file_names)} files")
except Exception as e:
logger.error(f"Error running tar command: {str(e)}")
return archive_name, []
finally:
# Clean up temp file
try:
os.unlink(temp_file_path)
except Exception as e:
logger.error(f"Error cleaning up temp file: {str(e)}")
end_time = time.time()
logger.info(f"Total archive creation time: {end_time - start_time:.2f} seconds")
return archive_name, file_names
[docs]
def unpack_local(archive_path: str, destination: str) -> Tuple[bool, List[str]]:
"""
Unpack a file locally.
Parameters
----------
archive_path : str
The path to the archive file (tar or zip)
destination : str
The path where the archive should be extracted.
Returns
-------
Tuple[bool, List[str]]
A tuple containing a boolean indicating success (True) or failure (False), and a
list of unpacked files (empty if failed).
Raises
------
ArchiveCorruptionError
If the archive is corrupted or incomplete.
"""
try:
extracted_files = []
if archive_path.endswith(".tar"):
# First verify the archive integrity
verify_cmd = ["tar", "-tf", archive_path]
verify_result = subprocess.run(
verify_cmd, capture_output=True, text=True, check=False
)
if verify_result.returncode != 0:
error_msg = verify_result.stderr.strip()
if "Unexpected EOF" in error_msg:
raise ArchiveCorruptionError(
f"Archive is corrupted or incomplete: {error_msg}",
archive_path=archive_path,
)
logger.error(f"Error verifying tar archive: {error_msg}")
return False, []
# Use system tar command for better performance and control
tar_cmd = [
"tar",
"-xf", # Extract files
archive_path, # Input file
"-C", # Change to directory
destination, # Destination directory
]
logger.info(f"Running tar extraction command: {' '.join(tar_cmd)}")
result = subprocess.run(
tar_cmd, capture_output=True, text=True, check=False
)
if result.returncode != 0:
error_msg = result.stderr.strip()
if "Unexpected EOF" in error_msg:
raise ArchiveCorruptionError(
f"Archive is corrupted or incomplete: {error_msg}",
archive_path=archive_path,
)
logger.error(f"Error extracting tar archive: {error_msg}")
return False, []
# Get the list of files from the tar archive
list_cmd = ["tar", "-tf", archive_path]
list_result = subprocess.run(
list_cmd, capture_output=True, text=True, check=False
)
if list_result.returncode == 0:
extracted_files = list_result.stdout.strip().split("\n")
elif archive_path.endswith(".zip"):
try:
with zipfile.ZipFile(archive_path, "r") as zip_ref:
# Test the zip file integrity
if zip_ref.testzip() is not None:
raise ArchiveCorruptionError(
"ZIP file is corrupted", archive_path=archive_path
)
zip_ref.extractall(destination)
extracted_files = zip_ref.namelist()
except zipfile.BadZipFile as e:
raise ArchiveCorruptionError(
f"ZIP file is corrupted: {str(e)}", archive_path=archive_path
)
else:
logger.error(
"Unsupported archive format",
archive_path=archive_path,
supported_formats=".tar, .zip",
)
return False, []
logger.info(
"Successfully unpacked archive locally",
archive_path=archive_path,
destination=destination,
)
return True, extracted_files
except ArchiveCorruptionError:
raise
except Exception as e:
logger.error(
"Error unpacking archive locally",
archive_path=archive_path,
error=str(e),
)
return False, []
[docs]
def unpack_remote(
user: Optional[str], host: Optional[str], archive_path: str, destination: str
) -> Tuple[bool, List[str]]:
"""Unpack a file on a remote host or locally. UNUSED
Parameters
----------
user : Optional[str]
The username for SSH connection. Use None for local operations.
host : Optional[str]
The hostname or IP address of the remote machine. Use None for local operations.
archive_path : str
The path to the archive file (tar or zip).
destination : str
The path where the archive should be extracted.
Returns
-------
Tuple[bool, List[str]]
A tuple containing a boolean indicating success (True) or failure (False), and a
list of unpacked files (empty if failed).
"""
if user is None or host is None:
return unpack_local(archive_path, destination)
# Choose command based on file extension
if archive_path.endswith(".tar"):
extract_cmd = f"tar xf {archive_path} -C {destination}"
elif archive_path.endswith(".zip"):
extract_cmd = f"unzip -o {archive_path} -d {destination}"
else:
logger.error("Unsupported archive format", archive_path=archive_path)
return False, []
ssh_command = [
"ssh",
f"{user}@{host}",
f"{extract_cmd} && echo 'SUCCESS' || (echo 'FAILURE'; exit 1)",
]
logger.info("Running SSH command", command=" ".join(ssh_command))
try:
result = subprocess.run(
ssh_command, capture_output=True, text=True, check=True, timeout=300
)
output = result.stdout.strip()
if "SUCCESS" in output:
logger.info(
"Successfully unpacked archive remotely",
archive_path=archive_path,
destination=destination,
user=user,
host=host,
)
# Extract file list
file_list = []
if archive_path.endswith(".zip"):
file_list = [
line.split(":")[-1].strip()
for line in output.splitlines()
if line.startswith(" inflating: ")
]
# For tar extraction, we just note success but don't parse the file list currently
return True, file_list
else:
error_message = output.splitlines()[-1] if output else "Unknown error"
logger.error(
"Failed to unpack archive remotely",
archive_path=archive_path,
destination=destination,
user=user,
host=host,
error=error_message,
)
return False, []
except subprocess.CalledProcessError as e:
logger.error(
"Error unpacking archive remotely",
archive_path=archive_path,
destination=destination,
user=user,
host=host,
exit_code=e.returncode,
error=e.stderr.strip(),
)
return False, []
except subprocess.TimeoutExpired:
logger.error(
"Timeout while unpacking archive remotely",
archive_path=archive_path,
destination=destination,
user=user,
host=host,
)
return False, []
except Exception as e:
logger.error(
"Unexpected error while unpacking archive remotely",
archive_path=archive_path,
destination=destination,
user=user,
host=host,
error=str(e),
)
return False, []
[docs]
def calculate_checksum(filepath: str) -> Optional[str]:
"""Calculate xxHash64 checksum of a file for fast integrity verification."""
start_time = time.time()
try:
result = subprocess.run(
["xxh64sum", filepath], capture_output=True, text=True, check=True
)
checksum = result.stdout.strip().split()[0]
elapsed = time.time() - start_time
logger.info(
f"xxHash checksum for file '{filepath}': {checksum}, took {elapsed:.2f} seconds"
)
return checksum
except Exception as e:
logger.error(f"Error calculating checksum for '{filepath}': {e}")
return None
[docs]
def make_bbcp_command(source_url: str, destination_url: str) -> List[str]:
"""
Construct the bbcp command.
Parameters
----------
source_url : str
The source URL for the bbcp transfer.
destination_url : str
The destination URL for the bbcp transfer.
Returns
-------
List[str]
A list of strings representing the bbcp command and its arguments.
"""
command = ["/usr/bin/bbcp"]
# preserve source mode, ownership, and dates.
command += ["-p"]
command += ["-P", "2"]
# Add verbose options
if ccat_data_transfer_settings.get("BBCP_VERBOSE") == 1:
command += ["-v"]
elif ccat_data_transfer_settings.get("BBCP_VERBOSE") == 2:
command += ["-V"]
# Add window size if specified
if window_size := ccat_data_transfer_settings.get("BBCP_WINDOW_SIZE"):
command.extend(["-w", str(window_size)])
# Add parallel streams if specified
if streams := ccat_data_transfer_settings.get("BBCP_PARALLEL_STREAMS"):
command.extend(["-s", str(streams)])
# Add source path if specified
if source_path := ccat_data_transfer_settings.get("BBCP_SOURCE_PATH"):
command.extend(["-S", str(source_path)])
# Add target path if specified
if target_path := ccat_data_transfer_settings.get("BBCP_TARGET_PATH"):
command.extend(["-T", str(target_path)])
# Add source and destination URLs
command.extend([source_url, destination_url])
# Log the command after ensuring all elements are strings
command_str = " ".join(str(arg) for arg in command)
logger.debug(f"BBCP command: {command_str}")
return command
[docs]
def create_local_folder(folder: str) -> bool:
"""
Create a local folder.
Parameters
----------
folder : str
The path of the folder to be created.
Returns
-------
bool
True if the folder was created successfully or already exists, False otherwise.
"""
try:
os.makedirs(folder, exist_ok=True)
logger.info(f"Local folder {folder} created or already exists")
return True
except Exception as e:
logger.error(f"Error creating local folder {folder}: {str(e)}")
return False
[docs]
def create_remote_folder(user: str, host: str, folder: str) -> bool:
"""
Create a remote folder.
Parameters
----------
user : str
The username for SSH connection.
host : str
The hostname or IP address of the remote machine.
folder : str
The path of the folder to be created.
Returns
-------
bool
True if the folder was created successfully or already exists, False otherwise.
"""
try:
ssh_command = ["ssh", f"{user}@{host}", f"mkdir -p {folder}"]
subprocess.run(ssh_command, check=True)
logger.info(
f"Remote folder {folder} created or already exists on {user}@{host}"
)
return True
except Exception as e:
logger.error(
f"Error creating remote folder {folder} on {user}@{host}: {str(e)}"
)
return False
[docs]
def make_long_term_archive_copy_command(
source_url: str, destination_url: str
) -> List[str]:
"""
Construct the bbcp command for copying to long term archive.
Parameters
----------
source_url : str
The source URL for the bbcp transfer.
destination_url : str
The destination URL for the bbcp transfer.
Returns
-------
List[str]
A list of strings representing the bbcp command and its arguments.
"""
command = make_bbcp_command(source_url, destination_url)
logger.info(f"Copying to long term archive command: {' '.join(command)}")
return command
[docs]
def run_ssh_command(user, host, command):
"""Run an SSH command on a remote host and return the output."""
log = logging.getLogger(__name__)
ssh_command = ["ssh", f"{user}@{host}", command]
result = subprocess.run(
ssh_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
if result.returncode == 0:
return result.stdout.strip()
else:
log.error(
f"SSH command failed: {command} on {user}@{host}: {result.stderr.strip()}"
)
return None
[docs]
def check_remote_folder_size_gb(user, host, parent_path):
"""Check the size of a remote folder and return it in gigabytes."""
log = logging.getLogger(__name__)
size_command = f"du -sb {parent_path} | cut -f1"
size_output = run_ssh_command(user, host, size_command)
if size_output:
try:
size_gb = int(size_output) / (1024**3) # Convert bytes to GB
return round(size_gb, 2) # Round to 2 decimal places
except ValueError:
log.error(f"Error parsing folder size: {size_output}")
return None
[docs]
def parse_bbcp_output(stdout: bytes, stderr: bytes, duration: float) -> Dict[str, Any]:
"""
Parse BBCP command output to extract transfer metrics.
Parameters
----------
stdout : bytes
Standard output from BBCP command
stderr : bytes
Standard error from BBCP command
duration : float
Total duration of the transfer
Returns
-------
Dict[str, Any]
Dictionary containing parsed metrics
"""
metrics = {
"duration": duration,
"peak_transfer_rate_mbps": 0,
"average_transfer_rate_mbps": 0,
"bytes_transferred": 0,
"number_of_streams": 0,
"network_errors": 0,
}
def convert_to_mbps(value: float, unit: str) -> float:
"""Convert various transfer rates to MB/s"""
unit = unit.upper()
if unit == "GB/S":
return value * 1024
elif unit == "KB/S":
return value / 1024
elif unit == "MB/S":
return value
return 0
def extract_rate(rate_str: str) -> float:
"""Extract numeric rate and unit, convert to MB/s"""
try:
# Match number and unit (e.g., "1.6 GB/s" or "15.8 MB/s")
match = re.search(r"([\d.]+)\s*([KMG]B/s)", rate_str, re.IGNORECASE)
if match:
value = float(match.group(1))
unit = match.group(2)
return convert_to_mbps(value, unit)
except (ValueError, AttributeError):
pass
return 0
stdout_text = stdout
stderr_text = stderr
combined_output = stdout_text + stderr_text
# Parse BBCP output
for line in combined_output.split("\n"):
# Look for bytes transferred and transfer rate
if "created;" in line and "bytes at" in line:
try:
# Extract bytes value
bytes_str = line.split("created;")[1].split("bytes")[0].strip()
metrics["bytes_transferred"] = int(bytes_str)
logger.debug(
"Extracted bytes transferred",
bytes_transferred=metrics["bytes_transferred"],
raw_line=line,
)
# Extract peak rate - look for pattern "bytes at X MB/s"
try:
rate_part = line.split("bytes at")[-1].strip()
metrics["peak_transfer_rate_mbps"] = extract_rate(rate_part)
logger.debug(
"Extracted peak transfer rate",
peak_transfer_rate_mbps=metrics["peak_transfer_rate_mbps"],
raw_rate=rate_part,
raw_line=line,
)
except (ValueError, IndexError) as e:
logger.error(
"Failed to parse peak transfer rate",
error=e,
raw_line=line,
)
except (ValueError, IndexError) as e:
logger.error(
"Failed to parse bytes transferred or peak rate",
error=e,
raw_line=line,
)
# Look for effective transfer rate
if "effectively" in line:
try:
rate_str = line.split("effectively")[1].strip()
metrics["average_transfer_rate_mbps"] = extract_rate(rate_str)
logger.debug(
"Extracted average transfer rate",
average_transfer_rate_mbps=metrics["average_transfer_rate_mbps"],
raw_rate=rate_str,
raw_line=line,
)
except (ValueError, IndexError) as e:
logger.error(
"Failed to parse average transfer rate",
error=e,
raw_line=line,
)
# Look for network errors
if "error" in line.lower() or "failed" in line.lower():
metrics["network_errors"] += 1
logger.warning(
"Detected network error",
network_errors_count=metrics["network_errors"],
error_line=line,
)
logger.debug(
"BBCP output parsing complete"
+ " "
+ " ".join(f"{k}={v}" for k, v in metrics.items()),
stdout_length=len(stdout_text),
stderr_length=len(stderr_text),
)
return metrics
[docs]
def calculate_transfer_rate(file_size: int, duration: int) -> float:
"""
Calculate transfer rate in Mbps with float precision.
Parameters
----------
file_size : int
File size in bytes
duration : int
Transfer duration in seconds
Returns
-------
float
Transfer rate in Mbps
"""
if duration <= 0:
return 0.0
# Convert bytes to bits (multiply by 8)
# Convert to Mbps (divide by 1,000,000)
# Maintain float precision
return (file_size * 8.0) / (duration * 1_000_000.0)
[docs]
def generate_readable_filename(
raw_data_package, hash_value, file_type="raw", extension="tar"
):
"""
Generate a human-readable filename that includes metadata and a hash suffix.
Parameters
----------
raw_data_package : models.RawDataPackage
The raw data package containing metadata
hash_value : str
The original hash or UUID used for uniqueness
file_type : str
Type of file (e.g., "raw" or "transfer")
extension : str
File extension (without the dot)
Returns
-------
str
A human-readable filename with hash suffix
"""
# Extract date from package metadata or creation date
date_str = (
raw_data_package.created_at.strftime("%Y%m%d")
if hasattr(raw_data_package, "created_at") and raw_data_package.created_at
else time.strftime("%Y%m%d")
)
# Use first 8 chars of hash (still provides good uniqueness)
short_hash = hash_value[:8]
# Build filename with consistent extension
filename = f"{date_str}_{file_type}_{short_hash}.{extension}"
# Replace invalid characters
return re.sub(r"[^\w\.-]", "_", filename)
[docs]
def get_s3_key_for_package(
data_location: models.S3DataLocation, raw_data_package: models.RawDataPackage
) -> str:
"""
Construct S3 object key for a raw data package using consistent logic.
This function implements the same S3 key construction logic used in archive_manager.py
to ensure consistency between upload and download operations.
Parameters
----------
data_location : models.S3DataLocation
The S3 data location where the package is stored
raw_data_package : models.RawDataPackage
The raw data package to construct the key for
Returns
-------
str
The S3 object key for the package
Notes
-----
The S3 key is constructed as:
1. Replace underscores with slashes in the location name
2. Join with the package's relative path
3. Replace all slashes with underscores
4. Remove leading slash
"""
# Use the same logic as in archive_manager.py
location_path = data_location.name.replace("_", "/")
# Construct the destination path similar to the old implementation
destination = os.path.normpath(
os.path.join(
location_path,
raw_data_package.relative_path,
)
)
# Apply the same reformatting: replace / with _ and remove leading /
s3_key = destination.lstrip("/").replace("/", "_")
logger.debug(
"constructed_s3_key",
location_name=data_location.name,
location_path=location_path,
package_relative_path=raw_data_package.relative_path,
destination=destination,
s3_key=s3_key,
)
return s3_key
[docs]
def get_s3_key_for_file(
data_location: models.S3DataLocation, raw_data_file: models.RawDataFile
) -> str:
"""
Construct S3 object key for a raw data file using consistent logic.
This function implements the same S3 key construction logic for individual files.
Parameters
----------
data_location : models.S3DataLocation
The S3 data location where the file is stored
raw_data_file : models.RawDataFile
The raw data file to construct the key for
Returns
-------
str
The S3 object key for the file
"""
# Use the same logic as for packages
location_path = data_location.name.replace("_", "/")
# Construct the destination path
destination = os.path.normpath(
os.path.join(
location_path,
raw_data_file.relative_path,
)
)
# Apply the same reformatting: replace / with _ and remove leading /
s3_key = destination.lstrip("/").replace("/", "_")
logger.debug(
"constructed_s3_key_for_file",
location_name=data_location.name,
file_relative_path=raw_data_file.relative_path,
s3_key=s3_key,
)
return s3_key