Skip to content

Commit

Permalink
refactor: remote graphite trigger validation (#1127)
Browse files Browse the repository at this point in the history
  • Loading branch information
AleksandrMatsko authored Dec 10, 2024
1 parent e742999 commit 83c63a0
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 30 deletions.
2 changes: 1 addition & 1 deletion api/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type FeatureFlags struct {
IsPlottingAvailable bool `json:"isPlottingAvailable" example:"true"`
IsSubscriptionToAllTagsAvailable bool `json:"isSubscriptionToAllTagsAvailable" example:"false"`
IsReadonlyEnabled bool `json:"isReadonlyEnabled" example:"false"`
CelebrationMode CelebrationMode `json:"celebrationMode" example:"new_year"`
CelebrationMode CelebrationMode `json:"celebrationMode" swaggertype:"string" example:"new_year"`
}

// CelebrationMode is type for celebrate Moira.
Expand Down
2 changes: 1 addition & 1 deletion api/handler/notification.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func deleteNotification(writer http.ResponseWriter, request *http.Request) {
// @success 200 {object} dto.NotificationsList "Notification have been deleted"
// @failure 403 {object} api.ErrorForbiddenExample "Forbidden"
// @failure 500 {object} api.ErrorInternalServerExample "Internal server error"
// @router /notification [delete]
// @router /notification/all [delete]
func deleteAllNotifications(writer http.ResponseWriter, request *http.Request) {
if errorResponse := controller.DeleteAllNotifications(database); errorResponse != nil {
render.Render(writer, request, errorResponse) //nolint
Expand Down
27 changes: 21 additions & 6 deletions api/handler/triggers.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,15 @@ func getTriggerFromRequest(request *http.Request) (*dto.Trigger, *api.ErrorRespo
return nil, api.ErrorInvalidRequest(fmt.Errorf("invalid expression: %s", err.Error()))
case api.ErrInvalidRequestContent:
return nil, api.ErrorInvalidRequest(err)
case remote.ErrRemoteTriggerResponse:
case remote.ErrRemoteUnavailable:
response := api.ErrorRemoteServerUnavailable(err)
middleware.GetLoggerEntry(request).Error().
String("status", response.StatusText).
Error(err).
Msg("Remote server unavailable")
return nil, response
case remote.ErrRemoteTriggerResponse:
return nil, api.ErrorInvalidRequest(fmt.Errorf("error from graphite remote: %w", err))
case *json.UnmarshalTypeError:
return nil, api.ErrorInvalidRequest(fmt.Errorf("invalid payload: %s", err.Error()))
case *prometheus.Error:
Expand Down Expand Up @@ -232,10 +234,11 @@ func getMetricTTLByTrigger(request *http.Request, trigger *dto.Trigger) (time.Du
// @tags trigger
// @accept json
// @produce json
// @param trigger body dto.Trigger true "Trigger data"
// @success 200 {object} dto.TriggerCheckResponse "Validation is done, see response body for validation result"
// @failure 400 {object} api.ErrorInvalidRequestExample "Bad request from client"
// @failure 500 {object} api.ErrorInternalServerExample "Internal server error"
// @param trigger body dto.Trigger true "Trigger data"
// @success 200 {object} dto.TriggerCheckResponse "Validation is done, see response body for validation result"
// @failure 400 {object} api.ErrorInvalidRequestExample "Bad request from client"
// @failure 500 {object} api.ErrorInternalServerExample "Internal server error"
// @failure 503 {object} api.ErrorRemoteServerUnavailableExample "Remote server unavailable"
// @router /trigger/check [put]
func triggerCheck(writer http.ResponseWriter, request *http.Request) {
trigger := &dto.Trigger{}
Expand All @@ -246,10 +249,22 @@ func triggerCheck(writer http.ResponseWriter, request *http.Request) {
case expression.ErrInvalidExpression, local.ErrParseExpr, local.ErrEvalExpr, local.ErrUnknownFunction:
// TODO: move ErrInvalidExpression to separate case

// These errors are skipped because if there are error from local source then it will be caught in
// Errors above are skipped because if there is an error from local source then it will be caught in
// dto.TargetVerification and will be explained in detail.
case remote.ErrRemoteUnavailable:
errRsp := api.ErrorRemoteServerUnavailable(err)
middleware.GetLoggerEntry(request).Error().
String("status", errRsp.StatusText).
Error(err).
Msg("Remote server unavailable")
render.Render(writer, request, errRsp) //nolint
return
case remote.ErrRemoteTriggerResponse:
render.Render(writer, request, api.ErrorInvalidRequest(fmt.Errorf("error from graphite remote: %w", err))) //nolint
return
case *prometheus.Error:
render.Render(writer, request, errorResponseOnPrometheusError(typedErr)) //nolint
return
default:
render.Render(writer, request, api.ErrorInvalidRequest(err)) //nolint
return
Expand Down
67 changes: 46 additions & 21 deletions api/handler/triggers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,17 @@ func TestGetTriggerFromRequest(t *testing.T) {
fetchResult.EXPECT().GetPatterns().Return(make([]string, 0), nil).AnyTimes()
fetchResult.EXPECT().GetMetricsData().Return([]metricSource.MetricData{*metricSource.MakeMetricData("", []float64{}, 0, 0)}).AnyTimes()

setValuesToRequestCtx := func(
ctx context.Context,
metricSourceProvider *metricSource.SourceProvider,
limits api.LimitsConfig,
) context.Context {
ctx = middleware.SetContextValueForTest(ctx, "metricSourceProvider", metricSourceProvider)
ctx = middleware.SetContextValueForTest(ctx, "limits", limits)

return ctx
}

Convey("Given a correct payload", t, func() {
triggerWarnValue := 0.0
triggerErrorValue := 1.0
Expand Down Expand Up @@ -106,8 +117,7 @@ func TestGetTriggerFromRequest(t *testing.T) {

request := httptest.NewRequest(http.MethodPut, "/trigger", bytes.NewReader(body))
request.Header.Add("content-type", "application/json")
request = request.WithContext(middleware.SetContextValueForTest(request.Context(), "metricSourceProvider", sourceProvider))
request = request.WithContext(middleware.SetContextValueForTest(request.Context(), "limits", api.GetTestLimitsConfig()))
request = request.WithContext(setValuesToRequestCtx(request.Context(), sourceProvider, api.GetTestLimitsConfig()))

triggerDTO.Schedule.Days = moira.GetFilledScheduleDataDays(false)
triggerDTO.Schedule.Days[0].Enabled = true
Expand Down Expand Up @@ -148,8 +158,7 @@ func TestGetTriggerFromRequest(t *testing.T) {

request := httptest.NewRequest(http.MethodPut, "/trigger", strings.NewReader(body))
request.Header.Add("content-type", "application/json")
request = request.WithContext(middleware.SetContextValueForTest(request.Context(), "metricSourceProvider", sourceProvider))
request = request.WithContext(middleware.SetContextValueForTest(request.Context(), "limits", api.GetTestLimitsConfig()))
request = request.WithContext(setValuesToRequestCtx(request.Context(), sourceProvider, api.GetTestLimitsConfig()))

Convey("Parser should return en error", func() {
_, err := getTriggerFromRequest(request)
Expand Down Expand Up @@ -198,25 +207,43 @@ func TestGetTriggerFromRequest(t *testing.T) {
Convey("for graphite remote", func() {
triggerDTO.TriggerSource = moira.GraphiteRemote
body, _ := json.Marshal(triggerDTO)
testLogger, _ := logging.GetLogger("Test")

request := httptest.NewRequest(http.MethodPut, "/trigger", bytes.NewReader(body))
request.Header.Add("content-type", "application/json")
request = request.WithContext(middleware.SetContextValueForTest(request.Context(), "metricSourceProvider", allSourceProvider))
request = request.WithContext(middleware.SetContextValueForTest(request.Context(), "limits", api.GetTestLimitsConfig()))
Convey("when ErrRemoteTriggerResponse returned", func() {
request := httptest.NewRequest(http.MethodPut, "/trigger", bytes.NewReader(body))
request.Header.Add("content-type", "application/json")
request = request.WithContext(setValuesToRequestCtx(request.Context(), allSourceProvider, api.GetTestLimitsConfig()))

testLogger, _ := logging.GetLogger("Test")
request = middleware.WithLogEntry(request, middleware.NewLogEntry(testLogger, request))

request = middleware.WithLogEntry(request, middleware.NewLogEntry(testLogger, request))
var returnedErr error = remote.ErrRemoteTriggerResponse{
InternalError: fmt.Errorf(""),
}

var returnedErr error = remote.ErrRemoteTriggerResponse{
InternalError: fmt.Errorf(""),
}
graphiteRemoteSrc.EXPECT().Fetch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil, returnedErr)

graphiteRemoteSrc.EXPECT().Fetch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil, returnedErr)
_, errRsp := getTriggerFromRequest(request)
So(errRsp, ShouldResemble, api.ErrorInvalidRequest(fmt.Errorf("error from graphite remote: %w", returnedErr)))
})

Convey("when ErrRemoteUnavailable", func() {
request := httptest.NewRequest(http.MethodPut, "/trigger", bytes.NewReader(body))
request.Header.Add("content-type", "application/json")
request = request.WithContext(setValuesToRequestCtx(request.Context(), allSourceProvider, api.GetTestLimitsConfig()))

request = middleware.WithLogEntry(request, middleware.NewLogEntry(testLogger, request))

_, errRsp := getTriggerFromRequest(request)
So(errRsp, ShouldResemble, api.ErrorRemoteServerUnavailable(returnedErr))
var returnedErr error = remote.ErrRemoteUnavailable{
InternalError: fmt.Errorf(""),
}

graphiteRemoteSrc.EXPECT().Fetch(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil, returnedErr)

_, errRsp := getTriggerFromRequest(request)
So(errRsp, ShouldResemble, api.ErrorRemoteServerUnavailable(returnedErr))
})
})

Convey("for prometheus remote", func() {
Expand All @@ -226,8 +253,7 @@ func TestGetTriggerFromRequest(t *testing.T) {
Convey("with error type = bad_data got bad request", func() {
request := httptest.NewRequest(http.MethodPut, "/trigger", bytes.NewReader(body))
request.Header.Add("content-type", "application/json")
request = request.WithContext(middleware.SetContextValueForTest(request.Context(), "metricSourceProvider", allSourceProvider))
request = request.WithContext(middleware.SetContextValueForTest(request.Context(), "limits", api.GetTestLimitsConfig()))
request = request.WithContext(setValuesToRequestCtx(request.Context(), allSourceProvider, api.GetTestLimitsConfig()))

var returnedErr error = &prometheus.Error{
Type: prometheus.ErrBadData,
Expand All @@ -253,8 +279,7 @@ func TestGetTriggerFromRequest(t *testing.T) {
for _, errType := range otherTypes {
request := httptest.NewRequest(http.MethodPut, "/trigger", bytes.NewReader(body))
request.Header.Add("content-type", "application/json")
request = request.WithContext(middleware.SetContextValueForTest(request.Context(), "metricSourceProvider", allSourceProvider))
request = request.WithContext(middleware.SetContextValueForTest(request.Context(), "limits", api.GetTestLimitsConfig()))
request = request.WithContext(setValuesToRequestCtx(request.Context(), allSourceProvider, api.GetTestLimitsConfig()))

var returnedErr error = &prometheus.Error{
Type: errType,
Expand Down
2 changes: 1 addition & 1 deletion datatypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ type ScheduleData struct {
// ScheduleDataDay represents week day of schedule.
type ScheduleDataDay struct {
Enabled bool `json:"enabled" example:"true"`
Name DayName `json:"name,omitempty" example:"Mon" validate:"oneof=Mon Tue Wed Thu Fri Sat Sun"`
Name DayName `json:"name,omitempty" example:"Mon" swaggertype:"string" validate:"oneof=Mon Tue Wed Thu Fri Sat Sun"`
}

// DayName represents the day name used in ScheduleDataDay.
Expand Down

0 comments on commit 83c63a0

Please sign in to comment.