Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure k8s pod names/labels are RFC 1123 compliant #3639

Merged
merged 4 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ radical_local_test:

.PHONY: config_local_test
config_local_test: $(CCTOOLS_INSTALL)
pip3 install ".[monitoring,visualization,proxystore]"
pip3 install ".[monitoring,visualization,proxystore,kubernetes]"
PYTHONPATH=/tmp/cctools/lib/python3.8/site-packages pytest parsl/tests/ -k "not cleannet" --config local --random-order --durations 10

.PHONY: site_test
Expand Down
63 changes: 35 additions & 28 deletions parsl/providers/kubernetes/kube.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
import logging
import time

from parsl.providers.kubernetes.template import template_string

logger = logging.getLogger(__name__)

import uuid
from typing import Any, Dict, List, Optional, Tuple

import typeguard

from parsl.errors import OptionalModuleMissing
from parsl.jobs.states import JobState, JobStatus
from parsl.providers.base import ExecutionProvider
from parsl.utils import RepresentationMixin
from parsl.providers.kubernetes.template import template_string
from parsl.utils import RepresentationMixin, sanitize_dns_subdomain_rfc1123

try:
from kubernetes import client, config
_kubernetes_enabled = True
except (ImportError, NameError, FileNotFoundError):
_kubernetes_enabled = False

logger = logging.getLogger(__name__)

