diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d84c74f..9f770f4 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,5 +1,11 @@ name: Tests -on: [push] +on: + push: + branches: + - "main" + tags: + - "v*" + pull_request: jobs: postgres: strategy: diff --git a/dialector.go b/dialector.go new file mode 100644 index 0000000..70c565f --- /dev/null +++ b/dialector.go @@ -0,0 +1,77 @@ +package sharding + +import ( + "fmt" + + "gorm.io/gorm" +) + +type ShardingDialector struct { + gorm.Dialector + sharding *Sharding +} + +type ShardingMigrator struct { + gorm.Migrator + sharding *Sharding + dialector gorm.Dialector +} + +func NewShardingDialector(d gorm.Dialector, s *Sharding) ShardingDialector { + return ShardingDialector{ + Dialector: d, + sharding: s, + } +} + +func (d ShardingDialector) Migrator(db *gorm.DB) gorm.Migrator { + m := d.Dialector.Migrator(db) + return ShardingMigrator{ + Migrator: m, + sharding: d.sharding, + dialector: d.Dialector, + } +} + +func (m ShardingMigrator) AutoMigrate(dst ...interface{}) error { + noShardingDsts := make([]interface{}, 0) + for _, model := range dst { + stmt := &gorm.Statement{DB: m.sharding.DB} + if err := stmt.Parse(model); err == nil { + if cfg, ok := m.sharding.configs[stmt.Table]; ok { + // support sharding table + suffixs := cfg.ShardingSuffixs() + if len(suffixs) == 0 { + return fmt.Errorf("sharding table:%s suffixs is empty", stmt.Table) + } + + for _, suffix := range suffixs { + shardingTable := stmt.Table + suffix + tx := stmt.DB.Session(&gorm.Session{}).Table(shardingTable) + if err := m.dialector.Migrator(tx).AutoMigrate(model); err != nil { + return err + } + } + + if cfg.DoubleWrite { + noShardingDsts = append(noShardingDsts, model) + } + } else { + noShardingDsts = append(noShardingDsts, model) + } + } else { + return err + } + } + + if len(noShardingDsts) > 0 { + if err := m.Migrator.AutoMigrate(noShardingDsts...); err != nil { + return err + } + } + return nil +} + +// TODO: DropTable drop sharding table +// func (m ShardingMigrator) DropTable(dst ...interface{}) error { +// } diff --git a/sharding.go b/sharding.go index e68863a..cbb8d1e 100644 --- a/sharding.go +++ b/sharding.go @@ -56,6 +56,19 @@ type Config struct { // } ShardingAlgorithm func(columnValue interface{}) (suffix string, err error) + // ShardingSuffixs specifies a function to generate all table's suffix. + // Used to support Migrator. + // For example, this function get a mod all sharding suffixs. + // + // func () (suffixs []string) { + // numberOfShards := 5 + // for i := 0; i < numberOfShards; i++ { + // suffixs = append(suffixs, fmt.Sprintf("_%02d", i%numberOfShards)) + // } + // return + // } + ShardingSuffixs func() (suffixs []string) + // ShardingAlgorithmByPrimaryKey specifies a function to generate the sharding // table's suffix by the primary key. Used when no sharding key specified. // For example, this function use the Snowflake library to generate the suffix. @@ -160,10 +173,24 @@ func (s *Sharding) compile() error { return "", fmt.Errorf("default algorithm only support integer and string column," + "if you use other type, specify you own ShardingAlgorithm") } + return fmt.Sprintf(c.tableFormat, id%int(c.NumberOfShards)), nil } } + if c.ShardingSuffixs == nil { + c.ShardingSuffixs = func() (suffixs []string) { + for i := 0; i < int(c.NumberOfShards); i++ { + suffix, err := c.ShardingAlgorithm(i) + if err != nil { + return nil + } + suffixs = append(suffixs, suffix) + } + return + } + } + if c.ShardingAlgorithmByPrimaryKey == nil { if c.PrimaryKeyGenerator == PKSnowflake { c.ShardingAlgorithmByPrimaryKey = func(id int64) (suffix string) { @@ -193,6 +220,7 @@ func (s *Sharding) LastQuery() string { // Initialize implement for Gorm plugin interface func (s *Sharding) Initialize(db *gorm.DB) error { + db.Dialector = NewShardingDialector(db.Dialector, s) s.DB = db s.registerCallbacks(db) diff --git a/sharding_test.go b/sharding_test.go index eaeb8d0..7492ffe 100644 --- a/sharding_test.go +++ b/sharding_test.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "regexp" + "sort" "strings" "testing" @@ -151,6 +152,20 @@ func dropTables() { } } +func TestAutoMigrate(t *testing.T) { + targetTables := []string{"orders", "orders_0", "orders_1", "orders_2", "orders_3", "categories"} + for _, table := range targetTables { + db.Exec("DROP TABLE IF EXISTS " + table) + db.Exec(("DROP SEQUENCE IF EXISTS gorm_sharding_" + table + "_id_seq")) + } + + db.AutoMigrate(&Order{}, &Category{}) + tables, _ := db.Migrator().GetTables() + sort.Strings(tables) + sort.Strings(targetTables) + assert.Equal(t, tables, targetTables) +} + func TestInsert(t *testing.T) { tx := db.Create(&Order{ID: 100, UserID: 100, Product: "iPhone"}) assertQueryResult(t, `INSERT INTO orders_0 ("user_id", "product", "id") VALUES ($1, $2, $3) RETURNING "id"`, tx)