diff --git a/pkg/kine/drivers/dqlite/dqlite.go b/pkg/kine/drivers/dqlite/dqlite.go index f3967c0a..d41fcf6d 100644 --- a/pkg/kine/drivers/dqlite/dqlite.go +++ b/pkg/kine/drivers/dqlite/dqlite.go @@ -40,7 +40,6 @@ func NewVariant(ctx context.Context, datasourceName string, connectionPoolConfig if err := migrate(ctx, generic.DB.Underlying()); err != nil { return nil, nil, errors.Wrap(err, "failed to migrate DB from sqlite") } - generic.LockWrites = true generic.Retry = func(err error) bool { // get the inner-most error if possible err = errors.Cause(err) diff --git a/pkg/kine/drivers/generic/batch.go b/pkg/kine/drivers/generic/batch.go new file mode 100644 index 00000000..5912df97 --- /dev/null +++ b/pkg/kine/drivers/generic/batch.go @@ -0,0 +1,207 @@ +package generic + +import ( + "context" + "database/sql" + "sync" + + "github.com/canonical/k8s-dqlite/pkg/kine/prepared" +) + +type BatchConn interface { + ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) +} + +var _ BatchConn = &sql.DB{} +var _ BatchConn = &sql.Tx{} +var _ BatchConn = &sql.Conn{} + +var _ BatchConn = &prepared.DB{} +var _ BatchConn = &prepared.Tx{} + +type batchStatus int + +const ( + batchNotStarted batchStatus = iota + batchStarted + batchRunning +) + +type Batch struct { + db *prepared.DB + mu sync.Mutex + cv sync.Cond + status batchStatus + + queue []*batchJob + runId int64 +} + +func NewBatch(db *prepared.DB) *Batch { + b := &Batch{ + db: db, + } + b.cv.L = &b.mu + return b +} + +func (b *Batch) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { + b.mu.Lock() + defer b.mu.Unlock() + + runId := b.runId + if b.status == batchRunning { + // The current run is already taking place. + runId++ + } + + job := &batchJob{ + ctx: ctx, + query: query, + args: args, + runId: runId, + } + b.queue = append(b.queue, job) + + b.run() + + for job.runId >= b.runId { + b.cv.Wait() + } + + if job.err != nil { + return nil, job.err + } + + return job, nil +} + +// run starts a batching job which will run until queue exaustion. +// run does not block other goroutine from enqueuing new jobs. +// +// It must be called while holding the mu lock. +func (b *Batch) run() { + if b.status == batchNotStarted { + b.status = batchStarted + + go func() { + b.mu.Lock() + defer b.mu.Unlock() + + b.status = batchRunning + defer func() { b.status = batchNotStarted }() + + for len(b.queue) > 0 { + queue := b.queue + b.queue = nil + + b.mu.Unlock() + b.execQueue(context.TODO(), queue) + b.mu.Lock() + + b.runId++ + b.cv.Broadcast() + } + }() + } +} + +func (b *Batch) execQueue(ctx context.Context, queue []*batchJob) { + // TODO limit batch duration + // TODO limit batch size + if len(queue) == 0 { + return // This should never happen. + } + if len(queue) == 1 { + // We don't need to address the error here as it will be + // handled by the goroutine waiting for this result + queue[0].exec(queue[0].ctx, b.db) + return + } + + transaction := func() error { + // TODO: this should be BEGIN IMMEDIATE + tx, err := b.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + + for _, q := range queue { + // FIXME: + // In the case of SQLITE_FULL SQLITE_IOERR SQLITE_BUSY SQLITE_NOMEM + // we should explicitly rollback the whole transaction. However, it + // is a bit unclear to me what to do next though as: + // - SQLITE_FULL, SQLITE_IOERR mean that we have problems with the disk + // so, even retrying the batch will not work. We might throttle the + // max batch size, hoping in a checkpoint? + // - SQLITE_BUSY should never happen if we manage to get `IMMEDIATE` + // transactions in. + // - SQLITE_NOMEM, again, we could throttle here? + if err := q.exec(ctx, tx); err != nil { + return err + } + } + + return tx.Commit() + } + + if err := transaction(); err != nil { + for _, q := range queue { + q.err = err + } + } +} + +type batchJob struct { + ctx context.Context + query string + args []any + + runId int64 + lastInsertId int64 + rowsAffected int64 + err error +} + +var _ sql.Result = &batchJob{} + +func (bj *batchJob) exec(ctx context.Context, conn BatchConn) error { + select { + case <-bj.ctx.Done(): + bj.err = bj.ctx.Err() + return bj.err + default: + // From this point on, the job is not interruptible anymore + // as interrupting would mean that we would be forced to + // ROLLBACK the whole transaction. + } + + var result sql.Result + result, bj.err = conn.ExecContext(ctx, bj.query, bj.args...) + if bj.err != nil { + return bj.err + } + + bj.rowsAffected, bj.err = result.RowsAffected() + if bj.err != nil { + return bj.err + } + + bj.lastInsertId, bj.err = result.LastInsertId() + if bj.err != nil { + return bj.err + } + + return nil +} + +// LastInsertId implements sql.Result. +func (bj *batchJob) LastInsertId() (int64, error) { + return bj.lastInsertId, nil +} + +// RowsAffected implements sql.Result. +func (bj *batchJob) RowsAffected() (int64, error) { + return bj.rowsAffected, nil +} diff --git a/pkg/kine/drivers/generic/generic.go b/pkg/kine/drivers/generic/generic.go index 2a013806..fe41dc13 100644 --- a/pkg/kine/drivers/generic/generic.go +++ b/pkg/kine/drivers/generic/generic.go @@ -7,7 +7,6 @@ import ( "regexp" "strconv" "strings" - "sync" "time" "github.com/canonical/k8s-dqlite/pkg/kine/prepared" @@ -128,11 +127,9 @@ type TranslateErr func(error) error type ErrCode func(error) string type Generic struct { - sync.Mutex - - LockWrites bool LastInsertID bool DB *prepared.DB + batch *Batch GetCurrentSQL string RevisionSQL string ListRevisionStartSQL string @@ -163,11 +160,6 @@ type Generic struct { PollInterval time.Duration // WatchQueryTimeout is the timeout on the after query in the poll loop. WatchQueryTimeout time.Duration - - batchMu sync.Mutex - batchCv *sync.Cond - batchRunnig bool - batch []*batchedChange } type ConnectionPoolConfig struct { @@ -250,13 +242,11 @@ func Open(ctx context.Context, driverName, dataSourceName string, connPoolConfig } configureConnectionPooling(connPoolConfig, db) + wrappedDb := prepared.New(db) - if err != nil { - return nil, err - } - - driver := &Generic{ - DB: prepared.New(db), + return &Generic{ + DB: wrappedDb, + batch: NewBatch(wrappedDb), GetCurrentSQL: q(fmt.Sprintf(listSQL, ""), paramCharacter, numbered), ListRevisionStartSQL: q(fmt.Sprintf(listSQL, "AND mkv.id <= ?"), paramCharacter, numbered), @@ -306,11 +296,11 @@ func Open(ctx context.Context, driverName, dataSourceName string, connPoolConfig DeleteSQL: q(` INSERT INTO kine(name, created, deleted, create_revision, prev_revision, lease, value, old_value) - SELECT + SELECT name, 0 AS created, 1 AS deleted, - CASE + CASE WHEN kine.created THEN id ELSE create_revision END AS create_revision, @@ -324,14 +314,14 @@ func Open(ctx context.Context, driverName, dataSourceName string, connPoolConfig CreateSQL: q(` INSERT INTO kine(name, created, deleted, create_revision, prev_revision, lease, value, old_value) - SELECT + SELECT ? AS name, 1 AS created, 0 AS deleted, - 0 AS create_revision, - COALESCE(id, 0) AS prev_revision, - ? AS lease, - ? AS value, + 0 AS create_revision, + COALESCE(id, 0) AS prev_revision, + ? AS lease, + ? AS value, NULL AS old_value FROM ( SELECT MAX(id) AS id, deleted @@ -342,11 +332,11 @@ func Open(ctx context.Context, driverName, dataSourceName string, connPoolConfig UpdateSQL: q(` INSERT INTO kine(name, created, deleted, create_revision, prev_revision, lease, value, old_value) - SELECT + SELECT ? AS name, 0 AS created, 0 AS deleted, - CASE + CASE WHEN kine.created THEN id ELSE create_revision END AS create_revision, @@ -361,10 +351,7 @@ func Open(ctx context.Context, driverName, dataSourceName string, connPoolConfig FillSQL: q(`INSERT INTO kine(id, name, created, deleted, create_revision, prev_revision, lease, value, old_value) VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?)`, paramCharacter, numbered), AdmissionControlPolicy: &allowAllPolicy{}, - } - driver.batchCv = sync.NewCond(&driver.batchMu) - - return driver, err + }, err } func (d *Generic) Close() error { @@ -443,12 +430,6 @@ func (d *Generic) execute(ctx context.Context, txName, query string, args ...int } defer done() - if d.LockWrites { - d.Lock() - defer d.Unlock() - span.AddEvent("acquired write lock") - } - start := time.Now() retryCount := 0 defer func() { @@ -463,7 +444,7 @@ func (d *Generic) execute(ctx context.Context, txName, query string, args ...int } else { logrus.Tracef("EXEC (try: %d) %v : %s", retryCount, args, Stripped(query)) } - result, err = d.DB.ExecContext(ctx, query, args...) + result, err = d.batch.ExecContext(ctx, query, args...) if err == nil { break } @@ -476,10 +457,6 @@ func (d *Generic) execute(ctx context.Context, txName, query string, args ...int return result, err } -type executor interface { - ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) -} - func (d *Generic) CountCurrent(ctx context.Context, prefix string, startKey string) (int64, int64, error) { var ( rev sql.NullInt64 @@ -538,228 +515,61 @@ func (d *Generic) Count(ctx context.Context, prefix, startKey string, revision i return rev.Int64, id, err } -func (d *Generic) execBatchedOperation(ctx context.Context, change *batchedChange) (rev int64, succeeded bool, err error) { - d.batchMu.Lock() - defer d.batchMu.Unlock() - - d.batch = append(d.batch, change) - stop := context.AfterFunc(ctx, func() { - d.batchMu.Lock() - defer d.batchMu.Unlock() - - if !change.committed { - for i, c := range d.batch { - if c == change { - d.batch = append(d.batch[:i], d.batch[i+1:]...) - change.err = ctx.Err() - change.succeeded = false - change.committed = true - d.batchCv.Broadcast() - return - } - } - } - }) - defer stop() - - if !d.batchRunnig { - d.batchRunnig = true - go d.execBatch(context.TODO()) - } - - for !change.committed { - d.batchCv.Wait() - } - - return change.revision, change.succeeded, change.err -} - -func (d *Generic) execBatch(ctx context.Context) { - d.batchMu.Lock() - defer d.batchMu.Unlock() - - for len(d.batch) > 0 { - d.batchMu.Unlock() - d.execSingleBatch(ctx) - d.batchMu.Lock() - } - - d.batchRunnig = false -} +func (d *Generic) Create(ctx context.Context, key string, value []byte, ttl int64) (rev int64, succeeded bool, err error) { + ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.Create", otelName)) -func (d *Generic) execSingleBatch(ctx context.Context) { - defer d.batchCv.Broadcast() - book := func() []*batchedChange { - d.batchMu.Lock() - defer d.batchMu.Unlock() - - batch := d.batch - d.batch = nil - return batch - } - - retry := func(changes []*batchedChange) { - d.batchMu.Lock() - defer d.batchMu.Unlock() - - d.batch = append(d.batch, changes...) - } - - for i := 0; i < maxRetries; i++ { - batch := book() - - switch len(batch) { - case 0: - return - case 1: - d.exec(ctx, d.DB, batch[0]) - if batch[0].err == nil { - // Autocommit was on. - batch[0].committed = true - return - } else if d.Retry == nil || !d.Retry(batch[0].err) { - // In this case, if a query had a hard error, - // it doesn't make sense to retry it. - batch[0].committed = true - return - } else { - retry(batch) - } - default: - // FIXME it would be nice to have a `BEGIN IMMEDIATE` here, - // this way the database can never be busy after the transaction - // started... - tx, err := d.DB.BeginTx(ctx, nil) - if err != nil { - // TODO log - break - } - defer tx.Rollback() - - for i, change := range batch { - d.exec(ctx, tx, change) - if change.err != nil { - if d.Retry == nil || !d.Retry(batch[0].err) { - // In this case, if a query had a hard error, - // it doesn't make sense to retry it, but the - // whole batch needs to be rolled back. - change.committed = true - retry(batch[:i]) - retry(batch[i+1:]) - } else { - // In this case we need to retry the whole batch - retry(batch) - } - if err := tx.Rollback(); err != nil { - logrus.WithError(err).Debug("cannot rollback transaction") - } - break - } - } - if err := tx.Commit(); err != nil { - logrus.WithError(err).Error("cannot commit transaction") - } - for _, change := range batch { - change.committed = true + defer func() { + if err != nil { + if d.TranslateErr != nil { + err = d.TranslateErr(err) } - return + span.RecordError(err) } - } -} - -func (d *Generic) exec(ctx context.Context, db executor, bc *batchedChange) { - switch bc.Type { - case batchCreate: - bc.revision, bc.succeeded, bc.err = d.create(ctx, db, bc.Key, bc.Value, bc.TTL) - case batchUpdate: - bc.revision, bc.succeeded, bc.err = d.update(ctx, db, bc.Key, bc.Value, bc.PrevRevision, bc.TTL) - default: - panic("WTF") - } - if d.TranslateErr != nil && bc.err != nil { - bc.err = d.TranslateErr(bc.err) - } -} - -func (d *Generic) Create(ctx context.Context, key string, value []byte, ttl int64) (int64, bool, error) { - return d.execBatchedOperation(ctx, &batchedChange{ - Type: batchCreate, - Key: key, - Value: value, - TTL: ttl, - }) -} - -func (d *Generic) create(ctx context.Context, db executor, key string, value []byte, ttl int64) (int64, bool, error) { - ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.create", otelName)) + span.SetAttributes(attribute.Int64("revision", rev)) + span.End() + }() span.SetAttributes( attribute.String("key", key), attribute.Int64("ttl", ttl), ) - defer span.End() - result, err := db.ExecContext(ctx, d.CreateSQL, key, ttl, value, key) + result, err := d.execute(ctx, "create_sql", d.CreateSQL, key, ttl, value, key) if err != nil { + logrus.WithError(err).Error("failed to create key") return 0, false, err } - if insertCount, err := result.RowsAffected(); err != nil { - span.RecordError(err) - logrus.WithError(err).Error("failed to create key") return 0, false, err } else if insertCount == 0 { return 0, false, nil } - - rev, err := result.LastInsertId() - if err != nil { - span.RecordError(err) - logrus.WithError(err).Error("failed to retrive inserted id") - return 0, false, err - } - span.SetAttributes(attribute.Int64("revision", rev)) + rev, err = result.LastInsertId() return rev, true, err } -func (d *Generic) Update(ctx context.Context, key string, value []byte, preRev, ttl int64) (int64, bool, error) { - return d.execBatchedOperation(ctx, &batchedChange{ - Type: batchUpdate, - Key: key, - Value: value, - PrevRevision: preRev, - TTL: ttl, - }) -} - -func (d *Generic) update(ctx context.Context, db executor, key string, value []byte, preRev, ttl int64) (int64, bool, error) { +func (d *Generic) Update(ctx context.Context, key string, value []byte, preRev, ttl int64) (rev int64, updated bool, err error) { ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.Update", otelName)) - span.SetAttributes( - attribute.String("key", key), - attribute.Int64("ttl", ttl), - attribute.Int64("prevRev", preRev), - ) - defer span.End() + defer func() { + if err != nil { + if d.TranslateErr != nil { + err = d.TranslateErr(err) + } + span.RecordError(err) + } + span.End() + }() - result, err := db.ExecContext(ctx, d.UpdateSQL, key, ttl, value, key, preRev) + result, err := d.execute(ctx, "update_sql", d.UpdateSQL, key, ttl, value, key, preRev) if err != nil { + logrus.WithError(err).Error("failed to update key") return 0, false, err } - if insertCount, err := result.RowsAffected(); err != nil { - span.RecordError(err) - logrus.WithError(err).Error("failed to update key") return 0, false, err } else if insertCount == 0 { return 0, false, nil } - - rev, err := result.LastInsertId() - if err != nil { - span.RecordError(err) - logrus.WithError(err).Error("failed to retrive inserted id") - return 0, false, err - } - span.SetAttributes(attribute.Int64("revision", rev)) + rev, err = result.LastInsertId() return rev, true, err } @@ -1119,30 +929,3 @@ func (d *Generic) GetPollInterval() time.Duration { } return time.Second } - -type batchedChangeType int - -const ( - batchCreate batchedChangeType = iota + 1 - batchUpdate - batchDelete -) - -type batchedChange struct { - Type batchedChangeType - Key string - Value []byte - TTL int64 - PrevRevision int64 - - committed bool - succeeded bool - revision int64 - err error -} - -func (bc *batchedChange) Exec(ctx context.Context, db executor) { - if bc.committed { - return - } -} diff --git a/pkg/kine/prepared/db.go b/pkg/kine/prepared/db.go index 26a075d2..fa5e99e5 100644 --- a/pkg/kine/prepared/db.go +++ b/pkg/kine/prepared/db.go @@ -36,7 +36,7 @@ func New(db *sql.DB) *DB { func (db *DB) Underlying() *sql.DB { return db.underlying } func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) { - ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.ExecContext", otelName)) + ctx, span := otelTracer.Start(ctx, "DB.ExecContext") defer func() { span.RecordError(err) span.End() @@ -50,7 +50,7 @@ func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (resul } func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (rows *sql.Rows, err error) { - ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.QueryContext", otelName)) + ctx, span := otelTracer.Start(ctx, "DB.QueryContext") defer func() { span.RecordError(err) span.End()