diff --git a/.gitignore b/.gitignore index 2880c6d9e8..685dffe2c8 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ bin/ /tf-operator vendor/ testbin/* +dep-crds/ cover.out # IDEs diff --git a/Makefile b/Makefile index 44a011c093..1c26e6212a 100644 --- a/Makefile +++ b/Makefile @@ -74,15 +74,16 @@ HAS_SETUP_ENVTEST := $(shell command -v setup-envtest;) testall: manifests generate fmt vet golangci-lint test ## Run tests. test: envtest - KUBEBUILDER_ASSETS="$(shell setup-envtest use $(ENVTEST_K8S_VERSION) -p path)" go test ./... -coverprofile cover.out + KUBEBUILDER_ASSETS="$(shell setup-envtest use $(ENVTEST_K8S_VERSION) -p path)" \ + go test ./pkg/apis/kubeflow.org/v1/... ./pkg/cert/... ./pkg/common/... ./pkg/config/... ./pkg/controller.v1/... ./pkg/core/... ./pkg/util/... ./pkg/webhooks/... -coverprofile cover.out .PHONY: test-integrationv2 -test-integrationv2: envtest +test-integrationv2: envtest jobset-operator-crd scheduler-plugins-crd KUBEBUILDER_ASSETS="$(shell setup-envtest use $(ENVTEST_K8S_VERSION) -p path)" go test ./test/... -coverprofile cover.out .PHONY: testv2 testv2: - go test ./pkg/controller.v2/... ./pkg/runtime.v2/... ./pkg/webhook.v2/... ./pkg/util.v2/... -coverprofile cover.out + go test ./pkg/apis/kubeflow.org/v2alpha1/... ./pkg/controller.v2/... ./pkg/runtime.v2/... ./pkg/webhook.v2/... ./pkg/util.v2/... -coverprofile cover.out envtest: ifndef HAS_SETUP_ENVTEST @@ -127,3 +128,18 @@ controller-gen: ## Download controller-gen locally if necessary. KUSTOMIZE = $(shell pwd)/bin/kustomize kustomize: ## Download kustomize locally if necessary. GOBIN=$(PROJECT_DIR)/bin go install sigs.k8s.io/kustomize/kustomize/v4@v4.5.7 + +## Download external CRDs for the integration testings. +EXTERNAL_CRDS_DIR ?= $(PROJECT_DIR)/dep-crds + +JOBSET_ROOT = $(shell go list -m -mod=readonly -f "{{.Dir}}" sigs.k8s.io/jobset) +.PHONY: jobset-operator-crd +jobset-operator-crd: ## Copy the CRDs from the jobset-operator to the dep-crds directory. + mkdir -p $(EXTERNAL_CRDS_DIR)/jobset-operator/ + cp -f $(JOBSET_ROOT)/config/components/crd/bases/* $(EXTERNAL_CRDS_DIR)/jobset-operator/ + +SCHEDULER_PLUGINS_ROOT = $(shell go list -m -f "{{.Dir}}" sigs.k8s.io/scheduler-plugins) +.PHONY: scheduler-plugins-crd +scheduler-plugins-crd: + mkdir -p $(EXTERNAL_CRDS_DIR)/scheduler-plugins/ + cp -f $(SCHEDULER_PLUGINS_ROOT)/manifests/coscheduling/* $(PROJECT_DIR)/dep-crds/scheduler-plugins diff --git a/pkg/controller.v2/setup.go b/pkg/controller.v2/setup.go index e2fadd3a96..3fb8b98d9d 100644 --- a/pkg/controller.v2/setup.go +++ b/pkg/controller.v2/setup.go @@ -26,7 +26,8 @@ func SetupControllers(mgr ctrl.Manager, runtimes map[string]runtime.Runtime) (st if err := NewTrainJobReconciler( mgr.GetClient(), mgr.GetEventRecorderFor("training-operator-trainjob-controller"), - ).SetupWithManager(mgr, runtimes); err != nil { + runtimes, + ).SetupWithManager(mgr); err != nil { return "TrainJob", err } return "", nil diff --git a/pkg/controller.v2/trainjob_controller.go b/pkg/controller.v2/trainjob_controller.go index ef2f3242ce..a31962c199 100644 --- a/pkg/controller.v2/trainjob_controller.go +++ b/pkg/controller.v2/trainjob_controller.go @@ -20,10 +20,13 @@ import ( "context" "github.com/go-logr/logr" + "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/client-go/tools/record" "k8s.io/klog/v2" + "k8s.io/utils/ptr" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/apiutil" kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" @@ -33,13 +36,15 @@ type TrainJobReconciler struct { log logr.Logger client client.Client recorder record.EventRecorder + runtimes map[string]runtime.Runtime } -func NewTrainJobReconciler(client client.Client, recorder record.EventRecorder) *TrainJobReconciler { +func NewTrainJobReconciler(client client.Client, recorder record.EventRecorder, runs map[string]runtime.Runtime) *TrainJobReconciler { return &TrainJobReconciler{ log: ctrl.Log.WithName("trainjob-controller"), client: client, recorder: recorder, + runtimes: runs, } } @@ -49,15 +54,70 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c return ctrl.Result{}, client.IgnoreNotFound(err) } log := ctrl.LoggerFrom(ctx).WithValues("trainJob", klog.KObj(&trainJob)) - ctrl.LoggerInto(ctx, log) + ctx = ctrl.LoggerInto(ctx, log) log.V(2).Info("Reconciling TrainJob") + if err := r.createOrUpdateObjs(ctx, &trainJob); err != nil { + return ctrl.Result{}, err + } + // TODO (tenzen-y): Do update the status. return ctrl.Result{}, nil } -func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager, runtimes map[string]runtime.Runtime) error { +func (r *TrainJobReconciler) createOrUpdateObjs(ctx context.Context, trainJob *kubeflowv2.TrainJob) error { + log := ctrl.LoggerFrom(ctx) + + // Controller assumes the runtime existence has already verified in the webhook on TrainJob creation. + run := r.runtimes[runtimeRefToGroupKind(trainJob.Spec.RuntimeRef).String()] + objs, err := run.NewObjects(ctx, trainJob) + if err != nil { + return err + } + for _, obj := range objs { + var gvk schema.GroupVersionKind + if gvk, err = apiutil.GVKForObject(obj.DeepCopyObject(), r.client.Scheme()); err != nil { + return err + } + logKeysAndValues := []any{ + "groupVersionKind", gvk.String(), + "namespace", obj.GetNamespace(), + "name", obj.GetName(), + } + // TODO (tenzen-y): Ideally, we should use the SSA instead of checking existence. + // Non-empty resourceVersion indicates UPDATE operation. + var creationErr error + var created bool + if obj.GetResourceVersion() == "" { + creationErr = r.client.Create(ctx, obj) + created = creationErr == nil + } + switch { + case created: + log.V(5).Info("Succeeded to create object", logKeysAndValues) + continue + case client.IgnoreAlreadyExists(creationErr) != nil: + return creationErr + default: + // This indicates CREATE operation has not been performed or the object has already existed in the cluster. + if err = r.client.Update(ctx, obj); err != nil { + return err + } + log.V(5).Info("Succeeded to update object", logKeysAndValues) + } + } + return nil +} + +func runtimeRefToGroupKind(runtimeRef kubeflowv2.RuntimeRef) schema.GroupKind { + return schema.GroupKind{ + Group: ptr.Deref(runtimeRef.APIGroup, ""), + Kind: ptr.Deref(runtimeRef.Kind, ""), + } +} + +func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager) error { b := ctrl.NewControllerManagedBy(mgr). For(&kubeflowv2.TrainJob{}) - for _, run := range runtimes { + for _, run := range r.runtimes { for _, registrar := range run.EventHandlerRegistrars() { if registrar != nil { b = registrar(b, mgr.GetClient()) diff --git a/pkg/runtime.v2/core/clustertrainingruntime_test.go b/pkg/runtime.v2/core/clustertrainingruntime_test.go index 696d486ab5..84c3e39fe9 100644 --- a/pkg/runtime.v2/core/clustertrainingruntime_test.go +++ b/pkg/runtime.v2/core/clustertrainingruntime_test.go @@ -46,6 +46,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) { }{ "succeeded to build JobSet and PodGroup": { trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). + Suspend(true). UID("uid"). RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), "test-runtime"). Trainer( @@ -57,7 +58,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) { clusterTrainingRuntime: baseRuntime.RuntimeSpec( testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntime.Spec). ContainerImage("test:runtime"). - PodGroupPolicySchedulingTimeout(120). + PodGroupPolicyCoschedulingSchedulingTimeout(120). MLPolicyNumNodes(20). ResourceRequests(0, corev1.ResourceList{ corev1.ResourceCPU: resource.MustParse("1"), @@ -69,6 +70,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) { ).Obj(), wantObjs: []client.Object{ testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job"). + Suspend(true). PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job"). ContainerImage(ptr.To("test:trainjob")). JobCompletionMode(batchv1.IndexedCompletion). diff --git a/pkg/runtime.v2/core/trainingruntime_test.go b/pkg/runtime.v2/core/trainingruntime_test.go index a32ad33852..9c2deadbb6 100644 --- a/pkg/runtime.v2/core/trainingruntime_test.go +++ b/pkg/runtime.v2/core/trainingruntime_test.go @@ -46,6 +46,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { }{ "succeeded to build JobSet and PodGroup": { trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). + Suspend(true). UID("uid"). RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "test-runtime"). SpecLabel("conflictLabel", "override"). @@ -62,7 +63,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { RuntimeSpec( testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntime.Spec). ContainerImage("test:runtime"). - PodGroupPolicySchedulingTimeout(120). + PodGroupPolicyCoschedulingSchedulingTimeout(120). MLPolicyNumNodes(20). ResourceRequests(0, corev1.ResourceList{ corev1.ResourceCPU: resource.MustParse("1"), @@ -74,6 +75,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { ).Obj(), wantObjs: []client.Object{ testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job"). + Suspend(true). Label("conflictLabel", "override"). Annotation("conflictAnnotation", "override"). PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job"). diff --git a/pkg/runtime.v2/framework/core/framework_test.go b/pkg/runtime.v2/framework/core/framework_test.go index c3b630d923..0a1edb266f 100644 --- a/pkg/runtime.v2/framework/core/framework_test.go +++ b/pkg/runtime.v2/framework/core/framework_test.go @@ -334,13 +334,12 @@ func TestRunComponentBuilderPlugins(t *testing.T) { ResourceRequests(1, corev1.ResourceList{ corev1.ResourceCPU: resource.MustParse("1"), corev1.ResourceMemory: resource.MustParse("2Gi"), - }). - Clone() + }) jobSetWithPropagatedTrainJobParams := jobSetBase. + Clone(). JobCompletionMode(batchv1.IndexedCompletion). ContainerImage(ptr.To("foo:bar")). - ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid"). - Clone() + ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid") cases := map[string]struct { runtimeInfo *runtime.Info @@ -361,6 +360,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) { Obj(), runtimeInfo: &runtime.Info{ Obj: jobSetBase. + Clone(). Obj(), Policy: runtime.Policy{ MLPolicy: &kubeflowv2.MLPolicy{ @@ -403,10 +403,12 @@ func TestRunComponentBuilderPlugins(t *testing.T) { ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid"). Obj(), jobSetWithPropagatedTrainJobParams. + Clone(). Obj(), }, wantRuntimeInfo: &runtime.Info{ Obj: jobSetWithPropagatedTrainJobParams. + Clone(). Obj(), Policy: runtime.Policy{ MLPolicy: &kubeflowv2.MLPolicy{ diff --git a/pkg/runtime.v2/framework/plugins/jobset/builder.go b/pkg/runtime.v2/framework/plugins/jobset/builder.go index ed336edfc9..8b7a2b4571 100644 --- a/pkg/runtime.v2/framework/plugins/jobset/builder.go +++ b/pkg/runtime.v2/framework/plugins/jobset/builder.go @@ -28,12 +28,12 @@ import ( ) type Builder struct { - *jobsetv1alpha2.JobSet + jobsetv1alpha2.JobSet } func NewBuilder(objectKey client.ObjectKey, jobSetTemplateSpec kubeflowv2.JobSetTemplateSpec) *Builder { return &Builder{ - JobSet: &jobsetv1alpha2.JobSet{ + JobSet: jobsetv1alpha2.JobSet{ TypeMeta: metav1.TypeMeta{ APIVersion: jobsetv1alpha2.SchemeGroupVersion.String(), Kind: "JobSet", @@ -76,8 +76,13 @@ func (b *Builder) PodLabels(labels map[string]string) *Builder { return b } +func (b *Builder) Suspend(suspend *bool) *Builder { + b.Spec.Suspend = suspend + return b +} + // TODO: Need to support all TrainJob fields. func (b *Builder) Build() *jobsetv1alpha2.JobSet { - return b.JobSet + return &b.JobSet } diff --git a/pkg/runtime.v2/framework/plugins/jobset/jobset.go b/pkg/runtime.v2/framework/plugins/jobset/jobset.go index 82eca0ef7f..b5ebbebf14 100644 --- a/pkg/runtime.v2/framework/plugins/jobset/jobset.go +++ b/pkg/runtime.v2/framework/plugins/jobset/jobset.go @@ -74,15 +74,30 @@ func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *kubefl if !ok { return nil, nil } - jobSetBuilder := NewBuilder(client.ObjectKeyFromObject(trainJob), kubeflowv2.JobSetTemplateSpec{ - ObjectMeta: metav1.ObjectMeta{ - Labels: info.Labels, - Annotations: info.Annotations, - }, - Spec: raw.Spec, - }) + + var jobSetBuilder *Builder + oldJobSet := &jobsetv1alpha2.JobSet{} + if err := j.client.Get(ctx, client.ObjectKeyFromObject(trainJob), oldJobSet); err != nil { + if !apierrors.IsNotFound(err) { + return nil, err + } + jobSetBuilder = NewBuilder(client.ObjectKeyFromObject(trainJob), kubeflowv2.JobSetTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Labels: info.Labels, + Annotations: info.Annotations, + }, + Spec: raw.Spec, + }) + oldJobSet = nil + } else { + jobSetBuilder = &Builder{ + JobSet: *oldJobSet.DeepCopy(), + } + } + // TODO (tenzen-y): We should support all field propagation in builder. jobSet := jobSetBuilder. + Suspend(trainJob.Spec.Suspend). ContainerImage(trainJob.Spec.Trainer.Image). JobCompletionMode(batchv1.IndexedCompletion). PodLabels(info.PodLabels). @@ -90,13 +105,6 @@ func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *kubefl if err := ctrlutil.SetControllerReference(trainJob, jobSet, j.scheme); err != nil { return nil, err } - oldJobSet := &jobsetv1alpha2.JobSet{} - if err := j.client.Get(ctx, client.ObjectKeyFromObject(jobSet), oldJobSet); err != nil { - if !apierrors.IsNotFound(err) { - return nil, err - } - oldJobSet = nil - } if err := info.Update(jobSet); err != nil { return nil, err } @@ -106,9 +114,14 @@ func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *kubefl return nil, nil } -func needsCreateOrUpdate(old, new *jobsetv1alpha2.JobSet, suspended bool) bool { +func needsCreateOrUpdate(old, new *jobsetv1alpha2.JobSet, trainJobIsSuspended bool) bool { return old == nil || - suspended && (!equality.Semantic.DeepEqual(old.Spec, new.Spec) || !maps.Equal(old.Labels, new.Labels) || !maps.Equal(old.Annotations, new.Annotations)) + (!trainJobIsSuspended && !ptr.Equal(old.Spec.Suspend, new.Spec.Suspend)) || + (trainJobIsSuspended && (!equality.Semantic.DeepEqual(old.Spec, new.Spec) || !maps.Equal(old.Labels, new.Labels) || !maps.Equal(old.Annotations, new.Annotations))) +} + +func jobSetIsSuspended(jobSet *jobsetv1alpha2.JobSet) bool { + return ptr.Deref(jobSet.Spec.Suspend, false) } func (j *JobSet) ReconcilerBuilders() []runtime.ReconcilerBuilder { diff --git a/pkg/util.v2/testing/wrapper.go b/pkg/util.v2/testing/wrapper.go index 3be7f4f194..de83294aae 100644 --- a/pkg/util.v2/testing/wrapper.go +++ b/pkg/util.v2/testing/wrapper.go @@ -86,6 +86,11 @@ func MakeJobSetWrapper(namespace, name string) *JobSetWrapper { } } +func (j *JobSetWrapper) Suspend(suspend bool) *JobSetWrapper { + j.Spec.Suspend = &suspend + return j +} + func (j *JobSetWrapper) Completions(idx int, completions int32) *JobSetWrapper { if len(j.Spec.ReplicatedJobs) < idx { return j @@ -204,6 +209,11 @@ func MakeTrainJobWrapper(namespace, name string) *TrainJobWrapper { } } +func (t *TrainJobWrapper) Suspend(suspend bool) *TrainJobWrapper { + t.Spec.Suspend = &suspend + return t +} + func (t *TrainJobWrapper) UID(uid string) *TrainJobWrapper { t.ObjectMeta.UID = types.UID(uid) return t @@ -225,6 +235,24 @@ func (t *TrainJobWrapper) SpecAnnotation(key, value string) *TrainJobWrapper { return t } +func (t *TrainJobWrapper) RuntimeRef(gvk schema.GroupVersionKind, name string) *TrainJobWrapper { + t.Spec.RuntimeRef = kubeflowv2.RuntimeRef{ + APIGroup: &gvk.Group, + Kind: &gvk.Kind, + Name: name, + } + return t +} + +func (t *TrainJobWrapper) Trainer(trainer *kubeflowv2.Trainer) *TrainJobWrapper { + t.Spec.Trainer = trainer + return t +} + +func (t *TrainJobWrapper) Obj() *kubeflowv2.TrainJob { + return &t.TrainJob +} + type TrainJobTrainerWrapper struct { kubeflowv2.Trainer } @@ -244,24 +272,6 @@ func (t *TrainJobTrainerWrapper) Obj() *kubeflowv2.Trainer { return &t.Trainer } -func (t *TrainJobWrapper) Trainer(trainer *kubeflowv2.Trainer) *TrainJobWrapper { - t.Spec.Trainer = trainer - return t -} - -func (t *TrainJobWrapper) RuntimeRef(gvk schema.GroupVersionKind, name string) *TrainJobWrapper { - t.Spec.RuntimeRef = kubeflowv2.RuntimeRef{ - APIGroup: &gvk.Group, - Kind: &gvk.Kind, - Name: name, - } - return t -} - -func (t *TrainJobWrapper) Obj() *kubeflowv2.TrainJob { - return &t.TrainJob -} - type TrainingRuntimeWrapper struct { kubeflowv2.TrainingRuntime } @@ -455,15 +465,19 @@ func (s *TrainingRuntimeSpecWrapper) ResourceRequests(idx int, res corev1.Resour return s } -func (s *TrainingRuntimeSpecWrapper) PodGroupPolicySchedulingTimeout(timeout int32) *TrainingRuntimeSpecWrapper { +func (s *TrainingRuntimeSpecWrapper) PodGroupPolicyCoscheduling(src *kubeflowv2.CoschedulingPodGroupPolicySource) *TrainingRuntimeSpecWrapper { + if s.PodGroupPolicy == nil { + s.PodGroupPolicy = &kubeflowv2.PodGroupPolicy{} + } + s.PodGroupPolicy.Coscheduling = src + return s +} + +func (s *TrainingRuntimeSpecWrapper) PodGroupPolicyCoschedulingSchedulingTimeout(timeout int32) *TrainingRuntimeSpecWrapper { if s.PodGroupPolicy == nil || s.PodGroupPolicy.Coscheduling == nil { - s.PodGroupPolicy = &kubeflowv2.PodGroupPolicy{ - PodGroupPolicySource: kubeflowv2.PodGroupPolicySource{ - Coscheduling: &kubeflowv2.CoschedulingPodGroupPolicySource{ - ScheduleTimeoutSeconds: &timeout, - }, - }, - } + return s.PodGroupPolicyCoscheduling(&kubeflowv2.CoschedulingPodGroupPolicySource{ + ScheduleTimeoutSeconds: &timeout, + }) } s.PodGroupPolicy.Coscheduling.ScheduleTimeoutSeconds = &timeout return s diff --git a/test/integration/controller.v2/trainjob_controller_test.go b/test/integration/controller.v2/trainjob_controller_test.go index e31b7e3f79..24f3d401f2 100644 --- a/test/integration/controller.v2/trainjob_controller_test.go +++ b/test/integration/controller.v2/trainjob_controller_test.go @@ -19,13 +19,19 @@ package controllerv2 import ( "github.com/onsi/ginkgo/v2" "github.com/onsi/gomega" + batchv1 "k8s.io/api/batch/v1" corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client" + jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" + schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1" kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" + testingutil "github.com/kubeflow/training-operator/pkg/util.v2/testing" "github.com/kubeflow/training-operator/test/integration/framework" + "github.com/kubeflow/training-operator/test/util" ) var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { @@ -54,22 +60,185 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { }) ginkgo.When("Reconciling TrainJob", func() { + var ( + trainJob *kubeflowv2.TrainJob + trainJobKey client.ObjectKey + trainingRuntime *kubeflowv2.TrainingRuntime + ) + ginkgo.AfterEach(func() { gomega.Expect(k8sClient.DeleteAllOf(ctx, &kubeflowv2.TrainJob{}, client.InNamespace(ns.Name))).Should(gomega.Succeed()) }) - ginkgo.It("Should succeed to create TrainJob", func() { - trainJob := &kubeflowv2.TrainJob{ - TypeMeta: metav1.TypeMeta{ - APIVersion: kubeflowv2.SchemeGroupVersion.String(), - Kind: "TrainJob", - }, - ObjectMeta: metav1.ObjectMeta{ - Name: "alpha", - Namespace: ns.Name, - }, - } + ginkgo.BeforeEach(func() { + trainJob = testingutil.MakeTrainJobWrapper(ns.Name, "alpha"). + Suspend(true). + RuntimeRef(kubeflowv2.GroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "alpha"). + SpecLabel("testingKey", "testingVal"). + SpecAnnotation("testingKey", "testingVal"). + Trainer( + testingutil.MakeTrainJobTrainerWrapper(). + ContainerImage("trainJob"). + Obj()). + Obj() + trainJobKey = client.ObjectKeyFromObject(trainJob) + baseRuntime := testingutil.MakeTrainingRuntimeWrapper(ns.Name, "alpha") + trainingRuntime = baseRuntime.Clone(). + RuntimeSpec( + testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntime.Clone().Spec). + ContainerImage("trainingRuntime"). + PodGroupPolicyCoscheduling(&kubeflowv2.CoschedulingPodGroupPolicySource{}). + MLPolicyNumNodes(100). + ResourceRequests(0, corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("5"), + }). + ResourceRequests(1, corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("10"), + }). + Obj()). + Obj() + }) + + ginkgo.It("Should succeed to create TrainJob with TrainingRuntime", func() { + ginkgo.By("Creating TrainingRuntime and TrainJob") + gomega.Expect(k8sClient.Create(ctx, trainingRuntime)).Should(gomega.Succeed()) gomega.Expect(k8sClient.Create(ctx, trainJob)).Should(gomega.Succeed()) + + ginkgo.By("Checking if appropriately JobSet and PodGroup are created") + gomega.Eventually(func(g gomega.Gomega) { + jobSet := &jobsetv1alpha2.JobSet{} + g.Expect(k8sClient.Get(ctx, trainJobKey, jobSet)).Should(gomega.Succeed()) + g.Expect(jobSet).Should(gomega.BeComparableTo( + testingutil.MakeJobSetWrapper(ns.Name, trainJobKey.Name). + Suspend(true). + Label("testingKey", "testingVal"). + Annotation("testingKey", "testingVal"). + PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, trainJobKey.Name). + ContainerImage(ptr.To("trainJob")). + JobCompletionMode(batchv1.IndexedCompletion). + ResourceRequests(0, corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("5"), + }). + ResourceRequests(1, corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("10"), + }). + ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainJobKind), trainJobKey.Name, string(trainJob.UID)). + Obj(), + util.IgnoreObjectMetadata)) + pg := &schedulerpluginsv1alpha1.PodGroup{} + g.Expect(k8sClient.Get(ctx, trainJobKey, pg)).Should(gomega.Succeed()) + g.Expect(pg).Should(gomega.BeComparableTo( + testingutil.MakeSchedulerPluginsPodGroup(ns.Name, trainJobKey.Name). + MinMember(200). + MinResources(corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1500"), + }). + ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainJobKind), trainJobKey.Name, string(trainJob.UID)). + Obj(), + util.IgnoreObjectMetadata)) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.It("Should succeeded to update JobSet only when TrainJob is suspended", func() { + ginkgo.By("Creating TrainingRuntime and suspended TrainJob") + gomega.Expect(k8sClient.Create(ctx, trainingRuntime)).Should(gomega.Succeed()) + gomega.Expect(k8sClient.Create(ctx, trainJob)).Should(gomega.Succeed()) + + ginkgo.By("Checking if JobSet and PodGroup are created") + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, trainJobKey, &jobsetv1alpha2.JobSet{})).Should(gomega.Succeed()) + g.Expect(k8sClient.Get(ctx, trainJobKey, &schedulerpluginsv1alpha1.PodGroup{})).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Updating suspended TrainJob Trainer image") + updatedImageName := "updated-trainer-image" + originImageName := *trainJob.Spec.Trainer.Image + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, trainJobKey, trainJob)).Should(gomega.Succeed()) + trainJob.Spec.Trainer.Image = &updatedImageName + g.Expect(k8sClient.Update(ctx, trainJob)).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Trainer image should be updated") + gomega.Eventually(func(g gomega.Gomega) { + jobSet := &jobsetv1alpha2.JobSet{} + g.Expect(k8sClient.Get(ctx, trainJobKey, jobSet)).Should(gomega.Succeed()) + g.Expect(jobSet).Should(gomega.BeComparableTo( + testingutil.MakeJobSetWrapper(ns.Name, trainJobKey.Name). + Suspend(true). + Label("testingKey", "testingVal"). + Annotation("testingKey", "testingVal"). + PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, trainJobKey.Name). + ContainerImage(&updatedImageName). + JobCompletionMode(batchv1.IndexedCompletion). + ResourceRequests(0, corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("5"), + }). + ResourceRequests(1, corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("10"), + }). + ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainJobKind), trainJobKey.Name, string(trainJob.UID)). + Obj(), + util.IgnoreObjectMetadata)) + pg := &schedulerpluginsv1alpha1.PodGroup{} + g.Expect(k8sClient.Get(ctx, trainJobKey, pg)).Should(gomega.Succeed()) + g.Expect(pg).Should(gomega.BeComparableTo( + testingutil.MakeSchedulerPluginsPodGroup(ns.Name, trainJobKey.Name). + MinMember(200). + MinResources(corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1500"), + }). + ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainJobKind), trainJobKey.Name, string(trainJob.UID)). + Obj(), + util.IgnoreObjectMetadata)) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Unsuspending TrainJob") + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, trainJobKey, trainJob)).Should(gomega.Succeed()) + trainJob.Spec.Suspend = ptr.To(false) + g.Expect(k8sClient.Update(ctx, trainJob)).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + gomega.Eventually(func(g gomega.Gomega) { + jobSet := &jobsetv1alpha2.JobSet{} + g.Expect(k8sClient.Get(ctx, trainJobKey, jobSet)).Should(gomega.Succeed()) + g.Expect(ptr.Deref(jobSet.Spec.Suspend, false)).Should(gomega.BeFalse()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Trying to restore trainer image") + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, trainJobKey, trainJob)).Should(gomega.Succeed()) + trainJob.Spec.Trainer.Image = &originImageName + g.Expect(k8sClient.Update(ctx, trainJob)).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Checking if JobSet keep having updated image") + gomega.Consistently(func(g gomega.Gomega) { + jobSet := &jobsetv1alpha2.JobSet{} + g.Expect(k8sClient.Get(ctx, trainJobKey, jobSet)).Should(gomega.Succeed()) + for _, rJob := range jobSet.Spec.ReplicatedJobs { + g.Expect(rJob.Template.Spec.Template.Spec.Containers[0].Image).Should(gomega.Equal(updatedImageName)) + } + }, util.ConsistentDuration, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Trying to re-suspend TrainJob and restore trainer image") + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, trainJobKey, trainJob)) + trainJob.Spec.Suspend = ptr.To(true) + trainJob.Spec.Trainer.Image = &originImageName + g.Expect(k8sClient.Update(ctx, trainJob)).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Checking if JobSet image is restored") + gomega.Eventually(func(g gomega.Gomega) { + jobSet := &jobsetv1alpha2.JobSet{} + g.Expect(k8sClient.Get(ctx, trainJobKey, jobSet)).Should(gomega.Succeed()) + g.Expect(jobSet.Spec.Suspend).ShouldNot(gomega.BeNil()) + g.Expect(*jobSet.Spec.Suspend).Should(gomega.BeTrue()) + for _, rJob := range jobSet.Spec.ReplicatedJobs { + g.Expect(rJob.Template.Spec.Template.Spec.Containers[0].Image).Should(gomega.Equal(originImageName)) + } + }, util.Timeout, util.Interval).Should(gomega.Succeed()) }) }) diff --git a/test/integration/framework/framework.go b/test/integration/framework/framework.go index 0a3a7fb774..62beabf593 100644 --- a/test/integration/framework/framework.go +++ b/test/integration/framework/framework.go @@ -36,6 +36,8 @@ import ( "sigs.k8s.io/controller-runtime/pkg/manager" metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server" "sigs.k8s.io/controller-runtime/pkg/webhook" + jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" + schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1" kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" controllerv2 "github.com/kubeflow/training-operator/pkg/controller.v2" @@ -52,7 +54,11 @@ func (f *Framework) Init() *rest.Config { log.SetLogger(zap.New(zap.WriteTo(ginkgo.GinkgoWriter), zap.UseDevMode(true))) ginkgo.By("bootstrapping test environment") f.testEnv = &envtest.Environment{ - CRDDirectoryPaths: []string{filepath.Join("..", "..", "..", "manifests", "v2", "base", "crds")}, + CRDDirectoryPaths: []string{ + filepath.Join("..", "..", "..", "manifests", "v2", "base", "crds"), + filepath.Join("..", "..", "..", "dep-crds", "scheduler-plugins", "crd.yaml"), + filepath.Join("..", "..", "..", "dep-crds", "jobset-operator"), + }, WebhookInstallOptions: envtest.WebhookInstallOptions{ Paths: []string{filepath.Join("..", "..", "..", "manifests", "v2", "base", "webhook")}, }, @@ -67,8 +73,8 @@ func (f *Framework) Init() *rest.Config { func (f *Framework) RunManager(cfg *rest.Config) (context.Context, client.Client) { webhookInstallOpts := &f.testEnv.WebhookInstallOptions gomega.ExpectWithOffset(1, kubeflowv2.AddToScheme(scheme.Scheme)).NotTo(gomega.HaveOccurred()) - - // +kubebuilder:scaffold:scheme + gomega.ExpectWithOffset(1, jobsetv1alpha2.AddToScheme(scheme.Scheme)).NotTo(gomega.HaveOccurred()) + gomega.ExpectWithOffset(1, schedulerpluginsv1alpha1.AddToScheme(scheme.Scheme)).NotTo(gomega.HaveOccurred()) k8sClient, err := client.New(cfg, client.Options{Scheme: scheme.Scheme}) gomega.ExpectWithOffset(1, err).NotTo(gomega.HaveOccurred()) diff --git a/test/util/constants.go b/test/util/constants.go new file mode 100644 index 0000000000..a0b9d8a665 --- /dev/null +++ b/test/util/constants.go @@ -0,0 +1,37 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package util + +import ( + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "time" +) + +const ( + Timeout = 5 * time.Second + ConsistentDuration = time.Second + Interval = time.Millisecond * 250 +) + +var ( + IgnoreObjectMetadata = cmp.Options{ + cmpopts.IgnoreTypes(metav1.TypeMeta{}), + cmpopts.IgnoreFields(metav1.ObjectMeta{}, "UID", "ResourceVersion", "Generation", "CreationTimestamp", "ManagedFields"), + } +)