From b3bf3d502c2c71e113ea3cc83bb24672efba9385 Mon Sep 17 00:00:00 2001 From: Marco Manino Date: Thu, 9 Jan 2025 08:34:44 +0100 Subject: [PATCH 1/3] Move prepared to a database package --- pkg/database/driver_test.go | 144 ++++++++++++++++++++++++ pkg/database/interface.go | 42 +++++++ pkg/database/prepared.go | 166 ++++++++++++++++++++++++++++ pkg/database/prepared_test.go | 139 +++++++++++++++++++++++ pkg/kine/drivers/dqlite/dqlite.go | 11 +- pkg/kine/drivers/generic/generic.go | 6 +- pkg/kine/drivers/sqlite/sqlite.go | 11 +- pkg/kine/prepared/db.go | 135 ---------------------- pkg/kine/prepared/tx.go | 32 ------ 9 files changed, 512 insertions(+), 174 deletions(-) create mode 100644 pkg/database/driver_test.go create mode 100644 pkg/database/interface.go create mode 100644 pkg/database/prepared.go create mode 100644 pkg/database/prepared_test.go delete mode 100644 pkg/kine/prepared/db.go delete mode 100644 pkg/kine/prepared/tx.go diff --git a/pkg/database/driver_test.go b/pkg/database/driver_test.go new file mode 100644 index 00000000..d8737fa4 --- /dev/null +++ b/pkg/database/driver_test.go @@ -0,0 +1,144 @@ +package database_test + +import ( + "context" + "database/sql/driver" + "io" + "sync/atomic" + "testing" +) + +type testDriver struct { + t *testing.T + conns atomic.Int32 + queries atomic.Int32 + execs atomic.Int32 + stmts atomic.Int32 + trans atomic.Int32 +} + +var _ driver.DriverContext = &testDriver{} + +// Open implements driver.Driver. +func (td *testDriver) Open(name string) (driver.Conn, error) { + td.conns.Add(1) + return &testConn{ + driver: td, + dbName: name, + }, nil +} + +func (td *testDriver) OpenConnector(name string) (driver.Connector, error) { + return &testConnector{ + driver: td, + dbName: name, + }, nil +} + +func (td *testDriver) Close() { + if rows := td.queries.Load(); rows != 0 { + td.t.Errorf("%d rows left open after close", rows) + } + if stmts := td.stmts.Load(); stmts != 0 { + td.t.Errorf("%d statements left open after close", stmts) + } + if conns := td.conns.Load(); conns != 0 { + td.t.Errorf("%d connections left open after close", conns) + } + if trans := td.trans.Load(); trans != 0 { + td.t.Errorf("%d transactions left open after close", trans) + } +} + +type testConnector struct { + driver *testDriver + dbName string +} + +func (tc *testConnector) Driver() driver.Driver { return tc.driver } + +func (tc *testConnector) Connect(context.Context) (driver.Conn, error) { + tc.driver.conns.Add(1) + return &testConn{ + driver: tc.driver, + dbName: tc.dbName, + }, nil +} + +type testConn struct { + driver *testDriver + dbName string +} + +var _ driver.ExecerContext = &testConn{} +var _ driver.QueryerContext = &testConn{} + +func (t *testConn) Begin() (driver.Tx, error) { + t.driver.trans.Add(1) + return &testTx{ + driver: t.driver, + }, nil +} + +func (t *testConn) Prepare(query string) (driver.Stmt, error) { + t.driver.stmts.Add(1) + return &testStmt{ + driver: t.driver, + query: query, + }, nil +} + +func (t *testConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + t.driver.queries.Add(1) + return &testRows{ + driver: t.driver, + }, nil +} + +func (t *testConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + t.driver.execs.Add(1) + return testResult{}, nil +} + +func (t *testConn) Close() error { return nil } + +type testStmt struct { + driver *testDriver + query string +} + +func (t *testStmt) NumInput() int { return 0 } + +func (t *testStmt) Exec(args []driver.Value) (driver.Result, error) { + t.driver.execs.Add(1) + return testResult{}, nil +} + +func (t *testStmt) Query(args []driver.Value) (driver.Rows, error) { + t.driver.queries.Add(1) + return &testRows{ + driver: t.driver, + }, nil +} + +func (t *testStmt) Close() error { return nil } + +type testRows struct { + driver *testDriver +} + +func (t *testRows) Columns() []string { return nil } +func (t *testRows) Next(dest []driver.Value) error { return io.EOF } +func (t *testRows) Close() error { return nil } + +type testResult struct{} + +func (t testResult) LastInsertId() (int64, error) { return 0, nil } +func (t testResult) RowsAffected() (int64, error) { return 0, nil } + +type testTx struct { + driver *testDriver +} + +func (t *testTx) Commit() error { return nil } +func (t *testTx) Rollback() error { return nil } diff --git a/pkg/database/interface.go b/pkg/database/interface.go new file mode 100644 index 00000000..eed13f6c --- /dev/null +++ b/pkg/database/interface.go @@ -0,0 +1,42 @@ +package database + +import ( + "context" + "database/sql" + "errors" +) + +var errDBClosed = errors.New("sql: database is closed") + +type Interface interface { + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) + + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + + BeginTx(ctx context.Context, opts *sql.TxOptions) (Transaction, error) + Conn(ctx context.Context) (*sql.Conn, error) + Close() error +} + +type Transaction interface { + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) + + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + + StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt + Commit() error + Rollback() error +} + +type Wrapped[T Transaction] interface { + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) + + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + + BeginTx(ctx context.Context, opts *sql.TxOptions) (T, error) + Conn(ctx context.Context) (*sql.Conn, error) + Close() error +} diff --git a/pkg/database/prepared.go b/pkg/database/prepared.go new file mode 100644 index 00000000..51d4527d --- /dev/null +++ b/pkg/database/prepared.go @@ -0,0 +1,166 @@ +package database + +import ( + "context" + "database/sql" + "errors" + "fmt" + "sync" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/trace" +) + +const otelName = "prepared" + +var otelTracer trace.Tracer + +func init() { + otelTracer = otel.Tracer(otelName) +} + +type preparedDb[T Transaction] struct { + underlying Wrapped[T] + mu sync.RWMutex + store map[string]*sql.Stmt +} + +func NewPrepared[T Transaction](db Wrapped[T]) Interface { + return &preparedDb[T]{ + underlying: db, + store: make(map[string]*sql.Stmt), + } +} + +func (db *preparedDb[T]) ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) { + ctx, span := otelTracer.Start(ctx, "DB.ExecContext") + defer func() { + span.RecordError(err) + span.End() + }() + + stmt, err := db.prepare(ctx, query) + if err != nil { + return nil, err + } + return stmt.ExecContext(ctx, args...) +} + +func (db *preparedDb[T]) QueryContext(ctx context.Context, query string, args ...any) (rows *sql.Rows, err error) { + ctx, span := otelTracer.Start(ctx, "DB.QueryContext") + defer func() { + span.RecordError(err) + span.End() + }() + + stmt, err := db.prepare(ctx, query) + if err != nil { + return nil, err + } + return stmt.QueryContext(ctx, args...) +} + +func (db *preparedDb[T]) PrepareContext(ctx context.Context, query string) (stmt *sql.Stmt, err error) { + return db.underlying.PrepareContext(ctx, query) +} + +func (db *preparedDb[T]) prepare(ctx context.Context, query string) (stmt *sql.Stmt, err error) { + ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.prepare", otelName)) + defer func() { + span.RecordError(err) + span.End() + }() + + db.mu.RLock() + stmt = db.store[query] + db.mu.RUnlock() + if stmt != nil { + return stmt, nil + } + + db.mu.Lock() + defer db.mu.Unlock() + + if db.underlying == nil { + return nil, errDBClosed + } + + // Check again if the query was prepared during locking + stmt = db.store[query] + if stmt != nil { + return stmt, nil + } + + prepared, err := db.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + db.store[query] = prepared + return prepared, nil +} + +func (db *preparedDb[T]) Conn(ctx context.Context) (*sql.Conn, error) { + return db.underlying.Conn(ctx) +} + +func (db *preparedDb[T]) BeginTx(ctx context.Context, opts *sql.TxOptions) (Transaction, error) { + tx, err := db.underlying.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + + return &preparedTx[T]{ + Transaction: tx, + db: db, + }, nil +} + +func (db *preparedDb[T]) Close() error { + db.mu.Lock() + defer db.mu.Unlock() + + errs := []error{} + for _, stmt := range db.store { + if err := stmt.Close(); err != nil { + errs = append(errs, err) + } + } + db.store = nil + + if err := db.underlying.Close(); err != nil { + errs = append(errs, err) + } + db.underlying = nil + + return errors.Join(errs...) +} + +type preparedTx[T Transaction] struct { + Transaction + db *preparedDb[T] +} + +func (tx *preparedTx[T]) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + stmt, err := tx.db.prepare(ctx, query) + if err != nil { + return nil, err + } + return tx.StmtContext(ctx, stmt), nil +} + +func (tx *preparedTx[T]) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { + stmt, err := tx.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + return stmt.ExecContext(ctx, args...) +} + +func (tx *preparedTx[T]) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + stmt, err := tx.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + + return stmt.QueryContext(ctx, args...) +} diff --git a/pkg/database/prepared_test.go b/pkg/database/prepared_test.go new file mode 100644 index 00000000..a83a7761 --- /dev/null +++ b/pkg/database/prepared_test.go @@ -0,0 +1,139 @@ +package database_test + +import ( + "context" + "database/sql" + "testing" + + "github.com/canonical/k8s-dqlite/pkg/database" +) + +func TestPreparedPrepare(t *testing.T) { + const expectedQuery = "test query" + + ctx := context.Background() + driver := &testDriver{ + t: t, + } + db := database.NewPrepared(sql.OpenDB(&testConnector{driver: driver})) + defer db.Close() + + stmt1, err := db.PrepareContext(ctx, expectedQuery) + if err != nil { + t.Fatal(err) + } + defer stmt1.Close() + + stmt2, err := db.PrepareContext(ctx, expectedQuery) + if err != nil { + t.Fatal(err) + } + defer stmt2.Close() + + if stmts := driver.stmts.Load(); stmts != 2 { + t.Errorf("invalid open statements: want 1, got %d", stmts) + } +} + +func TestPreparedQuery(t *testing.T) { + ctx := context.Background() + driver := &testDriver{ + t: t, + } + db := database.NewPrepared(sql.OpenDB(&testConnector{driver: driver})) + defer db.Close() + + rows, err := db.QueryContext(ctx, "query 1") + if err != nil { + t.Error(err) + } + rows.Close() + + if stmts := driver.stmts.Load(); stmts != 1 { + t.Errorf("unexpected number of open statements: want %d, got %d", 1, stmts) + } + + rows, err = db.QueryContext(ctx, "query 2") + if err != nil { + t.Error(err) + } + rows.Close() + + if stmts := driver.stmts.Load(); stmts != 2 { + t.Errorf("unexpected number of open statements: want 2, got %d", stmts) + } +} + +func TestPreparedExec(t *testing.T) { + ctx := context.Background() + driver := &testDriver{ + t: t, + } + db := database.NewPrepared(sql.OpenDB(&testConnector{driver: driver})) + defer db.Close() + + _, err := db.ExecContext(ctx, "query 1") + if err != nil { + t.Error(err) + } + + if stmts := driver.stmts.Load(); stmts != 1 { + t.Errorf("unexpected number of open statements: want %d, got %d", 1, stmts) + } + + _, err = db.ExecContext(ctx, "query 2") + if err != nil { + t.Error(err) + } + + if stmts := driver.stmts.Load(); stmts != 2 { + t.Errorf("unexpected number of open statements: want 2, got %d", stmts) + } +} + +func TestPreparedTx(t *testing.T) { + ctx := context.Background() + driver := &testDriver{ + t: t, + } + db := database.NewPrepared(sql.OpenDB(&testConnector{driver: driver})) + defer db.Close() + + transaction := func() error { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return (err) + } + defer tx.Rollback() + + _, err = tx.ExecContext(ctx, "query 1") + if err != nil { + return (err) + } + + _, err = tx.ExecContext(ctx, "query 2") + if err != nil { + return (err) + } + return nil + } + + for i := 0; i < 5; i++ { + if err := transaction(); err != nil { + t.Error(err) + } + } + + if stmts := driver.stmts.Load(); stmts != 4 { + t.Errorf("unexpected number of open statements: want 2, got %d", stmts) + } + + _, err := db.ExecContext(ctx, "query 2") + if err != nil { + t.Error(err) + } + + if stmts := driver.stmts.Load(); stmts != 4 { + t.Errorf("unexpected number of open statements: want 2, got %d", stmts) + } +} diff --git a/pkg/kine/drivers/dqlite/dqlite.go b/pkg/kine/drivers/dqlite/dqlite.go index b613f053..7920a69d 100644 --- a/pkg/kine/drivers/dqlite/dqlite.go +++ b/pkg/kine/drivers/dqlite/dqlite.go @@ -37,7 +37,14 @@ func NewVariant(ctx context.Context, datasourceName string, connectionPoolConfig if err != nil { return nil, nil, errors.Wrap(err, "sqlite client") } - if err := migrate(ctx, generic.DB.Underlying()); err != nil { + + conn, err := generic.DB.Conn(ctx) + if err != nil { + return nil, nil, err + } + defer conn.Close() + + if err := migrate(ctx, conn); err != nil { return nil, nil, errors.Wrap(err, "failed to migrate DB from sqlite") } generic.LockWrites = true @@ -81,7 +88,7 @@ func NewVariant(ctx context.Context, datasourceName string, connectionPoolConfig return backend, generic, nil } -func migrate(ctx context.Context, newDB *sql.DB) (exitErr error) { +func migrate(ctx context.Context, newDB *sql.Conn) (exitErr error) { row := newDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM kine") var count int64 if err := row.Scan(&count); err != nil { diff --git a/pkg/kine/drivers/generic/generic.go b/pkg/kine/drivers/generic/generic.go index 27dfe581..d46f1a56 100644 --- a/pkg/kine/drivers/generic/generic.go +++ b/pkg/kine/drivers/generic/generic.go @@ -9,7 +9,7 @@ import ( "sync" "time" - "github.com/canonical/k8s-dqlite/pkg/kine/prepared" + "github.com/canonical/k8s-dqlite/pkg/database" "github.com/sirupsen/logrus" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" @@ -242,7 +242,7 @@ type Generic struct { sync.Mutex LockWrites bool - DB *prepared.DB + DB database.Interface Retry ErrRetry TranslateErr TranslateErr ErrCode ErrCode @@ -321,7 +321,7 @@ func Open(ctx context.Context, driverName, dataSourceName string, connPoolConfig configureConnectionPooling(connPoolConfig, db) return &Generic{ - DB: prepared.New(db), + DB: database.NewPrepared(db), }, err } diff --git a/pkg/kine/drivers/sqlite/sqlite.go b/pkg/kine/drivers/sqlite/sqlite.go index bf2e3b91..ae65c094 100644 --- a/pkg/kine/drivers/sqlite/sqlite.go +++ b/pkg/kine/drivers/sqlite/sqlite.go @@ -64,7 +64,14 @@ func NewVariant(ctx context.Context, driverName, dataSourceName string, connecti return nil, nil, err } for i := 0; i < retryAttempts; i++ { - err = setup(ctx, dialect.DB.Underlying()) + err = func() error { + conn, err := dialect.DB.Conn(ctx) + if err != nil { + return err + } + defer conn.Close() + return setup(ctx, conn) + }() if err == nil { break } @@ -104,7 +111,7 @@ func NewVariant(ctx context.Context, driverName, dataSourceName string, connecti // it doesn't already exist, migrating key_value table contents to the Kine // table if the key_value table exists, all in a single database transaction. // changes are rolled back if an error occurs. -func setup(ctx context.Context, db *sql.DB) error { +func setup(ctx context.Context, db *sql.Conn) error { // Optimistically ask for the user_version without starting a transaction var currentSchemaVersion SchemaVersion diff --git a/pkg/kine/prepared/db.go b/pkg/kine/prepared/db.go deleted file mode 100644 index 26a075d2..00000000 --- a/pkg/kine/prepared/db.go +++ /dev/null @@ -1,135 +0,0 @@ -package prepared - -import ( - "context" - "database/sql" - "errors" - "fmt" - "sync" - - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" -) - -const otelName = "prepared" - -var otelTracer trace.Tracer - -func init() { - otelTracer = otel.Tracer(otelName) -} - -type DB struct { - underlying *sql.DB - mu sync.RWMutex - store map[string]*sql.Stmt -} - -func New(db *sql.DB) *DB { - return &DB{ - underlying: db, - store: make(map[string]*sql.Stmt), - } -} - -func (db *DB) Underlying() *sql.DB { return db.underlying } - -func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) { - ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.ExecContext", otelName)) - defer func() { - span.RecordError(err) - span.End() - }() - - stmt, err := db.prepare(ctx, query) - if err != nil { - return nil, err - } - return stmt.ExecContext(ctx, args...) -} - -func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (rows *sql.Rows, err error) { - ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.QueryContext", otelName)) - defer func() { - span.RecordError(err) - span.End() - }() - - stmt, err := db.prepare(ctx, query) - if err != nil { - return nil, err - } - return stmt.QueryContext(ctx, args...) -} - -func (db *DB) Close() error { - db.mu.Lock() - defer db.mu.Unlock() - - errs := []error{} - for _, stmt := range db.store { - if err := stmt.Close(); err != nil { - errs = append(errs, err) - } - } - db.store = nil - - if err := db.underlying.Close(); err != nil { - errs = append(errs, err) - } - db.underlying = nil - - return errors.Join(errs...) -} - -func (db *DB) prepare(ctx context.Context, query string) (stmt *sql.Stmt, err error) { - ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.prepare", otelName)) - defer func() { - span.RecordError(err) - span.End() - }() - span.SetAttributes(attribute.String("query", query)) - - db.mu.RLock() - span.AddEvent("acquired read lock") - stmt = db.store[query] - db.mu.RUnlock() - if stmt != nil { - return stmt, nil - } - - db.mu.Lock() - span.AddEvent("acquired read-write lock") - defer db.mu.Unlock() - - if db.underlying == nil { - return nil, errors.New("database is closed") - } - - // Check again if the query was prepared during locking - stmt = db.store[query] - if stmt != nil { - return stmt, nil - } - - prepared, err := db.underlying.PrepareContext(ctx, query) - if err != nil { - return nil, err - } - - db.store[query] = prepared - return prepared, nil -} - -func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { - tx, err := db.underlying.BeginTx(ctx, opts) - if err != nil { - return nil, err - } - - return &Tx{ - db: db, - tx: tx, - }, nil -} diff --git a/pkg/kine/prepared/tx.go b/pkg/kine/prepared/tx.go deleted file mode 100644 index 2d80f4cb..00000000 --- a/pkg/kine/prepared/tx.go +++ /dev/null @@ -1,32 +0,0 @@ -package prepared - -import ( - "context" - "database/sql" -) - -type Tx struct { - db *DB - tx *sql.Tx -} - -func (tx *Tx) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { - stmt, err := tx.db.prepare(ctx, query) - if err != nil { - return nil, err - } - - return tx.tx.StmtContext(ctx, stmt).ExecContext(ctx, args...) -} - -func (tx *Tx) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { - stmt, err := tx.db.prepare(ctx, query) - if err != nil { - return nil, err - } - - return tx.tx.StmtContext(ctx, stmt).QueryContext(ctx, args...) -} - -func (tx *Tx) Commit() error { return tx.tx.Commit() } -func (tx *Tx) Rollback() error { return tx.tx.Rollback() } From 13de1f1a9b38f0e8f76b3cb554f248f5e8cbc567 Mon Sep 17 00:00:00 2001 From: Marco Manino Date: Thu, 9 Jan 2025 15:07:01 +0100 Subject: [PATCH 2/3] Addressing review --- pkg/database/prepared.go | 4 ++++ pkg/database/prepared_test.go | 14 +++++++------- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/pkg/database/prepared.go b/pkg/database/prepared.go index 51d4527d..9e0e3585 100644 --- a/pkg/database/prepared.go +++ b/pkg/database/prepared.go @@ -25,6 +25,10 @@ type preparedDb[T Transaction] struct { store map[string]*sql.Stmt } +// NewPrepared creates a new Interface that wraps the given database and +// uses a prepare cache to reduce the number of prepare calls. The cache +// is only used when calling ExecContext and QueryContext methods on the +// main instance or in a transaction. func NewPrepared[T Transaction](db Wrapped[T]) Interface { return &preparedDb[T]{ underlying: db, diff --git a/pkg/database/prepared_test.go b/pkg/database/prepared_test.go index a83a7761..36bb8bec 100644 --- a/pkg/database/prepared_test.go +++ b/pkg/database/prepared_test.go @@ -8,7 +8,7 @@ import ( "github.com/canonical/k8s-dqlite/pkg/database" ) -func TestPreparedPrepare(t *testing.T) { +func TestPreparedDistinctStmts(t *testing.T) { const expectedQuery = "test query" ctx := context.Background() @@ -31,7 +31,7 @@ func TestPreparedPrepare(t *testing.T) { defer stmt2.Close() if stmts := driver.stmts.Load(); stmts != 2 { - t.Errorf("invalid open statements: want 1, got %d", stmts) + t.Errorf("invalid open statements: want 2, got %d", stmts) } } @@ -50,7 +50,7 @@ func TestPreparedQuery(t *testing.T) { rows.Close() if stmts := driver.stmts.Load(); stmts != 1 { - t.Errorf("unexpected number of open statements: want %d, got %d", 1, stmts) + t.Errorf("unexpected number of open statements: want 1, got %d", stmts) } rows, err = db.QueryContext(ctx, "query 2") @@ -78,7 +78,7 @@ func TestPreparedExec(t *testing.T) { } if stmts := driver.stmts.Load(); stmts != 1 { - t.Errorf("unexpected number of open statements: want %d, got %d", 1, stmts) + t.Errorf("unexpected number of open statements: want 1, got %d", stmts) } _, err = db.ExecContext(ctx, "query 2") @@ -118,14 +118,14 @@ func TestPreparedTx(t *testing.T) { return nil } - for i := 0; i < 5; i++ { + for i := 0; i < 1; i++ { if err := transaction(); err != nil { t.Error(err) } } if stmts := driver.stmts.Load(); stmts != 4 { - t.Errorf("unexpected number of open statements: want 2, got %d", stmts) + t.Errorf("unexpected number of open statements: want 4, got %d", stmts) } _, err := db.ExecContext(ctx, "query 2") @@ -134,6 +134,6 @@ func TestPreparedTx(t *testing.T) { } if stmts := driver.stmts.Load(); stmts != 4 { - t.Errorf("unexpected number of open statements: want 2, got %d", stmts) + t.Errorf("unexpected number of open statements: want 4, got %d", stmts) } } From fa549e30257bfb08999cd1d6768cf510d22ffb5d Mon Sep 17 00:00:00 2001 From: Marco Manino Date: Tue, 14 Jan 2025 11:54:26 +0100 Subject: [PATCH 3/3] Addressing review --- pkg/database/interface.go | 6 ------ pkg/database/prepared.go | 18 ++++++++++-------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/pkg/database/interface.go b/pkg/database/interface.go index eed13f6c..28ca63d9 100644 --- a/pkg/database/interface.go +++ b/pkg/database/interface.go @@ -11,9 +11,7 @@ var errDBClosed = errors.New("sql: database is closed") type Interface interface { QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) - PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) - BeginTx(ctx context.Context, opts *sql.TxOptions) (Transaction, error) Conn(ctx context.Context) (*sql.Conn, error) Close() error @@ -22,9 +20,7 @@ type Interface interface { type Transaction interface { QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) - PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) - StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt Commit() error Rollback() error @@ -33,9 +29,7 @@ type Transaction interface { type Wrapped[T Transaction] interface { QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) - PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) - BeginTx(ctx context.Context, opts *sql.TxOptions) (T, error) Conn(ctx context.Context) (*sql.Conn, error) Close() error diff --git a/pkg/database/prepared.go b/pkg/database/prepared.go index 9e0e3585..54a65f05 100644 --- a/pkg/database/prepared.go +++ b/pkg/database/prepared.go @@ -22,7 +22,7 @@ func init() { type preparedDb[T Transaction] struct { underlying Wrapped[T] mu sync.RWMutex - store map[string]*sql.Stmt + cache map[string]*sql.Stmt } // NewPrepared creates a new Interface that wraps the given database and @@ -32,7 +32,7 @@ type preparedDb[T Transaction] struct { func NewPrepared[T Transaction](db Wrapped[T]) Interface { return &preparedDb[T]{ underlying: db, - store: make(map[string]*sql.Stmt), + cache: make(map[string]*sql.Stmt), } } @@ -76,7 +76,7 @@ func (db *preparedDb[T]) prepare(ctx context.Context, query string) (stmt *sql.S }() db.mu.RLock() - stmt = db.store[query] + stmt = db.cache[query] db.mu.RUnlock() if stmt != nil { return stmt, nil @@ -89,8 +89,10 @@ func (db *preparedDb[T]) prepare(ctx context.Context, query string) (stmt *sql.S return nil, errDBClosed } - // Check again if the query was prepared during locking - stmt = db.store[query] + // Given that some time has passed since the unlock of the read lock, and the lock of the + // write lock, another goroutine might have already prepared this query, so we should check + // again to avoid preparing the same query twice. + stmt = db.cache[query] if stmt != nil { return stmt, nil } @@ -99,7 +101,7 @@ func (db *preparedDb[T]) prepare(ctx context.Context, query string) (stmt *sql.S if err != nil { return nil, err } - db.store[query] = prepared + db.cache[query] = prepared return prepared, nil } @@ -124,12 +126,12 @@ func (db *preparedDb[T]) Close() error { defer db.mu.Unlock() errs := []error{} - for _, stmt := range db.store { + for _, stmt := range db.cache { if err := stmt.Close(); err != nil { errs = append(errs, err) } } - db.store = nil + db.cache = nil if err := db.underlying.Close(); err != nil { errs = append(errs, err)