diff --git a/pkg/autoops/api/api.go b/pkg/autoops/api/api.go index f306acdf5e..4445aba4eb 100644 --- a/pkg/autoops/api/api.go +++ b/pkg/autoops/api/api.go @@ -1112,16 +1112,28 @@ func (s *AutoOpsService) listAutoOpsRules( localizer locale.Localizer, storage v2as.AutoOpsRuleStorage, ) ([]*autoopsproto.AutoOpsRule, string, error) { - whereParts := []mysql.WherePart{ - mysql.NewFilter("deleted", "=", false), - mysql.NewFilter("environment_id", "=", environmentId), + filters := []*mysql.FilterV2{ + { + Column: "deleted", + Operator: mysql.OperatorEqual, + Value: false, + }, + { + Column: "environment_id", + Operator: mysql.OperatorEqual, + Value: environmentId, + }, } fIDs := make([]interface{}, 0, len(featureIds)) for _, fID := range featureIds { fIDs = append(fIDs, fID) } + var infilter *mysql.InFilter if len(fIDs) > 0 { - whereParts = append(whereParts, mysql.NewInFilter("feature_id", fIDs)) + infilter = &mysql.InFilter{ + Column: "feature_id", + Values: fIDs, + } } limit := int(pageSize) if cursor == "" { @@ -1139,13 +1151,18 @@ func (s *AutoOpsService) listAutoOpsRules( return nil, "", dt.Err() } - autoOpsRules, nextCursor, err := storage.ListAutoOpsRules( - ctx, - whereParts, - nil, - limit, - offset, - ) + listOptions := &mysql.ListOptions{ + Limit: limit, + Offset: offset, + Filters: filters, + InFilter: infilter, + NullFilters: nil, + JSONFilters: nil, + SearchQuery: nil, + Orders: nil, + } + + autoOpsRules, nextCursor, err := storage.ListAutoOpsRules(ctx, listOptions) if err != nil { s.logger.Error( "Failed to list autoOpsRules", diff --git a/pkg/autoops/api/progressive_rollout.go b/pkg/autoops/api/progressive_rollout.go index 46ade364fe..98c7c8339b 100644 --- a/pkg/autoops/api/progressive_rollout.go +++ b/pkg/autoops/api/progressive_rollout.go @@ -609,8 +609,12 @@ func (s *AutoOpsService) listProgressiveRollouts( req *autoopsproto.ListProgressiveRolloutsRequest, localizer locale.Localizer, ) ([]*autoopsproto.ProgressiveRollout, int64, int, error) { - whereParts := []mysql.WherePart{ - mysql.NewFilter("environment_id", "=", req.EnvironmentId), + filters := []*mysql.FilterV2{ + { + Column: "environment_id", + Operator: mysql.OperatorEqual, + Value: req.EnvironmentId, + }, } limit := int(req.PageSize) cursor := req.Cursor @@ -628,9 +632,13 @@ func (s *AutoOpsService) listProgressiveRollouts( } return nil, 0, 0, dt.Err() } + var inFilter *mysql.InFilter = nil if len(req.FeatureIds) > 0 { fIDs := s.convToInterfaceSlice(req.FeatureIds) - whereParts = append(whereParts, mysql.NewInFilter("feature_id", fIDs)) + inFilter = &mysql.InFilter{ + Column: "feature_id", + Values: fIDs, + } } orders, err := s.newListProgressiveRolloutsOrdersMySQL( req.OrderBy, @@ -648,19 +656,24 @@ func (s *AutoOpsService) listProgressiveRollouts( return nil, 0, 0, err } if req.Type != nil { - whereParts = append(whereParts, mysql.NewFilter("type", "=", req.Type)) + filters = append(filters, &mysql.FilterV2{Column: "type", Operator: mysql.OperatorEqual, Value: req.Type}) } if req.Status != nil { - whereParts = append(whereParts, mysql.NewFilter("status", "=", req.Status)) + filters = append(filters, &mysql.FilterV2{Column: "status", Operator: mysql.OperatorEqual, Value: req.Status}) + } + listOptions := &mysql.ListOptions{ + Filters: filters, + Orders: orders, + InFilter: inFilter, + NullFilters: nil, + JSONFilters: nil, + SearchQuery: nil, + Limit: limit, + Offset: offset, } + storage := v2as.NewProgressiveRolloutStorage(s.mysqlClient) - progressiveRollouts, totalCount, nextOffset, err := storage.ListProgressiveRollouts( - ctx, - whereParts, - orders, - limit, - offset, - ) + progressiveRollouts, totalCount, nextOffset, err := storage.ListProgressiveRollouts(ctx, listOptions) if err != nil { s.logger.Error( "Failed to list progressive rollouts", diff --git a/pkg/autoops/api/stop_progressive_rollout_operation.go b/pkg/autoops/api/stop_progressive_rollout_operation.go index 49de85af38..58e4231edd 100644 --- a/pkg/autoops/api/stop_progressive_rollout_operation.go +++ b/pkg/autoops/api/stop_progressive_rollout_operation.go @@ -30,11 +30,28 @@ func executeStopProgressiveRolloutOperation( environmentId string, operation autoopsproto.ProgressiveRollout_StoppedBy, ) error { - whereParts := []mysql.WherePart{ - mysql.NewFilter("environment_id", "=", environmentId), - mysql.NewInFilter("feature_id", featureIDs), + filters := []*mysql.FilterV2{ + { + Column: "environment_id", + Operator: mysql.OperatorEqual, + Value: environmentId, + }, } - list, _, _, err := storage.ListProgressiveRollouts(ctx, whereParts, nil, 0, 0) + inFilter := &mysql.InFilter{ + Column: "feature_id", + Values: featureIDs, + } + listOptions := &mysql.ListOptions{ + Filters: filters, + Orders: nil, + InFilter: inFilter, + NullFilters: nil, + JSONFilters: nil, + SearchQuery: nil, + Limit: 0, + Offset: 0, + } + list, _, _, err := storage.ListProgressiveRollouts(ctx, listOptions) if err != nil { return err } diff --git a/pkg/autoops/storage/v2/auto_ops_rule.go b/pkg/autoops/storage/v2/auto_ops_rule.go index 4097a495bd..da22cd5ffb 100644 --- a/pkg/autoops/storage/v2/auto_ops_rule.go +++ b/pkg/autoops/storage/v2/auto_ops_rule.go @@ -19,7 +19,6 @@ import ( "context" _ "embed" "errors" - "fmt" "github.com/bucketeer-io/bucketeer/pkg/autoops/domain" "github.com/bucketeer-io/bucketeer/pkg/storage/v2/mysql" @@ -49,9 +48,7 @@ type AutoOpsRuleStorage interface { GetAutoOpsRule(ctx context.Context, id, environmentId string) (*domain.AutoOpsRule, error) ListAutoOpsRules( ctx context.Context, - whereParts []mysql.WherePart, - orders []*mysql.Order, - limit, offset int, + options *mysql.ListOptions, ) ([]*proto.AutoOpsRule, int, error) } @@ -154,20 +151,15 @@ func (s *autoOpsRuleStorage) GetAutoOpsRule( func (s *autoOpsRuleStorage) ListAutoOpsRules( ctx context.Context, - whereParts []mysql.WherePart, - orders []*mysql.Order, - limit, offset int, + options *mysql.ListOptions, ) ([]*proto.AutoOpsRule, int, error) { - whereSQL, whereArgs := mysql.ConstructWhereSQLString(whereParts) - orderBySQL := mysql.ConstructOrderBySQLString(orders) - limitOffsetSQL := mysql.ConstructLimitOffsetSQLString(limit, offset) - query := fmt.Sprintf(selectAutoOpsRulesSQL, whereSQL, orderBySQL, limitOffsetSQL) + query, whereArgs := mysql.ConstructQueryAndWhereArgs(selectAutoOpsRulesSQL, options) rows, err := s.qe.QueryContext(ctx, query, whereArgs...) if err != nil { return nil, 0, err } defer rows.Close() - autoOpsRules := make([]*proto.AutoOpsRule, 0, limit) + autoOpsRules := make([]*proto.AutoOpsRule, 0) for rows.Next() { autoOpsRule := proto.AutoOpsRule{} var opsType int32 @@ -190,6 +182,10 @@ func (s *autoOpsRuleStorage) ListAutoOpsRules( if rows.Err() != nil { return nil, 0, err } + var offset int + if options != nil { + offset = options.Offset + } nextOffset := offset + len(autoOpsRules) return autoOpsRules, nextOffset, nil } diff --git a/pkg/autoops/storage/v2/auto_ops_rule_test.go b/pkg/autoops/storage/v2/auto_ops_rule_test.go index 939c808a41..e39a07eb6b 100644 --- a/pkg/autoops/storage/v2/auto_ops_rule_test.go +++ b/pkg/autoops/storage/v2/auto_ops_rule_test.go @@ -189,10 +189,7 @@ func TestListAutoOpsRules(t *testing.T) { defer mockController.Finish() patterns := []struct { setup func(*autoOpsRuleStorage) - whereParts []mysql.WherePart - orders []*mysql.Order - limit int - offset int + listOpts *mysql.ListOptions expected []*proto.AutoOpsRule expectedCursor int expectedErr error @@ -203,10 +200,7 @@ func TestListAutoOpsRules(t *testing.T) { gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil, errors.New("error")) }, - whereParts: nil, - orders: nil, - limit: 0, - offset: 0, + listOpts: nil, expected: nil, expectedCursor: 0, expectedErr: errors.New("error"), @@ -221,14 +215,27 @@ func TestListAutoOpsRules(t *testing.T) { gomock.Any(), gomock.Any(), gomock.Any(), ).Return(rows, nil) }, - whereParts: []mysql.WherePart{ - mysql.NewFilter("num", ">=", 5), + listOpts: &mysql.ListOptions{ + Limit: 10, + Offset: 5, + Filters: []*mysql.FilterV2{ + { + Column: "num", + Operator: mysql.OperatorGreaterThanOrEqual, + Value: 5, + }, + }, + InFilter: nil, + NullFilters: nil, + JSONFilters: nil, + SearchQuery: nil, + Orders: []*mysql.Order{ + { + Column: "id", + Direction: mysql.OrderDirectionAsc, + }, + }, }, - orders: []*mysql.Order{ - mysql.NewOrder("id", mysql.OrderDirectionAsc), - }, - limit: 10, - offset: 5, expected: []*proto.AutoOpsRule{}, expectedCursor: 5, expectedErr: nil, @@ -241,10 +248,7 @@ func TestListAutoOpsRules(t *testing.T) { } autoOpsRules, cursor, err := storage.ListAutoOpsRules( context.Background(), - p.whereParts, - p.orders, - p.limit, - p.offset, + p.listOpts, ) assert.Equal(t, p.expected, autoOpsRules) assert.Equal(t, p.expectedCursor, cursor) diff --git a/pkg/autoops/storage/v2/mock/auto_ops_rule.go b/pkg/autoops/storage/v2/mock/auto_ops_rule.go index d2c237ca9b..72892dd6da 100644 --- a/pkg/autoops/storage/v2/mock/auto_ops_rule.go +++ b/pkg/autoops/storage/v2/mock/auto_ops_rule.go @@ -73,9 +73,9 @@ func (mr *MockAutoOpsRuleStorageMockRecorder) GetAutoOpsRule(ctx, id, environmen } // ListAutoOpsRules mocks base method. -func (m *MockAutoOpsRuleStorage) ListAutoOpsRules(ctx context.Context, whereParts []mysql.WherePart, orders []*mysql.Order, limit, offset int) ([]*autoops.AutoOpsRule, int, error) { +func (m *MockAutoOpsRuleStorage) ListAutoOpsRules(ctx context.Context, options *mysql.ListOptions) ([]*autoops.AutoOpsRule, int, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListAutoOpsRules", ctx, whereParts, orders, limit, offset) + ret := m.ctrl.Call(m, "ListAutoOpsRules", ctx, options) ret0, _ := ret[0].([]*autoops.AutoOpsRule) ret1, _ := ret[1].(int) ret2, _ := ret[2].(error) @@ -83,9 +83,9 @@ func (m *MockAutoOpsRuleStorage) ListAutoOpsRules(ctx context.Context, wherePart } // ListAutoOpsRules indicates an expected call of ListAutoOpsRules. -func (mr *MockAutoOpsRuleStorageMockRecorder) ListAutoOpsRules(ctx, whereParts, orders, limit, offset any) *gomock.Call { +func (mr *MockAutoOpsRuleStorageMockRecorder) ListAutoOpsRules(ctx, options any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAutoOpsRules", reflect.TypeOf((*MockAutoOpsRuleStorage)(nil).ListAutoOpsRules), ctx, whereParts, orders, limit, offset) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAutoOpsRules", reflect.TypeOf((*MockAutoOpsRuleStorage)(nil).ListAutoOpsRules), ctx, options) } // UpdateAutoOpsRule mocks base method. diff --git a/pkg/autoops/storage/v2/mock/progressive_rollout.go b/pkg/autoops/storage/v2/mock/progressive_rollout.go index 19a71ba78f..213a825718 100644 --- a/pkg/autoops/storage/v2/mock/progressive_rollout.go +++ b/pkg/autoops/storage/v2/mock/progressive_rollout.go @@ -87,9 +87,9 @@ func (mr *MockProgressiveRolloutStorageMockRecorder) GetProgressiveRollout(ctx, } // ListProgressiveRollouts mocks base method. -func (m *MockProgressiveRolloutStorage) ListProgressiveRollouts(ctx context.Context, whereParts []mysql.WherePart, orders []*mysql.Order, limit, offset int) ([]*autoops.ProgressiveRollout, int64, int, error) { +func (m *MockProgressiveRolloutStorage) ListProgressiveRollouts(ctx context.Context, options *mysql.ListOptions) ([]*autoops.ProgressiveRollout, int64, int, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListProgressiveRollouts", ctx, whereParts, orders, limit, offset) + ret := m.ctrl.Call(m, "ListProgressiveRollouts", ctx, options) ret0, _ := ret[0].([]*autoops.ProgressiveRollout) ret1, _ := ret[1].(int64) ret2, _ := ret[2].(int) @@ -98,9 +98,9 @@ func (m *MockProgressiveRolloutStorage) ListProgressiveRollouts(ctx context.Cont } // ListProgressiveRollouts indicates an expected call of ListProgressiveRollouts. -func (mr *MockProgressiveRolloutStorageMockRecorder) ListProgressiveRollouts(ctx, whereParts, orders, limit, offset any) *gomock.Call { +func (mr *MockProgressiveRolloutStorageMockRecorder) ListProgressiveRollouts(ctx, options any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListProgressiveRollouts", reflect.TypeOf((*MockProgressiveRolloutStorage)(nil).ListProgressiveRollouts), ctx, whereParts, orders, limit, offset) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListProgressiveRollouts", reflect.TypeOf((*MockProgressiveRolloutStorage)(nil).ListProgressiveRollouts), ctx, options) } // UpdateProgressiveRollout mocks base method. diff --git a/pkg/autoops/storage/v2/progressive_rollout.go b/pkg/autoops/storage/v2/progressive_rollout.go index 7864c2b57c..987fa5368c 100644 --- a/pkg/autoops/storage/v2/progressive_rollout.go +++ b/pkg/autoops/storage/v2/progressive_rollout.go @@ -19,7 +19,6 @@ import ( "context" _ "embed" "errors" - "fmt" "github.com/bucketeer-io/bucketeer/pkg/autoops/domain" "github.com/bucketeer-io/bucketeer/pkg/storage/v2/mysql" @@ -61,9 +60,7 @@ type ProgressiveRolloutStorage interface { DeleteProgressiveRollout(ctx context.Context, id, environmentId string) error ListProgressiveRollouts( ctx context.Context, - whereParts []mysql.WherePart, - orders []*mysql.Order, - limit, offset int, + options *mysql.ListOptions, ) ([]*autoopsproto.ProgressiveRollout, int64, int, error) UpdateProgressiveRollout(ctx context.Context, progressiveRollout *domain.ProgressiveRollout, @@ -158,20 +155,15 @@ func (s *progressiveRolloutStorage) DeleteProgressiveRollout( func (s *progressiveRolloutStorage) ListProgressiveRollouts( ctx context.Context, - whereParts []mysql.WherePart, - orders []*mysql.Order, - limit, offset int, + options *mysql.ListOptions, ) ([]*autoopsproto.ProgressiveRollout, int64, int, error) { - whereSQL, whereArgs := mysql.ConstructWhereSQLString(whereParts) - orderBySQL := mysql.ConstructOrderBySQLString(orders) - limitOffsetSQL := mysql.ConstructLimitOffsetSQLString(limit, offset) - query := fmt.Sprintf(selectOpsProgressiveRolloutsSQL, whereSQL, orderBySQL, limitOffsetSQL) + query, whereArgs := mysql.ConstructQueryAndWhereArgs(selectOpsProgressiveRolloutsSQL, options) rows, err := s.qe.QueryContext(ctx, query, whereArgs...) if err != nil { return nil, 0, 0, err } defer rows.Close() - progressiveRollouts := make([]*autoopsproto.ProgressiveRollout, 0, limit) + progressiveRollouts := make([]*autoopsproto.ProgressiveRollout, 0) for rows.Next() { progressiveRollout := autoopsproto.ProgressiveRollout{} err := rows.Scan( @@ -193,9 +185,13 @@ func (s *progressiveRolloutStorage) ListProgressiveRollouts( if rows.Err() != nil { return nil, 0, 0, err } + var offset int + if options != nil { + offset = options.Offset + } nextOffset := offset + len(progressiveRollouts) var totalCount int64 - countQuery := fmt.Sprintf(countOpsProgressiveRolloutsSQL, whereSQL) + countQuery, whereArgs := mysql.ConstructQueryAndWhereArgs(countOpsProgressiveRolloutsSQL, options) err = s.qe.QueryRowContext(ctx, countQuery, whereArgs...).Scan(&totalCount) if err != nil { return nil, 0, 0, err diff --git a/pkg/autoops/storage/v2/progressive_rollout_test.go b/pkg/autoops/storage/v2/progressive_rollout_test.go index 17a461b73e..bfb3a004b9 100644 --- a/pkg/autoops/storage/v2/progressive_rollout_test.go +++ b/pkg/autoops/storage/v2/progressive_rollout_test.go @@ -16,6 +16,7 @@ package v2 import ( "context" + "errors" "testing" "github.com/stretchr/testify/assert" @@ -83,6 +84,86 @@ func TestCreateProgressiveRollout(t *testing.T) { } } +func TestListProgressiveRollouts(t *testing.T) { + t.Parallel() + mockController := gomock.NewController(t) + defer mockController.Finish() + + patterns := []struct { + setup func(*progressiveRolloutStorage) + listOpts *mysql.ListOptions + expected []*proto.ProgressiveRollout + expectedCursor int + expectedTotalCount int64 + expectedErr error + }{ + { + setup: func(s *progressiveRolloutStorage) { + s.qe.(*mock.MockQueryExecer).EXPECT().QueryContext( + gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(nil, errors.New("error")) + }, + listOpts: nil, + expected: nil, + expectedCursor: 0, + expectedTotalCount: 0, + expectedErr: errors.New("error"), + }, + { + setup: func(s *progressiveRolloutStorage) { + rows := mock.NewMockRows(mockController) + rows.EXPECT().Close().Return(nil) + rows.EXPECT().Next().Return(false) + rows.EXPECT().Err().Return(nil) + s.qe.(*mock.MockQueryExecer).EXPECT().QueryContext( + gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(rows, nil) + row := mock.NewMockRow(mockController) + s.qe.(*mock.MockQueryExecer).EXPECT().QueryRowContext( + gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(row) + row.EXPECT().Scan(gomock.Any()).Return(nil) + }, + listOpts: &mysql.ListOptions{ + Limit: 10, + Offset: 5, + Filters: []*mysql.FilterV2{ + { + Column: "num", + Operator: mysql.OperatorGreaterThanOrEqual, + Value: 5, + }, + }, + InFilter: nil, + NullFilters: nil, + JSONFilters: nil, + SearchQuery: nil, + Orders: []*mysql.Order{ + { + Column: "id", + Direction: mysql.OrderDirectionAsc, + }, + }, + }, + expected: []*proto.ProgressiveRollout{}, + expectedCursor: 5, + expectedTotalCount: 0, + expectedErr: nil, + }, + } + for _, p := range patterns { + storage := newProgressiveRolloutStorageWithMock(t, mockController) + if p.setup != nil { + p.setup(storage) + } + pr, totalCount, cursor, err := storage.ListProgressiveRollouts(context.Background(), p.listOpts) + assert.Equal(t, p.expected, pr) + assert.Equal(t, p.expectedCursor, cursor) + assert.Equal(t, p.expectedTotalCount, totalCount) + assert.Equal(t, p.expectedErr, err) + } +} + func newProgressiveRolloutStorageWithMock(t *testing.T, mockController *gomock.Controller) *progressiveRolloutStorage { t.Helper() return &progressiveRolloutStorage{mock.NewMockQueryExecer(mockController)} diff --git a/pkg/autoops/storage/v2/sql/auto_ops_rule/select_auto_ops_rules.sql b/pkg/autoops/storage/v2/sql/auto_ops_rule/select_auto_ops_rules.sql index 7bdeb94c4d..257aef8023 100644 --- a/pkg/autoops/storage/v2/sql/auto_ops_rule/select_auto_ops_rules.sql +++ b/pkg/autoops/storage/v2/sql/auto_ops_rule/select_auto_ops_rules.sql @@ -9,4 +9,3 @@ SELECT status FROM auto_ops_rule -%s %s %s diff --git a/pkg/autoops/storage/v2/sql/ops_progressive_rollout/count_ops_progressive_rollouts.sql b/pkg/autoops/storage/v2/sql/ops_progressive_rollout/count_ops_progressive_rollouts.sql index cf17c6551f..efc9c70cef 100644 --- a/pkg/autoops/storage/v2/sql/ops_progressive_rollout/count_ops_progressive_rollouts.sql +++ b/pkg/autoops/storage/v2/sql/ops_progressive_rollout/count_ops_progressive_rollouts.sql @@ -2,4 +2,3 @@ SELECT COUNT(1) FROM ops_progressive_rollout -%s diff --git a/pkg/autoops/storage/v2/sql/ops_progressive_rollout/select_ops_progressive_rollouts.sql b/pkg/autoops/storage/v2/sql/ops_progressive_rollout/select_ops_progressive_rollouts.sql index 8987b405fe..e1923084ea 100644 --- a/pkg/autoops/storage/v2/sql/ops_progressive_rollout/select_ops_progressive_rollouts.sql +++ b/pkg/autoops/storage/v2/sql/ops_progressive_rollout/select_ops_progressive_rollouts.sql @@ -10,4 +10,3 @@ SELECT updated_at FROM ops_progressive_rollout -%s %s %s diff --git a/pkg/feature/api/feature.go b/pkg/feature/api/feature.go index d788f1eee3..97b1be1515 100644 --- a/pkg/feature/api/feature.go +++ b/pkg/feature/api/feature.go @@ -1923,11 +1923,29 @@ func (s *FeatureService) stopProgressiveRollout( EnvironmentId, featureID string) error { storage := v2ao.NewProgressiveRolloutStorage(tx) ids := convToInterfaceSlice([]string{featureID}) - whereParts := []mysql.WherePart{ - mysql.NewFilter("environment_id", "=", EnvironmentId), - mysql.NewInFilter("feature_id", ids), + filters := []*mysql.FilterV2{ + { + Column: "environment_id", + Operator: mysql.OperatorEqual, + Value: EnvironmentId, + }, + } + inFilter := &mysql.InFilter{ + Column: "feature_id", + Values: ids, } - list, _, _, err := storage.ListProgressiveRollouts(ctx, whereParts, nil, 0, 0) + listOptions := &mysql.ListOptions{ + Filters: filters, + Orders: nil, + InFilter: inFilter, + NullFilters: nil, + JSONFilters: nil, + SearchQuery: nil, + Limit: 0, + Offset: 0, + } + + list, _, _, err := storage.ListProgressiveRollouts(ctx, listOptions) if err != nil { return err } diff --git a/pkg/storage/v2/mysql/query.go b/pkg/storage/v2/mysql/query.go index 1c62e4c45c..27bc5fef92 100644 --- a/pkg/storage/v2/mysql/query.go +++ b/pkg/storage/v2/mysql/query.go @@ -23,6 +23,41 @@ import ( const placeHolder = "?" +type Operator int + +const ( + // Operation to find the field is equal to the specified value. + OperatorEqual = iota + 1 + // Operation to find the field isn't equal to the specified value. + OperatorNotEqual + // Operation to find ones that contain any one of the multiple values. + OperatorIn + // Operation to find ones that do not contain any of the specified multiple values. + OperatorNotIn + // Operation to find ones the field is greater than the specified value. + OperatorGreaterThan + // Operation to find ones the field is greater or equal than the specified value. + OperatorGreaterThanOrEqual + // Operation to find ones the field is less than the specified value. + OperatorLessThan + // Operation to find ones the field is less or equal than the specified value. + OperatorLessThanOrEqual + // Operation to find ones that have a specified value in its array. + OperatorContains +) + +var operatorMap = map[Operator]string{ + OperatorEqual: "=", + OperatorNotEqual: "!=", + OperatorIn: "IN", + OperatorNotIn: "NOT IN", + OperatorGreaterThan: ">", + OperatorGreaterThanOrEqual: ">=", + OperatorLessThan: "<", + OperatorLessThanOrEqual: "<=", + OperatorContains: "MEMBER OF", +} + type WherePart interface { SQLString() (sql string, args []interface{}) } @@ -50,6 +85,21 @@ func (f *Filter) SQLString() (sql string, args []interface{}) { return } +type FilterV2 struct { + Column string + Operator Operator + Value interface{} +} + +func (f *FilterV2) SQLString() (sql string, args []interface{}) { + if f.Column == "" || f.Operator < OperatorEqual || f.Operator > OperatorContains { + return "", nil + } + sql = fmt.Sprintf("%s %s %s", f.Column, operatorMap[f.Operator], placeHolder) + args = append(args, f.Value) + return +} + type InFilter struct { Column string Values []interface{} @@ -63,7 +113,7 @@ func NewInFilter(column string, values []interface{}) WherePart { } func (f *InFilter) SQLString() (sql string, args []interface{}) { - if f.Column == "" { + if f.Column == "" || len(f.Values) == 0 { return "", nil } var sb strings.Builder @@ -236,7 +286,7 @@ func ConstructWhereSQLString(wps []WherePart) (sql string, args []interface{}) { sb.WriteString(wpSQL) args = append(args, wpArgs...) } - sql = sb.String() + sql = sb.String() + " " return } @@ -288,6 +338,42 @@ func ConstructOrderBySQLString(orders []*Order) string { return sb.String() } +func ConstructQueryAndWhereArgs(baseQuery string, options *ListOptions) (query string, whereArgs []interface{}) { + if options != nil { + var whereQuery string + whereParts := options.CreateWhereParts() + whereQuery, whereArgs = ConstructWhereSQLString(whereParts) + orderByQuery := ConstructOrderBySQLString(options.Orders) + limitOffsetQuery := ConstructLimitOffsetSQLString(options.Limit, options.Offset) + query = baseQuery + whereQuery + orderByQuery + limitOffsetQuery + } else { + query = baseQuery + whereArgs = []interface{}{} + } + return +} + +type Orders struct { + Orders []*Order +} + +func (o *Orders) SQLString() (sql string, args []interface{}) { + if len(o.Orders) == 0 { + return "", nil + } + var sb strings.Builder + sb.WriteString("ORDER BY ") + for i, o := range o.Orders { + if i != 0 { + sb.WriteString(", ") + } + sb.WriteString(o.Column) + sb.WriteString(" ") + sb.WriteString(o.Direction.String()) + } + return sb.String(), nil +} + const ( QueryNoLimit = 0 QueryNoOffset = 0 @@ -309,3 +395,40 @@ func ConstructLimitOffsetSQLString(limit, offset int) string { } return fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset) } + +type ListOptions struct { + Limit int + Filters []*FilterV2 + InFilter *InFilter + NullFilters []*NullFilter + JSONFilters []*JSONFilter + SearchQuery *SearchQuery + Orders []*Order + Offset int +} + +func (lo *ListOptions) CreateWhereParts() []WherePart { + var whereParts []WherePart + if lo.Filters != nil { + for _, f := range lo.Filters { + whereParts = append(whereParts, f) + } + } + if lo.InFilter != nil { + whereParts = append(whereParts, lo.InFilter) + } + if lo.NullFilters != nil { + for _, f := range lo.NullFilters { + whereParts = append(whereParts, f) + } + } + if lo.JSONFilters != nil { + for _, f := range lo.JSONFilters { + whereParts = append(whereParts, f) + } + } + if lo.SearchQuery != nil { + whereParts = append(whereParts, lo.SearchQuery) + } + return whereParts +}