Skip to content

Commit

Permalink
Gpu quota (#4)
Browse files Browse the repository at this point in the history
* Refactored out ID validating section

* Compliant to existing tests now

* GPU Validator tests done

* Fixed case where metadata does not have annotations

* int typecast fix

* Fixed key error with collecting existing gpus

* Fixed issue with getting container gpus

* Handles limits as well now

* Unit tests for k8s client

* Merge done

* Test rework

* Metadata optional, id_validator reformat

* Permit no-gpu pods if gpu usage is overcommitted

---------

Co-authored-by: D0rkKnight <[email protected]>
  • Loading branch information
heshou198 and shouhanzen authored Mar 12, 2024
1 parent 00b26e9 commit 83d2024
Show file tree
Hide file tree
Showing 21 changed files with 1,560 additions and 906 deletions.
4 changes: 3 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@
],
"files.exclude": {
"**/__pycache__": true
}
},
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
621 changes: 621 additions & 0 deletions ref.json

Large diffs are not rendered by default.

Empty file added ref.txt
Empty file.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ PyHamcrest
requests_mock
dataclasses-json
python-dotenv
pytest
git+https://github.com/ucsd-ets/[email protected]
2 changes: 1 addition & 1 deletion src/dsmlp/admission_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def create_app(test_config=None):
logging.getLogger('waitress').setLevel(logging.INFO)
logging.getLogger('dsmlp').setLevel(logging.DEBUG)
logger = PythonLogger(None)
validator = Validator(factory.awsed_client, logger)
validator = Validator(factory.awsed_client, factory.kube_client, logger)

@app.route('/validate', methods=['POST'])
def validate_request():
Expand Down
3 changes: 3 additions & 0 deletions src/dsmlp/app/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
GPU_LABEL = "nvidia.com/gpu"
GPU_LIMIT_ANNOTATION = 'gpu-limit'
LOW_PRIORITY_CLASS = "low"
56 changes: 56 additions & 0 deletions src/dsmlp/app/gpu_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from dataclasses import dataclass
import json
from typing import List, Optional

from dataclasses_json import dataclass_json
from dsmlp.plugin.awsed import AwsedClient, UnsuccessfulRequest
from dsmlp.plugin.console import Console
from dsmlp.plugin.course import ConfigProvider
from dsmlp.plugin.kube import KubeClient, NotFound
import jsonify

from dsmlp.plugin.logger import Logger
from dsmlp.app.types import *
from dsmlp.app.config import *


class GPUValidator(ComponentValidator):

def __init__(self, kube: KubeClient, logger: Logger) -> None:
self.kube = kube
self.logger = logger

def validate_pod(self, request: Request):
"""
Validate pods for namespaces with the 'k8s-sync' label
"""

# Low priority pods pass through
priority = request.object.spec.priorityClassName
if priority is not None and priority == LOW_PRIORITY_CLASS:
return

namespace = self.kube.get_namespace(request.namespace)
curr_gpus = self.kube.get_gpus_in_namespace(request.namespace)

utilized_gpus = 0
for container in request.object.spec.containers:
requested, limit = 0, 0
try:
requested = int(container.resources.requests[GPU_LABEL])
except (KeyError, AttributeError, TypeError):
pass
try:
limit = int(container.resources.limits[GPU_LABEL])
except (KeyError, AttributeError, TypeError):
pass

utilized_gpus += max(requested, limit)

# Short circuit if no GPUs requested (permits overcap)
if utilized_gpus == 0:
return

if utilized_gpus + curr_gpus > namespace.gpu_quota:
raise ValidationFailure(
f"GPU quota exceeded. Wanted {utilized_gpus} but with {curr_gpus} already in use, the quota of {namespace.gpu_quota} would be exceeded.")
145 changes: 145 additions & 0 deletions src/dsmlp/app/id_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from dataclasses import dataclass
import json
from typing import List, Optional

from dataclasses_json import dataclass_json
from dsmlp.plugin.awsed import AwsedClient, UnsuccessfulRequest
from dsmlp.plugin.console import Console
from dsmlp.plugin.course import ConfigProvider
from dsmlp.plugin.kube import KubeClient, NotFound
import jsonify

from dsmlp.plugin.logger import Logger
from dsmlp.app.types import *


class IDValidator(ComponentValidator):

def __init__(self, awsed: AwsedClient, logger: Logger) -> None:
self.awsed = awsed
self.logger = logger

def validate_pod(self, request: Request):
"""
Validate pods for namespaces with the 'k8s-sync' label
"""
username = request.namespace
# namespace = self.kube.get_namespace(request.namespace)

# if 'k8s-sync' in namespace.labels:
user = self.awsed.describe_user(username)
if not user:
raise ValidationFailure(
f"namespace: no AWSEd user found with username {username}")
allowed_uid = user.uid
allowed_courses = user.enrollments

team_response = self.awsed.list_user_teams(username)
allowed_gids = [team.gid for team in team_response.teams]
allowed_gids.append(0)
allowed_gids.append(100)