translate_table = {
'Running': JobState.RUNNING,
'Pending': JobState.PENDING,
Expand Down Expand Up @@ -161,7 +159,7 @@ def __init__(self,
self.resources: Dict[object, Dict[str, Any]]
self.resources = {}

def submit(self, cmd_string, tasks_per_node, job_name="parsl"):
def submit(self, cmd_string: str, tasks_per_node: int, job_name: str = "parsl.kube"):
""" Submit a job
Args:
- cmd_string :(String) - Name of the container to initiate
Expand All @@ -173,30 +171,34 @@ def submit(self, cmd_string, tasks_per_node, job_name="parsl"):
Returns:
- job_id: (string) Identifier for the job
"""
job_id = uuid.uuid4().hex[:8]

cur_timestamp = str(time.time() * 1000).split(".")[0]
job_name = "{0}-{1}".format(job_name, cur_timestamp)

if not self.pod_name:
pod_name = '{}'.format(job_name)
else:
pod_name = '{}-{}'.format(self.pod_name,
cur_timestamp)
pod_name = self.pod_name or job_name
try:
pod_name = sanitize_dns_subdomain_rfc1123(pod_name)
except ValueError:
logger.warning(
f"Invalid pod name '{pod_name}' for job '{job_id}', falling back to 'parsl.kube'"
)
pod_name = "parsl.kube"
pod_name = pod_name[:253 - 1 - len(job_id)] # Leave room for the job ID
pod_name = pod_name.rstrip(".-") # Remove trailing dot or hyphen after trim
pod_name = f"{pod_name}.{job_id}"
rjmello marked this conversation as resolved.
Show resolved Hide resolved

formatted_cmd = template_string.format(command=cmd_string,
worker_init=self.worker_init)

logger.debug("Pod name: %s", pod_name)
self._create_pod(image=self.image,
pod_name=pod_name,
job_name=job_name,
job_id=job_id,
cmd_string=formatted_cmd,
volumes=self.persistent_volumes,
service_account_name=self.service_account_name,
annotations=self.annotations)
self.resources[pod_name] = {'status': JobStatus(JobState.RUNNING)}
self.resources[job_id] = {'status': JobStatus(JobState.RUNNING), 'pod_name': pod_name}

return pod_name
return job_id

def status(self, job_ids):
""" Get the status of a list of jobs identified by the job identifiers
Expand All @@ -212,6 +214,9 @@ def status(self, job_ids):
self._status()
return [self.resources[jid]['status'] for jid in job_ids]

def _get_pod_name(self, job_id: str) -> str:
return self.resources[job_id]['pod_name']

def cancel(self, job_ids):
""" Cancels the jobs specified by a list of job ids
Args:
Expand All @@ -221,7 +226,8 @@ def cancel(self, job_ids):
"""
for job in job_ids:
logger.debug("Terminating job/pod: {0}".format(job))
self._delete_pod(job)
pod_name = self._get_pod_name(job)
self._delete_pod(pod_name)

self.resources[job]['status'] = JobStatus(JobState.CANCELLED)
rets = [True for i in job_ids]
Expand All @@ -242,7 +248,8 @@ def _status(self):
for jid in to_poll_job_ids:
phase = None
try:
pod = self.kube_client.read_namespaced_pod(name=jid, namespace=self.namespace)
pod_name = self._get_pod_name(jid)
pod = self.kube_client.read_namespaced_pod(name=pod_name, namespace=self.namespace)
except Exception:
logger.exception("Failed to poll pod {} status, most likely because pod was terminated".format(jid))
if self.resources[jid]['status'] is JobStatus(JobState.RUNNING):
Expand All @@ -257,10 +264,10 @@ def _status(self):
self.resources[jid]['status'] = JobStatus(status)

def _create_pod(self,
image,
pod_name,
job_name,
port=80,
image: str,
pod_name: str,
job_id: str,
port: int = 80,
cmd_string=None,
volumes=[],
service_account_name=None,
Expand All @@ -269,7 +276,7 @@ def _create_pod(self,
Args:
- image (string) : Docker image to launch
- pod_name (string) : Name of the pod
- job_name (string) : App label
- job_id (string) : Job ID
KWargs:
- port (integer) : Container port
Returns:
Expand Down Expand Up @@ -299,7 +306,7 @@ def _create_pod(self,
)
# Configure Pod template container
container = client.V1Container(
name=pod_name,
name=job_id,
image=image,
resources=resources,
ports=[client.V1ContainerPort(container_port=port)],
Expand All @@ -322,7 +329,7 @@ def _create_pod(self,
claim_name=volume[0])))

metadata = client.V1ObjectMeta(name=pod_name,
labels={"app": job_name},
labels={"parsl-job-id": job_id},
annotations=annotations)
spec = client.V1PodSpec(containers=[container],
image_pull_secrets=[secret],
Expand Down
102 changes: 102 additions & 0 deletions parsl/tests/test_providers/test_kubernetes_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import re
from unittest import mock

import pytest

from parsl.providers.kubernetes.kube import KubernetesProvider
from parsl.tests.test_utils.test_sanitize_dns import DNS_SUBDOMAIN_REGEX

_MOCK_BASE = "parsl.providers.kubernetes.kube"


@pytest.fixture(autouse=True)
def mock_kube_config():
with mock.patch(f"{_MOCK_BASE}.config") as mock_config:
mock_config.load_kube_config.return_value = None
yield mock_config


@pytest.fixture
def mock_kube_client():
mock_client = mock.MagicMock()
with mock.patch(f"{_MOCK_BASE}.client.CoreV1Api") as mock_api:
mock_api.return_value = mock_client
yield mock_client


@pytest.mark.local
def test_submit_happy_path(mock_kube_client: mock.MagicMock):
image = "test-image"
namespace = "test-namespace"
cmd_string = "test-command"
volumes = [("test-volume", "test-mount-path")]
service_account_name = "test-service-account"
annotations = {"test-annotation": "test-value"}
max_cpu = 2
max_mem = "2Gi"
init_cpu = 1
init_mem = "1Gi"
provider = KubernetesProvider(
image=image,
persistent_volumes=volumes,
namespace=namespace,
service_account_name=service_account_name,
annotations=annotations,
max_cpu=max_cpu,
max_mem=max_mem,
init_cpu=init_cpu,
init_mem=init_mem,
)

job_name = "test.job.name"
job_id = provider.submit(cmd_string=cmd_string, tasks_per_node=1, job_name=job_name)

assert job_id in provider.resources
assert mock_kube_client.create_namespaced_pod.call_count == 1

call_args = mock_kube_client.create_namespaced_pod.call_args[1]
pod = call_args["body"]
container = pod.spec.containers[0]
volume = container.volume_mounts[0]

assert image == container.image
assert namespace == call_args["namespace"]
assert any(cmd_string in arg for arg in container.args)
assert volumes[0] == (volume.name, volume.mount_path)
assert service_account_name == pod.spec.service_account_name
assert annotations == pod.metadata.annotations
assert str(max_cpu) == container.resources.limits["cpu"]
assert max_mem == container.resources.limits["memory"]
assert str(init_cpu) == container.resources.requests["cpu"]
assert init_mem == container.resources.requests["memory"]
assert job_id == pod.metadata.labels["parsl-job-id"]
assert job_id == container.name
assert f"{job_name}.{job_id}" == pod.metadata.name


@pytest.mark.local
@mock.patch(f"{_MOCK_BASE}.KubernetesProvider._create_pod")
@pytest.mark.parametrize("char", (".", "-"))
def test_submit_pod_name_includes_job_id(mock_create_pod: mock.MagicMock, char: str):
provider = KubernetesProvider(image="test-image")

job_name = "a." * 121 + f"a{char}" + "a" * 9
assert len(job_name) == 253 # Max length for pod name
job_id = provider.submit(cmd_string="test-command", tasks_per_node=1, job_name=job_name)

expected_pod_name = job_name[:253 - len(job_id) - 2] + f".{job_id}"
actual_pod_name = mock_create_pod.call_args[1]["pod_name"]
assert re.match(DNS_SUBDOMAIN_REGEX, actual_pod_name)
assert expected_pod_name == actual_pod_name


@pytest.mark.local
@mock.patch(f"{_MOCK_BASE}.KubernetesProvider._create_pod")
@mock.patch(f"{_MOCK_BASE}.logger")
@pytest.mark.parametrize("job_name", ("", ".", "-", "a.-.a", "$$$"))
def test_submit_invalid_job_name(mock_logger: mock.MagicMock, mock_create_pod: mock.MagicMock, job_name: str):
provider = KubernetesProvider(image="test-image")
job_id = provider.submit(cmd_string="test-command", tasks_per_node=1, job_name=job_name)
assert mock_logger.warning.call_count == 1
assert f"Invalid pod name '{job_name}' for job '{job_id}'" in mock_logger.warning.call_args[0][0]
assert f"parsl.kube.{job_id}" == mock_create_pod.call_args[1]["pod_name"]
76 changes: 76 additions & 0 deletions parsl/tests/test_utils/test_sanitize_dns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import random
import re

import pytest

from parsl.utils import sanitize_dns_label_rfc1123, sanitize_dns_subdomain_rfc1123

# Ref: https://datatracker.ietf.org/doc/html/rfc1123
DNS_LABEL_REGEX = r'^[a-z0-9]([-a-z0-9]{0,61}[a-z0-9])?$'
DNS_SUBDOMAIN_REGEX = r'^[a-z0-9]([-a-z0-9]{0,61}[a-z0-9])?(\.[a-z0-9]([-a-z0-9]{0,61}[a-z0-9])?)*$'

test_labels = [
"example-label-123", # Valid label
"EXAMPLE", # Case sensitivity
"!@#example*", # Remove invalid characters
"--leading-and-trailing--", # Leading and trailing hyphens
"..leading.and.trailing..", # Leading and tailing dots
"multiple..dots", # Consecutive dots
"valid--label", # Consecutive hyphens
"a" * random.randint(64, 70), # Longer than 63 characters
f"{'a' * 62}-a", # Trailing hyphen at max length
]


def _generate_test_subdomains(num_subdomains: int):
subdomains = []
for _ in range(num_subdomains):
num_labels = random.randint(1, 5)
labels = [test_labels[random.randint(0, num_labels - 1)] for _ in range(num_labels)]
subdomain = ".".join(labels)
subdomains.append(subdomain)
return subdomains


@pytest.mark.local
@pytest.mark.parametrize("raw_string", test_labels)
def test_sanitize_dns_label_rfc1123(raw_string: str):
print(sanitize_dns_label_rfc1123(raw_string))
assert re.match(DNS_LABEL_REGEX, sanitize_dns_label_rfc1123(raw_string))


@pytest.mark.local
@pytest.mark.parametrize("raw_string", ("", "-", "@", "$$$"))
def test_sanitize_dns_label_rfc1123_empty(raw_string: str):
with pytest.raises(ValueError) as e_info:
sanitize_dns_label_rfc1123(raw_string)
assert str(e_info.value) == f"Sanitized DNS label is empty for input '{raw_string}'"


@pytest.mark.local
@pytest.mark.parametrize("raw_string", _generate_test_subdomains(10))
def test_sanitize_dns_subdomain_rfc1123(raw_string: str):
assert re.match(DNS_SUBDOMAIN_REGEX, sanitize_dns_subdomain_rfc1123(raw_string))


@pytest.mark.local
@pytest.mark.parametrize("char", ("-", "."))
def test_sanitize_dns_subdomain_rfc1123_trailing_non_alphanumeric_at_max_length(char: str):
raw_string = (f"{'a' * 61}." * 4) + f".aaaa{char}a"
assert re.match(DNS_SUBDOMAIN_REGEX, sanitize_dns_subdomain_rfc1123(raw_string))


@pytest.mark.local
@pytest.mark.parametrize("raw_string", ("", ".", "..."))
def test_sanitize_dns_subdomain_rfc1123_empty(raw_string: str):
with pytest.raises(ValueError) as e_info:
sanitize_dns_subdomain_rfc1123(raw_string)
assert str(e_info.value) == f"Sanitized DNS subdomain is empty for input '{raw_string}'"


@pytest.mark.local
@pytest.mark.parametrize(
"raw_string", ("a" * 253, "a" * random.randint(254, 300)), ids=("254 chars", ">253 chars")
)
def test_sanitize_dns_subdomain_rfc1123_max_length(raw_string: str):
assert len(sanitize_dns_subdomain_rfc1123(raw_string)) <= 253
Loading
Loading