diff --git a/escrow/state_service.proto b/escrow/state_service.proto index 84ea14a7..7a06958d 100644 --- a/escrow/state_service.proto +++ b/escrow/state_service.proto @@ -1,3 +1,9 @@ +// +// FIXME: All changes all this file should manually be copied to the `snet-cli` +// repo until https://github.com/singnet/snet-daemon/issues/99 and +// https://github.com/singnet/snet-cli/issues/88 are fixed. +// + syntax = "proto3"; package escrow; diff --git a/handler/interceptors.go b/handler/interceptors.go index 452b8fc1..c8561f46 100644 --- a/handler/interceptors.go +++ b/handler/interceptors.go @@ -197,8 +197,8 @@ func (interceptor *paymentValidationInterceptor) intercept(srv interface{}, ss g defer func() { if !handlerSucceed { if r := recover(); r != nil { - e = r.(error) - paymentHandler.CompleteAfterError(payment, e) + log.WithField("panicValue", r).Warn("Service handler called panic(panicValue)") + paymentHandler.CompleteAfterError(payment, fmt.Errorf("Service handler called panic(%v)", r)) panic("re-panic after payment handler error handling") } else if e != nil { err = paymentHandler.CompleteAfterError(payment, e) diff --git a/handler/interceptors_test.go b/handler/interceptors_test.go index 3be4a755..cd5b9756 100644 --- a/handler/interceptors_test.go +++ b/handler/interceptors_test.go @@ -1,94 +1,202 @@ package handler import ( + "context" + "errors" "math/big" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" ) -func TestGetBytesFromHexString(t *testing.T) { +type serverStreamMock struct { + context context.Context +} + +func (m *serverStreamMock) Context() context.Context { + return m.context +} + +func (m *serverStreamMock) SetHeader(metadata.MD) error { + return errors.New("not implemented in mock") +} + +func (m *serverStreamMock) SendHeader(metadata.MD) error { + return errors.New("not implemented in mock") +} + +func (m *serverStreamMock) SetTrailer(metadata.MD) { +} + +func (m *serverStreamMock) SendMsg(interface{}) error { + return errors.New("not implemented in mock") +} + +func (m *serverStreamMock) RecvMsg(interface{}) error { + return errors.New("not implemented in mock") +} + +const ( + defaultPaymentHandlerType = "test-default-payment-handler" + testPaymentHandlerType = "test-payment-handler" +) + +type paymentHandlerMock struct { + typ string + completeAfterErrorCalled bool + completeCalled bool +} + +func (handler *paymentHandlerMock) Type() string { + return handler.typ +} + +func (handler *paymentHandlerMock) Payment(context *GrpcStreamContext) (payment Payment, err *GrpcError) { + return +} + +func (handler *paymentHandlerMock) Complete(payment Payment) (err *GrpcError) { + handler.completeCalled = true + return +} + +func (handler *paymentHandlerMock) CompleteAfterError(payment Payment, result error) (err *GrpcError) { + handler.completeAfterErrorCalled = true + return +} + +type InterceptorsSuite struct { + suite.Suite + + returnErrorHandler grpc.StreamHandler + panicHandler grpc.StreamHandler + defaultPaymentHandler *paymentHandlerMock + paymentHandler *paymentHandlerMock + interceptor grpc.StreamServerInterceptor + serverStream *serverStreamMock +} + +func (suite *InterceptorsSuite) SetupSuite() { + suite.returnErrorHandler = func(srv interface{}, stream grpc.ServerStream) error { + return errors.New("some error") + } + suite.panicHandler = func(srv interface{}, stream grpc.ServerStream) error { + panic("some panic") + } + suite.defaultPaymentHandler = &paymentHandlerMock{typ: defaultPaymentHandlerType} + suite.paymentHandler = &paymentHandlerMock{typ: testPaymentHandlerType} + suite.interceptor = GrpcPaymentValidationInterceptor(suite.defaultPaymentHandler, suite.paymentHandler) + suite.serverStream = &serverStreamMock{context: metadata.NewIncomingContext(context.Background(), metadata.Pairs(PaymentTypeHeader, testPaymentHandlerType))} +} + +func TestIntersecptorsSuite(t *testing.T) { + suite.Run(t, new(InterceptorsSuite)) +} + +func (suite *InterceptorsSuite) TestGetBytesFromHexString() { md := metadata.Pairs("test-key", "0xfFfE0100") bytes, err := GetBytesFromHex(md, "test-key") - assert.Nil(t, err) - assert.Equal(t, []byte{255, 254, 1, 0}, bytes) + assert.Nil(suite.T(), err) + assert.Equal(suite.T(), []byte{255, 254, 1, 0}, bytes) } -func TestGetBytesFromHexStringNoPrefix(t *testing.T) { +func (suite *InterceptorsSuite) TestGetBytesFromHexStringNoPrefix() { md := metadata.Pairs("test-key", "fFfE0100") bytes, err := GetBytesFromHex(md, "test-key") - assert.Nil(t, err) - assert.Equal(t, []byte{255, 254, 1, 0}, bytes) + assert.Nil(suite.T(), err) + assert.Equal(suite.T(), []byte{255, 254, 1, 0}, bytes) } -func TestGetBytesFromHexStringNoValue(t *testing.T) { +func (suite *InterceptorsSuite) TestGetBytesFromHexStringNoValue() { md := metadata.Pairs("unknown-key", "fFfE0100") _, err := GetBytesFromHex(md, "test-key") - assert.Equal(t, NewGrpcErrorf(codes.InvalidArgument, "missing \"test-key\""), err) + assert.Equal(suite.T(), NewGrpcErrorf(codes.InvalidArgument, "missing \"test-key\""), err) } -func TestGetBytesFromHexStringTooManyValues(t *testing.T) { +func (suite *InterceptorsSuite) TestGetBytesFromHexStringTooManyValues() { md := metadata.Pairs("test-key", "0x123", "test-key", "FED") _, err := GetBytesFromHex(md, "test-key") - assert.Equal(t, NewGrpcErrorf(codes.InvalidArgument, "too many values for key \"test-key\": [0x123 FED]"), err) + assert.Equal(suite.T(), NewGrpcErrorf(codes.InvalidArgument, "too many values for key \"test-key\": [0x123 FED]"), err) } -func TestGetBigInt(t *testing.T) { +func (suite *InterceptorsSuite) TestGetBigInt() { md := metadata.Pairs("big-int-key", "12345") value, err := GetBigInt(md, "big-int-key") - assert.Nil(t, err) - assert.Equal(t, big.NewInt(12345), value) + assert.Nil(suite.T(), err) + assert.Equal(suite.T(), big.NewInt(12345), value) } -func TestGetBigIntIncorrectValue(t *testing.T) { +func (suite *InterceptorsSuite) TestGetBigIntIncorrectValue() { md := metadata.Pairs("big-int-key", "12345abc") _, err := GetBigInt(md, "big-int-key") - assert.Equal(t, NewGrpcErrorf(codes.InvalidArgument, "incorrect format \"big-int-key\": \"12345abc\""), err) + assert.Equal(suite.T(), NewGrpcErrorf(codes.InvalidArgument, "incorrect format \"big-int-key\": \"12345abc\""), err) } -func TestGetBigIntNoValue(t *testing.T) { +func (suite *InterceptorsSuite) TestGetBigIntNoValue() { md := metadata.Pairs() _, err := GetBigInt(md, "big-int-key") - assert.Equal(t, NewGrpcErrorf(codes.InvalidArgument, "missing \"big-int-key\""), err) + assert.Equal(suite.T(), NewGrpcErrorf(codes.InvalidArgument, "missing \"big-int-key\""), err) } -func TestGetBigIntTooManyValues(t *testing.T) { +func (suite *InterceptorsSuite) TestGetBigIntTooManyValues() { md := metadata.Pairs("big-int-key", "12345", "big-int-key", "54321") _, err := GetBigInt(md, "big-int-key") - assert.Equal(t, NewGrpcErrorf(codes.InvalidArgument, "too many values for key \"big-int-key\": [12345 54321]"), err) + assert.Equal(suite.T(), NewGrpcErrorf(codes.InvalidArgument, "too many values for key \"big-int-key\": [12345 54321]"), err) } -func TestGetBytes(t *testing.T) { +func (suite *InterceptorsSuite) TestGetBytes() { md := metadata.Pairs("binary-key-bin", string([]byte{0x00, 0x01, 0xFE, 0xFF})) value, err := GetBytes(md, "binary-key-bin") - assert.Nil(t, err) - assert.Equal(t, []byte{0, 1, 254, 255}, value) + assert.Nil(suite.T(), err) + assert.Equal(suite.T(), []byte{0, 1, 254, 255}, value) } -func TestGetBytesIncorrectBinaryKey(t *testing.T) { +func (suite *InterceptorsSuite) TestGetBytesIncorrectBinaryKey() { md := metadata.Pairs("binary-key", string([]byte{0x00, 0x01, 0xFE, 0xFF})) _, err := GetBytes(md, "binary-key") - assert.Equal(t, NewGrpcErrorf(codes.InvalidArgument, "incorrect binary key name \"binary-key\""), err) + assert.Equal(suite.T(), NewGrpcErrorf(codes.InvalidArgument, "incorrect binary key name \"binary-key\""), err) +} + +func (suite *InterceptorsSuite) TestCompleteOnHandlerError() { + suite.interceptor(nil, suite.serverStream, nil, suite.returnErrorHandler) + + assert.True(suite.T(), suite.paymentHandler.completeAfterErrorCalled) + assert.False(suite.T(), suite.paymentHandler.completeCalled) +} + +func (suite *InterceptorsSuite) TestCompleteOnHandlerPanic() { + defer func() { + if r := recover(); r == nil { + assert.Fail(suite.T(), "panic() call expected") + } + }() + + suite.interceptor(nil, suite.serverStream, nil, suite.panicHandler) + + assert.True(suite.T(), suite.paymentHandler.completeAfterErrorCalled) + assert.False(suite.T(), suite.paymentHandler.completeCalled) }