Skip to content

Commit

Permalink
feat: add an optional burstable rate limiter
Browse files Browse the repository at this point in the history
The existing rate limiter was moved to a separate package and
renamed to IntervalLimiter. Added BurstLimiter which is a wrapper
around the "golang.org/x/time/rate" package.

The conf.Rate type now has a private `typ` field that indicates
if it is a "interval" or "burst" rate limiter. If the config value
is in the form of "<burst>/<rate>" we set it to "burst", otherwise
"interval". The conf.Rate.GetRateType() method is then called from
the ratelimit.New package to determine the underlying type of
`ratelimit.Limiter` returned from `ratelimit.New`.

Finally we changed `api.NewLimiterOptions` to call `ratelimit.New`
instead of creating a specific type of rate limiter.
  • Loading branch information
Chris Stockton committed Jan 22, 2025
1 parent 37e2349 commit 3746f70
Show file tree
Hide file tree
Showing 11 changed files with 521 additions and 199 deletions.
10 changes: 6 additions & 4 deletions internal/api/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@ import (
"github.com/didip/tollbooth/v5"
"github.com/didip/tollbooth/v5/limiter"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/ratelimit"
)

type Option interface {
apply(*API)
}

type LimiterOptions struct {
Email *RateLimiter
Phone *RateLimiter
Email ratelimit.Limiter
Phone ratelimit.Limiter

Signups *limiter.Limiter
AnonymousSignIns *limiter.Limiter
Expand All @@ -36,8 +37,9 @@ func (lo *LimiterOptions) apply(a *API) { a.limiterOpts = lo }
func NewLimiterOptions(gc *conf.GlobalConfiguration) *LimiterOptions {
o := &LimiterOptions{}

o.Email = newRateLimiter(gc.RateLimitEmailSent)
o.Phone = newRateLimiter(gc.RateLimitSmsSent)
o.Email = ratelimit.New(gc.RateLimitEmailSent)
o.Phone = ratelimit.New(gc.RateLimitSmsSent)

o.AnonymousSignIns = tollbooth.NewLimiter(gc.RateLimitAnonymousUsers/(60*60),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
Expand Down
49 changes: 0 additions & 49 deletions internal/api/ratelimits.go

This file was deleted.

125 changes: 0 additions & 125 deletions internal/api/ratelimits_test.go

This file was deleted.

17 changes: 12 additions & 5 deletions internal/conf/rate.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,28 @@ import (

const defaultOverTime = time.Hour

const (
BurstRateType = "burst"
IntervalRateType = "interval"
)

type Rate struct {
Events float64 `json:"events,omitempty"`
OverTime time.Duration `json:"over_time,omitempty"`
typ string
}

func (r *Rate) EventsPerSecond() float64 {
d := r.OverTime
if d == 0 {
d = defaultOverTime
func (r *Rate) GetRateType() string {
if r.typ == "" {
return IntervalRateType
}
return r.Events / d.Seconds()
return r.typ
}

// Decode is used by envconfig to parse the env-config string to a Rate value.
func (r *Rate) Decode(value string) error {
if f, err := strconv.ParseFloat(value, 64); err == nil {
r.typ = IntervalRateType
r.Events = f
r.OverTime = defaultOverTime
return nil
Expand All @@ -45,6 +51,7 @@ func (r *Rate) Decode(value string) error {
return fmt.Errorf("rate: over-time part of rate value %q failed to parse as duration: %w", value, err)
}

r.typ = BurstRateType
r.Events = float64(e)
r.OverTime = d
return nil
Expand Down
44 changes: 28 additions & 16 deletions internal/conf/rate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,39 @@ import (
func TestRateDecode(t *testing.T) {
cases := []struct {
str string
eps float64
exp Rate
err string
}{
{str: "1800", eps: 0.5, exp: Rate{Events: 1800, OverTime: time.Hour}},
{str: "1800.0", eps: 0.5, exp: Rate{Events: 1800, OverTime: time.Hour}},
{str: "3600/1h", eps: 1, exp: Rate{Events: 3600, OverTime: time.Hour}},
{str: "1800",
exp: Rate{Events: 1800, OverTime: time.Hour, typ: IntervalRateType}},
{str: "1800.0",
exp: Rate{Events: 1800, OverTime: time.Hour, typ: IntervalRateType}},
{str: "3600/1h",
exp: Rate{Events: 3600, OverTime: time.Hour, typ: BurstRateType}},
{str: "3600/1h0m0s",
exp: Rate{Events: 3600, OverTime: time.Hour, typ: BurstRateType}},
{str: "100/24h",
eps: 0.0011574074074074073,
exp: Rate{Events: 100, OverTime: time.Hour * 24}},
{str: "", eps: 1, exp: Rate{},
exp: Rate{Events: 100, OverTime: time.Hour * 24, typ: BurstRateType}},
{str: "", exp: Rate{},
err: `rate: value does not match`},
{str: "1h", eps: 1, exp: Rate{},
{str: "1h", exp: Rate{},
err: `rate: value does not match`},
{str: "/", eps: 1, exp: Rate{},
{str: "/", exp: Rate{},
err: `rate: events part of rate value`},
{str: "/1h", eps: 1, exp: Rate{},
{str: "/1h", exp: Rate{},
err: `rate: events part of rate value`},
{str: "3600.0/1h", eps: 1, exp: Rate{},
{str: "3600.0/1h", exp: Rate{},
err: `rate: events part of rate value "3600.0/1h" failed to parse`},
{str: "100/", eps: 1, exp: Rate{},
{str: "100/", exp: Rate{},
err: `rate: over-time part of rate value`},
{str: "100/1", eps: 1, exp: Rate{},
{str: "100/1", exp: Rate{},
err: `rate: over-time part of rate value`},

// zero events
{str: "0/1h", eps: 0.0, exp: Rate{Events: 0, OverTime: time.Hour}},
{str: "0/24h", eps: 0.0, exp: Rate{Events: 0, OverTime: time.Hour * 24}},
{str: "0/1h",
exp: Rate{Events: 0, OverTime: time.Hour, typ: BurstRateType}},
{str: "0/24h",
exp: Rate{Events: 0, OverTime: time.Hour * 24, typ: BurstRateType}},
}
for idx, tc := range cases {
var r Rate
Expand All @@ -51,6 +56,13 @@ func TestRateDecode(t *testing.T) {
}
require.NoError(t, err)
require.Equal(t, tc.exp, r)
require.Equal(t, tc.eps, r.EventsPerSecond())
require.Equal(t, tc.exp.typ, r.GetRateType())
}

// GetRateType() zero value
require.Equal(t, IntervalRateType, (&Rate{}).GetRateType())

// String()
require.Equal(t, "0.000000", (&Rate{}).String())
require.Equal(t, "100/1h0m0s", (&Rate{Events: 100, OverTime: time.Hour}).String())
}
60 changes: 60 additions & 0 deletions internal/ratelimit/burst.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package ratelimit

import (
"time"

"github.com/supabase/auth/internal/conf"
"golang.org/x/time/rate"
)

const defaultOverTime = time.Hour

// BurstLimiter wraps the golang.org/x/time/rate package.
type BurstLimiter struct {
rl *rate.Limiter
}

// NewBurstLimiter returns a rate limiter configured using the given conf.Rate.
//
// The returned Limiter will be configured with a token bucket containing a
// single token, which will fill up at a rate of 1 event per r.OverTime with
// an initial burst amount of r.Events.
//
// For example:
// - 1/10s is 1 events per 10 seconds with burst of 1.
// - 1/2s is 1 events per 2 seconds with burst of 1.
// - 10/10s is 1 events per 10 seconds with burst of 10.
//
// If Rate.Events is <= 0, the burst amount will be set to 1.
//
// See Example_newBurstLimiter for a visualization.
func NewBurstLimiter(r conf.Rate) *BurstLimiter {
// The rate limiter deals in events per second.
d := r.OverTime
if d <= 0 {
d = defaultOverTime
}

e := r.Events
if e <= 0 {
e = 1
}

// BurstLimiter will have an initial token bucket of size `e`. It will
// be refilled at a rate of 1 per duration `d` indefinitely.
rl := &BurstLimiter{
rl: rate.NewLimiter(rate.Every(d), int(e)),
}
return rl
}

// Allow implements Limiter by calling AllowAt with the current time.
func (l *BurstLimiter) Allow() bool {
return l.AllowAt(time.Now())
}

// AllowAt implements Limiter by calling the underlying x/time/rate.Limiter
// with the given time.
func (l *BurstLimiter) AllowAt(at time.Time) bool {
return l.rl.AllowN(at, 1)
}
Loading

0 comments on commit 3746f70

Please sign in to comment.