diff --git a/database/mysql/mysql.go b/database/mysql/mysql.go index 5cc892a16..fb256962b 100644 --- a/database/mysql/mysql.go +++ b/database/mysql/mysql.go @@ -97,6 +97,23 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { return mx, nil } +// extractCustomQueryParams extracts the custom query params (ones that start with "x-") from +// mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL +func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) { + if c == nil { + return nil, ErrNilConfig + } + customQueryParams := map[string]string{} + + for k, v := range c.Params { + if strings.HasPrefix(k, "x-") { + customQueryParams[k] = v + delete(c.Params, k) + } + } + return customQueryParams, nil +} + func urlToMySQLConfig(url string) (*mysql.Config, error) { config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://")) if err != nil { @@ -174,6 +191,13 @@ func (m *Mysql) Open(url string) (database.Driver, error) { if err != nil { return nil, err } + fmt.Printf("config: %+v\n", config) + + customParams, err := extractCustomQueryParams(config) + if err != nil { + return nil, err + } + fmt.Printf("config: %+v\n", config) db, err := sql.Open("mysql", config.FormatDSN()) if err != nil { @@ -182,7 +206,7 @@ func (m *Mysql) Open(url string) (database.Driver, error) { mx, err := WithInstance(db, &Config{ DatabaseName: config.DBName, - MigrationsTable: config.Params["x-migrations-table"], + MigrationsTable: customParams["x-migrations-table"], }) if err != nil { return nil, err diff --git a/database/mysql/mysql_test.go b/database/mysql/mysql_test.go index 5d6e82e8b..4091dcba0 100644 --- a/database/mysql/mysql_test.go +++ b/database/mysql/mysql_test.go @@ -6,17 +6,17 @@ import ( sqldriver "database/sql/driver" "fmt" "log" - - "github.com/golang-migrate/migrate/v4" "testing" ) import ( "github.com/dhui/dktest" "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/assert" ) import ( + "github.com/golang-migrate/migrate/v4" dt "github.com/golang-migrate/migrate/v4/database/testing" "github.com/golang-migrate/migrate/v4/dktesting" _ "github.com/golang-migrate/migrate/v4/source/file" @@ -175,6 +175,62 @@ func TestLockWorks(t *testing.T) { }) } +func TestExtractCustomQueryParams(t *testing.T) { + testcases := []struct { + name string + config *mysql.Config + expectedParams map[string]string + expectedCustomParams map[string]string + expectedErr error + }{ + {name: "nil config", expectedErr: ErrNilConfig}, + { + name: "no params", + config: mysql.NewConfig(), + expectedCustomParams: map[string]string{}, + }, + { + name: "no custom params", + config: &mysql.Config{Params: map[string]string{"hello": "world"}}, + expectedParams: map[string]string{"hello": "world"}, + expectedCustomParams: map[string]string{}, + }, + { + name: "one param, one custom param", + config: &mysql.Config{ + Params: map[string]string{"hello": "world", "x-foo": "bar"}, + }, + expectedParams: map[string]string{"hello": "world"}, + expectedCustomParams: map[string]string{"x-foo": "bar"}, + }, + { + name: "multiple params, multiple custom params", + config: &mysql.Config{ + Params: map[string]string{ + "hello": "world", + "x-foo": "bar", + "dead": "beef", + "x-cat": "hat", + }, + }, + expectedParams: map[string]string{"hello": "world", "dead": "beef"}, + expectedCustomParams: map[string]string{"x-foo": "bar", "x-cat": "hat"}, + }, + } + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + customParams, err := extractCustomQueryParams(tc.config) + if tc.config != nil { + assert.Equal(t, tc.expectedParams, tc.config.Params, + "Expected config params have custom params properly removed") + } + assert.Equal(t, tc.expectedErr, err, "Expected errors to match") + assert.Equal(t, tc.expectedCustomParams, customParams, + "Expected custom params to be properly extracted") + }) + } +} + func TestURLToMySQLConfig(t *testing.T) { testcases := []struct { name string