diff --git a/api/handler/triggers_test.go b/api/handler/triggers_test.go index f8650f486..ae13b5461 100644 --- a/api/handler/triggers_test.go +++ b/api/handler/triggers_test.go @@ -65,6 +65,17 @@ func TestGetTriggerFromRequest(t *testing.T) { fetchResult.EXPECT().GetPatterns().Return(make([]string, 0), nil).AnyTimes() fetchResult.EXPECT().GetMetricsData().Return([]metricSource.MetricData{*metricSource.MakeMetricData("", []float64{}, 0, 0)}).AnyTimes() + setValuesToRequestCtx := func( + ctx context.Context, + metricSourceProvider *metricSource.SourceProvider, + limits api.LimitsConfig, + ) context.Context { + ctx = middleware.SetContextValueForTest(ctx, "metricSourceProvider", metricSourceProvider) + ctx = middleware.SetContextValueForTest(ctx, "limits", limits) + + return ctx + } + Convey("Given a correct payload", t, func() { triggerWarnValue := 0.0 triggerErrorValue := 1.0 @@ -105,8 +116,7 @@ func TestGetTriggerFromRequest(t *testing.T) { request := httptest.NewRequest(http.MethodPut, "/trigger", bytes.NewReader(body)) request.Header.Add("content-type", "application/json") - request = request.WithContext(middleware.SetContextValueForTest(request.Context(), "metricSourceProvider", sourceProvider)) - request = request.WithContext(middleware.SetContextValueForTest(request.Context(), "limits", api.GetTestLimitsConfig())) + request = request.WithContext(setValuesToRequestCtx(request.Context(), sourceProvider, api.GetTestLimitsConfig())) triggerDTO.Schedule.Days = moira.GetFilledScheduleDataDays(false) triggerDTO.Schedule.Days[0].Enabled = true @@ -147,8 +157,7 @@ func TestGetTriggerFromRequest(t *testing.T) { request := httptest.NewRequest(http.MethodPut, "/trigger", strings.NewReader(body)) request.Header.Add("content-type", "application/json") - request = request.WithContext(middleware.SetContextValueForTest(request.Context(), "metricSourceProvider", sourceProvider)) - request = request.WithContext(middleware.SetContextValueForTest(request.Context(), "limits", api.GetTestLimitsConfig())) + request = request.WithContext(setValuesToRequestCtx(request.Context(), sourceProvider, api.GetTestLimitsConfig())) Convey("Parser should return en error", func() { _, err := getTriggerFromRequest(request) @@ -195,14 +204,13 @@ func TestGetTriggerFromRequest(t *testing.T) { } Convey("for graphite remote", func() { - Convey("when ErrRemoteTriggerResponse returned", func() { - triggerDTO.TriggerSource = moira.GraphiteRemote - body, _ := json.Marshal(triggerDTO) + triggerDTO.TriggerSource = moira.GraphiteRemote + body, _ := json.Marshal(triggerDTO) + Convey("when ErrRemoteTriggerResponse returned", func() { request := httptest.NewRequest(http.MethodPut, "/trigger", bytes.NewReader(body)) request.Header.Add("content-type", "application/json") - request = request.WithContext(middleware.SetContextValueForTest(request.Context(), "metricSourceProvider", allSourceProvider)) - request = request.WithContext(middleware.SetContextValueForTest(request.Context(), "limits", api.GetTestLimitsConfig())) + request = request.WithContext(setValuesToRequestCtx(request.Context(), allSourceProvider, api.GetTestLimitsConfig())) testLogger, _ := logging.GetLogger("Test") @@ -220,13 +228,9 @@ func TestGetTriggerFromRequest(t *testing.T) { }) Convey("when ErrRemoteUnavailable", func() { - triggerDTO.TriggerSource = moira.GraphiteRemote - body, _ := json.Marshal(triggerDTO) - request := httptest.NewRequest(http.MethodPut, "/trigger", bytes.NewReader(body)) request.Header.Add("content-type", "application/json") - request = request.WithContext(middleware.SetContextValueForTest(request.Context(), "metricSourceProvider", allSourceProvider)) - request = request.WithContext(middleware.SetContextValueForTest(request.Context(), "limits", api.GetTestLimitsConfig())) + request = request.WithContext(setValuesToRequestCtx(request.Context(), allSourceProvider, api.GetTestLimitsConfig())) testLogger, _ := logging.GetLogger("Test") @@ -251,8 +255,7 @@ func TestGetTriggerFromRequest(t *testing.T) { Convey("with error type = bad_data got bad request", func() { request := httptest.NewRequest(http.MethodPut, "/trigger", bytes.NewReader(body)) request.Header.Add("content-type", "application/json") - request = request.WithContext(middleware.SetContextValueForTest(request.Context(), "metricSourceProvider", allSourceProvider)) - request = request.WithContext(middleware.SetContextValueForTest(request.Context(), "limits", api.GetTestLimitsConfig())) + request = request.WithContext(setValuesToRequestCtx(request.Context(), allSourceProvider, api.GetTestLimitsConfig())) var returnedErr error = &prometheus.Error{ Type: prometheus.ErrBadData, @@ -278,8 +281,7 @@ func TestGetTriggerFromRequest(t *testing.T) { for _, errType := range otherTypes { request := httptest.NewRequest(http.MethodPut, "/trigger", bytes.NewReader(body)) request.Header.Add("content-type", "application/json") - request = request.WithContext(middleware.SetContextValueForTest(request.Context(), "metricSourceProvider", allSourceProvider)) - request = request.WithContext(middleware.SetContextValueForTest(request.Context(), "limits", api.GetTestLimitsConfig())) + request = request.WithContext(setValuesToRequestCtx(request.Context(), allSourceProvider, api.GetTestLimitsConfig())) var returnedErr error = &prometheus.Error{ Type: errType,