Skip to content

Commit

Permalink
Merge branch 'master' into fix_queue_cache_startup_race
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesMurkin authored Jan 30, 2025
2 parents ed8b62b + fcee91c commit d33da75
Show file tree
Hide file tree
Showing 47 changed files with 1,565 additions and 121 deletions.
27 changes: 21 additions & 6 deletions client/python/armada_client/asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(
event_timeout: timedelta = timedelta(minutes=15),
) -> None:
self.submit_stub = submit_pb2_grpc.SubmitStub(channel)
self.queue_stub = submit_pb2_grpc.QueueServiceStub(channel)
self.event_stub = event_pb2_grpc.EventStub(channel)
self.job_stub = job_pb2_grpc.JobsStub(channel)
self.event_timeout = event_timeout
Expand Down Expand Up @@ -390,7 +391,7 @@ async def create_queue(self, queue: submit_pb2.Queue) -> empty_pb2.Empty:
:param queue: A queue to create.
"""

response = await self.submit_stub.CreateQueue(queue)
response = await self.queue_stub.CreateQueue(queue)
return response

async def update_queue(self, queue: submit_pb2.Queue) -> empty_pb2.Empty:
Expand All @@ -400,7 +401,7 @@ async def update_queue(self, queue: submit_pb2.Queue) -> empty_pb2.Empty:
:param queue: A queue to update.
"""

response = await self.submit_stub.UpdateQueue(queue)
response = await self.queue_stub.UpdateQueue(queue)
return response

