diff --git a/go.mod b/go.mod index 2f2933948..80763df0f 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/openshift/api v0.0.0-20230823114715-5fdd7511b790 github.com/openshift/client-go v0.0.0-20221019143426-16aed247da5c github.com/project-codeflare/appwrapper v0.30.0 - github.com/project-codeflare/codeflare-common v0.0.0-20241216183607-222395d38924 + github.com/project-codeflare/codeflare-common v0.0.0-20250117134355-5748d670cd4a github.com/ray-project/kuberay/ray-operator v1.2.1 go.uber.org/zap v1.27.0 golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 diff --git a/go.sum b/go.sum index e71bd4404..d87400841 100644 --- a/go.sum +++ b/go.sum @@ -226,8 +226,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/project-codeflare/appwrapper v0.30.0 h1:tb9LJ/QmC2xyKdM0oVf+WAz9cKIGt3gllDrRgzySgyo= github.com/project-codeflare/appwrapper v0.30.0/go.mod h1:7FpO90DLv0BAq4rwZtXKS9aRRfkR9RvXsj3pgYF0HtQ= -github.com/project-codeflare/codeflare-common v0.0.0-20241216183607-222395d38924 h1:jM+gYqn8eGmUoeQLGGYxlJgXZ1gbZgB2UtpKU9z0x9s= -github.com/project-codeflare/codeflare-common v0.0.0-20241216183607-222395d38924/go.mod h1:DPSv5khRiRDFUD43SF8da+MrVQTWmxNhuKJmwSLOyO0= +github.com/project-codeflare/codeflare-common v0.0.0-20250117134355-5748d670cd4a h1:1F5xsxncIL5Bpboup8d5osQ8iWy/hzkCTtGSBZM2tQM= +github.com/project-codeflare/codeflare-common v0.0.0-20250117134355-5748d670cd4a/go.mod h1:DPSv5khRiRDFUD43SF8da+MrVQTWmxNhuKJmwSLOyO0= github.com/prometheus/client_golang v1.20.4 h1:Tgh3Yr67PaOv/uTqloMsCEdeuFTatm5zIq5+qNN23vI= github.com/prometheus/client_golang v1.20.4/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= diff --git a/pkg/controllers/raycluster_controller.go b/pkg/controllers/raycluster_controller.go index 902570bba..55bd34657 100644 --- a/pkg/controllers/raycluster_controller.go +++ b/pkg/controllers/raycluster_controller.go @@ -219,8 +219,11 @@ func (r *RayClusterReconciler) Reconcile(ctx context.Context, req ctrl.Request) return ctrl.Result{RequeueAfter: requeueTime}, err } - if err := r.deleteHeadPodIfMissingImagePullSecrets(ctx, cluster); err != nil { - return ctrl.Result{RequeueAfter: requeueTime}, err + if len(cluster.Spec.HeadGroupSpec.Template.Spec.ImagePullSecrets) == 0 { + // Delete head pod only if user doesn't specify own imagePullSecrets and imagePullSecrets from OAuth ServiceAccount are not present in the head Pod + if err := r.deleteHeadPodIfMissingImagePullSecrets(ctx, cluster); err != nil { + return ctrl.Result{RequeueAfter: requeueTime}, err + } } _, err = r.kubeClient.RbacV1().ClusterRoleBindings().Apply(ctx, desiredOAuthClusterRoleBinding(cluster), metav1.ApplyOptions{FieldManager: controllerName, Force: true}) diff --git a/pkg/controllers/raycluster_controller_test.go b/pkg/controllers/raycluster_controller_test.go index e0a7e969c..db49adf98 100644 --- a/pkg/controllers/raycluster_controller_test.go +++ b/pkg/controllers/raycluster_controller_test.go @@ -192,6 +192,73 @@ var _ = Describe("RayCluster controller", func() { }).WithTimeout(time.Second * 10).Should(Satisfy(errors.IsNotFound)) }) + It("should not delete the head pod if RayCluster CR provides image pull secrets", func(ctx SpecContext) { + By("creating an instance of the RayCluster CR with imagePullSecret") + rayclusterWithPullSecret := &rayv1.RayCluster{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pull-secret-cluster", + Namespace: namespaceName, + }, + Spec: rayv1.RayClusterSpec{ + HeadGroupSpec: rayv1.HeadGroupSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + ImagePullSecrets: []corev1.LocalObjectReference{{Name: "custom-pull-secret"}}, + Containers: []corev1.Container{}, + }, + }, + RayStartParams: map[string]string{}, + }, + }, + } + _, err := rayClient.RayV1().RayClusters(namespaceName).Create(ctx, rayclusterWithPullSecret, metav1.CreateOptions{}) + Expect(err).To(Not(HaveOccurred())) + + Eventually(func() (*corev1.ServiceAccount, error) { + return k8sClient.CoreV1().ServiceAccounts(namespaceName).Get(ctx, oauthServiceAccountNameFromCluster(rayclusterWithPullSecret), metav1.GetOptions{}) + }).WithTimeout(time.Second * 10).Should(WithTransform(OwnerReferenceKind, Equal("RayCluster"))) + + headPodName := "head-pod" + headPod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: headPodName, + Namespace: namespaceName, + Labels: map[string]string{ + "ray.io/node-type": "head", + "ray.io/cluster": rayclusterWithPullSecret.Name, + }, + }, + Spec: corev1.PodSpec{ + ImagePullSecrets: []corev1.LocalObjectReference{ + {Name: "custom-pull-secret"}, + }, + Containers: []corev1.Container{ + { + Name: "head-container", + Image: "busybox", + }, + }, + }, + } + _, err = k8sClient.CoreV1().Pods(namespaceName).Create(ctx, headPod, metav1.CreateOptions{}) + Expect(err).To(Not(HaveOccurred())) + + Eventually(func() (*corev1.Pod, error) { + return k8sClient.CoreV1().Pods(namespaceName).Get(ctx, headPodName, metav1.GetOptions{}) + }).WithTimeout(time.Second * 10).ShouldNot(BeNil()) + + sa, err := k8sClient.CoreV1().ServiceAccounts(namespaceName).Get(ctx, oauthServiceAccountNameFromCluster(rayclusterWithPullSecret), metav1.GetOptions{}) + Expect(err).To(Not(HaveOccurred())) + + sa.ImagePullSecrets = append(sa.ImagePullSecrets, corev1.LocalObjectReference{Name: "test-image-pull-secret"}) + _, err = k8sClient.CoreV1().ServiceAccounts(namespaceName).Update(ctx, sa, metav1.UpdateOptions{}) + Expect(err).To(Not(HaveOccurred())) + + Consistently(func() (*corev1.Pod, error) { + return k8sClient.CoreV1().Pods(namespaceName).Get(ctx, headPodName, metav1.GetOptions{}) + }).WithTimeout(time.Second * 5).Should(Not(BeNil())) + }) + It("should remove CRB when the RayCluster is deleted", func(ctx SpecContext) { foundRayCluster, err := rayClient.RayV1().RayClusters(namespaceName).Get(ctx, rayClusterName, metav1.GetOptions{}) Expect(err).To(Not(HaveOccurred())) diff --git a/test/e2e/mnist_rayjob_raycluster_test.go b/test/e2e/mnist_rayjob_raycluster_test.go index 0b2b01761..3a31c7135 100644 --- a/test/e2e/mnist_rayjob_raycluster_test.go +++ b/test/e2e/mnist_rayjob_raycluster_test.go @@ -211,6 +211,52 @@ func runMnistRayJobRayClusterAppWrapper(t *testing.T, accelerator string, number test.Eventually(AppWrappers(test, namespace), TestTimeoutShort).Should(BeEmpty()) } +// Verifying https://github.com/project-codeflare/codeflare-operator/issues/649 +func TestRayClusterImagePullSecret(t *testing.T) { + test := With(t) + + // Create a namespace + namespace := test.NewTestNamespace() + + // Create Kueue resources + resourceFlavor := CreateKueueResourceFlavor(test, v1beta1.ResourceFlavorSpec{}) + defer func() { + _ = test.Client().Kueue().KueueV1beta1().ResourceFlavors().Delete(test.Ctx(), resourceFlavor.Name, metav1.DeleteOptions{}) + }() + clusterQueue := createClusterQueue(test, resourceFlavor, 0) + defer func() { + _ = test.Client().Kueue().KueueV1beta1().ClusterQueues().Delete(test.Ctx(), clusterQueue.Name, metav1.DeleteOptions{}) + }() + CreateKueueLocalQueue(test, namespace.Name, clusterQueue.Name, AsDefaultQueue) + + // Create ServiceAccount, wait until corresponding imagePullSecret is available + sa := CreateServiceAccount(test, namespace.Name) + test.Eventually(ServiceAccount(test, sa.Namespace, sa.Name), TestTimeoutShort). + Should( + HaveField("ImagePullSecrets", HaveLen(1)), + ) + sa = GetServiceAccount(test, sa.Namespace, sa.Name) + + // Create MNIST training script + mnist := constructMNISTConfigMap(test, namespace) + mnist, err := test.Client().Core().CoreV1().ConfigMaps(namespace.Name).Create(test.Ctx(), mnist, metav1.CreateOptions{}) + test.Expect(err).NotTo(HaveOccurred()) + test.T().Logf("Created ConfigMap %s/%s successfully", mnist.Namespace, mnist.Name) + + // Create RayCluster with imagePullSecret and assign it to the localqueue + rayCluster := constructRayCluster(test, namespace, mnist, 0) + rayCluster.Spec.HeadGroupSpec.Template.Spec.ImagePullSecrets = sa.ImagePullSecrets + rayCluster, err = test.Client().Ray().RayV1().RayClusters(namespace.Name).Create(test.Ctx(), rayCluster, metav1.CreateOptions{}) + test.Expect(err).NotTo(HaveOccurred()) + test.T().Logf("Created RayCluster %s/%s successfully", rayCluster.Namespace, rayCluster.Name) + + test.T().Logf("Waiting for RayCluster %s/%s to be running", rayCluster.Namespace, rayCluster.Name) + test.Eventually(RayCluster(test, namespace.Name, rayCluster.Name), TestTimeoutMedium). + Should(WithTransform(RayClusterState, Equal(rayv1.Ready))) +} + +// Helper functions + func constructMNISTConfigMap(test Test, namespace *corev1.Namespace) *corev1.ConfigMap { return &corev1.ConfigMap{ TypeMeta: metav1.TypeMeta{ @@ -274,11 +320,11 @@ func constructRayCluster(_ Test, namespace *corev1.Namespace, mnist *corev1.Conf Resources: corev1.ResourceRequirements{ Requests: corev1.ResourceList{ corev1.ResourceCPU: resource.MustParse("250m"), - corev1.ResourceMemory: resource.MustParse("512Mi"), + corev1.ResourceMemory: resource.MustParse("2G"), }, Limits: corev1.ResourceList{ corev1.ResourceCPU: resource.MustParse("1"), - corev1.ResourceMemory: resource.MustParse("2G"), + corev1.ResourceMemory: resource.MustParse("4G"), }, }, },