Source code for ccat_workflow_manager.hpc.slurm

"""SLURM HPC backend - submits jobs via sbatch."""

import subprocess
import re
from typing import Optional

from .base import HPCBackend, HPCJobInfo, HPCJobStatus
from ..config.config import ccat_workflow_manager_settings
from ..logging_utils import get_structured_logger
from ..exceptions import HPCSubmissionError, HPCStatusError

logger = get_structured_logger(__name__)


[docs] class SLURMBackend(HPCBackend): """Submit and monitor jobs on a SLURM cluster."""
[docs] def __init__(self): self.partition = ccat_workflow_manager_settings.SLURM_PARTITION self.account = ccat_workflow_manager_settings.SLURM_ACCOUNT self.apptainer_cache = ccat_workflow_manager_settings.APPTAINER_CACHE_DIR
[docs] def submit( self, execution_command: str, image_ref: str, sif_path: str, input_dir: str, output_dir: str, workspace_dir: str, manifest_path: str, resource_requirements: dict, environment_variables: dict, job_name: str, ) -> str: cpu = resource_requirements.get("cpu", "1") memory_gb = resource_requirements.get("memory_gb", "4") time_hours = resource_requirements.get("time_hours", "24") gpu = resource_requirements.get("gpu", "0") memory_slurm = f"{memory_gb}G" time_limit = f"{int(float(time_hours)):02d}:00:00" # Build environment exports env_exports = "\n".join( f'export {k}="{v}"' for k, v in environment_variables.items() ) script = f"""#!/bin/bash #SBATCH --job-name={job_name} #SBATCH --partition={self.partition} #SBATCH --account={self.account} #SBATCH --cpus-per-task={cpu} #SBATCH --mem={memory_slurm} #SBATCH --time={time_limit} #SBATCH --output={output_dir}/{job_name}_%j.out #SBATCH --error={output_dir}/{job_name}_%j.err {"#SBATCH --gres=gpu:" + str(gpu) if int(gpu) > 0 else ""} {env_exports} echo "Starting pipeline: {job_name}" echo "SIF: {sif_path}" {execution_command} echo "Pipeline completed with exit code: $?" """ try: result = subprocess.run( ["sbatch"], input=script, capture_output=True, text=True, timeout=30, ) if result.returncode != 0: raise HPCSubmissionError( f"sbatch failed: {result.stderr}", ) match = re.search(r"Submitted batch job (\d+)", result.stdout) if not match: raise HPCSubmissionError( f"Could not parse job ID from sbatch output: {result.stdout}", ) job_id = match.group(1) logger.info("slurm_job_submitted", job_id=job_id, image=image_ref) return job_id except subprocess.TimeoutExpired: raise HPCSubmissionError("sbatch timed out") except FileNotFoundError: raise HPCSubmissionError("sbatch command not found")
[docs] def get_status(self, job_id: str) -> HPCJobInfo: try: result = subprocess.run( [ "sacct", "-j", job_id, "--format=JobID,State,ExitCode,Start,End,NodeList,Elapsed,MaxRSS,TotalCPU", "--noheader", "--parsable2", ], capture_output=True, text=True, timeout=15, ) except (subprocess.TimeoutExpired, FileNotFoundError) as e: raise HPCStatusError(f"sacct failed: {e}") if result.returncode != 0: raise HPCStatusError(f"sacct error: {result.stderr}") for line in result.stdout.strip().split("\n"): parts = line.split("|") if len(parts) >= 6 and parts[0] == job_id: state = parts[1] exit_code_str = parts[2] start_time = parts[3] if parts[3] != "Unknown" else None end_time = parts[4] if parts[4] != "Unknown" else None node = parts[5] exit_code = None if ":" in exit_code_str: exit_code = int(exit_code_str.split(":")[0]) # Extract metrics if available wall_time = None peak_memory = None cpu_time = None if len(parts) >= 9: wall_time = self._parse_elapsed(parts[6]) peak_memory = self._parse_memory_gb(parts[7]) cpu_time = self._parse_cpu_hours(parts[8]) status_map = { "PENDING": HPCJobStatus.PENDING, "RUNNING": HPCJobStatus.RUNNING, "COMPLETED": HPCJobStatus.COMPLETED, "FAILED": HPCJobStatus.FAILED, "CANCELLED": HPCJobStatus.CANCELLED, "TIMEOUT": HPCJobStatus.FAILED, "NODE_FAIL": HPCJobStatus.FAILED, "OUT_OF_MEMORY": HPCJobStatus.FAILED, } return HPCJobInfo( job_id=job_id, status=status_map.get(state, HPCJobStatus.UNKNOWN), exit_code=exit_code, start_time=start_time, end_time=end_time, node=node, wall_time_seconds=wall_time, peak_memory_gb=peak_memory, cpu_hours=cpu_time, ) return HPCJobInfo(job_id=job_id, status=HPCJobStatus.UNKNOWN)
[docs] def get_logs(self, job_id: str) -> str: try: result = subprocess.run( ["scontrol", "show", "job", job_id], capture_output=True, text=True, timeout=10, ) return result.stdout except Exception as e: raise HPCStatusError(f"Failed to get SLURM job logs: {e}")
[docs] def cancel(self, job_id: str) -> bool: try: result = subprocess.run( ["scancel", job_id], capture_output=True, text=True, timeout=10, ) if result.returncode == 0: logger.info("slurm_job_cancelled", job_id=job_id) return True logger.error("slurm_cancel_failed", job_id=job_id, stderr=result.stderr) return False except Exception as e: logger.error("slurm_cancel_error", job_id=job_id, error=str(e)) return False
def _parse_elapsed(self, elapsed_str: str) -> Optional[float]: """Parse SLURM elapsed time (HH:MM:SS or D-HH:MM:SS) to seconds.""" if not elapsed_str or elapsed_str == "": return None try: parts = elapsed_str.split("-") if len(parts) == 2: days = int(parts[0]) time_parts = parts[1].split(":") else: days = 0 time_parts = parts[0].split(":") hours, minutes, seconds = [int(p) for p in time_parts] return days * 86400 + hours * 3600 + minutes * 60 + seconds except (ValueError, IndexError): return None def _parse_memory_gb(self, mem_str: str) -> Optional[float]: """Parse SLURM MaxRSS (e.g., '4096K', '2G') to GB.""" if not mem_str or mem_str == "": return None try: if mem_str.endswith("K"): return float(mem_str[:-1]) / (1024 * 1024) elif mem_str.endswith("M"): return float(mem_str[:-1]) / 1024 elif mem_str.endswith("G"): return float(mem_str[:-1]) return float(mem_str) / (1024 * 1024 * 1024) except ValueError: return None def _parse_cpu_hours(self, cpu_str: str) -> Optional[float]: """Parse SLURM TotalCPU (HH:MM:SS) to hours.""" if not cpu_str or cpu_str == "": return None try: parts = cpu_str.split(":") hours, minutes, seconds = [float(p) for p in parts] return hours + minutes / 60 + seconds / 3600 except (ValueError, IndexError): return None