async def create_queues(
Expand All @@ -413,7 +414,7 @@ async def create_queues(
"""

queue_list = submit_pb2.QueueList(queues=queues)
response = await self.submit_stub.CreateQueues(queue_list)
response = await self.queue_stub.CreateQueues(queue_list)
return response

async def update_queues(
Expand All @@ -426,7 +427,7 @@ async def update_queues(
"""

queue_list = submit_pb2.QueueList(queues=queues)
response = await self.submit_stub.UpdateQueues(queue_list)
response = await self.queue_stub.UpdateQueues(queue_list)
return response

async def delete_queue(self, name: str) -> None:
Expand All @@ -438,7 +439,7 @@ async def delete_queue(self, name: str) -> None:
:return: None
"""
request = submit_pb2.QueueDeleteRequest(name=name)
await self.submit_stub.DeleteQueue(request)
await self.queue_stub.DeleteQueue(request)

async def get_queue(self, name: str) -> submit_pb2.Queue:
"""Get the queue by name.
Expand All @@ -449,9 +450,23 @@ async def get_queue(self, name: str) -> submit_pb2.Queue:
:return: A queue object. See the api definition.
"""
request = submit_pb2.QueueGetRequest(name=name)
response = await self.submit_stub.GetQueue(request)
response = await self.queue_stub.GetQueue(request)
return response

async def get_queues(self) -> list:
"""Retrieves all queues
:return: List containing all queues.
"""
queues = []
request = submit_pb2.StreamingQueueGetRequest()
async for message in self.queue_stub.GetQueues(request):
event_type = message.WhichOneof("event")
if event_type == "queue":
queues.append(message.queue)
elif event_type == "end":
break
return queues

@staticmethod
def unwatch_events(event_stream) -> None:
"""Closes gRPC event streams
Expand Down
34 changes: 27 additions & 7 deletions client/python/armada_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,9 @@ class ArmadaClient:
def __init__(self, channel, event_timeout: timedelta = timedelta(minutes=15)):
self.submit_stub = submit_pb2_grpc.SubmitStub(channel)
self.event_stub = event_pb2_grpc.EventStub(channel)
self.event_timeout = event_timeout
self.job_stub = job_pb2_grpc.JobsStub(channel)
self.queue_stub = submit_pb2_grpc.QueueServiceStub(channel)
self.event_timeout = event_timeout

def get_job_events_stream(
self,
Expand Down Expand Up @@ -290,6 +291,7 @@ def cancel_jobset(
) -> empty_pb2.Empty:
"""Cancel jobs in a given queue.
Uses the CancelJobSet RPC to cancel jobs.
A filter is used to only cancel jobs in certain states.
Expand Down Expand Up @@ -378,7 +380,7 @@ def create_queue(self, queue: submit_pb2.Queue) -> empty_pb2.Empty:
:param queue: A queue to create.
"""

response = self.submit_stub.CreateQueue(queue)
response = self.queue_stub.CreateQueue(queue)
return response

def update_queue(self, queue: submit_pb2.Queue) -> empty_pb2.Empty:
Expand All @@ -388,7 +390,7 @@ def update_queue(self, queue: submit_pb2.Queue) -> empty_pb2.Empty:
:param queue: A queue to update.
"""

response = self.submit_stub.UpdateQueue(queue)
response = self.queue_stub.UpdateQueue(queue)
return response

def create_queues(
Expand All @@ -401,7 +403,7 @@ def create_queues(
"""

queue_list = submit_pb2.QueueList(queues=queues)
response = self.submit_stub.CreateQueues(queue_list)
response = self.queue_stub.CreateQueues(queue_list)
return response

def update_queues(
Expand All @@ -414,7 +416,7 @@ def update_queues(
"""

queue_list = submit_pb2.QueueList(queues=queues)
response = self.submit_stub.UpdateQueues(queue_list)
response = self.queue_stub.UpdateQueues(queue_list)
return response

def delete_queue(self, name: str) -> None:
Expand All @@ -426,7 +428,7 @@ def delete_queue(self, name: str) -> None:
:return: None
"""
request = submit_pb2.QueueDeleteRequest(name=name)
self.submit_stub.DeleteQueue(request)
self.queue_stub.DeleteQueue(request)

def get_queue(self, name: str) -> submit_pb2.Queue:
"""Get the queue by name.
Expand All @@ -437,9 +439,27 @@ def get_queue(self, name: str) -> submit_pb2.Queue:
:return: A queue object. See the api definition.
"""
request = submit_pb2.QueueGetRequest(name=name)
response = self.submit_stub.GetQueue(request)
response = self.queue_stub.GetQueue(request)
return response

def get_queues(self) -> List[submit_pb2.Queue]:
"""Get all queues.
Uses the GetQueues RPC to get the queues.
:return: list containing all queues
"""
queues = []

request = submit_pb2.StreamingQueueGetRequest()

for message in self.queue_stub.GetQueues(request):
if message.HasField("queue"):
queues.append(message.queue)
elif message.HasField("end"):
break
return queues

@staticmethod
def unwatch_events(event_stream) -> None:
"""Closes gRPC event streams
Expand Down
2 changes: 1 addition & 1 deletion client/python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "armada_client"
version = "0.4.10"
version = "0.4.11"
description = "Armada gRPC API python client"
readme = "README.md"
requires-python = ">=3.9"
Expand Down
58 changes: 35 additions & 23 deletions client/python/tests/unit/server_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from armada_client.armada.submit_pb2 import JobState


class SubmitService(submit_pb2_grpc.SubmitServicer):
class QueueService(submit_pb2_grpc.QueueServiceServicer):
def CreateQueue(self, request, context):
return empty_pb2.Empty()

Expand All @@ -25,6 +25,40 @@ def DeleteQueue(self, request, context):
def GetQueue(self, request, context):
return submit_pb2.Queue(name=request.name)

def GetQueues(self, request, context):
queue_names = ["test_queue1", "test_queue2", "test_queue3"]
for name in queue_names:
queue_message = submit_pb2.StreamingQueueMessage(
queue=submit_pb2.Queue(name=name)
)
yield queue_message

yield submit_pb2.StreamingQueueMessage(end=submit_pb2.EndMarker())

def GetQueueInfo(self, request, context):
return submit_pb2.QueueInfo(name=request.name)

def CreateQueues(self, request, context):
return submit_pb2.BatchQueueCreateResponse(
failed_queues=[
submit_pb2.QueueCreateResponse(queue=submit_pb2.Queue(name=queue.name))
for queue in request.queues
]
)

def UpdateQueues(self, request, context):
return submit_pb2.BatchQueueUpdateResponse(
failed_queues=[
submit_pb2.QueueUpdateResponse(queue=submit_pb2.Queue(name=queue.name))
for queue in request.queues
]
)

def UpdateQueue(self, request, context):
return empty_pb2.Empty()


class SubmitService(submit_pb2_grpc.SubmitServicer):
def SubmitJobs(self, request, context):
# read job_ids from request.job_request_items
job_ids = [f"job-{i}" for i in range(1, len(request.job_request_items) + 1)]
Expand All @@ -35,9 +69,6 @@ def SubmitJobs(self, request, context):

return submit_pb2.JobSubmitResponse(job_response_items=job_response_items)

def GetQueueInfo(self, request, context):
return submit_pb2.QueueInfo(name=request.name)

def CancelJobs(self, request, context):
return submit_pb2.CancellationResult(
cancelled_ids=["job-1"],
Expand Down Expand Up @@ -72,25 +103,6 @@ def ReprioritizeJobs(self, request, context):

return submit_pb2.JobReprioritizeResponse(reprioritization_results=results)

def UpdateQueue(self, request, context):
return empty_pb2.Empty()

def CreateQueues(self, request, context):
return submit_pb2.BatchQueueCreateResponse(
failed_queues=[
submit_pb2.QueueCreateResponse(queue=submit_pb2.Queue(name=queue.name))
for queue in request.queues
]
)

def UpdateQueues(self, request, context):
return submit_pb2.BatchQueueUpdateResponse(
failed_queues=[
submit_pb2.QueueUpdateResponse(queue=submit_pb2.Queue(name=queue.name))
for queue in request.queues
]
)

def Health(self, request, context):
return health_pb2.HealthCheckResponse(
status=health_pb2.HealthCheckResponse.SERVING
Expand Down
10 changes: 9 additions & 1 deletion client/python/tests/unit/test_asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from armada_client.typings import JobState
from armada_client.armada.job_pb2 import JobRunState
from server_mock import EventService, SubmitService, QueryAPIService
from server_mock import EventService, SubmitService, QueueService, QueryAPIService

from armada_client.armada import (
event_pb2_grpc,
Expand All @@ -28,6 +28,7 @@
def server_mock():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
submit_pb2_grpc.add_SubmitServicer_to_server(SubmitService(), server)
submit_pb2_grpc.add_QueueServiceServicer_to_server(QueueService(), server)
event_pb2_grpc.add_EventServicer_to_server(EventService(), server)
job_pb2_grpc.add_JobsServicer_to_server(QueryAPIService(), server)
server.add_insecure_port("[::]:50051")
Expand Down Expand Up @@ -175,6 +176,13 @@ async def test_get_queue(aio_client):
assert queue.name == "test"


@pytest.mark.asyncio
async def test_get_queues(aio_client):
queues = await aio_client.get_queues()
queue_names = [q.name for q in queues]
assert queue_names == ["test_queue1", "test_queue2", "test_queue3"]


@pytest.mark.asyncio
async def test_delete_queue(aio_client):
await aio_client.delete_queue("test")
Expand Down
9 changes: 8 additions & 1 deletion client/python/tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from armada_client.typings import JobState
from armada_client.armada.job_pb2 import JobRunState
from server_mock import EventService, SubmitService, QueryAPIService
from server_mock import EventService, SubmitService, QueryAPIService, QueueService

from armada_client.armada import (
event_pb2_grpc,
Expand All @@ -27,6 +27,7 @@
def server_mock():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
submit_pb2_grpc.add_SubmitServicer_to_server(SubmitService(), server)
submit_pb2_grpc.add_QueueServiceServicer_to_server(QueueService(), server)
event_pb2_grpc.add_EventServicer_to_server(EventService(), server)
job_pb2_grpc.add_JobsServicer_to_server(QueryAPIService(), server)
server.add_insecure_port("[::]:50051")
Expand Down Expand Up @@ -165,6 +166,12 @@ def test_get_queue():
assert tester.get_queue("test").name == "test"


def test_get_queues():
queues = tester.get_queues()
queue_names = [q.name for q in queues]
assert queue_names == ["test_queue1", "test_queue2", "test_queue3"]


def test_delete_queue():
tester.delete_queue("test")

Expand Down
3 changes: 2 additions & 1 deletion config/scheduler/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ pulsar:
armadaApi:
armadaUrl: "server:50051"
forceNoTls: true
priorityMultiplier:
enabled: false
postgres:
connection:
host: postgres
Expand Down Expand Up @@ -118,4 +120,3 @@ scheduling:
experimentalIndicativePricing:
basePrice: 100.0
basePriority: 500.0

27 changes: 27 additions & 0 deletions docs/floating_resources.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Floating Resources

Floating resources are designed to constrain the usage of resources that are not tied to nodes. For example, if you have a fileserver outside your Kubernetes clusters, you may want to limit how many connections to the fileserver can exist at once. In that case you would add config like the below (this goes under the `scheduling` section of the Armada scheduler config).

```
floatingResources:
- name: fileserver-connections
resolution: "1"
pools:
- name: cpu
quantity: 1000
- name: gpu
quantity: 500
```
When submitting a job, floating resources are specified in the same way as normal Kubernetes resources such as `cpu`. For example if a job needs 3 cpu cores and opens 10 connections to the fileserver, the job should specify
```
resources:
requests:
cpu: "3"
fileserver-connections: "10"
limits:
cpu: "3"
fileserver-connections: "10"
```
The `requests` section is used for scheduling. For floating resources, the `limits` section is not enforced by Armada (this it not possible in the general case). Instead the workload must be trusted to respect its limit.

If the jobs submitted to Armada request more of a floating resource than is available, they queue just as if they had exceeded the amount available of a standard Kubernetes resource (e.g. `cpu`). Floating resources generally behave like standard Kubernetes resources. They use the same code for queue ordering, pre-emption, etc.
18 changes: 18 additions & 0 deletions docs/python_armada_client.md
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,24 @@ Uses the GetQueue RPC to get the queue.



#### get_queues()
Get all queues.

Uses the GetQueues RPC to get the queues.


* **Returns**

list containing all queues



* **Return type**

*List*[armada.submit_pb2.Queue]



#### preempt_jobs(queue, job_set_id, job_id)
Preempt jobs in a given queue.

Expand Down
Loading

0 comments on commit d33da75

Please sign in to comment.