From 3ae457a239a18f3f2fa92aec23c5c1d03f752f12 Mon Sep 17 00:00:00 2001 From: Daniil Aksenov Date: Thu, 12 Dec 2024 19:10:22 +0300 Subject: [PATCH] fix & test driver.Valuer interface for sql --- internal/bind/params.go | 18 ++++++--- internal/bind/params_test.go | 14 +++++++ .../database_sql_regression_test.go | 38 +++++++++++++++++++ 3 files changed, 65 insertions(+), 5 deletions(-) diff --git a/internal/bind/params.go b/internal/bind/params.go index 2597ba687..6ae265a0d 100644 --- a/internal/bind/params.go +++ b/internal/bind/params.go @@ -161,6 +161,13 @@ func toValue(v interface{}) (_ value.Value, err error) { return x, nil } + if valuer, ok := v.(driver.Valuer); ok { + v, err = valuer.Value() + if err != nil { + return nil, fmt.Errorf("ydb: driver.Valuer error: %w", err) + } + } + switch x := v.(type) { case nil: return value.VoidValue(), nil @@ -337,16 +344,17 @@ func supportNewTypeLink(x interface{}) string { } func toYdbParam(name string, value interface{}) (*params.Parameter, error) { - if na, ok := value.(driver.NamedValue); ok { - n, v := na.Name, na.Value + switch tv := value.(type) { + case driver.NamedValue: + n, v := tv.Name, tv.Value if n != "" { name = n } value = v + case *params.Parameter: + return tv, nil } - if v, ok := value.(*params.Parameter); ok { - return v, nil - } + v, err := toValue(value) if err != nil { return nil, xerrors.WithStackTrace(err) diff --git a/internal/bind/params_test.go b/internal/bind/params_test.go index 0bb8085b8..210b150b2 100644 --- a/internal/bind/params_test.go +++ b/internal/bind/params_test.go @@ -17,6 +17,14 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/table" ) +type testValuer struct { + value driver.Value +} + +func (v testValuer) Value() (driver.Value, error) { + return v.value, nil +} + func TestToValue(t *testing.T) { for _, tt := range []struct { name string @@ -601,6 +609,12 @@ func TestToValue(t *testing.T) { dst: nil, err: value.ErrIssue1501BadUUID, }, + { + name: xtest.CurrentFileLine(), + src: testValuer{value: "1234567890"}, + dst: value.TextValue("1234567890"), + err: nil, + }, } { t.Run(tt.name, func(t *testing.T) { dst, err := toValue(tt.src) diff --git a/tests/integration/database_sql_regression_test.go b/tests/integration/database_sql_regression_test.go index aeb6d6054..4b183596c 100644 --- a/tests/integration/database_sql_regression_test.go +++ b/tests/integration/database_sql_regression_test.go @@ -6,6 +6,7 @@ package integration import ( "context" "database/sql" + "database/sql/driver" "errors" "fmt" "math/rand" @@ -435,3 +436,40 @@ func TestUUIDSerializationDatabaseSQLIssue1501(t *testing.T) { require.Equal(t, id.String(), res.String()) }) } + +type testValuer struct { + value driver.Value +} + +func (v *testValuer) Value() (driver.Value, error) { + return v.value, nil +} + +func TestSQLX(t *testing.T) { + // test sqlx + + t.Run("named-exec-context", func(t *testing.T) { + // test old behavior - for test way of safe work with data, written with bagged API version + var ( + scope = newScope(t) + db = scope.SQLDriver() + ) + + id := "6E73B41C-4EDE-4D08-9CFB-B7462D9E498B" + v := testValuer{value: id} + + row := db.QueryRow(` + DECLARE $val AS String; + SELECT $val as result_s`, + sql.Named("val", v), + ) + + require.NoError(t, row.Err()) + + var res string + err := row.Scan(&res) + require.NoError(t, err) + + require.Equal(t, id, res) + }) +}