From 47ca48d234235f334dbbec066d2973de58485b2d Mon Sep 17 00:00:00 2001 From: Peter Svensson Date: Thu, 1 Feb 2024 13:45:09 +0100 Subject: [PATCH] feat: context to handlerFunc --- connection.go | 18 ++++++++++-------- connection_options_test.go | 22 +++++++++++----------- connection_test.go | 20 ++++++++++---------- example_test.go | 3 ++- examples/event-stream/example_test.go | 4 ++-- examples/request-response/example_test.go | 4 ++-- 6 files changed, 37 insertions(+), 34 deletions(-) diff --git a/connection.go b/connection.go index 624042c..18b59bc 100644 --- a/connection.go +++ b/connection.go @@ -42,7 +42,7 @@ import ( // HandlerFunc is used to process an incoming message // If processing fails, an error should be returned and the message will be re-queued // The optional response is used automatically when setting up a RequestResponseHandler, otherwise ignored -type HandlerFunc func(msg any, headers Headers) (response any, err error) +type HandlerFunc func(ctx context.Context, msg any, headers Headers) (response any, err error) // Connection is a wrapper around the actual amqp.Connection and amqp.Channel type Connection struct { @@ -171,7 +171,7 @@ func (c *Connection) Close() error { } func (c *Connection) TypeMappingHandler(handler HandlerFunc) HandlerFunc { - return func(msg any, headers Headers) (response any, err error) { + return func(ctx context.Context, msg any, headers Headers) (response any, err error) { routingKey := headers["routing-key"].(string) typ, exists := c.keyToType[routingKey] if !exists { @@ -182,7 +182,7 @@ func (c *Connection) TypeMappingHandler(handler HandlerFunc) HandlerFunc { if err != nil { return nil, err } else { - if resp, err := handler(message, headers); err == nil { + if resp, err := handler(ctx, message, headers); err == nil { return resp, nil } else { return nil, err @@ -242,8 +242,8 @@ func amqpVersion() string { } func responseWrapper(handler HandlerFunc, routingKey string, publisher ServiceResponsePublisher) HandlerFunc { - return func(msg any, headers Headers) (response any, err error) { - resp, err := handler(msg, headers) + return func(ctx context.Context, msg any, headers Headers) (response any, err error) { + resp, err := handler(ctx, msg, headers) if err != nil { return nil, errors.Wrap(err, "failed to process message") } @@ -351,13 +351,13 @@ func (c *Connection) addHandler(queueName, routingKey string, eventType eventTyp return c.queueHandlers.Add(queueName, routingKey, mHI) } -func (c *Connection) handleMessage(d amqp.Delivery, handler HandlerFunc, eventType eventType) { +func (c *Connection) handleMessage(ctx context.Context, d amqp.Delivery, handler HandlerFunc, eventType eventType) { message, err := c.parseMessage(d.Body, eventType) if err != nil { c.errorLog(fmt.Sprintf("failed to parse message %s", err)) _ = d.Reject(false) } else { - if _, err := handler(message, headers(d.Headers, d.RoutingKey)); err == nil { + if _, err := handler(ctx, message, headers(d.Headers, d.RoutingKey)); err == nil { _ = d.Ack(false) } else { c.errorLog(fmt.Sprintf("failed to process message %s", err)) @@ -400,8 +400,10 @@ func (c *Connection) publishMessage(ctx context.Context, msg any, routingKey, ex func (c *Connection) divertToMessageHandlers(deliveries <-chan amqp.Delivery, handlers *handlers.Handlers[messageHandlerInvoker]) { for d := range deliveries { if h, ok := handlers.Get(d.RoutingKey); ok { + // TODO More here.. + ctx := context.Background() c.messageLogger(d.Body, h.eventType, d.RoutingKey, false) - c.handleMessage(d, h.msgHandler, h.eventType) + c.handleMessage(ctx, d, h.msgHandler, h.eventType) } else { // Unhandled message, drop it _ = d.Reject(false) diff --git a/connection_options_test.go b/connection_options_test.go index 6b5b44d..8e37f58 100644 --- a/connection_options_test.go +++ b/connection_options_test.go @@ -183,7 +183,7 @@ func Test_UseMessageLogger_Default(t *testing.T) { func Test_EventStreamConsumer(t *testing.T) { channel := NewMockAmqpChannel() conn := mockConnection(channel) - err := conn.Start(context.Background(), EventStreamConsumer("key", func(i any, headers Headers) (any, error) { + err := conn.Start(context.Background(), EventStreamConsumer("key", func(ctx context.Context, i any, headers Headers) (any, error) { return nil, nil }, TestMessage{})) require.NoError(t, err) @@ -203,7 +203,7 @@ func Test_EventStreamConsumer(t *testing.T) { func Test_EventStreamConsumerWithOptFunc(t *testing.T) { channel := NewMockAmqpChannel() conn := mockConnection(channel) - err := conn.Start(context.Background(), EventStreamConsumer("key", func(i any, headers Headers) (any, error) { + err := conn.Start(context.Background(), EventStreamConsumer("key", func(ctx context.Context, i any, headers Headers) (any, error) { return nil, nil }, TestMessage{}, AddQueueNameSuffix("suffix"))) require.NoError(t, err) @@ -223,7 +223,7 @@ func Test_EventStreamConsumerWithOptFunc(t *testing.T) { func Test_EventStreamConsumerWithFailingOptFunc(t *testing.T) { channel := NewMockAmqpChannel() conn := mockConnection(channel) - err := conn.Start(context.Background(), EventStreamConsumer("key", func(i any, headers Headers) (any, error) { + err := conn.Start(context.Background(), EventStreamConsumer("key", func(ctx context.Context, i any, headers Headers) (any, error) { return nil, nil }, TestMessage{}, AddQueueNameSuffix(""))) require.ErrorContains(t, err, "failed, empty queue suffix not allowed") @@ -232,7 +232,7 @@ func Test_EventStreamConsumerWithFailingOptFunc(t *testing.T) { func Test_ServiceRequestConsumer_Ok(t *testing.T) { channel := NewMockAmqpChannel() conn := mockConnection(channel) - err := conn.Start(context.Background(), ServiceRequestConsumer("key", func(i any, headers Headers) (any, error) { + err := conn.Start(context.Background(), ServiceRequestConsumer("key", func(ctx context.Context, i any, headers Headers) (any, error) { return nil, nil }, TestMessage{})) @@ -256,7 +256,7 @@ func Test_ServiceRequestConsumer_ExchangeDeclareError(t *testing.T) { declareError := errors.New("failed") channel.ExchangeDeclarationError = &declareError conn := mockConnection(channel) - err := conn.Start(context.Background(), ServiceRequestConsumer("key", func(i any, headers Headers) (any, error) { + err := conn.Start(context.Background(), ServiceRequestConsumer("key", func(ctx context.Context, i any, headers Headers) (any, error) { return nil, nil }, TestMessage{})) @@ -266,7 +266,7 @@ func Test_ServiceRequestConsumer_ExchangeDeclareError(t *testing.T) { func Test_ServiceResponseConsumer_Ok(t *testing.T) { channel := NewMockAmqpChannel() conn := mockConnection(channel) - err := conn.Start(context.Background(), ServiceResponseConsumer("targetService", "key", func(i any, headers Headers) (any, error) { + err := conn.Start(context.Background(), ServiceResponseConsumer("targetService", "key", func(ctx context.Context, i any, headers Headers) (any, error) { return nil, nil }, TestMessage{})) @@ -289,7 +289,7 @@ func Test_ServiceResponseConsumer_ExchangeDeclareError(t *testing.T) { declareError := errors.New("actual error message") channel.ExchangeDeclarationError = &declareError conn := mockConnection(channel) - err := conn.Start(context.Background(), ServiceResponseConsumer("targetService", "key", func(i any, headers Headers) (any, error) { + err := conn.Start(context.Background(), ServiceResponseConsumer("targetService", "key", func(ctx context.Context, i any, headers Headers) (any, error) { return nil, nil }, TestMessage{})) @@ -299,7 +299,7 @@ func Test_ServiceResponseConsumer_ExchangeDeclareError(t *testing.T) { func Test_RequestResponseHandler(t *testing.T) { channel := NewMockAmqpChannel() conn := mockConnection(channel) - err := RequestResponseHandler("key", func(msg any, headers Headers) (response any, err error) { + err := RequestResponseHandler("key", func(ctx context.Context, msg any, headers Headers) (response any, err error) { return nil, nil }, Message{})(conn) require.NoError(t, err) @@ -412,7 +412,7 @@ func Test_TransientEventStreamConsumer_Ok(t *testing.T) { channel := NewMockAmqpChannel() conn := mockConnection(channel) uuid.SetRand(badRand{}) - err := TransientEventStreamConsumer("key", func(i any, headers Headers) (any, error) { + err := TransientEventStreamConsumer("key", func(ctx context.Context, i any, headers Headers) (any, error) { return nil, errors.New("failed") }, Message{})(conn) @@ -437,7 +437,7 @@ func Test_TransientEventStreamConsumer_HandlerForRoutingKeyAlreadyExists(t *test require.NoError(t, conn.queueHandlers.Add("events.topic.exchange.queue.svc-00010203-0405-4607-8809-0a0b0c0d0e0f", "root.key", &messageHandlerInvoker{})) uuid.SetRand(badRand{}) - err := TransientEventStreamConsumer("root.#", func(i any, headers Headers) (any, error) { + err := TransientEventStreamConsumer("root.#", func(ctx context.Context, i any, headers Headers) (any, error) { return nil, errors.New("failed") }, Message{})(conn) @@ -463,7 +463,7 @@ func testTransientEventStreamConsumerFailure(t *testing.T, channel *MockAmqpChan conn := mockConnection(channel) uuid.SetRand(badRand{}) - err := TransientEventStreamConsumer("key", func(i any, headers Headers) (any, error) { + err := TransientEventStreamConsumer("key", func(ctx context.Context, i any, headers Headers) (any, error) { return nil, errors.New("failed") }, Message{})(conn) diff --git a/connection_test.go b/connection_test.go index 09e6e1a..50bca5c 100644 --- a/connection_test.go +++ b/connection_test.go @@ -102,7 +102,7 @@ func Test_Start_SetupFails(t *testing.T) { queueHandlers: &handlers2.QueueHandlers[messageHandlerInvoker]{}, } err := conn.Start(context.Background(), - EventStreamConsumer("test", func(i any, headers Headers) (any, error) { + EventStreamConsumer("test", func(ctx context.Context, i any, headers Headers) (any, error) { return nil, errors.New("failed") }, Message{})) require.Error(t, err) @@ -392,9 +392,9 @@ func TestResponseWrapper(t *testing.T) { if tt.headers != nil { headers = *tt.headers } - resp, err := responseWrapper(func(i any, headers Headers) (any, error) { + resp, err := responseWrapper(func(ctx context.Context, i any, headers Headers) (any, error) { return tt.handlerResp, tt.handlerErr - }, "key", p.publish)(&Message{}, headers) + }, "key", p.publish)(context.Background(), &Message{}, headers) p.checkPublished(t, tt.published) require.Equal(t, tt.wantResp, resp) @@ -416,7 +416,7 @@ func Test_DivertToMessageHandler(t *testing.T) { handlers := &handlers2.QueueHandlers[messageHandlerInvoker]{} msgInvoker := &messageHandlerInvoker{ eventType: reflect.TypeOf(Message{}), - msgHandler: func(i any, headers Headers) (any, error) { + msgHandler: func(ctx context.Context, i any, headers Headers) (any, error) { if i.(*Message).Ok { return nil, nil } @@ -500,7 +500,7 @@ func testHandleMessage(json string, handle bool) MockAcknowledger { messageLogger: noOpMessageLogger(), errorLog: noOpLogger, } - c.handleMessage(delivery, func(i any, headers Headers) (any, error) { + c.handleMessage(context.Background(), delivery, func(ctx context.Context, i any, headers Headers) (any, error) { if handle { return nil, nil } @@ -570,7 +570,7 @@ func TestConnection_TypeMappingHandler(t *testing.T) { msg: []byte(`{"a":true}`), key: "unknown", handler: func(t *testing.T) HandlerFunc { - return func(msg any, headers Headers) (response any, err error) { + return func(ctx context.Context, msg any, headers Headers) (response any, err error) { return nil, nil } }, @@ -589,7 +589,7 @@ func TestConnection_TypeMappingHandler(t *testing.T) { msg: []byte(`{"a:}`), key: "known", handler: func(t *testing.T) HandlerFunc { - return func(msg any, headers Headers) (response any, err error) { + return func(ctx context.Context, msg any, headers Headers) (response any, err error) { return nil, nil } }, @@ -610,7 +610,7 @@ func TestConnection_TypeMappingHandler(t *testing.T) { msg: []byte(`{"a":true}`), key: "known", handler: func(t *testing.T) HandlerFunc { - return func(msg any, headers Headers) (response any, err error) { + return func(ctx context.Context, msg any, headers Headers) (response any, err error) { assert.IsType(t, &TestMessage{}, msg) return nil, fmt.Errorf("handler-error") } @@ -632,7 +632,7 @@ func TestConnection_TypeMappingHandler(t *testing.T) { msg: []byte(`{"a":true}`), key: "known", handler: func(t *testing.T) HandlerFunc { - return func(msg any, headers Headers) (response any, err error) { + return func(ctx context.Context, msg any, headers Headers) (response any, err error) { assert.IsType(t, &TestMessage{}, msg) return "OK", nil } @@ -650,7 +650,7 @@ func TestConnection_TypeMappingHandler(t *testing.T) { } handler := c.TypeMappingHandler(tt.args.handler(t)) - res, err := handler(&tt.args.msg, headers(make(amqp.Table), tt.args.key)) + res, err := handler(context.Background(), &tt.args.msg, headers(make(amqp.Table), tt.args.key)) if !tt.wantErr(t, err) { return } diff --git a/example_test.go b/example_test.go index 994e302..683bf95 100644 --- a/example_test.go +++ b/example_test.go @@ -32,6 +32,7 @@ var amqpURL = "amqp://user:password@localhost:5672/" func Example() { ctx := context.Background() + if urlFromEnv := os.Getenv("AMQP_URL"); urlFromEnv != "" { amqpURL = urlFromEnv } @@ -58,7 +59,7 @@ func checkError(err error) { } } -func process(m any, headers Headers) (any, error) { +func process(ctx context.Context, m any, headers Headers) (any, error) { fmt.Printf("Called process with %v\n", m.(*IncomingMessage).Data) return nil, nil } diff --git a/examples/event-stream/example_test.go b/examples/event-stream/example_test.go index 60bbf65..2e675da 100644 --- a/examples/event-stream/example_test.go +++ b/examples/event-stream/example_test.go @@ -77,7 +77,7 @@ func (s *StatService) Start(ctx context.Context) error { ) } -func (s *StatService) handleOrderEvent(msg any, headers Headers) (response any, err error) { +func (s *StatService) handleOrderEvent(ctx context.Context, msg any, headers Headers) (response any, err error) { switch msg.(type) { case *OrderCreated: // Just to make sure the Output is correct in the example... @@ -106,7 +106,7 @@ func (s *ShippingService) Start(ctx context.Context) error { ) } -func (s *ShippingService) handleOrderEvent(msg any, headers Headers) (response any, err error) { +func (s *ShippingService) handleOrderEvent(ctx context.Context, msg any, headers Headers) (response any, err error) { switch msg.(type) { case *OrderCreated: fmt.Println("Order created") diff --git a/examples/request-response/example_test.go b/examples/request-response/example_test.go index b57d076..32e8198 100644 --- a/examples/request-response/example_test.go +++ b/examples/request-response/example_test.go @@ -65,14 +65,14 @@ func checkError(err error) { } } -func handleRequest(m any, headers Headers) (any, error) { +func handleRequest(ctx context.Context, m any, headers Headers) (any, error) { request := m.(*Request) response := Response{Data: request.Data} fmt.Printf("Called process with %v, returning response %v\n", request.Data, response) return response, nil } -func handleResponse(m any, headers Headers) (any, error) { +func handleResponse(ctx context.Context, m any, headers Headers) (any, error) { response := m.(*Response) fmt.Printf("Got response, returning response %v\n", response.Data) return nil, nil