-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Refactored out ID validating section * Compliant to existing tests now * GPU Validator tests done * Fixed case where metadata does not have annotations * int typecast fix * Fixed key error with collecting existing gpus * Fixed issue with getting container gpus * Handles limits as well now * Unit tests for k8s client * Merge done * Test rework * Metadata optional, id_validator reformat * Permit no-gpu pods if gpu usage is overcommitted --------- Co-authored-by: D0rkKnight <[email protected]>
- Loading branch information
1 parent
00b26e9
commit 83d2024
Showing
21 changed files
with
1,560 additions
and
906 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,4 +7,5 @@ PyHamcrest | |
requests_mock | ||
dataclasses-json | ||
python-dotenv | ||
pytest | ||
git+https://github.com/ucsd-ets/[email protected] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
GPU_LABEL = "nvidia.com/gpu" | ||
GPU_LIMIT_ANNOTATION = 'gpu-limit' | ||
LOW_PRIORITY_CLASS = "low" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from dataclasses import dataclass | ||
import json | ||
from typing import List, Optional | ||
|
||
from dataclasses_json import dataclass_json | ||
from dsmlp.plugin.awsed import AwsedClient, UnsuccessfulRequest | ||
from dsmlp.plugin.console import Console | ||
from dsmlp.plugin.course import ConfigProvider | ||
from dsmlp.plugin.kube import KubeClient, NotFound | ||
import jsonify | ||
|
||
from dsmlp.plugin.logger import Logger | ||
from dsmlp.app.types import * | ||
from dsmlp.app.config import * | ||
|
||
|
||
class GPUValidator(ComponentValidator): | ||
|
||
def __init__(self, kube: KubeClient, logger: Logger) -> None: | ||
self.kube = kube | ||
self.logger = logger | ||
|
||
def validate_pod(self, request: Request): | ||
""" | ||
Validate pods for namespaces with the 'k8s-sync' label | ||
""" | ||
|
||
# Low priority pods pass through | ||
priority = request.object.spec.priorityClassName | ||
if priority is not None and priority == LOW_PRIORITY_CLASS: | ||
return | ||
|
||
namespace = self.kube.get_namespace(request.namespace) | ||
curr_gpus = self.kube.get_gpus_in_namespace(request.namespace) | ||
|
||
utilized_gpus = 0 | ||
for container in request.object.spec.containers: | ||
requested, limit = 0, 0 | ||
try: | ||
requested = int(container.resources.requests[GPU_LABEL]) | ||
except (KeyError, AttributeError, TypeError): | ||
pass | ||
try: | ||
limit = int(container.resources.limits[GPU_LABEL]) | ||
except (KeyError, AttributeError, TypeError): | ||
pass | ||
|
||
utilized_gpus += max(requested, limit) | ||
|
||
# Short circuit if no GPUs requested (permits overcap) | ||
if utilized_gpus == 0: | ||
return | ||
|
||
if utilized_gpus + curr_gpus > namespace.gpu_quota: | ||
raise ValidationFailure( | ||
f"GPU quota exceeded. Wanted {utilized_gpus} but with {curr_gpus} already in use, the quota of {namespace.gpu_quota} would be exceeded.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
from dataclasses import dataclass | ||
import json | ||
from typing import List, Optional | ||
|
||
from dataclasses_json import dataclass_json | ||
from dsmlp.plugin.awsed import AwsedClient, UnsuccessfulRequest | ||
from dsmlp.plugin.console import Console | ||
from dsmlp.plugin.course import ConfigProvider | ||
from dsmlp.plugin.kube import KubeClient, NotFound | ||
import jsonify | ||
|
||
from dsmlp.plugin.logger import Logger | ||
from dsmlp.app.types import * | ||
|
||
|
||
class IDValidator(ComponentValidator): | ||
|
||
def __init__(self, awsed: AwsedClient, logger: Logger) -> None: | ||
self.awsed = awsed | ||
self.logger = logger | ||
|
||
def validate_pod(self, request: Request): | ||
""" | ||
Validate pods for namespaces with the 'k8s-sync' label | ||
""" | ||
username = request.namespace | ||
# namespace = self.kube.get_namespace(request.namespace) | ||
|
||
# if 'k8s-sync' in namespace.labels: | ||
user = self.awsed.describe_user(username) | ||
if not user: | ||
raise ValidationFailure( | ||
f"namespace: no AWSEd user found with username {username}") | ||
allowed_uid = user.uid | ||
allowed_courses = user.enrollments | ||
|
||
team_response = self.awsed.list_user_teams(username) | ||
allowed_gids = [team.gid for team in team_response.teams] | ||
allowed_gids.append(0) | ||
allowed_gids.append(100) | ||
|
||
metadata = request.object.metadata | ||
spec = request.object.spec | ||
|
||
if metadata is not None and metadata.labels is not None: | ||
self.validate_course_enrollment(allowed_courses, metadata.labels) | ||
|
||
self.validate_pod_security_context( | ||
allowed_uid, allowed_gids, spec.securityContext) | ||
self.validate_containers(allowed_uid, allowed_gids, spec) | ||
|
||
def validate_course_enrollment(self, allowed_courses: List[str], labels: Dict[str, str]): | ||
if not 'dsmlp/course' in labels: | ||
return | ||
if not labels['dsmlp/course'] in allowed_courses: | ||
raise ValidationFailure( | ||
f"metadata.labels: dsmlp/course must be in range {allowed_courses}") | ||
|
||
def validate_pod_security_context( | ||
self, | ||
authorized_uid: int, | ||
allowed_teams: List[int], | ||
securityContext: PodSecurityContext): | ||
|
||
if securityContext is None: | ||
return | ||
|
||
if securityContext.runAsUser is not None and authorized_uid != securityContext.runAsUser: | ||
raise ValidationFailure( | ||
f"spec.securityContext: uid must be in range [{authorized_uid}]") | ||
|
||
if securityContext.runAsGroup is not None and securityContext.runAsGroup not in allowed_teams: | ||
raise ValidationFailure( | ||
f"spec.securityContext: gid must be in range {allowed_teams}") | ||
|
||
if securityContext.fsGroup is not None and securityContext.fsGroup not in allowed_teams: | ||
raise ValidationFailure( | ||
f"spec.securityContext: gid must be in range {allowed_teams}") | ||
|
||
if securityContext.supplementalGroups is not None: | ||
for sgroup in securityContext.supplementalGroups: | ||
if not sgroup in allowed_teams: | ||
raise ValidationFailure( | ||
f"spec.securityContext: gid must be in range {allowed_teams}") | ||
|
||
def validate_containers( | ||
self, | ||
authorized_uid: int, | ||
allowed_teams: List[int], | ||
spec: PodSpec | ||
): | ||
""" | ||
Validate the security context of containers and initContainers | ||
""" | ||
self.validate_security_contexts( | ||
authorized_uid, allowed_teams, spec.containers, "containers") | ||
self.validate_security_contexts( | ||
authorized_uid, allowed_teams, spec.initContainers, "initContainers") | ||
|
||
def validate_security_contexts( | ||
self, authorized_uid: int, allowed_teams: List[int], | ||
containers: List[Container], | ||
context: str): | ||
""" | ||
Validate the security context of a container. | ||
""" | ||
|
||
if containers is None: | ||
return | ||
|
||
for i, container in enumerate(containers): | ||
securityContext = container.securityContext | ||
if securityContext is None: | ||
continue | ||
|
||
self.validate_security_context( | ||
authorized_uid, allowed_teams, securityContext, f"{context}[{i}]") | ||
|
||
def validate_security_context( | ||
self, | ||
authorized_uid: int, | ||
allowed_teams: List[int], | ||
securityContext: SecurityContext, | ||
context: str): | ||
|
||
if securityContext.runAsUser is not None and authorized_uid != securityContext.runAsUser: | ||
raise ValidationFailure( | ||
f"spec.{context}.securityContext: uid must be in range [{authorized_uid}]") | ||
|
||
if securityContext.runAsGroup is not None and securityContext.runAsGroup not in allowed_teams: | ||
raise ValidationFailure( | ||
f"spec.{context}.securityContext: gid must be in range {allowed_teams}") | ||
|
||
def admission_response(self, uid, allowed, message): | ||
return { | ||
"apiVersion": "admission.k8s.io/v1", | ||
"kind": "AdmissionReview", | ||
"response": { | ||
"uid": uid, | ||
"allowed": allowed, | ||
"status": { | ||
"message": message | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
|
||
from dataclasses import dataclass | ||
from typing import List, Optional, Dict | ||
from dataclasses_json import dataclass_json | ||
from abc import ABCMeta, abstractmethod | ||
|
||
@dataclass_json | ||
@dataclass | ||
class SecurityContext: | ||
"""Each Container has a SecurityContext""" | ||
runAsUser: Optional[int] = None | ||
runAsGroup: Optional[int] = None | ||
|
||
@dataclass_json | ||
@dataclass | ||
class ResourceRequirements: | ||
requests: Optional[Dict[str, int]] = None | ||
limits: Optional[Dict[str, int]] = None | ||
|
||
@dataclass_json | ||
@dataclass | ||
class Container: | ||
securityContext: Optional[SecurityContext] = None | ||
resources: Optional[ResourceRequirements] = None | ||
|
||
@dataclass_json | ||
@dataclass | ||
class PodSecurityContext: | ||
"""Each Pod has a SecurityContext""" | ||
runAsUser: Optional[int] = None | ||
runAsGroup: Optional[int] = None | ||
fsGroup: Optional[int] = None | ||
supplementalGroups: Optional[List[int]] = None | ||
|
||
|
||
@dataclass_json | ||
@dataclass | ||
class PodSpec: | ||
containers: List[Container] | ||
initContainers: Optional[List[Container]] = None | ||
securityContext: Optional[PodSecurityContext] = None | ||
priorityClassName: Optional[str] = None | ||
|
||
@dataclass_json | ||
@dataclass | ||
class ObjectMeta: | ||
labels: Dict[str, str] | ||
|
||
|
||
@dataclass_json | ||
@dataclass | ||
class Object: | ||
metadata: ObjectMeta | ||
spec: PodSpec | ||
|
||
|
||
@dataclass_json | ||
@dataclass | ||
class UserInfo: | ||
username: str | ||
|
||
|
||
@dataclass_json | ||
@dataclass | ||
class Request: | ||
uid: str | ||
namespace: str | ||
object: Object | ||
userInfo: UserInfo | ||
|
||
|
||
@dataclass_json | ||
@dataclass | ||
class AdmissionReview: | ||
request: Request | ||
|
||
class ValidationFailure(Exception): | ||
def __init__(self, message: str) -> None: | ||
self.message = message | ||
super().__init__(self.message) | ||
|
||
class ComponentValidator: | ||
@abstractmethod | ||
def validate_pod(self, request: Request): | ||
pass |
Oops, something went wrong.