diff --git a/pkg/autoops/api/api.go b/pkg/autoops/api/api.go index f306acdf5e..06639c0747 100644 --- a/pkg/autoops/api/api.go +++ b/pkg/autoops/api/api.go @@ -182,25 +182,9 @@ func (s *AutoOpsService) CreateAutoOpsRule( return nil, dt.Err() } } - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return nil, statusInternal.Err() - } - return nil, dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - autoOpsRuleStorage := v2as.NewAutoOpsRuleStorage(tx) + + err = s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, _ mysql.Transaction) error { + autoOpsRuleStorage := v2as.NewAutoOpsRuleStorage(s.mysqlClient) handler, err := command.NewAutoOpsCommandHandler(editor, autoOpsRule, s.publisher, req.EnvironmentId) if err != nil { return err @@ -208,7 +192,7 @@ func (s *AutoOpsService) CreateAutoOpsRule( if err := handler.Handle(ctx, req.Command); err != nil { return err } - return autoOpsRuleStorage.CreateAutoOpsRule(ctx, autoOpsRule, req.EnvironmentId) + return autoOpsRuleStorage.CreateAutoOpsRule(contextWithTx, autoOpsRule, req.EnvironmentId) }) if err != nil { if err == v2as.ErrAutoOpsRuleAlreadyExists { @@ -473,27 +457,10 @@ func (s *AutoOpsService) StopAutoOpsRule( if err := validateStopAutoOpsRuleRequest(req, localizer); err != nil { return nil, err } - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return nil, statusInternal.Err() - } - return nil, dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - autoOpsRuleStorage := v2as.NewAutoOpsRuleStorage(tx) - autoOpsRule, err := autoOpsRuleStorage.GetAutoOpsRule(ctx, req.Id, req.EnvironmentId) + err = s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, _ mysql.Transaction) error { + autoOpsRuleStorage := v2as.NewAutoOpsRuleStorage(s.mysqlClient) + autoOpsRule, err := autoOpsRuleStorage.GetAutoOpsRule(contextWithTx, req.Id, req.EnvironmentId) if err != nil { return err } @@ -514,7 +481,7 @@ func (s *AutoOpsService) StopAutoOpsRule( if err := handler.Handle(ctx, req.Command); err != nil { return err } - return autoOpsRuleStorage.UpdateAutoOpsRule(ctx, autoOpsRule, req.EnvironmentId) + return autoOpsRuleStorage.UpdateAutoOpsRule(contextWithTx, autoOpsRule, req.EnvironmentId) }) if err != nil { @@ -561,26 +528,10 @@ func (s *AutoOpsService) DeleteAutoOpsRule( if err := validateDeleteAutoOpsRuleRequest(req, localizer); err != nil { return nil, err } - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return nil, statusInternal.Err() - } - return nil, dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - autoOpsRuleStorage := v2as.NewAutoOpsRuleStorage(tx) - autoOpsRule, err := autoOpsRuleStorage.GetAutoOpsRule(ctx, req.Id, req.EnvironmentId) + + err = s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, _ mysql.Transaction) error { + autoOpsRuleStorage := v2as.NewAutoOpsRuleStorage(s.mysqlClient) + autoOpsRule, err := autoOpsRuleStorage.GetAutoOpsRule(contextWithTx, req.Id, req.EnvironmentId) if err != nil { return err } @@ -591,7 +542,7 @@ func (s *AutoOpsService) DeleteAutoOpsRule( if err := handler.Handle(ctx, req.Command); err != nil { return err } - return autoOpsRuleStorage.UpdateAutoOpsRule(ctx, autoOpsRule, req.EnvironmentId) + return autoOpsRuleStorage.UpdateAutoOpsRule(contextWithTx, autoOpsRule, req.EnvironmentId) }) if err != nil { if err == v2as.ErrAutoOpsRuleNotFound || err == v2as.ErrAutoOpsRuleUnexpectedAffectedRows { @@ -696,26 +647,10 @@ func (s *AutoOpsService) UpdateAutoOpsRule( } } commands := s.createUpdateAutoOpsRuleCommands(req) - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return nil, statusInternal.Err() - } - return nil, dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - autoOpsRuleStorage := v2as.NewAutoOpsRuleStorage(tx) - autoOpsRule, err := autoOpsRuleStorage.GetAutoOpsRule(ctx, req.Id, req.EnvironmentId) + + err = s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, _ mysql.Transaction) error { + autoOpsRuleStorage := v2as.NewAutoOpsRuleStorage(s.mysqlClient) + autoOpsRule, err := autoOpsRuleStorage.GetAutoOpsRule(contextWithTx, req.Id, req.EnvironmentId) if err != nil { return err } @@ -817,7 +752,7 @@ func (s *AutoOpsService) UpdateAutoOpsRule( return err } } - return autoOpsRuleStorage.UpdateAutoOpsRule(ctx, autoOpsRule, req.EnvironmentId) + return autoOpsRuleStorage.UpdateAutoOpsRule(contextWithTx, autoOpsRule, req.EnvironmentId) }) if err != nil { if err == v2as.ErrAutoOpsRuleNotFound || err == v2as.ErrAutoOpsRuleUnexpectedAffectedRows { @@ -1187,26 +1122,10 @@ func (s *AutoOpsService) ExecuteAutoOps( if triggered { return &autoopsproto.ExecuteAutoOpsResponse{AlreadyTriggered: true}, nil } - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return nil, statusInternal.Err() - } - return nil, dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - autoOpsRuleStorage := v2as.NewAutoOpsRuleStorage(tx) - autoOpsRule, err := autoOpsRuleStorage.GetAutoOpsRule(ctx, req.Id, req.EnvironmentId) + + err = s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, tx mysql.Transaction) error { + autoOpsRuleStorage := v2as.NewAutoOpsRuleStorage(s.mysqlClient) + autoOpsRule, err := autoOpsRuleStorage.GetAutoOpsRule(contextWithTx, req.Id, req.EnvironmentId) if err != nil { return err } @@ -1231,15 +1150,15 @@ func (s *AutoOpsService) ExecuteAutoOps( } ftStorage := ftstorage.NewFeatureStorage(tx) - feature, err := ftStorage.GetFeature(ctx, autoOpsRule.FeatureId, req.EnvironmentId) + feature, err := ftStorage.GetFeature(contextWithTx, autoOpsRule.FeatureId, req.EnvironmentId) if err != nil { return err } - prStorage := v2as.NewProgressiveRolloutStorage(tx) + prStorage := v2as.NewProgressiveRolloutStorage(s.mysqlClient) // Stop the running progressive rollout if the operation type is disable if executeClause.ActionType == autoopsproto.ActionType_DISABLE { if err := s.stopProgressiveRollout( - ctx, + contextWithTx, req.EnvironmentId, autoOpsRule, prStorage, @@ -1249,7 +1168,7 @@ func (s *AutoOpsService) ExecuteAutoOps( } } if err := executeAutoOpsRuleOperation( - ctx, + contextWithTx, ftStorage, req.EnvironmentId, executeClause.ActionType, @@ -1280,7 +1199,7 @@ func (s *AutoOpsService) ExecuteAutoOps( return err } - if err = autoOpsRuleStorage.UpdateAutoOpsRule(ctx, autoOpsRule, req.EnvironmentId); err != nil { + if err = autoOpsRuleStorage.UpdateAutoOpsRule(contextWithTx, autoOpsRule, req.EnvironmentId); err != nil { if err == v2as.ErrAutoOpsRuleUnexpectedAffectedRows { s.logger.Warn( "No rows were affected", diff --git a/pkg/autoops/api/api_test.go b/pkg/autoops/api/api_test.go index 58985ab1d6..5615c5a85d 100644 --- a/pkg/autoops/api/api_test.go +++ b/pkg/autoops/api/api_test.go @@ -367,9 +367,8 @@ func TestCreateAutoOpsRuleMySQL(t *testing.T) { s.experimentClient.(*experimentclientmock.MockClient).EXPECT().GetGoal( gomock.Any(), gomock.Any(), ).Return(&experimentproto.GetGoalResponse{}, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &autoopsproto.CreateAutoOpsRuleRequest{ @@ -393,9 +392,8 @@ func TestCreateAutoOpsRuleMySQL(t *testing.T) { { desc: "success schedule", setup: func(s *AutoOpsService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &autoopsproto.CreateAutoOpsRuleRequest{ @@ -590,9 +588,8 @@ func TestUpdateAutoOpsRuleMySQL(t *testing.T) { { desc: "success", setup: func(s *AutoOpsService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &autoopsproto.UpdateAutoOpsRuleRequest{ @@ -673,9 +670,8 @@ func TestStopAutoOpsRuleMySQL(t *testing.T) { { desc: "success", setup: func(s *AutoOpsService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &autoopsproto.StopAutoOpsRuleRequest{ @@ -738,9 +734,8 @@ func TestDeleteAutoOpsRuleMySQL(t *testing.T) { { desc: "success", setup: func(s *AutoOpsService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &autoopsproto.DeleteAutoOpsRuleRequest{ @@ -801,7 +796,11 @@ func TestGetAutoOpsRuleMySQL(t *testing.T) { setup: func(s *AutoOpsService) { row := mysqlmock.NewMockRow(mockController) row.EXPECT().Scan(gomock.Any()).Return(mysql.ErrNoRows) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( + qe := mysqlmock.NewMockQueryExecer(mockController) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().QueryRowContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(row) }, @@ -821,7 +820,11 @@ func TestGetAutoOpsRuleMySQL(t *testing.T) { setup: func(s *AutoOpsService) { row := mysqlmock.NewMockRow(mockController) row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( + qe := mysqlmock.NewMockQueryExecer(mockController) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().QueryRowContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(row) }, @@ -834,7 +837,11 @@ func TestGetAutoOpsRuleMySQL(t *testing.T) { setup: func(s *AutoOpsService) { row := mysqlmock.NewMockRow(mockController) row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( + qe := mysqlmock.NewMockQueryExecer(mockController) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().QueryRowContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(row) }, @@ -888,7 +895,11 @@ func TestListAutoOpsRulesMySQL(t *testing.T) { rows.EXPECT().Close().Return(nil) rows.EXPECT().Next().Return(false) rows.EXPECT().Err().Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryContext( + qe := mysqlmock.NewMockQueryExecer(mockController) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().QueryContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(rows, nil) }, @@ -910,7 +921,12 @@ func TestListAutoOpsRulesMySQL(t *testing.T) { rows.EXPECT().Close().Return(nil) rows.EXPECT().Next().Return(false) rows.EXPECT().Err().Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryContext( + qe := mysqlmock.NewMockQueryExecer(mockController) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + + qe.EXPECT().QueryContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(rows, nil) }, @@ -1058,7 +1074,11 @@ func TestExecuteAutoOpsRuleMySQL(t *testing.T) { setup: func(s *AutoOpsService) { row := mysqlmock.NewMockRow(mockController) row.EXPECT().Scan(gomock.Any()).Return(mysql.ErrNoRows) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( + qe := mysqlmock.NewMockQueryExecer(mockController) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().QueryRowContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(row) }, @@ -1076,12 +1096,15 @@ func TestExecuteAutoOpsRuleMySQL(t *testing.T) { setup: func(s *AutoOpsService) { row := mysqlmock.NewMockRow(mockController) row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( + qe := mysqlmock.NewMockQueryExecer(mockController) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().QueryRowContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(row) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil).AnyTimes() - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &autoopsproto.ExecuteAutoOpsRequest{ diff --git a/pkg/autoops/api/progressive_rollout.go b/pkg/autoops/api/progressive_rollout.go index 46ade364fe..a46f4dc161 100644 --- a/pkg/autoops/api/progressive_rollout.go +++ b/pkg/autoops/api/progressive_rollout.go @@ -57,24 +57,8 @@ func (s *AutoOpsService) CreateProgressiveRollout( if err := s.validateCreateProgressiveRolloutRequest(ctx, req, localizer); err != nil { return nil, err } - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusProgressiveRolloutInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return nil, statusProgressiveRolloutInternal.Err() - } - return nil, dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { + + err = s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, tx mysql.Transaction) error { progressiveRollout, err := domain.NewProgressiveRollout( req.Command.FeatureId, req.Command.ProgressiveRolloutManualScheduleClause, @@ -83,7 +67,7 @@ func (s *AutoOpsService) CreateProgressiveRollout( if err != nil { return err } - storage := v2as.NewProgressiveRolloutStorage(tx) + storage := v2as.NewProgressiveRolloutStorage(s.mysqlClient) handler, err := command.NewProgressiveRolloutCommandHandler( editor, progressiveRollout, @@ -96,7 +80,7 @@ func (s *AutoOpsService) CreateProgressiveRollout( if err := handler.Handle(ctx, req.Command); err != nil { return err } - return storage.CreateProgressiveRollout(ctx, progressiveRollout, req.EnvironmentId) + return storage.CreateProgressiveRollout(contextWithTx, progressiveRollout, req.EnvironmentId) }) if err != nil { switch err { @@ -213,26 +197,9 @@ func (s *AutoOpsService) updateProgressiveRollout( editor *eventproto.Editor, localizer locale.Localizer, ) error { - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusProgressiveRolloutInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return statusProgressiveRolloutInternal.Err() - } - return dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - storage := v2as.NewProgressiveRolloutStorage(tx) - progressiveRollout, err := storage.GetProgressiveRollout(ctx, progressiveRolloutID, environmentId) + err := s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, tx mysql.Transaction) error { + storage := v2as.NewProgressiveRolloutStorage(s.mysqlClient) + progressiveRollout, err := storage.GetProgressiveRollout(contextWithTx, progressiveRolloutID, environmentId) if err != nil { return err } @@ -248,7 +215,7 @@ func (s *AutoOpsService) updateProgressiveRollout( if err := handler.Handle(ctx, cmd); err != nil { return err } - return storage.UpdateProgressiveRollout(ctx, progressiveRollout, environmentId) + return storage.UpdateProgressiveRollout(contextWithTx, progressiveRollout, environmentId) }) if err != nil { s.logger.Error( @@ -295,26 +262,10 @@ func (s *AutoOpsService) DeleteProgressiveRollout( if err := s.validateDeleteProgressiveRolloutRequest(req, localizer); err != nil { return nil, err } - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusProgressiveRolloutInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return nil, statusProgressiveRolloutInternal.Err() - } - return nil, dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - storage := v2as.NewProgressiveRolloutStorage(tx) - progressiveRollout, err := storage.GetProgressiveRollout(ctx, req.Id, req.EnvironmentId) + + err = s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, tx mysql.Transaction) error { + storage := v2as.NewProgressiveRolloutStorage(s.mysqlClient) + progressiveRollout, err := storage.GetProgressiveRollout(contextWithTx, req.Id, req.EnvironmentId) if err != nil { return err } @@ -330,7 +281,7 @@ func (s *AutoOpsService) DeleteProgressiveRollout( if err := handler.Handle(ctx, req.Command); err != nil { return err } - return storage.DeleteProgressiveRollout(ctx, req.Id, req.EnvironmentId) + return storage.DeleteProgressiveRollout(contextWithTx, req.Id, req.EnvironmentId) }) if err != nil { s.logger.Error( @@ -402,26 +353,10 @@ func (s *AutoOpsService) ExecuteProgressiveRollout( if err := s.validateExecuteProgressiveRolloutRequest(req, localizer); err != nil { return nil, err } - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusProgressiveRolloutInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return nil, statusProgressiveRolloutInternal.Err() - } - return nil, dt.Err() - } + var event *eventproto.Event - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - storage := v2as.NewProgressiveRolloutStorage(tx) + err = s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, tx mysql.Transaction) error { + storage := v2as.NewProgressiveRolloutStorage(s.mysqlClient) progressiveRollout, err := storage.GetProgressiveRollout(ctx, req.Id, req.EnvironmentId) if err != nil { return err diff --git a/pkg/autoops/api/progressive_rollout_test.go b/pkg/autoops/api/progressive_rollout_test.go index 474a82a709..8f2cd01c70 100644 --- a/pkg/autoops/api/progressive_rollout_test.go +++ b/pkg/autoops/api/progressive_rollout_test.go @@ -643,41 +643,6 @@ func TestCreateProgressiveRolloutMySQL(t *testing.T) { }, expectedErr: createError(statusProgressiveRolloutInvalidScheduleSpans, localizer.MustLocalize(locale.AutoOpsInvalidScheduleSpans)), }, - { - desc: "err: begin transaction error", - setup: func(aos *AutoOpsService) { - aos.featureClient.(*featureclientmock.MockClient).EXPECT().GetFeature( - gomock.Any(), gomock.Any(), - ).Return(&featureproto.GetFeatureResponse{Feature: &featureproto.Feature{ - Variations: []*featureproto.Variation{ - { - Id: "vid-1", - }, - { - Id: "vid-2", - }, - }, - Enabled: true, - }}, nil) - aos.experimentClient.(*experimentclientmock.MockClient).EXPECT().ListExperiments(gomock.Any(), gomock.Any()).Return( - &experimentproto.ListExperimentsResponse{Experiments: []*experimentproto.Experiment{}}, - nil, - ) - aos.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, errors.New("error")) - }, - req: &autoopsproto.CreateProgressiveRolloutRequest{ - Command: &autoopsproto.CreateProgressiveRolloutCommand{ - FeatureId: "fid", - ProgressiveRolloutTemplateScheduleClause: &autoopsproto.ProgressiveRolloutTemplateScheduleClause{ - VariationId: "vid-1", - Schedules: validSchedules, - Interval: autoopsproto.ProgressiveRolloutTemplateScheduleClause_DAILY, - Increments: 2, - }, - }, - }, - expectedErr: createError(statusProgressiveRolloutInternal, localizer.MustLocalize(locale.InternalServerError)), - }, { desc: "err: transaction error", setup: func(aos *AutoOpsService) { @@ -698,9 +663,8 @@ func TestCreateProgressiveRolloutMySQL(t *testing.T) { &experimentproto.ListExperimentsResponse{Experiments: []*experimentproto.Experiment{}}, nil, ) - aos.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - aos.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + aos.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(errors.New("error")) }, req: &autoopsproto.CreateProgressiveRolloutRequest{ @@ -736,9 +700,8 @@ func TestCreateProgressiveRolloutMySQL(t *testing.T) { &experimentproto.ListExperimentsResponse{Experiments: []*experimentproto.Experiment{}}, nil, ) - aos.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - aos.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + aos.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2as.ErrProgressiveRolloutAlreadyExists) }, req: &autoopsproto.CreateProgressiveRolloutRequest{ @@ -848,9 +811,8 @@ func TestCreateProgressiveRolloutMySQL(t *testing.T) { &experimentproto.ListExperimentsResponse{Experiments: []*experimentproto.Experiment{}}, nil, ) - aos.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - aos.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + aos.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &autoopsproto.CreateProgressiveRolloutRequest{ @@ -914,7 +876,13 @@ func TestGetProgressiveRolloutMySQL(t *testing.T) { setup: func(s *AutoOpsService) { row := mysqlmock.NewMockRow(mockController) row.EXPECT().Scan(gomock.Any()).Return(mysql.ErrNoRows) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( + qe := mysqlmock.NewMockQueryExecer(mockController) + + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + + qe.EXPECT().QueryRowContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(row) }, @@ -926,7 +894,13 @@ func TestGetProgressiveRolloutMySQL(t *testing.T) { setup: func(s *AutoOpsService) { row := mysqlmock.NewMockRow(mockController) row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( + qe := mysqlmock.NewMockQueryExecer(mockController) + + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + + qe.EXPECT().QueryRowContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(row) }, @@ -981,26 +955,11 @@ func TestStopProgressiveRolloutMySQL(t *testing.T) { req: &autoopsproto.StopProgressiveRolloutRequest{Id: "id", EnvironmentId: "ns"}, expectedErr: createError(statusProgressiveRolloutNoCommand, localizer.MustLocalizeWithTemplate(locale.RequiredFieldTemplate, "command")), }, - { - desc: "err: failed to begin transaction", - setup: func(s *AutoOpsService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, errors.New("error")) - }, - req: &autoopsproto.StopProgressiveRolloutRequest{ - Id: "id", - EnvironmentId: "ns", - Command: &autoopsproto.StopProgressiveRolloutCommand{ - StoppedBy: autoopsproto.ProgressiveRollout_USER, - }, - }, - expectedErr: createError(statusProgressiveRolloutInternal, localizer.MustLocalize(locale.InternalServerError)), - }, { desc: "err: internal error during transaction", setup: func(s *AutoOpsService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(errors.New("error")) }, req: &autoopsproto.StopProgressiveRolloutRequest{ @@ -1015,9 +974,8 @@ func TestStopProgressiveRolloutMySQL(t *testing.T) { { desc: "err: not found", setup: func(s *AutoOpsService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2as.ErrProgressiveRolloutNotFound) }, req: &autoopsproto.StopProgressiveRolloutRequest{ @@ -1032,9 +990,8 @@ func TestStopProgressiveRolloutMySQL(t *testing.T) { { desc: "err: unexpected affected rows", setup: func(s *AutoOpsService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2as.ErrProgressiveRolloutUnexpectedAffectedRows) }, req: &autoopsproto.StopProgressiveRolloutRequest{ @@ -1049,9 +1006,8 @@ func TestStopProgressiveRolloutMySQL(t *testing.T) { { desc: "success", setup: func(s *AutoOpsService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &autoopsproto.StopProgressiveRolloutRequest{ @@ -1106,20 +1062,11 @@ func TestDeleteProgressiveRolloutMySQL(t *testing.T) { req: &autoopsproto.DeleteProgressiveRolloutRequest{}, expectedErr: createError(statusProgressiveRolloutIDRequired, localizer.MustLocalizeWithTemplate(locale.RequiredFieldTemplate, "id")), }, - { - desc: "err: failed to begin transaction", - setup: func(s *AutoOpsService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, errors.New("error")) - }, - req: &autoopsproto.DeleteProgressiveRolloutRequest{Id: "wrongid", EnvironmentId: "ns0"}, - expectedErr: createError(statusProgressiveRolloutInternal, localizer.MustLocalize(locale.InternalServerError)), - }, { desc: "err: internal error during transaction", setup: func(s *AutoOpsService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(errors.New("error")) }, req: &autoopsproto.DeleteProgressiveRolloutRequest{Id: "wrongid", EnvironmentId: "ns0"}, @@ -1128,9 +1075,8 @@ func TestDeleteProgressiveRolloutMySQL(t *testing.T) { { desc: "err: internal error during transaction", setup: func(s *AutoOpsService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(errors.New("error")) }, req: &autoopsproto.DeleteProgressiveRolloutRequest{Id: "wrongid", EnvironmentId: "ns0"}, @@ -1139,9 +1085,8 @@ func TestDeleteProgressiveRolloutMySQL(t *testing.T) { { desc: "err: ErrProgressiveRolloutNotFound", setup: func(s *AutoOpsService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2as.ErrProgressiveRolloutNotFound) }, req: &autoopsproto.DeleteProgressiveRolloutRequest{Id: "wrongid", EnvironmentId: "ns0"}, @@ -1150,9 +1095,8 @@ func TestDeleteProgressiveRolloutMySQL(t *testing.T) { { desc: "err: ErrProgressiveRolloutUnexpectedAffectedRows", setup: func(s *AutoOpsService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2as.ErrProgressiveRolloutUnexpectedAffectedRows) }, req: &autoopsproto.DeleteProgressiveRolloutRequest{Id: "wrongid", EnvironmentId: "ns0"}, @@ -1161,9 +1105,8 @@ func TestDeleteProgressiveRolloutMySQL(t *testing.T) { { desc: "success", setup: func(s *AutoOpsService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &autoopsproto.DeleteProgressiveRolloutRequest{Id: "aid1", EnvironmentId: "ns0"}, @@ -1218,16 +1161,20 @@ func TestListProgressiveRolloutsMySQL(t *testing.T) { { desc: "err: interal error", setup: func(s *AutoOpsService) { + qe := mysqlmock.NewMockQueryExecer(mockController) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe).Times(2) rows := mysqlmock.NewMockRows(mockController) rows.EXPECT().Close().Return(nil) rows.EXPECT().Next().Return(false) rows.EXPECT().Err().Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryContext( + qe.EXPECT().QueryContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(rows, nil) row := mysqlmock.NewMockRow(mockController) row.EXPECT().Scan(gomock.Any()).Return(errors.New("error")) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( + qe.EXPECT().QueryRowContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(row) }, @@ -1238,16 +1185,20 @@ func TestListProgressiveRolloutsMySQL(t *testing.T) { { desc: "success", setup: func(s *AutoOpsService) { + qe := mysqlmock.NewMockQueryExecer(mockController) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe).Times(2) rows := mysqlmock.NewMockRows(mockController) rows.EXPECT().Close().Return(nil) rows.EXPECT().Next().Return(false) rows.EXPECT().Err().Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryContext( + qe.EXPECT().QueryContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(rows, nil) row := mysqlmock.NewMockRow(mockController) row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( + qe.EXPECT().QueryRowContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(row) }, @@ -1317,26 +1268,11 @@ func TestExecuteProgressiveRolloutMySQL(t *testing.T) { }, expectedErr: createError(statusProgressiveRolloutScheduleIDRequired, localizer.MustLocalizeWithTemplate(locale.RequiredFieldTemplate, "schedule_id")), }, - { - desc: "err: begin transaction error", - setup: func(s *AutoOpsService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, errors.New("error")) - }, - req: &autoopsproto.ExecuteProgressiveRolloutRequest{ - Id: "aid1", - EnvironmentId: "ns0", - ChangeProgressiveRolloutTriggeredAtCommand: &autoopsproto.ChangeProgressiveRolloutScheduleTriggeredAtCommand{ - ScheduleId: "sid1", - }, - }, - expectedErr: createError(statusProgressiveRolloutInternal, localizer.MustLocalize(locale.InternalServerError)), - }, { desc: "success", setup: func(s *AutoOpsService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &autoopsproto.ExecuteProgressiveRolloutRequest{ diff --git a/pkg/autoops/storage/v2/auto_ops_rule.go b/pkg/autoops/storage/v2/auto_ops_rule.go index 4097a495bd..3b37058206 100644 --- a/pkg/autoops/storage/v2/auto_ops_rule.go +++ b/pkg/autoops/storage/v2/auto_ops_rule.go @@ -56,11 +56,11 @@ type AutoOpsRuleStorage interface { } type autoOpsRuleStorage struct { - qe mysql.QueryExecer + client mysql.Client } -func NewAutoOpsRuleStorage(qe mysql.QueryExecer) AutoOpsRuleStorage { - return &autoOpsRuleStorage{qe: qe} +func NewAutoOpsRuleStorage(client mysql.Client) AutoOpsRuleStorage { + return &autoOpsRuleStorage{client: client} } func (s *autoOpsRuleStorage) CreateAutoOpsRule( @@ -68,7 +68,7 @@ func (s *autoOpsRuleStorage) CreateAutoOpsRule( e *domain.AutoOpsRule, environmentId string, ) error { - _, err := s.qe.ExecContext( + _, err := s.client.Qe(ctx).ExecContext( ctx, insertAutoOpsRuleSQL, e.Id, @@ -95,7 +95,7 @@ func (s *autoOpsRuleStorage) UpdateAutoOpsRule( e *domain.AutoOpsRule, environmentId string, ) error { - result, err := s.qe.ExecContext( + result, err := s.client.Qe(ctx).ExecContext( ctx, updateAutoOpsRuleSQL, e.FeatureId, @@ -127,7 +127,7 @@ func (s *autoOpsRuleStorage) GetAutoOpsRule( ) (*domain.AutoOpsRule, error) { autoOpsRule := proto.AutoOpsRule{} var opsType int32 - err := s.qe.QueryRowContext( + err := s.client.Qe(ctx).QueryRowContext( ctx, selectAutoOpsRuleSQL, id, @@ -162,7 +162,7 @@ func (s *autoOpsRuleStorage) ListAutoOpsRules( orderBySQL := mysql.ConstructOrderBySQLString(orders) limitOffsetSQL := mysql.ConstructLimitOffsetSQLString(limit, offset) query := fmt.Sprintf(selectAutoOpsRulesSQL, whereSQL, orderBySQL, limitOffsetSQL) - rows, err := s.qe.QueryContext(ctx, query, whereArgs...) + rows, err := s.client.Qe(ctx).QueryContext(ctx, query, whereArgs...) if err != nil { return nil, 0, err } diff --git a/pkg/autoops/storage/v2/auto_ops_rule_test.go b/pkg/autoops/storage/v2/auto_ops_rule_test.go index 939c808a41..ab209048d1 100644 --- a/pkg/autoops/storage/v2/auto_ops_rule_test.go +++ b/pkg/autoops/storage/v2/auto_ops_rule_test.go @@ -32,7 +32,8 @@ func TestNewAutoOpsRuleStorage(t *testing.T) { t.Parallel() mockController := gomock.NewController(t) defer mockController.Finish() - db := NewAutoOpsRuleStorage(mock.NewMockQueryExecer(mockController)) + client := mock.NewMockClient(mockController) + db := NewAutoOpsRuleStorage(client) assert.IsType(t, &autoOpsRuleStorage{}, db) } @@ -49,7 +50,11 @@ func TestCreateAutoOpsRule(t *testing.T) { }{ { setup: func(s *autoOpsRuleStorage) { - s.qe.(*mock.MockQueryExecer).EXPECT().ExecContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().ExecContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil, mysql.ErrDuplicateEntry) }, @@ -61,7 +66,11 @@ func TestCreateAutoOpsRule(t *testing.T) { }, { setup: func(s *autoOpsRuleStorage) { - s.qe.(*mock.MockQueryExecer).EXPECT().ExecContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().ExecContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil, nil) }, @@ -97,7 +106,11 @@ func TestUpdateAutoOpsRule(t *testing.T) { setup: func(s *autoOpsRuleStorage) { result := mock.NewMockResult(mockController) result.EXPECT().RowsAffected().Return(int64(0), nil) - s.qe.(*mock.MockQueryExecer).EXPECT().ExecContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().ExecContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(result, nil) }, @@ -111,7 +124,11 @@ func TestUpdateAutoOpsRule(t *testing.T) { setup: func(s *autoOpsRuleStorage) { result := mock.NewMockResult(mockController) result.EXPECT().RowsAffected().Return(int64(1), nil) - s.qe.(*mock.MockQueryExecer).EXPECT().ExecContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().ExecContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(result, nil) }, @@ -148,7 +165,11 @@ func TestGetAutoOpsRule(t *testing.T) { setup: func(s *autoOpsRuleStorage) { row := mock.NewMockRow(mockController) row.EXPECT().Scan(gomock.Any()).Return(mysql.ErrNoRows) - s.qe.(*mock.MockQueryExecer).EXPECT().QueryRowContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().QueryRowContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(row) }, @@ -161,7 +182,11 @@ func TestGetAutoOpsRule(t *testing.T) { setup: func(s *autoOpsRuleStorage) { row := mock.NewMockRow(mockController) row.EXPECT().Scan(gomock.Any()).Return(nil) - s.qe.(*mock.MockQueryExecer).EXPECT().QueryRowContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().QueryRowContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(row) }, @@ -199,7 +224,11 @@ func TestListAutoOpsRules(t *testing.T) { }{ { setup: func(s *autoOpsRuleStorage) { - s.qe.(*mock.MockQueryExecer).EXPECT().QueryContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().QueryContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil, errors.New("error")) }, @@ -217,7 +246,11 @@ func TestListAutoOpsRules(t *testing.T) { rows.EXPECT().Close().Return(nil) rows.EXPECT().Next().Return(false) rows.EXPECT().Err().Return(nil) - s.qe.(*mock.MockQueryExecer).EXPECT().QueryContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().QueryContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(rows, nil) }, @@ -254,5 +287,5 @@ func TestListAutoOpsRules(t *testing.T) { func newAutoOpsRuleStorageWithMock(t *testing.T, mockController *gomock.Controller) *autoOpsRuleStorage { t.Helper() - return &autoOpsRuleStorage{mock.NewMockQueryExecer(mockController)} + return &autoOpsRuleStorage{mock.NewMockClient(mockController)} } diff --git a/pkg/autoops/storage/v2/progressive_rollout.go b/pkg/autoops/storage/v2/progressive_rollout.go index 7864c2b57c..e9a150552c 100644 --- a/pkg/autoops/storage/v2/progressive_rollout.go +++ b/pkg/autoops/storage/v2/progressive_rollout.go @@ -48,7 +48,7 @@ var ( ) type progressiveRolloutStorage struct { - qe mysql.QueryExecer + client mysql.Client } type ProgressiveRolloutStorage interface { @@ -71,8 +71,8 @@ type ProgressiveRolloutStorage interface { ) error } -func NewProgressiveRolloutStorage(qe mysql.QueryExecer) ProgressiveRolloutStorage { - return &progressiveRolloutStorage{qe: qe} +func NewProgressiveRolloutStorage(client mysql.Client) ProgressiveRolloutStorage { + return &progressiveRolloutStorage{client: client} } func (s *progressiveRolloutStorage) CreateProgressiveRollout( @@ -80,7 +80,7 @@ func (s *progressiveRolloutStorage) CreateProgressiveRollout( progressiveRollout *domain.ProgressiveRollout, environmentId string, ) error { - _, err := s.qe.ExecContext( + _, err := s.client.Qe(ctx).ExecContext( ctx, insertOpsProgressiveRolloutSQL, progressiveRollout.Id, @@ -108,7 +108,7 @@ func (s *progressiveRolloutStorage) GetProgressiveRollout( id, environmentId string, ) (*domain.ProgressiveRollout, error) { progressiveRollout := autoopsproto.ProgressiveRollout{} - err := s.qe.QueryRowContext( + err := s.client.Qe(ctx).QueryRowContext( ctx, selectOpsProgressiveRolloutSQL, id, @@ -137,7 +137,7 @@ func (s *progressiveRolloutStorage) DeleteProgressiveRollout( ctx context.Context, id, environmentId string, ) error { - result, err := s.qe.ExecContext( + result, err := s.client.Qe(ctx).ExecContext( ctx, deleteOpsProgressiveRolloutSQL, id, @@ -166,7 +166,7 @@ func (s *progressiveRolloutStorage) ListProgressiveRollouts( orderBySQL := mysql.ConstructOrderBySQLString(orders) limitOffsetSQL := mysql.ConstructLimitOffsetSQLString(limit, offset) query := fmt.Sprintf(selectOpsProgressiveRolloutsSQL, whereSQL, orderBySQL, limitOffsetSQL) - rows, err := s.qe.QueryContext(ctx, query, whereArgs...) + rows, err := s.client.Qe(ctx).QueryContext(ctx, query, whereArgs...) if err != nil { return nil, 0, 0, err } @@ -196,7 +196,7 @@ func (s *progressiveRolloutStorage) ListProgressiveRollouts( nextOffset := offset + len(progressiveRollouts) var totalCount int64 countQuery := fmt.Sprintf(countOpsProgressiveRolloutsSQL, whereSQL) - err = s.qe.QueryRowContext(ctx, countQuery, whereArgs...).Scan(&totalCount) + err = s.client.Qe(ctx).QueryRowContext(ctx, countQuery, whereArgs...).Scan(&totalCount) if err != nil { return nil, 0, 0, err } @@ -208,7 +208,7 @@ func (s *progressiveRolloutStorage) UpdateProgressiveRollout( progressiveRollout *domain.ProgressiveRollout, environmentId string, ) error { - result, err := s.qe.ExecContext( + result, err := s.client.Qe(ctx).ExecContext( ctx, updateOpsProgressiveRolloutSQL, &progressiveRollout.FeatureId, diff --git a/pkg/autoops/storage/v2/progressive_rollout_test.go b/pkg/autoops/storage/v2/progressive_rollout_test.go index 17a461b73e..e402246073 100644 --- a/pkg/autoops/storage/v2/progressive_rollout_test.go +++ b/pkg/autoops/storage/v2/progressive_rollout_test.go @@ -31,7 +31,7 @@ func TestNewProgressiveRolloutStorage(t *testing.T) { t.Parallel() mockController := gomock.NewController(t) defer mockController.Finish() - db := NewProgressiveRolloutStorage(mock.NewMockQueryExecer(mockController)) + db := NewProgressiveRolloutStorage(mock.NewMockClient(mockController)) assert.IsType(t, &progressiveRolloutStorage{}, db) } @@ -48,9 +48,14 @@ func TestCreateProgressiveRollout(t *testing.T) { expectedErr error }{ { - desc: "", + desc: "error", setup: func(s *progressiveRolloutStorage) { - s.qe.(*mock.MockQueryExecer).EXPECT().ExecContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + + qe.EXPECT().ExecContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil, mysql.ErrDuplicateEntry) }, @@ -61,8 +66,14 @@ func TestCreateProgressiveRollout(t *testing.T) { expectedErr: ErrProgressiveRolloutAlreadyExists, }, { + desc: "success", setup: func(s *progressiveRolloutStorage) { - s.qe.(*mock.MockQueryExecer).EXPECT().ExecContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + + qe.EXPECT().ExecContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil, nil) }, @@ -85,5 +96,5 @@ func TestCreateProgressiveRollout(t *testing.T) { func newProgressiveRolloutStorageWithMock(t *testing.T, mockController *gomock.Controller) *progressiveRolloutStorage { t.Helper() - return &progressiveRolloutStorage{mock.NewMockQueryExecer(mockController)} + return &progressiveRolloutStorage{mock.NewMockClient(mockController)} } diff --git a/pkg/feature/api/feature.go b/pkg/feature/api/feature.go index d788f1eee3..6e677bd537 100644 --- a/pkg/feature/api/feature.go +++ b/pkg/feature/api/feature.go @@ -1396,26 +1396,10 @@ func (s *FeatureService) updateFeature( return dt.Err() } var handler *command.FeatureCommandHandler = command.NewEmptyFeatureCommandHandler() - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return statusInternal.Err() - } - return dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { + + err := s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, tx mysql.Transaction) error { featureStorage := v2fs.NewFeatureStorage(tx) - feature, err := featureStorage.GetFeature(ctx, id, environmentId) + feature, err := featureStorage.GetFeature(contextWithTx, id, environmentId) if err != nil { s.logger.Error( "Failed to get feature", @@ -1444,7 +1428,7 @@ func (s *FeatureService) updateFeature( // We must stop the progressive rollout if it contains a `DisableFeatureCommand` switch cmd.(type) { case *featureproto.DisableFeatureCommand: - if err := s.stopProgressiveRollout(ctx, tx, environmentId, feature.Id); err != nil { + if err := s.stopProgressiveRollout(contextWithTx, environmentId, feature.Id); err != nil { return err } } @@ -1458,7 +1442,7 @@ func (s *FeatureService) updateFeature( ) return err } - if err := featureStorage.UpdateFeature(ctx, feature, environmentId); err != nil { + if err := featureStorage.UpdateFeature(contextWithTx, feature, environmentId); err != nil { s.logger.Error( "Failed to update feature", log.FieldsFromImcomingContext(ctx).AddFields( @@ -1863,7 +1847,7 @@ func (s *FeatureService) UpdateFeatureTargeting( // We must stop the progressive rollout if it contains a `DisableFeatureCommand` switch cmd.(type) { case *featureproto.DisableFeatureCommand: - if err := s.stopProgressiveRollout(ctx, tx, req.EnvironmentId, feature.Id); err != nil { + if err := s.stopProgressiveRollout(ctx, req.EnvironmentId, feature.Id); err != nil { return err } } @@ -1919,9 +1903,8 @@ func (s *FeatureService) UpdateFeatureTargeting( func (s *FeatureService) stopProgressiveRollout( ctx context.Context, - tx mysql.Transaction, EnvironmentId, featureID string) error { - storage := v2ao.NewProgressiveRolloutStorage(tx) + storage := v2ao.NewProgressiveRolloutStorage(s.mysqlClient) ids := convToInterfaceSlice([]string{featureID}) whereParts := []mysql.WherePart{ mysql.NewFilter("environment_id", "=", EnvironmentId), diff --git a/pkg/feature/api/feature_test.go b/pkg/feature/api/feature_test.go index 5566ffd570..1b5e73bac5 100644 --- a/pkg/feature/api/feature_test.go +++ b/pkg/feature/api/feature_test.go @@ -1760,9 +1760,8 @@ func TestEnableFeatureMySQL(t *testing.T) { { desc: "error: statusNotFound", setup: func(s *FeatureService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2fs.ErrFeatureNotFound) s.environmentClient.(*envclientmock.MockClient).EXPECT().GetEnvironmentV2(gomock.Any(), gomock.Any()).Return( &envproto.GetEnvironmentV2Response{ @@ -1781,9 +1780,8 @@ func TestEnableFeatureMySQL(t *testing.T) { { desc: "success", setup: func(s *FeatureService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(nil) s.batchClient.(*btclientmock.MockClient).EXPECT().ExecuteBatchJob(gomock.Any(), gomock.Any()) s.environmentClient.(*envclientmock.MockClient).EXPECT().GetEnvironmentV2( @@ -1866,9 +1864,8 @@ func TestDisableFeatureMySQL(t *testing.T) { { desc: "error: statusNotFound", setup: func(s *FeatureService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2fs.ErrFeatureNotFound) s.environmentClient.(*envclientmock.MockClient).EXPECT().GetEnvironmentV2(gomock.Any(), gomock.Any()).Return( &envproto.GetEnvironmentV2Response{ @@ -1887,9 +1884,8 @@ func TestDisableFeatureMySQL(t *testing.T) { { desc: "success", setup: func(s *FeatureService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(nil) s.batchClient.(*btclientmock.MockClient).EXPECT().ExecuteBatchJob(gomock.Any(), gomock.Any()) s.environmentClient.(*envclientmock.MockClient).EXPECT().GetEnvironmentV2( @@ -2021,9 +2017,8 @@ func TestUnarchiveFeatureMySQL(t *testing.T) { { desc: "error: statusNotFound", setup: func(s *FeatureService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2fs.ErrFeatureNotFound) s.environmentClient.(*envclientmock.MockClient).EXPECT().GetEnvironmentV2(gomock.Any(), gomock.Any()).Return( &envproto.GetEnvironmentV2Response{ @@ -2042,9 +2037,8 @@ func TestUnarchiveFeatureMySQL(t *testing.T) { { desc: "success", setup: func(s *FeatureService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(nil) s.batchClient.(*btclientmock.MockClient).EXPECT().ExecuteBatchJob(gomock.Any(), gomock.Any()) s.environmentClient.(*envclientmock.MockClient).EXPECT().GetEnvironmentV2( @@ -2127,9 +2121,8 @@ func TestDeleteFeatureMySQL(t *testing.T) { { desc: "error: statusNotFound", setup: func(s *FeatureService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2fs.ErrFeatureNotFound) s.environmentClient.(*envclientmock.MockClient).EXPECT().GetEnvironmentV2(gomock.Any(), gomock.Any()).Return( &envproto.GetEnvironmentV2Response{ @@ -2148,9 +2141,8 @@ func TestDeleteFeatureMySQL(t *testing.T) { { desc: "success", setup: func(s *FeatureService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(nil) s.batchClient.(*btclientmock.MockClient).EXPECT().ExecuteBatchJob(gomock.Any(), gomock.Any()) s.environmentClient.(*envclientmock.MockClient).EXPECT().GetEnvironmentV2( diff --git a/pkg/storage/v2/mysql/client.go b/pkg/storage/v2/mysql/client.go index 8d219298ba..d71a2db8aa 100644 --- a/pkg/storage/v2/mysql/client.go +++ b/pkg/storage/v2/mysql/client.go @@ -28,6 +28,7 @@ import ( ) const dsnParams = "collation=utf8mb4_bin" +const transactionKey = "transaction" type options struct { connMaxLifetime time.Duration @@ -103,8 +104,14 @@ type QueryExecer interface { type Client interface { QueryExecer Close() error + // Deprecated BeginTx(ctx context.Context) (Transaction, error) RunInTransaction(ctx context.Context, tx Transaction, f func() error) error + // ToDo: + // Transaction is passed because it is required for storage that does not support storage architecture refactoring, + // but we plan to remove it once the refactoring is complete. + RunInTransactionV2(ctx context.Context, f func(ctx context.Context, tx Transaction) error) error + Qe(ctx context.Context) QueryExecer } type client struct { @@ -179,6 +186,7 @@ func (c *client) QueryRowContext(ctx context.Context, query string, args ...inte return r } +// Deprecated func (c *client) BeginTx(ctx context.Context) (Transaction, error) { var err error defer record()(operationBeginTx, &err) @@ -199,3 +207,31 @@ func (c *client) RunInTransaction(ctx context.Context, tx Transaction, f func() } return err } + +func (c *client) RunInTransactionV2( + ctx context.Context, + f func(ctx context.Context, ctxWithTx Transaction) error) error { + tx, err := c.BeginTx(ctx) + if err != nil { + return fmt.Errorf("client: begin tx: %w", err) + } + ctx = context.WithValue(ctx, transactionKey, tx) + defer record()(operationRunInTransaction, &err) + defer func() { + if err != nil { + tx.Rollback() // nolint:errcheck + } + }() + if err = f(ctx, tx); err == nil { + err = tx.Commit() + } + return err +} + +func (c *client) Qe(ctx context.Context) QueryExecer { + tx, ok := ctx.Value(transactionKey).(Transaction) + if ok { + return tx + } + return c +} diff --git a/pkg/storage/v2/mysql/mock/client.go b/pkg/storage/v2/mysql/mock/client.go index d3cc0b7e75..0c81649b57 100644 --- a/pkg/storage/v2/mysql/mock/client.go +++ b/pkg/storage/v2/mysql/mock/client.go @@ -277,6 +277,20 @@ func (mr *MockClientMockRecorder) ExecContext(ctx, query any, args ...any) *gomo return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecContext", reflect.TypeOf((*MockClient)(nil).ExecContext), varargs...) } +// Qe mocks base method. +func (m *MockClient) Qe(ctx context.Context) mysql.QueryExecer { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Qe", ctx) + ret0, _ := ret[0].(mysql.QueryExecer) + return ret0 +} + +// Qe indicates an expected call of Qe. +func (mr *MockClientMockRecorder) Qe(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Qe", reflect.TypeOf((*MockClient)(nil).Qe), ctx) +} + // QueryContext mocks base method. func (m *MockClient) QueryContext(ctx context.Context, query string, args ...any) (mysql.Rows, error) { m.ctrl.T.Helper() @@ -329,3 +343,17 @@ func (mr *MockClientMockRecorder) RunInTransaction(ctx, tx, f any) *gomock.Call mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunInTransaction", reflect.TypeOf((*MockClient)(nil).RunInTransaction), ctx, tx, f) } + +// RunInTransactionV2 mocks base method. +func (m *MockClient) RunInTransactionV2(ctx context.Context, f func(context.Context, mysql.Transaction) error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RunInTransactionV2", ctx, f) + ret0, _ := ret[0].(error) + return ret0 +} + +// RunInTransactionV2 indicates an expected call of RunInTransactionV2. +func (mr *MockClientMockRecorder) RunInTransactionV2(ctx, f any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunInTransactionV2", reflect.TypeOf((*MockClient)(nil).RunInTransactionV2), ctx, f) +}