Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Context #37

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import (
// HandlerFunc is used to process an incoming message
// If processing fails, an error should be returned and the message will be re-queued
// The optional response is used automatically when setting up a RequestResponseHandler, otherwise ignored
type HandlerFunc func(msg any, headers Headers) (response any, err error)
type HandlerFunc func(ctx context.Context, msg any, headers Headers) (response any, err error)

// Connection is a wrapper around the actual amqp.Connection and amqp.Channel
type Connection struct {
Expand Down Expand Up @@ -171,7 +171,7 @@ func (c *Connection) Close() error {
}

func (c *Connection) TypeMappingHandler(handler HandlerFunc) HandlerFunc {
return func(msg any, headers Headers) (response any, err error) {
return func(ctx context.Context, msg any, headers Headers) (response any, err error) {
routingKey := headers["routing-key"].(string)
typ, exists := c.keyToType[routingKey]
if !exists {
Expand All @@ -182,7 +182,7 @@ func (c *Connection) TypeMappingHandler(handler HandlerFunc) HandlerFunc {
if err != nil {
return nil, err
} else {
if resp, err := handler(message, headers); err == nil {
if resp, err := handler(ctx, message, headers); err == nil {
return resp, nil
} else {
return nil, err
Expand Down Expand Up @@ -242,8 +242,8 @@ func amqpVersion() string {
}

func responseWrapper(handler HandlerFunc, routingKey string, publisher ServiceResponsePublisher) HandlerFunc {
return func(msg any, headers Headers) (response any, err error) {
resp, err := handler(msg, headers)
return func(ctx context.Context, msg any, headers Headers) (response any, err error) {
resp, err := handler(ctx, msg, headers)
if err != nil {
return nil, errors.Wrap(err, "failed to process message")
}
Expand Down Expand Up @@ -351,13 +351,13 @@ func (c *Connection) addHandler(queueName, routingKey string, eventType eventTyp
return c.queueHandlers.Add(queueName, routingKey, mHI)
}

func (c *Connection) handleMessage(d amqp.Delivery, handler HandlerFunc, eventType eventType) {
func (c *Connection) handleMessage(ctx context.Context, d amqp.Delivery, handler HandlerFunc, eventType eventType) {
message, err := c.parseMessage(d.Body, eventType)
if err != nil {
c.errorLog(fmt.Sprintf("failed to parse message %s", err))
_ = d.Reject(false)
} else {
if _, err := handler(message, headers(d.Headers, d.RoutingKey)); err == nil {
if _, err := handler(ctx, message, headers(d.Headers, d.RoutingKey)); err == nil {
_ = d.Ack(false)
} else {
c.errorLog(fmt.Sprintf("failed to process message %s", err))
Expand Down Expand Up @@ -400,8 +400,10 @@ func (c *Connection) publishMessage(ctx context.Context, msg any, routingKey, ex
func (c *Connection) divertToMessageHandlers(deliveries <-chan amqp.Delivery, handlers *handlers.Handlers[messageHandlerInvoker]) {
for d := range deliveries {
if h, ok := handlers.Get(d.RoutingKey); ok {
// TODO More here..
ctx := context.Background()
c.messageLogger(d.Body, h.eventType, d.RoutingKey, false)
c.handleMessage(d, h.msgHandler, h.eventType)
c.handleMessage(ctx, d, h.msgHandler, h.eventType)
} else {
// Unhandled message, drop it
_ = d.Reject(false)
Expand Down
22 changes: 11 additions & 11 deletions connection_options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func Test_UseMessageLogger_Default(t *testing.T) {
func Test_EventStreamConsumer(t *testing.T) {
channel := NewMockAmqpChannel()
conn := mockConnection(channel)
err := conn.Start(context.Background(), EventStreamConsumer("key", func(i any, headers Headers) (any, error) {
err := conn.Start(context.Background(), EventStreamConsumer("key", func(ctx context.Context, i any, headers Headers) (any, error) {
return nil, nil
}, TestMessage{}))
require.NoError(t, err)
Expand All @@ -203,7 +203,7 @@ func Test_EventStreamConsumer(t *testing.T) {
func Test_EventStreamConsumerWithOptFunc(t *testing.T) {
channel := NewMockAmqpChannel()
conn := mockConnection(channel)
err := conn.Start(context.Background(), EventStreamConsumer("key", func(i any, headers Headers) (any, error) {
err := conn.Start(context.Background(), EventStreamConsumer("key", func(ctx context.Context, i any, headers Headers) (any, error) {
return nil, nil
}, TestMessage{}, AddQueueNameSuffix("suffix")))
require.NoError(t, err)
Expand All @@ -223,7 +223,7 @@ func Test_EventStreamConsumerWithOptFunc(t *testing.T) {
func Test_EventStreamConsumerWithFailingOptFunc(t *testing.T) {
channel := NewMockAmqpChannel()
conn := mockConnection(channel)
err := conn.Start(context.Background(), EventStreamConsumer("key", func(i any, headers Headers) (any, error) {
err := conn.Start(context.Background(), EventStreamConsumer("key", func(ctx context.Context, i any, headers Headers) (any, error) {
return nil, nil
}, TestMessage{}, AddQueueNameSuffix("")))
require.ErrorContains(t, err, "failed, empty queue suffix not allowed")
Expand All @@ -232,7 +232,7 @@ func Test_EventStreamConsumerWithFailingOptFunc(t *testing.T) {
func Test_ServiceRequestConsumer_Ok(t *testing.T) {
channel := NewMockAmqpChannel()
conn := mockConnection(channel)
err := conn.Start(context.Background(), ServiceRequestConsumer("key", func(i any, headers Headers) (any, error) {
err := conn.Start(context.Background(), ServiceRequestConsumer("key", func(ctx context.Context, i any, headers Headers) (any, error) {
return nil, nil
}, TestMessage{}))

Expand All @@ -256,7 +256,7 @@ func Test_ServiceRequestConsumer_ExchangeDeclareError(t *testing.T) {
declareError := errors.New("failed")
channel.ExchangeDeclarationError = &declareError
conn := mockConnection(channel)
err := conn.Start(context.Background(), ServiceRequestConsumer("key", func(i any, headers Headers) (any, error) {
err := conn.Start(context.Background(), ServiceRequestConsumer("key", func(ctx context.Context, i any, headers Headers) (any, error) {
return nil, nil
}, TestMessage{}))

Expand All @@ -266,7 +266,7 @@ func Test_ServiceRequestConsumer_ExchangeDeclareError(t *testing.T) {
func Test_ServiceResponseConsumer_Ok(t *testing.T) {
channel := NewMockAmqpChannel()
conn := mockConnection(channel)
err := conn.Start(context.Background(), ServiceResponseConsumer("targetService", "key", func(i any, headers Headers) (any, error) {
err := conn.Start(context.Background(), ServiceResponseConsumer("targetService", "key", func(ctx context.Context, i any, headers Headers) (any, error) {
return nil, nil
}, TestMessage{}))

Expand All @@ -289,7 +289,7 @@ func Test_ServiceResponseConsumer_ExchangeDeclareError(t *testing.T) {
declareError := errors.New("actual error message")
channel.ExchangeDeclarationError = &declareError
conn := mockConnection(channel)
err := conn.Start(context.Background(), ServiceResponseConsumer("targetService", "key", func(i any, headers Headers) (any, error) {
err := conn.Start(context.Background(), ServiceResponseConsumer("targetService", "key", func(ctx context.Context, i any, headers Headers) (any, error) {
return nil, nil
}, TestMessage{}))

Expand All @@ -299,7 +299,7 @@ func Test_ServiceResponseConsumer_ExchangeDeclareError(t *testing.T) {
func Test_RequestResponseHandler(t *testing.T) {
channel := NewMockAmqpChannel()
conn := mockConnection(channel)
err := RequestResponseHandler("key", func(msg any, headers Headers) (response any, err error) {
err := RequestResponseHandler("key", func(ctx context.Context, msg any, headers Headers) (response any, err error) {
return nil, nil
}, Message{})(conn)
require.NoError(t, err)
Expand Down Expand Up @@ -412,7 +412,7 @@ func Test_TransientEventStreamConsumer_Ok(t *testing.T) {
channel := NewMockAmqpChannel()
conn := mockConnection(channel)
uuid.SetRand(badRand{})
err := TransientEventStreamConsumer("key", func(i any, headers Headers) (any, error) {
err := TransientEventStreamConsumer("key", func(ctx context.Context, i any, headers Headers) (any, error) {
return nil, errors.New("failed")
}, Message{})(conn)

Expand All @@ -437,7 +437,7 @@ func Test_TransientEventStreamConsumer_HandlerForRoutingKeyAlreadyExists(t *test
require.NoError(t, conn.queueHandlers.Add("events.topic.exchange.queue.svc-00010203-0405-4607-8809-0a0b0c0d0e0f", "root.key", &messageHandlerInvoker{}))

uuid.SetRand(badRand{})
err := TransientEventStreamConsumer("root.#", func(i any, headers Headers) (any, error) {
err := TransientEventStreamConsumer("root.#", func(ctx context.Context, i any, headers Headers) (any, error) {
return nil, errors.New("failed")
}, Message{})(conn)

Expand All @@ -463,7 +463,7 @@ func testTransientEventStreamConsumerFailure(t *testing.T, channel *MockAmqpChan
conn := mockConnection(channel)

uuid.SetRand(badRand{})
err := TransientEventStreamConsumer("key", func(i any, headers Headers) (any, error) {
err := TransientEventStreamConsumer("key", func(ctx context.Context, i any, headers Headers) (any, error) {
return nil, errors.New("failed")
}, Message{})(conn)

Expand Down
20 changes: 10 additions & 10 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func Test_Start_SetupFails(t *testing.T) {
queueHandlers: &handlers2.QueueHandlers[messageHandlerInvoker]{},
}
err := conn.Start(context.Background(),
EventStreamConsumer("test", func(i any, headers Headers) (any, error) {
EventStreamConsumer("test", func(ctx context.Context, i any, headers Headers) (any, error) {
return nil, errors.New("failed")
}, Message{}))
require.Error(t, err)
Expand Down Expand Up @@ -392,9 +392,9 @@ func TestResponseWrapper(t *testing.T) {
if tt.headers != nil {
headers = *tt.headers
}
resp, err := responseWrapper(func(i any, headers Headers) (any, error) {
resp, err := responseWrapper(func(ctx context.Context, i any, headers Headers) (any, error) {
return tt.handlerResp, tt.handlerErr
}, "key", p.publish)(&Message{}, headers)
}, "key", p.publish)(context.Background(), &Message{}, headers)
p.checkPublished(t, tt.published)

require.Equal(t, tt.wantResp, resp)
Expand All @@ -416,7 +416,7 @@ func Test_DivertToMessageHandler(t *testing.T) {
handlers := &handlers2.QueueHandlers[messageHandlerInvoker]{}
msgInvoker := &messageHandlerInvoker{
eventType: reflect.TypeOf(Message{}),
msgHandler: func(i any, headers Headers) (any, error) {
msgHandler: func(ctx context.Context, i any, headers Headers) (any, error) {
if i.(*Message).Ok {
return nil, nil
}
Expand Down Expand Up @@ -500,7 +500,7 @@ func testHandleMessage(json string, handle bool) MockAcknowledger {
messageLogger: noOpMessageLogger(),
errorLog: noOpLogger,
}
c.handleMessage(delivery, func(i any, headers Headers) (any, error) {
c.handleMessage(context.Background(), delivery, func(ctx context.Context, i any, headers Headers) (any, error) {
if handle {
return nil, nil
}
Expand Down Expand Up @@ -570,7 +570,7 @@ func TestConnection_TypeMappingHandler(t *testing.T) {
msg: []byte(`{"a":true}`),
key: "unknown",
handler: func(t *testing.T) HandlerFunc {
return func(msg any, headers Headers) (response any, err error) {
return func(ctx context.Context, msg any, headers Headers) (response any, err error) {
return nil, nil
}
},
Expand All @@ -589,7 +589,7 @@ func TestConnection_TypeMappingHandler(t *testing.T) {
msg: []byte(`{"a:}`),
key: "known",
handler: func(t *testing.T) HandlerFunc {
return func(msg any, headers Headers) (response any, err error) {
return func(ctx context.Context, msg any, headers Headers) (response any, err error) {
return nil, nil
}
},
Expand All @@ -610,7 +610,7 @@ func TestConnection_TypeMappingHandler(t *testing.T) {
msg: []byte(`{"a":true}`),
key: "known",
handler: func(t *testing.T) HandlerFunc {
return func(msg any, headers Headers) (response any, err error) {
return func(ctx context.Context, msg any, headers Headers) (response any, err error) {
assert.IsType(t, &TestMessage{}, msg)
return nil, fmt.Errorf("handler-error")
}
Expand All @@ -632,7 +632,7 @@ func TestConnection_TypeMappingHandler(t *testing.T) {
msg: []byte(`{"a":true}`),
key: "known",
handler: func(t *testing.T) HandlerFunc {
return func(msg any, headers Headers) (response any, err error) {
return func(ctx context.Context, msg any, headers Headers) (response any, err error) {
assert.IsType(t, &TestMessage{}, msg)
return "OK", nil
}
Expand All @@ -650,7 +650,7 @@ func TestConnection_TypeMappingHandler(t *testing.T) {
}

handler := c.TypeMappingHandler(tt.args.handler(t))
res, err := handler(&tt.args.msg, headers(make(amqp.Table), tt.args.key))
res, err := handler(context.Background(), &tt.args.msg, headers(make(amqp.Table), tt.args.key))
if !tt.wantErr(t, err) {
return
}
Expand Down
3 changes: 2 additions & 1 deletion example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ var amqpURL = "amqp://user:password@localhost:5672/"

func Example() {
ctx := context.Background()

if urlFromEnv := os.Getenv("AMQP_URL"); urlFromEnv != "" {
amqpURL = urlFromEnv
}
Expand All @@ -58,7 +59,7 @@ func checkError(err error) {
}
}

func process(m any, headers Headers) (any, error) {
func process(ctx context.Context, m any, headers Headers) (any, error) {
fmt.Printf("Called process with %v\n", m.(*IncomingMessage).Data)
return nil, nil
}
Expand Down
4 changes: 2 additions & 2 deletions examples/event-stream/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (s *StatService) Start(ctx context.Context) error {
)
}

func (s *StatService) handleOrderEvent(msg any, headers Headers) (response any, err error) {
func (s *StatService) handleOrderEvent(ctx context.Context, msg any, headers Headers) (response any, err error) {
switch msg.(type) {
case *OrderCreated:
// Just to make sure the Output is correct in the example...
Expand Down Expand Up @@ -106,7 +106,7 @@ func (s *ShippingService) Start(ctx context.Context) error {
)
}

func (s *ShippingService) handleOrderEvent(msg any, headers Headers) (response any, err error) {
func (s *ShippingService) handleOrderEvent(ctx context.Context, msg any, headers Headers) (response any, err error) {
switch msg.(type) {
case *OrderCreated:
fmt.Println("Order created")
Expand Down
4 changes: 2 additions & 2 deletions examples/request-response/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ func checkError(err error) {
}
}

func handleRequest(m any, headers Headers) (any, error) {
func handleRequest(ctx context.Context, m any, headers Headers) (any, error) {
request := m.(*Request)
response := Response{Data: request.Data}
fmt.Printf("Called process with %v, returning response %v\n", request.Data, response)
return response, nil
}

func handleResponse(m any, headers Headers) (any, error) {
func handleResponse(ctx context.Context, m any, headers Headers) (any, error) {
response := m.(*Response)
fmt.Printf("Got response, returning response %v\n", response.Data)
return nil, nil
Expand Down
Loading