diff --git a/src/nni_manager/config/kubeflow/tfjob-crd-v1.json b/src/nni_manager/config/kubeflow/tfjob-crd-v1.json new file mode 100644 index 0000000000..bcc9d26002 --- /dev/null +++ b/src/nni_manager/config/kubeflow/tfjob-crd-v1.json @@ -0,0 +1,17 @@ +{ + "kind": "CustomResourceDefinition", + "spec": { + "scope": "Namespaced", + "version": "v1", + "group": "kubeflow.org", + "names": { + "kind": "TFJob", + "plural": "tfjobs", + "singular": "tfjob" + } + }, + "apiVersion": "apiextensions.k8s.io/v1beta1", + "metadata": { + "name": "tfjobs.kubeflow.org" + } +} diff --git a/src/nni_manager/training_service/kubernetes/kubeflow/kubeflowApiClient.ts b/src/nni_manager/training_service/kubernetes/kubeflow/kubeflowApiClient.ts index 60208aa24a..ea942163ae 100644 --- a/src/nni_manager/training_service/kubernetes/kubeflow/kubeflowApiClient.ts +++ b/src/nni_manager/training_service/kubernetes/kubeflow/kubeflowApiClient.ts @@ -65,6 +65,25 @@ class TFOperatorClientV1Beta2 extends KubernetesCRDClient { } } +class TFOperatorClientV1 extends KubernetesCRDClient { + /** + * constructor, to initialize tfjob CRD definition + */ + public constructor() { + super(); + this.crdSchema = JSON.parse(fs.readFileSync('./config/kubeflow/tfjob-crd-v1.json', 'utf8')); + this.client.addCustomResourceDefinition(this.crdSchema); + } + + protected get operator(): any { + return this.client.apis['kubeflow.org'].v1.namespaces('default').tfjobs; + } + + public get containerName(): string { + return 'tensorflow'; + } +} + class PyTorchOperatorClientV1Alpha2 extends KubernetesCRDClient { /** * constructor, to initialize tfjob CRD definition @@ -142,6 +161,9 @@ class KubeflowOperatorClientFactory { case 'v1beta2': { return new TFOperatorClientV1Beta2(); } + case 'v1': { + return new TFOperatorClientV1(); + } default: throw new Error(`Invalid tf-operator apiVersion ${operatorApiVersion}`); }