diff --git a/pipeline/resolve.go b/pipeline/resolve.go index 5ae24c804..c2840ec57 100644 --- a/pipeline/resolve.go +++ b/pipeline/resolve.go @@ -1,10 +1,12 @@ package pipeline import ( - "connectrpc.com/connect" "context" "errors" "fmt" + "sync/atomic" + + "connectrpc.com/connect" "github.com/streamingfast/bstream" "github.com/streamingfast/bstream/hub" pbbstream "github.com/streamingfast/bstream/pb/sf/bstream/v1" @@ -15,7 +17,6 @@ import ( pbsubstreams "github.com/streamingfast/substreams/pb/sf/substreams/v1" "github.com/streamingfast/substreams/reqctx" "go.uber.org/zap" - "sync/atomic" ) type getBlockFunc func() (uint64, error) @@ -48,7 +49,8 @@ func BuildRequestDetails( getRecentFinalBlock getBlockFunc, resolveCursor CursorResolver, getHeadBlock getBlockFunc, - segmentSize uint64) (req *reqctx.RequestDetails, undoSignal *pbsubstreamsrpc.BlockUndoSignal, err error) { + segmentSize uint64, +) (req *reqctx.RequestDetails, undoSignal *pbsubstreamsrpc.BlockUndoSignal, err error) { req = &reqctx.RequestDetails{ Modules: request.Modules, OutputModule: request.OutputModule, @@ -95,7 +97,7 @@ func BuildRequestDetails( return } -func BuildRequestDetailsFromSubrequest(request *pbssinternal.ProcessRangeRequest) (req *reqctx.RequestDetails) { +func BuildRequestDetailsFromSubrequest(ctx context.Context, request *pbssinternal.ProcessRangeRequest) (req *reqctx.RequestDetails) { req = &reqctx.RequestDetails{ Modules: request.Modules, OutputModule: request.OutputModule, @@ -107,6 +109,9 @@ func BuildRequestDetailsFromSubrequest(request *pbssinternal.ProcessRangeRequest ResolvedStartBlockNum: request.StartBlock(), UniqueID: nextUniqueID(), } + + req.SetStageLayerParallelExecutorCountFromContext(ctx) + return req } diff --git a/reqctx/request.go b/reqctx/request.go index 9b1afba24..a123e1115 100644 --- a/reqctx/request.go +++ b/reqctx/request.go @@ -1,8 +1,10 @@ package reqctx import ( + "context" "strconv" + "github.com/streamingfast/dauth" pbsubstreams "github.com/streamingfast/substreams/pb/sf/substreams/v1" ) @@ -45,3 +47,18 @@ func (d *RequestDetails) ShouldStreamCachedOutputs() bool { return d.ProductionMode && d.ResolvedStartBlockNum < d.LinearHandoffBlockNum } + +// SetStageLayerParallelExecutorCountFromContext sets the MaxStageLayerParallelExecutor from the context +// by first retrieving the dauth trusted headers and then parsing the value from the header, if present. +func (d *RequestDetails) SetStageLayerParallelExecutorCountFromContext(ctx context.Context) { + trustedHeaders := dauth.FromContext(ctx) + if trustedHeaders == nil { + return + } + + if parallelExecutors := trustedHeaders.Get("X-Sf-Substreams-Stage-Layer-Parallel-Executor-Max-Count"); parallelExecutors != "" { + if count, err := strconv.ParseUint(parallelExecutors, 10, 64); err == nil { + d.MaxStageLayerParallelExecutor = count + } + } +} diff --git a/service/tier1.go b/service/tier1.go index 7937ac1c6..489204d1a 100644 --- a/service/tier1.go +++ b/service/tier1.go @@ -437,18 +437,15 @@ func (s *Tier1Service) blocks(ctx context.Context, request *pbsubstreamsrpc.Requ requestDetails.MaxParallelJobs = count } } - if parallelExecutors := auth.Get("X-Sf-Substreams-Stage-Layer-Parallel-Executor-Max-Count"); parallelExecutors != "" { - if count, err := strconv.ParseUint(parallelExecutors, 10, 64); err == nil { - requestDetails.MaxStageLayerParallelExecutor = count - } - } - if ct := auth.Get("X-Sf-Substreams-Cache-Tag"); ct != "" { - if IsValidCacheTag(ct) { - cacheTag = ct + if tag := auth.Get("X-Sf-Substreams-Cache-Tag"); tag != "" { + if IsValidCacheTag(tag) { + cacheTag = tag } else { - return fmt.Errorf("invalid value for X-Sf-Substreams-Cache-Tag %s, should only contain letters, numbers, hyphens and underscores", ct) + return fmt.Errorf("invalid value for X-Sf-Substreams-Cache-Tag %s, should only contain letters, numbers, hyphens and underscores", tag) } } + + requestDetails.SetStageLayerParallelExecutorCountFromContext(ctx) } var requestStats *metrics.Stats diff --git a/service/tier2.go b/service/tier2.go index 76840a7c4..074f918dd 100644 --- a/service/tier2.go +++ b/service/tier2.go @@ -267,7 +267,7 @@ func (s *Tier2Service) processRange(ctx context.Context, request *pbssinternal.P return stream.NewErrInvalidArg(err.Error()) } - requestDetails := pipeline.BuildRequestDetailsFromSubrequest(request) + requestDetails := pipeline.BuildRequestDetailsFromSubrequest(ctx, request) ctx = reqctx.WithRequest(ctx, requestDetails) if s.moduleExecutionTracing { ctx = reqctx.WithModuleExecutionTracing(ctx)