diff --git a/pkg/kine/drivers/generic/generic.go b/pkg/kine/drivers/generic/generic.go index 1a464c7c..e75323f6 100644 --- a/pkg/kine/drivers/generic/generic.go +++ b/pkg/kine/drivers/generic/generic.go @@ -5,13 +5,11 @@ import ( "database/sql" "fmt" "regexp" - "strconv" "strings" "sync" "time" "github.com/canonical/k8s-dqlite/pkg/kine/prepared" - "github.com/pkg/errors" "github.com/sirupsen/logrus" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" @@ -78,7 +76,6 @@ func init() { if err != nil { logrus.WithError(err).Warning("Otel failed to create create counter") } - } var ( @@ -90,39 +87,20 @@ var ( listSQL = fmt.Sprintf(` SELECT %s - FROM kine kv + FROM kine AS kv JOIN ( SELECT MAX(mkv.id) as id - FROM kine mkv + FROM kine AS mkv WHERE mkv.name >= ? AND mkv.name < ? - %%s - GROUP BY mkv.name) maxkv + AND mkv.id <= ? + GROUP BY mkv.name + ) AS maxkv ON maxkv.id = kv.id - WHERE - (kv.deleted = 0 OR ?) + WHERE (kv.deleted = 0 OR ?) ORDER BY kv.name ASC, kv.id ASC `, columns) - revisionAfterSQL = fmt.Sprintf(` - SELECT * - FROM ( - SELECT %s - FROM kine AS kv - JOIN ( - SELECT MAX(mkv.id) AS id - FROM kine AS mkv - WHERE mkv.name >= ? AND mkv.name < ? - AND mkv.id <= ? - GROUP BY mkv.name - ) AS maxkv - ON maxkv.id = kv.id - WHERE - ? OR kv.deleted = 0 - ) AS lkv - ORDER BY lkv.name ASC, lkv.theid ASC - `, columns) - revisionIntervalSQL = ` SELECT ( SELECT MAX(prev_revision) @@ -132,6 +110,100 @@ var ( SELECT MAX(id) FROM kine ) AS high` + + listRevisionStartSQL = listSQL + + countRevisionSQL = fmt.Sprintf(` + SELECT COUNT(*) + FROM ( + %s + )`, listSQL) + + afterSQLPrefix = fmt.Sprintf(` + SELECT %s + FROM kine AS kv + WHERE + kv.name >= ? AND kv.name < ? + AND kv.id > ? + ORDER BY kv.id ASC`, columns) + + afterSQL = fmt.Sprintf(` + SELECT %s + FROM kine AS kv + WHERE kv.id > ? + ORDER BY kv.id ASC + `, columns) + + deleteRevSQL = ` + DELETE FROM kine + WHERE id = ?` + + updateCompactSQL = ` + UPDATE kine + SET prev_revision = max(prev_revision, ?) + WHERE name = 'compact_rev_key'` + + deleteSQL = ` + INSERT INTO kine(name, created, deleted, create_revision, prev_revision, lease, value, old_value) + SELECT + name, + 0 AS created, + 1 AS deleted, + CASE + WHEN kine.created THEN id + ELSE create_revision + END AS create_revision, + id AS prev_revision, + lease, + NULL AS value, + value AS old_value + FROM kine WHERE id = (SELECT MAX(id) FROM kine WHERE name = ?) + AND deleted = 0 + AND id = ?` + + createSQL = ` + INSERT INTO kine(name, created, deleted, create_revision, prev_revision, lease, value, old_value) + SELECT + ? AS name, + 1 AS created, + 0 AS deleted, + 0 AS create_revision, + COALESCE(id, 0) AS prev_revision, + ? AS lease, + ? AS value, + NULL AS old_value + FROM ( + SELECT MAX(id) AS id, deleted + FROM kine + WHERE name = ? + ) maxkv + WHERE maxkv.deleted = 1 OR id IS NULL` + + updateSQL = ` + INSERT INTO kine(name, created, deleted, create_revision, prev_revision, lease, value, old_value) + SELECT + ? AS name, + 0 AS created, + 0 AS deleted, + CASE + WHEN kine.created THEN id + ELSE create_revision + END AS create_revision, + id AS prev_revision, + ? AS lease, + ? AS value, + value AS old_value + FROM kine WHERE id = (SELECT MAX(id) FROM kine WHERE name = ?) + AND deleted = 0 + AND id = ?` + + fillSQL = ` + INSERT INTO kine(id, name, created, deleted, create_revision, prev_revision, lease, value, old_value) + VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?)` + + getSizeSQL = ` + SELECT (page_count - freelist_count) * page_size + FROM pragma_page_count(), pragma_page_size(), pragma_freelist_count()` ) const maxRetries = 500 @@ -150,27 +222,11 @@ type ErrCode func(error) string type Generic struct { sync.Mutex - LockWrites bool - DB *prepared.DB - GetCurrentSQL string - RevisionSQL string - ListRevisionStartSQL string - GetRevisionAfterSQL string - CountCurrentSQL string - CountRevisionSQL string - AfterSQLPrefix string - AfterSQL string - DeleteRevSQL string - CompactSQL string - UpdateCompactSQL string - DeleteSQL string - FillSQL string - CreateSQL string - UpdateSQL string - GetSizeSQL string - Retry ErrRetry - TranslateErr TranslateErr - ErrCode ErrCode + LockWrites bool + DB *prepared.DB + Retry ErrRetry + TranslateErr TranslateErr + ErrCode ErrCode // CompactInterval is interval between database compactions performed by kine. CompactInterval time.Duration @@ -208,22 +264,6 @@ func configureConnectionPooling(connPoolConfig *ConnectionPoolConfig, db *sql.DB db.SetConnMaxIdleTime(connPoolConfig.MaxIdleTime) } -func q(sql, param string, numbered bool) string { - if param == "?" && !numbered { - return sql - } - - regex := regexp.MustCompile(`\?`) - n := 0 - return regex.ReplaceAllStringFunc(sql, func(string) string { - if numbered { - n++ - return param + strconv.Itoa(n) - } - return param - }) -} - func openAndTest(driverName, dataSourceName string) (*sql.DB, error) { db, err := sql.Open(driverName, dataSourceName) if err != nil { @@ -240,7 +280,7 @@ func openAndTest(driverName, dataSourceName string) (*sql.DB, error) { return db, nil } -func Open(ctx context.Context, driverName, dataSourceName string, connPoolConfig *ConnectionPoolConfig, paramCharacter string, numbered bool) (*Generic, error) { +func Open(ctx context.Context, driverName, dataSourceName string, connPoolConfig *ConnectionPoolConfig) (*Generic, error) { var ( db *sql.DB err error @@ -263,103 +303,6 @@ func Open(ctx context.Context, driverName, dataSourceName string, connPoolConfig return &Generic{ DB: prepared.New(db), - - GetCurrentSQL: q(fmt.Sprintf(listSQL, ""), paramCharacter, numbered), - ListRevisionStartSQL: q(fmt.Sprintf(listSQL, "AND mkv.id <= ?"), paramCharacter, numbered), - GetRevisionAfterSQL: q(revisionAfterSQL, paramCharacter, numbered), - - CountCurrentSQL: q(fmt.Sprintf(` - SELECT (%s), COUNT(*) - FROM ( - %s - ) c`, revSQL, fmt.Sprintf(listSQL, "")), paramCharacter, numbered), - - CountRevisionSQL: q(fmt.Sprintf(` - SELECT (%s), COUNT(c.theid) - FROM ( - %s - ) c`, revSQL, fmt.Sprintf(listSQL, "AND mkv.id <= ?")), paramCharacter, numbered), - - AfterSQLPrefix: q(fmt.Sprintf(` - SELECT %s - FROM kine AS kv - WHERE - kv.name >= ? AND kv.name < ? - AND kv.id > ? - ORDER BY kv.id ASC`, columns), paramCharacter, numbered), - - AfterSQL: q(fmt.Sprintf(` - SELECT %s - FROM kine AS kv - WHERE kv.id > ? - ORDER BY kv.id ASC - `, columns), paramCharacter, numbered), - - DeleteRevSQL: q(` - DELETE FROM kine - WHERE id = ?`, paramCharacter, numbered), - - UpdateCompactSQL: q(` - UPDATE kine - SET prev_revision = max(prev_revision, ?) - WHERE name = 'compact_rev_key'`, paramCharacter, numbered), - - DeleteSQL: q(` - INSERT INTO kine(name, created, deleted, create_revision, prev_revision, lease, value, old_value) - SELECT - name, - 0 AS created, - 1 AS deleted, - CASE - WHEN kine.created THEN id - ELSE create_revision - END AS create_revision, - id AS prev_revision, - lease, - NULL AS value, - value AS old_value - FROM kine WHERE id = (SELECT MAX(id) FROM kine WHERE name = ?) - AND deleted = 0 - AND id = ?`, paramCharacter, numbered), - - CreateSQL: q(` - INSERT INTO kine(name, created, deleted, create_revision, prev_revision, lease, value, old_value) - SELECT - ? AS name, - 1 AS created, - 0 AS deleted, - 0 AS create_revision, - COALESCE(id, 0) AS prev_revision, - ? AS lease, - ? AS value, - NULL AS old_value - FROM ( - SELECT MAX(id) AS id, deleted - FROM kine - WHERE name = ? - ) maxkv - WHERE maxkv.deleted = 1 OR id IS NULL`, paramCharacter, numbered), - - UpdateSQL: q(` - INSERT INTO kine(name, created, deleted, create_revision, prev_revision, lease, value, old_value) - SELECT - ? AS name, - 0 AS created, - 0 AS deleted, - CASE - WHEN kine.created THEN id - ELSE create_revision - END AS create_revision, - id AS prev_revision, - ? AS lease, - ? AS value, - value AS old_value - FROM kine WHERE id = (SELECT MAX(id) FROM kine WHERE name = ?) - AND deleted = 0 - AND id = ?`, paramCharacter, numbered), - - FillSQL: q(`INSERT INTO kine(id, name, created, deleted, create_revision, prev_revision, lease, value, old_value) - VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?)`, paramCharacter, numbered), }, err } @@ -460,62 +403,29 @@ func (d *Generic) execute(ctx context.Context, txName, query string, args ...int return result, err } -func (d *Generic) CountCurrent(ctx context.Context, prefix string, startKey string) (int64, int64, error) { - var ( - rev sql.NullInt64 - id int64 - ) - - start, end := getPrefixRange(prefix) - if startKey != "" { - start = startKey + "\x01" - } - rows, err := d.query(ctx, "count_current", d.CountCurrentSQL, start, end, false) - if err != nil { - return 0, 0, err - } - defer rows.Close() - - if !rows.Next() { - if err := rows.Err(); err != nil { - return 0, 0, err - } - return 0, 0, sql.ErrNoRows - } - - if err := rows.Scan(&rev, &id); err != nil { - return 0, 0, err - } - return rev.Int64, id, nil -} - -func (d *Generic) Count(ctx context.Context, prefix, startKey string, revision int64) (int64, int64, error) { - var ( - rev sql.NullInt64 - id int64 - ) - +func (d *Generic) Count(ctx context.Context, prefix, startKey string, revision int64) (int64, error) { start, end := getPrefixRange(prefix) if startKey != "" { start = startKey + "\x01" } - rows, err := d.query(ctx, "count_revision", d.CountRevisionSQL, start, end, revision, false) + rows, err := d.query(ctx, "count_revision", countRevisionSQL, start, end, revision, false) if err != nil { - return 0, 0, err + return 0, err } defer rows.Close() if !rows.Next() { if err := rows.Err(); err != nil { - return 0, 0, err + return 0, err } - return 0, 0, sql.ErrNoRows + return 0, sql.ErrNoRows } - if err := rows.Scan(&rev, &id); err != nil { - return 0, 0, err + var id int64 + if err := rows.Scan(&id); err != nil { + return 0, err } - return rev.Int64, id, err + return id, err } func (d *Generic) Create(ctx context.Context, key string, value []byte, ttl int64) (rev int64, succeeded bool, err error) { @@ -537,7 +447,7 @@ func (d *Generic) Create(ctx context.Context, key string, value []byte, ttl int6 ) createCnt.Add(ctx, 1) - result, err := d.execute(ctx, "create_sql", d.CreateSQL, key, ttl, value, key) + result, err := d.execute(ctx, "create_sql", createSQL, key, ttl, value, key) if err != nil { logrus.WithError(err).Error("failed to create key") return 0, false, err @@ -564,7 +474,7 @@ func (d *Generic) Update(ctx context.Context, key string, value []byte, preRev, }() updateCnt.Add(ctx, 1) - result, err := d.execute(ctx, "update_sql", d.UpdateSQL, key, ttl, value, key, preRev) + result, err := d.execute(ctx, "update_sql", updateSQL, key, ttl, value, key, preRev) if err != nil { logrus.WithError(err).Error("failed to update key") return 0, false, err @@ -589,7 +499,7 @@ func (d *Generic) Delete(ctx context.Context, key string, revision int64) (rev i }() span.SetAttributes(attribute.String("key", key)) - result, err := d.execute(ctx, "delete_sql", d.DeleteSQL, key, revision) + result, err := d.execute(ctx, "delete_sql", deleteSQL, key, revision) if err != nil { logrus.WithError(err).Error("failed to delete key") return 0, false, err @@ -688,7 +598,7 @@ func (d *Generic) tryCompact(ctx context.Context, start, end int64) (err error) return err } - if _, err = tx.ExecContext(ctx, d.UpdateCompactSQL, end); err != nil { + if _, err = tx.ExecContext(ctx, updateCompactSQL, end); err != nil { return err } return tx.Commit() @@ -737,42 +647,21 @@ func (d *Generic) DeleteRevision(ctx context.Context, revision int64) error { }() span.SetAttributes(attribute.Int64("revision", revision)) - _, err = d.execute(ctx, "delete_rev_sql", d.DeleteRevSQL, revision) + _, err = d.execute(ctx, "delete_rev_sql", deleteRevSQL, revision) return err } -func (d *Generic) ListCurrent(ctx context.Context, prefix, startKey string, limit int64, includeDeleted bool) (*sql.Rows, error) { - sql := d.GetCurrentSQL +func (d *Generic) List(ctx context.Context, prefix, startKey string, limit, revision int64, includeDeleted bool) (*sql.Rows, error) { start, end := getPrefixRange(prefix) - // NOTE(neoaggelos): don't ignore startKey if set if startKey != "" { start = startKey + "\x01" } - + sql := listRevisionStartSQL if limit > 0 { sql = fmt.Sprintf("%s LIMIT ?", sql) - return d.query(ctx, "get_current_sql_limit", sql, start, end, includeDeleted, limit) + return d.query(ctx, "list_revision_start_sql_limit", sql, start, end, revision, includeDeleted, limit) } - return d.query(ctx, "get_current_sql", sql, start, end, includeDeleted) -} - -func (d *Generic) List(ctx context.Context, prefix, startKey string, limit, revision int64, includeDeleted bool) (*sql.Rows, error) { - start, end := getPrefixRange(prefix) - if startKey == "" { - sql := d.ListRevisionStartSQL - if limit > 0 { - sql = fmt.Sprintf("%s LIMIT ?", sql) - return d.query(ctx, "list_revision_start_sql_limit", sql, start, end, revision, includeDeleted, limit) - } - return d.query(ctx, "list_revision_start_sql", sql, start, end, revision, includeDeleted) - } - - sql := d.GetRevisionAfterSQL - if limit > 0 { - sql = fmt.Sprintf("%s LIMIT ?", sql) - return d.query(ctx, "get_revision_after_sql_limit", sql, startKey+"\x01", end, revision, includeDeleted, limit) - } - return d.query(ctx, "get_revision_after_sql", sql, startKey+"\x01", end, revision, includeDeleted) + return d.query(ctx, "list_revision_start_sql", sql, start, end, revision, includeDeleted) } func (d *Generic) CurrentRevision(ctx context.Context) (int64, error) { @@ -808,7 +697,7 @@ func (d *Generic) CurrentRevision(ctx context.Context) (int64, error) { func (d *Generic) AfterPrefix(ctx context.Context, prefix string, rev, limit int64) (*sql.Rows, error) { start, end := getPrefixRange(prefix) - sql := d.AfterSQLPrefix + sql := afterSQLPrefix if limit > 0 { sql = fmt.Sprintf("%s LIMIT ?", sql) return d.query(ctx, "after_sql_prefix_limit", sql, start, end, rev, limit) @@ -817,7 +706,7 @@ func (d *Generic) AfterPrefix(ctx context.Context, prefix string, rev, limit int } func (d *Generic) After(ctx context.Context, rev, limit int64) (*sql.Rows, error) { - sql := d.AfterSQL + sql := afterSQL if limit > 0 { sql = fmt.Sprintf("%s LIMIT ?", sql) return d.query(ctx, "after_sql_limit", sql, rev, limit) @@ -827,7 +716,7 @@ func (d *Generic) After(ctx context.Context, rev, limit int64) (*sql.Rows, error func (d *Generic) Fill(ctx context.Context, revision int64) error { fillCnt.Add(ctx, 1) - _, err := d.execute(ctx, "fill_sql", d.FillSQL, revision, fmt.Sprintf("gap-%d", revision), 0, 1, 0, 0, 0, nil, nil) + _, err := d.execute(ctx, "fill_sql", fillSQL, revision, fmt.Sprintf("gap-%d", revision), 0, 1, 0, 0, 0, nil, nil) return err } @@ -836,10 +725,7 @@ func (d *Generic) IsFill(key string) bool { } func (d *Generic) GetSize(ctx context.Context) (int64, error) { - if d.GetSizeSQL == "" { - return 0, errors.New("driver does not support size reporting") - } - rows, err := d.query(ctx, "get_size_sql", d.GetSizeSQL) + rows, err := d.query(ctx, "get_size_sql", getSizeSQL) if err != nil { return 0, err } diff --git a/pkg/kine/drivers/sqlite/sqlite.go b/pkg/kine/drivers/sqlite/sqlite.go index 42982312..f30cd62b 100644 --- a/pkg/kine/drivers/sqlite/sqlite.go +++ b/pkg/kine/drivers/sqlite/sqlite.go @@ -60,7 +60,7 @@ func NewVariant(ctx context.Context, driverName, dataSourceName string, connecti opts.dsn = "./db/state.db?_journal=WAL&_synchronous=FULL&_foreign_keys=1" } - dialect, err := generic.Open(ctx, driverName, opts.dsn, connectionPoolConfig, "?", false) + dialect, err := generic.Open(ctx, driverName, opts.dsn, connectionPoolConfig) if err != nil { return nil, nil, err } @@ -84,7 +84,6 @@ func NewVariant(ctx context.Context, driverName, dataSourceName string, connecti } return err } - dialect.GetSizeSQL = `SELECT (page_count - freelist_count) * page_size FROM pragma_page_count(), pragma_page_size(), pragma_freelist_count()` dialect.CompactInterval = opts.compactInterval dialect.PollInterval = opts.pollInterval diff --git a/pkg/kine/logstructured/logstructured.go b/pkg/kine/logstructured/logstructured.go index 4ce36383..84713c5f 100644 --- a/pkg/kine/logstructured/logstructured.go +++ b/pkg/kine/logstructured/logstructured.go @@ -70,6 +70,7 @@ func (l *LogStructured) Start(ctx context.Context) error { func (l *LogStructured) Wait() { l.wg.Wait() + l.log.Wait() } func (l *LogStructured) Get(ctx context.Context, key, rangeEnd string, limit, revision int64) (revRet int64, kvRet *server.KeyValue, errRet error) { @@ -81,9 +82,8 @@ func (l *LogStructured) Get(ctx context.Context, key, rangeEnd string, limit, re attribute.Int64("revision", revision), ) defer func() { - l.adjustRevision(ctx, &revRet) logrus.Debugf("GET %s, rev=%d => rev=%d, kv=%v, err=%v", key, revision, revRet, kvRet != nil, errRet) - span.SetAttributes(attribute.Int64("adjusted-revision", revRet)) + span.SetAttributes(attribute.Int64("current-revision", revRet)) span.RecordError(errRet) span.End() }() @@ -114,8 +114,7 @@ func (l *LogStructured) get(ctx context.Context, key, rangeEnd string, limit, re span.AddEvent("key already compacted") // ignore compacted when getting by revision err = nil - } - if err != nil { + } else if err != nil { return 0, nil, err } if revision != 0 { @@ -127,16 +126,6 @@ func (l *LogStructured) get(ctx context.Context, key, rangeEnd string, limit, re return rev, events[0], nil } -func (l *LogStructured) adjustRevision(ctx context.Context, rev *int64) { - if *rev != 0 { - return - } - - if newRev, err := l.log.CurrentRevision(ctx); err == nil { - *rev = newRev - } -} - func (l *LogStructured) Create(ctx context.Context, key string, value []byte, lease int64) (rev int64, created bool, err error) { rev, created, err = l.log.Create(ctx, key, value, lease) logrus.Debugf("CREATE %s, size=%d, lease=%d => rev=%d, err=%v", key, len(value), lease, rev, err) @@ -151,17 +140,15 @@ func (l *LogStructured) Delete(ctx context.Context, key string, revision int64) func (l *LogStructured) List(ctx context.Context, prefix, startKey string, limit, revision int64) (revRet int64, kvRet []*server.KeyValue, errRet error) { ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.List", otelName)) + span.SetAttributes( + attribute.String("prefix", prefix), + attribute.String("startKey", startKey), + attribute.Int64("limit", limit), + attribute.Int64("revision", revision), + ) defer func() { logrus.Debugf("LIST %s, start=%s, limit=%d, rev=%d => rev=%d, kvs=%d, err=%v", prefix, startKey, limit, revision, revRet, len(kvRet), errRet) - span.SetAttributes( - attribute.String("prefix", prefix), - attribute.String("startKey", startKey), - attribute.Int64("limit", limit), - attribute.Int64("revision", revision), - attribute.Int64("adjusted-revision", revRet), - attribute.Int64("kv-count", int64(len(kvRet))), - ) span.RecordError(errRet) span.End() }() @@ -170,23 +157,10 @@ func (l *LogStructured) List(ctx context.Context, prefix, startKey string, limit if err != nil { return 0, nil, err } - if revision == 0 && len(events) == 0 { - // if no revision is requested and no events are returned, then - // get the current revision and relist. Relist is required because - // between now and getting the current revision something could have - // been created. - currentRev, err := l.log.CurrentRevision(ctx) - if err != nil { - return 0, nil, err - } - return l.List(ctx, prefix, startKey, limit, currentRev) - } else if revision != 0 { - rev = revision - } - kvs := make([]*server.KeyValue, 0, len(events)) - for _, event := range events { - kvs = append(kvs, event.KV) + kvs := make([]*server.KeyValue, len(events)) + for i, event := range events { + kvs[i] = event.KV } return rev, kvs, nil } @@ -199,27 +173,13 @@ func (l *LogStructured) Count(ctx context.Context, prefix, startKey string, revi attribute.String("prefix", prefix), attribute.String("startKey", startKey), attribute.Int64("revision", revision), - attribute.Int64("adjusted-revision", revRet), + attribute.Int64("current-revision", revRet), attribute.Int64("count", count), ) span.RecordError(err) span.End() }() - rev, count, err := l.log.Count(ctx, prefix, startKey, revision) - if err != nil { - return 0, 0, err - } - - if count == 0 { - // if count is zero, then so is revision, so now get the current revision and re-count at that revision - currentRev, err := l.log.CurrentRevision(ctx) - if err != nil { - return 0, 0, err - } - rev, rows, err := l.List(ctx, prefix, prefix, 1000, currentRev) - return rev, int64(len(rows)), err - } - return rev, count, nil + return l.log.Count(ctx, prefix, startKey, revision) } func (l *LogStructured) Update(ctx context.Context, key string, value []byte, revision, lease int64) (revRet int64, updateRet bool, errRet error) { @@ -231,7 +191,7 @@ func (l *LogStructured) Update(ctx context.Context, key string, value []byte, re attribute.Int64("revision", revision), attribute.Int64("lease", lease), attribute.Int64("value-size", int64(len(value))), - attribute.Int64("adjusted-revision", revRet), + attribute.Int64("current-revision", revRet), attribute.Bool("updated", updateRet), ) span.End() diff --git a/pkg/kine/logstructured/sqllog/sql.go b/pkg/kine/logstructured/sqllog/sql.go index dab85524..282770c6 100644 --- a/pkg/kine/logstructured/sqllog/sql.go +++ b/pkg/kine/logstructured/sqllog/sql.go @@ -58,10 +58,8 @@ func New(d Dialect) *SQLLog { } type Dialect interface { - ListCurrent(ctx context.Context, prefix, startKey string, limit int64, includeDeleted bool) (*sql.Rows, error) List(ctx context.Context, prefix, startKey string, limit, revision int64, includeDeleted bool) (*sql.Rows, error) - CountCurrent(ctx context.Context, prefix, startKey string) (int64, int64, error) - Count(ctx context.Context, prefix, startKey string, revision int64) (int64, int64, error) + Count(ctx context.Context, prefix, startKey string, revision int64) (int64, error) CurrentRevision(ctx context.Context) (int64, error) AfterPrefix(ctx context.Context, prefix string, rev, limit int64) (*sql.Rows, error) After(ctx context.Context, rev, limit int64) (*sql.Rows, error) @@ -192,34 +190,32 @@ func (s *SQLLog) After(ctx context.Context, prefix string, revision, limit int64 attribute.Int64("revision", revision), attribute.Int64("limit", limit), ) - rows, err := s.d.AfterPrefix(ctx, prefix, revision, limit) + + compactRevision, currentRevision, err := s.d.GetCompactRevision(ctx) if err != nil { return 0, nil, err } + if revision == 0 || revision > currentRevision { + revision = currentRevision + } else if revision < compactRevision { + return currentRevision, nil, server.ErrCompacted + } - result, err := RowsToEvents(rows) + rows, err := s.d.AfterPrefix(ctx, prefix, revision, limit) if err != nil { return 0, nil, err } - compact, rev, err := s.d.GetCompactRevision(ctx) - + result, err := RowsToEvents(rows) if err != nil { return 0, nil, err } - - if revision > 0 && revision < compact { - return rev, result, server.ErrCompacted - } - - return rev, result, err + return currentRevision, result, err } func (s *SQLLog) List(ctx context.Context, prefix, startKey string, limit, revision int64, includeDeleted bool) (int64, []*server.Event, error) { - var ( - rows *sql.Rows - err error - ) + var err error + ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.List", otelName)) defer func() { span.RecordError(err) @@ -233,6 +229,16 @@ func (s *SQLLog) List(ctx context.Context, prefix, startKey string, limit, revis attribute.Bool("includeDeleted", includeDeleted), ) + compactRevision, currentRevision, err := s.d.GetCompactRevision(ctx) + if err != nil { + return 0, nil, err + } + if revision == 0 || revision > currentRevision { + revision = currentRevision + } else if revision < compactRevision { + return currentRevision, nil, server.ErrCompacted + } + // It's assumed that when there is a start key that that key exists. if strings.HasSuffix(prefix, "/") { // In the situation of a list start the startKey will not exist so set to "" @@ -244,11 +250,7 @@ func (s *SQLLog) List(ctx context.Context, prefix, startKey string, limit, revis startKey = "" } - if revision == 0 { - rows, err = s.d.ListCurrent(ctx, prefix, startKey, limit, includeDeleted) - } else { - rows, err = s.d.List(ctx, prefix, startKey, limit, revision, includeDeleted) - } + rows, err := s.d.List(ctx, prefix, startKey, limit, revision, includeDeleted) if err != nil { return 0, nil, err } @@ -258,18 +260,7 @@ func (s *SQLLog) List(ctx context.Context, prefix, startKey string, limit, revis return 0, nil, err } - compact, rev, err := s.d.GetCompactRevision(ctx) - if err != nil { - return 0, nil, err - } - - if revision > 0 && revision < compact { - return rev, result, server.ErrCompacted - } - - s.notifyWatcherPoll(rev) - - return rev, result, err + return currentRevision, result, err } func RowsToEvents(rows *sql.Rows) ([]*server.Event, error) { @@ -486,11 +477,21 @@ func (s *SQLLog) Count(ctx context.Context, prefix, startKey string, revision in attribute.String("startKey", startKey), attribute.Int64("revision", revision), ) - if revision == 0 { - return s.d.CountCurrent(ctx, prefix, startKey) - } - return s.d.Count(ctx, prefix, startKey, revision) + compactRevision, currentRevision, err := s.d.GetCompactRevision(ctx) + if err != nil { + return 0, 0, err + } + if revision == 0 || revision > currentRevision { + revision = currentRevision + } else if revision < compactRevision { + return currentRevision, 0, server.ErrCompacted + } + count, err := s.d.Count(ctx, prefix, startKey, revision) + if err != nil { + return 0, 0, err + } + return currentRevision, count, nil } func (s *SQLLog) Create(ctx context.Context, key string, value []byte, lease int64) (int64, bool, error) { diff --git a/test/update_test.go b/test/update_test.go index 78e31e50..b71e013e 100644 --- a/test/update_test.go +++ b/test/update_test.go @@ -120,7 +120,7 @@ func BenchmarkUpdate(b *testing.B) { run := func(start int) { defer wg.Done() benchKey := fmt.Sprintf("benchKey-%d", start) - for i, lastModRev := 0, int64(0); i < b.N; i += workers { + for i, lastModRev := start, int64(0); i < b.N; i += workers { value := fmt.Sprintf("value-%d", i) lastModRev = updateRev(ctx, g, kine.client, benchKey, lastModRev, value) }