Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add database package #216

Merged
merged 4 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions pkg/database/driver_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
package database_test

import (
"context"
"database/sql/driver"
"io"
"sync/atomic"
"testing"
)

type testDriver struct {
t *testing.T
conns atomic.Int32
queries atomic.Int32
execs atomic.Int32
stmts atomic.Int32
trans atomic.Int32
}

var _ driver.DriverContext = &testDriver{}

// Open implements driver.Driver.
func (td *testDriver) Open(name string) (driver.Conn, error) {
td.conns.Add(1)
return &testConn{
driver: td,
dbName: name,
}, nil
}

func (td *testDriver) OpenConnector(name string) (driver.Connector, error) {
return &testConnector{
driver: td,
dbName: name,
}, nil
}

func (td *testDriver) Close() {
if rows := td.queries.Load(); rows != 0 {
td.t.Errorf("%d rows left open after close", rows)
}
if stmts := td.stmts.Load(); stmts != 0 {
td.t.Errorf("%d statements left open after close", stmts)
}
if conns := td.conns.Load(); conns != 0 {
td.t.Errorf("%d connections left open after close", conns)
}
if trans := td.trans.Load(); trans != 0 {
td.t.Errorf("%d transactions left open after close", trans)
}
}

type testConnector struct {
driver *testDriver
dbName string
}

func (tc *testConnector) Driver() driver.Driver { return tc.driver }

func (tc *testConnector) Connect(context.Context) (driver.Conn, error) {
tc.driver.conns.Add(1)
return &testConn{
driver: tc.driver,
dbName: tc.dbName,
}, nil
}

type testConn struct {
driver *testDriver
dbName string
}

var _ driver.ExecerContext = &testConn{}
var _ driver.QueryerContext = &testConn{}

func (t *testConn) Begin() (driver.Tx, error) {
t.driver.trans.Add(1)
return &testTx{
driver: t.driver,
}, nil
}

func (t *testConn) Prepare(query string) (driver.Stmt, error) {
t.driver.stmts.Add(1)
return &testStmt{
driver: t.driver,
query: query,
}, nil
}

func (t *testConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
t.driver.queries.Add(1)
return &testRows{
driver: t.driver,
}, nil
}

func (t *testConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
t.driver.execs.Add(1)
return testResult{}, nil
}

func (t *testConn) Close() error { return nil }

type testStmt struct {
driver *testDriver
query string
}

func (t *testStmt) NumInput() int { return 0 }
marco6 marked this conversation as resolved.
Show resolved Hide resolved

func (t *testStmt) Exec(args []driver.Value) (driver.Result, error) {
t.driver.execs.Add(1)
return testResult{}, nil
}

func (t *testStmt) Query(args []driver.Value) (driver.Rows, error) {
t.driver.queries.Add(1)
return &testRows{
driver: t.driver,
}, nil
}

func (t *testStmt) Close() error { return nil }

type testRows struct {
driver *testDriver
}

func (t *testRows) Columns() []string { return nil }
func (t *testRows) Next(dest []driver.Value) error { return io.EOF }
func (t *testRows) Close() error { return nil }

type testResult struct{}

func (t testResult) LastInsertId() (int64, error) { return 0, nil }
func (t testResult) RowsAffected() (int64, error) { return 0, nil }

type testTx struct {
driver *testDriver
}

func (t *testTx) Commit() error { return nil }
func (t *testTx) Rollback() error { return nil }
36 changes: 36 additions & 0 deletions pkg/database/interface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package database

import (
"context"
"database/sql"
"errors"
)

var errDBClosed = errors.New("sql: database is closed")
marco6 marked this conversation as resolved.
Show resolved Hide resolved

type Interface interface {
marco6 marked this conversation as resolved.
Show resolved Hide resolved
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
BeginTx(ctx context.Context, opts *sql.TxOptions) (Transaction, error)
Conn(ctx context.Context) (*sql.Conn, error)
Close() error
}

type Transaction interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
Commit() error
Rollback() error
}

