Skip to content

Commit

Permalink
Merge pull request #1591 from 4el0ve4ek/fix_sql_driver_Valuer
Browse files Browse the repository at this point in the history
fix sql/driver.Valuer interface usage
  • Loading branch information
asmyasnikov authored Dec 16, 2024
2 parents cb3bf6b + fb2182e commit 62b3351
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* Supported of `database/sql/driver.Valuer` interfaces for params which passed to query using sql driver
* Exposed `credentials/credentials.OAuth2Config` OAuth2 config

## v3.95.2
Expand Down
22 changes: 17 additions & 5 deletions internal/bind/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,17 @@ func toValue(v interface{}) (_ value.Value, err error) {
return x, nil
}

if valuer, ok := v.(driver.Valuer); ok {
v, err = valuer.Value()
if err != nil {
return nil, fmt.Errorf("ydb: driver.Valuer error: %w", err)
}
}

if x, ok := asUUID(v); ok {
return x, nil
}

switch x := v.(type) {
case nil:
return value.VoidValue(), nil
Expand Down Expand Up @@ -337,16 +348,17 @@ func supportNewTypeLink(x interface{}) string {
}

func toYdbParam(name string, value interface{}) (*params.Parameter, error) {
if na, ok := value.(driver.NamedValue); ok {
n, v := na.Name, na.Value
switch tv := value.(type) {
case driver.NamedValue:
n, v := tv.Name, tv.Value
if n != "" {
name = n
}
value = v
case *params.Parameter:
return tv, nil
}
if v, ok := value.(*params.Parameter); ok {
return v, nil
}

v, err := toValue(value)
if err != nil {
return nil, xerrors.WithStackTrace(err)
Expand Down
26 changes: 26 additions & 0 deletions internal/bind/params_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ import (
"github.com/ydb-platform/ydb-go-sdk/v3/table"
)

type testValuer struct {
value driver.Value
}

func (v testValuer) Value() (driver.Value, error) {
return v.value, nil
}

func TestToValue(t *testing.T) {
for _, tt := range []struct {
name string
Expand Down Expand Up @@ -601,6 +609,24 @@ func TestToValue(t *testing.T) {
dst: nil,
err: value.ErrIssue1501BadUUID,
},
{
name: xtest.CurrentFileLine(),
src: testValuer{value: "1234567890"},
dst: value.TextValue("1234567890"),
err: nil,
},
{
name: xtest.CurrentFileLine(),
src: testValuer{value: uuid.UUID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}},
dst: value.Uuid(uuid.UUID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}),
err: nil,
},
{
name: xtest.CurrentFileLine(),
src: testValuer{value: func() *string { return nil }()},
dst: value.NullValue(types.Text),
err: nil,
},
} {
t.Run(tt.name, func(t *testing.T) {
dst, err := toValue(tt.src)
Expand Down
9 changes: 9 additions & 0 deletions tests/integration/database_sql_regression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package integration
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"math/rand"
Expand Down Expand Up @@ -451,3 +452,11 @@ func TestUUIDSerializationDatabaseSQLIssue1501(t *testing.T) {
require.Equal(t, id.String(), res.String())
})
}

type testValuer struct {
value driver.Value
}

func (v *testValuer) Value() (driver.Value, error) {
return v.value, nil
}

0 comments on commit 62b3351

Please sign in to comment.