diff --git a/pkg/engine/check_object_task.go b/pkg/engine/check_object_task.go index 4d24e32..050d01b 100644 --- a/pkg/engine/check_object_task.go +++ b/pkg/engine/check_object_task.go @@ -37,7 +37,7 @@ type CheckObjTask struct { } // newCheckObjTask initializes and returns CheckObjTask -func newCheckObjTask(log logr.Logger, client *dynamic.DynamicClient, getter ObjGetter, cfg *config.Task) (*CheckObjTask, error) { +func newCheckObjTask(log logr.Logger, client *dynamic.DynamicClient, accessor ObjInfoAccessor, cfg *config.Task) (*CheckObjTask, error) { if client == nil { return nil, fmt.Errorf("%s/%s: DynamicClient is not set", cfg.Type, cfg.ID) } @@ -49,8 +49,8 @@ func newCheckObjTask(log logr.Logger, client *dynamic.DynamicClient, getter ObjG taskType: cfg.Type, taskID: cfg.ID, }, - client: client, - getter: getter, + client: client, + accessor: accessor, }, } @@ -63,7 +63,7 @@ func newCheckObjTask(log logr.Logger, client *dynamic.DynamicClient, getter ObjG // Exec implements Runnable interface func (task *CheckObjTask) Exec(ctx context.Context) error { - info, err := task.getter.GetObjInfo(task.RefTaskID) + info, err := task.accessor.GetObjInfo(task.RefTaskID) if err != nil { return err } diff --git a/pkg/engine/check_object_task_test.go b/pkg/engine/check_object_task_test.go index a9d8ce1..d742121 100644 --- a/pkg/engine/check_object_task_test.go +++ b/pkg/engine/check_object_task_test.go @@ -90,7 +90,7 @@ func TestNewCheckObjTask(t *testing.T) { eng, err := New(testLogger, nil, tc.simClients) require.NoError(t, err) if len(tc.refTaskId) != 0 { - eng.objMap[tc.refTaskId] = nil + eng.objInfoMap[tc.refTaskId] = nil } task, err := eng.GetTask(&config.Task{ @@ -102,7 +102,7 @@ func TestNewCheckObjTask(t *testing.T) { require.EqualError(t, err, tc.err) require.Nil(t, tc.task) } else { - tc.task.getter = eng + tc.task.accessor = eng require.NoError(t, err) require.NotNil(t, tc.task) require.Equal(t, tc.task, task) diff --git a/pkg/engine/check_pod_task.go b/pkg/engine/check_pod_task.go index b92bda2..ffe147d 100644 --- a/pkg/engine/check_pod_task.go +++ b/pkg/engine/check_pod_task.go @@ -19,6 +19,7 @@ package engine import ( "context" "fmt" + "regexp" "time" "github.com/go-logr/logr" @@ -41,8 +42,8 @@ type CheckPodTask struct { BaseTask checkPodTaskParams - client *kubernetes.Clientset - getter ObjGetter + client *kubernetes.Clientset + accessor ObjInfoAccessor } type checkPodTaskParams struct { @@ -53,7 +54,7 @@ type checkPodTaskParams struct { } // newCheckPodTask initializes and returns CheckPodTask -func newCheckPodTask(log logr.Logger, client *kubernetes.Clientset, getter ObjGetter, cfg *config.Task) (*CheckPodTask, error) { +func newCheckPodTask(log logr.Logger, client *kubernetes.Clientset, accessor ObjInfoAccessor, cfg *config.Task) (*CheckPodTask, error) { if client == nil { return nil, fmt.Errorf("%s/%s: Kubernetes client is not set", cfg.Type, cfg.ID) } @@ -64,8 +65,8 @@ func newCheckPodTask(log logr.Logger, client *kubernetes.Clientset, getter ObjGe taskType: cfg.Type, taskID: cfg.ID, }, - client: client, - getter: getter, + client: client, + accessor: accessor, } if err := task.validate(cfg.Params); err != nil { @@ -98,13 +99,13 @@ func (task *CheckPodTask) validate(params map[string]interface{}) error { // Exec implements Runnable interface func (task *CheckPodTask) Exec(ctx context.Context) error { - info, err := task.getter.GetObjInfo(task.RefTaskID) + info, err := task.accessor.GetObjInfo(task.RefTaskID) if err != nil { return err } - if len(info.Pods) == 0 { - return nil + if len(info.PodRegexp) == 0 { + return fmt.Errorf("%s: no pods to check", task.ID()) } if task.Timeout == 0 { @@ -114,37 +115,57 @@ func (task *CheckPodTask) Exec(ctx context.Context) error { } func (task *CheckPodTask) checkPods(ctx context.Context, info *ObjInfo) error { - for _, name := range info.Pods { - pod, err := task.client.CoreV1().Pods(info.Namespace).Get(ctx, name, metav1.GetOptions{}) - if err != nil { - return fmt.Errorf("%s: failed to get pod '%s': %v", task.ID(), name, err) - } + list, err := task.client.CoreV1().Pods(info.Namespace).List(ctx, metav1.ListOptions{}) + if err != nil { + return fmt.Errorf("%s: failed to list pods: %v", task.ID(), err) + } - status := string(pod.Status.Phase) - if status != task.Status { - return fmt.Errorf("%s: pod %s, status %s, expected %s", task.ID(), name, status, task.Status) - } + re, err := utils.Exp2Regexp(info.PodRegexp) + if err != nil { + return fmt.Errorf("%s: %v", task.ID(), err) + } - if err := task.verifyLabels(ctx, pod); err != nil { - return err + var count int + for i := range list.Items { + pod := &list.Items[i] + for _, r := range re { + if r.MatchString(pod.Name) { + task.log.V(4).Info("Matched pod", "name", pod.Name) + count++ + + status := string(pod.Status.Phase) + if status != task.Status { + return fmt.Errorf("%s: pod %s, status %s, expected %s", task.ID(), pod.Name, status, task.Status) + } + + if err := task.verifyLabels(ctx, pod); err != nil { + return err + } + } } } + if count != info.PodCount { + return fmt.Errorf("%s: verified %d pods, expected %d", task.ID(), count, info.PodCount) + } + return nil } // watchPods watches statuses of given pods and compares them with the expected status. // The function runs until all statuses are equal to the expected one, or until the timeout, whichever comes first. func (task *CheckPodTask) watchPods(ctx context.Context, info *ObjInfo) error { - task.log.Info("Create pod informer", "#pods", len(info.Pods), "timeout", task.Timeout.String()) + task.log.Info("Create pod informer", "#pod", info.PodCount, "timeout", task.Timeout.String()) + + re, err := utils.Exp2Regexp(info.PodRegexp) + if err != nil { + return fmt.Errorf("%s: %v", task.ID(), err) + } ctx, cancel := context.WithTimeout(ctx, task.Timeout) defer cancel() podMap := utils.NewSyncMap() - for _, pod := range info.Pods { - podMap.Set(pod, true) - } errs := make(chan error) @@ -153,12 +174,12 @@ func (task *CheckPodTask) watchPods(ctx context.Context, info *ObjInfo) error { informer := factory.Core().V1().Pods().Informer() - _, err := informer.AddEventHandler(cache.ResourceEventHandlerFuncs{ + _, err = informer.AddEventHandler(cache.ResourceEventHandlerFuncs{ AddFunc: func(obj interface{}) { - task.verifyPod(ctx, podMap, obj, errs) + task.verifyPod(ctx, re, podMap, info.PodCount, obj, errs) }, UpdateFunc: func(_, obj interface{}) { - task.verifyPod(ctx, podMap, obj, errs) + task.verifyPod(ctx, re, podMap, info.PodCount, obj, errs) }, }) if err != nil { @@ -173,10 +194,10 @@ func (task *CheckPodTask) watchPods(ctx context.Context, info *ObjInfo) error { return } for i := range list.Items { - if podMap.Size() == 0 { + if podMap.Size() == info.PodCount { break } - task.verifyPod(ctx, podMap, &list.Items[i], errs) + task.verifyPod(ctx, re, podMap, info.PodCount, &list.Items[i], errs) } }() @@ -209,24 +230,33 @@ func (task *CheckPodTask) verifyLabels(ctx context.Context, pod *v1.Pod) error { return nil } -func (task *CheckPodTask) verifyPod(ctx context.Context, podMap *utils.SyncMap, obj interface{}, errs chan error) { +func (task *CheckPodTask) verifyPod(ctx context.Context, re []*regexp.Regexp, podMap *utils.SyncMap, count int, obj interface{}, errs chan error) { pod, ok := obj.(*v1.Pod) if !ok { errs <- fmt.Errorf("%s: unexpected object type %T, expected *v1.Pod", task.ID(), obj) return } - if _, ok := podMap.Get(pod.Name); ok { - status := string(pod.Status.Phase) - task.log.V(4).Info("Informer event", "pod", pod.Name, "status", status) - if err := task.verifyLabels(ctx, pod); err != nil { - errs <- err - return - } - if sz := podMap.Delete(pod.Name); sz == 0 { - task.log.Info("Accounted for all pods") - errs <- nil - return + for _, r := range re { + if r.MatchString(pod.Name) { + task.log.V(4).Info("Matched pod", "name", pod.Name) + if _, ok := podMap.Get(pod.Name); ok { + return + } + status := string(pod.Status.Phase) + task.log.V(4).Info("Informer event", "pod", pod.Name, "status", status) + if status != task.Status { + return + } + if err := task.verifyLabels(ctx, pod); err != nil { + errs <- err + return + } + if sz := podMap.Set(pod.Name, true); sz == count { + task.log.Info("Accounted for all pods") + errs <- nil + return + } } } } diff --git a/pkg/engine/check_pod_task_test.go b/pkg/engine/check_pod_task_test.go index 53045c0..b6aef2a 100644 --- a/pkg/engine/check_pod_task_test.go +++ b/pkg/engine/check_pod_task_test.go @@ -105,7 +105,7 @@ func TestCheckPodParams(t *testing.T) { eng, err := New(testLogger, nil, tc.simClients) require.NoError(t, err) if len(tc.refTaskId) != 0 { - eng.objMap[tc.refTaskId] = nil + eng.objInfoMap[tc.refTaskId] = nil } task, err := eng.GetTask(&config.Task{ ID: taskID, @@ -116,7 +116,7 @@ func TestCheckPodParams(t *testing.T) { require.EqualError(t, err, tc.err) require.Nil(t, tc.task) } else { - tc.task.getter = eng + tc.task.accessor = eng require.NoError(t, err) require.NotNil(t, tc.task) require.Equal(t, tc.task, task) diff --git a/pkg/engine/delete_object_task.go b/pkg/engine/delete_object_task.go index 84dd95e..2b6591d 100644 --- a/pkg/engine/delete_object_task.go +++ b/pkg/engine/delete_object_task.go @@ -33,7 +33,7 @@ type DeleteObjTask struct { deleteObjTaskParams client *dynamic.DynamicClient - getter ObjGetter + getter ObjInfoAccessor } type deleteObjTaskParams struct { @@ -41,7 +41,7 @@ type deleteObjTaskParams struct { } // newDeleteObjTask initializes and returns DeleteObjTask -func newDeleteObjTask(log logr.Logger, client *dynamic.DynamicClient, getter ObjGetter, cfg *config.Task) (*DeleteObjTask, error) { +func newDeleteObjTask(log logr.Logger, client *dynamic.DynamicClient, getter ObjInfoAccessor, cfg *config.Task) (*DeleteObjTask, error) { if client == nil { return nil, fmt.Errorf("%s/%s: DynamicClient is not set", cfg.Type, cfg.ID) } diff --git a/pkg/engine/delete_object_task_test.go b/pkg/engine/delete_object_task_test.go index 7328446..acdc6d9 100644 --- a/pkg/engine/delete_object_task_test.go +++ b/pkg/engine/delete_object_task_test.go @@ -79,7 +79,7 @@ func TestNewDeleteObjTask(t *testing.T) { eng, err := New(testLogger, nil, tc.simClients) require.NoError(t, err) if len(tc.refTaskId) != 0 { - eng.objMap[tc.refTaskId] = nil + eng.objInfoMap[tc.refTaskId] = nil } task, err := eng.GetTask(&config.Task{ diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 398c3f4..1ffd39d 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -23,6 +23,7 @@ import ( "time" "github.com/go-logr/logr" + "k8s.io/client-go/discovery" "k8s.io/client-go/dynamic" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" @@ -36,30 +37,37 @@ type Engine interface { } type Eng struct { - log logr.Logger - mutex sync.Mutex - k8sClient *kubernetes.Clientset - dynamicClient *dynamic.DynamicClient - objMap map[string]*ObjInfo + log logr.Logger + mutex sync.Mutex + k8sClient *kubernetes.Clientset + dynamicClient *dynamic.DynamicClient + discoveryClient *discovery.DiscoveryClient + objTypeMap map[string]*RegisterObjParams + objInfoMap map[string]*ObjInfo } func New(log logr.Logger, config *rest.Config, sim ...bool) (*Eng, error) { eng := &Eng{ - log: log, - objMap: make(map[string]*ObjInfo), + log: log, + objTypeMap: make(map[string]*RegisterObjParams), + objInfoMap: make(map[string]*ObjInfo), } if len(sim) == 0 { // len(sim) != 0 in unit tests var err error + if eng.k8sClient, err = kubernetes.NewForConfig(config); err != nil { + return nil, err + } if eng.dynamicClient, err = dynamic.NewForConfig(config); err != nil { return nil, err } - if eng.k8sClient, err = kubernetes.NewForConfig(config); err != nil { + if eng.discoveryClient, err = discovery.NewDiscoveryClientForConfig(config); err != nil { return nil, err } } else if sim[0] { - eng.dynamicClient = &dynamic.DynamicClient{} eng.k8sClient = &kubernetes.Clientset{} + eng.dynamicClient = &dynamic.DynamicClient{} + eng.discoveryClient = &discovery.DiscoveryClient{} } return eng, nil @@ -98,56 +106,64 @@ func (eng *Eng) GetTask(cfg *config.Task) (Runnable, error) { eng.log.Info("Creating task", "name", cfg.Type, "id", cfg.ID) switch cfg.Type { + case TaskRegisterObj: + return newRegisterObjTask(eng.log, eng.discoveryClient, eng, cfg) + case TaskSubmitObj: task, err := newSubmitObjTask(eng.log, eng.dynamicClient, eng, cfg) if err != nil { return nil, err } + if _, ok := eng.objTypeMap[task.RefTaskID]; !ok { + return nil, fmt.Errorf("%s: unreferenced task ID %s", task.ID(), task.RefTaskID) + } return task, nil + case TaskUpdateObj: task, err := newUpdateObjTask(eng.log, eng.dynamicClient, eng, cfg) if err != nil { return nil, err } - if _, ok := eng.objMap[task.RefTaskID]; !ok { + if _, ok := eng.objInfoMap[task.RefTaskID]; !ok { return nil, fmt.Errorf("%s: unreferenced task ID %s", task.ID(), task.RefTaskID) } return task, nil + case TaskCheckObj: task, err := newCheckObjTask(eng.log, eng.dynamicClient, eng, cfg) if err != nil { return nil, err } - if _, ok := eng.objMap[task.RefTaskID]; !ok { + if _, ok := eng.objInfoMap[task.RefTaskID]; !ok { return nil, fmt.Errorf("%s: unreferenced task ID %s", task.ID(), task.RefTaskID) } return task, nil + case TaskDeleteObj: task, err := newDeleteObjTask(eng.log, eng.dynamicClient, eng, cfg) if err != nil { return nil, err } - if _, ok := eng.objMap[task.RefTaskID]; !ok { + if _, ok := eng.objInfoMap[task.RefTaskID]; !ok { return nil, fmt.Errorf("%s: unreferenced task ID %s", task.ID(), task.RefTaskID) } return task, nil + case TaskUpdateNodes: return newUpdateNodesTask(eng.log, eng.k8sClient, cfg) + case TaskCheckPod: task, err := newCheckPodTask(eng.log, eng.k8sClient, eng, cfg) if err != nil { return nil, err } - if _, ok := eng.objMap[task.RefTaskID]; !ok { + if _, ok := eng.objInfoMap[task.RefTaskID]; !ok { return nil, fmt.Errorf("%s: unreferenced task ID %s", task.ID(), task.RefTaskID) } return task, nil + case TaskSleep: - task, err := newSleepTask(eng.log, cfg) - if err != nil { - return nil, err - } - return task, nil + return newSleepTask(eng.log, cfg) case TaskPause: return newPauseTask(eng.log, cfg), nil @@ -157,16 +173,47 @@ func (eng *Eng) GetTask(cfg *config.Task) (Runnable, error) { } } +// SetObjType implements ObjSetter interface and maps object type to RegisterObjParams +func (eng *Eng) SetObjType(taskID string, params *RegisterObjParams) error { + eng.mutex.Lock() + defer eng.mutex.Unlock() + + if _, ok := eng.objTypeMap[taskID]; ok { + return fmt.Errorf("SetObjType: duplicate task ID %s", taskID) + } + + eng.objTypeMap[taskID] = params + + eng.log.V(4).Info("Registering object for task ID", "name", taskID) + + return nil +} + +// GetObjType implements ObjGetter interface returns RegisterObjParams for given object type +func (eng *Eng) GetObjType(objType string) (*RegisterObjParams, error) { + eng.mutex.Lock() + defer eng.mutex.Unlock() + + info, ok := eng.objTypeMap[objType] + if !ok { + return nil, fmt.Errorf("GetObjType: missing object type %s", objType) + } + + eng.log.V(4).Info("Getting object type", "name", objType) + + return info, nil +} + // SetObjInfo implements ObjSetter interface and maps task ID to the corresponding ObjInfo func (eng *Eng) SetObjInfo(taskID string, info *ObjInfo) error { eng.mutex.Lock() defer eng.mutex.Unlock() - if _, ok := eng.objMap[taskID]; ok { + if _, ok := eng.objInfoMap[taskID]; ok { return fmt.Errorf("SetObjInfo: duplicate task ID %s", taskID) } - eng.objMap[taskID] = info + eng.objInfoMap[taskID] = info eng.log.V(4).Info("Setting task info", "taskID", taskID) @@ -178,7 +225,7 @@ func (eng *Eng) GetObjInfo(taskID string) (*ObjInfo, error) { eng.mutex.Lock() defer eng.mutex.Unlock() - info, ok := eng.objMap[taskID] + info, ok := eng.objInfoMap[taskID] if !ok { return nil, fmt.Errorf("GetObjInfo: missing task ID %s", taskID) } diff --git a/pkg/engine/engine_test.go b/pkg/engine/engine_test.go index 9f592cc..ec2e720 100644 --- a/pkg/engine/engine_test.go +++ b/pkg/engine/engine_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/stretchr/testify/require" + "k8s.io/client-go/discovery" "k8s.io/client-go/dynamic" "k8s.io/client-go/kubernetes" "k8s.io/klog/v2/textlogger" @@ -30,11 +31,12 @@ import ( ) var ( - errExec = fmt.Errorf("exec error") - errReset = fmt.Errorf("reset error") - testLogger = textlogger.NewLogger(textlogger.NewConfig()) - testK8sClient = &kubernetes.Clientset{} - testDynamicClient = &dynamic.DynamicClient{} + errExec = fmt.Errorf("exec error") + errReset = fmt.Errorf("reset error") + testLogger = textlogger.NewLogger(textlogger.NewConfig()) + testK8sClient = &kubernetes.Clientset{} + testDynamicClient = &dynamic.DynamicClient{} + testDiscoveryClient = &discovery.DiscoveryClient{} ) type testEngine struct { diff --git a/pkg/engine/object_state.go b/pkg/engine/object_state.go index 7e26716..425225a 100644 --- a/pkg/engine/object_state.go +++ b/pkg/engine/object_state.go @@ -28,8 +28,8 @@ type ObjStateTask struct { BaseTask StateParams - client *dynamic.DynamicClient - getter ObjGetter + client *dynamic.DynamicClient + accessor ObjInfoAccessor } // validate initializes and validates parameters for ObjStateTask diff --git a/pkg/engine/register_object_task.go b/pkg/engine/register_object_task.go new file mode 100644 index 0000000..c5cc13d --- /dev/null +++ b/pkg/engine/register_object_task.go @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package engine + +import ( + "context" + "fmt" + "os" + "text/template" + + "github.com/go-logr/logr" + "gopkg.in/yaml.v3" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/client-go/discovery" + + "github.com/NVIDIA/knavigator/pkg/config" +) + +type RegisterObjTask struct { + BaseTask + RegisterObjParams + + client *discovery.DiscoveryClient + accessor ObjInfoAccessor + + gvk schema.GroupVersionKind +} + +// newRegisterObjTask initializes and returns RegisterObjTask +func newRegisterObjTask(log logr.Logger, client *discovery.DiscoveryClient, accessor ObjInfoAccessor, cfg *config.Task) (*RegisterObjTask, error) { + if client == nil { + return nil, fmt.Errorf("%s/%s: DiscoveryClient is not set", cfg.Type, cfg.ID) + } + + task := &RegisterObjTask{ + BaseTask: BaseTask{ + log: log, + taskType: cfg.Type, + taskID: cfg.ID, + }, + client: client, + accessor: accessor, + } + + if err := task.validate(cfg.Params); err != nil { + return nil, err + } + + return task, nil +} + +// validate initializes and validates parameters for RegisterObjTask. +func (task *RegisterObjTask) validate(params map[string]interface{}) error { + data, err := yaml.Marshal(params) + if err != nil { + return fmt.Errorf("%s: failed to parse parameters: %v", task.ID(), err) + } + if err = yaml.Unmarshal(data, &task.RegisterObjParams); err != nil { + return fmt.Errorf("%s: failed to parse parameters: %v", task.ID(), err) + } + + if len(task.Template) == 0 { + return fmt.Errorf("%s: must specify template", task.ID()) + } + + tplData, err := os.ReadFile(task.Template) + if err != nil { + return fmt.Errorf("%s: failed to read %s: %v", task.ID(), task.Template, err) + } + + var typeMeta TypeMeta + err = yaml.Unmarshal(tplData, &typeMeta) + if err != nil { + return fmt.Errorf("%s: failed to parse template %s: %v", task.ID(), task.Template, err) + } + + task.gvk = schema.FromAPIVersionAndKind(typeMeta.APIVersion, typeMeta.Kind) + + task.objTpl, err = template.ParseFiles(task.Template) + if err != nil { + return fmt.Errorf("%s: failed to parse template %s: %v", task.ID(), task.Template, err) + } + + if len(task.NameFormat) == 0 { + return fmt.Errorf("%s: must specify nameFormat", task.ID()) + } + + if len(task.PodNameFormat) != 0 { + if task.podNameTpl, err = template.New("podname").Parse(task.PodNameFormat); err != nil { + return fmt.Errorf("%s: failed to parse podname template: %v", task.ID(), err) + } + } + + if len(task.PodCount) != 0 { + if task.podNameTpl == nil { + return fmt.Errorf("%s: must define podNameFormat with podCount", task.ID()) + } + if task.podCountTpl, err = template.New("podcount").Parse(task.PodCount); err != nil { + return fmt.Errorf("%s: failed to parse podcount template: %v", task.ID(), err) + } + } else if task.podNameTpl != nil { + return fmt.Errorf("%s: must define podCount with podNameFormat", task.ID()) + } + + return nil +} + +// Exec implements Runnable interface +func (task *RegisterObjTask) Exec(ctx context.Context) error { + switch task.gvk.String() { + case "batch/v1, Kind=Job": + task.gvr = schema.GroupVersionResource{ + Group: task.gvk.Group, + Version: task.gvk.Version, + Resource: "jobs", + } + default: + if err := task.getGVR(); err != nil { + return err + } + } + + return task.accessor.SetObjType(task.taskID, &task.RegisterObjParams) +} + +func (task *RegisterObjTask) getGVR() error { + apiResourceList, err := task.client.ServerPreferredResources() + if err != nil { + return fmt.Errorf("%s: failed to retrieve API resources: %v", task.ID(), err) + } + + for _, list := range apiResourceList { + for _, r := range list.APIResources { + if r.Group == task.gvk.Group && r.Kind == task.gvk.Kind { + task.gvr = schema.GroupVersionResource{Group: r.Group, Version: r.Version, Resource: r.Name} + return nil + } + } + } + + return fmt.Errorf("%s: failed to find resource for %s", task.ID(), task.gvk.String()) +} diff --git a/pkg/engine/register_object_task_test.go b/pkg/engine/register_object_task_test.go new file mode 100644 index 0000000..122797f --- /dev/null +++ b/pkg/engine/register_object_task_test.go @@ -0,0 +1,162 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package engine + +import ( + "testing" + + "github.com/stretchr/testify/require" + "k8s.io/apimachinery/pkg/runtime/schema" + + "github.com/NVIDIA/knavigator/pkg/config" +) + +func TestNewRegisterObjTask(t *testing.T) { + taskID := "register" + testCases := []struct { + name string + params map[string]interface{} + simClients bool + err string + task *RegisterObjTask + pods []string + }{ + { + name: "Case 1: no client", + params: nil, + simClients: false, + err: "RegisterObj/register: DiscoveryClient is not set", + }, + { + name: "Case 2: missing template", + params: map[string]interface{}{}, + simClients: true, + err: "RegisterObj/register: must specify template", + }, + { + name: "Case 3: bad template path", + params: map[string]interface{}{ + "template": "/does/not/exist", + }, + simClients: true, + err: "RegisterObj/register: failed to read /does/not/exist: open /does/not/exist: no such file or directory", + }, + { + name: "Case 4: missing nameFormat", + params: map[string]interface{}{ + "template": "../../resources/templates/example.yml", + }, + simClients: true, + err: "RegisterObj/register: must specify nameFormat", + }, + { + name: "Case 5: bad podNameFormat", + params: map[string]interface{}{ + "template": "../../resources/templates/example.yml", + "nameFormat": "test", + "podNameFormat": "test{{", + }, + simClients: true, + err: "RegisterObj/register: failed to parse podname template: template: podname:1: unclosed action", + }, + { + name: "Case 6: bad podCount", + params: map[string]interface{}{ + "template": "../../resources/templates/example.yml", + "nameFormat": "test", + "podNameFormat": "test{{._NAME_}}", + "podCount": "test{{", + }, + simClients: true, + err: "RegisterObj/register: failed to parse podcount template: template: podcount:1: unclosed action", + }, + { + name: "Case 7: missing podCount", + params: map[string]interface{}{ + "template": "../../resources/templates/example.yml", + "nameFormat": "test", + "podNameFormat": "test{{._NAME_}}", + }, + simClients: true, + err: "RegisterObj/register: must define podCount with podNameFormat", + }, + { + name: "Case 8: missing podNameFormat", + params: map[string]interface{}{ + "template": "../../resources/templates/example.yml", + "nameFormat": "test", + "podCount": "2", + }, + simClients: true, + err: "RegisterObj/register: must define podNameFormat with podCount", + }, + { + name: "Case 9: valid input", + params: map[string]interface{}{ + "template": "../../resources/templates/example.yml", + "nameFormat": "test", + "podNameFormat": "test{{._NAME_}}", + "podCount": "2", + }, + simClients: true, + task: &RegisterObjTask{ + BaseTask: BaseTask{ + log: testLogger, + taskType: TaskRegisterObj, + taskID: taskID, + }, + RegisterObjParams: RegisterObjParams{ + Template: "../../resources/templates/example.yml", + NameFormat: "test", + PodNameFormat: "test{{._NAME_}}", + PodCount: "2", + }, + client: testDiscoveryClient, + gvk: schema.GroupVersionKind{ + Group: "example.com", + Version: "v1", + Kind: "MyObject", + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + eng, err := New(testLogger, nil, tc.simClients) + require.NoError(t, err) + + runnable, err := eng.GetTask(&config.Task{ + ID: taskID, + Type: TaskRegisterObj, + Params: tc.params, + }) + if len(tc.err) != 0 { + require.EqualError(t, err, tc.err) + require.Nil(t, tc.task) + } else { + tc.task.accessor = eng + require.NoError(t, err) + require.NotNil(t, tc.task) + + task := runnable.(*RegisterObjTask) + task.objTpl, task.podNameTpl, task.podCountTpl = nil, nil, nil + require.Equal(t, tc.task, task) + } + }) + } +} diff --git a/pkg/engine/submit_object_task.go b/pkg/engine/submit_object_task.go index 8ab50ce..0ef39c6 100644 --- a/pkg/engine/submit_object_task.go +++ b/pkg/engine/submit_object_task.go @@ -19,13 +19,13 @@ package engine import ( "context" "fmt" - "text/template" + "strconv" + "strings" "github.com/go-logr/logr" "gopkg.in/yaml.v3" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" - "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/client-go/dynamic" "github.com/NVIDIA/knavigator/pkg/config" @@ -35,40 +35,17 @@ import ( type SubmitObjTask struct { BaseTask submitObjTaskParams - client *dynamic.DynamicClient - setter ObjSetter - - // derived - obj []GenericObject + client *dynamic.DynamicClient + accessor ObjInfoAccessor } type submitObjTaskParams struct { + // RefTaskID: task ID of the corresponding RegisterObjTask + RefTaskID string `yaml:"refTaskId"` // Count: number of objects to submit; default 1. - Count int `json:"count"` - // GRV: Group/Version/Resource of the object. - GRV groupVersionResource `json:"grv"` - // Template: path to the object template; see examples in resources/templates/ - Template string `json:"template"` - // NameFormat: a Go-template parameter for generating unique object names. - // It utilizes the '_ENUM_' keyword for an incrementing counter and - // adds the '_NAME_' key to the Overrides map with the templated value. - // Example: "job{{._ENUM_}}" - NameFormat string `json:"nameformat"` - // Overrides: a map of key:value pairs to be used when executing object and name templates. - Overrides map[string]interface{} `json:"overrides"` - // Pods: an optional parameter for specifying the naming format of pods spawned by the object(s). - Pods utils.NameSelector `json:"pods,omitempty"` -} - -type groupVersionResource struct { - Group string `json:"group" yaml:"group"` - Version string `json:"version" yaml:"version"` - Resource string `json:"resource" yaml:"resource"` -} - -type typeMeta struct { - Kind string `json:"kind" yaml:"kind"` - APIVersion string `json:"apiVersion" yaml:"apiVersion"` + Count int `yaml:"count"` + // Params: a map of key:value pairs to be used when executing object and name templates. + Params map[string]interface{} `yaml:"params"` } type objectMeta struct { @@ -79,13 +56,13 @@ type objectMeta struct { } type GenericObject struct { - typeMeta `json:",inline" yaml:",inline"` + TypeMeta `json:",inline" yaml:",inline"` Metadata objectMeta `json:"metadata" yaml:"metadata"` Spec interface{} `json:"spec" yaml:"spec"` } // newSubmitObjTask initializes and returns SubmitObjTask -func newSubmitObjTask(log logr.Logger, client *dynamic.DynamicClient, setter ObjSetter, cfg *config.Task) (*SubmitObjTask, error) { +func newSubmitObjTask(log logr.Logger, client *dynamic.DynamicClient, accessor ObjInfoAccessor, cfg *config.Task) (*SubmitObjTask, error) { if client == nil { return nil, fmt.Errorf("%s/%s: DynamicClient is not set", cfg.Type, cfg.ID) } @@ -96,8 +73,8 @@ func newSubmitObjTask(log logr.Logger, client *dynamic.DynamicClient, setter Obj taskType: cfg.Type, taskID: cfg.ID, }, - client: client, - setter: setter, + client: client, + accessor: accessor, } if err := task.validate(cfg.Params); err != nil { @@ -117,103 +94,102 @@ func (task *SubmitObjTask) validate(params map[string]interface{}) error { return fmt.Errorf("%s: failed to parse parameters: %v", task.ID(), err) } + if len(task.RefTaskID) == 0 { + return fmt.Errorf("%s: must specify refTaskId", task.ID()) + } + if task.Count == 0 { task.Count = 1 // default } else if task.Count < 0 { return fmt.Errorf("%s: 'count' must be a positive number", task.ID()) } - if len(task.Template) == 0 { - return fmt.Errorf("%s: 'template' must be a filepath", task.ID()) + return nil +} + +// Exec implements Runnable interface +func (task *SubmitObjTask) Exec(ctx context.Context) error { + regObjParams, err := task.accessor.GetObjType(task.RefTaskID) + if err != nil { + return fmt.Errorf("%s: failed to get object type: %v", task.ID(), err) } - tpl, err := template.ParseFiles(task.Template) + objs, podCount, podRegexp, err := task.getGenericObjects(regObjParams) if err != nil { - return fmt.Errorf("%s: failed to parse template %s: %v", task.ID(), task.Template, err) + return err } - if len(task.NameFormat) == 0 { - if task.Count > 1 { - return fmt.Errorf("%s: must specify name format for multiple object submissions", task.ID()) + for _, obj := range objs { + crd := &unstructured.Unstructured{ + Object: map[string]interface{}{ + "apiVersion": obj.APIVersion, + "kind": obj.Kind, + "metadata": obj.Metadata, + "spec": obj.Spec, + }, + } + + if _, err := task.client.Resource(regObjParams.gvr).Namespace(obj.Metadata.Namespace).Create(ctx, crd, metav1.CreateOptions{}); err != nil { + return err } } - task.obj = make([]GenericObject, task.Count) - names, err := utils.GenerateNames(task.NameFormat, task.Count, task.Overrides) + return task.accessor.SetObjInfo(task.taskID, + NewObjInfo([]string{objs[0].Metadata.Name}, objs[0].Metadata.Namespace, regObjParams.gvr, podCount, podRegexp...)) +} + +func (task *SubmitObjTask) getGenericObjects(regObjParams *RegisterObjParams) ([]GenericObject, int, []string, error) { + names, err := utils.GenerateNames(regObjParams.NameFormat, task.Count, task.Params) if err != nil { - return fmt.Errorf("%s: failed to generate object names: %v", task.ID(), err) + return nil, 0, nil, fmt.Errorf("%s: failed to generate object names: %v", task.ID(), err) } - task.Pods.Init() - if task.Pods.List != nil { - if task.Pods.List.Params == nil { - task.Pods.List.Params = make(map[string]interface{}) - } - } - if task.Pods.Range != nil { - if task.Pods.Range.Params == nil { - task.Pods.Range.Params = make(map[string]interface{}) - } - } + objs := make([]GenericObject, task.Count) + podRegexp := []string{} for i := 0; i < task.Count; i++ { - task.Overrides["_NAME_"] = names[i] + task.Params["_NAME_"] = names[i] - data, err = utils.ExecTemplate(tpl, task.Overrides) + data, err := utils.ExecTemplate(regObjParams.objTpl, task.Params) if err != nil { - return err + return nil, 0, nil, err } - if err = yaml.Unmarshal(data, &task.obj[i]); err != nil { - return err + if err = yaml.Unmarshal(data, &objs[i]); err != nil { + return nil, 0, nil, err } - if task.Pods.List != nil { - task.Pods.List.Params["_NAME_"] = task.obj[i].Metadata.Name - } - if task.Pods.Range != nil { - task.Pods.Range.Params["_NAME_"] = task.obj[i].Metadata.Name + if regObjParams.podNameTpl != nil { + data, err = utils.ExecTemplate(regObjParams.podNameTpl, task.Params) + if err != nil { + return nil, 0, nil, err + } + re := strings.Trim(strings.TrimSpace(string(data)), "\"") + podRegexp = append(podRegexp, re) } - if err = task.Pods.Finalize(); err != nil { - return err - } - } - - if pods := task.Pods.Names(); len(pods) != 0 { - task.log.V(4).Info("Expected pods", "names", pods) } - return nil -} -// Exec implements Runnable interface -func (task *SubmitObjTask) Exec(ctx context.Context) error { - gvr := schema.GroupVersionResource{ - Group: task.GRV.Group, - Version: task.GRV.Version, - Resource: task.GRV.Resource, - } - for _, obj := range task.obj { - crd := &unstructured.Unstructured{ - Object: map[string]interface{}{ - "apiVersion": obj.APIVersion, - "kind": obj.Kind, - "metadata": obj.Metadata, - "spec": obj.Spec, - }, + var podCount int + if regObjParams.podCountTpl != nil { + data, err := utils.ExecTemplate(regObjParams.podCountTpl, task.Params) + if err != nil { + return nil, 0, nil, err } - - if _, err := task.client.Resource(gvr).Namespace(obj.Metadata.Namespace).Create(ctx, crd, metav1.CreateOptions{}); err != nil { - return err + str := string(data) + podCount, err = strconv.Atoi(str) + if err != nil { + return nil, 0, nil, fmt.Errorf("%s: failed to convert pod count %s to int: %v", task.ID(), str, err) } + podCount *= task.Count } + task.log.V(4).Info("Generating object specs", "podCount", podCount, "podRegexp", podRegexp) - return task.setter.SetObjInfo(task.taskID, - NewObjInfo([]string{task.obj[0].Metadata.Name}, task.obj[0].Metadata.Namespace, gvr, task.Pods.Names()...)) + return objs, podCount, podRegexp, nil } func (obj *GenericObject) UnmarshalYAML(unmarshal func(interface{}) error) error { var o struct { - typeMeta `yaml:",inline"` + TypeMeta `yaml:",inline"` Metadata objectMeta `yaml:"metadata"` Spec map[string]interface{} `yaml:"spec"` } @@ -223,7 +199,7 @@ func (obj *GenericObject) UnmarshalYAML(unmarshal func(interface{}) error) error return err } - obj.typeMeta = o.typeMeta + obj.TypeMeta = o.TypeMeta obj.Metadata = o.Metadata obj.Spec = convertMap(o.Spec) return nil diff --git a/pkg/engine/submit_object_task_test.go b/pkg/engine/submit_object_task_test.go index 52e2394..cac1bf3 100644 --- a/pkg/engine/submit_object_task_test.go +++ b/pkg/engine/submit_object_task_test.go @@ -18,6 +18,7 @@ package engine import ( "testing" + "text/template" "github.com/stretchr/testify/require" @@ -27,12 +28,8 @@ import ( func TestNewSubmitObjTask(t *testing.T) { taskID := "submit" - grv := map[string]interface{}{ - "group": "example.com", - "version": "v1", - "resource": "myobjects", - } - overrides := map[string]interface{}{ + params := map[string]interface{}{ + "replicas": 2, "instance": "lnx2000", "command": "sleep infinity", "image": "ubuntu", @@ -83,12 +80,16 @@ func TestNewSubmitObjTask(t *testing.T) { }, } testCases := []struct { - name string - params map[string]interface{} - simClients bool - err string - task *SubmitObjTask - pods []string + name string + params map[string]interface{} + simClients bool + regObjParams *RegisterObjParams + refTaskID string + err string + task *SubmitObjTask + objs []GenericObject + podCount int + podRegexp []string }{ { name: "Case 1: no client", @@ -99,77 +100,57 @@ func TestNewSubmitObjTask(t *testing.T) { { name: "Case 2a: parsing error", params: map[string]interface{}{ - "count": false, - "grv": grv, - "template": "../../resources/templates/example.yml", - "overrides": overrides, + "count": false, + "params": params, }, simClients: true, err: "SubmitObj/submit: failed to parse parameters: yaml: unmarshal errors:\n line 1: cannot unmarshal !!bool `false` into int", }, { - name: "Case 2b: negative count", + name: "Case 2b: missing refTaskId", params: map[string]interface{}{ - "count": -3, - "grv": grv, - "template": "../../resources/templates/example.yml", - "overrides": overrides, + "params": params, }, simClients: true, - err: "SubmitObj/submit: 'count' must be a positive number", + err: "SubmitObj/submit: must specify refTaskId", }, { - name: "Case 2c: no template", + name: "Case 2c: negative count", params: map[string]interface{}{ - "grv": grv, - "overrides": overrides, + "refTaskId": "register", + "count": -3, + "params": params, }, simClients: true, - err: "SubmitObj/submit: 'template' must be a filepath", + err: "SubmitObj/submit: 'count' must be a positive number", }, { - name: "Case 2d: bad template", + name: "Case 2d: negative count", params: map[string]interface{}{ - "grv": grv, - "template": "/does/not/exist", - "overrides": overrides, + "refTaskId": "register", + "count": 1, + "params": params, }, simClients: true, - err: "SubmitObj/submit: failed to parse template /does/not/exist: open /does/not/exist: no such file or directory", - }, - { - name: "Case 2e: no name format", - params: map[string]interface{}{ - "count": 3, - "grv": grv, - "template": "../../resources/templates/example.yml", - "overrides": overrides, + regObjParams: &RegisterObjParams{ + Template: "../../resources/templates/example.yml", + NameFormat: "job{{._ENUM_}}", }, - simClients: true, - err: "SubmitObj/submit: must specify name format for multiple object submissions", + err: "SubmitObj/submit: unreferenced task ID register", }, { - name: "Case 2f: name format error", + name: "Case 3: Valid parameters without pods", params: map[string]interface{}{ - "count": 3, - "grv": grv, - "template": "../../resources/templates/example.yml", - "nameformat": "{{{.}}", - "overrides": overrides, + "refTaskId": "register", + "count": 1, + "params": params, }, simClients: true, - err: "SubmitObj/submit: failed to generate object names: template: name:1: unexpected \"{\" in command", - }, - { - name: "Case 3: Valid parameters without pod name selector", - params: map[string]interface{}{ - "count": 1, - "grv": grv, - "template": "../../resources/templates/example.yml", - "nameformat": "job{{._ENUM_}}", - "overrides": overrides, + regObjParams: &RegisterObjParams{ + Template: "../../resources/templates/example.yml", + NameFormat: "job{{._ENUM_}}", }, - simClients: true, + refTaskID: "register", task: &SubmitObjTask{ BaseTask: BaseTask{ log: testLogger, @@ -177,52 +158,42 @@ func TestNewSubmitObjTask(t *testing.T) { taskID: taskID, }, submitObjTaskParams: submitObjTaskParams{ - Count: 1, - GRV: groupVersionResource{ - Group: "example.com", - Version: "v1", - Resource: "myobjects", - }, - Template: "../../resources/templates/example.yml", - NameFormat: "job{{._ENUM_}}", - Overrides: overrides, + RefTaskID: "register", + Count: 1, + Params: params, }, client: testDynamicClient, - obj: []GenericObject{ - { - typeMeta: typeMeta{ - APIVersion: "example.com/v1", - Kind: "MyObject", - }, - Metadata: objectMeta{ - Name: "job1", - Namespace: "test", - }, - Spec: spec, + }, + objs: []GenericObject{ + { + TypeMeta: TypeMeta{ + APIVersion: "example.com/v1", + Kind: "MyObject", + }, + Metadata: objectMeta{ + Name: "job1", + Namespace: "test", }, + Spec: spec, }, }, - pods: []string{}, + podRegexp: []string{}, }, { - name: "Case 4: Valid parameters with pod name selector", + name: "Case 4: Valid parameters with pods", params: map[string]interface{}{ - "count": 2, - "grv": grv, - "template": "../../resources/templates/example.yml", - "nameformat": "job{{._ENUM_}}", - "overrides": overrides, - "pods": map[string]interface{}{ - "list": map[string]interface{}{ - "patterns": []string{"pod{{._NAME_}}"}, - }, - "range": map[string]interface{}{ - "pattern": "{{._NAME_}}-{{._INDEX_}}", - "ranges": []string{"0-1"}, - }, - }, + "refTaskId": "register", + "count": 2, + "params": params, }, simClients: true, + regObjParams: &RegisterObjParams{ + Template: "../../resources/templates/example.yml", + NameFormat: "job{{._ENUM_}}", + PodNameFormat: "{{._NAME_}}-test-[0-9]+", + PodCount: "{{.replicas}}", + }, + refTaskID: "register", task: &SubmitObjTask{ BaseTask: BaseTask{ log: testLogger, @@ -230,43 +201,38 @@ func TestNewSubmitObjTask(t *testing.T) { taskID: taskID, }, submitObjTaskParams: submitObjTaskParams{ - Count: 2, - GRV: groupVersionResource{ - Group: "example.com", - Version: "v1", - Resource: "myobjects", - }, - Template: "../../resources/templates/example.yml", - NameFormat: "job{{._ENUM_}}", - Overrides: overrides, + RefTaskID: "register", + Count: 2, + Params: params, }, client: testDynamicClient, - obj: []GenericObject{ - { - typeMeta: typeMeta{ - APIVersion: "example.com/v1", - Kind: "MyObject", - }, - Metadata: objectMeta{ - Name: "job1", - Namespace: "test", - }, - Spec: spec, + }, + objs: []GenericObject{ + { + TypeMeta: TypeMeta{ + APIVersion: "example.com/v1", + Kind: "MyObject", }, - { - typeMeta: typeMeta{ - APIVersion: "example.com/v1", - Kind: "MyObject", - }, - Metadata: objectMeta{ - Name: "job2", - Namespace: "test", - }, - Spec: spec, + Metadata: objectMeta{ + Name: "job1", + Namespace: "test", + }, + Spec: spec, + }, + { + TypeMeta: TypeMeta{ + APIVersion: "example.com/v1", + Kind: "MyObject", }, + Metadata: objectMeta{ + Name: "job2", + Namespace: "test", + }, + Spec: spec, }, }, - pods: []string{"podjob1", "job1-0", "job1-1", "podjob2", "job2-0", "job2-1"}, + podCount: 4, + podRegexp: []string{"job1-test-[0-9]+", "job2-test-[0-9]+"}, }, } @@ -277,6 +243,10 @@ func TestNewSubmitObjTask(t *testing.T) { eng, err := New(testLogger, nil, tc.simClients) require.NoError(t, err) + if len(tc.refTaskID) != 0 { + eng.objTypeMap[tc.refTaskID] = tc.regObjParams + } + runnable, err := eng.GetTask(&config.Task{ ID: taskID, Type: TaskSubmitObj, @@ -286,18 +256,34 @@ func TestNewSubmitObjTask(t *testing.T) { require.EqualError(t, err, tc.err) require.Nil(t, tc.task) } else { - tc.task.setter = eng + tc.task.accessor = eng require.NoError(t, err) require.NotNil(t, tc.task) task := runnable.(*SubmitObjTask) - delete(task.Overrides, "_NAME_") - delete(task.Overrides, "_ENUM_") - - require.Equal(t, tc.pods, task.Pods.Names()) - task.Pods = utils.NameSelector{} + delete(task.Params, "_NAME_") + delete(task.Params, "_ENUM_") require.Equal(t, tc.task, task) + + tc.regObjParams.objTpl, err = template.ParseFiles(tc.regObjParams.Template) + require.NoError(t, err) + + if len(tc.regObjParams.PodNameFormat) != 0 { + tc.regObjParams.podNameTpl, err = template.New("podname").Parse(tc.regObjParams.PodNameFormat) + require.NoError(t, err) + } + + if len(tc.regObjParams.PodCount) != 0 { + tc.regObjParams.podCountTpl, err = template.New("podcount").Parse(tc.regObjParams.PodCount) + require.NoError(t, err) + } + + objs, podCount, podRegexp, err := task.getGenericObjects(tc.regObjParams) + require.NoError(t, err) + require.Equal(t, tc.objs, objs) + require.Equal(t, tc.podCount, podCount) + require.Equal(t, tc.podRegexp, podRegexp) } }) } diff --git a/pkg/engine/types.go b/pkg/engine/types.go index efac353..30bd056 100644 --- a/pkg/engine/types.go +++ b/pkg/engine/types.go @@ -19,6 +19,7 @@ package engine import ( "context" "fmt" + "text/template" "time" "github.com/go-logr/logr" @@ -26,6 +27,7 @@ import ( ) const ( + TaskRegisterObj = "RegisterObj" TaskSubmitObj = "SubmitObj" TaskUpdateObj = "UpdateObj" TaskCheckObj = "CheckObj" @@ -58,32 +60,65 @@ type StateParams struct { Timeout time.Duration `yaml:"timeout"` } +type TypeMeta struct { + Kind string `json:"kind" yaml:"kind"` + APIVersion string `json:"apiVersion" yaml:"apiVersion"` +} + +type RegisterObjParams struct { + // Template: path to the object template; see examples in resources/templates/ + Template string `yaml:"template"` + // NameFormat: a Go-template parameter for generating unique object names. + // It utilizes the '_ENUM_' keyword for an incrementing counter and + // adds the '_NAME_' key to the parameter map with the templated value. + // Example: "job{{._ENUM_}}" + NameFormat string `yaml:"nameFormat"` + // PodNameFormat: an optional Go-template parameter for specifying regexp for the naming format + // of pods spawned by the object(s). It utilizes the '_NAME_' keyword for the object name. + // PodNameFormat should be specified when a user intends to use 'CheckPod' task. + // Example: "{{._NAME_}}-\d+-\S+" + PodNameFormat string `yaml:"podNameFormat,omitempty"` + // PodCount: an optional Go-template parameter for specifying number of spawned pods per object. + // It can contain a numerical value or refer to the template parameter. + // PodCount should be specified when a user intends to use 'CheckPod' task. + // Example: "2" or "{{.replicas}}" + PodCount string `yaml:"podCount,omitempty"` + + // derived + gvr schema.GroupVersionResource + objTpl *template.Template + podNameTpl *template.Template + podCountTpl *template.Template +} + // ObjInfo contains object GVR and an optional list of derived pod names type ObjInfo struct { Names []string Namespace string GVR schema.GroupVersionResource - Pods []string + PodCount int + PodRegexp []string } // NewObjInfo creates new ObjInfo -func NewObjInfo(names []string, ns string, gvr schema.GroupVersionResource, pods ...string) *ObjInfo { +func NewObjInfo(names []string, ns string, gvr schema.GroupVersionResource, podCount int, podRegexp ...string) *ObjInfo { return &ObjInfo{ Names: names, Namespace: ns, GVR: gvr, - Pods: pods, + PodCount: podCount, + PodRegexp: podRegexp, } } -// ObjSetter defines interface for setting ObjInfo -type ObjSetter interface { +// ObjInfoAccessor defines interface for getting and setting object info +type ObjInfoAccessor interface { + // SetObjType maps object type to RegisterObjParams + SetObjType(string, *RegisterObjParams) error + // GetObjType returns RegisterObjParams for given object type, where object type is formatted as "." + GetObjType(string) (*RegisterObjParams, error) // SetObjInfo maps task ID to ObjInfo SetObjInfo(string, *ObjInfo) error -} - -// ObjGetter defines interface for retrieving ObjInfo -type ObjGetter interface { // GetObjInfo returns ObjInfo for given task ID GetObjInfo(string) (*ObjInfo, error) } diff --git a/pkg/engine/update_object_task.go b/pkg/engine/update_object_task.go index 54a4876..3b4f749 100644 --- a/pkg/engine/update_object_task.go +++ b/pkg/engine/update_object_task.go @@ -35,7 +35,7 @@ type UpdateObjTask struct { ObjStateTask } -func newUpdateObjTask(log logr.Logger, client *dynamic.DynamicClient, getter ObjGetter, cfg *config.Task) (*UpdateObjTask, error) { +func newUpdateObjTask(log logr.Logger, client *dynamic.DynamicClient, accessor ObjInfoAccessor, cfg *config.Task) (*UpdateObjTask, error) { if client == nil { return nil, fmt.Errorf("%s/%s: DynamicClient is not set", cfg.Type, cfg.ID) } @@ -47,8 +47,8 @@ func newUpdateObjTask(log logr.Logger, client *dynamic.DynamicClient, getter Obj taskType: cfg.Type, taskID: cfg.ID, }, - client: client, - getter: getter, + client: client, + accessor: accessor, }, } @@ -61,7 +61,7 @@ func newUpdateObjTask(log logr.Logger, client *dynamic.DynamicClient, getter Obj // Exec implements Runnable interface func (task *UpdateObjTask) Exec(ctx context.Context) error { - info, err := task.getter.GetObjInfo(task.RefTaskID) + info, err := task.accessor.GetObjInfo(task.RefTaskID) if err != nil { return err } diff --git a/pkg/engine/update_task_test.go b/pkg/engine/update_object_task_test.go similarity index 97% rename from pkg/engine/update_task_test.go rename to pkg/engine/update_object_task_test.go index 4108e4d..fcac2da 100644 --- a/pkg/engine/update_task_test.go +++ b/pkg/engine/update_object_task_test.go @@ -90,7 +90,7 @@ func TestNewUpdateObjTask(t *testing.T) { eng, err := New(testLogger, nil, tc.simClients) require.NoError(t, err) if len(tc.refTaskId) != 0 { - eng.objMap[tc.refTaskId] = nil + eng.objInfoMap[tc.refTaskId] = nil } task, err := eng.GetTask(&config.Task{ @@ -102,7 +102,7 @@ func TestNewUpdateObjTask(t *testing.T) { require.EqualError(t, err, tc.err) require.Nil(t, tc.task) } else { - tc.task.getter = eng + tc.task.accessor = eng require.NoError(t, err) require.NotNil(t, tc.task) require.Equal(t, tc.task, task) diff --git a/pkg/utils/sync_map.go b/pkg/utils/sync_map.go index 4bd8852..e9811f7 100644 --- a/pkg/utils/sync_map.go +++ b/pkg/utils/sync_map.go @@ -35,10 +35,11 @@ func NewSyncMap() *SyncMap { } // Set sets a key:value pair -func (m *SyncMap) Set(key interface{}, val interface{}) { +func (m *SyncMap) Set(key interface{}, val interface{}) int { m.mutex.Lock() defer m.mutex.Unlock() m.data[key] = val + return len(m.data) } // Get return a value for a key (first returned argument) if found (second returned argument) diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 4705752..d32d54f 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -20,6 +20,7 @@ import ( "bytes" "flag" "fmt" + "regexp" "sync/atomic" "text/template" @@ -79,3 +80,14 @@ func GenerateNames(pattern string, n int, params map[string]interface{}) ([]stri } return names, nil } + +func Exp2Regexp(expr []string) ([]*regexp.Regexp, error) { + re := make([]*regexp.Regexp, len(expr)) + for i, r := range expr { + var err error + if re[i], err = regexp.Compile(r); err != nil { + return nil, fmt.Errorf("failed to compile regexp '%s': %v", r, err) + } + } + return re, nil +} diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go index 684cefa..19460d9 100644 --- a/pkg/utils/utils_test.go +++ b/pkg/utils/utils_test.go @@ -18,6 +18,7 @@ package utils import ( "flag" + "regexp" "testing" "github.com/stretchr/testify/require" @@ -106,3 +107,38 @@ func TestGenerateNames(t *testing.T) { }) } } + +func TestExp2Regexp(t *testing.T) { + testCases := []struct { + name string + in []string + out []*regexp.Regexp + err string + }{ + { + name: "Case 1: invalid input", + in: []string{"^name[0-9]+$", "(foo(bar)"}, + err: "failed to compile regexp '(foo(bar)': error parsing regexp: missing closing ): `(foo(bar)`", + }, + { + name: "Case 2: valid input", + in: []string{"^name[0-9]+$", "text"}, + out: []*regexp.Regexp{ + regexp.MustCompile("^name[0-9]+$"), + regexp.MustCompile("text"), + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + re, err := Exp2Regexp(tc.in) + if len(tc.err) != 0 { + require.EqualError(t, err, tc.err) + } else { + require.NoError(t, err) + require.Equal(t, tc.out, re) + } + }) + } +} diff --git a/resources/tests/k8s/test-failed-job.yml b/resources/tests/k8s/test-failed-job.yml index bfaa897..9a114cb 100644 --- a/resources/tests/k8s/test-failed-job.yml +++ b/resources/tests/k8s/test-failed-job.yml @@ -1,17 +1,19 @@ name: test-k8s-job description: submit and validate a k8s job tasks: +- id: register + type: RegisterObj + params: + template: "resources/templates/k8s/failed-job.yml" + nameFormat: "job{{._ENUM_}}" + podNameFormat: "{{._NAME_}}-[0-9]-.*" + podCount: "{{.parallelism}}" - id: job type: SubmitObj params: + refTaskId: register count: 1 - grv: - group: batch - version: v1 - resource: jobs - template: "resources/templates/k8s/failed-job.yml" - nameformat: "job{{._ENUM_}}" - overrides: + params: namespace: default parallelism: 1 completions: 1 diff --git a/resources/tests/k8s/test-job.yml b/resources/tests/k8s/test-job.yml index 167376b..a47a9e5 100644 --- a/resources/tests/k8s/test-job.yml +++ b/resources/tests/k8s/test-job.yml @@ -1,17 +1,19 @@ name: test-k8s-job description: submit and validate a k8s job tasks: +- id: register + type: RegisterObj + params: + template: "resources/templates/k8s/job.yml" + nameFormat: "job{{._ENUM_}}" + podNameFormat: "{{._NAME_}}-[0-9]-.*" + podCount: "{{.parallelism}}" - id: job type: SubmitObj params: + refTaskId: register count: 1 - grv: - group: batch - version: v1 - resource: jobs - template: "resources/templates/k8s/job.yml" - nameformat: "job{{._ENUM_}}" - overrides: + params: namespace: k8s-test parallelism: 2 completions: 2 @@ -22,3 +24,9 @@ tasks: cpu: 100m memory: 512M gpu: 8 +- id: status + type: CheckPod + params: + refTaskId: job + status: Running + timeout: 5s diff --git a/resources/tests/k8s/test-jobset-with-driver.yml b/resources/tests/k8s/test-jobset-with-driver.yml index 4d42a92..2202a9b 100644 --- a/resources/tests/k8s/test-jobset-with-driver.yml +++ b/resources/tests/k8s/test-jobset-with-driver.yml @@ -1,17 +1,19 @@ name: test-k8s-jobset-with-driver description: submit and validate a k8s jobset with 1 driver and 1 worker job tasks: +- id: register + type: RegisterObj + params: + template: "resources/templates/k8s/jobset-with-driver.yml" + nameFormat: "jobset{{._ENUM_}}" + podNameFormat: "{{._NAME_}}-(worker|driver)-[0-9]+-[0-9]+-.+" + podCount: "{{.replicas}}" - id: jobset type: SubmitObj params: - count: 1 - grv: - group: jobset.x-k8s.io - version: v1alpha2 - resource: jobsets - template: "resources/templates/k8s/jobset-with-driver.yml" - nameformat: "jobset{{._ENUM_}}" - overrides: + refTaskId: register + count: 2 + params: namespace: default replicas: 1 parallelism: 1 @@ -23,3 +25,9 @@ tasks: cpu: 100m memory: 512M gpu: 8 +- id: status + type: CheckPod + params: + refTaskId: jobset + status: Running + timeout: 5s diff --git a/resources/tests/k8s/test-jobset.yml b/resources/tests/k8s/test-jobset.yml index 43b1f81..6daa879 100644 --- a/resources/tests/k8s/test-jobset.yml +++ b/resources/tests/k8s/test-jobset.yml @@ -1,19 +1,21 @@ name: test-k8s-jobset description: submit and validate a k8s jobset with 1 worker job tasks: +- id: register + type: RegisterObj + params: + template: "resources/templates/k8s/jobset.yml" + nameFormat: "jobset{{._ENUM_}}" + podNameFormat: "{{._NAME_}}-workers-[0-9]+-[0-9]+-.+" + podCount: "{{.replicas}}" - id: jobset type: SubmitObj params: + refTaskId: register count: 1 - grv: - group: jobset.x-k8s.io - version: v1alpha2 - resource: jobsets - template: "resources/templates/k8s/jobset.yml" - nameformat: "jobset{{._ENUM_}}" - overrides: + params: namespace: default - replicas: 1 + replicas: 2 parallelism: 1 completions: 1 backoffLimit: 0 @@ -23,3 +25,9 @@ tasks: cpu: 100m memory: 512M gpu: 8 +- id: status + type: CheckPod + params: + refTaskId: jobset + status: Running + timeout: 5s diff --git a/resources/tests/kueue/test-job.yml b/resources/tests/kueue/test-job.yml index 3e5959e..c29316a 100644 --- a/resources/tests/kueue/test-job.yml +++ b/resources/tests/kueue/test-job.yml @@ -1,17 +1,19 @@ name: test-kueue-job description: submit and validate a kueue job tasks: +- id: register + type: RegisterObj + params: + template: "resources/templates/kueue/job.yml" + nameFormat: "job{{._ENUM_}}" + podNameFormat: "{{._NAME_}}-[0-9]-.*" + podCount: "{{.parallelism}}" - id: job type: SubmitObj params: + refTaskId: register count: 1 - grv: - group: batch - version: v1 - resource: jobs - template: "resources/templates/kueue/job.yml" - nameformat: "job{{._ENUM_}}" - overrides: + params: queueName: team-a-queue namespace: default parallelism: 3 @@ -21,3 +23,9 @@ tasks: cpu: 100m memory: 512M gpu: 1 +- id: status + type: CheckPod + params: + refTaskId: job + status: Running + timeout: 5s diff --git a/resources/tests/volcano/test-job.yml b/resources/tests/volcano/test-job.yml index 69ae9ea..73ca893 100644 --- a/resources/tests/volcano/test-job.yml +++ b/resources/tests/volcano/test-job.yml @@ -1,17 +1,19 @@ name: test-volcano-job description: submit and manage volcano job tasks: +- id: register + type: RegisterObj + params: + template: "resources/templates/volcano/job.yml" + nameFormat: "j{{._ENUM_}}" + podNameFormat: "{{._NAME_}}-test-[0-9]+" + podCount: "{{.replicas}}" - id: job type: SubmitObj params: - count: 1 - grv: - group: batch.volcano.sh - version: v1alpha1 - resource: jobs - template: "resources/templates/volcano/job.yml" - nameformat: "j{{._ENUM_}}" - overrides: + refTaskId: register + count: 2 + params: namespace: default replicas: 2 priorityClassName: normal-priority @@ -19,15 +21,9 @@ tasks: cpu: 100m memory: 512M gpu: 8 - pods: - range: - pattern: "{{._NAME_}}-test-{{._INDEX_}}" - ranges: ["0-1"] - id: status type: CheckPod params: refTaskId: job status: Running - nodeLabels: - nodeType: gpu timeout: 5s