Skip to content

Commit

Permalink
add validation for two nproc_per_node, use auto for defaulter
Browse files Browse the repository at this point in the history
  • Loading branch information
kuizhiqing committed Jun 30, 2023
1 parent aae9541 commit e3953d6
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 10 deletions.
5 changes: 5 additions & 0 deletions pkg/apis/kubeflow.org/v1/pytorch_defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,9 @@ func SetDefaults_PyTorchJob(job *PyTorchJob) {
}
// Set default elastic policy.
setElasticPolicy(job)

if job.Spec.NprocPerNode == nil {
defaultNprocPerNode := "auto"
job.Spec.NprocPerNode = &defaultNprocPerNode
}
}
4 changes: 4 additions & 0 deletions pkg/apis/kubeflow.org/v1/pytorch_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ type PyTorchJob struct {
Status JobStatus `json:"status,omitempty"`
}

// For PyTorch launch/run related spec declaration, please see the following doc for more detail:
// https://pytorch.org/docs/stable/elastic/run.html
// Or run command `torchrun --help` for a brief description.

// PyTorchJobSpec is a desired state description of the PyTorchJob.
type PyTorchJobSpec struct {
// RunPolicy encapsulates various runtime policies of the distributed training
Expand Down
7 changes: 7 additions & 0 deletions pkg/apis/kubeflow.org/v1/pytorch_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ func ValidateV1PyTorchJob(pytorchJob *PyTorchJob) error {
return nil
}

func validateNprocPerNode(pytorchJob *PyTorchJob) error {
if pytorchJob.Spec.NprocPerNode != nil && pytorchJob.Spec.ElasticPolicy.NProcPerNode != nil {
return fmt.Errorf(".spec.elasticPolicy.nProcPerNode is deprecated, use .spec.nprocPerNode instead")
}
return nil
}

func validatePyTorchReplicaSpecs(specs map[ReplicaType]*ReplicaSpec) error {
if specs == nil {
return fmt.Errorf("PyTorchJobSpec is not valid")
Expand Down
10 changes: 1 addition & 9 deletions pkg/controller.v1/pytorch/envvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func setPodEnv(obj interface{}, podTemplateSpec *corev1.PodTemplateSpec, rtype,
})
podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env, corev1.EnvVar{
Name: EnvNprocPerNode,
Value: getNprocPerNodeEnv(pytorchjob),
Value: *pytorchjob.Spec.NprocPerNode,
})
podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env, corev1.EnvVar{
Name: EnvNodeRank,
Expand Down Expand Up @@ -134,14 +134,6 @@ func getNprocPerNodeInt(job *kubeflowv1.PyTorchJob) int {
return 1
}

func getNprocPerNodeEnv(job *kubeflowv1.PyTorchJob) string {
if job.Spec.NprocPerNode == nil {
return "auto"
} else {
return *job.Spec.NprocPerNode
}
}

func getTotalReplicas(job *kubeflowv1.PyTorchJob) int32 {
jobReplicas := int32(0)
for _, r := range job.Spec.PyTorchReplicaSpecs {
Expand Down
2 changes: 1 addition & 1 deletion pkg/controller.v1/pytorch/pytorchjob_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ var _ = Describe("PyTorchJob controller", func() {
},
},
}
job.Spec.NprocPerNode = &nprocPerNode
job.Spec.NprocPerNode = nil

Expect(testK8sClient.Create(ctx, job)).Should(Succeed())

Expand Down

0 comments on commit e3953d6

Please sign in to comment.