"""SLURM HPC backend - submits jobs via sbatch."""
import subprocess
import re
from typing import Optional
from .base import HPCBackend, HPCJobInfo, HPCJobStatus
from ..config.config import ccat_workflow_manager_settings
from ..logging_utils import get_structured_logger
from ..exceptions import HPCSubmissionError, HPCStatusError
logger = get_structured_logger(__name__)
[docs]
class SLURMBackend(HPCBackend):
"""Submit and monitor jobs on a SLURM cluster."""
[docs]
def __init__(self):
self.partition = ccat_workflow_manager_settings.SLURM_PARTITION
self.account = ccat_workflow_manager_settings.SLURM_ACCOUNT
self.apptainer_cache = ccat_workflow_manager_settings.APPTAINER_CACHE_DIR
[docs]
def submit(
self,
execution_command: str,
image_ref: str,
sif_path: str,
input_dir: str,
output_dir: str,
workspace_dir: str,
manifest_path: str,
resource_requirements: dict,
environment_variables: dict,
job_name: str,
) -> str:
cpu = resource_requirements.get("cpu", "1")
memory_gb = resource_requirements.get("memory_gb", "4")
time_hours = resource_requirements.get("time_hours", "24")
gpu = resource_requirements.get("gpu", "0")
memory_slurm = f"{memory_gb}G"
time_limit = f"{int(float(time_hours)):02d}:00:00"
# Build environment exports
env_exports = "\n".join(
f'export {k}="{v}"' for k, v in environment_variables.items()
)
script = f"""#!/bin/bash
#SBATCH --job-name={job_name}
#SBATCH --partition={self.partition}
#SBATCH --account={self.account}
#SBATCH --cpus-per-task={cpu}
#SBATCH --mem={memory_slurm}
#SBATCH --time={time_limit}
#SBATCH --output={output_dir}/{job_name}_%j.out
#SBATCH --error={output_dir}/{job_name}_%j.err
{"#SBATCH --gres=gpu:" + str(gpu) if int(gpu) > 0 else ""}
{env_exports}
echo "Starting pipeline: {job_name}"
echo "SIF: {sif_path}"
{execution_command}
echo "Pipeline completed with exit code: $?"
"""
try:
result = subprocess.run(
["sbatch"],
input=script,
capture_output=True,
text=True,
timeout=30,
)
if result.returncode != 0:
raise HPCSubmissionError(
f"sbatch failed: {result.stderr}",
)
match = re.search(r"Submitted batch job (\d+)", result.stdout)
if not match:
raise HPCSubmissionError(
f"Could not parse job ID from sbatch output: {result.stdout}",
)
job_id = match.group(1)
logger.info("slurm_job_submitted", job_id=job_id, image=image_ref)
return job_id
except subprocess.TimeoutExpired:
raise HPCSubmissionError("sbatch timed out")
except FileNotFoundError:
raise HPCSubmissionError("sbatch command not found")
[docs]
def get_status(self, job_id: str) -> HPCJobInfo:
try:
result = subprocess.run(
[
"sacct",
"-j", job_id,
"--format=JobID,State,ExitCode,Start,End,NodeList,Elapsed,MaxRSS,TotalCPU",
"--noheader",
"--parsable2",
],
capture_output=True,
text=True,
timeout=15,
)
except (subprocess.TimeoutExpired, FileNotFoundError) as e:
raise HPCStatusError(f"sacct failed: {e}")
if result.returncode != 0:
raise HPCStatusError(f"sacct error: {result.stderr}")
for line in result.stdout.strip().split("\n"):
parts = line.split("|")
if len(parts) >= 6 and parts[0] == job_id:
state = parts[1]
exit_code_str = parts[2]
start_time = parts[3] if parts[3] != "Unknown" else None
end_time = parts[4] if parts[4] != "Unknown" else None
node = parts[5]
exit_code = None
if ":" in exit_code_str:
exit_code = int(exit_code_str.split(":")[0])
# Extract metrics if available
wall_time = None
peak_memory = None
cpu_time = None
if len(parts) >= 9:
wall_time = self._parse_elapsed(parts[6])
peak_memory = self._parse_memory_gb(parts[7])
cpu_time = self._parse_cpu_hours(parts[8])
status_map = {
"PENDING": HPCJobStatus.PENDING,
"RUNNING": HPCJobStatus.RUNNING,
"COMPLETED": HPCJobStatus.COMPLETED,
"FAILED": HPCJobStatus.FAILED,
"CANCELLED": HPCJobStatus.CANCELLED,
"TIMEOUT": HPCJobStatus.FAILED,
"NODE_FAIL": HPCJobStatus.FAILED,
"OUT_OF_MEMORY": HPCJobStatus.FAILED,
}
return HPCJobInfo(
job_id=job_id,
status=status_map.get(state, HPCJobStatus.UNKNOWN),
exit_code=exit_code,
start_time=start_time,
end_time=end_time,
node=node,
wall_time_seconds=wall_time,
peak_memory_gb=peak_memory,
cpu_hours=cpu_time,
)
return HPCJobInfo(job_id=job_id, status=HPCJobStatus.UNKNOWN)
[docs]
def get_logs(self, job_id: str) -> str:
try:
result = subprocess.run(
["scontrol", "show", "job", job_id],
capture_output=True,
text=True,
timeout=10,
)
return result.stdout
except Exception as e:
raise HPCStatusError(f"Failed to get SLURM job logs: {e}")
[docs]
def cancel(self, job_id: str) -> bool:
try:
result = subprocess.run(
["scancel", job_id],
capture_output=True,
text=True,
timeout=10,
)
if result.returncode == 0:
logger.info("slurm_job_cancelled", job_id=job_id)
return True
logger.error("slurm_cancel_failed", job_id=job_id, stderr=result.stderr)
return False
except Exception as e:
logger.error("slurm_cancel_error", job_id=job_id, error=str(e))
return False
def _parse_elapsed(self, elapsed_str: str) -> Optional[float]:
"""Parse SLURM elapsed time (HH:MM:SS or D-HH:MM:SS) to seconds."""
if not elapsed_str or elapsed_str == "":
return None
try:
parts = elapsed_str.split("-")
if len(parts) == 2:
days = int(parts[0])
time_parts = parts[1].split(":")
else:
days = 0
time_parts = parts[0].split(":")
hours, minutes, seconds = [int(p) for p in time_parts]
return days * 86400 + hours * 3600 + minutes * 60 + seconds
except (ValueError, IndexError):
return None
def _parse_memory_gb(self, mem_str: str) -> Optional[float]:
"""Parse SLURM MaxRSS (e.g., '4096K', '2G') to GB."""
if not mem_str or mem_str == "":
return None
try:
if mem_str.endswith("K"):
return float(mem_str[:-1]) / (1024 * 1024)
elif mem_str.endswith("M"):
return float(mem_str[:-1]) / 1024
elif mem_str.endswith("G"):
return float(mem_str[:-1])
return float(mem_str) / (1024 * 1024 * 1024)
except ValueError:
return None
def _parse_cpu_hours(self, cpu_str: str) -> Optional[float]:
"""Parse SLURM TotalCPU (HH:MM:SS) to hours."""
if not cpu_str or cpu_str == "":
return None
try:
parts = cpu_str.split(":")
hours, minutes, seconds = [float(p) for p in parts]
return hours + minutes / 60 + seconds / 3600
except (ValueError, IndexError):
return None