Skip to content

Commit

Permalink
Merge branch 'master' of github.com:armadaproject/armada into f/chris…
Browse files Browse the repository at this point in the history
…ma/scxhedulerobjects-internal
  • Loading branch information
d80tb7 committed Jan 29, 2025
2 parents ce0aa39 + fcee91c commit f6717ae
Show file tree
Hide file tree
Showing 172 changed files with 4,091 additions and 891 deletions.
27 changes: 21 additions & 6 deletions client/python/armada_client/asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(
event_timeout: timedelta = timedelta(minutes=15),
) -> None:
self.submit_stub = submit_pb2_grpc.SubmitStub(channel)
self.queue_stub = submit_pb2_grpc.QueueServiceStub(channel)
self.event_stub = event_pb2_grpc.EventStub(channel)
self.job_stub = job_pb2_grpc.JobsStub(channel)
self.event_timeout = event_timeout
Expand Down Expand Up @@ -390,7 +391,7 @@ async def create_queue(self, queue: submit_pb2.Queue) -> empty_pb2.Empty:
:param queue: A queue to create.
"""

response = await self.submit_stub.CreateQueue(queue)
response = await self.queue_stub.CreateQueue(queue)
return response

async def update_queue(self, queue: submit_pb2.Queue) -> empty_pb2.Empty:
Expand All @@ -400,7 +401,7 @@ async def update_queue(self, queue: submit_pb2.Queue) -> empty_pb2.Empty:
:param queue: A queue to update.
"""

response = await self.submit_stub.UpdateQueue(queue)
response = await self.queue_stub.UpdateQueue(queue)
return response

