diff --git a/core/xservice/option.go b/core/xservice/option.go index 273c7e9..35df956 100644 --- a/core/xservice/option.go +++ b/core/xservice/option.go @@ -16,6 +16,7 @@ import ( "github.com/xinpianchang/xservice/core" "github.com/xinpianchang/xservice/pkg/config" + "github.com/xinpianchang/xservice/pkg/gormx" "github.com/xinpianchang/xservice/pkg/netx" ) @@ -26,6 +27,7 @@ type Options struct { Build string Description string Config *viper.Viper + DbConfigureFn gormx.ConfigureFn GrpcServerOptions []grpc.ServerOption GrpcClientDialOptions []grpc.DialOption GrpcClientDialTimeout time.Duration @@ -74,6 +76,13 @@ func Config(config *viper.Viper) Option { } } +// WithDbConfigureFn set db configure function +func WithDbConfigureFn(fn gormx.ConfigureFn) Option { + return func(o *Options) { + o.DbConfigureFn = fn + } +} + // WithGrpcServerOptions add additional grpc server options func WithGrpcServerOptions(options ...grpc.ServerOption) Option { return func(o *Options) { diff --git a/core/xservice/xservice.go b/core/xservice/xservice.go index 1ed2b5a..0169224 100644 --- a/core/xservice/xservice.go +++ b/core/xservice/xservice.go @@ -101,7 +101,7 @@ func (t *serviceImpl) init() { } if t.options.Config.IsSet("database") { - gormx.Config(t.options.Config) + gormx.Config(t.options.Config, t.options.DbConfigureFn) } if t.options.Config.IsSet("mq") { diff --git a/pkg/gormx/config.go b/pkg/gormx/config.go index adcc6e8..87ce7a2 100644 --- a/pkg/gormx/config.go +++ b/pkg/gormx/config.go @@ -57,7 +57,7 @@ func Config(v *viper.Viper, configureFn ...ConfigureFn) { } var db *gorm.DB - if len(configureFn) > 0 { + if len(configureFn) > 0 && configureFn[0] != nil { db = configureFn[0](c) } else { db = MySQLDbConfig(c)