Skip to content

Commit

Permalink
feat: improve rate limiting with better customization (#1476)
Browse files Browse the repository at this point in the history
Co-authored-by: Dustin Deus <[email protected]>
  • Loading branch information
jensneuse and StarpTech authored Jan 7, 2025
1 parent c2f2131 commit ffcb634
Show file tree
Hide file tree
Showing 12 changed files with 540 additions and 38 deletions.
261 changes: 252 additions & 9 deletions router-tests/ratelimit_test.go

Large diffs are not rendered by default.

12 changes: 11 additions & 1 deletion router/core/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ package core

import (
"context"
"github.com/wundergraph/cosmo/router/internal/expr"
"net/http"
"net/url"
"strings"
"sync"
"time"

"github.com/expr-lang/expr/vm"
"github.com/wundergraph/cosmo/router/internal/expr"

"github.com/wundergraph/astjson"
graphqlmetrics "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/graphqlmetrics/v1"
"github.com/wundergraph/cosmo/router/pkg/config"
Expand Down Expand Up @@ -252,6 +254,14 @@ type requestContext struct {
expressionContext expr.Context
}

func (c *requestContext) ResolveStringExpression(expression *vm.Program) (string, error) {
return expr.ResolveStringExpression(expression, c.expressionContext)
}

func (c *requestContext) ResolveBoolExpression(expression *vm.Program) (bool, error) {
return expr.ResolveBoolExpression(expression, c.expressionContext)
}

func (c *requestContext) Operation() OperationContext {
return c.operation
}
Expand Down
11 changes: 8 additions & 3 deletions router/core/graph_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -964,10 +964,15 @@ func (s *graphServer) buildGraphMux(ctx context.Context,

if s.redisClient != nil {
handlerOpts.RateLimitConfig = s.rateLimit
handlerOpts.RateLimiter = NewCosmoRateLimiter(&CosmoRateLimiterOptions{
RedisClient: s.redisClient,
Debug: s.rateLimit.Debug,
handlerOpts.RateLimiter, err = NewCosmoRateLimiter(&CosmoRateLimiterOptions{
RedisClient: s.redisClient,
Debug: s.rateLimit.Debug,
RejectStatusCode: s.rateLimit.SimpleStrategy.RejectStatusCode,
KeySuffixExpression: s.rateLimit.KeySuffixExpression,
})
if err != nil {
return nil, fmt.Errorf("failed to create rate limiter: %w", err)
}
}

graphqlHandler := NewGraphQLHandler(handlerOpts)
Expand Down
2 changes: 1 addition & 1 deletion router/core/graphql_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ func (h *GraphQLHandler) WriteError(ctx *resolve.Context, err error, res *resolv
RateLimit: buf.Bytes(),
}
if isHttpResponseWriter {
httpWriter.WriteHeader(http.StatusOK) // Always return 200 OK when we return a well-formed response
httpWriter.WriteHeader(h.rateLimiter.RejectStatusCode())
}
case errorTypeUnauthorized:
response.Errors[0].Message = "Unauthorized"
Expand Down
3 changes: 2 additions & 1 deletion router/core/graphql_prehandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ import (
"context"
"crypto/ecdsa"
"fmt"
"github.com/wundergraph/cosmo/router/internal/expr"
"net/http"
"strconv"
"strings"
"sync"
"time"

"github.com/wundergraph/cosmo/router/internal/expr"

"github.com/wundergraph/cosmo/router/pkg/config"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
Expand Down
75 changes: 63 additions & 12 deletions router/core/ratelimiter.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
package core

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"sync"

"github.com/expr-lang/expr/vm"
"github.com/go-redis/redis_rate/v10"
"github.com/redis/go-redis/v9"
"github.com/wundergraph/cosmo/router/internal/expr"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve"
)

Expand All @@ -19,21 +23,37 @@ var (
type CosmoRateLimiterOptions struct {
RedisClient *redis.Client
Debug bool

RejectStatusCode int

KeySuffixExpression string
}

func NewCosmoRateLimiter(opts *CosmoRateLimiterOptions) *CosmoRateLimiter {
func NewCosmoRateLimiter(opts *CosmoRateLimiterOptions) (rl *CosmoRateLimiter, err error) {
limiter := redis_rate.NewLimiter(opts.RedisClient)
return &CosmoRateLimiter{
client: opts.RedisClient,
limiter: limiter,
debug: opts.Debug,
rl = &CosmoRateLimiter{
client: opts.RedisClient,
limiter: limiter,
debug: opts.Debug,
rejectStatusCode: opts.RejectStatusCode,
}
if opts.KeySuffixExpression != "" {
rl.keySuffixProgram, err = expr.CompileStringExpression(opts.KeySuffixExpression)
if err != nil {
return nil, err
}
}
return rl, nil
}

type CosmoRateLimiter struct {
client *redis.Client
limiter *redis_rate.Limiter
debug bool

rejectStatusCode int

keySuffixProgram *vm.Program
}

func (c *CosmoRateLimiter) RateLimitPreFetch(ctx *resolve.Context, info *resolve.FetchInfo, input json.RawMessage) (result *resolve.RateLimitDeny, err error) {
Expand All @@ -46,11 +66,15 @@ func (c *CosmoRateLimiter) RateLimitPreFetch(ctx *resolve.Context, info *resolve
Burst: ctx.RateLimitOptions.Burst,
Period: ctx.RateLimitOptions.Period,
}
allow, err := c.limiter.AllowN(ctx.Context(), ctx.RateLimitOptions.RateLimitKey, limit, requestRate)
key, err := c.generateKey(ctx)
if err != nil {
return nil, err
}
allow, err := c.limiter.AllowN(ctx.Context(), key, limit, requestRate)
if err != nil {
return nil, err
}
c.setRateLimitStats(ctx, requestRate, allow.Remaining, allow.RetryAfter.Milliseconds(), allow.ResetAfter.Milliseconds())
c.setRateLimitStats(ctx, key, requestRate, allow.Remaining, allow.RetryAfter.Milliseconds(), allow.ResetAfter.Milliseconds())
if allow.Allowed >= requestRate {
return nil, nil
}
Expand All @@ -60,11 +84,35 @@ func (c *CosmoRateLimiter) RateLimitPreFetch(ctx *resolve.Context, info *resolve
return &resolve.RateLimitDeny{}, nil
}

func (c *CosmoRateLimiter) generateKey(ctx *resolve.Context) (string, error) {
if c.keySuffixProgram == nil {
return ctx.RateLimitOptions.RateLimitKey, nil
}
rc := getRequestContext(ctx.Context())
if rc == nil {
return "", errors.New("no request context")
}
str, err := rc.ResolveStringExpression(c.keySuffixProgram)
if err != nil {
return "", fmt.Errorf("failed to resolve key suffix expression: %w", err)
}
buf := bytes.NewBuffer(make([]byte, 0, len(ctx.RateLimitOptions.RateLimitKey)+len(str)+1))
_, _ = buf.WriteString(ctx.RateLimitOptions.RateLimitKey)
_ = buf.WriteByte(':')
_, _ = buf.WriteString(str)
return buf.String(), nil
}

func (c *CosmoRateLimiter) RejectStatusCode() int {
return c.rejectStatusCode
}

type RateLimitStats struct {
RequestRate int `json:"requestRate"`
Remaining int `json:"remaining"`
RetryAfterMilliseconds int64 `json:"retryAfterMs"`
ResetAfterMilliseconds int64 `json:"resetAfterMs"`
Key string `json:"key,omitempty"`
RequestRate int `json:"requestRate"`
Remaining int `json:"remaining"`
RetryAfterMilliseconds int64 `json:"retryAfterMs"`
ResetAfterMilliseconds int64 `json:"resetAfterMs"`
}

func (c *CosmoRateLimiter) RenderResponseExtension(ctx *resolve.Context, out io.Writer) error {
Expand Down Expand Up @@ -98,17 +146,20 @@ func (c *CosmoRateLimiter) statsJSON(ctx *resolve.Context) ([]byte, error) {
if c.debug {
stats.ResetAfterMilliseconds = 1234
stats.RetryAfterMilliseconds = 1234
} else {
stats.Key = "" // hide key when not in debug mode
}
return json.Marshal(stats)
}

func (c *CosmoRateLimiter) setRateLimitStats(ctx *resolve.Context, requestRate, remaining int, retryAfter, resetAfter int64) {
func (c *CosmoRateLimiter) setRateLimitStats(ctx *resolve.Context, key string, requestRate, remaining int, retryAfter, resetAfter int64) {
v := ctx.Context().Value(rateLimitStatsCtxKey{})
if v == nil {
return
}
statsCtx := v.(*rateLimitStatsCtx)
statsCtx.mux.Lock()
statsCtx.stats.Key = key
statsCtx.stats.RequestRate = statsCtx.stats.RequestRate + requestRate
statsCtx.stats.Remaining = remaining
statsCtx.stats.RetryAfterMilliseconds = retryAfter
Expand Down
161 changes: 161 additions & 0 deletions router/core/ratelimiter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package core

import (
"context"
"net/http"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/wundergraph/cosmo/router/internal/expr"
"github.com/wundergraph/cosmo/router/pkg/authentication"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve"
)

func expressionResolveContext(t *testing.T, header http.Header, claims map[string]any) *resolve.Context {
req, err := http.NewRequest(http.MethodGet, "http://localhost:3002/graphql", nil)
assert.NoError(t, err)
if header != nil {
req.Header = header
}
rcc := buildRequestContext(requestContextOptions{
r: req,
})
ctx := withRequestContext(context.Background(), rcc)
rc := &resolve.Context{
RateLimitOptions: resolve.RateLimitOptions{
RateLimitKey: "test",
},
}
if claims != nil {
rc = ContextWithClaims(rc, claims)
rcc.expressionContext.Request.Auth = expr.LoadAuth(rc.Context())
}
return rc.WithContext(ctx)
}

func TestRateLimiterGenerateKey(t *testing.T) {
t.Parallel()
t.Run("default", func(t *testing.T) {
t.Parallel()
rl, err := NewCosmoRateLimiter(&CosmoRateLimiterOptions{})
assert.NoError(t, err)
key, err := rl.generateKey(expressionResolveContext(t, nil, nil))
assert.NoError(t, err)
assert.Equal(t, "test", key)
})
t.Run("from header", func(t *testing.T) {
t.Parallel()
rl, err := NewCosmoRateLimiter(&CosmoRateLimiterOptions{
KeySuffixExpression: "request.header.Get('Authorization')",
})
require.NoError(t, err)
key, err := rl.generateKey(
expressionResolveContext(t, http.Header{"Authorization": []string{"token"}}, nil),
)
assert.NoError(t, err)
assert.Equal(t, "test:token", key)
})
t.Run("from header number", func(t *testing.T) {
t.Parallel()
rl, err := NewCosmoRateLimiter(&CosmoRateLimiterOptions{
KeySuffixExpression: "request.header.Get('Authorization')",
})
assert.NoError(t, err)
key, err := rl.generateKey(
expressionResolveContext(t, http.Header{"Authorization": []string{"123"}}, nil),
)
assert.NoError(t, err)
assert.Equal(t, "test:123", key)
})
t.Run("from header whitespace", func(t *testing.T) {
t.Parallel()
rl, err := NewCosmoRateLimiter(&CosmoRateLimiterOptions{
KeySuffixExpression: "trim(request.header.Get('Authorization'))",
})
assert.NoError(t, err)
key, err := rl.generateKey(
expressionResolveContext(t, http.Header{"Authorization": []string{" token "}}, nil),
)
assert.NoError(t, err)
assert.Equal(t, "test:token", key)
})
t.Run("from claims", func(t *testing.T) {
t.Parallel()
rl, err := NewCosmoRateLimiter(&CosmoRateLimiterOptions{
KeySuffixExpression: "request.auth.claims.sub",
})
assert.NoError(t, err)
key, err := rl.generateKey(
expressionResolveContext(t, nil, map[string]any{"sub": "token"}),
)
assert.NoError(t, err)
assert.Equal(t, "test:token", key)
})
t.Run("from claims invalid claim", func(t *testing.T) {
t.Parallel()
rl, err := NewCosmoRateLimiter(&CosmoRateLimiterOptions{
KeySuffixExpression: "request.auth.claims.sub",
})
assert.NoError(t, err)
key, err := rl.generateKey(
expressionResolveContext(t, nil, map[string]any{"sub": 123}),
)
assert.Error(t, err)
assert.Empty(t, key)
})
t.Run("from claims or X-Forwarded-For header claims present", func(t *testing.T) {
t.Parallel()
rl, err := NewCosmoRateLimiter(&CosmoRateLimiterOptions{
KeySuffixExpression: "request.auth.claims.sub ?? request.header.Get('X-Forwarded-For')",
})
assert.NoError(t, err)
key, err := rl.generateKey(
expressionResolveContext(t, http.Header{"X-Forwarded-For": []string{"192.168.0.1"}}, map[string]any{"sub": "token"}),
)
assert.NoError(t, err)
assert.Equal(t, "test:token", key)
})
t.Run("from claims or X-Forwarded-For header claims not present", func(t *testing.T) {
t.Parallel()
rl, err := NewCosmoRateLimiter(&CosmoRateLimiterOptions{
KeySuffixExpression: "request.auth.claims.sub ?? request.header.Get('X-Forwarded-For')",
})
assert.NoError(t, err)
key, err := rl.generateKey(
expressionResolveContext(t, http.Header{"X-Forwarded-For": []string{"192.168.0.1"}}, nil),
)
assert.NoError(t, err)
assert.Equal(t, "test:192.168.0.1", key)
})
}

func ContextWithClaims(ctx *resolve.Context, claims map[string]any) *resolve.Context {
auth := &FakeAuthenticator{
claims: claims,
}
withScopes := authentication.NewContext(context.Background(), auth)
return ctx.WithContext(withScopes)
}

type FakeAuthenticator struct {
claims map[string]any
scopes []string
}

func (f *FakeAuthenticator) Authenticator() string {
return "fake"
}

func (f *FakeAuthenticator) Claims() authentication.Claims {
return f.claims
}

func (f *FakeAuthenticator) SetScopes(scopes []string) {
//TODO implement me
panic("implement me")
}

func (f *FakeAuthenticator) Scopes() []string {
return f.scopes
}
Loading

0 comments on commit ffcb634

Please sign in to comment.