diff --git a/pkg/controller/jobs/statefulset/statefulset_webhook.go b/pkg/controller/jobs/statefulset/statefulset_webhook.go index b4cea359353..dbae86fc5b6 100644 --- a/pkg/controller/jobs/statefulset/statefulset_webhook.go +++ b/pkg/controller/jobs/statefulset/statefulset_webhook.go @@ -142,21 +142,23 @@ func (wh *Webhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Ob groupNameLabelPath, )...) - oldReplicas := ptr.Deref(oldStatefulSet.Spec.Replicas, 1) - newReplicas := ptr.Deref(newStatefulSet.Spec.Replicas, 1) - - // Allow only scale down to zero and scale up from zero. - // TODO(#3279): Support custom resizes later - if newReplicas != 0 && oldReplicas != 0 { - allErrs = append(allErrs, apivalidation.ValidateImmutableField( - newStatefulSet.Spec.Replicas, - oldStatefulSet.Spec.Replicas, - replicasPath, - )...) - } + if newQueueName != "" { + oldReplicas := ptr.Deref(oldStatefulSet.Spec.Replicas, 1) + newReplicas := ptr.Deref(newStatefulSet.Spec.Replicas, 1) + + // Allow only scale down to zero and scale up from zero. + // TODO(#3279): Support custom resizes later + if newReplicas != 0 && oldReplicas != 0 { + allErrs = append(allErrs, apivalidation.ValidateImmutableField( + newStatefulSet.Spec.Replicas, + oldStatefulSet.Spec.Replicas, + replicasPath, + )...) + } - if oldReplicas == 0 && newReplicas > 0 && newStatefulSet.Status.Replicas > 0 { - allErrs = append(allErrs, field.Forbidden(replicasPath, "scaling down is still in progress")) + if oldReplicas == 0 && newReplicas > 0 && newStatefulSet.Status.Replicas > 0 { + allErrs = append(allErrs, field.Forbidden(replicasPath, "scaling down is still in progress")) + } } return warnings, allErrs.ToAggregate() diff --git a/pkg/controller/jobs/statefulset/statefulset_webhook_test.go b/pkg/controller/jobs/statefulset/statefulset_webhook_test.go index e400623bb99..b5bfc337728 100644 --- a/pkg/controller/jobs/statefulset/statefulset_webhook_test.go +++ b/pkg/controller/jobs/statefulset/statefulset_webhook_test.go @@ -342,22 +342,16 @@ func TestValidateUpdate(t *testing.T) { }, }, "change in replicas (scale up while the previous scaling operation is still in progress)": { - oldObj: &appsv1.StatefulSet{ - Spec: appsv1.StatefulSetSpec{ - Replicas: ptr.To(int32(0)), - }, - Status: appsv1.StatefulSetStatus{ - Replicas: 3, - }, - }, - newObj: &appsv1.StatefulSet{ - Spec: appsv1.StatefulSetSpec{ - Replicas: ptr.To(int32(3)), - }, - Status: appsv1.StatefulSetStatus{ - Replicas: 1, - }, - }, + oldObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns"). + Queue("test-queue"). + Replicas(0). + StatusReplicas(3). + Obj(), + newObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns"). + Queue("test-queue"). + Replicas(3). + StatusReplicas(1). + Obj(), wantErr: field.ErrorList{ &field.Error{ Type: field.ErrorTypeForbidden, @@ -366,16 +360,14 @@ func TestValidateUpdate(t *testing.T) { }.ToAggregate(), }, "change in replicas (scale up)": { - oldObj: &appsv1.StatefulSet{ - Spec: appsv1.StatefulSetSpec{ - Replicas: ptr.To(int32(3)), - }, - }, - newObj: &appsv1.StatefulSet{ - Spec: appsv1.StatefulSetSpec{ - Replicas: ptr.To(int32(4)), - }, - }, + oldObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns"). + Queue("test-queue"). + Replicas(3). + Obj(), + newObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns"). + Queue("test-queue"). + Replicas(4). + Obj(), wantErr: field.ErrorList{ &field.Error{ Type: field.ErrorTypeInvalid, @@ -383,6 +375,25 @@ func TestValidateUpdate(t *testing.T) { }, }.ToAggregate(), }, + + "change in replicas (scale up without queue-name while the previous scaling operation is still in progress)": { + oldObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns"). + Replicas(0). + StatusReplicas(3). + Obj(), + newObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns"). + Replicas(3). + StatusReplicas(1). + Obj(), + }, + "change in replicas (scale up without queue-name)": { + oldObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns"). + Replicas(3). + Obj(), + newObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns"). + Replicas(4). + Obj(), + }, } for name, tc := range testCases { diff --git a/pkg/util/testingjobs/statefulset/wrappers.go b/pkg/util/testingjobs/statefulset/wrappers.go index fe0c4214fd9..3072eabd951 100644 --- a/pkg/util/testingjobs/statefulset/wrappers.go +++ b/pkg/util/testingjobs/statefulset/wrappers.go @@ -134,6 +134,11 @@ func (ss *StatefulSetWrapper) Replicas(r int32) *StatefulSetWrapper { return ss } +func (ss *StatefulSetWrapper) StatusReplicas(r int32) *StatefulSetWrapper { + ss.Status.Replicas = r + return ss +} + func (ss *StatefulSetWrapper) PodTemplateSpecPodGroupNameLabel( ownerName string, ownerUID types.UID, ownerGVK schema.GroupVersionKind, ) *StatefulSetWrapper {