From 62b339910441852c50276fc9079afa1c1316ba56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Fri, 29 Nov 2024 15:24:10 +0000 Subject: [PATCH] pgx: don't use database/sql interface it'd be nice to have contexts for many of these methods, but that'd be a much wider change --- database/pgx/pgx.go | 106 ++++++++++++++---------------------- database/pgx/pgx_test.go | 55 ++++++++----------- database/pgx/v5/pgx.go | 75 ++++++++++--------------- database/pgx/v5/pgx_test.go | 55 ++++++++----------- 4 files changed, 117 insertions(+), 174 deletions(-) diff --git a/database/pgx/pgx.go b/database/pgx/pgx.go index 7e42d29c9..1a16fb268 100644 --- a/database/pgx/pgx.go +++ b/database/pgx/pgx.go @@ -5,7 +5,6 @@ package pgx import ( "context" - "database/sql" "fmt" "io" nurl "net/url" @@ -22,8 +21,7 @@ import ( "github.com/hashicorp/go-multierror" "github.com/jackc/pgconn" "github.com/jackc/pgerrcode" - _ "github.com/jackc/pgx/v4/stdlib" - "github.com/lib/pq" + "github.com/jackc/pgx/v4" ) const ( @@ -69,27 +67,26 @@ type Config struct { type Postgres struct { // Locking and unlocking need to use the same connection - conn *sql.Conn - db *sql.DB + conn *pgx.Conn isLocked atomic.Bool // Open and WithInstance need to guarantee that config is never nil config *Config } -func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { +func WithInstance(instance *pgx.Conn, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig } - if err := instance.Ping(); err != nil { + if err := instance.Ping(context.Background()); err != nil { return nil, err } if config.DatabaseName == "" { query := `SELECT CURRENT_DATABASE()` var databaseName string - if err := instance.QueryRow(query).Scan(&databaseName); err != nil { + if err := instance.QueryRow(context.Background(), query).Scan(&databaseName); err != nil { return nil, &database.Error{OrigErr: err, Query: []byte(query)} } @@ -103,7 +100,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config.SchemaName == "" { query := `SELECT CURRENT_SCHEMA()` var schemaName string - if err := instance.QueryRow(query).Scan(&schemaName); err != nil { + if err := instance.QueryRow(context.Background(), query).Scan(&schemaName); err != nil { return nil, &database.Error{OrigErr: err, Query: []byte(query)} } @@ -139,15 +136,8 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { } } - conn, err := instance.Conn(context.Background()) - - if err != nil { - return nil, err - } - px := &Postgres{ - conn: conn, - db: instance, + conn: instance, config: config, } @@ -173,7 +163,7 @@ func (p *Postgres) Open(url string) (database.Driver, error) { // i.e. pgx://user:password@host:port/db => postgres://user:password@host:port/db purl.Scheme = "postgres" - db, err := sql.Open("pgx/v4", migrate.FilterCustomQuery(purl).String()) + db, err := pgx.Connect(context.Background(), migrate.FilterCustomQuery(purl).String()) if err != nil { return nil, err } @@ -240,10 +230,9 @@ func (p *Postgres) Open(url string) (database.Driver, error) { } func (p *Postgres) Close() error { - connErr := p.conn.Close() - dbErr := p.db.Close() - if connErr != nil || dbErr != nil { - return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) + connErr := p.conn.Close(context.Background()) + if connErr != nil { + return fmt.Errorf("conn: %w", connErr) } return nil } @@ -283,19 +272,19 @@ func (p *Postgres) applyAdvisoryLock() error { // This will wait indefinitely until the lock can be acquired. query := `SELECT pg_advisory_lock($1)` - if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { + if _, err := p.conn.Exec(context.Background(), query, aid); err != nil { return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} } return nil } func (p *Postgres) applyTableLock() error { - tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) + tx, err := p.conn.BeginTx(context.Background(), pgx.TxOptions{}) if err != nil { return &database.Error{OrigErr: err, Err: "transaction start failed"} } defer func() { - errRollback := tx.Rollback() + errRollback := tx.Rollback(context.Background()) if errRollback != nil { err = multierror.Append(err, errRollback) } @@ -306,17 +295,12 @@ func (p *Postgres) applyTableLock() error { return err } - query := "SELECT * FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1" - rows, err := tx.Query(query, aid) + query := "SELECT * FROM " + quoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1" + rows, err := tx.Query(context.Background(), query, aid) if err != nil { return database.Error{OrigErr: err, Err: "failed to fetch migration lock", Query: []byte(query)} } - - defer func() { - if errClose := rows.Close(); errClose != nil { - err = multierror.Append(err, errClose) - } - }() + defer rows.Close() // If row exists at all, lock is present locked := rows.Next() @@ -324,12 +308,12 @@ func (p *Postgres) applyTableLock() error { return database.ErrLocked } - query = "INSERT INTO " + pq.QuoteIdentifier(p.config.LockTable) + " (lock_id) VALUES ($1)" - if _, err := tx.Exec(query, aid); err != nil { + query = "INSERT INTO " + quoteIdentifier(p.config.LockTable) + " (lock_id) VALUES ($1)" + if _, err := tx.Exec(context.Background(), query, aid); err != nil { return database.Error{OrigErr: err, Err: "failed to set migration lock", Query: []byte(query)} } - return tx.Commit() + return tx.Commit(context.Background()) } func (p *Postgres) releaseAdvisoryLock() error { @@ -339,7 +323,7 @@ func (p *Postgres) releaseAdvisoryLock() error { } query := `SELECT pg_advisory_unlock($1)` - if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { + if _, err := p.conn.Exec(context.Background(), query, aid); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } @@ -352,8 +336,8 @@ func (p *Postgres) releaseTableLock() error { return err } - query := "DELETE FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1" - if _, err := p.db.Exec(query, aid); err != nil { + query := "DELETE FROM " + quoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1" + if _, err := p.conn.Exec(context.Background(), query, aid); err != nil { return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)} } @@ -391,7 +375,7 @@ func (p *Postgres) runStatement(statement []byte) error { if strings.TrimSpace(query) == "" { return nil } - if _, err := p.conn.ExecContext(ctx, query); err != nil { + if _, err := p.conn.Exec(ctx, query); err != nil { if pgErr, ok := err.(*pgconn.PgError); ok { var line uint @@ -448,14 +432,14 @@ func runesLastIndex(input []rune, target rune) int { } func (p *Postgres) SetVersion(version int, dirty bool) error { - tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) + tx, err := p.conn.BeginTx(context.Background(), pgx.TxOptions{}) if err != nil { return &database.Error{OrigErr: err, Err: "transaction start failed"} } query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) - if _, err := tx.Exec(query); err != nil { - if errRollback := tx.Rollback(); errRollback != nil { + if _, err := tx.Exec(context.Background(), query); err != nil { + if errRollback := tx.Rollback(context.Background()); errRollback != nil { err = multierror.Append(err, errRollback) } return &database.Error{OrigErr: err, Query: []byte(query)} @@ -466,15 +450,15 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { // See: https://github.com/golang-migrate/migrate/issues/330 if version >= 0 || (version == database.NilVersion && dirty) { query = `INSERT INTO ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)` - if _, err := tx.Exec(query, version, dirty); err != nil { - if errRollback := tx.Rollback(); errRollback != nil { + if _, err := tx.Exec(context.Background(), query, version, dirty); err != nil { + if errRollback := tx.Rollback(context.Background()); errRollback != nil { err = multierror.Append(err, errRollback) } return &database.Error{OrigErr: err, Query: []byte(query)} } } - if err := tx.Commit(); err != nil { + if err := tx.Commit(context.Background()); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } @@ -483,9 +467,9 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { func (p *Postgres) Version() (version int, dirty bool, err error) { query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1` - err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) + err = p.conn.QueryRow(context.Background(), query).Scan(&version, &dirty) switch { - case err == sql.ErrNoRows: + case err == pgx.ErrNoRows: return database.NilVersion, false, nil case err != nil: @@ -504,15 +488,11 @@ func (p *Postgres) Version() (version int, dirty bool, err error) { func (p *Postgres) Drop() (err error) { // select all tables in current schema query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` - tables, err := p.conn.QueryContext(context.Background(), query) + tables, err := p.conn.Query(context.Background(), query) if err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } - defer func() { - if errClose := tables.Close(); errClose != nil { - err = multierror.Append(err, errClose) - } - }() + defer tables.Close() // delete one table after another tableNames := make([]string, 0) @@ -539,7 +519,7 @@ func (p *Postgres) Drop() (err error) { // delete one by one ... for _, t := range tableNames { query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE` - if _, err := p.conn.ExecContext(context.Background(), query); err != nil { + if _, err := p.conn.Exec(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } } @@ -571,7 +551,7 @@ func (p *Postgres) ensureVersionTable() (err error) { // `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission. // Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1` - row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName) + row := p.conn.QueryRow(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName) var count int err = row.Scan(&count) @@ -584,7 +564,7 @@ func (p *Postgres) ensureVersionTable() (err error) { } query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)` - if _, err = p.conn.ExecContext(context.Background(), query); err != nil { + if _, err = p.conn.Exec(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } @@ -598,15 +578,15 @@ func (p *Postgres) ensureLockTable() error { var count int query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` - if err := p.db.QueryRow(query, p.config.LockTable).Scan(&count); err != nil { + if err := p.conn.QueryRow(context.Background(), query, p.config.LockTable).Scan(&count); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } if count == 1 { return nil } - query = `CREATE TABLE ` + pq.QuoteIdentifier(p.config.LockTable) + ` (lock_id BIGINT NOT NULL PRIMARY KEY)` - if _, err := p.db.Exec(query); err != nil { + query = `CREATE TABLE ` + quoteIdentifier(p.config.LockTable) + ` (lock_id BIGINT NOT NULL PRIMARY KEY)` + if _, err := p.conn.Exec(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } @@ -615,9 +595,5 @@ func (p *Postgres) ensureLockTable() error { // Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611 func quoteIdentifier(name string) string { - end := strings.IndexRune(name, 0) - if end > -1 { - name = name[:end] - } - return `"` + strings.Replace(name, `"`, `""`, -1) + `"` + return pgx.Identifier([]string{name}).Sanitize() } diff --git a/database/pgx/pgx_test.go b/database/pgx/pgx_test.go index 03977973d..60e333af7 100644 --- a/database/pgx/pgx_test.go +++ b/database/pgx/pgx_test.go @@ -4,8 +4,6 @@ package pgx import ( "context" - "database/sql" - sqldriver "database/sql/driver" "errors" "fmt" "io" @@ -22,6 +20,8 @@ import ( dt "github.com/golang-migrate/migrate/v4/database/testing" "github.com/golang-migrate/migrate/v4/dktesting" _ "github.com/golang-migrate/migrate/v4/source/file" + + "github.com/jackc/pgx/v4" ) const ( @@ -53,20 +53,17 @@ func isReady(ctx context.Context, c dktest.ContainerInfo) bool { return false } - db, err := sql.Open("pgx", pgConnectionString(ip, port)) + db, err := pgx.Connect(ctx, pgConnectionString(ip, port)) if err != nil { return false } defer func() { - if err := db.Close(); err != nil { + if err := db.Close(context.Background()); err != nil { log.Println("close error:", err) } }() - if err = db.PingContext(ctx); err != nil { - switch err { - case sqldriver.ErrBadConn, io.EOF: - return false - default: + if err := db.Ping(ctx); err != nil { + if !errors.Is(err, io.EOF) { log.Println(err) } return false @@ -181,7 +178,7 @@ func TestMultipleStatements(t *testing.T) { // make sure second table exists var exists bool - if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil { + if err := d.(*Postgres).conn.QueryRow(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil { t.Fatal(err) } if !exists { @@ -214,7 +211,7 @@ func TestMultipleStatementsInMultiStatementMode(t *testing.T) { // make sure created index exists var exists bool - if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM pg_indexes WHERE schemaname = (SELECT current_schema()) AND indexname = 'idx_foo')").Scan(&exists); err != nil { + if err := d.(*Postgres).conn.QueryRow(context.Background(), "SELECT EXISTS (SELECT 1 FROM pg_indexes WHERE schemaname = (SELECT current_schema()) AND indexname = 'idx_foo')").Scan(&exists); err != nil { t.Fatal(err) } if !exists { @@ -388,7 +385,7 @@ func TestMigrationTableOption(t *testing.T) { // make sure migrate.schema_migrations table exists var exists bool - if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'schema_migrations' AND table_schema = 'migrate')").Scan(&exists); err != nil { + if err := d.(*Postgres).conn.QueryRow(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'schema_migrations' AND table_schema = 'migrate')").Scan(&exists); err != nil { t.Fatal(err) } if !exists { @@ -400,7 +397,7 @@ func TestMigrationTableOption(t *testing.T) { if err != nil { t.Fatal(err) } - if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'migrate.schema_migrations' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil { + if err := d.(*Postgres).conn.QueryRow(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'migrate.schema_migrations' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil { t.Fatal(err) } if !exists { @@ -672,23 +669,7 @@ func TestWithInstance_Concurrent(t *testing.T) { // The number of concurrent processes running WithInstance const concurrency = 30 - - // We can instantiate a single database handle because it is - // actually a connection pool, and so, each of the below go - // routines will have a high probability of using a separate - // connection, which is something we want to exercise. - db, err := sql.Open("pgx", pgConnectionString(ip, port)) - if err != nil { - t.Fatal(err) - } - defer func() { - if err := db.Close(); err != nil { - t.Error(err) - } - }() - - db.SetMaxIdleConns(concurrency) - db.SetMaxOpenConns(concurrency) + connString := pgConnectionString(ip, port) var wg sync.WaitGroup defer wg.Wait() @@ -697,7 +678,19 @@ func TestWithInstance_Concurrent(t *testing.T) { for i := 0; i < concurrency; i++ { go func(i int) { defer wg.Done() - _, err := WithInstance(db, &Config{}) + + db, err := pgx.Connect(context.Background(), connString) + if err != nil { + t.Errorf("process %d error: %s", i, err) + return + } + defer func() { + if err := db.Close(context.Background()); err != nil { + t.Error(err) + } + }() + + _, err = WithInstance(db, &Config{}) if err != nil { t.Errorf("process %d error: %s", i, err) } diff --git a/database/pgx/v5/pgx.go b/database/pgx/v5/pgx.go index 1b5a6ea7a..76cc854ed 100644 --- a/database/pgx/v5/pgx.go +++ b/database/pgx/v5/pgx.go @@ -5,7 +5,6 @@ package pgx import ( "context" - "database/sql" "fmt" "io" nurl "net/url" @@ -21,8 +20,8 @@ import ( "github.com/golang-migrate/migrate/v4/database/multistmt" "github.com/hashicorp/go-multierror" "github.com/jackc/pgerrcode" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" - _ "github.com/jackc/pgx/v5/stdlib" ) func init() { @@ -57,27 +56,26 @@ type Config struct { type Postgres struct { // Locking and unlocking need to use the same connection - conn *sql.Conn - db *sql.DB + conn *pgx.Conn isLocked atomic.Bool // Open and WithInstance need to guarantee that config is never nil config *Config } -func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { +func WithInstance(instance *pgx.Conn, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig } - if err := instance.Ping(); err != nil { + if err := instance.Ping(context.Background()); err != nil { return nil, err } if config.DatabaseName == "" { query := `SELECT CURRENT_DATABASE()` var databaseName string - if err := instance.QueryRow(query).Scan(&databaseName); err != nil { + if err := instance.QueryRow(context.Background(), query).Scan(&databaseName); err != nil { return nil, &database.Error{OrigErr: err, Query: []byte(query)} } @@ -91,7 +89,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config.SchemaName == "" { query := `SELECT CURRENT_SCHEMA()` var schemaName string - if err := instance.QueryRow(query).Scan(&schemaName); err != nil { + if err := instance.QueryRow(context.Background(), query).Scan(&schemaName); err != nil { return nil, &database.Error{OrigErr: err, Query: []byte(query)} } @@ -119,15 +117,8 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { } } - conn, err := instance.Conn(context.Background()) - - if err != nil { - return nil, err - } - px := &Postgres{ - conn: conn, - db: instance, + conn: instance, config: config, } @@ -149,7 +140,7 @@ func (p *Postgres) Open(url string) (database.Driver, error) { // i.e. pgx://user:password@host:port/db => postgres://user:password@host:port/db purl.Scheme = "postgres" - db, err := sql.Open("pgx/v5", migrate.FilterCustomQuery(purl).String()) + db, err := pgx.Connect(context.Background(), migrate.FilterCustomQuery(purl).String()) if err != nil { return nil, err } @@ -211,10 +202,9 @@ func (p *Postgres) Open(url string) (database.Driver, error) { } func (p *Postgres) Close() error { - connErr := p.conn.Close() - dbErr := p.db.Close() - if connErr != nil || dbErr != nil { - return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) + connErr := p.conn.Close(context.Background()) + if connErr != nil { + return fmt.Errorf("conn: %w", connErr) } return nil } @@ -229,7 +219,7 @@ func (p *Postgres) Lock() error { // This will wait indefinitely until the lock can be acquired. query := `SELECT pg_advisory_lock($1)` - if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { + if _, err := p.conn.Exec(context.Background(), query, aid); err != nil { return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} } return nil @@ -244,7 +234,7 @@ func (p *Postgres) Unlock() error { } query := `SELECT pg_advisory_unlock($1)` - if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { + if _, err := p.conn.Exec(context.Background(), query, aid); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } return nil @@ -282,7 +272,7 @@ func (p *Postgres) runStatement(statement []byte) error { if strings.TrimSpace(query) == "" { return nil } - if _, err := p.conn.ExecContext(ctx, query); err != nil { + if _, err := p.conn.Exec(ctx, query); err != nil { if pgErr, ok := err.(*pgconn.PgError); ok { var line uint @@ -339,14 +329,14 @@ func runesLastIndex(input []rune, target rune) int { } func (p *Postgres) SetVersion(version int, dirty bool) error { - tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) + tx, err := p.conn.BeginTx(context.Background(), pgx.TxOptions{}) if err != nil { return &database.Error{OrigErr: err, Err: "transaction start failed"} } query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) - if _, err := tx.Exec(query); err != nil { - if errRollback := tx.Rollback(); errRollback != nil { + if _, err := tx.Exec(context.Background(), query); err != nil { + if errRollback := tx.Rollback(context.Background()); errRollback != nil { err = multierror.Append(err, errRollback) } return &database.Error{OrigErr: err, Query: []byte(query)} @@ -357,15 +347,15 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { // See: https://github.com/golang-migrate/migrate/issues/330 if version >= 0 || (version == database.NilVersion && dirty) { query = `INSERT INTO ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)` - if _, err := tx.Exec(query, version, dirty); err != nil { - if errRollback := tx.Rollback(); errRollback != nil { + if _, err := tx.Exec(context.Background(), query, version, dirty); err != nil { + if errRollback := tx.Rollback(context.Background()); errRollback != nil { err = multierror.Append(err, errRollback) } return &database.Error{OrigErr: err, Query: []byte(query)} } } - if err := tx.Commit(); err != nil { + if err := tx.Commit(context.Background()); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } @@ -374,9 +364,9 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { func (p *Postgres) Version() (version int, dirty bool, err error) { query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1` - err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) + err = p.conn.QueryRow(context.Background(), query).Scan(&version, &dirty) switch { - case err == sql.ErrNoRows: + case err == pgx.ErrNoRows: return database.NilVersion, false, nil case err != nil: @@ -395,15 +385,11 @@ func (p *Postgres) Version() (version int, dirty bool, err error) { func (p *Postgres) Drop() (err error) { // select all tables in current schema query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` - tables, err := p.conn.QueryContext(context.Background(), query) + tables, err := p.conn.Query(context.Background(), query) if err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } - defer func() { - if errClose := tables.Close(); errClose != nil { - err = multierror.Append(err, errClose) - } - }() + defer tables.Close() // delete one table after another tableNames := make([]string, 0) @@ -424,7 +410,7 @@ func (p *Postgres) Drop() (err error) { // delete one by one ... for _, t := range tableNames { query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE` - if _, err := p.conn.ExecContext(context.Background(), query); err != nil { + if _, err := p.conn.Exec(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } } @@ -456,7 +442,7 @@ func (p *Postgres) ensureVersionTable() (err error) { // `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission. // Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1` - row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName) + row := p.conn.QueryRow(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName) var count int err = row.Scan(&count) @@ -469,18 +455,13 @@ func (p *Postgres) ensureVersionTable() (err error) { } query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)` - if _, err = p.conn.ExecContext(context.Background(), query); err != nil { + if _, err = p.conn.Exec(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } return nil } -// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611 func quoteIdentifier(name string) string { - end := strings.IndexRune(name, 0) - if end > -1 { - name = name[:end] - } - return `"` + strings.Replace(name, `"`, `""`, -1) + `"` + return pgx.Identifier([]string{name}).Sanitize() } diff --git a/database/pgx/v5/pgx_test.go b/database/pgx/v5/pgx_test.go index 3066376b9..6a5d14105 100644 --- a/database/pgx/v5/pgx_test.go +++ b/database/pgx/v5/pgx_test.go @@ -4,8 +4,6 @@ package pgx import ( "context" - "database/sql" - sqldriver "database/sql/driver" "errors" "fmt" "io" @@ -23,6 +21,8 @@ import ( dt "github.com/golang-migrate/migrate/v4/database/testing" "github.com/golang-migrate/migrate/v4/dktesting" _ "github.com/golang-migrate/migrate/v4/source/file" + + "github.com/jackc/pgx/v5" ) const ( @@ -54,20 +54,17 @@ func isReady(ctx context.Context, c dktest.ContainerInfo) bool { return false } - db, err := sql.Open("pgx", pgConnectionString(ip, port)) + db, err := pgx.Connect(ctx, pgConnectionString(ip, port)) if err != nil { return false } defer func() { - if err := db.Close(); err != nil { + if err := db.Close(context.Background()); err != nil { log.Println("close error:", err) } }() - if err = db.PingContext(ctx); err != nil { - switch err { - case sqldriver.ErrBadConn, io.EOF: - return false - default: + if err := db.Ping(ctx); err != nil { + if !errors.Is(err, io.EOF) { log.Println(err) } return false @@ -156,7 +153,7 @@ func TestMultipleStatements(t *testing.T) { // make sure second table exists var exists bool - if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil { + if err := d.(*Postgres).conn.QueryRow(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil { t.Fatal(err) } if !exists { @@ -189,7 +186,7 @@ func TestMultipleStatementsInMultiStatementMode(t *testing.T) { // make sure created index exists var exists bool - if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM pg_indexes WHERE schemaname = (SELECT current_schema()) AND indexname = 'idx_foo')").Scan(&exists); err != nil { + if err := d.(*Postgres).conn.QueryRow(context.Background(), "SELECT EXISTS (SELECT 1 FROM pg_indexes WHERE schemaname = (SELECT current_schema()) AND indexname = 'idx_foo')").Scan(&exists); err != nil { t.Fatal(err) } if !exists { @@ -363,7 +360,7 @@ func TestMigrationTableOption(t *testing.T) { // make sure migrate.schema_migrations table exists var exists bool - if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'schema_migrations' AND table_schema = 'migrate')").Scan(&exists); err != nil { + if err := d.(*Postgres).conn.QueryRow(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'schema_migrations' AND table_schema = 'migrate')").Scan(&exists); err != nil { t.Fatal(err) } if !exists { @@ -375,7 +372,7 @@ func TestMigrationTableOption(t *testing.T) { if err != nil { t.Fatal(err) } - if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'migrate.schema_migrations' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil { + if err := d.(*Postgres).conn.QueryRow(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'migrate.schema_migrations' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil { t.Fatal(err) } if !exists { @@ -647,23 +644,7 @@ func TestWithInstance_Concurrent(t *testing.T) { // The number of concurrent processes running WithInstance const concurrency = 30 - - // We can instantiate a single database handle because it is - // actually a connection pool, and so, each of the below go - // routines will have a high probability of using a separate - // connection, which is something we want to exercise. - db, err := sql.Open("pgx", pgConnectionString(ip, port)) - if err != nil { - t.Fatal(err) - } - defer func() { - if err := db.Close(); err != nil { - t.Error(err) - } - }() - - db.SetMaxIdleConns(concurrency) - db.SetMaxOpenConns(concurrency) + connString := pgConnectionString(ip, port) var wg sync.WaitGroup defer wg.Wait() @@ -672,7 +653,19 @@ func TestWithInstance_Concurrent(t *testing.T) { for i := 0; i < concurrency; i++ { go func(i int) { defer wg.Done() - _, err := WithInstance(db, &Config{}) + + db, err := pgx.Connect(context.Background(), connString) + if err != nil { + t.Errorf("process %d error: %s", i, err) + return + } + defer func() { + if err := db.Close(context.Background()); err != nil { + t.Error(err) + } + }() + + _, err = WithInstance(db, &Config{}) if err != nil { t.Errorf("process %d error: %s", i, err) }