Skip to content

Commit

Permalink
Use IsManagedByKueue on LeaderWorkerSet.
Browse files Browse the repository at this point in the history
  • Loading branch information
mbobrovskyi committed Jan 17, 2025
1 parent b482fff commit e90983f
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 31 deletions.
11 changes: 11 additions & 0 deletions pkg/controller/jobframework/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1"
apivalidation "k8s.io/apimachinery/pkg/api/validation"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/util/validation"
"k8s.io/apimachinery/pkg/util/validation/field"
Expand Down Expand Up @@ -172,3 +173,13 @@ func ValidateImmutablePodSpec(newPodSpec *corev1.PodSpec, oldPodSpec *corev1.Pod

return apivalidation.ValidateImmutableField(mungedPodSpec, oldPodSpec, fieldPath)
}

func IsManagedByKueue(obj client.Object) bool {
objectOwner := metav1.GetControllerOf(obj)
if objectOwner != nil && IsOwnerManagedByKueue(objectOwner) {
return false
} else if QueueNameForObject(obj) != "" {
return true
}
return false
}
21 changes: 14 additions & 7 deletions pkg/controller/jobs/leaderworkerset/leaderworkerset_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ package leaderworkerset

import (
"context"
"strings"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
Expand All @@ -38,13 +40,14 @@ const (

func init() {
utilruntime.Must(jobframework.RegisterIntegration(FrameworkName, jobframework.IntegrationCallbacks{
SetupIndexes: SetupIndexes,
NewReconciler: NewPodReconciler,
SetupWebhook: SetupWebhook,
JobType: &leaderworkersetv1.LeaderWorkerSet{},
AddToScheme: leaderworkersetv1.AddToScheme,
DependencyList: []string{"pod"},
GVK: gvk,
SetupIndexes: SetupIndexes,
NewReconciler: NewPodReconciler,
SetupWebhook: SetupWebhook,
JobType: &leaderworkersetv1.LeaderWorkerSet{},
AddToScheme: leaderworkersetv1.AddToScheme,
DependencyList: []string{"pod"},
IsManagingObjectsOwner: isLeaderWorkerSet,
GVK: gvk,
}))
}

Expand All @@ -54,6 +57,10 @@ func fromObject(o runtime.Object) *LeaderWorkerSet {
return (*LeaderWorkerSet)(o.(*leaderworkersetv1.LeaderWorkerSet))
}

func isLeaderWorkerSet(ref *metav1.OwnerReference) bool {
return ref.Kind == "LeaderWorkerSet" && strings.HasPrefix(ref.APIVersion, "leaderworkerset.x-k8s.io/v1")
}

func (lws *LeaderWorkerSet) Object() client.Object {
return (*leaderworkersetv1.LeaderWorkerSet)(lws)
}
Expand Down
12 changes: 8 additions & 4 deletions pkg/controller/jobs/leaderworkerset/leaderworkerset_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ func (wh *Webhook) ValidateCreate(ctx context.Context, obj runtime.Object) (warn
log.V(5).Info("Validating create")

allErrs := jobframework.ValidateQueueName(lws.Object())
allErrs = append(allErrs, validateStartupPolicy(lws)...)

if jobframework.IsManagedByKueue(lws.Object()) {
allErrs = append(allErrs, validateStartupPolicy(lws)...)
}

return nil, allErrs.ToAggregate()
}
Expand All @@ -125,9 +128,10 @@ func (wh *Webhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Ob
jobframework.QueueNameForObject(oldLeaderWorkerSet.Object()),
queueNameLabelPath,
)
allErrs = append(allErrs, validateStartupPolicy(newLeaderWorkerSet)...)

if jobframework.QueueNameForObject(oldLeaderWorkerSet.Object()) != "" {
if jobframework.IsManagedByKueue(newLeaderWorkerSet.Object()) {
allErrs = append(allErrs, validateStartupPolicy(newLeaderWorkerSet)...)

allErrs = append(allErrs, validateImmutablePodTemplateSpec(
newLeaderWorkerSet.Spec.LeaderWorkerTemplate.LeaderTemplate,
oldLeaderWorkerSet.Spec.LeaderWorkerTemplate.LeaderTemplate,
Expand Down Expand Up @@ -155,7 +159,7 @@ func GetWorkloadName(lws *leaderworkersetv1.LeaderWorkerSet, groupIndex string)
func validateStartupPolicy(lws *LeaderWorkerSet) field.ErrorList {
allErrors := field.ErrorList{}
// TODO(#3232): Support LeaderReady StartupPolicy
if jobframework.QueueNameForObject(lws.Object()) != "" && lws.Spec.StartupPolicy == leaderworkersetv1.LeaderReadyStartupPolicy {
if lws.Spec.StartupPolicy == leaderworkersetv1.LeaderReadyStartupPolicy {
allErrors = append(allErrors,
field.Invalid(startupPolicyPath, lws.Spec.StartupPolicy, "only the LeaderCreated startup policy is allowed when using the kueue.x-k8s.io/queue-name label or annotation"),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,18 @@ import (

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
awv1beta2 "github.com/project-codeflare/appwrapper/api/v1beta2"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
leaderworkersetv1 "sigs.k8s.io/lws/api/leaderworkerset/v1"

"sigs.k8s.io/kueue/pkg/cache"
"sigs.k8s.io/kueue/pkg/controller/jobframework"
"sigs.k8s.io/kueue/pkg/controller/jobs/appwrapper"
podcontroller "sigs.k8s.io/kueue/pkg/controller/jobs/pod"
"sigs.k8s.io/kueue/pkg/features"
"sigs.k8s.io/kueue/pkg/queue"
Expand Down Expand Up @@ -124,9 +128,10 @@ func TestDefault(t *testing.T) {

func TestValidateCreate(t *testing.T) {
testCases := map[string]struct {
lws *leaderworkersetv1.LeaderWorkerSet
wantErr error
wantWarns admission.Warnings
integrations []string
lws *leaderworkersetv1.LeaderWorkerSet
wantErr error
wantWarns admission.Warnings
}{
"without queue": {
lws: testingleaderworkerset.MakeLeaderWorkerSet("test-lws", "").
Expand Down Expand Up @@ -167,11 +172,26 @@ func TestValidateCreate(t *testing.T) {
&field.Error{Type: field.ErrorTypeInvalid, Field: "spec.startupPolicy"},
}.ToAggregate(),
},
"leader ready startup policy with owner reference": {
integrations: []string{appwrapper.FrameworkName},
lws: testingleaderworkerset.MakeLeaderWorkerSet("test-lws", "").
WithOwnerReference(metav1.OwnerReference{
APIVersion: awv1beta2.GroupVersion.String(),
Kind: "AppWrapper",
Controller: ptr.To(true),
}).
LeaderTemplate(corev1.PodTemplateSpec{}).
StartupPolicy(leaderworkersetv1.LeaderReadyStartupPolicy).
Obj(),
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
t.Cleanup(jobframework.EnableIntegrationsForTest(t, "pod"))
for _, integration := range tc.integrations {
jobframework.EnableIntegrationsForTest(t, integration)
}
builder := utiltesting.NewClientBuilder()
client := builder.Build()
w := &Webhook{client: client}
Expand All @@ -189,9 +209,10 @@ func TestValidateCreate(t *testing.T) {

func TestValidateUpdate(t *testing.T) {
testCases := map[string]struct {
oldObj *leaderworkersetv1.LeaderWorkerSet
newObj *leaderworkersetv1.LeaderWorkerSet
wantErr error
integrations []string
oldObj *leaderworkersetv1.LeaderWorkerSet
newObj *leaderworkersetv1.LeaderWorkerSet
wantErr error
}{
"no changes": {
oldObj: testingleaderworkerset.MakeLeaderWorkerSet("test-lws", "").
Expand Down Expand Up @@ -223,6 +244,46 @@ func TestValidateUpdate(t *testing.T) {
},
}.ToAggregate(),
},
"change startup policy without queue-name": {
oldObj: testingleaderworkerset.MakeLeaderWorkerSet("test-lws", "").
StartupPolicy(leaderworkersetv1.LeaderCreatedStartupPolicy).
Obj(),
newObj: testingleaderworkerset.MakeLeaderWorkerSet("test-lws", "").
StartupPolicy(leaderworkersetv1.LeaderReadyStartupPolicy).
Obj(),
},
"change startup policy with queue-name": {
oldObj: testingleaderworkerset.MakeLeaderWorkerSet("test-lws", "").
StartupPolicy(leaderworkersetv1.LeaderCreatedStartupPolicy).
Queue("test-queue").
Obj(),
newObj: testingleaderworkerset.MakeLeaderWorkerSet("test-lws", "").
StartupPolicy(leaderworkersetv1.LeaderReadyStartupPolicy).
Queue("test-queue").
Obj(),
wantErr: field.ErrorList{
&field.Error{
Type: field.ErrorTypeInvalid,
Field: startupPolicyPath.String(),
},
}.ToAggregate(),
},
"change startup policy with owner reference": {
integrations: []string{appwrapper.FrameworkName},
oldObj: testingleaderworkerset.MakeLeaderWorkerSet("test-lws", "").
StartupPolicy(leaderworkersetv1.LeaderCreatedStartupPolicy).
Queue("test-queue").
Obj(),
newObj: testingleaderworkerset.MakeLeaderWorkerSet("test-lws", "").
WithOwnerReference(metav1.OwnerReference{
APIVersion: awv1beta2.GroupVersion.String(),
Kind: "AppWrapper",
Controller: ptr.To(true),
}).
Queue("test-queue").
StartupPolicy(leaderworkersetv1.LeaderReadyStartupPolicy).
Obj(),
},
"change image": {
oldObj: testingleaderworkerset.MakeLeaderWorkerSet("test-lws", "").
LeaderTemplate(corev1.PodTemplateSpec{
Expand Down Expand Up @@ -515,6 +576,10 @@ func TestValidateUpdate(t *testing.T) {

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
for _, integration := range tc.integrations {
jobframework.EnableIntegrationsForTest(t, integration)
}

ctx := context.Background()

wh := &Webhook{}
Expand Down
13 changes: 1 addition & 12 deletions pkg/controller/jobs/statefulset/statefulset_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (

appsv1 "k8s.io/api/apps/v1"
apivalidation "k8s.io/apimachinery/pkg/api/validation"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/validation/field"
Expand Down Expand Up @@ -143,7 +142,7 @@ func (wh *Webhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Ob
groupNameLabelPath,
)...)

if isManagedByKueue(newStatefulSet.Object()) {
if jobframework.IsManagedByKueue(newStatefulSet.Object()) {
oldReplicas := ptr.Deref(oldStatefulSet.Spec.Replicas, 1)
newReplicas := ptr.Deref(newStatefulSet.Spec.Replicas, 1)

Expand Down Expand Up @@ -173,13 +172,3 @@ func GetWorkloadName(statefulSetName string) string {
// Passing empty UID as it is not available before object creation
return jobframework.GetWorkloadNameForOwnerWithGVK(statefulSetName, "", gvk)
}

func isManagedByKueue(obj client.Object) bool {
objectOwner := metav1.GetControllerOf(obj)
if objectOwner != nil && jobframework.IsOwnerManagedByKueue(objectOwner) {
return false
} else if jobframework.QueueNameForObject(obj) != "" {
return true
}
return false
}
40 changes: 38 additions & 2 deletions pkg/controller/jobs/statefulset/statefulset_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ import (
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
leaderworkersetv1 "sigs.k8s.io/lws/api/leaderworkerset/v1"

"sigs.k8s.io/kueue/pkg/cache"
"sigs.k8s.io/kueue/pkg/controller/constants"
"sigs.k8s.io/kueue/pkg/controller/jobframework"
"sigs.k8s.io/kueue/pkg/controller/jobs/appwrapper"
"sigs.k8s.io/kueue/pkg/controller/jobs/leaderworkerset"
"sigs.k8s.io/kueue/pkg/controller/jobs/pod"
"sigs.k8s.io/kueue/pkg/features"
"sigs.k8s.io/kueue/pkg/queue"
Expand Down Expand Up @@ -397,7 +399,7 @@ func TestValidateUpdate(t *testing.T) {
Replicas(4).
Obj(),
},
"change in replicas (scale up with ownerReference while the previous scaling operation is still in progress)": {
"change in replicas (scale up with AppWrapper ownerReference while the previous scaling operation is still in progress)": {
integrations: []string{appwrapper.FrameworkName},
oldObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns").
Queue("test-queue").
Expand All @@ -415,7 +417,7 @@ func TestValidateUpdate(t *testing.T) {
StatusReplicas(1).
Obj(),
},
"change in replicas (scale up with ownerReference)": {
"change in replicas (scale up with AppWrapper ownerReference)": {
integrations: []string{appwrapper.FrameworkName},
oldObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns").
Queue("test-queue").
Expand All @@ -431,6 +433,40 @@ func TestValidateUpdate(t *testing.T) {
Replicas(4).
Obj(),
},
"change in replicas (scale up with LeaderWorkerSet ownerReference while the previous scaling operation is still in progress)": {
integrations: []string{leaderworkerset.FrameworkName},
oldObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns").
Queue("test-queue").
Replicas(0).
StatusReplicas(3).
Obj(),
newObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns").
WithOwnerReference(metav1.OwnerReference{
APIVersion: leaderworkersetv1.GroupVersion.String(),
Kind: "LeaderWorkerSet",
Controller: ptr.To(true),
}).
Queue("test-queue").
Replicas(3).
StatusReplicas(1).
Obj(),
},
"change in replicas (scale up with LeaderWorkerSet ownerReference)": {
integrations: []string{leaderworkerset.FrameworkName},
oldObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns").
Queue("test-queue").
Replicas(3).
Obj(),
newObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns").
WithOwnerReference(metav1.OwnerReference{
APIVersion: leaderworkersetv1.GroupVersion.String(),
Kind: "LeaderWorkerSet",
Controller: ptr.To(true),
}).
Queue("test-queue").
Replicas(4).
Obj(),
},
}

for name, tc := range testCases {
Expand Down
5 changes: 5 additions & 0 deletions pkg/util/testingjobs/leaderworkerset/wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ func (w *LeaderWorkerSetWrapper) Name(n string) *LeaderWorkerSetWrapper {
return w
}

func (w *LeaderWorkerSetWrapper) WithOwnerReference(ownerReference metav1.OwnerReference) *LeaderWorkerSetWrapper {
w.OwnerReferences = append(w.OwnerReferences, ownerReference)
return w
}

func (w *LeaderWorkerSetWrapper) StartupPolicy(startupPolicyType leaderworkersetv1.StartupPolicyType) *LeaderWorkerSetWrapper {
w.Spec.StartupPolicy = startupPolicyType
return w
Expand Down

0 comments on commit e90983f

Please sign in to comment.