diff --git a/CHANGELOG.md b/CHANGELOG.md index 3cf14162e..3b40e03d4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +* Supported of `database/sql/driver.Valuer` interfaces for params which passed to query using sql driver * Exposed `credentials/credentials.OAuth2Config` OAuth2 config ## v3.95.2 diff --git a/internal/bind/params.go b/internal/bind/params.go index 2597ba687..ddf2e0c1e 100644 --- a/internal/bind/params.go +++ b/internal/bind/params.go @@ -161,6 +161,17 @@ func toValue(v interface{}) (_ value.Value, err error) { return x, nil } + if valuer, ok := v.(driver.Valuer); ok { + v, err = valuer.Value() + if err != nil { + return nil, fmt.Errorf("ydb: driver.Valuer error: %w", err) + } + } + + if x, ok := asUUID(v); ok { + return x, nil + } + switch x := v.(type) { case nil: return value.VoidValue(), nil @@ -337,16 +348,17 @@ func supportNewTypeLink(x interface{}) string { } func toYdbParam(name string, value interface{}) (*params.Parameter, error) { - if na, ok := value.(driver.NamedValue); ok { - n, v := na.Name, na.Value + switch tv := value.(type) { + case driver.NamedValue: + n, v := tv.Name, tv.Value if n != "" { name = n } value = v + case *params.Parameter: + return tv, nil } - if v, ok := value.(*params.Parameter); ok { - return v, nil - } + v, err := toValue(value) if err != nil { return nil, xerrors.WithStackTrace(err) diff --git a/internal/bind/params_test.go b/internal/bind/params_test.go index 0bb8085b8..67970f993 100644 --- a/internal/bind/params_test.go +++ b/internal/bind/params_test.go @@ -17,6 +17,14 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/table" ) +type testValuer struct { + value driver.Value +} + +func (v testValuer) Value() (driver.Value, error) { + return v.value, nil +} + func TestToValue(t *testing.T) { for _, tt := range []struct { name string @@ -601,6 +609,24 @@ func TestToValue(t *testing.T) { dst: nil, err: value.ErrIssue1501BadUUID, }, + { + name: xtest.CurrentFileLine(), + src: testValuer{value: "1234567890"}, + dst: value.TextValue("1234567890"), + err: nil, + }, + { + name: xtest.CurrentFileLine(), + src: testValuer{value: uuid.UUID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}}, + dst: value.Uuid(uuid.UUID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}), + err: nil, + }, + { + name: xtest.CurrentFileLine(), + src: testValuer{value: func() *string { return nil }()}, + dst: value.NullValue(types.Text), + err: nil, + }, } { t.Run(tt.name, func(t *testing.T) { dst, err := toValue(tt.src) diff --git a/tests/integration/database_sql_regression_test.go b/tests/integration/database_sql_regression_test.go index c8b3b06e9..7e936897a 100644 --- a/tests/integration/database_sql_regression_test.go +++ b/tests/integration/database_sql_regression_test.go @@ -6,6 +6,7 @@ package integration import ( "context" "database/sql" + "database/sql/driver" "errors" "fmt" "math/rand" @@ -451,3 +452,11 @@ func TestUUIDSerializationDatabaseSQLIssue1501(t *testing.T) { require.Equal(t, id.String(), res.String()) }) } + +type testValuer struct { + value driver.Value +} + +func (v *testValuer) Value() (driver.Value, error) { + return v.value, nil +}