Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stdlib matches native pgx scanning support #2029

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions stdlib/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"strings"
"testing"
"time"

"github.com/jackc/pgx/v5/pgtype"
)

func getSelectRowsCounts(b *testing.B) []int64 {
Expand Down Expand Up @@ -107,3 +109,52 @@ func BenchmarkSelectRowsScanNull(b *testing.B) {
})
}
}

func BenchmarkFlatArrayEncodeArgument(b *testing.B) {
db := openDB(b)
defer closeDB(b, db)

input := make(pgtype.FlatArray[string], 10)
for i := range input {
input[i] = fmt.Sprintf("String %d", i)
}

b.ResetTimer()

for i := 0; i < b.N; i++ {
var n int64
err := db.QueryRow("select cardinality($1::text[])", input).Scan(&n)
if err != nil {
b.Fatal(err)
}
if n != int64(len(input)) {
b.Fatalf("Expected %d, got %d", len(input), n)
}
}
}

func BenchmarkFlatArrayScanResult(b *testing.B) {
db := openDB(b)
defer closeDB(b, db)

var input string
for i := 0; i < 10; i++ {
if i > 0 {
input += ","
}
input += fmt.Sprintf(`'String %d'`, i)
}

b.ResetTimer()

for i := 0; i < b.N; i++ {
var result pgtype.FlatArray[string]
err := db.QueryRow(fmt.Sprintf("select array[%s]::text[]", input)).Scan(&result)
if err != nil {
b.Fatal(err)
}
if len(result) != 10 {
b.Fatalf("Expected %d, got %d", len(result), 10)
}
}
}
6 changes: 6 additions & 0 deletions stdlib/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,12 @@ func (r *Rows) Next(dest []driver.Value) error {
return nil
}

func (r *Rows) ScanColumn(index int, dest any) error {
m := r.conn.conn.TypeMap()
fd := r.rows.FieldDescriptions()[index]
return m.Scan(fd.DataTypeOID, fd.Format, r.rows.RawValues()[index], dest)
}

func valueToInterface(argsV []driver.Value) []any {
args := make([]any, 0, len(argsV))
for _, v := range argsV {
Expand Down
166 changes: 155 additions & 11 deletions stdlib/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,32 @@ func testWithAllQueryExecModes(t *testing.T, f func(t *testing.T, db *sql.DB)) {
}
}

func testWithKnownOIDQueryExecModes(t *testing.T, f func(t *testing.T, db *sql.DB)) {
for _, mode := range []pgx.QueryExecMode{
pgx.QueryExecModeCacheStatement,
pgx.QueryExecModeCacheDescribe,
pgx.QueryExecModeDescribeExec,
} {
t.Run(mode.String(),
func(t *testing.T) {
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)

config.DefaultQueryExecMode = mode
db := stdlib.OpenDB(*config)
defer func() {
err := db.Close()
require.NoError(t, err)
}()

f(t, db)

ensureDBValid(t, db)
},
)
}
}

// Do a simple query to ensure the DB is still usable. This is of less use in stdlib as the connection pool should
// cover broken connections.
func ensureDBValid(t testing.TB, db *sql.DB) {
Expand Down Expand Up @@ -509,29 +535,99 @@ func TestConnQueryScanGoArray(t *testing.T) {
})
}

func TestConnQueryScanArray(t *testing.T) {
func TestGoArray(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
m := pgtype.NewMap()
var names []string

var a pgtype.Array[int64]
err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a))
err := db.QueryRow("select array['John', 'Jane']::text[]").Scan(&names)
require.NoError(t, err)
assert.Equal(t, pgtype.Array[int64]{Elements: []int64{1, 2, 3}, Dims: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, Valid: true}, a)
require.Equal(t, []string{"John", "Jane"}, names)

err = db.QueryRow("select null::bigint[]").Scan(m.SQLScanner(&a))
var n int
err = db.QueryRow("select cardinality($1::text[])", names).Scan(&n)
require.NoError(t, err)
assert.Equal(t, pgtype.Array[int64]{Elements: nil, Dims: nil, Valid: false}, a)
require.EqualValues(t, 2, n)

err = db.QueryRow("select null::text[]").Scan(&names)
require.NoError(t, err)
require.Nil(t, names)
})
}

