Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pgx: don't use database/sql interface #1198

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 41 additions & 65 deletions database/pgx/pgx.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package pgx

import (
"context"
"database/sql"
"fmt"
"io"
nurl "net/url"
Expand All @@ -22,8 +21,7 @@ import (
"github.com/hashicorp/go-multierror"
"github.com/jackc/pgconn"
"github.com/jackc/pgerrcode"
_ "github.com/jackc/pgx/v4/stdlib"
"github.com/lib/pq"
"github.com/jackc/pgx/v4"
)

const (
Expand Down Expand Up @@ -69,27 +67,26 @@ type Config struct {

type Postgres struct {
// Locking and unlocking need to use the same connection
conn *sql.Conn
db *sql.DB
conn *pgx.Conn
isLocked atomic.Bool

// Open and WithInstance need to guarantee that config is never nil
config *Config
}

func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
func WithInstance(instance *pgx.Conn, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}

if err := instance.Ping(); err != nil {
if err := instance.Ping(context.Background()); err != nil {
return nil, err
}

if config.DatabaseName == "" {
query := `SELECT CURRENT_DATABASE()`
var databaseName string
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
if err := instance.QueryRow(context.Background(), query).Scan(&databaseName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}

Expand All @@ -103,7 +100,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if config.SchemaName == "" {
query := `SELECT CURRENT_SCHEMA()`
var schemaName string
if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
if err := instance.QueryRow(context.Background(), query).Scan(&schemaName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}

Expand Down Expand Up @@ -139,15 +136,8 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
}
}

conn, err := instance.Conn(context.Background())

if err != nil {
return nil, err
}

px := &Postgres{
conn: conn,
db: instance,
conn: instance,
config: config,
}

Expand All @@ -173,7 +163,7 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
// i.e. pgx://user:password@host:port/db => postgres://user:password@host:port/db
purl.Scheme = "postgres"

db, err := sql.Open("pgx/v4", migrate.FilterCustomQuery(purl).String())
db, err := pgx.Connect(context.Background(), migrate.FilterCustomQuery(purl).String())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -240,10 +230,9 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
}

func (p *Postgres) Close() error {
connErr := p.conn.Close()
dbErr := p.db.Close()
if connErr != nil || dbErr != nil {
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
connErr := p.conn.Close(context.Background())
if connErr != nil {
return fmt.Errorf("conn: %w", connErr)
}
return nil
}
Expand Down Expand Up @@ -283,19 +272,19 @@ func (p *Postgres) applyAdvisoryLock() error {

// This will wait indefinitely until the lock can be acquired.
query := `SELECT pg_advisory_lock($1)`
if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
if _, err := p.conn.Exec(context.Background(), query, aid); err != nil {
return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
}
return nil
}

func (p *Postgres) applyTableLock() error {
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
tx, err := p.conn.BeginTx(context.Background(), pgx.TxOptions{})
if err != nil {
return &database.Error{OrigErr: err, Err: "transaction start failed"}
}
defer func() {
errRollback := tx.Rollback()
errRollback := tx.Rollback(context.Background())
if errRollback != nil {
err = multierror.Append(err, errRollback)
}
Expand All @@ -306,30 +295,25 @@ func (p *Postgres) applyTableLock() error {
return err
}

query := "SELECT * FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
rows, err := tx.Query(query, aid)
query := "SELECT * FROM " + quoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
rows, err := tx.Query(context.Background(), query, aid)
if err != nil {
return database.Error{OrigErr: err, Err: "failed to fetch migration lock", Query: []byte(query)}
}

defer func() {
if errClose := rows.Close(); errClose != nil {
err = multierror.Append(err, errClose)
}
}()
defer rows.Close()

// If row exists at all, lock is present
locked := rows.Next()
if locked {
return database.ErrLocked
}

query = "INSERT INTO " + pq.QuoteIdentifier(p.config.LockTable) + " (lock_id) VALUES ($1)"
if _, err := tx.Exec(query, aid); err != nil {
query = "INSERT INTO " + quoteIdentifier(p.config.LockTable) + " (lock_id) VALUES ($1)"
if _, err := tx.Exec(context.Background(), query, aid); err != nil {
return database.Error{OrigErr: err, Err: "failed to set migration lock", Query: []byte(query)}
}

return tx.Commit()
return tx.Commit(context.Background())
}

func (p *Postgres) releaseAdvisoryLock() error {
Expand All @@ -339,7 +323,7 @@ func (p *Postgres) releaseAdvisoryLock() error {
}

query := `SELECT pg_advisory_unlock($1)`
if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
if _, err := p.conn.Exec(context.Background(), query, aid); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}

Expand All @@ -352,8 +336,8 @@ func (p *Postgres) releaseTableLock() error {
return err
}

query := "DELETE FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
if _, err := p.db.Exec(query, aid); err != nil {
query := "DELETE FROM " + quoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
if _, err := p.conn.Exec(context.Background(), query, aid); err != nil {
return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)}
}

Expand Down Expand Up @@ -391,7 +375,7 @@ func (p *Postgres) runStatement(statement []byte) error {
if strings.TrimSpace(query) == "" {
return nil
}
if _, err := p.conn.ExecContext(ctx, query); err != nil {
if _, err := p.conn.Exec(ctx, query); err != nil {

if pgErr, ok := err.(*pgconn.PgError); ok {
var line uint
Expand Down Expand Up @@ -448,14 +432,14 @@ func runesLastIndex(input []rune, target rune) int {
}

func (p *Postgres) SetVersion(version int, dirty bool) error {
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
tx, err := p.conn.BeginTx(context.Background(), pgx.TxOptions{})
if err != nil {
return &database.Error{OrigErr: err, Err: "transaction start failed"}
}

query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName)
if _, err := tx.Exec(query); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
if _, err := tx.Exec(context.Background(), query); err != nil {
if errRollback := tx.Rollback(context.Background()); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return &database.Error{OrigErr: err, Query: []byte(query)}
Expand All @@ -466,15 +450,15 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
// See: https://github.com/golang-migrate/migrate/issues/330
if version >= 0 || (version == database.NilVersion && dirty) {
query = `INSERT INTO ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)`
if _, err := tx.Exec(query, version, dirty); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
if _, err := tx.Exec(context.Background(), query, version, dirty); err != nil {
if errRollback := tx.Rollback(context.Background()); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}

if err := tx.Commit(); err != nil {
if err := tx.Commit(context.Background()); err != nil {
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
}

Expand All @@ -483,9 +467,9 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {

func (p *Postgres) Version() (version int, dirty bool, err error) {
query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1`
err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
err = p.conn.QueryRow(context.Background(), query).Scan(&version, &dirty)
switch {
case err == sql.ErrNoRows:
case err == pgx.ErrNoRows:
return database.NilVersion, false, nil

case err != nil:
Expand All @@ -504,15 +488,11 @@ func (p *Postgres) Version() (version int, dirty bool, err error) {
func (p *Postgres) Drop() (err error) {
// select all tables in current schema
query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
tables, err := p.conn.QueryContext(context.Background(), query)
tables, err := p.conn.Query(context.Background(), query)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
defer func() {
if errClose := tables.Close(); errClose != nil {
err = multierror.Append(err, errClose)
}
}()
defer tables.Close()

// delete one table after another
tableNames := make([]string, 0)
Expand All @@ -539,7 +519,7 @@ func (p *Postgres) Drop() (err error) {
// delete one by one ...
for _, t := range tableNames {
query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE`
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
if _, err := p.conn.Exec(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
Expand Down Expand Up @@ -571,7 +551,7 @@ func (p *Postgres) ensureVersionTable() (err error) {
// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)
row := p.conn.QueryRow(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)

var count int
err = row.Scan(&count)
Expand All @@ -584,7 +564,7 @@ func (p *Postgres) ensureVersionTable() (err error) {
}

query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)`
if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
if _, err = p.conn.Exec(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}

Expand All @@ -598,15 +578,15 @@ func (p *Postgres) ensureLockTable() error {

var count int
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
if err := p.db.QueryRow(query, p.config.LockTable).Scan(&count); err != nil {
if err := p.conn.QueryRow(context.Background(), query, p.config.LockTable).Scan(&count); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if count == 1 {
return nil
}

query = `CREATE TABLE ` + pq.QuoteIdentifier(p.config.LockTable) + ` (lock_id BIGINT NOT NULL PRIMARY KEY)`
if _, err := p.db.Exec(query); err != nil {
query = `CREATE TABLE ` + quoteIdentifier(p.config.LockTable) + ` (lock_id BIGINT NOT NULL PRIMARY KEY)`
if _, err := p.conn.Exec(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}

Expand All @@ -615,9 +595,5 @@ func (p *Postgres) ensureLockTable() error {

// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
func quoteIdentifier(name string) string {
end := strings.IndexRune(name, 0)
if end > -1 {
name = name[:end]
}
return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
return pgx.Identifier([]string{name}).Sanitize()
}
Loading
Loading