Skip to content

Commit

Permalink
feat: Add user GPU quota retrieval to AwsedClient
Browse files Browse the repository at this point in the history
  • Loading branch information
trn024 committed Jul 22, 2024
1 parent 83d2024 commit bb6b015
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 15 deletions.
23 changes: 19 additions & 4 deletions src/dsmlp/app/gpu_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

class GPUValidator(ComponentValidator):

def __init__(self, kube: KubeClient, logger: Logger) -> None:
def __init__(self, awsed: AwsedClient, kube: KubeClient, logger: Logger) -> None:
self.kube = kube
self.logger = logger
self.awsed = awsed

def validate_pod(self, request: Request):
"""
Expand All @@ -32,7 +33,20 @@ def validate_pod(self, request: Request):

namespace = self.kube.get_namespace(request.namespace)
curr_gpus = self.kube.get_gpus_in_namespace(request.namespace)

awsed_gpu_quota = self.awsed.get_user_gpu_quota(request.namespace)
"""
Use AWSED GPU quota if it is not None and greater than 0
else use namespace GPU quota if it is not None and greater than 0
else use 1 as default
"""

gpu_quota = 1
if awsed_gpu_quota is not None and awsed_gpu_quota > 0:
gpu_quota = awsed_gpu_quota
elif namespace.gpu_quota is not None and namespace.gpu_quota > 0:
gpu_quota = namespace.gpu_quota

# Calculate the number of GPUs requested for kube client
utilized_gpus = 0
for container in request.object.spec.containers:
requested, limit = 0, 0
Expand All @@ -51,6 +65,7 @@ def validate_pod(self, request: Request):
if utilized_gpus == 0:
return

if utilized_gpus + curr_gpus > namespace.gpu_quota:
# Check if the total number of utilized GPUs exceeds the GPU quota
if utilized_gpus + curr_gpus > gpu_quota:
raise ValidationFailure(
f"GPU quota exceeded. Wanted {utilized_gpus} but with {curr_gpus} already in use, the quota of {namespace.gpu_quota} would be exceeded.")
f"GPU quota exceeded. Wanted {utilized_gpus} but with {curr_gpus} already in use, the quota of {gpu_quota} would be exceeded.")
2 changes: 2 additions & 0 deletions src/dsmlp/app/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from dataclasses_json import dataclass_json
from abc import ABCMeta, abstractmethod

# Kubernetes API types

@dataclass_json
@dataclass
class SecurityContext:
Expand Down
2 changes: 1 addition & 1 deletion src/dsmlp/app/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Validator:
def __init__(self, awsed: AwsedClient, kube: KubeClient, logger: Logger) -> None:
self.awsed = awsed
self.logger = logger
self.component_validators = [IDValidator(awsed, logger), GPUValidator(kube, logger)]
self.component_validators = [IDValidator(awsed, logger), GPUValidator(awsed, kube, logger)]

def validate_request(self, admission_review_json):
self.logger.debug("request=" + json.dumps(admission_review_json, indent=2))
Expand Down
14 changes: 12 additions & 2 deletions src/dsmlp/ext/awsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import requests
from dacite import from_dict

from dsmlp.plugin.awsed import AwsedClient, ListTeamsResponse, TeamJson, UnsuccessfulRequest, UserResponse
from dsmlp.plugin.awsed import AwsedClient, ListTeamsResponse, TeamJson, UnsuccessfulRequest, UserResponse, UserGpuQuotaResponse

import awsed.client
import awsed.types

class ExternalAwsedClient(AwsedClient):
class ExternalAwsedClient(AwsedClient):
def __init__(self):
self.client = awsed.client.DefaultAwsedClient(endpoint=os.environ.get('AWSED_ENDPOINT'),
awsed_api_key=os.environ.get('AWSED_API_KEY'))
Expand All @@ -27,3 +27,13 @@ def list_user_teams(self, username: str) -> ListTeamsResponse:
teams.append(TeamJson(gid=team.gid))

return ListTeamsResponse(teams=teams)

# Fetch user's GPU quota with AWSED Api and assign to UserGpuQuotaResponse object
def get_user_gpu_quota(self, username: str) -> UserGpuQuotaResponse:
try:
usrGpuQuota = self.client.fetch_user_gpu_quota(username)
if not usrGpuQuota:
return None
return UserGpuQuotaResponse(gpu_quota=usrGpuQuota.gpuQuota)
except:
return None
11 changes: 9 additions & 2 deletions src/dsmlp/plugin/awsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,24 @@ class UserResponse:
uid: int
enrollments: List[str]

@dataclass
class UserGpuQuotaResponse:
gpu_quota: int

class AwsedClient(metaclass=ABCMeta):
@abstractmethod
def list_user_teams(self, username: str) -> ListTeamsResponse:
"""Return the groups of a course"""
# Return the groups of a course
pass

@abstractmethod
def describe_user(self, username: str) -> UserResponse:
pass


@abstractmethod
def get_user_gpu_quota(self, username: str) -> UserGpuQuotaResponse:
# Return the gpu quota of a user
pass

class UnsuccessfulRequest(Exception):
pass
62 changes: 61 additions & 1 deletion tests/app/test_gpu_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ def setup_method(self) -> None:
self.kube_client.add_namespace('user10', Namespace(
name='user10', labels={'k8s-sync': 'true'}, gpu_quota=10))
self.kube_client.set_existing_gpus('user10', 0)

# Set gpu quota for user 10 with AWSED client
self.awsed_client.add_user_gpu_quota('user10', 10)

# Set up user11 without any quota & namespace.
self.awsed_client.add_user(
'user11', UserResponse(uid=11, enrollments=[]))
self.awsed_client.add_teams('user11', ListTeamsResponse(
teams=[TeamJson(gid=1001)]
))

def test_no_gpus_requested(self):
self.try_validate(
Expand Down Expand Up @@ -74,5 +84,55 @@ def test_low_priority_overcap(self):
gen_request(), expected=True)

def try_validate(self, json, expected: bool, message: str = None):
try_val_with_component(GPUValidator(
try_val_with_component(GPUValidator(self.awsed_client,
self.kube_client, self.logger), json, expected, message)

# Test correct response for get_user_gpu_quota method
def test_awsed_gpu_quota_correct_response(self):
self.awsed_client.add_user_gpu_quota('user11', 5)
user_gpu_quota = self.awsed_client.get_user_gpu_quota('user11')
assert_that(user_gpu_quota, equal_to(5))

# No quota set for user 11 from both kube and awsed, should return default value 1
def test_gpu_validator_default_limit(self):
self.kube_client.add_namespace('user11', Namespace(
name='user11', labels={'k8s-sync': 'true'}, gpu_quota=0))

self.kube_client.set_existing_gpus('user11', 0)
self.try_validate(
gen_request(gpu_req=11, username='user11'), expected=False, message="GPU quota exceeded. Wanted 11 but with 0 already in use, the quota of 1 would be exceeded."
)

# No quota set for user 11 from kube, but set from kube client, should return 5
def test_no_awsed_gpu_quota(self):
self.kube_client.add_namespace('user11', Namespace(
name='user11', labels={'k8s-sync': 'true'}, gpu_quota=5))

self.kube_client.set_existing_gpus('user11', 0)
self.try_validate(
gen_request(gpu_req=11, username='user11'), expected=False, message="GPU quota exceeded. Wanted 11 but with 0 already in use, the quota of 5 would be exceeded."
)

# Quota both set for user 11 from kube and awsed, should prioritize AWSED quota
def test_gpu_quota_client_priority(self):
self.kube_client.add_namespace('user11', Namespace(
name='user11', labels={'k8s-sync': 'true'}, gpu_quota=8))

self.kube_client.set_existing_gpus('user11', 3)
self.awsed_client.add_user_gpu_quota('user11', 6)
self.try_validate(
gen_request(gpu_req=6, username='user11'), expected=False, message="GPU quota exceeded. Wanted 6 but with 3 already in use, the quota of 6 would be exceeded."
)

# Quota both set for user 11 from kube and awsed, should prioritize AWSED quota
def test_gpu_quota_client_priority2(self):
self.awsed_client.add_user_gpu_quota('user11', 18)
self.kube_client.add_namespace('user11', Namespace(
name='user11', labels={'k8s-sync': 'true'}, gpu_quota=12))

# set existing gpu = kube client quota
self.kube_client.set_existing_gpus('user11', 12)

self.try_validate(
gen_request(gpu_req=6, username='user11'), expected=True, message="GPU quota exceeded. Wanted 6 but with 5 already in use, the quota of 18 would be exceeded."
)
26 changes: 25 additions & 1 deletion tests/app/test_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,31 @@ def test_log_allowed_requests(self):

assert_that(self.logger.messages, has_item(
"INFO Allowed request username=user10 namespace=user10 uid=705ab4f5-6393-11e8-b7cc-42010a800002"))


# def test_gpu_quota_request(self):
# self.awsed_client.add_user_gpu_quota('user10', 10)
# self.awsed_client.get_user_gpu_quota('user10')

# response = self.when_validate(
# {
# "request": {
# "uid": "705ab4f5-6393-11e8-b7cc-42010a800002",
# "namespace": "user10",
# "userInfo": {
# "username": "user10"
# },
# "object": {
# "metadata": {
# "labels": {}
# },
# "spec": {
# "containers": [{}]
# }
# }
# }
# }
# )

def when_validate(self, json):
validator = Validator(self.awsed_client, self.kube_client, self.logger)
response = validator.validate_request(json)
Expand Down
8 changes: 6 additions & 2 deletions tests/app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@


def gen_request(gpu_req: int = 0, gpu_lim: int = 0, low_priority: bool = False, uid: str = "705ab4f5-6393-11e8-b7cc-42010a800002", course: str = None,
run_as_user: int = None, run_as_group: int = None, fs_group: int = None, supplemental_groups: List[int] = None, username: str = "user10", has_container: bool = True,
run_as_user: int = None, run_as_group: int = None, fs_group: int = None, supplemental_groups: List[int] = None, username: str = None, has_container: bool = True,
container_override: List[Container] = None, init_containers: List[Container] = None) -> Request:


# add default username is user10 unless specified during testing
if username is None:
username = 'user10'

res_req = None
if gpu_req > 0:
if res_req is None:
Expand Down
15 changes: 13 additions & 2 deletions tests/fakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from dacite import from_dict

from dsmlp.plugin.awsed import AwsedClient, ListTeamsResponse, UnsuccessfulRequest, UserResponse
from dsmlp.plugin.awsed import AwsedClient, ListTeamsResponse, UnsuccessfulRequest, UserResponse, UserGpuQuotaResponse
from dsmlp.plugin.kube import KubeClient, Namespace, NotFound
from dsmlp.plugin.logger import Logger

Expand All @@ -14,6 +14,7 @@ class FakeAwsedClient(AwsedClient):
def __init__(self):
self.teams: Dict[str, ListTeamsResponse] = {}
self.users: Dict[str, UserResponse] = {}
self.user_gpu_quota: Dict[str, UserGpuQuotaResponse] = {}

def list_user_teams(self, username: str) -> ListTeamsResponse:
try:
Expand All @@ -28,7 +29,17 @@ def describe_user(self, username: str) -> UserResponse:
return self.users[username]
except KeyError:
return None


# Get user GPU quota. If user does not exist, return 0
def get_user_gpu_quota(self, username: str) -> UserGpuQuotaResponse:
try:
return self.user_gpu_quota[username]
except KeyError:
return 0

def add_user_gpu_quota(self, username, gpu_quota: UserGpuQuotaResponse):
self.user_gpu_quota[username] = gpu_quota

def add_user(self, username, user: UserResponse):
self.users[username] = user

Expand Down

0 comments on commit bb6b015

Please sign in to comment.