Skip to content

Commit

Permalink
Merge pull request #213 from rusq/i212
Browse files Browse the repository at this point in the history
merge network package from v3 (fixes #212)
  • Loading branch information
rusq authored Apr 22, 2023
2 parents 57966d8 + b5ba49c commit c971b1b
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 49 deletions.
2 changes: 1 addition & 1 deletion export/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func New(sd *slackdump.Session, fs fsadapter.FS, cfg Options) *Export {
if cfg.Logger == nil {
cfg.Logger = logger.Default
}
network.Logger = cfg.Logger
network.SetLogger(cfg.Logger)

se := &Export{
fs: fs,
Expand Down
6 changes: 5 additions & 1 deletion internal/network/limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@ const (
Tier2 Tier = 20
Tier3 Tier = 50
// Tier4 Tier = 100

// secPerMin is the number of seconds in a minute, it is here to allow easy
// modification of the program, should this value change.
secPerMin = 60.0
)

// NewLimiter returns throttler with rateLimit requests per minute.
// optionally caller may specify the boost
func NewLimiter(t Tier, burst uint, boost int) *rate.Limiter {
callsPerSec := float64(int(t)+boost) / 60.0
callsPerSec := float64(int(t)+boost) / secPerMin
l := rate.NewLimiter(rate.Limit(callsPerSec), int(burst))
return l
}
38 changes: 38 additions & 0 deletions internal/network/limiter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package network

import (
"testing"

"golang.org/x/time/rate"
)

func TestNewLimiter(t *testing.T) {
type args struct {
t Tier
burst uint
boost int
}
tests := []struct {
name string
args args
want *rate.Limiter
wantPerSec rate.Limit
}{
{
name: "tier 2",
args: args{
t: Tier2,
burst: 10,
boost: 0,
},
wantPerSec: 0.3333333333333333,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := NewLimiter(tt.args.t, tt.args.burst, tt.args.boost); got.Limit() != tt.wantPerSec {
t.Errorf("NewLimiter() = %v, want %v", got.Limit(), tt.wantPerSec)
}
})
}
}
73 changes: 49 additions & 24 deletions internal/network/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ package network

import (
"context"
"errors"
"fmt"
"net/http"
"runtime/trace"
"sync"
"time"

"errors"

"github.com/slack-go/slack"
"golang.org/x/time/rate"

Expand All @@ -20,30 +20,36 @@ const (
defNumAttempts = 3
)

// MaxAllowedWaitTime is the maximum time to wait for a transient error. The
// wait time for a transient error depends on the current retry attempt number
// and is calculated as: (attempt+2)^3 seconds, capped at MaxAllowedWaitTime.
var MaxAllowedWaitTime = 5 * time.Minute
var (
// maxAllowedWaitTime is the maximum time to wait for a transient error.
// The wait time for a transient error depends on the current retry
// attempt number and is calculated as: (attempt+2)^3 seconds, capped at
// maxAllowedWaitTime.
maxAllowedWaitTime = 5 * time.Minute
lg logger.Interface = logger.Default
// waitFn returns the amount of time to wait before retrying depending on
// the current attempt. This variable exists to reduce the test time.
waitFn = cubicWait

// Logger is the package logger.
var Logger logger.Interface = logger.Default
mu sync.RWMutex
)

// ErrRetryFailed is returned if number of retry attempts exceeded the retry attempts limit and
// function wasn't able to complete without errors.
var ErrRetryFailed = errors.New("callback was not able to complete without errors within the allowed number of retries")
var ErrRetryFailed = errors.New("callback was unable to complete without errors within the allowed number of retries")

// withRetry will run the callback function fn. If the function returns
// slack.RateLimitedError, it will delay, and then call it again up to
// maxAttempts times. It will return an error if it runs out of attempts.
func WithRetry(ctx context.Context, l *rate.Limiter, maxAttempts int, fn func() error) error {
func WithRetry(ctx context.Context, lim *rate.Limiter, maxAttempts int, fn func() error) error {
var ok bool
if maxAttempts == 0 {
maxAttempts = defNumAttempts
}
for attempt := 0; attempt < maxAttempts; attempt++ {
var err error
trace.WithRegion(ctx, "withRetry.wait", func() {
err = l.Wait(ctx)
err = lim.Wait(ctx)
})
if err != nil {
return err
Expand All @@ -65,9 +71,9 @@ func WithRetry(ctx context.Context, l *rate.Limiter, maxAttempts int, fn func()
time.Sleep(rle.RetryAfter)
continue
} else if errors.As(cbErr, &sce) {
if sce.Code >= http.StatusInternalServerError && sce.Code <= 599 {
if isRecoverable(sce.Code) {
// possibly transient error
delay := waitTime(attempt)
delay := waitFn(attempt)
tracelogf(ctx, "info", "got server error %d, sleeping %s", sce.Code, delay)
time.Sleep(delay)
continue
Expand All @@ -82,27 +88,46 @@ func WithRetry(ctx context.Context, l *rate.Limiter, maxAttempts int, fn func()
return nil
}

// waitTime returns the amount of time to wait before retrying depending on
// the current attempt. The wait time is calculated as (x+2)^3 seconds, where
// x is the current attempt number. The maximum wait time is capped at 5
// isRecoverable returns true if the status code is a recoverable error.
func isRecoverable(statusCode int) bool {
return (statusCode >= http.StatusInternalServerError && statusCode <= 599) || statusCode == 408
}

// cubicWait is the wait time function. Time is calculated as (x+2)^3 seconds,
// where x is the current attempt number. The maximum wait time is capped at 5
// minutes.
func waitTime(attempt int) time.Duration {
func cubicWait(attempt int) time.Duration {
x := attempt + 2 // this is to ensure that we sleep at least 8 seconds.
delay := time.Duration(x*x*x) * time.Second
if delay > MaxAllowedWaitTime {
return MaxAllowedWaitTime
if delay > maxAllowedWaitTime {
return maxAllowedWaitTime
}
return delay
}

func tracelogf(ctx context.Context, category string, fmt string, a ...any) {
mu.RLock()
defer mu.RUnlock()

trace.Logf(ctx, category, fmt, a...)
l().Debugf(fmt, a...)
lg.Debugf(fmt, a...)
}

func l() logger.Interface {
if Logger == nil {
return logger.Default
// SetLogger sets the package logger.
func SetLogger(l logger.Interface) {
mu.Lock()
defer mu.Unlock()
if l == nil {
l = logger.Default
return
}
return Logger
lg = l
}

// SetMaxAllowedWaitTime sets the maximum time to wait for a transient error.
func SetMaxAllowedWaitTime(d time.Duration) {
mu.Lock()
defer mu.Unlock()

maxAllowedWaitTime = d
}
48 changes: 40 additions & 8 deletions internal/network/network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,16 @@ func Test_withRetry(t *testing.T) {
}

func Test500ErrorHandling(t *testing.T) {
t.Parallel()
waitFn = func(attempt int) time.Duration { return 50 * time.Millisecond }
defer func() {
waitFn = cubicWait
}()

var codes = []int{500, 502, 503, 504, 598}
for _, code := range codes {
var thisCode = code
// This test is to ensure that we handle 500 errors correctly.
t.Run(fmt.Sprintf("%d error", code), func(t *testing.T) {
t.Parallel()

const (
testRetryCount = 1
Expand Down Expand Up @@ -185,8 +187,8 @@ func Test500ErrorHandling(t *testing.T) {
}

dur := time.Since(start)
if dur < waitTime(testRetryCount-1)-waitThreshold || waitTime(testRetryCount-1)+waitThreshold < dur {
t.Errorf("expected duration to be around %s, got %s", waitTime(testRetryCount), dur)
if dur < waitFn(testRetryCount-1)-waitThreshold || waitFn(testRetryCount-1)+waitThreshold < dur {
t.Errorf("expected duration to be around %s, got %s", waitFn(testRetryCount), dur)
}

})
Expand Down Expand Up @@ -227,7 +229,7 @@ func Test500ErrorHandling(t *testing.T) {
})
}

func Test_waitTime(t *testing.T) {
func Test_cubicWait(t *testing.T) {
type args struct {
attempt int
}
Expand All @@ -240,14 +242,44 @@ func Test_waitTime(t *testing.T) {
{"attempt 1", args{1}, 27 * time.Second},
{"attempt 2", args{2}, 64 * time.Second},
{"attempt 2", args{4}, 216 * time.Second},
{"attempt 100", args{5}, MaxAllowedWaitTime}, // check if capped properly
{"attempt 100", args{1000}, MaxAllowedWaitTime},
{"attempt 100", args{5}, maxAllowedWaitTime}, // check if capped properly
{"attempt 100", args{1000}, maxAllowedWaitTime},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := waitTime(tt.args.attempt); !reflect.DeepEqual(got, tt.want) {
if got := cubicWait(tt.args.attempt); !reflect.DeepEqual(got, tt.want) {
t.Errorf("waitTime() = %v, want %v", got, tt.want)
}
})
}
}

func Test_isRecoverable(t *testing.T) {
type args struct {
statusCode int
}
tests := []struct {
name string
args args
want bool
}{
{"500", args{500}, true},
{"502", args{502}, true},
{"503", args{503}, true},
{"504", args{504}, true},
{"598", args{598}, true},
{"599", args{599}, true},
{"200", args{200}, false},
{"400", args{400}, false},
{"404", args{404}, false},
{"408", args{408}, true},
{"429", args{429}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isRecoverable(tt.args.statusCode); got != tt.want {
t.Errorf("isRecoverable() = %v, want %v", got, tt.want)
}
})
}
}
16 changes: 1 addition & 15 deletions slackdump.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"fmt"
"io"
"net/http"
"os"
"runtime/trace"
"time"
Expand Down Expand Up @@ -107,7 +106,7 @@ func NewWithOptions(ctx context.Context, authProvider auth.Provider, opts Option
fs: fsadapter.NewDirectory("."), // default is to save attachments to the current directory.
}

sd.propagateLogger(sd.l())
network.SetLogger(sd.l())

if err := os.MkdirAll(opts.CacheDir, 0700); err != nil {
return nil, fmt.Errorf("failed to create the cache directory: %s", err)
Expand Down Expand Up @@ -173,14 +172,6 @@ func (sd *Session) SetFS(fs fsadapter.FS) {
sd.fs = fs
}

func toPtrCookies(cc []http.Cookie) []*http.Cookie {
var ret = make([]*http.Cookie, len(cc))
for i := range cc {
ret[i] = &cc[i]
}
return ret
}

func (sd *Session) limiter(t network.Tier) *rate.Limiter {
return network.NewLimiter(t, sd.options.Tier3Burst, int(sd.options.Tier3Boost))
}
Expand Down Expand Up @@ -224,8 +215,3 @@ func (sd *Session) l() logger.Interface {
}
return sd.options.Logger
}

// propagateLogger propagates the slackdump logger to some dumb packages.
func (sd *Session) propagateLogger(l logger.Interface) {
network.Logger = l
}

0 comments on commit c971b1b

Please sign in to comment.