From 60459dcb59290bfb921bc9ce20b5d15b67ea7c4a Mon Sep 17 00:00:00 2001 From: Haytham AbuelFutuh Date: Mon, 16 Dec 2019 13:25:29 -0800 Subject: [PATCH] Make catalog async client array-aware --- .../pluginmachinery/catalog/async_client.go | 61 +++++- .../catalog/async_client_impl.go | 175 +++++++++-------- .../catalog/async_client_impl_test.go | 107 ++++++++--- go/tasks/pluginmachinery/catalog/config.go | 31 ++- .../pluginmachinery/catalog/config_flags.go | 14 +- .../catalog/config_flags_test.go | 116 ++++++++---- .../catalog/mocks/async_client.go | 128 ++++++++++--- .../catalog/reader_array_processor.go | 131 +++++++++++++ .../catalog/reader_processor.go | 84 --------- go/tasks/pluginmachinery/catalog/response.go | 4 +- .../catalog/writer_array_processor.go | 116 ++++++++++++ .../catalog/writer_processor.go | 56 ------ go/tasks/plugins/array/catalog.go | 178 ++++-------------- go/tasks/plugins/array/catalog_test.go | 2 +- tests/end_to_end.go | 23 ++- 15 files changed, 733 insertions(+), 493 deletions(-) create mode 100644 go/tasks/pluginmachinery/catalog/reader_array_processor.go delete mode 100644 go/tasks/pluginmachinery/catalog/reader_processor.go create mode 100644 go/tasks/pluginmachinery/catalog/writer_array_processor.go delete mode 100644 go/tasks/pluginmachinery/catalog/writer_processor.go diff --git a/go/tasks/pluginmachinery/catalog/async_client.go b/go/tasks/pluginmachinery/catalog/async_client.go index d6c4604ac..992612fa2 100644 --- a/go/tasks/pluginmachinery/catalog/async_client.go +++ b/go/tasks/pluginmachinery/catalog/async_client.go @@ -3,6 +3,10 @@ package catalog import ( "context" + "github.com/lyft/flytestdlib/storage" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytestdlib/bitarray" "github.com/lyft/flytestdlib/errors" @@ -24,7 +28,8 @@ const ( type UploadRequest struct { Key Key - ArtifactData io.OutputReader + ArtifactData io.OutputFilePaths + DataStore *storage.DataStore ArtifactMetadata Metadata } @@ -49,8 +54,9 @@ type UploadFuture interface { // Catalog Download Request to represent async operation download request. type DownloadRequest struct { - Key Key - Target io.OutputWriter + Key Key + Target io.OutputWriter + DataStore *storage.DataStore } // Catalog download future to represent async process of downloading catalog artifacts. @@ -73,13 +79,58 @@ type DownloadResponse interface { GetCachedCount() int } +// Catalog Download Request to represent async operation download request. +type DownloadArrayRequest struct { + // Identifier is the same among all subtasks of the array + Identifier core.Identifier + // Cache version is the same among all subtasks of the array + CacheVersion string + // Interface is the same among all subtasks of the array + TypedInterface core.TypedInterface + + dataStore *storage.DataStore + + // Base input reader to build subtasks input readers from + BaseInputReader io.InputReader + // Base output writer to build subtasks input readers from + BaseTarget io.OutputWriter + Indexes *bitarray.BitSet + Count int +} + +type UploadArrayRequest struct { + // Identifier is the same among all subtasks of the array + Identifier core.Identifier + // Cache version is the same among all subtasks of the array + CacheVersion string + // Interface is the same among all subtasks of the array + TypedInterface core.TypedInterface + // ArtifactMetadata is the same among all subtasks of the array + ArtifactMetadata Metadata + + dataStore *storage.DataStore + + // Base input reader to build subtasks input readers from + BaseInputReader io.InputReader + // Base output reader to build subtasks input readers from + BaseArtifactData io.OutputFilePaths + Indexes *bitarray.BitSet + Count int +} + // An interface that helps async interaction with catalog service type AsyncClient interface { // Returns if an entry exists for the given task and input. It returns the data as a LiteralMap - Download(ctx context.Context, requests ...DownloadRequest) (outputFuture DownloadFuture, err error) + Download(ctx context.Context, request DownloadRequest) (outputFuture DownloadFuture, err error) + + // Adds a new entry to catalog for the given task execution context and the generated output + Upload(ctx context.Context, request UploadRequest) (putFuture UploadFuture, err error) + + // Returns if an entry exists for the given task and input. It returns the data as a LiteralMap + DownloadArray(ctx context.Context, request DownloadArrayRequest) (outputFuture DownloadFuture, err error) // Adds a new entry to catalog for the given task execution context and the generated output - Upload(ctx context.Context, requests ...UploadRequest) (putFuture UploadFuture, err error) + UploadArray(ctx context.Context, requests UploadArrayRequest) (putFuture UploadFuture, err error) } var _ AsyncClient = AsyncClientImpl{} diff --git a/go/tasks/pluginmachinery/catalog/async_client_impl.go b/go/tasks/pluginmachinery/catalog/async_client_impl.go index cee7dc39b..8fac81d67 100644 --- a/go/tasks/pluginmachinery/catalog/async_client_impl.go +++ b/go/tasks/pluginmachinery/catalog/async_client_impl.go @@ -9,8 +9,6 @@ import ( "github.com/lyft/flytestdlib/promutils" - "github.com/lyft/flytestdlib/bitarray" - "github.com/lyft/flytestdlib/errors" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/workqueue" @@ -22,12 +20,8 @@ var base32Encoder = base32.NewEncoding(specialEncoderKey).WithPadding(base32.NoP // An async-client for catalog that can queue download and upload requests on workqueues. type AsyncClientImpl struct { - Reader workqueue.IndexedWorkQueue - Writer workqueue.IndexedWorkQueue -} - -func formatWorkItemID(key Key, idx int, suffix string) string { - return fmt.Sprintf("%v-%v-%v", key, idx, suffix) + ArrayReader workqueue.IndexedWorkQueue + ArrayWriter workqueue.IndexedWorkQueue } func consistentHash(str string) (string, error) { @@ -41,96 +35,113 @@ func consistentHash(str string) (string, error) { return base32Encoder.EncodeToString(b), nil } -func (c AsyncClientImpl) Download(ctx context.Context, requests ...DownloadRequest) (outputFuture DownloadFuture, err error) { - status := ResponseStatusReady - cachedResults := bitarray.NewBitSet(uint(len(requests))) - cachedCount := 0 - var respErr error - for idx, request := range requests { - uniqueOutputLoc, err := consistentHash(request.Target.GetOutputPrefixPath().String()) - if err != nil { - return nil, err - } +// Returns if an entry exists for the given task and input. It returns the data as a LiteralMap +func (c AsyncClientImpl) DownloadArray(ctx context.Context, request DownloadArrayRequest) (outputFuture DownloadFuture, err error) { + workItemID := fmt.Sprintf("%v-%v-%v-%v-%v-%v", request.Identifier.String(), request.Count, + request.BaseTarget.GetOutputPrefixPath(), request.TypedInterface, request.BaseInputReader.GetInputPrefixPath(), + request.CacheVersion) - workItemID := formatWorkItemID(request.Key, idx, uniqueOutputLoc) - err = c.Reader.Queue(ctx, workItemID, NewReaderWorkItem( - request.Key, - request.Target)) + hashedID, err := consistentHash(workItemID) + if err != nil { + return nil, err + } - if err != nil { - return nil, err - } + err = c.ArrayReader.Queue(ctx, hashedID, NewArrayReaderWorkItem(request)) + if err != nil { + return nil, err + } - info, found, err := c.Reader.Get(workItemID) - if err != nil { - return nil, errors.Wrapf(ErrSystemError, err, "Failed to lookup from reader workqueue for info: %v", workItemID) - } + info, found, err := c.ArrayReader.Get(hashedID) + if err != nil { + return nil, errors.Wrapf(ErrSystemError, err, "Failed to lookup from reader workqueue for info: %v", workItemID) + } - if !found { - return nil, errors.Errorf(ErrSystemError, "Item not found in the reader workqueue even though it was just added. ItemID: %v", workItemID) - } + if !found { + return nil, errors.Errorf(ErrSystemError, "Item not found in the reader workqueue even though it was just added. ItemID: %v", workItemID) + } - switch info.Status() { - case workqueue.WorkStatusSucceeded: - readerWorkItem, casted := info.Item().(*ReaderWorkItem) - if !casted { - return nil, errors.Errorf(ErrSystemError, "Item wasn't casted to ReaderWorkItem. ItemID: %v. Type: %v", workItemID, reflect.TypeOf(info)) - } - - if readerWorkItem.IsCached() { - cachedResults.Set(uint(idx)) - cachedCount++ - } - case workqueue.WorkStatusFailed: - respErr = info.Error() - case workqueue.WorkStatusNotDone: - status = ResponseStatusNotReady + switch info.Status() { + case workqueue.WorkStatusSucceeded: + readerWorkItem, casted := info.Item().(*ArrayReaderWorkItem) + if !casted { + return nil, errors.Errorf(ErrSystemError, "Item wasn't casted to ReaderWorkItem. ItemID: %v. Type: %v", workItemID, reflect.TypeOf(info)) } - } - return newDownloadFuture(status, respErr, cachedResults, len(requests), cachedCount), nil + return newDownloadFuture(ResponseStatusReady, nil, readerWorkItem.CachedResults(), request.Count), nil + case workqueue.WorkStatusFailed: + return newDownloadFuture(ResponseStatusReady, info.Error(), nil, request.Count), nil + default: + return newDownloadFuture(ResponseStatusNotReady, nil, nil, request.Count), nil + } } -func (c AsyncClientImpl) Upload(ctx context.Context, requests ...UploadRequest) (putFuture UploadFuture, err error) { - status := ResponseStatusReady - var respErr error - for idx, request := range requests { - workItemID := formatWorkItemID(request.Key, idx, "") - err := c.Writer.Queue(ctx, workItemID, NewWriterWorkItem( - request.Key, - request.ArtifactData, - request.ArtifactMetadata)) - - if err != nil { - return nil, err - } +// Adds a new entry to catalog for the given task execution context and the generated output +func (c AsyncClientImpl) UploadArray(ctx context.Context, request UploadArrayRequest) (putFuture UploadFuture, err error) { + workItemID := fmt.Sprintf("%v-%v-%v-%v-%v", request.Identifier.String(), request.Count, + request.TypedInterface, request.BaseInputReader.GetInputPrefixPath(), request.CacheVersion) - info, found, err := c.Writer.Get(workItemID) - if err != nil { - return nil, errors.Wrapf(ErrSystemError, err, "Failed to lookup from writer workqueue for info: %v", workItemID) - } + hashedID, err := consistentHash(workItemID) + if err != nil { + return nil, err + } - if !found { - return nil, errors.Errorf(ErrSystemError, "Item not found in the writer workqueue even though it was just added. ItemID: %v", workItemID) - } + err = c.ArrayWriter.Queue(ctx, hashedID, NewArrayWriterWorkItem(request)) + if err != nil { + return nil, err + } - switch info.Status() { - case workqueue.WorkStatusNotDone: - status = ResponseStatusNotReady - case workqueue.WorkStatusFailed: - respErr = info.Error() - } + info, found, err := c.ArrayWriter.Get(hashedID) + if err != nil { + return nil, errors.Wrapf(ErrSystemError, err, "Failed to lookup from reader workqueue for info: %v", workItemID) + } + + if !found { + return nil, errors.Errorf(ErrSystemError, "Item not found in the reader workqueue even though it was just added. ItemID: %v", workItemID) + } + + switch info.Status() { + case workqueue.WorkStatusSucceeded: + return newUploadFuture(ResponseStatusReady, nil), nil + case workqueue.WorkStatusFailed: + return newUploadFuture(ResponseStatusReady, info.Error()), nil + default: + return newUploadFuture(ResponseStatusNotReady, nil), nil } +} + +func (c AsyncClientImpl) Download(ctx context.Context, request DownloadRequest) (outputFuture DownloadFuture, err error) { + return c.DownloadArray(ctx, DownloadArrayRequest{ + Identifier: request.Key.Identifier, + CacheVersion: request.Key.CacheVersion, + TypedInterface: request.Key.TypedInterface, + BaseInputReader: request.Key.InputReader, + BaseTarget: request.Target, + dataStore: request.DataStore, + Indexes: nil, + Count: 0, + }) +} - return newUploadFuture(status, respErr), nil +func (c AsyncClientImpl) Upload(ctx context.Context, requests UploadRequest) (putFuture UploadFuture, err error) { + return c.UploadArray(ctx, UploadArrayRequest{ + Identifier: requests.Key.Identifier, + CacheVersion: requests.Key.CacheVersion, + TypedInterface: requests.Key.TypedInterface, + ArtifactMetadata: requests.ArtifactMetadata, + dataStore: requests.DataStore, + BaseInputReader: requests.Key.InputReader, + BaseArtifactData: requests.ArtifactData, + Indexes: nil, + Count: 0, + }) } func (c AsyncClientImpl) Start(ctx context.Context) error { - if err := c.Reader.Start(ctx); err != nil { + if err := c.ArrayReader.Start(ctx); err != nil { return errors.Wrapf(ErrSystemError, err, "Failed to start reader queue.") } - if err := c.Writer.Start(ctx); err != nil { + if err := c.ArrayWriter.Start(ctx); err != nil { return errors.Wrapf(ErrSystemError, err, "Failed to start writer queue.") } @@ -138,20 +149,20 @@ func (c AsyncClientImpl) Start(ctx context.Context) error { } func NewAsyncClient(client Client, cfg Config, scope promutils.Scope) (AsyncClientImpl, error) { - readerWorkQueue, err := workqueue.NewIndexedWorkQueue("reader", NewReaderProcessor(client), cfg.ReaderWorkqueueConfig, + arrayReaderWorkQueue, err := workqueue.NewIndexedWorkQueue("reader", NewArrayReaderProcessor(client, cfg.Reader.MaxItemsPerRound), cfg.Reader.Workqueue, scope.NewSubScope("reader")) if err != nil { return AsyncClientImpl{}, err } - writerWorkQueue, err := workqueue.NewIndexedWorkQueue("writer", NewWriterProcessor(client), cfg.WriterWorkqueueConfig, + arrayWriterWorkQueue, err := workqueue.NewIndexedWorkQueue("writer", NewWriterArrayProcessor(client, cfg.Writer.MaxItemsPerRound), cfg.Writer.Workqueue, scope.NewSubScope("writer")) if err != nil { return AsyncClientImpl{}, err } return AsyncClientImpl{ - Reader: readerWorkQueue, - Writer: writerWorkQueue, + ArrayWriter: arrayWriterWorkQueue, + ArrayReader: arrayReaderWorkQueue, }, nil } diff --git a/go/tasks/pluginmachinery/catalog/async_client_impl_test.go b/go/tasks/pluginmachinery/catalog/async_client_impl_test.go index fe3078c0a..5965ca1b1 100644 --- a/go/tasks/pluginmachinery/catalog/async_client_impl_test.go +++ b/go/tasks/pluginmachinery/catalog/async_client_impl_test.go @@ -2,9 +2,19 @@ package catalog import ( "context" - "reflect" "testing" + "github.com/go-test/deep" + + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils/labeled" + + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/storage" + "github.com/stretchr/testify/assert" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + mocks2 "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io/mocks" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/workqueue/mocks" "github.com/lyft/flytestdlib/bitarray" @@ -13,12 +23,27 @@ import ( "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/workqueue" ) +func init() { + labeled.SetMetricKeys(contextutils.NamespaceKey) +} + +func invertBitSet(input *bitarray.BitSet, limit uint) *bitarray.BitSet { + output := bitarray.NewBitSet(limit) + for i := uint(0); i < limit; i++ { + if !input.IsSet(i) { + output.Set(i) + } + } + + return output +} + func TestAsyncClientImpl_Download(t *testing.T) { ctx := context.Background() q := &mocks.IndexedWorkQueue{} info := &mocks.WorkItemInfo{} - info.OnItem().Return(NewReaderWorkItem(Key{}, &mocks2.OutputWriter{})) + info.OnItem().Return(NewArrayReaderWorkItem(DownloadArrayRequest{})) info.OnStatus().Return(workqueue.WorkStatusSucceeded) q.OnGetMatch(mock.Anything).Return(info, true, nil) q.OnQueueMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) @@ -27,32 +52,44 @@ func TestAsyncClientImpl_Download(t *testing.T) { ow.OnGetOutputPrefixPath().Return("/prefix/") ow.OnGetOutputPath().Return("/prefix/outputs.pb") + ir := &mocks2.InputReader{} + ir.OnGetInputPrefixPath().Return("/prefix/") + + ds, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + tests := []struct { name string reader workqueue.IndexedWorkQueue - requests []DownloadRequest + request DownloadArrayRequest wantOutputFuture DownloadFuture wantErr bool }{ - {"DownloadQueued", q, []DownloadRequest{ - { - Key: Key{}, - Target: ow, - }, - }, newDownloadFuture(ResponseStatusReady, nil, bitarray.NewBitSet(1), 1, 0), false}, + {"DownloadQueued", q, DownloadArrayRequest{ + Identifier: core.Identifier{}, + CacheVersion: "", + TypedInterface: core.TypedInterface{}, + dataStore: ds, + BaseInputReader: ir, + BaseTarget: ow, + Indexes: invertBitSet(bitarray.NewBitSet(1), 1), + Count: 1, + }, newDownloadFuture(ResponseStatusReady, nil, bitarray.NewBitSet(1), 1), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := AsyncClientImpl{ - Reader: tt.reader, + ArrayReader: tt.reader, } - gotOutputFuture, err := c.Download(ctx, tt.requests...) + + gotOutputFuture, err := c.DownloadArray(ctx, tt.request) if (err != nil) != tt.wantErr { t.Errorf("AsyncClientImpl.Download() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(gotOutputFuture, tt.wantOutputFuture) { - t.Errorf("AsyncClientImpl.Download() = %v, want %v", gotOutputFuture, tt.wantOutputFuture) + + if diff := deep.Equal(tt.wantOutputFuture, gotOutputFuture); diff != nil { + t.Errorf("expected != actual. Diff: %v", diff) } }) } @@ -63,35 +100,53 @@ func TestAsyncClientImpl_Upload(t *testing.T) { q := &mocks.IndexedWorkQueue{} info := &mocks.WorkItemInfo{} - info.OnItem().Return(NewReaderWorkItem(Key{}, &mocks2.OutputWriter{})) + info.OnItem().Return(NewArrayReaderWorkItem(DownloadArrayRequest{})) info.OnStatus().Return(workqueue.WorkStatusSucceeded) - q.OnGet("{UNSPECIFIED {} [] 0}:-0-").Return(info, true, nil) + q.OnGet("cfqacua").Return(info, true, nil) q.OnQueueMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + ir := &mocks2.InputReader{} + ir.OnGetInputPrefixPath().Return("/prefix/") + + ow := &mocks2.OutputWriter{} + ow.OnGetOutputPrefixPath().Return("/prefix/") + ow.OnGetOutputPath().Return("/prefix/outputs.pb") + + ds, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + tests := []struct { name string - requests []UploadRequest + request UploadArrayRequest wantPutFuture UploadFuture wantErr bool }{ - {"UploadSucceeded", []UploadRequest{ - { - Key: Key{}, - }, + {"UploadSucceeded", UploadArrayRequest{ + Identifier: core.Identifier{}, + CacheVersion: "", + TypedInterface: core.TypedInterface{}, + ArtifactMetadata: Metadata{}, + dataStore: ds, + BaseInputReader: ir, + BaseArtifactData: ow, + Indexes: invertBitSet(bitarray.NewBitSet(1), 1), + Count: 1, }, newUploadFuture(ResponseStatusReady, nil), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := AsyncClientImpl{ - Writer: q, + ArrayWriter: q, } - gotPutFuture, err := c.Upload(ctx, tt.requests...) + + gotPutFuture, err := c.UploadArray(ctx, tt.request) if (err != nil) != tt.wantErr { t.Errorf("AsyncClientImpl.Upload() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(gotPutFuture, tt.wantPutFuture) { - t.Errorf("AsyncClientImpl.Upload() = %v, want %v", gotPutFuture, tt.wantPutFuture) + + if diff := deep.Equal(tt.wantPutFuture, gotPutFuture); diff != nil { + t.Errorf("expected != actual. Diff: %v", diff) } }) } @@ -116,8 +171,8 @@ func TestAsyncClientImpl_Start(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := AsyncClientImpl{ - Reader: tt.fields.Reader, - Writer: tt.fields.Writer, + ArrayReader: tt.fields.Reader, + ArrayWriter: tt.fields.Writer, } if err := c.Start(tt.args.ctx); (err != nil) != tt.wantErr { t.Errorf("AsyncClientImpl.Start() error = %v, wantErr %v", err, tt.wantErr) diff --git a/go/tasks/pluginmachinery/catalog/config.go b/go/tasks/pluginmachinery/catalog/config.go index d3b7609e5..049ec665e 100644 --- a/go/tasks/pluginmachinery/catalog/config.go +++ b/go/tasks/pluginmachinery/catalog/config.go @@ -10,20 +10,31 @@ import ( var cfgSection = config.MustRegisterSubSection("catalogCache", defaultConfig) type Config struct { - ReaderWorkqueueConfig workqueue.Config `json:"reader" pflag:",Catalog reader workqueue config. Make sure the index cache must be big enough to accommodate the biggest array task allowed to run on the system."` - WriterWorkqueueConfig workqueue.Config `json:"writer" pflag:",Catalog writer workqueue config. Make sure the index cache must be big enough to accommodate the biggest array task allowed to run on the system."` + Reader ProcessorConfig `json:"reader" pflag:",Catalog reader processor config."` + Writer ProcessorConfig `json:"writer" pflag:",Catalog writer processor config."` +} + +type ProcessorConfig struct { + Workqueue workqueue.Config `json:"queue" pflag:",Workqueue config. Make sure the index cache must be big enough to accommodate the biggest array task allowed to run on the system."` + MaxItemsPerRound int `json:"itemsPerRound" pflag:",Max number of items to process in each round. Under load, this ensures fairness between different array jobs and avoid head-of-line blocking."` } var defaultConfig = &Config{ - ReaderWorkqueueConfig: workqueue.Config{ - MaxRetries: 3, - Workers: 10, - IndexCacheMaxItems: 1000, + Reader: ProcessorConfig{ + Workqueue: workqueue.Config{ + MaxRetries: 3, + Workers: 10, + IndexCacheMaxItems: 1000, + }, + MaxItemsPerRound: 100, }, - WriterWorkqueueConfig: workqueue.Config{ - MaxRetries: 3, - Workers: 10, - IndexCacheMaxItems: 1000, + Writer: ProcessorConfig{ + Workqueue: workqueue.Config{ + MaxRetries: 3, + Workers: 10, + IndexCacheMaxItems: 1000, + }, + MaxItemsPerRound: 100, }, } diff --git a/go/tasks/pluginmachinery/catalog/config_flags.go b/go/tasks/pluginmachinery/catalog/config_flags.go index f7bd0561f..4a8ebd212 100755 --- a/go/tasks/pluginmachinery/catalog/config_flags.go +++ b/go/tasks/pluginmachinery/catalog/config_flags.go @@ -41,11 +41,13 @@ func (Config) mustMarshalJSON(v json.Marshaler) string { // flags is json-name.json-sub-name... etc. func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "reader.workers"), defaultConfig.ReaderWorkqueueConfig.Workers, "Number of concurrent workers to start processing the queue.") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "reader.maxRetries"), defaultConfig.ReaderWorkqueueConfig.MaxRetries, "Maximum number of retries per item.") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "reader.maxItems"), defaultConfig.ReaderWorkqueueConfig.IndexCacheMaxItems, "Maximum number of entries to keep in the index.") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "writer.workers"), defaultConfig.WriterWorkqueueConfig.Workers, "Number of concurrent workers to start processing the queue.") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "writer.maxRetries"), defaultConfig.WriterWorkqueueConfig.MaxRetries, "Maximum number of retries per item.") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "writer.maxItems"), defaultConfig.WriterWorkqueueConfig.IndexCacheMaxItems, "Maximum number of entries to keep in the index.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "reader.queue.workers"), defaultConfig.Reader.Workqueue.Workers, "Number of concurrent workers to start processing the queue.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "reader.queue.maxRetries"), defaultConfig.Reader.Workqueue.MaxRetries, "Maximum number of retries per item.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "reader.queue.maxItems"), defaultConfig.Reader.Workqueue.IndexCacheMaxItems, "Maximum number of entries to keep in the index.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "reader.itemsPerRound"), defaultConfig.Reader.MaxItemsPerRound, "Max number of items to process in each round. Under load, this ensures fairness between different array jobs and avoid head-of-line blocking.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "writer.queue.workers"), defaultConfig.Writer.Workqueue.Workers, "Number of concurrent workers to start processing the queue.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "writer.queue.maxRetries"), defaultConfig.Writer.Workqueue.MaxRetries, "Maximum number of retries per item.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "writer.queue.maxItems"), defaultConfig.Writer.Workqueue.IndexCacheMaxItems, "Maximum number of entries to keep in the index.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "writer.itemsPerRound"), defaultConfig.Writer.MaxItemsPerRound, "Max number of items to process in each round. Under load, this ensures fairness between different array jobs and avoid head-of-line blocking.") return cmdFlags } diff --git a/go/tasks/pluginmachinery/catalog/config_flags_test.go b/go/tasks/pluginmachinery/catalog/config_flags_test.go index 0b9b1efd9..fb190f760 100755 --- a/go/tasks/pluginmachinery/catalog/config_flags_test.go +++ b/go/tasks/pluginmachinery/catalog/config_flags_test.go @@ -99,11 +99,11 @@ func TestConfig_SetFlags(t *testing.T) { cmdFlags := actual.GetPFlagSet("") assert.True(t, cmdFlags.HasFlags()) - t.Run("Test_reader.workers", func(t *testing.T) { + t.Run("Test_reader.queue.workers", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("reader.workers"); err == nil { - assert.Equal(t, int(defaultConfig.ReaderWorkqueueConfig.Workers), vInt) + if vInt, err := cmdFlags.GetInt("reader.queue.workers"); err == nil { + assert.Equal(t, int(defaultConfig.Reader.Workqueue.Workers), vInt) } else { assert.FailNow(t, err.Error()) } @@ -112,20 +112,20 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("reader.workers", testValue) - if vInt, err := cmdFlags.GetInt("reader.workers"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.ReaderWorkqueueConfig.Workers) + cmdFlags.Set("reader.queue.workers", testValue) + if vInt, err := cmdFlags.GetInt("reader.queue.workers"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Reader.Workqueue.Workers) } else { assert.FailNow(t, err.Error()) } }) }) - t.Run("Test_reader.maxRetries", func(t *testing.T) { + t.Run("Test_reader.queue.maxRetries", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("reader.maxRetries"); err == nil { - assert.Equal(t, int(defaultConfig.ReaderWorkqueueConfig.MaxRetries), vInt) + if vInt, err := cmdFlags.GetInt("reader.queue.maxRetries"); err == nil { + assert.Equal(t, int(defaultConfig.Reader.Workqueue.MaxRetries), vInt) } else { assert.FailNow(t, err.Error()) } @@ -134,20 +134,20 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("reader.maxRetries", testValue) - if vInt, err := cmdFlags.GetInt("reader.maxRetries"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.ReaderWorkqueueConfig.MaxRetries) + cmdFlags.Set("reader.queue.maxRetries", testValue) + if vInt, err := cmdFlags.GetInt("reader.queue.maxRetries"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Reader.Workqueue.MaxRetries) } else { assert.FailNow(t, err.Error()) } }) }) - t.Run("Test_reader.maxItems", func(t *testing.T) { + t.Run("Test_reader.queue.maxItems", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("reader.maxItems"); err == nil { - assert.Equal(t, int(defaultConfig.ReaderWorkqueueConfig.IndexCacheMaxItems), vInt) + if vInt, err := cmdFlags.GetInt("reader.queue.maxItems"); err == nil { + assert.Equal(t, int(defaultConfig.Reader.Workqueue.IndexCacheMaxItems), vInt) } else { assert.FailNow(t, err.Error()) } @@ -156,20 +156,20 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("reader.maxItems", testValue) - if vInt, err := cmdFlags.GetInt("reader.maxItems"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.ReaderWorkqueueConfig.IndexCacheMaxItems) + cmdFlags.Set("reader.queue.maxItems", testValue) + if vInt, err := cmdFlags.GetInt("reader.queue.maxItems"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Reader.Workqueue.IndexCacheMaxItems) } else { assert.FailNow(t, err.Error()) } }) }) - t.Run("Test_writer.workers", func(t *testing.T) { + t.Run("Test_reader.itemsPerRound", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("writer.workers"); err == nil { - assert.Equal(t, int(defaultConfig.WriterWorkqueueConfig.Workers), vInt) + if vInt, err := cmdFlags.GetInt("reader.itemsPerRound"); err == nil { + assert.Equal(t, int(defaultConfig.Reader.MaxItemsPerRound), vInt) } else { assert.FailNow(t, err.Error()) } @@ -178,20 +178,20 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("writer.workers", testValue) - if vInt, err := cmdFlags.GetInt("writer.workers"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.WriterWorkqueueConfig.Workers) + cmdFlags.Set("reader.itemsPerRound", testValue) + if vInt, err := cmdFlags.GetInt("reader.itemsPerRound"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Reader.MaxItemsPerRound) } else { assert.FailNow(t, err.Error()) } }) }) - t.Run("Test_writer.maxRetries", func(t *testing.T) { + t.Run("Test_writer.queue.workers", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("writer.maxRetries"); err == nil { - assert.Equal(t, int(defaultConfig.WriterWorkqueueConfig.MaxRetries), vInt) + if vInt, err := cmdFlags.GetInt("writer.queue.workers"); err == nil { + assert.Equal(t, int(defaultConfig.Writer.Workqueue.Workers), vInt) } else { assert.FailNow(t, err.Error()) } @@ -200,20 +200,20 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("writer.maxRetries", testValue) - if vInt, err := cmdFlags.GetInt("writer.maxRetries"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.WriterWorkqueueConfig.MaxRetries) + cmdFlags.Set("writer.queue.workers", testValue) + if vInt, err := cmdFlags.GetInt("writer.queue.workers"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Writer.Workqueue.Workers) } else { assert.FailNow(t, err.Error()) } }) }) - t.Run("Test_writer.maxItems", func(t *testing.T) { + t.Run("Test_writer.queue.maxRetries", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("writer.maxItems"); err == nil { - assert.Equal(t, int(defaultConfig.WriterWorkqueueConfig.IndexCacheMaxItems), vInt) + if vInt, err := cmdFlags.GetInt("writer.queue.maxRetries"); err == nil { + assert.Equal(t, int(defaultConfig.Writer.Workqueue.MaxRetries), vInt) } else { assert.FailNow(t, err.Error()) } @@ -222,9 +222,53 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("writer.maxItems", testValue) - if vInt, err := cmdFlags.GetInt("writer.maxItems"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.WriterWorkqueueConfig.IndexCacheMaxItems) + cmdFlags.Set("writer.queue.maxRetries", testValue) + if vInt, err := cmdFlags.GetInt("writer.queue.maxRetries"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Writer.Workqueue.MaxRetries) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_writer.queue.maxItems", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("writer.queue.maxItems"); err == nil { + assert.Equal(t, int(defaultConfig.Writer.Workqueue.IndexCacheMaxItems), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("writer.queue.maxItems", testValue) + if vInt, err := cmdFlags.GetInt("writer.queue.maxItems"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Writer.Workqueue.IndexCacheMaxItems) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_writer.itemsPerRound", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("writer.itemsPerRound"); err == nil { + assert.Equal(t, int(defaultConfig.Writer.MaxItemsPerRound), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("writer.itemsPerRound", testValue) + if vInt, err := cmdFlags.GetInt("writer.itemsPerRound"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Writer.MaxItemsPerRound) } else { assert.FailNow(t, err.Error()) diff --git a/go/tasks/pluginmachinery/catalog/mocks/async_client.go b/go/tasks/pluginmachinery/catalog/mocks/async_client.go index 46fc8c5ec..fd8fd08b6 100644 --- a/go/tasks/pluginmachinery/catalog/mocks/async_client.go +++ b/go/tasks/pluginmachinery/catalog/mocks/async_client.go @@ -23,8 +23,8 @@ func (_m AsyncClient_Download) Return(outputFuture catalog.DownloadFuture, err e return &AsyncClient_Download{Call: _m.Call.Return(outputFuture, err)} } -func (_m *AsyncClient) OnDownload(ctx context.Context, requests ...catalog.DownloadRequest) *AsyncClient_Download { - c := _m.On("Download", ctx, requests) +func (_m *AsyncClient) OnDownload(ctx context.Context, request catalog.DownloadRequest) *AsyncClient_Download { + c := _m.On("Download", ctx, request) return &AsyncClient_Download{Call: c} } @@ -33,20 +33,54 @@ func (_m *AsyncClient) OnDownloadMatch(matchers ...interface{}) *AsyncClient_Dow return &AsyncClient_Download{Call: c} } -// Download provides a mock function with given fields: ctx, requests -func (_m *AsyncClient) Download(ctx context.Context, requests ...catalog.DownloadRequest) (catalog.DownloadFuture, error) { - _va := make([]interface{}, len(requests)) - for _i := range requests { - _va[_i] = requests[_i] +// Download provides a mock function with given fields: ctx, request +func (_m *AsyncClient) Download(ctx context.Context, request catalog.DownloadRequest) (catalog.DownloadFuture, error) { + ret := _m.Called(ctx, request) + + var r0 catalog.DownloadFuture + if rf, ok := ret.Get(0).(func(context.Context, catalog.DownloadRequest) catalog.DownloadFuture); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(catalog.DownloadFuture) + } } - var _ca []interface{} - _ca = append(_ca, ctx) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, catalog.DownloadRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type AsyncClient_DownloadArray struct { + *mock.Call +} + +func (_m AsyncClient_DownloadArray) Return(outputFuture catalog.DownloadFuture, err error) *AsyncClient_DownloadArray { + return &AsyncClient_DownloadArray{Call: _m.Call.Return(outputFuture, err)} +} + +func (_m *AsyncClient) OnDownloadArray(ctx context.Context, request catalog.DownloadArrayRequest) *AsyncClient_DownloadArray { + c := _m.On("DownloadArray", ctx, request) + return &AsyncClient_DownloadArray{Call: c} +} + +func (_m *AsyncClient) OnDownloadArrayMatch(matchers ...interface{}) *AsyncClient_DownloadArray { + c := _m.On("DownloadArray", matchers...) + return &AsyncClient_DownloadArray{Call: c} +} + +// DownloadArray provides a mock function with given fields: ctx, request +func (_m *AsyncClient) DownloadArray(ctx context.Context, request catalog.DownloadArrayRequest) (catalog.DownloadFuture, error) { + ret := _m.Called(ctx, request) var r0 catalog.DownloadFuture - if rf, ok := ret.Get(0).(func(context.Context, ...catalog.DownloadRequest) catalog.DownloadFuture); ok { - r0 = rf(ctx, requests...) + if rf, ok := ret.Get(0).(func(context.Context, catalog.DownloadArrayRequest) catalog.DownloadFuture); ok { + r0 = rf(ctx, request) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(catalog.DownloadFuture) @@ -54,8 +88,8 @@ func (_m *AsyncClient) Download(ctx context.Context, requests ...catalog.Downloa } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, ...catalog.DownloadRequest) error); ok { - r1 = rf(ctx, requests...) + if rf, ok := ret.Get(1).(func(context.Context, catalog.DownloadArrayRequest) error); ok { + r1 = rf(ctx, request) } else { r1 = ret.Error(1) } @@ -71,8 +105,8 @@ func (_m AsyncClient_Upload) Return(putFuture catalog.UploadFuture, err error) * return &AsyncClient_Upload{Call: _m.Call.Return(putFuture, err)} } -func (_m *AsyncClient) OnUpload(ctx context.Context, requests ...catalog.UploadRequest) *AsyncClient_Upload { - c := _m.On("Upload", ctx, requests) +func (_m *AsyncClient) OnUpload(ctx context.Context, request catalog.UploadRequest) *AsyncClient_Upload { + c := _m.On("Upload", ctx, request) return &AsyncClient_Upload{Call: c} } @@ -81,20 +115,54 @@ func (_m *AsyncClient) OnUploadMatch(matchers ...interface{}) *AsyncClient_Uploa return &AsyncClient_Upload{Call: c} } -// Upload provides a mock function with given fields: ctx, requests -func (_m *AsyncClient) Upload(ctx context.Context, requests ...catalog.UploadRequest) (catalog.UploadFuture, error) { - _va := make([]interface{}, len(requests)) - for _i := range requests { - _va[_i] = requests[_i] +// Upload provides a mock function with given fields: ctx, request +func (_m *AsyncClient) Upload(ctx context.Context, request catalog.UploadRequest) (catalog.UploadFuture, error) { + ret := _m.Called(ctx, request) + + var r0 catalog.UploadFuture + if rf, ok := ret.Get(0).(func(context.Context, catalog.UploadRequest) catalog.UploadFuture); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(catalog.UploadFuture) + } } - var _ca []interface{} - _ca = append(_ca, ctx) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, catalog.UploadRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type AsyncClient_UploadArray struct { + *mock.Call +} + +func (_m AsyncClient_UploadArray) Return(putFuture catalog.UploadFuture, err error) *AsyncClient_UploadArray { + return &AsyncClient_UploadArray{Call: _m.Call.Return(putFuture, err)} +} + +func (_m *AsyncClient) OnUploadArray(ctx context.Context, requests catalog.UploadArrayRequest) *AsyncClient_UploadArray { + c := _m.On("UploadArray", ctx, requests) + return &AsyncClient_UploadArray{Call: c} +} + +func (_m *AsyncClient) OnUploadArrayMatch(matchers ...interface{}) *AsyncClient_UploadArray { + c := _m.On("UploadArray", matchers...) + return &AsyncClient_UploadArray{Call: c} +} + +// UploadArray provides a mock function with given fields: ctx, requests +func (_m *AsyncClient) UploadArray(ctx context.Context, requests catalog.UploadArrayRequest) (catalog.UploadFuture, error) { + ret := _m.Called(ctx, requests) var r0 catalog.UploadFuture - if rf, ok := ret.Get(0).(func(context.Context, ...catalog.UploadRequest) catalog.UploadFuture); ok { - r0 = rf(ctx, requests...) + if rf, ok := ret.Get(0).(func(context.Context, catalog.UploadArrayRequest) catalog.UploadFuture); ok { + r0 = rf(ctx, requests) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(catalog.UploadFuture) @@ -102,8 +170,8 @@ func (_m *AsyncClient) Upload(ctx context.Context, requests ...catalog.UploadReq } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, ...catalog.UploadRequest) error); ok { - r1 = rf(ctx, requests...) + if rf, ok := ret.Get(1).(func(context.Context, catalog.UploadArrayRequest) error); ok { + r1 = rf(ctx, requests) } else { r1 = ret.Error(1) } diff --git a/go/tasks/pluginmachinery/catalog/reader_array_processor.go b/go/tasks/pluginmachinery/catalog/reader_array_processor.go new file mode 100644 index 000000000..1d31a59a4 --- /dev/null +++ b/go/tasks/pluginmachinery/catalog/reader_array_processor.go @@ -0,0 +1,131 @@ +package catalog + +import ( + "context" + "fmt" + "reflect" + "strconv" + + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/ioutils" + "github.com/lyft/flytestdlib/bitarray" + + "github.com/lyft/flyteplugins/go/tasks/errors" + + "github.com/lyft/flytestdlib/logger" + + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/workqueue" +) + +type ArrayReaderWorkItem struct { + // ArrayReaderWorkItem outputs: + cached *bitarray.BitSet + progress *bitarray.BitSet + + // ArrayReaderWorkItem Inputs: + downloadRequest DownloadArrayRequest +} + +func (item ArrayReaderWorkItem) CachedResults() *bitarray.BitSet { + return item.cached +} + +func NewArrayReaderWorkItem(request DownloadArrayRequest) *ArrayReaderWorkItem { + return &ArrayReaderWorkItem{ + downloadRequest: request, + } +} + +type ArrayReaderProcessor struct { + catalogClient Client + maxItemsPerRound int +} + +func (p ArrayReaderProcessor) Process(ctx context.Context, workItem workqueue.WorkItem) (workqueue.WorkStatus, error) { + wi, casted := workItem.(*ArrayReaderWorkItem) + if !casted { + return workqueue.WorkStatusNotDone, fmt.Errorf("wrong work item type. Received: %v", reflect.TypeOf(workItem)) + } + + if wi.cached == nil { + wi.cached = bitarray.NewBitSet(uint(wi.downloadRequest.Count)) + wi.progress = bitarray.NewBitSet(uint(wi.downloadRequest.Count)) + } + + isArray := wi.downloadRequest.Count > 0 + + for i := uint(0); i < uint(wi.downloadRequest.Count); i++ { + inputReader := wi.downloadRequest.BaseInputReader + if isArray { + if !wi.downloadRequest.Indexes.IsSet(i) { + continue + } + + if wi.progress.IsSet(i) { + logger.Debugf(ctx, "Catalog lookup already ran for index [%v], result [%v].", i, wi.cached.IsSet(i)) + continue + } + + indexedInputLocation, err := wi.downloadRequest.dataStore.ConstructReference(ctx, + wi.downloadRequest.BaseInputReader.GetInputPrefixPath(), + strconv.Itoa(int(i))) + if err != nil { + return workqueue.WorkStatusNotDone, err + } + + inputReader = ioutils.NewRemoteFileInputReader(ctx, wi.downloadRequest.dataStore, + ioutils.NewInputFilePaths(ctx, wi.downloadRequest.dataStore, indexedInputLocation)) + } + + k := Key{ + Identifier: wi.downloadRequest.Identifier, + CacheVersion: wi.downloadRequest.CacheVersion, + TypedInterface: wi.downloadRequest.TypedInterface, + InputReader: inputReader, + } + op, err := p.catalogClient.Get(ctx, k) + if err != nil { + if IsNotFound(err) { + logger.Infof(ctx, "Artifact not found in Catalog. Key: %v", k) + wi.progress.Set(i) + } else { + err = errors.Wrapf("CausedBy", err, "Failed to call catalog for Key: %v.", k) + logger.Warnf(ctx, "Cache call failed: %v", err) + return workqueue.WorkStatusFailed, err + } + } else if op != nil { + writer := wi.downloadRequest.BaseTarget + if isArray { + // TODO: Check task interface, if it has outputs but literalmap is empty (or not matching output), error. + dataReference, err := wi.downloadRequest.dataStore.ConstructReference(ctx, + wi.downloadRequest.BaseTarget.GetOutputPrefixPath(), strconv.Itoa(int(i))) + if err != nil { + return workqueue.WorkStatusFailed, err + } + + writer = ioutils.NewRemoteFileOutputWriter(ctx, wi.downloadRequest.dataStore, + ioutils.NewRemoteFileOutputPaths(ctx, wi.downloadRequest.dataStore, dataReference)) + } + + logger.Debugf(ctx, "Persisting output to %v", writer.GetOutputPath()) + err = writer.Put(ctx, op) + if err != nil { + err = errors.Wrapf("CausedBy", err, "Failed to persist cached output for Key: %v.", k) + logger.Warnf(ctx, "Cache write to output writer failed: %v", err) + return workqueue.WorkStatusFailed, err + } + + wi.cached.Set(i) + wi.progress.Set(i) + } + } + + logger.Debugf(ctx, "Successfully wrote to catalog. Identifier [%v]", wi.downloadRequest.Identifier) + return workqueue.WorkStatusSucceeded, nil +} + +func NewArrayReaderProcessor(catalogClient Client, maxItemsPerRound int) ArrayReaderProcessor { + return ArrayReaderProcessor{ + catalogClient: catalogClient, + maxItemsPerRound: maxItemsPerRound, + } +} diff --git a/go/tasks/pluginmachinery/catalog/reader_processor.go b/go/tasks/pluginmachinery/catalog/reader_processor.go deleted file mode 100644 index 0875df7dd..000000000 --- a/go/tasks/pluginmachinery/catalog/reader_processor.go +++ /dev/null @@ -1,84 +0,0 @@ -package catalog - -import ( - "context" - "fmt" - "reflect" - - "github.com/lyft/flyteplugins/go/tasks/errors" - - "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io" - "github.com/lyft/flytestdlib/logger" - - "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/workqueue" -) - -type ReaderWorkItem struct { - // ReaderWorkItem outputs: - cached bool - - // ReaderWorkItem Inputs: - outputsWriter io.OutputWriter - // Inputs to query data catalog - key Key -} - -func (item ReaderWorkItem) IsCached() bool { - return item.cached -} - -func NewReaderWorkItem(key Key, outputsWriter io.OutputWriter) *ReaderWorkItem { - return &ReaderWorkItem{ - key: key, - outputsWriter: outputsWriter, - } -} - -type ReaderProcessor struct { - catalogClient Client -} - -func (p ReaderProcessor) Process(ctx context.Context, workItem workqueue.WorkItem) (workqueue.WorkStatus, error) { - wi, casted := workItem.(*ReaderWorkItem) - if !casted { - return workqueue.WorkStatusNotDone, fmt.Errorf("wrong work item type. Received: %v", reflect.TypeOf(workItem)) - } - - op, err := p.catalogClient.Get(ctx, wi.key) - if err != nil { - if IsNotFound(err) { - logger.Infof(ctx, "Artifact not found in Catalog. Key: %v", wi.key) - wi.cached = false - return workqueue.WorkStatusSucceeded, nil - } - - err = errors.Wrapf("CausedBy", err, "Failed to call catalog for Key: %v.", wi.key) - logger.Warnf(ctx, "Cache call failed: %v", err) - return workqueue.WorkStatusFailed, err - } - - if op == nil { - wi.cached = false - return workqueue.WorkStatusSucceeded, nil - } - - // TODO: Check task interface, if it has outputs but literalmap is empty (or not matching output), error. - logger.Debugf(ctx, "Persisting output to %v", wi.outputsWriter.GetOutputPath()) - err = wi.outputsWriter.Put(ctx, op) - if err != nil { - err = errors.Wrapf("CausedBy", err, "Failed to persist cached output for Key: %v.", wi.key) - logger.Warnf(ctx, "Cache write to output writer failed: %v", err) - return workqueue.WorkStatusFailed, err - } - - wi.cached = true - - logger.Debugf(ctx, "Successfully wrote to catalog. Key [%v]", wi.key) - return workqueue.WorkStatusSucceeded, nil -} - -func NewReaderProcessor(catalogClient Client) ReaderProcessor { - return ReaderProcessor{ - catalogClient: catalogClient, - } -} diff --git a/go/tasks/pluginmachinery/catalog/response.go b/go/tasks/pluginmachinery/catalog/response.go index da2c4936e..b9e403bdf 100644 --- a/go/tasks/pluginmachinery/catalog/response.go +++ b/go/tasks/pluginmachinery/catalog/response.go @@ -59,15 +59,13 @@ func (r downloadFuture) GetCachedCount() int { return r.cachedCount } -func newDownloadFuture(status ResponseStatus, err error, cachedResults *bitarray.BitSet, resultsSize int, - cachedCount int) downloadFuture { +func newDownloadFuture(status ResponseStatus, err error, cachedResults *bitarray.BitSet, resultsSize int) downloadFuture { return downloadFuture{ future: &future{ responseStatus: status, err: err, }, - cachedCount: cachedCount, cachedResults: cachedResults, resultsSize: resultsSize, } diff --git a/go/tasks/pluginmachinery/catalog/writer_array_processor.go b/go/tasks/pluginmachinery/catalog/writer_array_processor.go new file mode 100644 index 000000000..7e12efe89 --- /dev/null +++ b/go/tasks/pluginmachinery/catalog/writer_array_processor.go @@ -0,0 +1,116 @@ +package catalog + +import ( + "context" + "fmt" + "reflect" + "strconv" + + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io" + + "github.com/lyft/flytestdlib/bitarray" + + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/ioutils" + + "github.com/lyft/flyteplugins/go/tasks/errors" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/workqueue" + "github.com/lyft/flytestdlib/logger" +) + +type WriterArrayWorkItem struct { + // WriterArrayWorkItem Inputs + request UploadArrayRequest + + progress *bitarray.BitSet +} + +func NewArrayWriterWorkItem(request UploadArrayRequest) *WriterArrayWorkItem { + return &WriterArrayWorkItem{ + request: request, + } +} + +type writerArrayProcessor struct { + catalogClient Client + maxItemsPerRound int +} + +func (p writerArrayProcessor) Process(ctx context.Context, workItem workqueue.WorkItem) (workqueue.WorkStatus, error) { + wi, casted := workItem.(*WriterArrayWorkItem) + if !casted { + return workqueue.WorkStatusNotDone, fmt.Errorf("wrong work item type. Received: %v", reflect.TypeOf(workItem)) + } + + if wi.progress == nil { + wi.progress = bitarray.NewBitSet(uint(wi.request.Count)) + } + + isArray := wi.request.Count > 0 + for i := uint(0); i < uint(wi.request.Count); i++ { + inputReader := wi.request.BaseInputReader + var outputReader io.OutputReader + if isArray { + if !wi.request.Indexes.IsSet(i) { + continue + } + + if wi.progress.IsSet(i) { + logger.Debugf(ctx, "Catalog lookup already ran for index [%v].", i) + continue + } + + indexedInputLocation, err := wi.request.dataStore.ConstructReference(ctx, + wi.request.BaseInputReader.GetInputPrefixPath(), + strconv.Itoa(int(i))) + if err != nil { + return workqueue.WorkStatusNotDone, err + } + + inputReader = ioutils.NewRemoteFileInputReader(ctx, wi.request.dataStore, + ioutils.NewInputFilePaths(ctx, wi.request.dataStore, indexedInputLocation)) + + indexedOutputLocation, err := wi.request.dataStore.ConstructReference(ctx, + wi.request.BaseArtifactData.GetOutputPrefixPath(), + strconv.Itoa(int(i))) + if err != nil { + return workqueue.WorkStatusNotDone, err + } + + // TODO: size limit is weird to be passed here... + outputReader = ioutils.NewRemoteFileOutputReader(ctx, wi.request.dataStore, + ioutils.NewRemoteFileOutputPaths(ctx, wi.request.dataStore, indexedOutputLocation), + int64(999999999)) + } else { + outputReader = ioutils.NewRemoteFileOutputReader(ctx, wi.request.dataStore, + wi.request.BaseArtifactData, int64(999999999)) + } + + k := Key{ + Identifier: wi.request.Identifier, + CacheVersion: wi.request.CacheVersion, + TypedInterface: wi.request.TypedInterface, + InputReader: inputReader, + } + + err := p.catalogClient.Put(ctx, k, outputReader, wi.request.ArtifactMetadata) + if err != nil { + logger.Errorf(ctx, "Error putting to catalog [%s]", err) + return workqueue.WorkStatusNotDone, errors.Wrapf(errors.DownstreamSystemError, err, + "Error writing to catalog, key id [%v] cache version [%v]", + k.Identifier, k.CacheVersion) + } + + wi.progress.Set(i) + } + + logger.Debugf(ctx, "Successfully wrote to catalog.") + + return workqueue.WorkStatusSucceeded, nil +} + +func NewWriterArrayProcessor(catalogClient Client, maxItemsPerRound int) workqueue.Processor { + return writerArrayProcessor{ + catalogClient: catalogClient, + maxItemsPerRound: maxItemsPerRound, + } +} diff --git a/go/tasks/pluginmachinery/catalog/writer_processor.go b/go/tasks/pluginmachinery/catalog/writer_processor.go deleted file mode 100644 index aa9a1e4df..000000000 --- a/go/tasks/pluginmachinery/catalog/writer_processor.go +++ /dev/null @@ -1,56 +0,0 @@ -package catalog - -import ( - "context" - "fmt" - "reflect" - - "github.com/lyft/flyteplugins/go/tasks/errors" - "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io" - "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/workqueue" - "github.com/lyft/flytestdlib/logger" -) - -type WriterWorkItem struct { - // WriterWorkItem Inputs - key Key - data io.OutputReader - metadata Metadata -} - -func NewWriterWorkItem(key Key, data io.OutputReader, metadata Metadata) *WriterWorkItem { - return &WriterWorkItem{ - key: key, - data: data, - metadata: metadata, - } -} - -type writerProcessor struct { - catalogClient Client -} - -func (p writerProcessor) Process(ctx context.Context, workItem workqueue.WorkItem) (workqueue.WorkStatus, error) { - wi, casted := workItem.(*WriterWorkItem) - if !casted { - return workqueue.WorkStatusNotDone, fmt.Errorf("wrong work item type. Received: %v", reflect.TypeOf(workItem)) - } - - err := p.catalogClient.Put(ctx, wi.key, wi.data, wi.metadata) - if err != nil { - logger.Errorf(ctx, "Error putting to catalog [%s]", err) - return workqueue.WorkStatusNotDone, errors.Wrapf(errors.DownstreamSystemError, err, - "Error writing to catalog, key id [%v] cache version [%v]", - wi.key.Identifier, wi.key.CacheVersion) - } - - logger.Debugf(ctx, "Successfully wrote to catalog. Key [%v]", wi.key) - - return workqueue.WorkStatusSucceeded, nil -} - -func NewWriterProcessor(catalogClient Client) workqueue.Processor { - return writerProcessor{ - catalogClient: catalogClient, - } -} diff --git a/go/tasks/plugins/array/catalog.go b/go/tasks/plugins/array/catalog.go index e3ca301f7..38097443a 100644 --- a/go/tasks/plugins/array/catalog.go +++ b/go/tasks/plugins/array/catalog.go @@ -55,27 +55,21 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex return state, nil } - // Otherwise, run the data catalog steps - create and submit work items to the catalog processor, - // build input readers - inputReaders, err := ConstructInputReaders(ctx, tCtx.DataStore(), tCtx.InputReader().GetInputPrefixPath(), int(arrayJob.Size)) - if err != nil { - return state, err - } - - // build output writers - outputWriters, err := ConstructOutputWriters(ctx, tCtx.DataStore(), tCtx.OutputWriter().GetOutputPrefixPath(), int(arrayJob.Size)) - if err != nil { - return state, err - } + iface := *taskTemplate.Interface + iface.Outputs = makeSingularTaskInterface(iface.Outputs) - // build work items from inputs and outputs - workItems, err := ConstructCatalogReaderWorkItems(ctx, tCtx.TaskReader(), inputReaders, outputWriters) - if err != nil { - return state, err + request := catalog.DownloadArrayRequest{ + Identifier: *taskTemplate.GetId(), + CacheVersion: taskTemplate.Metadata.DiscoveryVersion, + TypedInterface: iface, + BaseInputReader: tCtx.InputReader(), + BaseTarget: tCtx.OutputWriter(), + Indexes: arrayCore.InvertBitSet(bitarray.NewBitSet(uint(arrayJob.Size)), uint(arrayJob.Size)), + Count: int(arrayJob.Size), } // Check catalog, and if we have responses from catalog for everything, then move to writing the mapping file. - future, err := tCtx.Catalog().Download(ctx, workItems...) + future, err := tCtx.Catalog().DownloadArray(ctx, request) if err != nil { return state, err } @@ -156,45 +150,40 @@ func WriteToDiscovery(ctx context.Context, tCtx core.TaskExecutionContext, state return state, errors.Errorf(errors.BadTaskSpecification, "Could not extract custom array job") } - // input readers - inputReaders, err := ConstructInputReaders(ctx, tCtx.DataStore(), tCtx.InputReader().GetInputPrefixPath(), int(arrayJob.Size)) - if err != nil { - return nil, err - } - - // output reader - outputReaders, err := ConstructOutputReaders(ctx, tCtx.DataStore(), tCtx.OutputWriter().GetOutputPrefixPath(), int(arrayJob.Size)) - if err != nil { - return nil, err - } - - iface := *taskTemplate.Interface - iface.Outputs = makeSingularTaskInterface(iface.Outputs) - + taskExecID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID() // Do not cache failed tasks. Retrieve the final phase from array status and unset the non-successful ones. tasksToCache := state.GetIndexesToCache().DeepCopy() + toCacheCount := 0 for idx, phaseIdx := range state.ArrayStatus.Detailed.GetItems() { phase := core.Phases[phaseIdx] if !phase.IsSuccess() { tasksToCache.Clear(uint(idx)) + } else { + toCacheCount++ } } - // Create catalog put items, but only put the ones that were not originally cached (as read from the catalog results bitset) - catalogWriterItems, err := ConstructCatalogUploadRequests(*tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID().TaskId, - tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID(), taskTemplate.Metadata.DiscoveryVersion, - iface, &tasksToCache, inputReaders, outputReaders) - - if err != nil { - return nil, err - } - - if len(catalogWriterItems) == 0 { + if toCacheCount == 0 { state.SetPhase(phaseOnSuccess, core.DefaultPhaseVersion).SetReason("No outputs need to be cached.") return state, nil } - allWritten, err := WriteToCatalog(ctx, tCtx.TaskRefreshIndicator(), tCtx.Catalog(), catalogWriterItems) + iface := *taskTemplate.Interface + iface.Outputs = makeSingularTaskInterface(iface.Outputs) + request := catalog.UploadArrayRequest{ + Identifier: *tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID().TaskId, + CacheVersion: taskTemplate.Metadata.DiscoveryVersion, + TypedInterface: iface, + ArtifactMetadata: catalog.Metadata{ + TaskExecutionIdentifier: &taskExecID, + }, + BaseInputReader: tCtx.InputReader(), + BaseArtifactData: tCtx.OutputWriter(), + Indexes: &tasksToCache, + Count: int(arrayJob.Size), + } + + allWritten, err := WriteToCatalog(ctx, tCtx.TaskRefreshIndicator(), tCtx.Catalog(), request) if err != nil { return nil, err } @@ -207,10 +196,10 @@ func WriteToDiscovery(ctx context.Context, tCtx core.TaskExecutionContext, state } func WriteToCatalog(ctx context.Context, ownerSignal core.SignalAsync, catalogClient catalog.AsyncClient, - workItems []catalog.UploadRequest) (bool, error) { + workItem catalog.UploadArrayRequest) (bool, error) { // Enqueue work items - future, err := catalogClient.Upload(ctx, workItems...) + future, err := catalogClient.UploadArray(ctx, workItem) if err != nil { return false, errors.Wrapf(arrayCore.ErrorWorkQueue, err, "Error enqueuing work items") @@ -233,41 +222,6 @@ func WriteToCatalog(ctx context.Context, ownerSignal core.SignalAsync, catalogCl return false, nil } -func ConstructCatalogUploadRequests(keyId idlCore.Identifier, taskExecId idlCore.TaskExecutionIdentifier, - cacheVersion string, taskInterface idlCore.TypedInterface, whichTasksToCache *bitarray.BitSet, - inputReaders []io.InputReader, outputReaders []io.OutputReader) ([]catalog.UploadRequest, error) { - - writerWorkItems := make([]catalog.UploadRequest, 0, len(inputReaders)) - - if len(inputReaders) != len(outputReaders) { - return nil, errors.Errorf(arrayCore.ErrorInternalMismatch, "Length different building catalog writer items %d %d", - len(inputReaders), len(outputReaders)) - } - - for idx, input := range inputReaders { - if !whichTasksToCache.IsSet(uint(idx)) { - continue - } - - wi := catalog.UploadRequest{ - Key: catalog.Key{ - Identifier: keyId, - InputReader: input, - CacheVersion: cacheVersion, - TypedInterface: taskInterface, - }, - ArtifactData: outputReaders[idx], - ArtifactMetadata: catalog.Metadata{ - TaskExecutionIdentifier: &taskExecId, - }, - } - - writerWorkItems = append(writerWorkItems, wi) - } - - return writerWorkItems, nil -} - func NewLiteralScalarOfInteger(number int64) *idlCore.Literal { return &idlCore.Literal{ Value: &idlCore.Literal_Scalar{ @@ -325,70 +279,6 @@ func makeSingularTaskInterface(varMap *idlCore.VariableMap) *idlCore.VariableMap } -func ConstructCatalogReaderWorkItems(ctx context.Context, taskReader core.TaskReader, inputs []io.InputReader, - outputs []io.OutputWriter) ([]catalog.DownloadRequest, error) { - - t, err := taskReader.Read(ctx) - if err != nil { - return nil, err - } - - workItems := make([]catalog.DownloadRequest, 0, len(inputs)) - - iface := *t.Interface - iface.Outputs = makeSingularTaskInterface(iface.Outputs) - - for idx, inputReader := range inputs { - // TODO: Check if Id or Interface are empty and return err - item := catalog.DownloadRequest{ - Key: catalog.Key{ - Identifier: *t.Id, - CacheVersion: t.GetMetadata().DiscoveryVersion, - InputReader: inputReader, - TypedInterface: iface, - }, - Target: outputs[idx], - } - workItems = append(workItems, item) - } - - return workItems, nil -} - -func ConstructInputReaders(ctx context.Context, dataStore *storage.DataStore, inputPrefix storage.DataReference, - size int) ([]io.InputReader, error) { - - inputReaders := make([]io.InputReader, 0, size) - for i := 0; i < size; i++ { - indexedInputLocation, err := dataStore.ConstructReference(ctx, inputPrefix, strconv.Itoa(i)) - if err != nil { - return inputReaders, err - } - - inputReader := ioutils.NewRemoteFileInputReader(ctx, dataStore, ioutils.NewInputFilePaths(ctx, dataStore, indexedInputLocation)) - inputReaders = append(inputReaders, inputReader) - } - - return inputReaders, nil -} - -func ConstructOutputWriters(ctx context.Context, dataStore *storage.DataStore, outputPrefix storage.DataReference, - size int) ([]io.OutputWriter, error) { - - outputWriters := make([]io.OutputWriter, 0, size) - - for i := 0; i < size; i++ { - ow, err := ConstructOutputWriter(ctx, dataStore, outputPrefix, i) - if err != nil { - return outputWriters, err - } - - outputWriters = append(outputWriters, ow) - } - - return outputWriters, nil -} - func ConstructOutputWriter(ctx context.Context, dataStore *storage.DataStore, outputPrefix storage.DataReference, index int) (io.OutputWriter, error) { dataReference, err := dataStore.ConstructReference(ctx, outputPrefix, strconv.Itoa(index)) diff --git a/go/tasks/plugins/array/catalog_test.go b/go/tasks/plugins/array/catalog_test.go index a3596e381..b6c2322df 100644 --- a/go/tasks/plugins/array/catalog_test.go +++ b/go/tasks/plugins/array/catalog_test.go @@ -59,7 +59,7 @@ func runDetermineDiscoverabilityTest(t testing.TB, taskTemplate *core.TaskTempla assert.NoError(t, err) cat := &catalogMocks.AsyncClient{} - cat.OnDownloadMatch(mock.Anything, mock.Anything).Return(future, nil) + cat.OnDownloadArrayMatch(mock.Anything, mock.Anything).Return(future, nil) ir := &ioMocks.InputReader{} ir.OnGetInputPrefixPath().Return("/prefix/") diff --git a/tests/end_to_end.go b/tests/end_to_end.go index 6c11a12e5..76eea2c54 100644 --- a/tests/end_to_end.go +++ b/tests/end_to_end.go @@ -8,9 +8,8 @@ import ( "testing" "time" - "k8s.io/apimachinery/pkg/util/rand" - "github.com/go-test/deep" + "k8s.io/apimachinery/pkg/util/rand" v12 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -199,15 +198,19 @@ func RunPluginEndToEndTest(t *testing.T, executor pluginCore.Plugin, template *i catData.Store(key, o) }) cat, err := catalog.NewAsyncClient(catClient, catalog.Config{ - ReaderWorkqueueConfig: workqueue.Config{ - MaxRetries: 0, - Workers: 2, - IndexCacheMaxItems: 100, + Reader: catalog.ProcessorConfig{ + Workqueue: workqueue.Config{ + MaxRetries: 0, + Workers: 2, + IndexCacheMaxItems: 100, + }, }, - WriterWorkqueueConfig: workqueue.Config{ - MaxRetries: 0, - Workers: 2, - IndexCacheMaxItems: 100, + Writer: catalog.ProcessorConfig{ + Workqueue: workqueue.Config{ + MaxRetries: 0, + Workers: 2, + IndexCacheMaxItems: 100, + }, }, }, promutils.NewTestScope()) assert.NoError(t, err)