type Wrapped[T Transaction] interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
BeginTx(ctx context.Context, opts *sql.TxOptions) (T, error)
Conn(ctx context.Context) (*sql.Conn, error)
Close() error
}
172 changes: 172 additions & 0 deletions pkg/database/prepared.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
package database

import (
"context"
"database/sql"
"errors"
"fmt"
"sync"

"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace"
)

const otelName = "prepared"

var otelTracer trace.Tracer

func init() {
otelTracer = otel.Tracer(otelName)
}

type preparedDb[T Transaction] struct {
underlying Wrapped[T]
mu sync.RWMutex
cache map[string]*sql.Stmt
}

// NewPrepared creates a new Interface that wraps the given database and
// uses a prepare cache to reduce the number of prepare calls. The cache
// is only used when calling ExecContext and QueryContext methods on the
// main instance or in a transaction.
func NewPrepared[T Transaction](db Wrapped[T]) Interface {
return &preparedDb[T]{
underlying: db,
cache: make(map[string]*sql.Stmt),
}
}

func (db *preparedDb[T]) ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) {
ctx, span := otelTracer.Start(ctx, "DB.ExecContext")
defer func() {
span.RecordError(err)
span.End()
}()

stmt, err := db.prepare(ctx, query)
if err != nil {
return nil, err
}
return stmt.ExecContext(ctx, args...)
}

func (db *preparedDb[T]) QueryContext(ctx context.Context, query string, args ...any) (rows *sql.Rows, err error) {
ctx, span := otelTracer.Start(ctx, "DB.QueryContext")
defer func() {
span.RecordError(err)
span.End()
}()

stmt, err := db.prepare(ctx, query)
if err != nil {
return nil, err
}
return stmt.QueryContext(ctx, args...)
}

func (db *preparedDb[T]) PrepareContext(ctx context.Context, query string) (stmt *sql.Stmt, err error) {
return db.underlying.PrepareContext(ctx, query)
}

func (db *preparedDb[T]) prepare(ctx context.Context, query string) (stmt *sql.Stmt, err error) {
ctx, span := otelTracer.Start(ctx, fmt.Sprintf("%s.prepare", otelName))
defer func() {
span.RecordError(err)
span.End()
}()

db.mu.RLock()
stmt = db.cache[query]
db.mu.RUnlock()
if stmt != nil {
return stmt, nil
}

db.mu.Lock()
defer db.mu.Unlock()

if db.underlying == nil {
return nil, errDBClosed
}

// Given that some time has passed since the unlock of the read lock, and the lock of the
// write lock, another goroutine might have already prepared this query, so we should check
// again to avoid preparing the same query twice.
stmt = db.cache[query]
if stmt != nil {
return stmt, nil
}

prepared, err := db.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
db.cache[query] = prepared
return prepared, nil
}

func (db *preparedDb[T]) Conn(ctx context.Context) (*sql.Conn, error) {
return db.underlying.Conn(ctx)
}

func (db *preparedDb[T]) BeginTx(ctx context.Context, opts *sql.TxOptions) (Transaction, error) {
tx, err := db.underlying.BeginTx(ctx, opts)
if err != nil {
return nil, err
}

return &preparedTx[T]{
Transaction: tx,
db: db,
}, nil
}

func (db *preparedDb[T]) Close() error {
db.mu.Lock()
defer db.mu.Unlock()

errs := []error{}
for _, stmt := range db.cache {
if err := stmt.Close(); err != nil {
errs = append(errs, err)
}
}
db.cache = nil

if err := db.underlying.Close(); err != nil {
errs = append(errs, err)
}
db.underlying = nil

return errors.Join(errs...)
}

type preparedTx[T Transaction] struct {
Transaction
db *preparedDb[T]
}

func (tx *preparedTx[T]) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
stmt, err := tx.db.prepare(ctx, query)
if err != nil {
return nil, err
}
return tx.StmtContext(ctx, stmt), nil
}

func (tx *preparedTx[T]) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
stmt, err := tx.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
return stmt.ExecContext(ctx, args...)
}

func (tx *preparedTx[T]) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
stmt, err := tx.PrepareContext(ctx, query)
if err != nil {
return nil, err
}

return stmt.QueryContext(ctx, args...)
}
Loading
Loading