metadata = request.object.metadata
spec = request.object.spec

if metadata is not None and metadata.labels is not None:
self.validate_course_enrollment(allowed_courses, metadata.labels)

self.validate_pod_security_context(
allowed_uid, allowed_gids, spec.securityContext)
self.validate_containers(allowed_uid, allowed_gids, spec)

def validate_course_enrollment(self, allowed_courses: List[str], labels: Dict[str, str]):
if not 'dsmlp/course' in labels:
return
if not labels['dsmlp/course'] in allowed_courses:
raise ValidationFailure(
f"metadata.labels: dsmlp/course must be in range {allowed_courses}")

def validate_pod_security_context(
self,
authorized_uid: int,
allowed_teams: List[int],
securityContext: PodSecurityContext):

if securityContext is None:
return

if securityContext.runAsUser is not None and authorized_uid != securityContext.runAsUser:
raise ValidationFailure(
f"spec.securityContext: uid must be in range [{authorized_uid}]")

if securityContext.runAsGroup is not None and securityContext.runAsGroup not in allowed_teams:
raise ValidationFailure(
f"spec.securityContext: gid must be in range {allowed_teams}")

if securityContext.fsGroup is not None and securityContext.fsGroup not in allowed_teams:
raise ValidationFailure(
f"spec.securityContext: gid must be in range {allowed_teams}")

if securityContext.supplementalGroups is not None:
for sgroup in securityContext.supplementalGroups:
if not sgroup in allowed_teams:
raise ValidationFailure(
f"spec.securityContext: gid must be in range {allowed_teams}")

def validate_containers(
self,
authorized_uid: int,
allowed_teams: List[int],
spec: PodSpec
):
"""
Validate the security context of containers and initContainers
"""
self.validate_security_contexts(
authorized_uid, allowed_teams, spec.containers, "containers")
self.validate_security_contexts(
authorized_uid, allowed_teams, spec.initContainers, "initContainers")

def validate_security_contexts(
self, authorized_uid: int, allowed_teams: List[int],
containers: List[Container],
context: str):
"""
Validate the security context of a container.
"""

if containers is None:
return

for i, container in enumerate(containers):
securityContext = container.securityContext
if securityContext is None:
continue

self.validate_security_context(
authorized_uid, allowed_teams, securityContext, f"{context}[{i}]")

def validate_security_context(
self,
authorized_uid: int,
allowed_teams: List[int],
securityContext: SecurityContext,
context: str):

if securityContext.runAsUser is not None and authorized_uid != securityContext.runAsUser:
raise ValidationFailure(
f"spec.{context}.securityContext: uid must be in range [{authorized_uid}]")

if securityContext.runAsGroup is not None and securityContext.runAsGroup not in allowed_teams:
raise ValidationFailure(
f"spec.{context}.securityContext: gid must be in range {allowed_teams}")

def admission_response(self, uid, allowed, message):
return {
"apiVersion": "admission.k8s.io/v1",
"kind": "AdmissionReview",
"response": {
"uid": uid,
"allowed": allowed,
"status": {
"message": message
}
}
}
85 changes: 85 additions & 0 deletions src/dsmlp/app/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@

from dataclasses import dataclass
from typing import List, Optional, Dict
from dataclasses_json import dataclass_json
from abc import ABCMeta, abstractmethod

@dataclass_json
@dataclass
class SecurityContext:
"""Each Container has a SecurityContext"""
runAsUser: Optional[int] = None
runAsGroup: Optional[int] = None

@dataclass_json
@dataclass
class ResourceRequirements:
requests: Optional[Dict[str, int]] = None
limits: Optional[Dict[str, int]] = None

@dataclass_json
@dataclass
class Container:
securityContext: Optional[SecurityContext] = None
resources: Optional[ResourceRequirements] = None

@dataclass_json
@dataclass
class PodSecurityContext:
"""Each Pod has a SecurityContext"""
runAsUser: Optional[int] = None
runAsGroup: Optional[int] = None
fsGroup: Optional[int] = None
supplementalGroups: Optional[List[int]] = None


@dataclass_json
@dataclass
class PodSpec:
containers: List[Container]
initContainers: Optional[List[Container]] = None
securityContext: Optional[PodSecurityContext] = None
priorityClassName: Optional[str] = None

@dataclass_json
@dataclass
class ObjectMeta:
labels: Dict[str, str]


@dataclass_json
@dataclass
class Object:
metadata: ObjectMeta
spec: PodSpec


@dataclass_json
@dataclass
class UserInfo:
username: str


@dataclass_json
@dataclass
class Request:
uid: str
namespace: str
object: Object
userInfo: UserInfo


@dataclass_json
@dataclass
class AdmissionReview:
request: Request

class ValidationFailure(Exception):
def __init__(self, message: str) -> None:
self.message = message
super().__init__(self.message)

class ComponentValidator:
@abstractmethod
def validate_pod(self, request: Request):
pass
Loading

0 comments on commit 83d2024

Please sign in to comment.