async def create_queues(
Expand All @@ -413,7 +414,7 @@ async def create_queues(
"""

queue_list = submit_pb2.QueueList(queues=queues)
response = await self.submit_stub.CreateQueues(queue_list)
response = await self.queue_stub.CreateQueues(queue_list)
return response

async def update_queues(
Expand All @@ -426,7 +427,7 @@ async def update_queues(
"""

queue_list = submit_pb2.QueueList(queues=queues)
response = await self.submit_stub.UpdateQueues(queue_list)
response = await self.queue_stub.UpdateQueues(queue_list)
return response

async def delete_queue(self, name: str) -> None:
Expand All @@ -438,7 +439,7 @@ async def delete_queue(self, name: str) -> None:
:return: None
"""
request = submit_pb2.QueueDeleteRequest(name=name)
await self.submit_stub.DeleteQueue(request)
await self.queue_stub.DeleteQueue(request)

async def get_queue(self, name: str) -> submit_pb2.Queue:
"""Get the queue by name.
Expand All @@ -449,9 +450,23 @@ async def get_queue(self, name: str) -> submit_pb2.Queue:
:return: A queue object. See the api definition.
"""
request = submit_pb2.QueueGetRequest(name=name)
response = await self.submit_stub.GetQueue(request)
response = await self.queue_stub.GetQueue(request)
return response

async def get_queues(self) -> list:
"""Retrieves all queues
:return: List containing all queues.
"""
queues = []
request = submit_pb2.StreamingQueueGetRequest()
async for message in self.queue_stub.GetQueues(request):
event_type = message.WhichOneof("event")
if event_type == "queue":
queues.append(message.queue)
elif event_type == "end":
break
return queues

@staticmethod
def unwatch_events(event_stream) -> None:
"""Closes gRPC event streams
Expand Down
34 changes: 27 additions & 7 deletions client/python/armada_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,9 @@ class ArmadaClient:
def __init__(self, channel, event_timeout: timedelta = timedelta(minutes=15)):
self.submit_stub = submit_pb2_grpc.SubmitStub(channel)
self.event_stub = event_pb2_grpc.EventStub(channel)
self.event_timeout = event_timeout
self.job_stub = job_pb2_grpc.JobsStub(channel)
self.queue_stub = submit_pb2_grpc.QueueServiceStub(channel)
self.event_timeout = event_timeout

def get_job_events_stream(
self,
Expand Down Expand Up @@ -290,6 +291,7 @@ def cancel_jobset(
) -> empty_pb2.Empty:
"""Cancel jobs in a given queue.
Uses the CancelJobSet RPC to cancel jobs.
A filter is used to only cancel jobs in certain states.
Expand Down Expand Up @@ -378,7 +380,7 @@ def create_queue(self, queue: submit_pb2.Queue) -> empty_pb2.Empty:
:param queue: A queue to create.
"""

response = self.submit_stub.CreateQueue(queue)
response = self.queue_stub.CreateQueue(queue)
return response

def update_queue(self, queue: submit_pb2.Queue) -> empty_pb2.Empty:
Expand All @@ -388,7 +390,7 @@ def update_queue(self, queue: submit_pb2.Queue) -> empty_pb2.Empty:
:param queue: A queue to update.
"""

response = self.submit_stub.UpdateQueue(queue)
response = self.queue_stub.UpdateQueue(queue)
return response

def create_queues(
Expand All @@ -401,7 +403,7 @@ def create_queues(
"""

queue_list = submit_pb2.QueueList(queues=queues)
response = self.submit_stub.CreateQueues(queue_list)
response = self.queue_stub.CreateQueues(queue_list)
return response

def update_queues(
Expand All @@ -414,7 +416,7 @@ def update_queues(
"""

queue_list = submit_pb2.QueueList(queues=queues)
response = self.submit_stub.UpdateQueues(queue_list)
response = self.queue_stub.UpdateQueues(queue_list)
return response

def delete_queue(self, name: str) -> None:
Expand All @@ -426,7 +428,7 @@ def delete_queue(self, name: str) -> None:
:return: None
"""
request = submit_pb2.QueueDeleteRequest(name=name)
self.submit_stub.DeleteQueue(request)
self.queue_stub.DeleteQueue(request)

def get_queue(self, name: str) -> submit_pb2.Queue:
"""Get the queue by name.
Expand All @@ -437,9 +439,27 @@ def get_queue(self, name: str) -> submit_pb2.Queue:
:return: A queue object. See the api definition.
"""
request = submit_pb2.QueueGetRequest(name=name)
response = self.submit_stub.GetQueue(request)
response = self.queue_stub.GetQueue(request)
return response

def get_queues(self) -> List[submit_pb2.Queue]:
"""Get all queues.
Uses the GetQueues RPC to get the queues.
:return: list containing all queues
"""
queues = []

request = submit_pb2.StreamingQueueGetRequest()

for message in self.queue_stub.GetQueues(request):
if message.HasField("queue"):
queues.append(message.queue)
elif message.HasField("end"):
break
return queues

@staticmethod
def unwatch_events(event_stream) -> None:
"""Closes gRPC event streams
Expand Down
2 changes: 1 addition & 1 deletion client/python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "armada_client"
version = "0.4.10"
version = "0.4.11"
description = "Armada gRPC API python client"
readme = "README.md"
requires-python = ">=3.9"
Expand Down
58 changes: 35 additions & 23 deletions client/python/tests/unit/server_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from armada_client.armada.submit_pb2 import JobState


class SubmitService(submit_pb2_grpc.SubmitServicer):
class QueueService(submit_pb2_grpc.QueueServiceServicer):
def CreateQueue(self, request, context):
return empty_pb2.Empty()

Expand All @@ -25,6 +25,40 @@ def DeleteQueue(self, request, context):
def GetQueue(self, request, context):
return submit_pb2.Queue(name=request.name)

def GetQueues(self, request, context):
queue_names = ["test_queue1", "test_queue2", "test_queue3"]
for name in queue_names:
queue_message = submit_pb2.StreamingQueueMessage(
queue=submit_pb2.Queue(name=name)
)
yield queue_message

yield submit_pb2.StreamingQueueMessage(end=submit_pb2.EndMarker())

def GetQueueInfo(self, request, context):
return submit_pb2.QueueInfo(name=request.name)

def CreateQueues(self, request, context):
return submit_pb2.BatchQueueCreateResponse(
failed_queues=[
submit_pb2.QueueCreateResponse(queue=submit_pb2.Queue(name=queue.name))
for queue in request.queues
]
)

def UpdateQueues(self, request, context):
return submit_pb2.BatchQueueUpdateResponse(
failed_queues=[
submit_pb2.QueueUpdateResponse(queue=submit_pb2.Queue(name=queue.name))
for queue in request.queues
]
)

def UpdateQueue(self, request, context):
return empty_pb2.Empty()


class SubmitService(submit_pb2_grpc.SubmitServicer):
def SubmitJobs(self, request, context):
# read job_ids from request.job_request_items
job_ids = [f"job-{i}" for i in range(1, len(request.job_request_items) + 1)]
Expand All @@ -35,9 +69,6 @@ def SubmitJobs(self, request, context):

return submit_pb2.JobSubmitResponse(job_response_items=job_response_items)

def GetQueueInfo(self, request, context):
return submit_pb2.QueueInfo(name=request.name)

def CancelJobs(self, request, context):
return submit_pb2.CancellationResult(
cancelled_ids=["job-1"],
Expand Down Expand Up @@ -72,25 +103,6 @@ def ReprioritizeJobs(self, request, context):

return submit_pb2.JobReprioritizeResponse(reprioritization_results=results)

def UpdateQueue(self, request, context):
return empty_pb2.Empty()

def CreateQueues(self, request, context):
return submit_pb2.BatchQueueCreateResponse(
failed_queues=[
submit_pb2.QueueCreateResponse(queue=submit_pb2.Queue(name=queue.name))
for queue in request.queues
]
)

def UpdateQueues(self, request, context):
return submit_pb2.BatchQueueUpdateResponse(
failed_queues=[
submit_pb2.QueueUpdateResponse(queue=submit_pb2.Queue(name=queue.name))
for queue in request.queues
]
)

def Health(self, request, context):
return health_pb2.HealthCheckResponse(
status=health_pb2.HealthCheckResponse.SERVING
Expand Down
10 changes: 9 additions & 1 deletion client/python/tests/unit/test_asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from armada_client.typings import JobState
from armada_client.armada.job_pb2 import JobRunState
from server_mock import EventService, SubmitService, QueryAPIService
from server_mock import EventService, SubmitService, QueueService, QueryAPIService

from armada_client.armada import (
event_pb2_grpc,
Expand All @@ -28,6 +28,7 @@
def server_mock():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
submit_pb2_grpc.add_SubmitServicer_to_server(SubmitService(), server)
submit_pb2_grpc.add_QueueServiceServicer_to_server(QueueService(), server)
event_pb2_grpc.add_EventServicer_to_server(EventService(), server)
job_pb2_grpc.add_JobsServicer_to_server(QueryAPIService(), server)
server.add_insecure_port("[::]:50051")
Expand Down Expand Up @@ -175,6 +176,13 @@ async def test_get_queue(aio_client):
assert queue.name == "test"


@pytest.mark.asyncio
async def test_get_queues(aio_client):
queues = await aio_client.get_queues()
queue_names = [q.name for q in queues]
assert queue_names == ["test_queue1", "test_queue2", "test_queue3"]


@pytest.mark.asyncio
async def test_delete_queue(aio_client):
await aio_client.delete_queue("test")
Expand Down
9 changes: 8 additions & 1 deletion client/python/tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from armada_client.typings import JobState
from armada_client.armada.job_pb2 import JobRunState
from server_mock import EventService, SubmitService, QueryAPIService
from server_mock import EventService, SubmitService, QueryAPIService, QueueService

from armada_client.armada import (
event_pb2_grpc,
Expand All @@ -27,6 +27,7 @@
def server_mock():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
submit_pb2_grpc.add_SubmitServicer_to_server(SubmitService(), server)
submit_pb2_grpc.add_QueueServiceServicer_to_server(QueueService(), server)
event_pb2_grpc.add_EventServicer_to_server(EventService(), server)
job_pb2_grpc.add_JobsServicer_to_server(QueryAPIService(), server)
server.add_insecure_port("[::]:50051")
Expand Down Expand Up @@ -165,6 +166,12 @@ def test_get_queue():
assert tester.get_queue("test").name == "test"


def test_get_queues():
queues = tester.get_queues()
queue_names = [q.name for q in queues]
assert queue_names == ["test_queue1", "test_queue2", "test_queue3"]


def test_delete_queue():
tester.delete_queue("test")

Expand Down
4 changes: 2 additions & 2 deletions cmd/armada-load-tester/cmd/loadtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import (
"strings"
"time"

log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/spf13/viper"

log "github.com/armadaproject/armada/internal/common/logging"
"github.com/armadaproject/armada/pkg/client"
"github.com/armadaproject/armada/pkg/client/domain"
"github.com/armadaproject/armada/pkg/client/util"
Expand Down Expand Up @@ -72,7 +72,7 @@ var loadtestCmd = &cobra.Command{
loadTestSpec := &domain.LoadTestSpecification{}
err := util.BindJsonOrYaml(filePath, loadTestSpec)
if err != nil {
log.Error(err)
log.Error(err.Error())
os.Exit(1)
}

Expand Down
6 changes: 3 additions & 3 deletions cmd/armada-load-tester/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ package cmd
import (
"os"

log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"

log "github.com/armadaproject/armada/internal/common/logging"
"github.com/armadaproject/armada/pkg/client"
)

Expand All @@ -28,7 +28,7 @@ The location of this file can be passed in using --config argument or picked fro
// This is called by main.main(). It only needs to happen once to the rootCmd.
func Execute() {
if err := rootCmd.Execute(); err != nil {
log.Error(err)
log.Error(err.Error())
os.Exit(1)
}
}
Expand All @@ -37,7 +37,7 @@ var cfgFile string

func initConfig() {
if err := client.LoadCommandlineArgsFromConfigFile(cfgFile); err != nil {
log.Error(err)
log.Error(err.Error())
os.Exit(1)
}
}
2 changes: 1 addition & 1 deletion cmd/binoculars/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"syscall"

"github.com/grpc-ecosystem/grpc-gateway/runtime"
log "github.com/sirupsen/logrus"
"github.com/spf13/pflag"
"github.com/spf13/viper"
"google.golang.org/grpc"
Expand All @@ -19,6 +18,7 @@ import (
"github.com/armadaproject/armada/internal/common/armadacontext"
gateway "github.com/armadaproject/armada/internal/common/grpc"
"github.com/armadaproject/armada/internal/common/health"
log "github.com/armadaproject/armada/internal/common/logging"
"github.com/armadaproject/armada/internal/common/profiling"
api "github.com/armadaproject/armada/pkg/api/binoculars"
)
Expand Down
Loading

0 comments on commit f6717ae

Please sign in to comment.