Skip to content

Commit

Permalink
The key should nvidia.com/gpu instead of GPU when fetching awsed gpu …
Browse files Browse the repository at this point in the history
…quota
  • Loading branch information
trn024 committed Aug 29, 2024
1 parent c3f6393 commit bdd3773
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/dsmlp/ext/awsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import awsed.client
import awsed.types
import logging
from dsmlp.plugin.logger import Logger

# added logging to check if API has an error getting GPU quota
logging.basicConfig(level=logging.ERROR, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
Expand Down Expand Up @@ -38,7 +39,7 @@ def get_user_gpu_quota(self, username: str) -> UserQuotaResponse:
usrGpuQuota = self.client.get_user_quota(username)
if not usrGpuQuota:
return None
gpu_quota = usrGpuQuota.get("gpu", 0)
gpu_quota = usrGpuQuota.get("nvidia.com/gpu", 0)
quota = Quota(user=username, resources=gpu_quota)
return UserQuotaResponse(quota=quota)
except Exception as e:
Expand Down
6 changes: 3 additions & 3 deletions tests/app/test_gpu_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def try_validate(self, json, expected: bool, message: str = None):

# Test correct response for get_user_gpu_quota method
def test_awsed_gpu_quota_correct_response(self):
self.awsed_client.assign_user_gpu_quota('user11', {"gpu": 5})
self.awsed_client.assign_user_gpu_quota('user11', {"nvidia.com/gpu": 5})
user_gpu_quota = self.awsed_client.get_user_gpu_quota('user11')
assert_that(user_gpu_quota, equal_to(5))

Expand Down Expand Up @@ -123,7 +123,7 @@ def test_gpu_quota_client_priority(self):

self.kube_client.set_existing_gpus('user11', 3)
# add awsed quota
self.awsed_client.assign_user_gpu_quota('user11', {"gpu": 6})
self.awsed_client.assign_user_gpu_quota('user11', {"nvidia.com/gpu": 6})
self.try_validate(
gen_request(gpu_req=6, username='user11'), expected=False, message="GPU quota exceeded. Wanted 6 but with 3 already in use, the quota of 6 would be exceeded."
)
Expand All @@ -133,7 +133,7 @@ def test_gpu_quota_client_priority2(self):
self.kube_client.add_namespace('user11', Namespace(
name='user11', labels={'k8s-sync': 'true'}, gpu_quota=12))
# add awsed quota
self.awsed_client.assign_user_gpu_quota('user11', {"gpu": 18})
self.awsed_client.assign_user_gpu_quota('user11', {"nvidia.com/gpu": 18})

# set existing gpu = kube client quota
self.kube_client.set_existing_gpus('user11', 12)
Expand Down
2 changes: 1 addition & 1 deletion tests/fakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def describe_user(self, username: str) -> UserResponse:
def get_user_gpu_quota(self, username: str) -> int:
try:
user_quota_response = self.user_quota[username]
return user_quota_response.quota.resources.get("gpu", 0)
return user_quota_response.quota.resources.get("nvidia.com/gpu", 0)
except KeyError:
return 0

Expand Down

0 comments on commit bdd3773

Please sign in to comment.