From 4f0d572f6a0fffcd7ba9ddf11120e6e28201b327 Mon Sep 17 00:00:00 2001 From: Benjamin Schimke Date: Thu, 19 Oct 2023 16:59:50 +0200 Subject: [PATCH] Use channels for batch queue changes --- pkg/kine/drivers/generic/generic.go | 31 +++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/pkg/kine/drivers/generic/generic.go b/pkg/kine/drivers/generic/generic.go index 14e9d6cd..5564b3c3 100644 --- a/pkg/kine/drivers/generic/generic.go +++ b/pkg/kine/drivers/generic/generic.go @@ -128,6 +128,8 @@ type Generic struct { TranslateErr TranslateErr ErrCode ErrCode flushCh chan struct{} + addToBatchQueueCh chan BatchedInsert + removeFromBatchQueueCh chan int batchingQueue []BatchedInsert AdmissionControlPolicy AdmissionControlPolicy @@ -394,9 +396,28 @@ func (d *Generic) queryRowPrepared(ctx context.Context, txName, sql string, prep func (d *Generic) InitializeWriteBatching(ctx context.Context) { if d.BatchingEnabled && d.BatchingInterval > 0 && d.BatchingMaxQueries > 0 { d.flushCh = make(chan struct{}) + d.addToBatchQueueCh = make(chan BatchedInsert) + d.removeFromBatchQueueCh = make(chan int) d.batchingQueue = make([]BatchedInsert, 0, d.BatchingMaxQueries) go d.watchBatchQueue(ctx) + go d.batchQueueWatcher(ctx) + } +} + +func (d *Generic) batchQueueWatcher(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case item := <-d.addToBatchQueueCh: + d.batchingQueue = append(d.batchingQueue, item) + if len(d.batchingQueue) == d.BatchingMaxQueries { + d.flushCh <- struct{}{} + } + case count := <-d.removeFromBatchQueueCh: + d.batchingQueue = d.batchingQueue[count:] + } } } @@ -452,7 +473,6 @@ func (d *Generic) processBatchQueue(ctx context.Context) { copy(dcq, d.batchingQueue) ids, err := d.processBatchedInserts(ctx, dcq) - if err != nil { logrus.WithError(err).Error("Process batch queue error") return @@ -462,7 +482,9 @@ func (d *Generic) processBatchQueue(ctx context.Context) { insertItem.retCh <- ids[i] } - d.batchingQueue = make([]BatchedInsert, 0, d.BatchingMaxQueries) + d.removeFromBatchQueueCh <- len(dcq) + } else { + fmt.Printf("Queue is empty: %d\n", len(d.batchingQueue)) } } @@ -482,10 +504,7 @@ func (d *Generic) watchBatchQueue(ctx context.Context) { func (d *Generic) batchQueryRowPrepared(ctx context.Context, txName, sqli string, prepared *sql.Stmt, args ...interface{}) (id int64) { ret := make(chan int64) - d.batchingQueue = append(d.batchingQueue, BatchedInsert{retCh: ret, args: args}) - if len(d.batchingQueue) == d.BatchingMaxQueries { - d.flushCh <- struct{}{} - } + d.addToBatchQueueCh <- BatchedInsert{retCh: ret, args: args} return <-ret }