Skip to content

Commit

Permalink
feat: add extensions.code to rate limiting error
Browse files Browse the repository at this point in the history
  • Loading branch information
jensneuse committed Jan 8, 2025
1 parent 20e4f82 commit 6f6575d
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 8 deletions.
7 changes: 7 additions & 0 deletions v2/pkg/engine/resolve/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ type RateLimitOptions struct {
Period time.Duration
RateLimitKey string
RejectExceedingRequests bool

ErrorExtensionCode RateLimitErrorExtensionCode
}

type RateLimitErrorExtensionCode struct {
Enabled bool
Code string
}

type RateLimitDeny struct {
Expand Down
27 changes: 19 additions & 8 deletions v2/pkg/engine/resolve/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -1034,35 +1034,46 @@ func (l *Loader) renderAuthorizationRejectedErrors(fetchItem *FetchItem, res *re
func (l *Loader) renderRateLimitRejectedErrors(fetchItem *FetchItem, res *result) error {
l.ctx.appendSubgraphError(goerrors.Join(res.err, NewRateLimitError(res.ds.Name, fetchItem.ResponsePath, res.rateLimitRejectedReason)))
pathPart := l.renderAtPathErrorPart(fetchItem.ResponsePath)
var (
err error
errorObject *astjson.Value
)
if res.ds.Name == "" {
if res.rateLimitRejectedReason == "" {
errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s."}`, pathPart))
errorObject, err = astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s."}`, pathPart))
if err != nil {
return err
}
astjson.AppendToArray(l.resolvable.errors, errorObject)
} else {
errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s, Reason: %s."}`, pathPart, res.rateLimitRejectedReason))
errorObject, err = astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s, Reason: %s."}`, pathPart, res.rateLimitRejectedReason))
if err != nil {
return err
}
astjson.AppendToArray(l.resolvable.errors, errorObject)
}
} else {
if res.rateLimitRejectedReason == "" {
errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s."}`, res.ds.Name, pathPart))
errorObject, err = astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s."}`, res.ds.Name, pathPart))
if err != nil {
return err
}
astjson.AppendToArray(l.resolvable.errors, errorObject)
} else {
errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s, Reason: %s."}`, res.ds.Name, pathPart, res.rateLimitRejectedReason))
errorObject, err = astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s, Reason: %s."}`, res.ds.Name, pathPart, res.rateLimitRejectedReason))
if err != nil {
return err
}
astjson.AppendToArray(l.resolvable.errors, errorObject)
}
}
if l.ctx.RateLimitOptions.ErrorExtensionCode.Enabled {
extension, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"code":"%s"}`, l.ctx.RateLimitOptions.ErrorExtensionCode.Code))
if err != nil {
return err
}
errorObject, _, err = astjson.MergeValuesWithPath(errorObject, extension, "extensions")
if err != nil {
return err
}
}
astjson.AppendToArray(l.resolvable.errors, errorObject)
return nil
}

Expand Down
16 changes: 16 additions & 0 deletions v2/pkg/engine/resolve/ratelimit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,22 @@ func TestRateLimiter(t *testing.T) {
assert.Equal(t, int64(1), limiter.rateLimitPreFetchCalls.Load())
}
}))
t.Run("deny with code", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) {

limiter := &testRateLimiter{
allowFn: func(ctx *Context, info *FetchInfo, input json.RawMessage) (*RateLimitDeny, error) {
return &RateLimitDeny{Reason: "rate limit exceeded"}, nil
},
}

res := generateTestFederationGraphQLResponse(t, ctrl)

return res, &Context{ctx: context.Background(), Variables: nil, rateLimiter: limiter, RateLimitOptions: RateLimitOptions{Enable: true, ErrorExtensionCode: RateLimitErrorExtensionCode{Enabled: true, Code: "RATE_LIMIT_EXCEEDED"}}},
`{"errors":[{"message":"Rate limit exceeded for Subgraph 'users' at Path 'query', Reason: rate limit exceeded.","extensions":{"code":"RATE_LIMIT_EXCEEDED"}},{"message":"Failed to fetch from Subgraph 'reviews' at Path 'query.me'.","extensions":{"errors":[{"message":"Failed to render Fetch Input","path":["me"]}]}},{"message":"Failed to fetch from Subgraph 'products' at Path '[email protected]'.","extensions":{"errors":[{"message":"Failed to render Fetch Input","path":["me","reviews","@","product"]}]}}],"data":{"me":null}}`,
func(t *testing.T) {
assert.Equal(t, int64(1), limiter.rateLimitPreFetchCalls.Load())
}
}))
t.Run("err all", testFnWithError(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) {

limiter := &testRateLimiter{
Expand Down

0 comments on commit 6f6575d

Please sign in to comment.