func TestConnQueryScanRange(t *testing.T) {
func TestGoArrayOfDriverValuer(t *testing.T) {
// Because []sql.NullString is not a registered type on the connection, it will only work with known OIDs.
testWithKnownOIDQueryExecModes(t, func(t *testing.T, db *sql.DB) {
var names []sql.NullString

err := db.QueryRow("select array['John', null, 'Jane']::text[]").Scan(&names)
require.NoError(t, err)
require.Equal(t, []sql.NullString{{String: "John", Valid: true}, {}, {String: "Jane", Valid: true}}, names)

var n int
err = db.QueryRow("select cardinality($1::text[])", names).Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 3, n)

err = db.QueryRow("select null::text[]").Scan(&names)
require.NoError(t, err)
require.Nil(t, names)
})
}

func TestPGTypeFlatArray(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
skipCockroachDB(t, db, "Server does not support int4range")
var names pgtype.FlatArray[string]

m := pgtype.NewMap()
err := db.QueryRow("select array['John', 'Jane']::text[]").Scan(&names)
require.NoError(t, err)
require.Equal(t, pgtype.FlatArray[string]{"John", "Jane"}, names)

var n int
err = db.QueryRow("select cardinality($1::text[])", names).Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 2, n)

err = db.QueryRow("select null::text[]").Scan(&names)
require.NoError(t, err)
require.Nil(t, names)
})
}

func TestPGTypeArray(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
skipCockroachDB(t, db, "Server does not support nested arrays")

var matrix pgtype.Array[int64]

err := db.QueryRow("select '{{1,2,3},{4,5,6}}'::bigint[]").Scan(&matrix)
require.NoError(t, err)
require.Equal(t,
pgtype.Array[int64]{
Elements: []int64{1, 2, 3, 4, 5, 6},
Dims: []pgtype.ArrayDimension{
{Length: 2, LowerBound: 1},
{Length: 3, LowerBound: 1},
},
Valid: true},
matrix)

var equal bool
err = db.QueryRow("select '{{1,2,3},{4,5,6}}'::bigint[] = $1::bigint[]", matrix).Scan(&equal)
require.NoError(t, err)
require.Equal(t, true, equal)

err = db.QueryRow("select null::bigint[]").Scan(&matrix)
require.NoError(t, err)
assert.Equal(t, pgtype.Array[int64]{Elements: nil, Dims: nil, Valid: false}, matrix)
})
}

func TestConnQueryPGTypeRange(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
skipCockroachDB(t, db, "Server does not support int4range")

var r pgtype.Range[pgtype.Int4]
err := db.QueryRow("select int4range(1, 5)").Scan(m.SQLScanner(&r))
err := db.QueryRow("select int4range(1, 5)").Scan(&r)
require.NoError(t, err)
assert.Equal(
t,
Expand All @@ -543,6 +639,54 @@ func TestConnQueryScanRange(t *testing.T) {
Valid: true,
},
r)

var equal bool
err = db.QueryRow("select int4range(1, 5) = $1::int4range", r).Scan(&equal)
require.NoError(t, err)
require.Equal(t, true, equal)

err = db.QueryRow("select null::int4range").Scan(&r)
require.NoError(t, err)
assert.Equal(t, pgtype.Range[pgtype.Int4]{}, r)
})
}

func TestConnQueryPGTypeMultirange(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
skipCockroachDB(t, db, "Server does not support int4range")
skipPostgreSQLVersionLessThan(t, db, 14)

var r pgtype.Multirange[pgtype.Range[pgtype.Int4]]
err := db.QueryRow("select int4multirange(int4range(1, 5), int4range(7,9))").Scan(&r)
require.NoError(t, err)
assert.Equal(
t,
pgtype.Multirange[pgtype.Range[pgtype.Int4]]{
{
Lower: pgtype.Int4{Int32: 1, Valid: true},
Upper: pgtype.Int4{Int32: 5, Valid: true},
LowerType: pgtype.Inclusive,
UpperType: pgtype.Exclusive,
Valid: true,
},
{
Lower: pgtype.Int4{Int32: 7, Valid: true},
Upper: pgtype.Int4{Int32: 9, Valid: true},
LowerType: pgtype.Inclusive,
UpperType: pgtype.Exclusive,
Valid: true,
},
},
r)

var equal bool
err = db.QueryRow("select int4multirange(int4range(1, 5), int4range(7,9)) = $1::int4multirange", r).Scan(&equal)
require.NoError(t, err)
require.Equal(t, true, equal)

err = db.QueryRow("select null::int4multirange").Scan(&r)
require.NoError(t, err)
require.Nil(t, r)
})
}

Expand Down
Loading