diff --git a/header_wrp.go b/header_wrp.go index 26feeaf..4123652 100644 --- a/header_wrp.go +++ b/header_wrp.go @@ -17,10 +17,6 @@ package wrp -import ( - "errors" -) - // Constant HTTP header strings representing WRP fields const ( MsgTypeHeader = "X-Midt-Msg-Type" @@ -34,7 +30,7 @@ const ( SourceHeader = "X-Midt-Source" ) -var ErrInvalidMsgType = errors.New("Invalid Message Type") +// var ErrInvalidMsgType = errors.New("Invalid Message Type") // Map string to MessageType int /* diff --git a/validator.go b/validator.go new file mode 100644 index 0000000..dcdf91d --- /dev/null +++ b/validator.go @@ -0,0 +1,93 @@ +/** + * Copyright (c) 2022 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package wrp + +import ( + "errors" + "fmt" +) + +var ( + ErrInvalidMsgTypeValidator = errors.New("invalid WRP message type validator") + ErrInvalidMsgType = errors.New("invalid WRP message type") +) + +// Validator is a WRP validator that allows access to the Validate function. +type Validator interface { + Validate(m Message) error +} + +// Validators is a WRP validator that ensures messages are valid based on +// message type and each validator in the list. +type Validators []Validator + +// ValidatorFunc is a WRP validator that takes messages and validates them +// against functions. +type ValidatorFunc func(Message) error + +// Validate executes its own ValidatorFunc receiver and returns the result. +func (vf ValidatorFunc) Validate(m Message) error { + return vf(m) +} + +// MsgTypeValidator is a WRP validator that validates based on message type +// or using the defaultValidator if message type is unknown +type MsgTypeValidator struct { + m map[MessageType]Validators + defaultValidator Validator +} + +// Validate validates messages based on message type or using the defaultValidator +// if message type is unknown +func (m MsgTypeValidator) Validate(msg Message) error { + vs, ok := m.m[msg.MessageType()] + if !ok { + return m.defaultValidator.Validate(msg) + } + + for _, v := range vs { + err := v.Validate(msg) + if err != nil { + return err + } + } + + return nil +} + +// NewMsgTypeValidator returns a MsgTypeValidator +func NewMsgTypeValidator(m map[MessageType]Validators, defaultValidator Validator) (MsgTypeValidator, error) { + if m == nil { + return MsgTypeValidator{}, fmt.Errorf("%w: %v", ErrInvalidMsgTypeValidator, m) + } + if defaultValidator == nil { + defaultValidator = alwaysInvalidMsg() + } + + return MsgTypeValidator{ + m: m, + defaultValidator: defaultValidator, + }, nil +} + +// AlwaysInvalid doesn't validate anything about the message and always returns an error. +func alwaysInvalidMsg() ValidatorFunc { + return func(m Message) error { + return fmt.Errorf("%w: %v", ErrInvalidMsgType, m.MessageType().String()) + } +} diff --git a/validator_test.go b/validator_test.go new file mode 100644 index 0000000..8be4dde --- /dev/null +++ b/validator_test.go @@ -0,0 +1,199 @@ +/** + * Copyright (c) 2022 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +package wrp + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func testMsgTypeValidatorValidate(t *testing.T) { + type Test struct { + m map[MessageType]Validators + defaultValidator Validator + msg Message + } + + var alwaysValidMsg ValidatorFunc = func(msg Message) error { return nil } + tests := []struct { + description string + value Test + expectedErr error + }{ + // Success case + { + description: "known message type with successful Validators", + value: Test{ + m: map[MessageType]Validators{ + SimpleEventMessageType: {alwaysValidMsg}, + }, + msg: Message{Type: SimpleEventMessageType}, + }, + }, + { + description: "unknown message type with provided default Validator", + value: Test{ + m: map[MessageType]Validators{ + SimpleEventMessageType: {alwaysValidMsg}, + }, + defaultValidator: alwaysValidMsg, + msg: Message{Type: CreateMessageType}, + }, + }, + // Failure case + { + description: "known message type with failing Validators", + value: Test{ + m: map[MessageType]Validators{ + SimpleEventMessageType: {alwaysInvalidMsg()}, + }, + msg: Message{Type: SimpleEventMessageType}, + }, + expectedErr: ErrInvalidMsgType, + }, + { + description: "unknown message type without provided default Validator", + value: Test{ + m: map[MessageType]Validators{ + SimpleEventMessageType: {alwaysValidMsg}, + }, + msg: Message{Type: CreateMessageType}, + }, + expectedErr: ErrInvalidMsgType, + }, + } + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + msgv, err := NewMsgTypeValidator(tc.value.m, tc.value.defaultValidator) + assert.NotNil(msgv) + assert.Nil(err) + err = msgv.Validate(tc.value.msg) + if tc.expectedErr != nil { + assert.ErrorIs(err, tc.expectedErr) + return + } + + assert.Nil(err) + }) + } +} + +func testNewMsgTypeValidator(t *testing.T) { + type Test struct { + m map[MessageType]Validators + defaultValidator Validator + } + + var alwaysValidMsg ValidatorFunc = func(msg Message) error { return nil } + tests := []struct { + description string + value Test + expectedErr error + }{ + // Success case + { + description: "with provided default Validator", + value: Test{ + m: map[MessageType]Validators{ + SimpleEventMessageType: {alwaysValidMsg}, + }, + defaultValidator: alwaysValidMsg, + }, + expectedErr: nil, + }, + { + description: "without provided default Validator", + value: Test{ + m: map[MessageType]Validators{ + SimpleEventMessageType: {alwaysValidMsg}, + }, + }, + expectedErr: nil, + }, + { + description: "empty list of message type Validators", + value: Test{ + m: map[MessageType]Validators{ + SimpleEventMessageType: {}, + }, + defaultValidator: alwaysValidMsg, + }, + expectedErr: nil, + }, + // Failure case + { + description: "missing message type Validators", + value: Test{}, + expectedErr: ErrInvalidMsgTypeValidator, + }, + } + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + msgv, err := NewMsgTypeValidator(tc.value.m, tc.value.defaultValidator) + assert.NotNil(msgv) + if tc.expectedErr != nil { + assert.ErrorIs(err, tc.expectedErr) + return + } + + assert.Nil(err) + }) + } +} + +func testAlwaysInvalidMsg(t *testing.T) { + assert := assert.New(t) + msg := Message{} + v := alwaysInvalidMsg() + + assert.NotNil(v) + err := v(msg) + + assert.NotNil(err) + assert.ErrorIs(err, ErrInvalidMsgType) + +} + +func TestHelperValidators(t *testing.T) { + tests := []struct { + name string + test func(*testing.T) + }{ + {"alwaysInvalidMsg", testAlwaysInvalidMsg}, + } + + for _, tc := range tests { + t.Run(tc.name, tc.test) + } +} + +func TestMsgTypeValidator(t *testing.T) { + tests := []struct { + name string + test func(*testing.T) + }{ + {"MsgTypeValidator validate", testMsgTypeValidatorValidate}, + {"MsgTypeValidator factory", testNewMsgTypeValidator}, + } + + for _, tc := range tests { + t.Run(tc.name, tc.test) + } +}