diff --git a/export/export.go b/export/export.go index 51a1afcf..2fd4449e 100644 --- a/export/export.go +++ b/export/export.go @@ -37,7 +37,7 @@ func New(sd *slackdump.Session, fs fsadapter.FS, cfg Options) *Export { if cfg.Logger == nil { cfg.Logger = logger.Default } - network.Logger = cfg.Logger + network.SetLogger(cfg.Logger) se := &Export{ fs: fs, diff --git a/internal/network/limiter.go b/internal/network/limiter.go index 98a47a50..27511b2c 100644 --- a/internal/network/limiter.go +++ b/internal/network/limiter.go @@ -14,12 +14,16 @@ const ( Tier2 Tier = 20 Tier3 Tier = 50 // Tier4 Tier = 100 + + // secPerMin is the number of seconds in a minute, it is here to allow easy + // modification of the program, should this value change. + secPerMin = 60.0 ) // NewLimiter returns throttler with rateLimit requests per minute. // optionally caller may specify the boost func NewLimiter(t Tier, burst uint, boost int) *rate.Limiter { - callsPerSec := float64(int(t)+boost) / 60.0 + callsPerSec := float64(int(t)+boost) / secPerMin l := rate.NewLimiter(rate.Limit(callsPerSec), int(burst)) return l } diff --git a/internal/network/limiter_test.go b/internal/network/limiter_test.go new file mode 100644 index 00000000..2d6b42a0 --- /dev/null +++ b/internal/network/limiter_test.go @@ -0,0 +1,38 @@ +package network + +import ( + "testing" + + "golang.org/x/time/rate" +) + +func TestNewLimiter(t *testing.T) { + type args struct { + t Tier + burst uint + boost int + } + tests := []struct { + name string + args args + want *rate.Limiter + wantPerSec rate.Limit + }{ + { + name: "tier 2", + args: args{ + t: Tier2, + burst: 10, + boost: 0, + }, + wantPerSec: 0.3333333333333333, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := NewLimiter(tt.args.t, tt.args.burst, tt.args.boost); got.Limit() != tt.wantPerSec { + t.Errorf("NewLimiter() = %v, want %v", got.Limit(), tt.wantPerSec) + } + }) + } +} diff --git a/internal/network/network.go b/internal/network/network.go index 972407cd..cd8e2a8c 100644 --- a/internal/network/network.go +++ b/internal/network/network.go @@ -2,13 +2,13 @@ package network import ( "context" + "errors" "fmt" "net/http" "runtime/trace" + "sync" "time" - "errors" - "github.com/slack-go/slack" "golang.org/x/time/rate" @@ -20,22 +20,28 @@ const ( defNumAttempts = 3 ) -// MaxAllowedWaitTime is the maximum time to wait for a transient error. The -// wait time for a transient error depends on the current retry attempt number -// and is calculated as: (attempt+2)^3 seconds, capped at MaxAllowedWaitTime. -var MaxAllowedWaitTime = 5 * time.Minute +var ( + // maxAllowedWaitTime is the maximum time to wait for a transient error. + // The wait time for a transient error depends on the current retry + // attempt number and is calculated as: (attempt+2)^3 seconds, capped at + // maxAllowedWaitTime. + maxAllowedWaitTime = 5 * time.Minute + lg logger.Interface = logger.Default + // waitFn returns the amount of time to wait before retrying depending on + // the current attempt. This variable exists to reduce the test time. + waitFn = cubicWait -// Logger is the package logger. -var Logger logger.Interface = logger.Default + mu sync.RWMutex +) // ErrRetryFailed is returned if number of retry attempts exceeded the retry attempts limit and // function wasn't able to complete without errors. -var ErrRetryFailed = errors.New("callback was not able to complete without errors within the allowed number of retries") +var ErrRetryFailed = errors.New("callback was unable to complete without errors within the allowed number of retries") // withRetry will run the callback function fn. If the function returns // slack.RateLimitedError, it will delay, and then call it again up to // maxAttempts times. It will return an error if it runs out of attempts. -func WithRetry(ctx context.Context, l *rate.Limiter, maxAttempts int, fn func() error) error { +func WithRetry(ctx context.Context, lim *rate.Limiter, maxAttempts int, fn func() error) error { var ok bool if maxAttempts == 0 { maxAttempts = defNumAttempts @@ -43,7 +49,7 @@ func WithRetry(ctx context.Context, l *rate.Limiter, maxAttempts int, fn func() for attempt := 0; attempt < maxAttempts; attempt++ { var err error trace.WithRegion(ctx, "withRetry.wait", func() { - err = l.Wait(ctx) + err = lim.Wait(ctx) }) if err != nil { return err @@ -65,9 +71,9 @@ func WithRetry(ctx context.Context, l *rate.Limiter, maxAttempts int, fn func() time.Sleep(rle.RetryAfter) continue } else if errors.As(cbErr, &sce) { - if sce.Code >= http.StatusInternalServerError && sce.Code <= 599 { + if isRecoverable(sce.Code) { // possibly transient error - delay := waitTime(attempt) + delay := waitFn(attempt) tracelogf(ctx, "info", "got server error %d, sleeping %s", sce.Code, delay) time.Sleep(delay) continue @@ -82,27 +88,46 @@ func WithRetry(ctx context.Context, l *rate.Limiter, maxAttempts int, fn func() return nil } -// waitTime returns the amount of time to wait before retrying depending on -// the current attempt. The wait time is calculated as (x+2)^3 seconds, where -// x is the current attempt number. The maximum wait time is capped at 5 +// isRecoverable returns true if the status code is a recoverable error. +func isRecoverable(statusCode int) bool { + return (statusCode >= http.StatusInternalServerError && statusCode <= 599) || statusCode == 408 +} + +// cubicWait is the wait time function. Time is calculated as (x+2)^3 seconds, +// where x is the current attempt number. The maximum wait time is capped at 5 // minutes. -func waitTime(attempt int) time.Duration { +func cubicWait(attempt int) time.Duration { x := attempt + 2 // this is to ensure that we sleep at least 8 seconds. delay := time.Duration(x*x*x) * time.Second - if delay > MaxAllowedWaitTime { - return MaxAllowedWaitTime + if delay > maxAllowedWaitTime { + return maxAllowedWaitTime } return delay } func tracelogf(ctx context.Context, category string, fmt string, a ...any) { + mu.RLock() + defer mu.RUnlock() + trace.Logf(ctx, category, fmt, a...) - l().Debugf(fmt, a...) + lg.Debugf(fmt, a...) } -func l() logger.Interface { - if Logger == nil { - return logger.Default +// SetLogger sets the package logger. +func SetLogger(l logger.Interface) { + mu.Lock() + defer mu.Unlock() + if l == nil { + l = logger.Default + return } - return Logger + lg = l +} + +// SetMaxAllowedWaitTime sets the maximum time to wait for a transient error. +func SetMaxAllowedWaitTime(d time.Duration) { + mu.Lock() + defer mu.Unlock() + + maxAllowedWaitTime = d } diff --git a/internal/network/network_test.go b/internal/network/network_test.go index 0ef060fe..fbabdeaf 100644 --- a/internal/network/network_test.go +++ b/internal/network/network_test.go @@ -148,14 +148,16 @@ func Test_withRetry(t *testing.T) { } func Test500ErrorHandling(t *testing.T) { - t.Parallel() + waitFn = func(attempt int) time.Duration { return 50 * time.Millisecond } + defer func() { + waitFn = cubicWait + }() var codes = []int{500, 502, 503, 504, 598} for _, code := range codes { var thisCode = code // This test is to ensure that we handle 500 errors correctly. t.Run(fmt.Sprintf("%d error", code), func(t *testing.T) { - t.Parallel() const ( testRetryCount = 1 @@ -185,8 +187,8 @@ func Test500ErrorHandling(t *testing.T) { } dur := time.Since(start) - if dur < waitTime(testRetryCount-1)-waitThreshold || waitTime(testRetryCount-1)+waitThreshold < dur { - t.Errorf("expected duration to be around %s, got %s", waitTime(testRetryCount), dur) + if dur < waitFn(testRetryCount-1)-waitThreshold || waitFn(testRetryCount-1)+waitThreshold < dur { + t.Errorf("expected duration to be around %s, got %s", waitFn(testRetryCount), dur) } }) @@ -227,7 +229,7 @@ func Test500ErrorHandling(t *testing.T) { }) } -func Test_waitTime(t *testing.T) { +func Test_cubicWait(t *testing.T) { type args struct { attempt int } @@ -240,14 +242,44 @@ func Test_waitTime(t *testing.T) { {"attempt 1", args{1}, 27 * time.Second}, {"attempt 2", args{2}, 64 * time.Second}, {"attempt 2", args{4}, 216 * time.Second}, - {"attempt 100", args{5}, MaxAllowedWaitTime}, // check if capped properly - {"attempt 100", args{1000}, MaxAllowedWaitTime}, + {"attempt 100", args{5}, maxAllowedWaitTime}, // check if capped properly + {"attempt 100", args{1000}, maxAllowedWaitTime}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := waitTime(tt.args.attempt); !reflect.DeepEqual(got, tt.want) { + if got := cubicWait(tt.args.attempt); !reflect.DeepEqual(got, tt.want) { t.Errorf("waitTime() = %v, want %v", got, tt.want) } }) } } + +func Test_isRecoverable(t *testing.T) { + type args struct { + statusCode int + } + tests := []struct { + name string + args args + want bool + }{ + {"500", args{500}, true}, + {"502", args{502}, true}, + {"503", args{503}, true}, + {"504", args{504}, true}, + {"598", args{598}, true}, + {"599", args{599}, true}, + {"200", args{200}, false}, + {"400", args{400}, false}, + {"404", args{404}, false}, + {"408", args{408}, true}, + {"429", args{429}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isRecoverable(tt.args.statusCode); got != tt.want { + t.Errorf("isRecoverable() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/slackdump.go b/slackdump.go index b9f54b7d..16251397 100644 --- a/slackdump.go +++ b/slackdump.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "io" - "net/http" "os" "runtime/trace" "time" @@ -107,7 +106,7 @@ func NewWithOptions(ctx context.Context, authProvider auth.Provider, opts Option fs: fsadapter.NewDirectory("."), // default is to save attachments to the current directory. } - sd.propagateLogger(sd.l()) + network.SetLogger(sd.l()) if err := os.MkdirAll(opts.CacheDir, 0700); err != nil { return nil, fmt.Errorf("failed to create the cache directory: %s", err) @@ -173,14 +172,6 @@ func (sd *Session) SetFS(fs fsadapter.FS) { sd.fs = fs } -func toPtrCookies(cc []http.Cookie) []*http.Cookie { - var ret = make([]*http.Cookie, len(cc)) - for i := range cc { - ret[i] = &cc[i] - } - return ret -} - func (sd *Session) limiter(t network.Tier) *rate.Limiter { return network.NewLimiter(t, sd.options.Tier3Burst, int(sd.options.Tier3Boost)) } @@ -224,8 +215,3 @@ func (sd *Session) l() logger.Interface { } return sd.options.Logger } - -// propagateLogger propagates the slackdump logger to some dumb packages. -func (sd *Session) propagateLogger(l logger.Interface) { - network.Logger = l -}