diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 7db2ac172..fde2a351e 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -21,7 +21,7 @@ import ( ) const ( - HeartbeatInterval = 5 * time.Second + DefaultHeartbeatInterval = 5 * time.Second ) var ( @@ -65,8 +65,9 @@ type Resolver struct { reporter Reporter asyncErrorWriter AsyncErrorWriter - propagateSubgraphErrors bool - propagateSubgraphStatusCodes bool + propagateSubgraphErrors bool + propagateSubgraphStatusCodes bool + multipartSubHeartbeatInterval time.Duration } func (r *Resolver) SetAsyncErrorWriter(w AsyncErrorWriter) { @@ -142,6 +143,8 @@ type ResolverOptions struct { ResolvableOptions ResolvableOptions // AllowedCustomSubgraphErrorFields defines which fields are allowed in the subgraph error when in passthrough mode AllowedSubgraphErrorFields []string + // MultipartSubHeartbeatInterval defines the interval in which a heartbeat is sent to all multipart subscriptions + MultipartSubHeartbeatInterval time.Duration } // New returns a new Resolver, ctx.Done() is used to cancel all active subscriptions & streams @@ -151,6 +154,10 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { options.MaxConcurrency = 32 } + if options.MultipartSubHeartbeatInterval <= 0 { + options.MultipartSubHeartbeatInterval = DefaultHeartbeatInterval + } + // We transform the allowed fields into a map for faster lookups allowedExtensionFields := make(map[string]struct{}, len(options.AllowedErrorExtensionFields)) for _, field := range options.AllowedErrorExtensionFields { @@ -176,18 +183,19 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { } resolver := &Resolver{ - ctx: ctx, - options: options, - propagateSubgraphErrors: options.PropagateSubgraphErrors, - propagateSubgraphStatusCodes: options.PropagateSubgraphStatusCodes, - events: make(chan subscriptionEvent), - triggers: make(map[uint64]*trigger), - heartbeatSubscriptions: make(map[*Context]*sub), - reporter: options.Reporter, - asyncErrorWriter: options.AsyncErrorWriter, - triggerUpdateBuf: bytes.NewBuffer(make([]byte, 0, 1024)), - allowedErrorExtensionFields: allowedExtensionFields, - allowedErrorFields: allowedErrorFields, + ctx: ctx, + options: options, + propagateSubgraphErrors: options.PropagateSubgraphErrors, + propagateSubgraphStatusCodes: options.PropagateSubgraphStatusCodes, + events: make(chan subscriptionEvent), + triggers: make(map[uint64]*trigger), + heartbeatSubscriptions: make(map[*Context]*sub), + reporter: options.Reporter, + asyncErrorWriter: options.AsyncErrorWriter, + triggerUpdateBuf: bytes.NewBuffer(make([]byte, 0, 1024)), + allowedErrorExtensionFields: allowedExtensionFields, + allowedErrorFields: allowedErrorFields, + multipartSubHeartbeatInterval: options.MultipartSubHeartbeatInterval, } resolver.maxConcurrency = make(chan struct{}, options.MaxConcurrency) for i := 0; i < options.MaxConcurrency; i++ { @@ -358,7 +366,7 @@ func (r *Resolver) executeSubscriptionUpdate(ctx *Context, sub *sub, sharedInput func (r *Resolver) handleEvents() { done := r.ctx.Done() - heartbeat := time.NewTicker(HeartbeatInterval) + heartbeat := time.NewTicker(r.multipartSubHeartbeatInterval) defer heartbeat.Stop() for { select { @@ -407,7 +415,7 @@ func (r *Resolver) handleHeartbeat(data []byte) { // check if the last write to the subscription was more than heartbeat interval ago c, s := c, s s.mux.Lock() - skipHeartbeat := now.Sub(s.lastWrite) < HeartbeatInterval + skipHeartbeat := now.Sub(s.lastWrite) < r.multipartSubHeartbeatInterval s.mux.Unlock() if skipHeartbeat { continue