Skip to content

Commit

Permalink
Enforce supplying queue + jobset on Cancel + Reprioritise endpoints (a…
Browse files Browse the repository at this point in the history
…rmadaproject#3560)

We want to enforce these, so we don't need to look the values up in Redis when they aren't supplied

The motivation here is to so we can remove redis

Signed-off-by: JamesMurkin <[email protected]>
  • Loading branch information
JamesMurkin authored May 2, 2024
1 parent 7f2085d commit 70fd866
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 101 deletions.
5 changes: 3 additions & 2 deletions config/armada/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ corsAllowedOrigins:
- http://localhost:10000
grpcGatewayPath: "/"
cancelJobsBatchSize: 1000
QueueRepositoryUsesPostgres: false
QueueCacheRefreshPeriod: 10s
queueRepositoryUsesPostgres: false
queueCacheRefreshPeriod: 10s
requireQueueAndJobSet: true
schedulerApiConnection:
armadaUrl: "localhost:50052"
grpc:
Expand Down
2 changes: 2 additions & 0 deletions internal/armada/configuration/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ type ArmadaConfig struct {

CancelJobsBatchSize int

RequireQueueAndJobSet bool

Redis redis.UniversalOptions
EventsApiRedis redis.UniversalOptions
Pulsar PulsarConfig
Expand Down
3 changes: 2 additions & 1 deletion internal/armada/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ func Serve(ctx *armadacontext.Context, config *configuration.ArmadaConfig, healt
config.Submission,
submit.NewDeduplicator(store),
submitChecker,
authorizer)
authorizer,
config.RequireQueueAndJobSet)

// Consumer that's used for deleting pulsarJob details
// Need to use the old config.Pulsar.RedisFromPulsarSubscription name so we continue processing where we left off
Expand Down
185 changes: 90 additions & 95 deletions internal/armada/submit/submit.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@ import (
// Server is a service that accepts API calls according to the original Armada submit API and publishes messages
// to Pulsar based on those calls.
type Server struct {
publisher pulsarutils.Publisher
queueRepository repository.QueueRepository
queueCache repository.ReadOnlyQueueRepository
jobRepository repository.JobRepository
submissionConfig configuration.SubmissionConfig
deduplicator Deduplicator
submitChecker scheduler.SubmitScheduleChecker
authorizer server.ActionAuthorizer
publisher pulsarutils.Publisher
queueRepository repository.QueueRepository
queueCache repository.ReadOnlyQueueRepository
jobRepository repository.JobRepository
submissionConfig configuration.SubmissionConfig
deduplicator Deduplicator
submitChecker scheduler.SubmitScheduleChecker
authorizer server.ActionAuthorizer
requireQueueAndJobSet bool
// Below are used only for testing
clock clock.Clock
idGenerator func() *armadaevents.Uuid
Expand All @@ -59,17 +60,19 @@ func NewServer(
deduplicator Deduplicator,
submitChecker scheduler.SubmitScheduleChecker,
authorizer server.ActionAuthorizer,
requireQueueAndJobSet bool,
) *Server {
return &Server{
publisher: publisher,
queueRepository: queueRepository,
queueCache: queueCache,
jobRepository: jobRepository,
submissionConfig: submissionConfig,
deduplicator: deduplicator,
submitChecker: submitChecker,
authorizer: authorizer,
clock: clock.RealClock{},
publisher: publisher,
queueRepository: queueRepository,
queueCache: queueCache,
jobRepository: jobRepository,
submissionConfig: submissionConfig,
deduplicator: deduplicator,
submitChecker: submitChecker,
authorizer: authorizer,
requireQueueAndJobSet: requireQueueAndJobSet,
clock: clock.RealClock{},
idGenerator: func() *armadaevents.Uuid {
return armadaevents.MustProtoUuidFromUlidString(util.NewULID())
},
Expand Down Expand Up @@ -216,26 +219,38 @@ func (s *Server) CancelJobs(grpcCtx context.Context, req *api.JobCancelRequest)
}, nil
}

// resolve the queue and jobset of the job: we can't trust what the user has given us
resolvedQueue, resolvedJobset, err := s.resolveQueueAndJobsetForJob(ctx, req.JobId)
if err != nil {
return nil, err
}
resolvedQueue := ""
resolvedJobSet := ""

// If both a job id and queue or jobsetId is provided, return ErrNotFound if they don't match,
// since the job could not be found for the provided queue/jobSetId.
if req.Queue != "" && req.Queue != resolvedQueue {
return nil, &armadaerrors.ErrNotFound{
Type: "job",
Value: req.JobId,
Message: fmt.Sprintf("job not found in queue %s, try waiting", req.Queue),
if s.requireQueueAndJobSet {
err := validation.ValidateQueueAndJobSet(req)
if err != nil {
return nil, err
}
}
if req.JobSetId != "" && req.JobSetId != resolvedJobset {
return nil, &armadaerrors.ErrNotFound{
Type: "job",
Value: req.JobId,
Message: fmt.Sprintf("job not found in job set %s, try waiting", req.JobSetId),
resolvedQueue = req.Queue
resolvedJobSet = req.JobSetId
} else {
// resolve the queue and jobset of the job: we can't trust what the user has given us
resolvedQueue, resolvedJobSet, err := s.resolveQueueAndJobsetForJob(ctx, req.JobId)
if err != nil {
return nil, err
}

// If both a job id and queue or jobsetId is provided, return ErrNotFound if they don't match,
// since the job could not be found for the provided queue/jobSetId.
if req.Queue != "" && req.Queue != resolvedQueue {
return nil, &armadaerrors.ErrNotFound{
Type: "job",
Value: req.JobId,
Message: fmt.Sprintf("job not found in queue %s, try waiting", req.Queue),
}
}
if req.JobSetId != "" && req.JobSetId != resolvedJobSet {
return nil, &armadaerrors.ErrNotFound{
Type: "job",
Value: req.JobId,
Message: fmt.Sprintf("job not found in job set %s, try waiting", req.JobSetId),
}
}
}

Expand All @@ -251,7 +266,7 @@ func (s *Server) CancelJobs(grpcCtx context.Context, req *api.JobCancelRequest)

sequence := &armadaevents.EventSequence{
Queue: resolvedQueue,
JobSetName: resolvedJobset,
JobSetName: resolvedJobSet,
UserId: userId,
Groups: groups,
Events: []*armadaevents.EventSequence_Event{
Expand Down Expand Up @@ -346,44 +361,51 @@ func preemptJobEventSequenceForJobIds(jobIds []string, q, jobSet, userId string,
func (s *Server) ReprioritizeJobs(grpcCtx context.Context, req *api.JobReprioritizeRequest) (*api.JobReprioritizeResponse, error) {
ctx := armadacontext.FromGrpcCtx(grpcCtx)

if req.JobSetId == "" || req.Queue == "" {
ctx.
WithField("apidatamissing", "true").
Warnf("Reprioritize jobs called with missing data: jobId=%s, jobset=%s, queue=%s, user=%s", req.JobIds[0], req.JobSetId, req.Queue, s.GetUser(ctx))
}

// If either queue or jobSetId is missing, we get the job set and queue associated
// with the first job id in the request.
//
// This must be done before checking auth, since the auth check expects a queue.
if len(req.JobIds) > 0 && (req.Queue == "" || req.JobSetId == "") {
firstJobId := req.JobIds[0]

resolvedQueue, resolvedJobset, err := s.resolveQueueAndJobsetForJob(ctx, firstJobId)
if s.requireQueueAndJobSet {
err := validation.ValidateQueueAndJobSet(req)
if err != nil {
return nil, err
}
} else {
if req.JobSetId == "" || req.Queue == "" {
ctx.
WithField("apidatamissing", "true").
Warnf("Reprioritize jobs called with missing data: jobId=%s, jobset=%s, queue=%s, user=%s", req.JobIds[0], req.JobSetId, req.Queue, s.GetUser(ctx))
}

// If both a job id and queue or jobsetId is provided, return ErrNotFound if they don't match,
// since the job could not be found for the provided queue/jobSetId.
// If both a job id and queue or jobsetId is provided, return ErrNotFound if they don't match,
// since the job could not be found for the provided queue/jobSetId.
if req.Queue != "" && req.Queue != resolvedQueue {
return nil, &armadaerrors.ErrNotFound{
Type: "job",
Value: firstJobId,
Message: fmt.Sprintf("job not found in queue %s, try waiting", req.Queue),
// If either queue or jobSetId is missing, we get the job set and queue associated
// with the first job id in the request.
//
// This must be done before checking auth, since the auth check expects a queue.
if len(req.JobIds) > 0 && (req.Queue == "" || req.JobSetId == "") {
firstJobId := req.JobIds[0]

resolvedQueue, resolvedJobset, err := s.resolveQueueAndJobsetForJob(ctx, firstJobId)
if err != nil {
return nil, err
}
}
if req.JobSetId != "" && req.JobSetId != resolvedJobset {
return nil, &armadaerrors.ErrNotFound{
Type: "job",
Value: firstJobId,
Message: fmt.Sprintf("job not found in job set %s, try waiting", req.JobSetId),

// If both a job id and queue or jobsetId is provided, return ErrNotFound if they don't match,
// since the job could not be found for the provided queue/jobSetId.
// If both a job id and queue or jobsetId is provided, return ErrNotFound if they don't match,
// since the job could not be found for the provided queue/jobSetId.
if req.Queue != "" && req.Queue != resolvedQueue {
return nil, &armadaerrors.ErrNotFound{
Type: "job",
Value: firstJobId,
Message: fmt.Sprintf("job not found in queue %s, try waiting", req.Queue),
}
}
if req.JobSetId != "" && req.JobSetId != resolvedJobset {
return nil, &armadaerrors.ErrNotFound{
Type: "job",
Value: firstJobId,
Message: fmt.Sprintf("job not found in job set %s, try waiting", req.JobSetId),
}
}
req.Queue = resolvedQueue
req.JobSetId = resolvedJobset
}
req.Queue = resolvedQueue
req.JobSetId = resolvedJobset
}

// TODO: this is incorrect we only validate the permissions on the first job but the other jobs may belong to different queues
Expand Down Expand Up @@ -481,7 +503,7 @@ func (s *Server) CancelJobSet(grpcCtx context.Context, req *api.JobSetCancelRequ
}
}

err := validateJobSetFilter(req.Filter)
err := validation.ValidateJobSetFilter(req.Filter)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -607,33 +629,6 @@ func (s *Server) resolveQueueAndJobsetForJob(ctx *armadacontext.Context, jobId s
}
}

func validateJobSetFilter(filter *api.JobSetFilter) error {
if filter == nil {
return nil
}
providedStatesSet := map[string]bool{}
for _, state := range filter.States {
providedStatesSet[state.String()] = true
}
for _, state := range filter.States {
if state == api.JobState_PENDING {
if _, present := providedStatesSet[api.JobState_RUNNING.String()]; !present {
return fmt.Errorf("unsupported state combination - state %s and %s must always be used together",
api.JobState_PENDING, api.JobState_RUNNING)
}
}

if state == api.JobState_RUNNING {
if _, present := providedStatesSet[api.JobState_PENDING.String()]; !present {
return fmt.Errorf("unsupported state combination - state %s and %s must always be used together",
api.JobState_PENDING, api.JobState_RUNNING)
}
}
}

return nil
}

func (s *Server) CreateQueue(grpcCtx context.Context, req *api.Queue) (*types.Empty, error) {
ctx := armadacontext.FromGrpcCtx(grpcCtx)
err := s.authorizer.AuthorizeAction(ctx, permissions.CreateQueue)
Expand Down
3 changes: 2 additions & 1 deletion internal/armada/submit/submit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,8 @@ func createTestServer(t *testing.T) (*Server, *mockObjects) {
testfixtures.DefaultSubmissionConfig(),
m.deduplicator,
m.submitChecker,
m.authorizer)
m.authorizer,
true)
server.clock = clock.NewFakeClock(testfixtures.DefaultTime)
server.idGenerator = testfixtures.TestUlidGenerator()
return server, m
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package submit
package validation

import (
"fmt"

"github.com/armadaproject/armada/internal/common/armadaerrors"
"github.com/armadaproject/armada/pkg/api"
)

Expand Down Expand Up @@ -32,3 +33,26 @@ func ValidateJobSetFilter(filter *api.JobSetFilter) error {

return nil
}

type JobSetRequest interface {
GetJobSetId() string
GetQueue() string
}

func ValidateQueueAndJobSet(req JobSetRequest) error {
if req.GetQueue() == "" {
return &armadaerrors.ErrInvalidArgument{
Name: "Queue",
Value: req.GetQueue(),
Message: "queue cannot be empty",
}
}
if req.GetJobSetId() == "" {
return &armadaerrors.ErrInvalidArgument{
Name: "JobSetId",
Value: req.GetJobSetId(),
Message: "jobset cannot be empty",
}
}
return nil
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package submit
package validation

import (
"testing"
Expand Down

0 comments on commit 70fd866

Please sign in to comment.