diff --git a/apis/kueue/v1alpha1/workload_webhook.go b/apis/kueue/v1alpha1/workload_webhook.go index e5bdf87c56..1b206a2ac3 100644 --- a/apis/kueue/v1alpha1/workload_webhook.go +++ b/apis/kueue/v1alpha1/workload_webhook.go @@ -18,6 +18,8 @@ package v1alpha1 import ( "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/validation" + "k8s.io/apimachinery/pkg/util/validation/field" "k8s.io/klog/v2" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/webhook" @@ -58,15 +60,45 @@ var _ webhook.Validator = &Workload{} // ValidateCreate implements webhook.Validator so a webhook will be registered for the type func (r *Workload) ValidateCreate() error { - return nil + return ValidateWorkload(r).ToAggregate() } // ValidateUpdate implements webhook.Validator so a webhook will be registered for the type func (r *Workload) ValidateUpdate(old runtime.Object) error { - return nil + return ValidateWorkload(r).ToAggregate() } // ValidateDelete implements webhook.Validator so a webhook will be registered for the type func (r *Workload) ValidateDelete() error { return nil } + +func ValidateWorkload(obj *Workload) field.ErrorList { + var allErrs field.ErrorList + specField := field.NewPath("spec") + podSetsField := specField.Child("podSets") + if len(obj.Spec.PodSets) == 0 { + allErrs = append(allErrs, field.Required(podSetsField, "at least one podSet is required")) + } + + for i, podSet := range obj.Spec.PodSets { + if podSet.Count <= 0 { + allErrs = append(allErrs, field.Invalid( + podSetsField.Index(i).Child("count"), + podSet.Count, + "count must be greater than 0"), + ) + } + } + + if len(obj.Spec.PriorityClassName) > 0 { + msgs := validation.IsDNS1123Subdomain(obj.Spec.PriorityClassName) + if len(msgs) > 0 { + classNameField := specField.Child("priorityClassName") + for _, msg := range msgs { + allErrs = append(allErrs, field.Invalid(classNameField, obj.Spec.PriorityClassName, msg)) + } + } + } + return allErrs +} diff --git a/apis/kueue/v1alpha1/workload_webhook_test.go b/apis/kueue/v1alpha1/workload_webhook_test.go new file mode 100644 index 0000000000..61a0feaaa6 --- /dev/null +++ b/apis/kueue/v1alpha1/workload_webhook_test.go @@ -0,0 +1,90 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package v1alpha1_test + +// Rename the package to avoid circular dependencies which is caused by "sigs.k8s.io/kueue/pkg/util/testing". +// See also: https://github.com/golang/go/wiki/CodeReviewComments#import-dot + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/validation/field" + + . "sigs.k8s.io/kueue/apis/kueue/v1alpha1" + testingutil "sigs.k8s.io/kueue/pkg/util/testing" +) + +func TestValidateWorkload(t *testing.T) { + const ( + objName = "name" + objNs = "ns" + ) + specField := field.NewPath("spec") + podSetsField := specField.Child("podSets") + testCases := map[string]struct { + workload *Workload + wantErr field.ErrorList + }{ + "should have at least one podSet": { + workload: testingutil.MakeWorkload(objName, objNs).PodSets(nil).Obj(), + wantErr: field.ErrorList{ + field.Required(podSetsField, ""), + }, + }, + "count should be greater than 0": { + workload: testingutil.MakeWorkload(objName, objNs).PodSets([]PodSet{ + { + Name: "main", + Count: -1, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "c", + Resources: corev1.ResourceRequirements{ + Requests: make(corev1.ResourceList), + }, + }, + }, + }, + }, + }).Obj(), + wantErr: field.ErrorList{ + field.Invalid(podSetsField.Index(0).Child("count"), int32(-1), ""), + }, + }, + "should have valid priorityClassName": { + workload: testingutil.MakeWorkload(objName, objNs).PriorityClass("invalid_class").Obj(), + wantErr: field.ErrorList{ + field.Invalid(specField.Child("priorityClassName"), "invalid_class", ""), + }, + }, + } + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + errList := ValidateWorkload(tc.workload) + if len(errList) != 1 { + t.Errorf("Unexpected error: %v, want %v", errList, tc.wantErr) + } + if diff := cmp.Diff(tc.wantErr[0], errList[0], cmpopts.IgnoreFields(field.Error{}, "Detail")); diff != "" { + t.Errorf("ValidateWorkload() mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/util/testing/wrappers.go b/pkg/util/testing/wrappers.go index 6df5edbe13..1a8dab610a 100644 --- a/pkg/util/testing/wrappers.go +++ b/pkg/util/testing/wrappers.go @@ -204,6 +204,11 @@ func (w *WorkloadWrapper) Priority(priority *int32) *WorkloadWrapper { return w } +func (w *WorkloadWrapper) PodSets(podSets []kueue.PodSet) *WorkloadWrapper { + w.Spec.PodSets = podSets + return w +} + // AdmissionWrapper wraps an Admission type AdmissionWrapper struct{ kueue.Admission } diff --git a/test/integration/webhook/v1alpha1/workload_test.go b/test/integration/webhook/v1alpha1/workload_test.go index 4aa2effced..b08f086aea 100644 --- a/test/integration/webhook/v1alpha1/workload_test.go +++ b/test/integration/webhook/v1alpha1/workload_test.go @@ -17,6 +17,7 @@ import ( "github.com/onsi/ginkgo/v2" "github.com/onsi/gomega" corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" @@ -27,6 +28,8 @@ import ( var ns *corev1.Namespace +const workloadName = "workload-test" + var _ = ginkgo.BeforeEach(func() { ns = &corev1.Namespace{ ObjectMeta: metav1.ObjectMeta{ @@ -41,23 +44,10 @@ var _ = ginkgo.AfterEach(func() { }) var _ = ginkgo.Describe("Workload defaulting webhook", func() { - ginkgo.Context("When creating Workload", func() { + ginkgo.Context("When creating a Workload", func() { ginkgo.It("Should set default values", func() { ginkgo.By("Creating a new Workload") - workload := testing.MakeWorkload("workload1", ns.Name).Obj() - workload.Spec.PodSets = []v1alpha1.PodSet{ - { - Count: 2, - Spec: corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "test", - Image: "fake-image", - }, - }, - }, - }, - } + workload := testing.MakeWorkload(workloadName, ns.Name).Obj() gomega.Expect(k8sClient.Create(ctx, workload)).Should(gomega.Succeed()) created := &v1alpha1.Workload{} @@ -70,3 +60,47 @@ var _ = ginkgo.Describe("Workload defaulting webhook", func() { }) }) }) + +var _ = ginkgo.Describe("Workload validating webhook", func() { + ginkgo.Context("When creating a Workload", func() { + ginkgo.It("Should validate Workload", func() { + ginkgo.By("Creating a new Workload") + workload := testing.MakeWorkload(workloadName, ns.Name). + PodSets([]v1alpha1.PodSet{ + { + Name: "main", + Count: -1, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "c", + Resources: corev1.ResourceRequirements{ + Requests: make(corev1.ResourceList), + }, + }, + }, + }, + }, + }). + PriorityClass("invalid_class"). + Obj() + err := k8sClient.Create(ctx, workload) + gomega.Expect(err).Should(gomega.HaveOccurred()) + gomega.Expect(errors.IsForbidden(err)).Should(gomega.BeTrue(), "error: %v", err) + }) + }) + + ginkgo.Context("When updating a Workload", func() { + ginkgo.It("Should validate spec.podSet.count", func() { + ginkgo.By("Creating a new Workload") + workload := testing.MakeWorkload(workloadName, ns.Name).Obj() + gomega.Expect(k8sClient.Create(ctx, workload)).Should(gomega.Succeed()) + + ginkgo.By("Updating the Workload") + workload.Spec.PodSets[0].Count = -1 + err := k8sClient.Update(ctx, workload) + gomega.Expect(err).Should(gomega.HaveOccurred()) + gomega.Expect(errors.IsForbidden(err)).Should(gomega.BeTrue(), "error: %v", err) + }) + }) +})