diff --git a/config/scheduler/config.yaml b/config/scheduler/config.yaml index 71e84baf7a4..856a284d885 100644 --- a/config/scheduler/config.yaml +++ b/config/scheduler/config.yaml @@ -122,9 +122,15 @@ scheduling: maxUnacknowledgedJobsPerExecutor: 2500 alwaysAttemptScheduling: false executorUpdateFrequency: "1m" - failureEstimatorConfig: + failureProbabilityEstimation: # Optimised default parameters. numInnerIterations: 10 innerOptimiserStepSize: 0.05 outerOptimiserStepSize: 0.05 outerOptimiserNesterovAcceleration: 0.2 + nodeQuarantining: + failureProbabilityQuarantineThreshold: 0.95 + failureProbabilityEstimateTimeout: "10m" + queueQuarantining: + quarantineFactorMultiplier: 0.5 # At most halve the scheduling rate of misbehaving queues. + failureProbabilityEstimateTimeout: "10m" diff --git a/internal/armada/configuration/types.go b/internal/armada/configuration/types.go index ea15750e77f..16f375d7c4f 100644 --- a/internal/armada/configuration/types.go +++ b/internal/armada/configuration/types.go @@ -251,8 +251,12 @@ type SchedulingConfig struct { AlwaysAttemptScheduling bool // The frequency at which the scheduler updates the cluster state. ExecutorUpdateFrequency time.Duration - // Controls node and queue failure probability estimation. - FailureEstimatorConfig FailureEstimatorConfig + // Controls node and queue success probability estimation. + FailureProbabilityEstimation FailureEstimatorConfig + // Controls node quarantining, i.e., removing from consideration for scheduling misbehaving nodes. + NodeQuarantining NodeQuarantinerConfig + // Controls queue quarantining, i.e., rate-limiting scheduling from misbehaving queues. + QueueQuarantining QueueQuarantinerConfig } const ( @@ -310,8 +314,8 @@ type WellKnownNodeType struct { Taints []v1.Taint } -// FailureEstimatorConfig contains config controlling node and queue success probability estimation. -// See the internal/scheduler/failureestimator package for details. +// FailureEstimatorConfig controls node and queue success probability estimation. +// See internal/scheduler/failureestimator.go for details. type FailureEstimatorConfig struct { Disabled bool NumInnerIterations int `validate:"gt=0"` @@ -320,6 +324,20 @@ type FailureEstimatorConfig struct { OuterOptimiserNesterovAcceleration float64 `validate:"gte=0"` } +// NodeQuarantinerConfig controls how nodes are quarantined, i.e., removed from consideration when scheduling new jobs. +// See internal/scheduler/quarantine/node_quarantiner.go for details. +type NodeQuarantinerConfig struct { + FailureProbabilityQuarantineThreshold float64 `validate:"gte=0,lte=1"` + FailureProbabilityEstimateTimeout time.Duration `validate:"gte=0"` +} + +// QueueQuarantinerConfig controls how scheduling from misbehaving queues is rate-limited. +// See internal/scheduler/quarantine/queue_quarantiner.go for details. +type QueueQuarantinerConfig struct { + QuarantineFactorMultiplier float64 `validate:"gte=0,lte=1"` + FailureProbabilityEstimateTimeout time.Duration `validate:"gte=0"` +} + // TODO: we can probably just typedef this to map[string]string type PostgresConfig struct { Connection map[string]string diff --git a/internal/scheduler/failureestimator/failureestimator.go b/internal/scheduler/failureestimator/failureestimator.go index 724556a6c01..c80f5849317 100644 --- a/internal/scheduler/failureestimator/failureestimator.go +++ b/internal/scheduler/failureestimator/failureestimator.go @@ -4,6 +4,7 @@ import ( "fmt" "math" "sync" + "time" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" @@ -55,15 +56,12 @@ type FailureEstimator struct { gradient *mat.VecDense // Maps node (queue) names to the corresponding index of parameters. - // E.g., if parameterIndexByNode["myNode"] = 10, then parameters[10] is the estimated success probability of myNode. - parameterIndexByNode map[string]int - parameterIndexByQueue map[string]int - - // Maps node names to the cluster they belong to. - clusterByNode map[string]string + // E.g., if nodeByName["myNode"].parameterIndex = 10, then parameters[10] is the estimated success probability of myNode. + nodeByName map[string]node + queueByName map[string]queue // Samples that have not been processed yet. - samples []Sample + samples []sample // Optimisation settings. numInnerIterations int @@ -82,7 +80,18 @@ type FailureEstimator struct { mu sync.Mutex } -type Sample struct { +type node struct { + parameterIndex int + cluster string + timeOfMostRecentSample time.Time +} + +type queue struct { + parameterIndex int + timeOfMostRecentSample time.Time +} + +type sample struct { i int j int c bool @@ -106,10 +115,8 @@ func New( intermediateParameters: mat.NewVecDense(32, armadaslices.Zeros[float64](32)), gradient: mat.NewVecDense(32, armadaslices.Zeros[float64](32)), - parameterIndexByNode: make(map[string]int, 16), - parameterIndexByQueue: make(map[string]int, 16), - - clusterByNode: make(map[string]string), + nodeByName: make(map[string]node, 16), + queueByName: make(map[string]queue, 16), numInnerIterations: numInnerIterations, innerOptimiser: innerOptimiser, @@ -146,25 +153,34 @@ func (fe *FailureEstimator) IsDisabled() bool { // Push adds a sample to the internal buffer of the failure estimator. // Samples added via Push are processed on the next call to Update. -func (fe *FailureEstimator) Push(node, queue, cluster string, success bool) { +// The timestamp t should be the time at which the success or failure happened. +func (fe *FailureEstimator) Push(nodeName, queueName, clusterName string, success bool, t time.Time) { fe.mu.Lock() defer fe.mu.Unlock() - fe.clusterByNode[node] = cluster - i, ok := fe.parameterIndexByNode[node] + node, ok := fe.nodeByName[nodeName] if !ok { - i = len(fe.parameterIndexByNode) + len(fe.parameterIndexByQueue) - fe.parameterIndexByNode[node] = i + node.parameterIndex = len(fe.nodeByName) + len(fe.queueByName) + } + node.cluster = clusterName + if node.timeOfMostRecentSample.Compare(t) == -1 { + node.timeOfMostRecentSample = t } - j, ok := fe.parameterIndexByQueue[queue] + fe.nodeByName[nodeName] = node + + queue, ok := fe.queueByName[queueName] if !ok { - j = len(fe.parameterIndexByNode) + len(fe.parameterIndexByQueue) - fe.parameterIndexByQueue[queue] = j + queue.parameterIndex = len(fe.nodeByName) + len(fe.queueByName) } - fe.extendParameters(armadamath.Max(i, j) + 1) - fe.samples = append(fe.samples, Sample{ - i: i, - j: j, + if queue.timeOfMostRecentSample.Compare(t) == -1 { + queue.timeOfMostRecentSample = t + } + fe.queueByName[queueName] = queue + + fe.extendParameters(armadamath.Max(node.parameterIndex, queue.parameterIndex) + 1) + fe.samples = append(fe.samples, sample{ + i: node.parameterIndex, + j: queue.parameterIndex, c: success, }) } @@ -259,6 +275,54 @@ func (fe *FailureEstimator) negLogLikelihoodGradient(nodeSuccessProbability, que } } +// FailureProbabilityFromNodeName returns the failure probability estimate of the named node +// and the timestamp of the most recent success or failure observed for this node. +// The most recent sample may not be reflected in the estimate if Update has not been called since the last call to Push. +// If there is no estimate for nodeName, the final return value is false. +func (fe *FailureEstimator) FailureProbabilityFromNodeName(nodeName string) (float64, time.Time, bool) { + node, ok := fe.nodeByName[nodeName] + if !ok { + return 0, time.Time{}, false + } + return 1 - fe.parameters.AtVec(node.parameterIndex), node.timeOfMostRecentSample, true +} + +// FailureProbabilityFromQueueName returns the failure probability estimate of the named queue +// and the timestamp of the most recent success or failure observed for this queue. +// The most recent sample may not be reflected in the estimate if Update has not been called since the last call to Push. +// If there is no estimate for queueName, the final return value is false. +func (fe *FailureEstimator) FailureProbabilityFromQueueName(queueName string) (float64, time.Time, bool) { + queue, ok := fe.nodeByName[queueName] + if !ok { + return 0, time.Time{}, false + } + return 1 - fe.parameters.AtVec(queue.parameterIndex), queue.timeOfMostRecentSample, true +} + +func (fe *FailureEstimator) ApplyNodes(f func(nodeName, cluster string, failureProbability float64, timeOfLastUpdate time.Time)) { + fe.mu.Lock() + defer fe.mu.Unlock() + for nodeName, node := range fe.nodeByName { + // Report failure probability rounded to nearest multiple of 0.01. + // (As it's unlikely the estimate is accurate to within less than this.) + failureProbability := 1 - fe.parameters.AtVec(node.parameterIndex) + failureProbability = math.Round(failureProbability*100) / 100 + f(nodeName, node.cluster, failureProbability, node.timeOfMostRecentSample) + } +} + +func (fe *FailureEstimator) ApplyQueues(f func(queueName string, failureProbability float64, timeOfLastUpdate time.Time)) { + fe.mu.Lock() + defer fe.mu.Unlock() + for queueName, queue := range fe.queueByName { + // Report failure probability rounded to nearest multiple of 0.01. + // (As it's unlikely the estimate is accurate to within less than this.) + failureProbability := 1 - fe.parameters.AtVec(queue.parameterIndex) + failureProbability = math.Round(failureProbability*100) / 100 + f(queueName, failureProbability, queue.timeOfMostRecentSample) + } +} + func (fe *FailureEstimator) Describe(ch chan<- *prometheus.Desc) { if fe.IsDisabled() { return @@ -271,19 +335,10 @@ func (fe *FailureEstimator) Collect(ch chan<- prometheus.Metric) { if fe.IsDisabled() { return } - fe.mu.Lock() - defer fe.mu.Unlock() - - // Report failure probability rounded to nearest multiple of 0.01. - // (As it's unlikely the estimate is accurate to within less than this.) - for k, i := range fe.parameterIndexByNode { - failureProbability := 1 - fe.parameters.AtVec(i) - failureProbability = math.Round(failureProbability*100) / 100 - ch <- prometheus.MustNewConstMetric(fe.failureProbabilityByNodeDesc, prometheus.GaugeValue, failureProbability, k, fe.clusterByNode[k]) - } - for k, j := range fe.parameterIndexByQueue { - failureProbability := 1 - fe.parameters.AtVec(j) - failureProbability = math.Round(failureProbability*100) / 100 - ch <- prometheus.MustNewConstMetric(fe.failureProbabilityByQueueDesc, prometheus.GaugeValue, failureProbability, k) - } + fe.ApplyNodes(func(nodeName, cluster string, failureProbability float64, timeOfLastUpdate time.Time) { + ch <- prometheus.MustNewConstMetric(fe.failureProbabilityByNodeDesc, prometheus.GaugeValue, failureProbability, nodeName, cluster) + }) + fe.ApplyQueues(func(queueName string, failureProbability float64, timeOfLastUpdate time.Time) { + ch <- prometheus.MustNewConstMetric(fe.failureProbabilityByQueueDesc, prometheus.GaugeValue, failureProbability, queueName) + }) } diff --git a/internal/scheduler/failureestimator/failureestimator_test.go b/internal/scheduler/failureestimator/failureestimator_test.go index 9c5e80dac30..04825f16b18 100644 --- a/internal/scheduler/failureestimator/failureestimator_test.go +++ b/internal/scheduler/failureestimator/failureestimator_test.go @@ -3,6 +3,7 @@ package failureestimator import ( "fmt" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -20,25 +21,29 @@ func TestUpdate(t *testing.T) { require.NoError(t, err) // Test initialisation. - fe.Push("node", "queue", "cluster", false) - nodeParameterIndex, ok := fe.parameterIndexByNode["node"] + now := time.Now() + fe.Push("node", "queue", "cluster", false, now) + node, ok := fe.nodeByName["node"] require.True(t, ok) - queueParameterIndex, ok := fe.parameterIndexByQueue["queue"] + queue, ok := fe.queueByName["queue"] require.True(t, ok) - require.Equal(t, 0, nodeParameterIndex) - require.Equal(t, 1, queueParameterIndex) + require.Equal(t, 0, node.parameterIndex) + require.Equal(t, 1, queue.parameterIndex) require.Equal(t, 0.5, fe.parameters.AtVec(0)) require.Equal(t, 0.5, fe.parameters.AtVec(1)) + require.Equal(t, now, node.timeOfMostRecentSample) + require.Equal(t, now, queue.timeOfMostRecentSample) for i := 0; i < 100; i++ { - fe.Push(fmt.Sprintf("node-%d", i), "queue-0", "cluster", false) + now := time.Now() + fe.Push(fmt.Sprintf("node-%d", i), "queue-0", "cluster", false, now) } - nodeParameterIndex, ok = fe.parameterIndexByNode["node-99"] + node, ok = fe.nodeByName["node-99"] require.True(t, ok) - queueParameterIndex, ok = fe.parameterIndexByQueue["queue-0"] + queue, ok = fe.queueByName["queue-0"] require.True(t, ok) - require.Equal(t, 2+100, nodeParameterIndex) - require.Equal(t, 3, queueParameterIndex) + require.Equal(t, 2+100, node.parameterIndex) + require.Equal(t, 3, queue.parameterIndex) require.Equal(t, 0.5, fe.parameters.AtVec(102)) require.Equal(t, 0.5, fe.parameters.AtVec(3)) @@ -51,15 +56,15 @@ func TestUpdate(t *testing.T) { assert.Less(t, nodeSuccessProbability, 0.5-eps) assert.Less(t, queueSuccessProbability, 0.5-eps) - // Test that the estimates move in the expected direction on success. - fe.Push("node", "queue", "cluster", true) + // Test that the estimates move in the expected direction after observing successes and failures. + fe.Push("node", "queue", "cluster", true, now) fe.Update() assert.Greater(t, fe.parameters.AtVec(0), nodeSuccessProbability) assert.Greater(t, fe.parameters.AtVec(1), queueSuccessProbability) for i := 0; i < 1000; i++ { for i := 0; i < 10; i++ { - fe.Push("node", "queue", "cluster", false) + fe.Push("node", "queue", "cluster", false, now) } fe.Update() } @@ -70,7 +75,7 @@ func TestUpdate(t *testing.T) { for i := 0; i < 1000; i++ { for i := 0; i < 10; i++ { - fe.Push("node", "queue", "cluster", true) + fe.Push("node", "queue", "cluster", true, now) } fe.Update() } diff --git a/internal/scheduler/quarantine/node_quarantiner.go b/internal/scheduler/quarantine/node_quarantiner.go new file mode 100644 index 00000000000..3bc5c161a26 --- /dev/null +++ b/internal/scheduler/quarantine/node_quarantiner.go @@ -0,0 +1,123 @@ +package quarantine + +import ( + "fmt" + "time" + + "github.com/pkg/errors" + "github.com/prometheus/client_golang/prometheus" + v1 "k8s.io/api/core/v1" + + "github.com/armadaproject/armada/internal/common/armadaerrors" + "github.com/armadaproject/armada/internal/scheduler/failureestimator" +) + +const ( + namespace = "armada" + subsystem = "scheduler" + + highFailureProbabilityQuarantineReason = "highFailureProbability" +) + +var highFailureProbabilityTaint = v1.Taint{ + Key: "armadaproject.io/schedulerInternal/quarantined", + Value: highFailureProbabilityQuarantineReason, + Effect: v1.TaintEffectNoSchedule, +} + +// NodeQuarantiner determines whether nodes should be quarantined, +// i.e., removed from consideration when scheduling new jobs, +// based on the estimated failure probability of the node. +// +// Specifically, any node for which the following is true is quarantined: +// 1. The estimated failure probability exceeds failureProbabilityQuarantineThreshold. +// 2. The failure probability estimate was updated at most failureProbabilityEstimateTimeout ago. +type NodeQuarantiner struct { + // Quarantine nodes with a failure probability greater than this threshold. + failureProbabilityQuarantineThreshold float64 + // Ignore failure probability estimates with no updates for at least this amount of time. + failureProbabilityEstimateTimeout time.Duration + // Provides failure probability estimates. + failureEstimator *failureestimator.FailureEstimator + + // Prometheus metrics. + isQuarantinedDesc *prometheus.Desc +} + +func NewNodeQuarantiner( + failureProbabilityQuarantineThreshold float64, + failureProbabilityEstimateTimeout time.Duration, + failureEstimator *failureestimator.FailureEstimator, +) (*NodeQuarantiner, error) { + if failureProbabilityQuarantineThreshold < 0 || failureProbabilityQuarantineThreshold > 1 { + return nil, errors.WithStack(&armadaerrors.ErrInvalidArgument{ + Name: "failureProbabilityQuarantineThreshold", + Value: failureProbabilityQuarantineThreshold, + Message: fmt.Sprintf("outside allowed range [0, 1]"), + }) + } + if failureProbabilityEstimateTimeout < 0 { + return nil, errors.WithStack(&armadaerrors.ErrInvalidArgument{ + Name: "failureProbabilityEstimateTimeout", + Value: failureProbabilityEstimateTimeout, + Message: fmt.Sprintf("outside allowed range [0, Inf)"), + }) + } + return &NodeQuarantiner{ + failureProbabilityQuarantineThreshold: failureProbabilityQuarantineThreshold, + failureProbabilityEstimateTimeout: failureProbabilityEstimateTimeout, + failureEstimator: failureEstimator, + isQuarantinedDesc: prometheus.NewDesc( + fmt.Sprintf("%s_%s_node_quarantined", namespace, subsystem), + "Indicates which nodes are quarantined and for what reason.", + []string{"node", "cluster", "reason"}, + nil, + ), + }, nil +} + +// IsQuarantined returns true if the node is quarantined and a taint expressing the reason why, and false otherwise. +func (nq *NodeQuarantiner) IsQuarantined(t time.Time, nodeName string) (taint v1.Taint, isQuarantined bool) { + if nq.failureEstimator.IsDisabled() { + return + } + failureProbability, timeOfLastUpdate, ok := nq.failureEstimator.FailureProbabilityFromNodeName(nodeName) + if !ok { + // No estimate available for this node. + return + } + if !nq.isQuarantined(t, failureProbability, timeOfLastUpdate) { + return + } + return highFailureProbabilityTaint, true +} + +func (nq *NodeQuarantiner) isQuarantined(t time.Time, failureProbability float64, timeOfLastUpdate time.Time) bool { + if failureProbability < nq.failureProbabilityQuarantineThreshold { + // Failure probability does not exceed threshold. + return false + } + if t.Sub(timeOfLastUpdate) > nq.failureProbabilityEstimateTimeout { + // Failure probability estimate hasn't been updated recently. + return false + } + return true +} + +func (nq *NodeQuarantiner) Describe(ch chan<- *prometheus.Desc) { + ch <- nq.isQuarantinedDesc +} + +func (nq *NodeQuarantiner) Collect(ch chan<- prometheus.Metric) { + if nq.failureEstimator.IsDisabled() { + return + } + t := time.Now() + nq.failureEstimator.ApplyNodes(func(nodeName, cluster string, failureProbability float64, timeOfLastUpdate time.Time) { + v := 0.0 + if nq.isQuarantined(t, failureProbability, timeOfLastUpdate) { + v = 1.0 + } + ch <- prometheus.MustNewConstMetric(nq.isQuarantinedDesc, prometheus.GaugeValue, v, nodeName, cluster, highFailureProbabilityQuarantineReason) + }) +} diff --git a/internal/scheduler/quarantine/queue_quarantiner.go b/internal/scheduler/quarantine/queue_quarantiner.go new file mode 100644 index 00000000000..ec4ed6e3c4d --- /dev/null +++ b/internal/scheduler/quarantine/queue_quarantiner.go @@ -0,0 +1,104 @@ +package quarantine + +import ( + "fmt" + "time" + + "github.com/pkg/errors" + "github.com/prometheus/client_golang/prometheus" + + "github.com/armadaproject/armada/internal/common/armadaerrors" + "github.com/armadaproject/armada/internal/scheduler/failureestimator" +) + +// QueueQuarantiner determines whether queues should be quarantined, +// i.e., whether we should reduce the rate which we schedule jobs from the queue, +// based on the estimated failure probability of the queue. +// +// Specifically, each queue has a quarantine factor associated with it equal to: +// - Zero, if the failure probability estimate was last updated more then failureProbabilityEstimateTimeout ago. +// - Failure probability estimate of the queue multiplied by quarantineFactorMultiplier otherwise. +type QueueQuarantiner struct { + // Multiply the failure probability by this value to produce the qurantineFactor. + quarantineFactorMultiplier float64 + // Ignore failure probability estimates with no updates for at least this amount of time. + failureProbabilityEstimateTimeout time.Duration + // Provides failure probability estimates. + failureEstimator *failureestimator.FailureEstimator + + // Prometheus metrics. + isQuarantinedDesc *prometheus.Desc +} + +func NewQueueQuarantiner( + quarantineFactorMultiplier float64, + failureProbabilityEstimateTimeout time.Duration, + failureEstimator *failureestimator.FailureEstimator, +) (*QueueQuarantiner, error) { + if quarantineFactorMultiplier < 0 || quarantineFactorMultiplier > 1 { + return nil, errors.WithStack(&armadaerrors.ErrInvalidArgument{ + Name: "quarantineFactorMultiplier", + Value: quarantineFactorMultiplier, + Message: fmt.Sprintf("outside allowed range [0, 1]"), + }) + } + if failureProbabilityEstimateTimeout < 0 { + return nil, errors.WithStack(&armadaerrors.ErrInvalidArgument{ + Name: "failureProbabilityEstimateTimeout", + Value: failureProbabilityEstimateTimeout, + Message: fmt.Sprintf("outside allowed range [0, Inf)"), + }) + } + return &QueueQuarantiner{ + quarantineFactorMultiplier: quarantineFactorMultiplier, + failureProbabilityEstimateTimeout: failureProbabilityEstimateTimeout, + failureEstimator: failureEstimator, + isQuarantinedDesc: prometheus.NewDesc( + fmt.Sprintf("%s_%s_queue_quarantined", namespace, subsystem), + "Indicates which queues are quarantined and for what reason.", + []string{"queue", "reason"}, + nil, + ), + }, nil +} + +// QuarantineFactor returns a value in [0, 1] indicating to which extent the queue should be quarantined, +// where 0.0 indicates not at all and 1.0 completely. +func (qq *QueueQuarantiner) QuarantineFactor(t time.Time, queueName string) float64 { + if qq.failureEstimator.IsDisabled() { + return 0 + } + failureProbability, timeOfLastUpdate, ok := qq.failureEstimator.FailureProbabilityFromQueueName(queueName) + if !ok { + // No estimate available for this node. + return 0 + } + return qq.quarantineFactor(t, failureProbability, timeOfLastUpdate) +} + +func (qq *QueueQuarantiner) quarantineFactor(t time.Time, failureProbability float64, timeOfLastUpdate time.Time) float64 { + if t.Sub(timeOfLastUpdate) > qq.failureProbabilityEstimateTimeout { + // Failure probability estimate hasn't been updated recently. + return 0 + } + return failureProbability * qq.quarantineFactorMultiplier +} + +func (qq *QueueQuarantiner) Describe(ch chan<- *prometheus.Desc) { + ch <- qq.isQuarantinedDesc +} + +func (qq *QueueQuarantiner) Collect(ch chan<- prometheus.Metric) { + if qq.failureEstimator.IsDisabled() { + return + } + t := time.Now() + qq.failureEstimator.ApplyQueues(func(queueName string, failureProbability float64, timeOfLastUpdate time.Time) { + ch <- prometheus.MustNewConstMetric( + qq.isQuarantinedDesc, prometheus.GaugeValue, + qq.quarantineFactor(t, failureProbability, timeOfLastUpdate), + queueName, + highFailureProbabilityQuarantineReason, + ) + }) +} diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go index 0191fd4aab7..e475cb45819 100644 --- a/internal/scheduler/scheduler.go +++ b/internal/scheduler/scheduler.go @@ -288,11 +288,17 @@ func (s *Scheduler) cycle(ctx *armadacontext.Context, updateAll bool, leaderToke if run == nil { continue } + var t time.Time + if terminatedTime := run.TerminatedTime(); terminatedTime != nil { + t = *terminatedTime + } else { + t = time.Now() + } if jst.Failed { - s.failureEstimator.Push(run.NodeName(), jst.Job.GetQueue(), run.Executor(), false) + s.failureEstimator.Push(run.NodeName(), jst.Job.GetQueue(), run.Executor(), false, t) } if jst.Succeeded { - s.failureEstimator.Push(run.NodeName(), jst.Job.GetQueue(), run.Executor(), true) + s.failureEstimator.Push(run.NodeName(), jst.Job.GetQueue(), run.Executor(), true, t) } } s.failureEstimator.Update() diff --git a/internal/scheduler/schedulerapp.go b/internal/scheduler/schedulerapp.go index af695e7ce9d..c33961b8144 100644 --- a/internal/scheduler/schedulerapp.go +++ b/internal/scheduler/schedulerapp.go @@ -37,6 +37,7 @@ import ( "github.com/armadaproject/armada/internal/scheduler/failureestimator" "github.com/armadaproject/armada/internal/scheduler/jobdb" "github.com/armadaproject/armada/internal/scheduler/metrics" + "github.com/armadaproject/armada/internal/scheduler/quarantine" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" "github.com/armadaproject/armada/pkg/executorapi" ) @@ -195,12 +196,55 @@ func Run(config schedulerconfig.Configuration) error { schedulingReportServer := NewLeaderProxyingSchedulingReportsServer(schedulingContextRepository, leaderClientConnectionProvider) schedulerobjects.RegisterSchedulerReportingServer(grpcServer, schedulingReportServer) + // Setup failure estimation and quarantining. + failureEstimator, err := failureestimator.New( + config.Scheduling.FailureProbabilityEstimation.NumInnerIterations, + // Invalid config will have failed validation. + descent.MustNew(config.Scheduling.FailureProbabilityEstimation.InnerOptimiserStepSize), + // Invalid config will have failed validation. + nesterov.MustNew( + config.Scheduling.FailureProbabilityEstimation.OuterOptimiserStepSize, + config.Scheduling.FailureProbabilityEstimation.OuterOptimiserNesterovAcceleration, + ), + ) + if err != nil { + return err + } + failureEstimator.Disable(config.Scheduling.FailureProbabilityEstimation.Disabled) + if err := prometheus.Register(failureEstimator); err != nil { + return errors.WithStack(err) + } + nodeQuarantiner, err := quarantine.NewNodeQuarantiner( + config.Scheduling.NodeQuarantining.FailureProbabilityQuarantineThreshold, + config.Scheduling.NodeQuarantining.FailureProbabilityEstimateTimeout, + failureEstimator, + ) + if err != nil { + return err + } + if err := prometheus.Register(nodeQuarantiner); err != nil { + return errors.WithStack(err) + } + queueQuarantiner, err := quarantine.NewQueueQuarantiner( + config.Scheduling.QueueQuarantining.QuarantineFactorMultiplier, + config.Scheduling.QueueQuarantining.FailureProbabilityEstimateTimeout, + failureEstimator, + ) + if err != nil { + return err + } + if err := prometheus.Register(queueQuarantiner); err != nil { + return errors.WithStack(err) + } + schedulingAlgo, err := NewFairSchedulingAlgo( config.Scheduling, config.MaxSchedulingDuration, executorRepository, queueRepository, schedulingContextRepository, + nodeQuarantiner, + queueQuarantiner, ) if err != nil { return errors.WithMessage(err, "error creating scheduling algo") @@ -222,24 +266,6 @@ func Run(config schedulerconfig.Configuration) error { return errors.WithStack(err) } - failureEstimator, err := failureestimator.New( - config.Scheduling.FailureEstimatorConfig.NumInnerIterations, - // Invalid config will have failed validation. - descent.MustNew(config.Scheduling.FailureEstimatorConfig.InnerOptimiserStepSize), - // Invalid config will have failed validation. - nesterov.MustNew( - config.Scheduling.FailureEstimatorConfig.OuterOptimiserStepSize, - config.Scheduling.FailureEstimatorConfig.OuterOptimiserNesterovAcceleration, - ), - ) - if err != nil { - return err - } - failureEstimator.Disable(config.Scheduling.FailureEstimatorConfig.Disabled) - if err := prometheus.Register(failureEstimator); err != nil { - return errors.WithStack(err) - } - scheduler, err := NewScheduler( jobDb, jobRepository, diff --git a/internal/scheduler/scheduling_algo.go b/internal/scheduler/scheduling_algo.go index 3a5d18933ff..1b9c27c2625 100644 --- a/internal/scheduler/scheduling_algo.go +++ b/internal/scheduler/scheduling_algo.go @@ -27,6 +27,7 @@ import ( "github.com/armadaproject/armada/internal/scheduler/interfaces" "github.com/armadaproject/armada/internal/scheduler/jobdb" "github.com/armadaproject/armada/internal/scheduler/nodedb" + "github.com/armadaproject/armada/internal/scheduler/quarantine" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" "github.com/armadaproject/armada/pkg/client/queue" ) @@ -54,6 +55,10 @@ type FairSchedulingAlgo struct { // Order in which to schedule executor groups. // Executors are grouped by either id (i.e., individually) or by pool. executorGroupsToSchedule []string + // Used to avoid scheduling onto broken nodes. + nodeQuarantiner *quarantine.NodeQuarantiner + // Used to reduce the rate at which jobs are scheduled from misbehaving queues. + queueQuarantiner *quarantine.QueueQuarantiner // Function that is called every time an executor is scheduled. Useful for testing. onExecutorScheduled func(executor *schedulerobjects.Executor) // rand and clock injected here for repeatable testing. @@ -67,6 +72,8 @@ func NewFairSchedulingAlgo( executorRepository database.ExecutorRepository, queueRepository repository.QueueRepository, schedulingContextRepository *SchedulingContextRepository, + nodeQuarantiner *quarantine.NodeQuarantiner, + queueQuarantiner *quarantine.QueueQuarantiner, ) (*FairSchedulingAlgo, error) { if _, ok := config.PriorityClasses[config.DefaultPriorityClassName]; !ok { return nil, errors.Errorf( @@ -82,9 +89,11 @@ func NewFairSchedulingAlgo( limiter: rate.NewLimiter(rate.Limit(config.MaximumSchedulingRate), config.MaximumSchedulingBurst), limiterByQueue: make(map[string]*rate.Limiter), maxSchedulingDuration: maxSchedulingDuration, + nodeQuarantiner: nodeQuarantiner, + queueQuarantiner: queueQuarantiner, + onExecutorScheduled: func(executor *schedulerobjects.Executor) {}, rand: util.NewThreadsafeRand(time.Now().UnixNano()), clock: clock.RealClock{}, - onExecutorScheduled: func(executor *schedulerobjects.Executor) {}, }, nil } @@ -246,6 +255,7 @@ func (l *FairSchedulingAlgo) newFairSchedulingAlgoContext(ctx *armadacontext.Con } executors = l.filterStaleExecutors(executors) + // TODO(albin): Skip queues with a high failure rate. queues, err := l.queueRepository.GetAllQueues() if err != nil { return nil, err @@ -380,6 +390,8 @@ func (l *FairSchedulingAlgo) scheduleOnExecutors( l.limiter, totalResources, ) + + now := time.Now() for queue, priorityFactor := range fsctx.priorityFactorByQueue { if !fsctx.isActiveByQueueName[queue] { // To ensure fair share is computed only from active queues, i.e., queues with jobs queued or running. @@ -393,15 +405,24 @@ func (l *FairSchedulingAlgo) scheduleOnExecutors( if priorityFactor > 0 { weight = 1 / priorityFactor } + + // Create per-queue limiters lazily. queueLimiter, ok := l.limiterByQueue[queue] if !ok { - // Create per-queue limiters lazily. queueLimiter = rate.NewLimiter( rate.Limit(l.schedulingConfig.MaximumPerQueueSchedulingRate), l.schedulingConfig.MaximumPerQueueSchedulingBurst, ) l.limiterByQueue[queue] = queueLimiter } + + // Reduce max the scheduling rate of misbehaving queues by adjusting the per-queue rate-limiter limit. + quarantineFactor := 0.0 + if l.queueQuarantiner != nil { + quarantineFactor = l.queueQuarantiner.QuarantineFactor(now, queue) + } + queueLimiter.SetLimitAt(now, rate.Limit(l.schedulingConfig.MaximumPerQueueSchedulingRate*(1-quarantineFactor))) + if err := sctx.AddQueueSchedulingContext(queue, weight, allocatedByPriorityClass, queueLimiter); err != nil { return nil, nil, err } @@ -531,7 +552,16 @@ func (l *FairSchedulingAlgo) addExecutorToNodeDb(nodeDb *nodedb.NodeDb, jobs []* } jobsByNodeId[nodeId] = append(jobsByNodeId[nodeId], job) } + + now := time.Now() for _, node := range nodes { + // Taint quarantined nodes to avoid scheduling new jobs onto them. + if l.nodeQuarantiner != nil { + if taint, ok := l.nodeQuarantiner.IsQuarantined(now, node.Name); ok { + node.Taints = append(node.Taints, taint) + } + } + if err := nodeDb.CreateAndInsertWithJobDbJobsWithTxn(txn, jobsByNodeId[node.Id], node); err != nil { return err } diff --git a/internal/scheduler/scheduling_algo_test.go b/internal/scheduler/scheduling_algo_test.go index b3cac9d6a26..ce1bbe24b65 100644 --- a/internal/scheduler/scheduling_algo_test.go +++ b/internal/scheduler/scheduling_algo_test.go @@ -390,6 +390,8 @@ func TestSchedule(t *testing.T) { mockExecutorRepo, mockQueueRepo, schedulingContextRepo, + nil, + nil, ) require.NoError(t, err) @@ -548,6 +550,8 @@ func BenchmarkNodeDbConstruction(b *testing.B) { nil, nil, nil, + nil, + nil, ) require.NoError(b, err) b.StartTimer()