"""Kubernetes HPC backend - submits jobs as K8s Jobs."""
from .base import HPCBackend, HPCJobInfo, HPCJobStatus
from ..config.config import ccat_workflow_manager_settings
from ..logging_utils import get_structured_logger
from ..exceptions import HPCSubmissionError, HPCStatusError
logger = get_structured_logger(__name__)
[docs]
class KubernetesBackend(HPCBackend):
"""Submit and monitor jobs on a Kubernetes cluster."""
[docs]
def __init__(self):
self.namespace = ccat_workflow_manager_settings.K8S_NAMESPACE
self.service_account = ccat_workflow_manager_settings.K8S_SERVICE_ACCOUNT
self._client = None
@property
def client(self):
if self._client is None:
try:
from kubernetes import client, config
config.load_incluster_config()
self._client = client.BatchV1Api()
except Exception:
from kubernetes import client, config
config.load_kube_config()
self._client = client.BatchV1Api()
return self._client
@property
def core_client(self):
from kubernetes import client
return client.CoreV1Api()
[docs]
def submit(
self,
execution_command: str,
image_ref: str,
sif_path: str,
input_dir: str,
output_dir: str,
workspace_dir: str,
manifest_path: str,
resource_requirements: dict,
environment_variables: dict,
job_name: str,
) -> str:
from kubernetes import client
cpu = resource_requirements.get("cpu", "1")
memory = resource_requirements.get("memory_gb", "4")
memory_k8s = f"{memory}Gi"
env_vars = [
client.V1EnvVar(name=k, value=str(v))
for k, v in environment_variables.items()
]
container = client.V1Container(
name="pipeline",
image=image_ref,
command=["/bin/sh", "-c"],
args=[execution_command],
env=env_vars,
resources=client.V1ResourceRequirements(
requests={"cpu": str(cpu), "memory": memory_k8s},
limits={"cpu": str(cpu), "memory": memory_k8s},
),
volume_mounts=[
client.V1VolumeMount(name="data", mount_path="/data"),
],
)
job_spec = client.V1Job(
api_version="batch/v1",
kind="Job",
metadata=client.V1ObjectMeta(
name=job_name,
namespace=self.namespace,
labels={
"app": "ccat-workflow",
"pipeline-job": job_name,
},
),
spec=client.V1JobSpec(
template=client.V1PodTemplateSpec(
metadata=client.V1ObjectMeta(
labels={"app": "ccat-workflow", "pipeline-job": job_name}
),
spec=client.V1PodSpec(
service_account_name=self.service_account,
restart_policy="Never",
containers=[container],
volumes=[
client.V1Volume(
name="data",
persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource(
claim_name="ccat-data-pvc"
),
),
],
),
),
backoff_limit=0,
ttl_seconds_after_finished=86400,
),
)
try:
result = self.client.create_namespaced_job(
namespace=self.namespace, body=job_spec
)
job_id = result.metadata.name
logger.info("k8s_job_submitted", job_id=job_id, image=image_ref)
return job_id
except Exception as e:
raise HPCSubmissionError(
f"Failed to submit K8s job: {e}",
operation_id=None,
)
[docs]
def get_status(self, job_id: str) -> HPCJobInfo:
try:
job = self.client.read_namespaced_job(
name=job_id, namespace=self.namespace
)
except Exception as e:
raise HPCStatusError(f"Failed to get K8s job status: {e}")
status = job.status
if status.succeeded and status.succeeded > 0:
hpc_status = HPCJobStatus.COMPLETED
elif status.failed and status.failed > 0:
hpc_status = HPCJobStatus.FAILED
elif status.active and status.active > 0:
hpc_status = HPCJobStatus.RUNNING
else:
hpc_status = HPCJobStatus.PENDING
# Extract metrics from pod status
wall_time = None
if status.start_time and status.completion_time:
delta = status.completion_time - status.start_time
wall_time = delta.total_seconds()
return HPCJobInfo(
job_id=job_id,
status=hpc_status,
start_time=str(status.start_time) if status.start_time else None,
end_time=str(status.completion_time) if status.completion_time else None,
wall_time_seconds=wall_time,
)
[docs]
def get_logs(self, job_id: str) -> str:
try:
pods = self.core_client.list_namespaced_pod(
namespace=self.namespace,
label_selector=f"pipeline-job={job_id}",
)
if not pods.items:
return "No pods found for job"
pod_name = pods.items[0].metadata.name
return self.core_client.read_namespaced_pod_log(
name=pod_name, namespace=self.namespace
)
except Exception as e:
raise HPCStatusError(f"Failed to get K8s job logs: {e}")
[docs]
def cancel(self, job_id: str) -> bool:
try:
from kubernetes import client
self.client.delete_namespaced_job(
name=job_id,
namespace=self.namespace,
body=client.V1DeleteOptions(propagation_policy="Foreground"),
)
logger.info("k8s_job_cancelled", job_id=job_id)
return True
except Exception as e:
logger.error("k8s_job_cancel_failed", job_id=job_id, error=str(e))
return False