Skip to content

Commit

Permalink
simplify context error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
ewollesen committed Jan 16, 2025
1 parent 7d21a4a commit 5b852bf
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 30 deletions.
30 changes: 3 additions & 27 deletions asyncevents/sarama.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,17 @@ func NewSaramaEventsConsumer(consumerGroup sarama.ConsumerGroup,
// Run is stopped by its context being canceled. When its context is canceled,
// it returns nil.
func (p *SaramaEventsConsumer) Run(ctx context.Context) (err error) {
defer canceledContextReturnsNil(&err)

for {
err := p.ConsumerGroup.Consume(ctx, p.Topics, p.Handler)
if err != nil {
return err
}
if ctxErr := ctx.Err(); ctxErr != nil {
return ctxErr
return nil
}
}
}

// canceledContextReturnsNil checks for a context.Canceled error, and when
// found, returns nil instead.
//
// It is meant to be called via defer.
func canceledContextReturnsNil(err *error) {
if err != nil && errors.Is(*err, context.Canceled) {
*err = nil
}
}

// SaramaConsumerGroupHandler implements sarama.ConsumerGroupHandler.
type SaramaConsumerGroupHandler struct {
Consumer SaramaMessageConsumer
Expand Down Expand Up @@ -161,7 +149,7 @@ type Logger interface {
}

func (c *NTimesRetryingConsumer) Consume(ctx context.Context,
session sarama.ConsumerGroupSession, message *sarama.ConsumerMessage) error {
session sarama.ConsumerGroupSession, message *sarama.ConsumerMessage) (err error) {

var joinedErrors error
var tries int = 0
Expand All @@ -176,18 +164,10 @@ func (c *NTimesRetryingConsumer) Consume(ctx context.Context,
for tries < c.Times {
select {
case <-done:
if ctxErr := ctx.Err(); ctxErr != nil {
return ctxErr
}
return nil
case <-time.After(delay):
err := c.Consumer.Consume(ctx, session, message)
if err == nil {
return nil
}
if c.isContextErr(err) {
return err
} else if errors.Is(err, nil) {
if errors.Is(err, nil) || errors.Is(err, context.Canceled) {
return nil
}
delay = c.Delay(tries)
Expand All @@ -205,10 +185,6 @@ func (c *NTimesRetryingConsumer) Consume(ctx context.Context,
return errors.Join(joinedErrors, c.retryLimitError())
}

func (c *NTimesRetryingConsumer) isContextErr(err error) bool {
return errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled)
}

func (c *NTimesRetryingConsumer) retryLimitError() error {
return fmt.Errorf("%w (%d)", ErrRetriesLimitExceeded, c.Times)
}
Expand Down
6 changes: 3 additions & 3 deletions asyncevents/sarama_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func TestNTimesRetryingConsumer(s *testing.T) {
}
})

s.Run("aborts when the context deadline exceeded", func(t *testing.T) {
s.Run("retries when the context deadline is exceeded", func(t *testing.T) {
testConsumer := newCountingSaramaMessageConsumer(context.DeadlineExceeded)
c := &NTimesRetryingConsumer{
Times: testTimes,
Expand All @@ -115,8 +115,8 @@ func TestNTimesRetryingConsumer(s *testing.T) {
if !errors.Is(err, context.DeadlineExceeded) {
t.Errorf("expected %s, got %v", context.DeadlineExceeded, err)
}
if testConsumer.Count >= testTimes {
t.Errorf("expected < %d tries, got %d", testTimes, testConsumer.Count)
if testConsumer.Count != testTimes {
t.Errorf("expected %d tries, got %d", testTimes, testConsumer.Count)
}
})

Expand Down

0 comments on commit 5b852bf

Please sign in to comment.