diff --git a/src/nni_manager/config/kubeflow/pytorchjob-crd-v1.json b/src/nni_manager/config/kubeflow/pytorchjob-crd-v1.json new file mode 100644 index 0000000000..70fed16922 --- /dev/null +++ b/src/nni_manager/config/kubeflow/pytorchjob-crd-v1.json @@ -0,0 +1,17 @@ +{ + "kind": "CustomResourceDefinition", + "spec": { + "scope": "Namespaced", + "version": "v1", + "group": "kubeflow.org", + "names": { + "kind": "PyTorchJob", + "plural": "pytorchjobs", + "singular": "pytorchjob" + } + }, + "apiVersion": "kubeflow.org/v1", + "metadata": { + "name": "pytorchjobs.kubeflow.org" + } +} \ No newline at end of file diff --git a/src/nni_manager/training_service/kubernetes/kubeflow/kubeflowApiClient.ts b/src/nni_manager/training_service/kubernetes/kubeflow/kubeflowApiClient.ts index ea942163ae..d2c52c20a3 100644 --- a/src/nni_manager/training_service/kubernetes/kubeflow/kubeflowApiClient.ts +++ b/src/nni_manager/training_service/kubernetes/kubeflow/kubeflowApiClient.ts @@ -83,7 +83,24 @@ class TFOperatorClientV1 extends KubernetesCRDClient { return 'tensorflow'; } } +class PyTorchOperatorClientV1 extends KubernetesCRDClient { + /** + * constructor, to initialize tfjob CRD definition + */ + public constructor() { + super(); + this.crdSchema = JSON.parse(fs.readFileSync('./config/kubeflow/pytorchjob-crd-v1.json', 'utf8')); + this.client.addCustomResourceDefinition(this.crdSchema); + } + + protected get operator(): any { + return this.client.apis['kubeflow.org'].v1.namespaces('default').pytorchjobs; + } + public get containerName(): string { + return 'pytorch'; + } +} class PyTorchOperatorClientV1Alpha2 extends KubernetesCRDClient { /** * constructor, to initialize tfjob CRD definition @@ -179,6 +196,9 @@ class KubeflowOperatorClientFactory { case 'v1beta2': { return new PyTorchOperatorClientV1Beta2(); } + case 'v1': { + return new PyTorchOperatorClientV1(); + } default: throw new Error(`Invalid pytorch-operator apiVersion ${operatorApiVersion}`); } diff --git a/src/nni_manager/training_service/kubernetes/kubeflow/kubeflowConfig.ts b/src/nni_manager/training_service/kubernetes/kubeflow/kubeflowConfig.ts index 6aea0bb879..89d1c8a19e 100644 --- a/src/nni_manager/training_service/kubernetes/kubeflow/kubeflowConfig.ts +++ b/src/nni_manager/training_service/kubernetes/kubeflow/kubeflowConfig.ts @@ -13,7 +13,7 @@ import { AzureStorage, KeyVaultConfig, KubernetesClusterConfig, KubernetesCluste export type KubeflowOperator = 'tf-operator' | 'pytorch-operator' ; export type DistTrainRole = 'worker' | 'ps' | 'master'; export type KubeflowJobStatus = 'Created' | 'Running' | 'Failed' | 'Succeeded'; -export type OperatorApiVersion = 'v1alpha2' | 'v1beta1' | 'v1beta2'; +export type OperatorApiVersion = 'v1alpha2' | 'v1beta1' | 'v1beta2' | 'v1'; /** * Kubeflow Cluster Configuration