From 870b5a0861af62dd66c5430b59b09c06a96cca08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anna=20V=C3=ADtov=C3=A1?= Date: Tue, 14 Nov 2023 15:46:33 +0100 Subject: [PATCH] fix(HMS-2933): Improve some HTTP error codes --- cmd/pbackend/statuser.go | 4 ++ .../clients/http/sources/sources_client.go | 8 ++- internal/payloads/validation/id.go | 26 ++++++++ internal/payloads/validation/id_test.go | 59 +++++++++++++++++++ internal/services/aws_iam_service.go | 19 ++++++ internal/services/launch_templates_service.go | 13 ++++ internal/services/pubkey_service.go | 8 +++ internal/services/sources_service.go | 5 ++ internal/services/sources_status_service.go | 5 ++ 9 files changed, 146 insertions(+), 1 deletion(-) create mode 100644 internal/payloads/validation/id.go create mode 100644 internal/payloads/validation/id_test.go diff --git a/cmd/pbackend/statuser.go b/cmd/pbackend/statuser.go index cc5e7530..f609a4d7 100644 --- a/cmd/pbackend/statuser.go +++ b/cmd/pbackend/statuser.go @@ -92,6 +92,10 @@ func processMessage(msgCtx context.Context, message *kafka.GenericMessage) { logger.Warn().Err(err).Msg("Not found error from sources") return } + if errors.Is(err, clients.ErrBadRequest) { + logger.Warn().Err(err).Msg("Bad request error from sources") + return + } logger.Warn().Err(err).Msg("Could not get authentication") return } diff --git a/internal/clients/http/sources/sources_client.go b/internal/clients/http/sources/sources_client.go index 3af37249..255b3376 100644 --- a/internal/clients/http/sources/sources_client.go +++ b/internal/clients/http/sources/sources_client.go @@ -222,9 +222,15 @@ func (c *sourcesClient) GetAuthentication(ctx context.Context, sourceId string) if resp == nil { return nil, fmt.Errorf("get source authentication call: empty response: %w", clients.ErrUnexpectedBackendResponse) } - if resp.JSON404 != nil || resp.JSON400 != nil { + if resp.JSON404 != nil { logger.Warn().Bytes("body", resp.Body).Int("status", resp.StatusCode()).Msg("Get authentication from sources returned 4xx") + return nil, fmt.Errorf("get source authentication call: %w", clients.ErrNotFound) } + if resp.JSON400 != nil { + logger.Warn().Bytes("body", resp.Body).Int("status", resp.StatusCode()).Msg("Get authentication from sources returned 4xx") + return nil, fmt.Errorf("get source authentication call: %w", clients.ErrBadRequest) + } + if resp.JSON200 == nil { logger.Warn().Str("source_id", sourceId).RawJSON("response", resp.Body).Msg("Sources returned non-200 response") return nil, fmt.Errorf("get source authentication returned %d: %w", resp.StatusCode(), clients.ErrUnexpectedBackendResponse) diff --git a/internal/payloads/validation/id.go b/internal/payloads/validation/id.go new file mode 100644 index 00000000..2eb727c5 --- /dev/null +++ b/internal/payloads/validation/id.go @@ -0,0 +1,26 @@ +package validation + +import ( + "fmt" + "regexp" + "strconv" +) + +var ( + idValidation = regexp.MustCompile("^[0-9]+$") + ErrInvalidId = fmt.Errorf("invalid id") +) + +func DigitsOnly(id string) error { + if !idValidation.MatchString(id) { + return ErrInvalidId + } + + // Checking for out of range error + _, err := strconv.ParseInt(id, 10, 64) + if err != nil { + return ErrInvalidId + } + + return nil +} diff --git a/internal/payloads/validation/id_test.go b/internal/payloads/validation/id_test.go new file mode 100644 index 00000000..e346c1cf --- /dev/null +++ b/internal/payloads/validation/id_test.go @@ -0,0 +1,59 @@ +package validation + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDigitsOnlySucceeds(t *testing.T) { + id1 := "1111" + err := DigitsOnly(id1) + require.NoError(t, err) + + id2 := "1" + err = DigitsOnly(id2) + require.NoError(t, err) + + id3 := "0" + err = DigitsOnly(id3) + require.NoError(t, err) +} + +func TestDigitsOnlyWithSymbolsFails(t *testing.T) { + id1 := "-1111" + err := DigitsOnly(id1) + if err == nil { + t.Errorf("no error for invalid input: %s", id1) + } + + id2 := "aa" + err = DigitsOnly(id2) + if err == nil { + t.Errorf("no error for invalid input: %s", id2) + } + + id3 := "7a" + err = DigitsOnly(id3) + if err == nil { + t.Errorf("no error for invalid input: %s", id3) + } + + id4 := "a7" + err = DigitsOnly(id4) + if err == nil { + t.Errorf("no error for invalid input: %s", id4) + } + + id5 := "7a7" + err = DigitsOnly(id5) + if err == nil { + t.Errorf("no error for invalid input: %s", id5) + } + + id6 := "zj{{=9243*9806}}zj" + err = DigitsOnly(id6) + if err == nil { + t.Errorf("no error for invalid input: %s", id6) + } +} diff --git a/internal/services/aws_iam_service.go b/internal/services/aws_iam_service.go index 313dab90..c39df897 100644 --- a/internal/services/aws_iam_service.go +++ b/internal/services/aws_iam_service.go @@ -1,8 +1,11 @@ package services import ( + "errors" "net/http" + "github.com/RHEnVision/provisioning-backend/internal/payloads/validation" + "github.com/RHEnVision/provisioning-backend/internal/clients" "github.com/RHEnVision/provisioning-backend/internal/config" "github.com/RHEnVision/provisioning-backend/internal/models" @@ -15,6 +18,10 @@ import ( func ValidatePermissions(w http.ResponseWriter, r *http.Request) { logger := zerolog.Ctx(r.Context()) sourceId := chi.URLParam(r, "ID") + if err := validation.DigitsOnly(sourceId); err != nil { + renderError(w, r, payloads.NewURLParsingError(r.Context(), "id parameter invalid", err)) + } + region := r.URL.Query().Get("region") if region == "" { @@ -31,6 +38,18 @@ func ValidatePermissions(w http.ResponseWriter, r *http.Request) { // Fetch arn from Sources authentication, err := sourcesClient.GetAuthentication(r.Context(), sourceId) if err != nil { + if err != nil { + if errors.Is(err, clients.ErrNotFound) { + renderError(w, r, payloads.NewNotFoundError(r.Context(), "unable to get authentication for sources", err)) + return + } + if errors.Is(err, clients.ErrBadRequest) { + renderError(w, r, payloads.NewResponseError(r.Context(), http.StatusBadRequest, "unable to get authentication from sources", err)) + return + } + renderError(w, r, payloads.NewClientError(r.Context(), err)) + return + } renderError(w, r, payloads.NewClientError(r.Context(), err)) return } diff --git a/internal/services/launch_templates_service.go b/internal/services/launch_templates_service.go index a97322b7..3ce55977 100644 --- a/internal/services/launch_templates_service.go +++ b/internal/services/launch_templates_service.go @@ -3,6 +3,8 @@ package services import ( "net/http" + "github.com/RHEnVision/provisioning-backend/internal/payloads/validation" + "github.com/RHEnVision/provisioning-backend/internal/clients" "github.com/RHEnVision/provisioning-backend/internal/models" "github.com/RHEnVision/provisioning-backend/internal/page" @@ -14,6 +16,9 @@ import ( //nolint:exhaustive func ListLaunchTemplates(w http.ResponseWriter, r *http.Request) { sourceId := chi.URLParam(r, "ID") + if err := validation.DigitsOnly(sourceId); err != nil { + renderError(w, r, payloads.NewURLParsingError(r.Context(), "id parameter invalid", err)) + } sourcesClient, err := clients.GetSourcesClient(r.Context()) if err != nil { @@ -41,6 +46,10 @@ func ListLaunchTemplates(w http.ResponseWriter, r *http.Request) { func ListLaunchTemplateAWS(w http.ResponseWriter, r *http.Request) { sourceId := chi.URLParam(r, "ID") + if err := validation.DigitsOnly(sourceId); err != nil { + renderError(w, r, payloads.NewURLParsingError(r.Context(), "id parameter invalid", err)) + } + region := r.URL.Query().Get("region") if region == "" { renderError(w, r, payloads.NewMissingRequestParameterError(r.Context(), "region parameter is missing")) @@ -80,6 +89,10 @@ func ListLaunchTemplateAWS(w http.ResponseWriter, r *http.Request) { func ListLaunchTemplateGCP(w http.ResponseWriter, r *http.Request) { sourceId := chi.URLParam(r, "ID") + if err := validation.DigitsOnly(sourceId); err != nil { + renderError(w, r, payloads.NewURLParsingError(r.Context(), "id parameter invalid", err)) + } + sourcesClient, err := clients.GetSourcesClient(r.Context()) if err != nil { renderError(w, r, payloads.NewClientError(r.Context(), err)) diff --git a/internal/services/pubkey_service.go b/internal/services/pubkey_service.go index ffe1ec40..8117ce9e 100644 --- a/internal/services/pubkey_service.go +++ b/internal/services/pubkey_service.go @@ -5,6 +5,9 @@ import ( "fmt" "net/http" + "github.com/go-playground/mold/v4" + "github.com/go-playground/validator/v10" + "github.com/RHEnVision/provisioning-backend/internal/clients" httpClients "github.com/RHEnVision/provisioning-backend/internal/clients/http" "github.com/RHEnVision/provisioning-backend/internal/dao" @@ -33,9 +36,14 @@ func CreatePubkey(w http.ResponseWriter, r *http.Request) { pk := payload.NewModel() err := pkDao.Create(r.Context(), pk) + var validationError *validator.ValidationErrors + var transformValueError *mold.ErrInvalidTransformValue + var transformationError *mold.ErrInvalidTransformation if err != nil { if db.IsPostgresError(err, db.UniqueConstraintErrorCode) != nil { renderError(w, r, payloads.PubkeyDuplicateError(r.Context(), "pubkey with such name or fingerprint already exists for this account", err)) + } else if errors.As(err, &validationError) || errors.As(err, &transformValueError) || errors.As(err, &transformationError) { + renderError(w, r, payloads.NewInvalidRequestError(r.Context(), "validation error", err)) } else { renderError(w, r, payloads.NewDAOError(r.Context(), "create pubkey", err)) } diff --git a/internal/services/sources_service.go b/internal/services/sources_service.go index 9db4258a..bd3081ba 100644 --- a/internal/services/sources_service.go +++ b/internal/services/sources_service.go @@ -6,6 +6,8 @@ import ( "fmt" "net/http" + "github.com/RHEnVision/provisioning-backend/internal/payloads/validation" + "github.com/RHEnVision/provisioning-backend/internal/cache" "github.com/RHEnVision/provisioning-backend/internal/clients" "github.com/RHEnVision/provisioning-backend/internal/models" @@ -81,6 +83,9 @@ func ListProvisioningSourcesByProvider(w http.ResponseWriter, r *http.Request, a func GetSourceUploadInfo(w http.ResponseWriter, r *http.Request) { sourceId := chi.URLParam(r, "ID") + if err := validation.DigitsOnly(sourceId); err != nil { + renderError(w, r, payloads.NewURLParsingError(r.Context(), "id parameter invalid", err)) + } sourcesClient, err := clients.GetSourcesClient(r.Context()) if err != nil { diff --git a/internal/services/sources_status_service.go b/internal/services/sources_status_service.go index 08ed6505..e0efc7ad 100644 --- a/internal/services/sources_status_service.go +++ b/internal/services/sources_status_service.go @@ -4,6 +4,8 @@ import ( "errors" stdhttp "net/http" + "github.com/RHEnVision/provisioning-backend/internal/payloads/validation" + "github.com/RHEnVision/provisioning-backend/internal/clients" "github.com/RHEnVision/provisioning-backend/internal/config" "github.com/RHEnVision/provisioning-backend/internal/models" @@ -18,6 +20,9 @@ var ErrUnknownProviderFromSources = errors.New("unknown provider returned from s // is no longer valid. func SourcesStatus(w stdhttp.ResponseWriter, r *stdhttp.Request) { sourceId := chi.URLParam(r, "ID") + if err := validation.DigitsOnly(sourceId); err != nil { + renderError(w, r, payloads.NewURLParsingError(r.Context(), "id parameter invalid", err)) + } sourcesClient, err := clients.GetSourcesClient(r.Context()) if err != nil {