diff --git a/pkg/kine/broadcaster/broadcaster.go b/pkg/kine/broadcaster/broadcaster.go index c7ecd014..8fa2f0e5 100644 --- a/pkg/kine/broadcaster/broadcaster.go +++ b/pkg/kine/broadcaster/broadcaster.go @@ -5,21 +5,21 @@ import ( "sync" ) -type ConnectFunc func() (chan interface{}, error) +type ConnectFunc[T any] func(ctx context.Context) (chan T, error) -type Broadcaster struct { +type Broadcaster[T any] struct { sync.Mutex running bool - subs map[chan interface{}]struct{} + subs map[chan T]struct{} } -func (b *Broadcaster) Subscribe(ctx context.Context) (<-chan interface{}, error) { +func (b *Broadcaster[T]) Subscribe(ctx context.Context) (<-chan T, error) { b.Lock() defer b.Unlock() - sub := make(chan interface{}, 100) + sub := make(chan T, 100) if b.subs == nil { - b.subs = map[chan interface{}]struct{}{} + b.subs = map[chan T]struct{}{} } b.subs[sub] = struct{}{} context.AfterFunc(ctx, func() { @@ -31,18 +31,18 @@ func (b *Broadcaster) Subscribe(ctx context.Context) (<-chan interface{}, error) return sub, nil } -func (b *Broadcaster) unsub(sub chan interface{}) { +func (b *Broadcaster[T]) unsub(sub chan T) { if _, ok := b.subs[sub]; ok { close(sub) delete(b.subs, sub) } } -func (b *Broadcaster) Start(connect ConnectFunc) error { +func (b *Broadcaster[T]) Start(ctx context.Context, connect ConnectFunc[T]) error { b.Lock() defer b.Unlock() - c, err := connect() + c, err := connect(ctx) if err != nil { return err } @@ -52,7 +52,7 @@ func (b *Broadcaster) Start(connect ConnectFunc) error { return nil } -func (b *Broadcaster) stream(ch chan interface{}) { +func (b *Broadcaster[T]) stream(ch chan T) { for item := range ch { b.publish(item) } @@ -65,7 +65,7 @@ func (b *Broadcaster) stream(ch chan interface{}) { b.running = false } -func (b *Broadcaster) publish(item interface{}) { +func (b *Broadcaster[T]) publish(item T) { b.Lock() defer b.Unlock() diff --git a/pkg/kine/drivers/sqlite/sqlite.go b/pkg/kine/drivers/sqlite/sqlite.go index f30cd62b..bf2e3b91 100644 --- a/pkg/kine/drivers/sqlite/sqlite.go +++ b/pkg/kine/drivers/sqlite/sqlite.go @@ -10,9 +10,8 @@ import ( "time" "github.com/canonical/k8s-dqlite/pkg/kine/drivers/generic" - "github.com/canonical/k8s-dqlite/pkg/kine/logstructured" - "github.com/canonical/k8s-dqlite/pkg/kine/logstructured/sqllog" "github.com/canonical/k8s-dqlite/pkg/kine/server" + "github.com/canonical/k8s-dqlite/pkg/kine/sqllog" "github.com/mattn/go-sqlite3" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -98,7 +97,7 @@ func NewVariant(ctx context.Context, driverName, dataSourceName string, connecti } } - return logstructured.New(sqllog.New(dialect)), dialect, nil + return sqllog.New(dialect), dialect, nil } // setup performs table setup, which may include creation of the Kine table if diff --git a/pkg/kine/endpoint/endpoint.go b/pkg/kine/endpoint/endpoint.go index 80de13b1..46880ca9 100644 --- a/pkg/kine/endpoint/endpoint.go +++ b/pkg/kine/endpoint/endpoint.go @@ -76,11 +76,10 @@ func Listen(ctx context.Context, config Config) (ETCDConfig, error) { go func() { if err := grpcServer.Serve(listener); err != nil { - logrus.Errorf("Kine server shutdown: %v", err) + logrus.Errorf("unexpected server shutdown: %v", err) } - listener.Close() - grpcServer.Stop() }() + context.AfterFunc(ctx, grpcServer.Stop) return ETCDConfig{ LeaderElect: leaderelect, @@ -145,8 +144,8 @@ func ListenAndReturnBackend(ctx context.Context, config Config) (ETCDConfig, ser logrus.Errorf("Kine server shutdown: %v", err) } listener.Close() - grpcServer.Stop() }() + context.AfterFunc(ctx, grpcServer.Stop) return ETCDConfig{ LeaderElect: leaderelect, diff --git a/pkg/kine/logstructured/logstructured.go b/pkg/kine/logstructured/logstructured.go deleted file mode 100644 index 84713c5f..00000000 --- a/pkg/kine/logstructured/logstructured.go +++ /dev/null @@ -1,333 +0,0 @@ -package logstructured - -import ( - "context" - "fmt" - "sync" - "sync/atomic" - "time" - - "github.com/canonical/k8s-dqlite/pkg/kine/server" - "github.com/sirupsen/logrus" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" -) - -const otelName = "logstructured" - -var ( - otelTracer trace.Tracer -) - -func init() { - otelTracer = otel.Tracer(otelName) -} - -type Log interface { - Start(ctx context.Context) error - Wait() - CurrentRevision(ctx context.Context) (int64, error) - List(ctx context.Context, prefix, startKey string, limit, revision int64, includeDeletes bool) (int64, []*server.Event, error) - Create(ctx context.Context, key string, value []byte, lease int64) (rev int64, created bool, err error) - Update(ctx context.Context, key string, value []byte, revision, lease int64) (rev int64, updated bool, err error) - Delete(ctx context.Context, key string, revision int64) (rev int64, deleted bool, err error) - After(ctx context.Context, prefix string, revision, limit int64) (int64, []*server.Event, error) - Watch(ctx context.Context, prefix string) <-chan []*server.Event - Count(ctx context.Context, prefix, startKey string, revision int64) (int64, int64, error) - DbSize(ctx context.Context) (int64, error) - DoCompact(ctx context.Context) error -} - -type LogStructured struct { - log Log - wg sync.WaitGroup -} - -func New(log Log) *LogStructured { - return &LogStructured{ - log: log, - } -} - -func (l *LogStructured) DoCompact(ctx context.Context) error { - return l.log.DoCompact(ctx) -} - -func (l *LogStructured) Start(ctx context.Context) error { - if err := l.log.Start(ctx); err != nil { - return err - } - l.Create(ctx, "/registry/health", []byte(`{"health":"true"}`), 0) - - l.wg.Add(1) - go func() { - defer l.wg.Done() - l.ttl(ctx) - }() - return nil -} - -func (l *LogStructured) Wait() { - l.wg.Wait() - l.log.Wait() -} - -func (l *LogStructured) Get(ctx context.Context, key, rangeEnd string, limit, revision int64) (revRet int64, kvRet *server.KeyValue, errRet error) { - ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.Get", otelName)) - span.SetAttributes( - attribute.String("key", key), - attribute.String("rangeEnd", rangeEnd), - attribute.Int64("limit", limit), - attribute.Int64("revision", revision), - ) - defer func() { - logrus.Debugf("GET %s, rev=%d => rev=%d, kv=%v, err=%v", key, revision, revRet, kvRet != nil, errRet) - span.SetAttributes(attribute.Int64("current-revision", revRet)) - span.RecordError(errRet) - span.End() - }() - - rev, event, err := l.get(ctx, key, rangeEnd, limit, revision, false) - if event == nil { - return rev, nil, err - } - return rev, event.KV, err -} - -func (l *LogStructured) get(ctx context.Context, key, rangeEnd string, limit, revision int64, includeDeletes bool) (int64, *server.Event, error) { - var err error - ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.get", otelName)) - defer func() { - span.RecordError(err) - span.End() - }() - span.SetAttributes( - attribute.String("key", key), - attribute.String("rangeEnd", rangeEnd), - attribute.Int64("limit", limit), - attribute.Int64("revision", revision), - attribute.Bool("includeDeletes", includeDeletes), - ) - rev, events, err := l.log.List(ctx, key, rangeEnd, limit, revision, includeDeletes) - if err == server.ErrCompacted { - span.AddEvent("key already compacted") - // ignore compacted when getting by revision - err = nil - } else if err != nil { - return 0, nil, err - } - if revision != 0 { - rev = revision - } - if len(events) == 0 { - return rev, nil, nil - } - return rev, events[0], nil -} - -func (l *LogStructured) Create(ctx context.Context, key string, value []byte, lease int64) (rev int64, created bool, err error) { - rev, created, err = l.log.Create(ctx, key, value, lease) - logrus.Debugf("CREATE %s, size=%d, lease=%d => rev=%d, err=%v", key, len(value), lease, rev, err) - return rev, created, err -} - -func (l *LogStructured) Delete(ctx context.Context, key string, revision int64) (revRet int64, deleted bool, errRet error) { - rev, del, err := l.log.Delete(ctx, key, revision) - logrus.Debugf("DELETE %s, rev=%d => rev=%d, deleted=%v, err=%v", key, revision, rev, del, err) - return rev, del, err -} - -func (l *LogStructured) List(ctx context.Context, prefix, startKey string, limit, revision int64) (revRet int64, kvRet []*server.KeyValue, errRet error) { - ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.List", otelName)) - span.SetAttributes( - attribute.String("prefix", prefix), - attribute.String("startKey", startKey), - attribute.Int64("limit", limit), - attribute.Int64("revision", revision), - ) - - defer func() { - logrus.Debugf("LIST %s, start=%s, limit=%d, rev=%d => rev=%d, kvs=%d, err=%v", prefix, startKey, limit, revision, revRet, len(kvRet), errRet) - span.RecordError(errRet) - span.End() - }() - - rev, events, err := l.log.List(ctx, prefix, startKey, limit, revision, false) - if err != nil { - return 0, nil, err - } - - kvs := make([]*server.KeyValue, len(events)) - for i, event := range events { - kvs[i] = event.KV - } - return rev, kvs, nil -} - -func (l *LogStructured) Count(ctx context.Context, prefix, startKey string, revision int64) (revRet int64, count int64, err error) { - ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.Count", otelName)) - defer func() { - logrus.Debugf("COUNT prefix=%s startKey=%s => rev=%d, count=%d, err=%v", prefix, startKey, revRet, count, err) - span.SetAttributes( - attribute.String("prefix", prefix), - attribute.String("startKey", startKey), - attribute.Int64("revision", revision), - attribute.Int64("current-revision", revRet), - attribute.Int64("count", count), - ) - span.RecordError(err) - span.End() - }() - return l.log.Count(ctx, prefix, startKey, revision) -} - -func (l *LogStructured) Update(ctx context.Context, key string, value []byte, revision, lease int64) (revRet int64, updateRet bool, errRet error) { - ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.Update", otelName)) - defer func() { - logrus.Debugf("UPDATE %s, value=%d, rev=%d, lease=%v => rev=%d, updated=%v, err=%v", key, len(value), revision, lease, revRet, updateRet, errRet) - span.SetAttributes( - attribute.String("key", key), - attribute.Int64("revision", revision), - attribute.Int64("lease", lease), - attribute.Int64("value-size", int64(len(value))), - attribute.Int64("current-revision", revRet), - attribute.Bool("updated", updateRet), - ) - span.End() - }() - return l.log.Update(ctx, key, value, revision, lease) -} - -func (l *LogStructured) ttlEvents(ctx context.Context) chan *server.Event { - result := make(chan *server.Event) - var shouldClose atomic.Bool - - l.wg.Add(2) - go func() { - defer l.wg.Done() - - rev, events, err := l.log.List(ctx, "/", "", 1000, 0, false) - for len(events) > 0 { - if err != nil { - logrus.Errorf("failed to read old events for ttl: %v", err) - return - } - - for _, event := range events { - if event.KV.Lease > 0 { - result <- event - } - } - - _, events, err = l.log.List(ctx, "/", events[len(events)-1].KV.Key, 1000, rev, false) - } - - if !shouldClose.CompareAndSwap(false, true) { - close(result) - } - }() - - go func() { - defer l.wg.Done() - - for events := range l.log.Watch(ctx, "/") { - for _, event := range events { - if event.KV.Lease > 0 { - result <- event - } - } - } - - if !shouldClose.CompareAndSwap(false, true) { - close(result) - } - }() - - return result -} - -func (l *LogStructured) ttl(ctx context.Context) { - // very naive TTL support - mutex := &sync.Mutex{} - for event := range l.ttlEvents(ctx) { - go func(event *server.Event) { - select { - case <-ctx.Done(): - return - case <-time.After(time.Duration(event.KV.Lease) * time.Second): - } - mutex.Lock() - l.Delete(ctx, event.KV.Key, event.KV.ModRevision) - mutex.Unlock() - }(event) - } -} - -func (l *LogStructured) Watch(ctx context.Context, prefix string, revision int64) <-chan []*server.Event { - logrus.Debugf("WATCH %s, revision=%d", prefix, revision) - ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.Watch", otelName)) - defer span.End() - span.SetAttributes( - attribute.String("prefix", prefix), - attribute.Int64("revision", revision), - ) - - // starting watching right away so we don't miss anything - ctx, cancel := context.WithCancel(ctx) - readChan := l.log.Watch(ctx, prefix) - - // include the current revision in list - if revision > 0 { - revision -= 1 - } - - result := make(chan []*server.Event, 100) - - rev, kvs, err := l.log.After(ctx, prefix, revision, 0) - if err != nil { - logrus.Errorf("failed to list %s for revision %d", prefix, revision) - msg := fmt.Sprintf("failed to list %s for revision %d", prefix, revision) - span.AddEvent(msg) - logrus.Errorf(msg) - cancel() - } - - logrus.Debugf("WATCH LIST key=%s rev=%d => rev=%d kvs=%d", prefix, revision, rev, len(kvs)) - span.SetAttributes(attribute.Int64("current-revision", rev), attribute.Int64("kvs-count", int64(len(kvs)))) - - l.wg.Add(1) - go func() { - defer l.wg.Done() - - lastRevision := revision - if len(kvs) > 0 { - lastRevision = rev - } - - if len(kvs) > 0 { - result <- kvs - } - - // always ensure we fully read the channel - for i := range readChan { - result <- filter(i, lastRevision) - } - close(result) - cancel() - }() - - return result -} - -func filter(events []*server.Event, rev int64) []*server.Event { - for len(events) > 0 && events[0].KV.ModRevision <= rev { - events = events[1:] - } - - return events -} - -func (l *LogStructured) DbSize(ctx context.Context) (int64, error) { - return l.log.DbSize(ctx) -} diff --git a/pkg/kine/server/compact.go b/pkg/kine/server/compact.go index f1b82976..1e1aed8d 100644 --- a/pkg/kine/server/compact.go +++ b/pkg/kine/server/compact.go @@ -29,10 +29,8 @@ func (l *LimitedServer) compact(ctx context.Context) (*etcdserverpb.TxnResponse, Response: &etcdserverpb.ResponseOp_ResponseRange{ ResponseRange: &etcdserverpb.RangeResponse{ Header: &etcdserverpb.ResponseHeader{}, - Kvs: []*mvccpb.KeyValue{ - &mvccpb.KeyValue{}, - }, - Count: 1, + Kvs: []*mvccpb.KeyValue{{}}, + Count: 1, }, }, }, diff --git a/pkg/kine/server/delete.go b/pkg/kine/server/delete.go index 180a4b25..0b4b1429 100644 --- a/pkg/kine/server/delete.go +++ b/pkg/kine/server/delete.go @@ -21,8 +21,7 @@ func isDelete(txn *etcdserverpb.TxnRequest) (int64, string, bool) { return 0, "", false } -func (l *LimitedServer) delete(ctx context.Context, key string, revision int64) (*etcdserverpb.TxnResponse, error) { - var err error +func (l *LimitedServer) delete(ctx context.Context, key string, revision int64) (_ *etcdserverpb.TxnResponse, err error) { deleteCnt.Add(ctx, 1) ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.delete", otelName)) defer func() { @@ -57,7 +56,7 @@ func (l *LimitedServer) delete(ctx context.Context, key string, revision int64) }, } } else { - rev, kv, err := l.backend.Get(ctx, key, "", 1, rev) + rev, kv, err := l.backend.List(ctx, key, "", 1, rev) if err != nil { return nil, err } @@ -66,7 +65,7 @@ func (l *LimitedServer) delete(ctx context.Context, key string, revision int64) Response: &etcdserverpb.ResponseOp_ResponseRange{ ResponseRange: &etcdserverpb.RangeResponse{ Header: txnHeader(rev), - Kvs: toKVs(kv), + Kvs: toKVs(kv...), }, }, }, diff --git a/pkg/kine/server/get.go b/pkg/kine/server/get.go index 2f1ce6d5..03daeb4a 100644 --- a/pkg/kine/server/get.go +++ b/pkg/kine/server/get.go @@ -19,25 +19,23 @@ func (l *LimitedServer) get(ctx context.Context, r *etcdserverpb.RangeRequest) ( span.SetAttributes( attribute.String("key", string(r.Key)), - attribute.String("rangeEnd", string(r.RangeEnd)), attribute.Int64("limit", r.Limit), attribute.Int64("revision", r.Revision), ) - if r.Limit != 0 && len(r.RangeEnd) != 0 { - err := fmt.Errorf("invalid combination of rangeEnd and limit, limit should be 0 got %d", r.Limit) - return nil, err + + if len(r.RangeEnd) != 0 { + return nil, fmt.Errorf("unexpected rangeEnd: want empty, got %s", r.RangeEnd) + } + if r.Limit != 0 { + return nil, fmt.Errorf("unexpected limit: want 0, got %d", r.Limit) } - rev, kv, err := l.backend.Get(ctx, string(r.Key), string(r.RangeEnd), r.Limit, r.Revision) + rev, kv, err := l.backend.List(ctx, string(r.Key), "", 1, r.Revision) if err != nil { return nil, err } - - resp := &RangeResponse{ + return &RangeResponse{ Header: txnHeader(rev), - } - if kv != nil { - resp.Kvs = []*KeyValue{kv} - } - return resp, nil + Kvs: kv, + }, nil } diff --git a/pkg/kine/server/types.go b/pkg/kine/server/types.go index 61580859..a437f0e0 100644 --- a/pkg/kine/server/types.go +++ b/pkg/kine/server/types.go @@ -13,16 +13,16 @@ var ( type Backend interface { Start(ctx context.Context) error - Wait() - Get(ctx context.Context, key, rangeEnd string, limit, revision int64) (int64, *KeyValue, error) + Stop() error Create(ctx context.Context, key string, value []byte, lease int64) (int64, bool, error) Delete(ctx context.Context, key string, revision int64) (int64, bool, error) List(ctx context.Context, prefix, startKey string, limit, revision int64) (int64, []*KeyValue, error) Count(ctx context.Context, prefix, startKey string, revision int64) (int64, int64, error) Update(ctx context.Context, key string, value []byte, revision, lease int64) (int64, bool, error) - Watch(ctx context.Context, key string, revision int64) <-chan []*Event + Watch(ctx context.Context, key string, revision int64) (<-chan []*Event, error) DbSize(ctx context.Context) (int64, error) DoCompact(ctx context.Context) error + Close() error } type KeyValue struct { diff --git a/pkg/kine/server/update.go b/pkg/kine/server/update.go index 2cdf172f..fbc42b18 100644 --- a/pkg/kine/server/update.go +++ b/pkg/kine/server/update.go @@ -25,12 +25,7 @@ func isUpdate(txn *etcdserverpb.TxnRequest) (int64, string, []byte, int64, bool) return 0, "", nil, 0, false } -func (l *LimitedServer) update(ctx context.Context, rev int64, key string, value []byte, lease int64) (*etcdserverpb.TxnResponse, error) { - var ( - kv *KeyValue - succeeded bool - err error - ) +func (l *LimitedServer) update(ctx context.Context, rev int64, key string, value []byte, lease int64) (_ *etcdserverpb.TxnResponse, err error) { updateCnt.Add(ctx, 1) ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.update", otelName)) @@ -44,6 +39,7 @@ func (l *LimitedServer) update(ctx context.Context, rev int64, key string, value attribute.Int64("revision", rev), ) + var succeeded bool if rev == 0 { rev, succeeded, err = l.backend.Create(ctx, key, value, lease) } else { @@ -70,20 +66,18 @@ func (l *LimitedServer) update(ctx context.Context, rev int64, key string, value }, } } else { - rev, kv, err = l.backend.Get(ctx, key, "", 1, rev) + rev, kv, err := l.backend.List(ctx, key, "", 1, rev) if err != nil { return nil, err } - resp.Responses = []*etcdserverpb.ResponseOp{ - { - Response: &etcdserverpb.ResponseOp_ResponseRange{ - ResponseRange: &etcdserverpb.RangeResponse{ - Header: txnHeader(rev), - Kvs: toKVs(kv), - }, + resp.Responses = []*etcdserverpb.ResponseOp{{ + Response: &etcdserverpb.ResponseOp_ResponseRange{ + ResponseRange: &etcdserverpb.RangeResponse{ + Header: txnHeader(rev), + Kvs: toKVs(kv...), }, }, - } + }} } return resp, nil diff --git a/pkg/kine/server/watch.go b/pkg/kine/server/watch.go index 9fe34bf8..b78f9e8d 100644 --- a/pkg/kine/server/watch.go +++ b/pkg/kine/server/watch.go @@ -71,7 +71,12 @@ func (w *watcher) Start(ctx context.Context, r *etcdserverpb.WatchCreateRequest) return } - for events := range w.backend.Watch(ctx, key, r.StartRevision) { + watchCh, err := w.backend.Watch(ctx, key, r.StartRevision) + if err != nil { + w.Cancel(id, err) + return + } + for events := range watchCh { if len(events) == 0 { continue } diff --git a/pkg/kine/logstructured/sqllog/sql.go b/pkg/kine/sqllog/sqllog.go similarity index 72% rename from pkg/kine/logstructured/sqllog/sql.go rename to pkg/kine/sqllog/sqllog.go index 282770c6..83468230 100644 --- a/pkg/kine/logstructured/sqllog/sql.go +++ b/pkg/kine/sqllog/sqllog.go @@ -19,9 +19,9 @@ import ( ) const ( + otelName = "sqllog" SupersededCount = 100 compactBatchSize = 1000 - otelName = "sqllog" ) var ( @@ -41,22 +41,6 @@ func init() { } } -type SQLLog struct { - d Dialect - broadcaster broadcaster.Broadcaster - ctx context.Context - notify chan int64 - wg sync.WaitGroup -} - -func New(d Dialect) *SQLLog { - l := &SQLLog{ - d: d, - notify: make(chan int64, 1024), - } - return l -} - type Dialect interface { List(ctx context.Context, prefix, startKey string, limit, revision int64, includeDeleted bool) (*sql.Rows, error) Count(ctx context.Context, prefix, startKey string, revision int64) (int64, error) @@ -78,18 +62,74 @@ type Dialect interface { Close() error } -func (s *SQLLog) Start(ctx context.Context) (err error) { - s.ctx = ctx - context.AfterFunc(ctx, func() { - if err := s.d.Close(); err != nil { - logrus.Errorf("cannot close database: %v", err) - } - }) - return s.broadcaster.Start(s.startWatch) +type SQLLog struct { + mu sync.Mutex + stop func() + started bool + + d Dialect + broadcaster broadcaster.Broadcaster[[]*server.Event] + notify chan int64 + wg sync.WaitGroup } -func (s *SQLLog) Wait() { +func New(d Dialect) *SQLLog { + return &SQLLog{ + d: d, + notify: make(chan int64, 1024), + } +} + +func (s *SQLLog) Start(startCtx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.started { + return nil + } + + _, _, err := s.Create(startCtx, "/registry/health", []byte(`{"health":"true"}`), 0) + if err != nil { + return err + } + + ctx, stop := context.WithCancel(context.Background()) + err = s.broadcaster.Start(ctx, s.startWatch) + if err != nil { + stop() + return err + } + + s.wg.Add(1) + go func() { + defer s.wg.Done() + s.ttl(ctx) + }() + + s.stop = stop + s.started = true + return nil +} + +func (s *SQLLog) Stop() error { + s.mu.Lock() + defer s.mu.Unlock() + + if !s.started { + return nil + } + + s.stop() s.wg.Wait() + s.stop, s.started = nil, false + return nil +} + +func (s *SQLLog) Close() error { + stopErr := s.Stop() + closeErr := s.d.Close() + + return errors.Join(stopErr, closeErr) } func (s *SQLLog) compactStart(ctx context.Context) error { @@ -98,7 +138,7 @@ func (s *SQLLog) compactStart(ctx context.Context) error { return err } - events, err := RowsToEvents(rows) + events, err := ScanAll(rows, scanEvent) if err != nil { return err } @@ -136,6 +176,7 @@ func (s *SQLLog) compactStart(ctx context.Context) error { // from test functions that have access to the backend. func (s *SQLLog) DoCompact(ctx context.Context) (err error) { ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.DoCompact", otelName)) + compactCnt.Add(ctx, 1) defer func() { span.RecordError(err) span.End() @@ -166,7 +207,7 @@ func (s *SQLLog) DoCompact(ctx context.Context) (err error) { if batchRevision > target { batchRevision = target } - if err := s.d.Compact(s.ctx, batchRevision); err != nil { + if err := s.d.Compact(ctx, batchRevision); err != nil { return err } start = batchRevision @@ -174,10 +215,6 @@ func (s *SQLLog) DoCompact(ctx context.Context) (err error) { return nil } -func (s *SQLLog) CurrentRevision(ctx context.Context) (int64, error) { - return s.d.CurrentRevision(ctx) -} - func (s *SQLLog) After(ctx context.Context, prefix string, revision, limit int64) (int64, []*server.Event, error) { var err error ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.After", otelName)) @@ -206,14 +243,14 @@ func (s *SQLLog) After(ctx context.Context, prefix string, revision, limit int64 return 0, nil, err } - result, err := RowsToEvents(rows) + result, err := ScanAll(rows, scanEvent) if err != nil { return 0, nil, err } return currentRevision, result, err } -func (s *SQLLog) List(ctx context.Context, prefix, startKey string, limit, revision int64, includeDeleted bool) (int64, []*server.Event, error) { +func (s *SQLLog) List(ctx context.Context, prefix, startKey string, limit, revision int64) (int64, []*server.KeyValue, error) { var err error ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.List", otelName)) @@ -226,7 +263,6 @@ func (s *SQLLog) List(ctx context.Context, prefix, startKey string, limit, revis attribute.String("startKey", startKey), attribute.Int64("limit", limit), attribute.Int64("revision", revision), - attribute.Bool("includeDeleted", includeDeleted), ) compactRevision, currentRevision, err := s.d.GetCompactRevision(ctx) @@ -250,12 +286,12 @@ func (s *SQLLog) List(ctx context.Context, prefix, startKey string, limit, revis startKey = "" } - rows, err := s.d.List(ctx, prefix, startKey, limit, revision, includeDeleted) + rows, err := s.d.List(ctx, prefix, startKey, limit, revision, false) if err != nil { return 0, nil, err } - result, err := RowsToEvents(rows) + result, err := ScanAll(rows, scanKeyValue) if err != nil { return 0, nil, err } @@ -263,70 +299,125 @@ func (s *SQLLog) List(ctx context.Context, prefix, startKey string, limit, revis return currentRevision, result, err } -func RowsToEvents(rows *sql.Rows) ([]*server.Event, error) { - var result []*server.Event - defer rows.Close() - - for rows.Next() { - event := &server.Event{} - if err := scan(rows, event); err != nil { - return nil, err +func (s *SQLLog) ttl(ctx context.Context) { + run := func(ctx context.Context, key string, revision int64, timeout time.Duration) { + select { + case <-ctx.Done(): + return + case <-time.After(timeout): + s.Delete(ctx, key, revision) } - result = append(result, event) } - return result, nil + s.wg.Add(1) + go func() { + defer s.wg.Done() + + rev, kvs, err := s.List(ctx, "/", "", 1000, 0) + for len(kvs) > 0 { + if err != nil { + logrus.Errorf("failed to read old events for ttl: %v", err) + return + } + + for _, kv := range kvs { + if kv.Lease > 0 { + go run(ctx, kv.Key, kv.ModRevision, time.Duration(kv.Lease)*time.Second) + } + } + + _, kvs, err = s.List(ctx, "/", kvs[len(kvs)-1].Key, 1000, rev) + } + + watchCh, err := s.Watch(ctx, "/", rev) + if err != nil { + logrus.Errorf("failed to watch events for ttl: %v", err) + return + } + + for events := range watchCh { + for _, event := range events { + if event.KV.Lease > 0 { + go run(ctx, event.KV.Key, event.KV.ModRevision, time.Duration(event.KV.Lease)*time.Second) + } + } + } + }() } -func (s *SQLLog) Watch(ctx context.Context, prefix string) <-chan []*server.Event { - res := make(chan []*server.Event, 100) +func (s *SQLLog) Watch(ctx context.Context, key string, startRevision int64) (<-chan []*server.Event, error) { + ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.Watch", otelName)) + defer span.End() + span.SetAttributes( + attribute.String("key", key), + attribute.Int64("startRevision", startRevision), + ) + values, err := s.broadcaster.Subscribe(ctx) if err != nil { - return nil + return nil, err } - checkPrefix := strings.HasSuffix(prefix, "/") + if startRevision > 0 { + startRevision = startRevision - 1 + } + + initialRevision, initialEvents, err := s.After(ctx, key, startRevision, 0) + if err != nil { + span.RecordError(err) + return nil, err + } + + res := make(chan []*server.Event, 100) + if len(initialEvents) > 0 { + res <- initialEvents + } s.wg.Add(1) go func() { - defer s.wg.Done() - defer close(res) - - for i := range values { - events, ok := filter(i, checkPrefix, prefix) - if ok { - res <- events + defer func() { + close(res) + s.wg.Done() + }() + + for events := range values { + filtered := filterEvents(events, key, initialRevision) + if len(filtered) > 0 { + res <- filtered } } }() - - return res + return res, nil } -func filter(events interface{}, checkPrefix bool, prefix string) ([]*server.Event, bool) { - eventList := events.([]*server.Event) - filteredEventList := make([]*server.Event, 0, len(eventList)) +func filterEvents(events []*server.Event, key string, startRevision int64) []*server.Event { + filteredEventList := make([]*server.Event, 0, len(events)) + checkPrefix := strings.HasSuffix(key, "/") - for _, event := range eventList { - if (checkPrefix && strings.HasPrefix(event.KV.Key, prefix)) || event.KV.Key == prefix { - filteredEventList = append(filteredEventList, event) + for _, event := range events { + if event.KV.ModRevision <= startRevision { + continue + } + if !(checkPrefix && strings.HasPrefix(event.KV.Key, key)) && event.KV.Key != key { + continue } + filteredEventList = append(filteredEventList, event) } - return filteredEventList, len(filteredEventList) > 0 + return filteredEventList } -func (s *SQLLog) startWatch() (chan interface{}, error) { - if err := s.compactStart(s.ctx); err != nil { +func (s *SQLLog) startWatch(ctx context.Context) (chan []*server.Event, error) { + if err := s.compactStart(ctx); err != nil { return nil, err } - pollStart, _, err := s.d.GetCompactRevision(s.ctx) + pollStart, _, err := s.d.GetCompactRevision(ctx) if err != nil { return nil, err } - c := make(chan interface{}) + c := make(chan []*server.Event) // start compaction and polling at the same time to watch starts // at the oldest revision, but compaction doesn't create gaps s.wg.Add(2) @@ -338,10 +429,10 @@ func (s *SQLLog) startWatch() (chan interface{}, error) { for { select { - case <-s.ctx.Done(): + case <-ctx.Done(): return case <-t.C: - if err := s.DoCompact(s.ctx); err != nil { + if err := s.DoCompact(ctx); err != nil { logrus.WithError(err).Trace("compaction failed") } } @@ -350,13 +441,13 @@ func (s *SQLLog) startWatch() (chan interface{}, error) { go func() { defer s.wg.Done() - s.poll(c, pollStart) + s.poll(ctx, c, pollStart) }() return c, nil } -func (s *SQLLog) poll(result chan interface{}, pollStart int64) { +func (s *SQLLog) poll(ctx context.Context, result chan []*server.Event, pollStart int64) { var ( last = pollStart skip int64 @@ -371,7 +462,7 @@ func (s *SQLLog) poll(result chan interface{}, pollStart int64) { for { if waitForMore { select { - case <-s.ctx.Done(): + case <-ctx.Done(): return case check := <-s.notify: if check <= last { @@ -381,7 +472,7 @@ func (s *SQLLog) poll(result chan interface{}, pollStart int64) { } } waitForMore = true - watchCtx, cancel := context.WithTimeout(s.ctx, s.d.GetWatchQueryTimeout()) + watchCtx, cancel := context.WithTimeout(ctx, s.d.GetWatchQueryTimeout()) defer cancel() rows, err := s.d.After(watchCtx, last, 500) @@ -392,7 +483,7 @@ func (s *SQLLog) poll(result chan interface{}, pollStart int64) { continue } - events, err := RowsToEvents(rows) + events, err := ScanAll(rows, scanEvent) if err != nil { logrus.Errorf("fail to convert rows changes: %v", err) continue @@ -427,7 +518,7 @@ func (s *SQLLog) poll(result chan interface{}, pollStart int64) { s.notifyWatcherPoll(next) break } else { - if err := s.d.Fill(s.ctx, next); err == nil { + if err := s.d.Fill(ctx, next); err == nil { logrus.Debugf("FILL, revision=%d, err=%v", next, err) s.notifyWatcherPoll(next) } else { @@ -534,9 +625,66 @@ func (s *SQLLog) notifyWatcherPoll(revision int64) { } } -func scan(rows *sql.Rows, event *server.Event) error { - event.KV = &server.KeyValue{} - event.PrevKV = &server.KeyValue{} +func (s *SQLLog) DbSize(ctx context.Context) (int64, error) { + var err error + ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.DbSize", otelName)) + defer func() { + span.RecordError(err) + span.End() + }() + size, err := s.d.GetSize(ctx) + span.SetAttributes(attribute.Int64("size", size)) + return size, err +} + +func ScanAll[T any](rows *sql.Rows, scanOne func(*sql.Rows) (T, error)) ([]T, error) { + var result []T + defer rows.Close() + + for rows.Next() { + item, err := scanOne(rows) + if err != nil { + return nil, err + } + result = append(result, item) + } + + return result, nil +} + +func scanKeyValue(rows *sql.Rows) (*server.KeyValue, error) { + kv := &server.KeyValue{} + var create, delete bool + var prevRevision int64 + var prevValue []byte + + err := rows.Scan( + &kv.ModRevision, + &kv.Key, + &create, + &delete, + &kv.CreateRevision, + &prevRevision, + &kv.Lease, + &kv.Value, + &prevValue, + ) + if err != nil { + return nil, err + } + + if create { + kv.CreateRevision = kv.ModRevision + } + + return kv, nil +} + +func scanEvent(rows *sql.Rows) (*server.Event, error) { + event := &server.Event{ + KV: &server.KeyValue{}, + PrevKV: &server.KeyValue{}, + } err := rows.Scan( &event.KV.ModRevision, @@ -550,7 +698,7 @@ func scan(rows *sql.Rows, event *server.Event) error { &event.PrevKV.Value, ) if err != nil { - return err + return nil, err } if event.Create { @@ -562,17 +710,5 @@ func scan(rows *sql.Rows, event *server.Event) error { event.PrevKV.Lease = event.KV.Lease } - return nil -} - -func (s *SQLLog) DbSize(ctx context.Context) (int64, error) { - var err error - ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.DbSize", otelName)) - defer func() { - span.RecordError(err) - span.End() - }() - size, err := s.d.GetSize(ctx) - span.SetAttributes(attribute.Int64("size", size)) - return size, err + return event, nil } diff --git a/pkg/server/server.go b/pkg/server/server.go index 02a3e672..21c2f89f 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -351,6 +351,10 @@ func (s *Server) Start(ctx context.Context) error { // Shutdown cleans up any resources and attempts to hand-over and shutdown the dqlite application. func (s *Server) Shutdown(ctx context.Context) error { + if err := s.backend.Close(); err != nil { + return err + } + logrus.Debug("Handing over dqlite leadership") if err := s.app.Handover(ctx); err != nil { logrus.WithError(err).Errorf("Failed to handover dqlite") @@ -359,8 +363,8 @@ func (s *Server) Shutdown(ctx context.Context) error { if err := s.app.Close(); err != nil { return fmt.Errorf("failed to close dqlite app: %w", err) } + close(s.mustStopCh) - s.backend.Wait() return nil } diff --git a/test/compaction_test.go b/test/compaction_test.go index 1ff399dd..15e35484 100644 --- a/test/compaction_test.go +++ b/test/compaction_test.go @@ -6,7 +6,7 @@ import ( "testing" "github.com/canonical/k8s-dqlite/pkg/kine/endpoint" - "github.com/canonical/k8s-dqlite/pkg/kine/logstructured/sqllog" + "github.com/canonical/k8s-dqlite/pkg/kine/sqllog" . "github.com/onsi/gomega" ) diff --git a/test/create_test.go b/test/create_test.go index a69af1bb..b7f0dc92 100644 --- a/test/create_test.go +++ b/test/create_test.go @@ -70,10 +70,10 @@ func BenchmarkCreate(b *testing.B) { } } -func createKey(ctx context.Context, g Gomega, client *clientv3.Client, key string, value string) int64 { +func createKey(ctx context.Context, g Gomega, client *clientv3.Client, key string, value string, opts ...clientv3.OpOption) int64 { resp, err := client.Txn(ctx). If(clientv3.Compare(clientv3.ModRevision(key), "=", 0)). - Then(clientv3.OpPut(key, value)). + Then(clientv3.OpPut(key, value, opts...)). Commit() g.Expect(err).To(BeNil()) diff --git a/test/lease_test.go b/test/lease_test.go index da737d31..41f648ac 100644 --- a/test/lease_test.go +++ b/test/lease_test.go @@ -18,6 +18,10 @@ const ( // TestLease is unit testing for the lease operation. func TestLease(t *testing.T) { + const leaseKey = "/leaseTestKey" + const leaseValue = "testValue" + const ttlSeconds = 1 + for _, backendType := range []string{endpoint.SQLiteBackend, endpoint.DQLiteBackend} { t.Run(backendType, func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) @@ -25,57 +29,32 @@ func TestLease(t *testing.T) { kine := newKineServer(ctx, t, &kineOptions{backendType: backendType}) - t.Run("LeaseGrant", func(t *testing.T) { - g := NewWithT(t) - ttl := int64(300) - resp, err := kine.client.Lease.Grant(ctx, ttl) + g := NewWithT(t) + lease := grantLease(ctx, g, kine.client, ttlSeconds) - g.Expect(err).To(BeNil()) - g.Expect(resp.ID).To(Equal(clientv3.LeaseID(ttl))) - g.Expect(resp.TTL).To(Equal(ttl)) - }) + createKey(ctx, g, kine.client, leaseKey, leaseValue, clientv3.WithLease(lease)) - t.Run("UseLease", func(t *testing.T) { - ttl := int64(1) - t.Run("CreateWithLease", func(t *testing.T) { - g := NewWithT(t) + resp, err := kine.client.Get(ctx, leaseKey, clientv3.WithRange("")) + g.Expect(err).To(BeNil()) + g.Expect(resp.Kvs).To(HaveLen(1)) + g.Expect(resp.Kvs[0].Key).To(Equal([]byte(leaseKey))) + g.Expect(resp.Kvs[0].Value).To(Equal([]byte(leaseValue))) + g.Expect(resp.Kvs[0].Lease).To(Equal(int64(lease))) - { - resp, err := kine.client.Lease.Grant(ctx, ttl) - g.Expect(err).To(BeNil()) - g.Expect(resp.ID).To(Equal(clientv3.LeaseID(ttl))) - g.Expect(resp.TTL).To(Equal(ttl)) - } + g.Eventually(func() []*mvccpb.KeyValue { + resp, err := kine.client.Get(ctx, leaseKey, clientv3.WithRange("")) + g.Expect(err).To(BeNil()) + return resp.Kvs + }, time.Duration(ttlSeconds*2)*time.Second, testExpirePollPeriod, ctx).Should(BeEmpty()) + }) + } +} - { - resp, err := kine.client.Txn(ctx). - If(clientv3.Compare(clientv3.ModRevision("/leaseTestKey"), "=", 0)). - Then(clientv3.OpPut("/leaseTestKey", "testValue", clientv3.WithLease(clientv3.LeaseID(ttl)))). - Commit() - g.Expect(err).To(BeNil()) - g.Expect(resp.Succeeded).To(BeTrue()) - } +func grantLease(ctx context.Context, g Gomega, client *clientv3.Client, ttl int64) clientv3.LeaseID { + resp, err := client.Lease.Grant(ctx, ttl) - { - resp, err := kine.client.Get(ctx, "/leaseTestKey", clientv3.WithRange("")) - g.Expect(err).To(BeNil()) - g.Expect(resp.Kvs).To(HaveLen(1)) - g.Expect(resp.Kvs[0].Key).To(Equal([]byte("/leaseTestKey"))) - g.Expect(resp.Kvs[0].Value).To(Equal([]byte("testValue"))) - g.Expect(resp.Kvs[0].Lease).To(Equal(ttl)) - } - }) + g.Expect(err).To(BeNil()) + g.Expect(resp.TTL).To(Equal(ttl)) - t.Run("KeyShouldExpire", func(t *testing.T) { - g := NewWithT(t) - // timeout ttl*2 seconds, poll 100ms - g.Eventually(func() []*mvccpb.KeyValue { - resp, err := kine.client.Get(ctx, "/leaseTestKey", clientv3.WithRange("")) - g.Expect(err).To(BeNil()) - return resp.Kvs - }, time.Duration(ttl*2)*time.Second, testExpirePollPeriod, ctx).Should(BeEmpty()) - }) - }) - }) - } + return resp.ID } diff --git a/test/util_test.go b/test/util_test.go index ede2d010..6d0d7c1c 100644 --- a/test/util_test.go +++ b/test/util_test.go @@ -88,7 +88,9 @@ func newKineServer(ctx context.Context, tb testing.TB, options *kineOptions) *ki tb.Fatal(err) } tb.Cleanup(func() { - backend.Wait() + if err := backend.Close(); err != nil { + tb.Error("cannot close backend", err) + } }) if options.setup != nil { diff --git a/test/watch_test.go b/test/watch_test.go index 9222c598..860ce37b 100644 --- a/test/watch_test.go +++ b/test/watch_test.go @@ -30,8 +30,9 @@ func TestWatch(t *testing.T) { kine := newKineServer(ctx, t, &kineOptions{backendType: backendType}) // start watching for events on key - const prefix = "test/" - watchCh := kine.client.Watch(ctx, prefix) + const watchedPrefix = "watched/" + const ingnoredPrefix = "ignored/" + watchCh := kine.client.Watch(ctx, watchedPrefix) t.Run("ReceiveNothingUntilActivity", func(t *testing.T) { g := NewWithT(t) @@ -41,9 +42,10 @@ func TestWatch(t *testing.T) { t.Run("Create", func(t *testing.T) { g := NewWithT(t) - key := prefix + "createdKey" value := "testValue" + key := watchedPrefix + "createdKey" rev := createKey(ctx, g, kine.client, key, value) + createKey(ctx, g, kine.client, ingnoredPrefix+"createdKey", value) g.Eventually(watchCh, pollTimeout).Should(ReceiveEvents(g, CreateEvent(g, key, value, rev), @@ -54,7 +56,7 @@ func TestWatch(t *testing.T) { t.Run("Update", func(t *testing.T) { g := NewWithT(t) - key := prefix + "updatedKey" + key := watchedPrefix + "updatedKey" createValue := "testValue1" createRev := createKey(ctx, g, kine.client, key, createValue) g.Eventually(watchCh, pollTimeout).Should(ReceiveEvents(g, @@ -73,7 +75,7 @@ func TestWatch(t *testing.T) { t.Run("Delete", func(t *testing.T) { g := NewWithT(t) - key := prefix + "deletedKey" + key := watchedPrefix + "deletedKey" createValue := "testValue" createRev := createKey(ctx, g, kine.client, key, createValue) g.Eventually(watchCh, pollTimeout).Should(ReceiveEvents(g, @@ -93,7 +95,7 @@ func TestWatch(t *testing.T) { defer cancel() g := NewWithT(t) - key := prefix + "revisionKey" + key := watchedPrefix + "revisionKey" createValue := "testValue1" createRev := createKey(ctx, g, kine.client, key, createValue)