diff --git a/src/nni_manager/common/manager.ts b/src/nni_manager/common/manager.ts index 8a33ce6ab9..c29ac52ce1 100644 --- a/src/nni_manager/common/manager.ts +++ b/src/nni_manager/common/manager.ts @@ -34,6 +34,7 @@ interface ExperimentParams { searchSpace: string; trainingServicePlatform: string; multiPhase?: boolean; + multiThread?: boolean; tuner: { className: string; builtinTunerName?: string; diff --git a/src/nni_manager/common/utils.ts b/src/nni_manager/common/utils.ts index 850b88a652..3636725c07 100644 --- a/src/nni_manager/common/utils.ts +++ b/src/nni_manager/common/utils.ts @@ -158,12 +158,16 @@ function parseArg(names: string[]): string { * @param assessor: similiar as tuner * */ -function getMsgDispatcherCommand(tuner: any, assessor: any, multiPhase: boolean = false): string { +function getMsgDispatcherCommand(tuner: any, assessor: any, multiPhase: boolean = false, multiThread: boolean = false): string { let command: string = `python3 -m nni --tuner_class_name ${tuner.className}`; if (multiPhase) { command += ' --multi_phase'; } + if (multiThread) { + command += ' --multi_thread'; + } + if (tuner.classArgs !== undefined) { command += ` --tuner_args ${JSON.stringify(JSON.stringify(tuner.classArgs))}`; } diff --git a/src/nni_manager/core/commands.ts b/src/nni_manager/core/commands.ts index 19204b2f31..ff6c9840bd 100644 --- a/src/nni_manager/core/commands.ts +++ b/src/nni_manager/core/commands.ts @@ -26,6 +26,7 @@ const ADD_CUSTOMIZED_TRIAL_JOB = 'AD'; const TRIAL_END = 'EN'; const TERMINATE = 'TE'; +const INITIALIZED = 'ID'; const NEW_TRIAL_JOB = 'TR'; const SEND_TRIAL_JOB_PARAMETER = 'SP'; const NO_MORE_TRIAL_JOBS = 'NO'; @@ -39,6 +40,7 @@ const TUNER_COMMANDS: Set = new Set([ ADD_CUSTOMIZED_TRIAL_JOB, TERMINATE, + INITIALIZED, NEW_TRIAL_JOB, SEND_TRIAL_JOB_PARAMETER, NO_MORE_TRIAL_JOBS @@ -61,6 +63,7 @@ export { ADD_CUSTOMIZED_TRIAL_JOB, TRIAL_END, TERMINATE, + INITIALIZED, NEW_TRIAL_JOB, NO_MORE_TRIAL_JOBS, KILL_TRIAL_JOB, diff --git a/src/nni_manager/core/nnimanager.ts b/src/nni_manager/core/nnimanager.ts index 089ddec2e9..3665a95e40 100644 --- a/src/nni_manager/core/nnimanager.ts +++ b/src/nni_manager/core/nnimanager.ts @@ -37,8 +37,8 @@ import { } from '../common/trainingService'; import { delay, getLogDir, getMsgDispatcherCommand } from '../common/utils'; import { - ADD_CUSTOMIZED_TRIAL_JOB, KILL_TRIAL_JOB, NEW_TRIAL_JOB, NO_MORE_TRIAL_JOBS, REPORT_METRIC_DATA, - REQUEST_TRIAL_JOBS, SEND_TRIAL_JOB_PARAMETER, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE + ADD_CUSTOMIZED_TRIAL_JOB, INITIALIZE, INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, NO_MORE_TRIAL_JOBS, + REPORT_METRIC_DATA, REQUEST_TRIAL_JOBS, SEND_TRIAL_JOB_PARAMETER, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE } from './commands'; import { createDispatcherInterface, IpcInterface } from './ipcInterface'; @@ -127,7 +127,8 @@ class NNIManager implements Manager { this.trainingService.setClusterMetadata('multiPhase', expParams.multiPhase.toString()); } - const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.multiPhase); + const dispatcherCommand: string = getMsgDispatcherCommand( + expParams.tuner, expParams.assessor, expParams.multiPhase, expParams.multiThread); this.log.debug(`dispatcher command: ${dispatcherCommand}`); this.setupTuner( //expParams.tuner.tunerCommand, @@ -159,7 +160,8 @@ class NNIManager implements Manager { this.trainingService.setClusterMetadata('multiPhase', expParams.multiPhase.toString()); } - const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.multiPhase); + const dispatcherCommand: string = getMsgDispatcherCommand( + expParams.tuner, expParams.assessor, expParams.multiPhase, expParams.multiThread); this.log.debug(`dispatcher command: ${dispatcherCommand}`); this.setupTuner( dispatcherCommand, @@ -419,16 +421,20 @@ class NNIManager implements Manager { } else { this.trialConcurrencyChange = requestTrialNum; } - for (let i: number = 0; i < requestTrialNum; i++) { + + const requestCustomTrialNum: number = Math.min(requestTrialNum, this.customizedTrials.length); + for (let i: number = 0; i < requestCustomTrialNum; i++) { // ask tuner for more trials if (this.customizedTrials.length > 0) { const hyperParams: string | undefined = this.customizedTrials.shift(); this.dispatcher.sendCommand(ADD_CUSTOMIZED_TRIAL_JOB, hyperParams); - } else { - this.dispatcher.sendCommand(REQUEST_TRIAL_JOBS, '1'); } } + if (requestTrialNum - requestCustomTrialNum > 0) { + this.requestTrialJobs(requestTrialNum - requestCustomTrialNum); + } + // check maxtrialnum and maxduration here if (this.experimentProfile.execDuration > this.experimentProfile.params.maxExecDuration || this.currSubmittedTrialNum >= this.experimentProfile.params.maxTrialNum) { @@ -526,11 +532,9 @@ class NNIManager implements Manager { if (this.dispatcher === undefined) { throw new Error('Dispatcher error: tuner has not been setup'); } - // TO DO: we should send INITIALIZE command to tuner if user's tuner needs to run init method in tuner - this.log.debug(`Send tuner command: update search space: ${this.experimentProfile.params.searchSpace}`); - this.dispatcher.sendCommand(UPDATE_SEARCH_SPACE, this.experimentProfile.params.searchSpace); - this.log.debug(`Send tuner command: ${this.experimentProfile.params.trialConcurrency}`); - this.dispatcher.sendCommand(REQUEST_TRIAL_JOBS, String(this.experimentProfile.params.trialConcurrency)); + this.log.debug(`Send tuner command: INITIALIZE: ${this.experimentProfile.params.searchSpace}`); + // Tuner need to be initialized with search space before generating any hyper parameters + this.dispatcher.sendCommand(INITIALIZE, this.experimentProfile.params.searchSpace); } private async onTrialJobMetrics(metric: TrialJobMetric): Promise { @@ -541,9 +545,32 @@ class NNIManager implements Manager { this.dispatcher.sendCommand(REPORT_METRIC_DATA, metric.data); } + private requestTrialJobs(jobNum: number): void { + if (jobNum < 1) { + return; + } + if (this.dispatcher === undefined) { + throw new Error('Dispatcher error: tuner has not been setup'); + } + if (this.experimentProfile.params.multiThread) { + // Send multiple requests to ensure multiple hyper parameters are generated in non-blocking way. + // For a single REQUEST_TRIAL_JOBS request, hyper parameters are generated one by one + // sequentially. + for (let i: number = 0; i < jobNum; i++) { + this.dispatcher.sendCommand(REQUEST_TRIAL_JOBS, '1'); + } + } else { + this.dispatcher.sendCommand(REQUEST_TRIAL_JOBS, String(jobNum)); + } + } + private async onTunerCommand(commandType: string, content: string): Promise { this.log.info(`Command from tuner: ${commandType}, ${content}`); switch (commandType) { + case INITIALIZED: + // Tuner is intialized, search space is set, request tuner to generate hyper parameters + this.requestTrialJobs(this.experimentProfile.params.trialConcurrency); + break; case NEW_TRIAL_JOB: this.waitingTrials.push(content); break; diff --git a/src/nni_manager/rest_server/restValidationSchemas.ts b/src/nni_manager/rest_server/restValidationSchemas.ts index d727e9d13e..a298869d15 100644 --- a/src/nni_manager/rest_server/restValidationSchemas.ts +++ b/src/nni_manager/rest_server/restValidationSchemas.ts @@ -39,8 +39,26 @@ export namespace ValidationSchemas { outputDir: joi.string(), cpuNum: joi.number().min(1), memoryMB: joi.number().min(100), - gpuNum: joi.number().min(0).required(), - command: joi.string().min(1).required() + gpuNum: joi.number().min(0), + command: joi.string().min(1), + worker: joi.object({ + replicas: joi.number().min(1).required(), + image: joi.string().min(1), + outputDir: joi.string(), + cpuNum: joi.number().min(1), + memoryMB: joi.number().min(100), + gpuNum: joi.number().min(0).required(), + command: joi.string().min(1).required() + }), + ps: joi.object({ + replicas: joi.number().min(1).required(), + image: joi.string().min(1), + outputDir: joi.string(), + cpuNum: joi.number().min(1), + memoryMB: joi.number().min(100), + gpuNum: joi.number().min(0).required(), + command: joi.string().min(1).required() + }) }), pai_config: joi.object({ userName: joi.string().min(1).required(), @@ -68,6 +86,7 @@ export namespace ValidationSchemas { searchSpace: joi.string().required(), maxExecDuration: joi.number().min(0).required(), multiPhase: joi.boolean(), + multiThread: joi.boolean(), tuner: joi.object({ builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch'), codeDir: joi.string(), diff --git a/src/nni_manager/training_service/kubeflow/kubeflowConfig.ts b/src/nni_manager/training_service/kubeflow/kubeflowConfig.ts index dca552ee11..d24c309b21 100644 --- a/src/nni_manager/training_service/kubeflow/kubeflowConfig.ts +++ b/src/nni_manager/training_service/kubeflow/kubeflowConfig.ts @@ -79,15 +79,44 @@ export class NFSConfig { /** * Trial job configuration for Kubeflow */ -export class KubeflowTrialConfig extends TrialConfig { +export class KubeflowTrialConfigTemplate { + /** replication number of current role */ + public readonly replicas: number; + + /** CPU number */ public readonly cpuNum: number; + + /** Memory */ public readonly memoryMB: number; + + /** Docker image */ public readonly image: string; + + /** Trail command */ + public readonly command : string; + + /** Required GPU number for trial job. The number should be in [0,100] */ + public readonly gpuNum : number; - constructor(command : string, codeDir : string, gpuNum : number, cpuNum: number, memoryMB: number, image: string) { - super(command, codeDir, gpuNum); + constructor(replicas: number, command : string, gpuNum : number, + cpuNum: number, memoryMB: number, image: string) { + this.replicas = replicas; + this.command = command; + this.gpuNum = gpuNum; this.cpuNum = cpuNum; this.memoryMB = memoryMB; this.image = image; } +} + +export class KubeflowTrialConfig { + public readonly codeDir: string; + public readonly ps?: KubeflowTrialConfigTemplate; + public readonly worker: KubeflowTrialConfigTemplate; + + constructor(codeDir: string, worker: KubeflowTrialConfigTemplate, ps?: KubeflowTrialConfigTemplate) { + this.codeDir = codeDir; + this.worker = worker; + this.ps = ps; + } } \ No newline at end of file diff --git a/src/nni_manager/training_service/kubeflow/kubeflowData.ts b/src/nni_manager/training_service/kubeflow/kubeflowData.ts index f65d0cb603..0dce48732e 100644 --- a/src/nni_manager/training_service/kubeflow/kubeflowData.ts +++ b/src/nni_manager/training_service/kubeflow/kubeflowData.ts @@ -72,7 +72,7 @@ mkdir -p $NNI_OUTPUT_DIR cp -rT $NNI_CODE_DIR $NNI_SYS_DIR cd $NNI_SYS_DIR sh install_nni.sh # Check and install NNI pkg -python3 -m nni_trial_tool.trial_keeper --trial_command '{6}' --nnimanager_ip '{7}' --nnimanager_port '{8}' 1>$NNI_OUTPUT_DIR/trialkeeper_stdout 2>$NNI_OUTPUT_DIR//trialkeeper_stderr +python3 -m nni_trial_tool.trial_keeper --trial_command '{6}' --nnimanager_ip '{7}' --nnimanager_port '{8}' 1>$NNI_OUTPUT_DIR/trialkeeper_stdout 2>$NNI_OUTPUT_DIR/trialkeeper_stderr ` export type KubeflowTFJobType = 'Created' | 'Running' | 'Failed' | 'Succeeded'; \ No newline at end of file diff --git a/src/nni_manager/training_service/kubeflow/kubeflowTrainingService.ts b/src/nni_manager/training_service/kubeflow/kubeflowTrainingService.ts index 1f3b30b404..96082fb984 100644 --- a/src/nni_manager/training_service/kubeflow/kubeflowTrainingService.ts +++ b/src/nni_manager/training_service/kubeflow/kubeflowTrainingService.ts @@ -37,13 +37,15 @@ import { TrialJobDetail, TrialJobMetric } from '../../common/trainingService'; import { delay, generateParamFileName, getExperimentRootDir, getIPV4Address, uniqueString } from '../../common/utils'; -import { KubeflowClusterConfig, kubeflowOperatorMap, KubeflowTrialConfig, NFSConfig } from './kubeflowConfig'; +import { KubeflowClusterConfig, kubeflowOperatorMap, KubeflowTrialConfig, KubeflowTrialConfigTemplate, NFSConfig } from './kubeflowConfig'; import { KubeflowTrialJobDetail, KUBEFLOW_RUN_SHELL_FORMAT } from './kubeflowData'; import { KubeflowJobRestServer } from './kubeflowJobRestServer'; import { KubeflowJobInfoCollector } from './kubeflowJobInfoCollector'; var yaml = require('node-yaml'); +type DistTrainRole = 'worker' | 'ps'; + /** * Training Service implementation for Kubeflow * Refer https://github.com/kubeflow/kubeflow for more info about Kubeflow @@ -64,7 +66,7 @@ class KubeflowTrainingService implements TrainingService { private kubeflowJobInfoCollector: KubeflowJobInfoCollector; private kubeflowRestServerPort?: number; private kubeflowJobPlural?: string; - private readonly CONTAINER_MOUNT_PATH: string; + private readonly CONTAINER_MOUNT_PATH: string; constructor() { this.log = getLogger(); @@ -93,8 +95,8 @@ class KubeflowTrainingService implements TrainingService { throw new Error('Kubeflow Cluster config is not initialized'); } - if(!this.kubeflowTrialConfig) { - throw new Error('Kubeflow trial config is not initialized'); + if(!this.kubeflowTrialConfig || !this.kubeflowTrialConfig.worker) { + throw new Error('Kubeflow trial config or worker config is not initialized'); } if(!this.kubeflowJobPlural) { @@ -119,47 +121,57 @@ class KubeflowTrainingService implements TrainingService { // Write NNI installation file to local tmp files await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), runScriptContent, { encoding: 'utf8' }); - const kubeflowRunScriptContent: string = String.Format( - KUBEFLOW_RUN_SHELL_FORMAT, - `$PWD/nni/${trialJobId}`, - path.join(trialWorkingFolder, 'output'), - trialJobId, - getExperimentId(), - trialWorkingFolder, - curTrialSequenceId, - this.kubeflowTrialConfig.command, - getIPV4Address(), - this.kubeflowRestServerPort - ); - - //create tmp trial working folder locally. + // Create tmp trial working folder locally. await cpp.exec(`mkdir -p ${trialLocalTempFolder}`); - // Write file content ( run.sh and parameter.cfg ) to local tmp files - await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run.sh'), kubeflowRunScriptContent, { encoding: 'utf8' }); + // Write worker file content run_worker.sh to local tmp folders + if(this.kubeflowTrialConfig.worker) { + const workerRunScriptContent: string = this.genereateRunScript(trialJobId, trialWorkingFolder, + this.kubeflowTrialConfig.worker.command, curTrialSequenceId.toString(), 'worker'); + + await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run_worker.sh'), workerRunScriptContent, { encoding: 'utf8' }); + } + + // Write parameter server file content run_ps.sh to local tmp folders + if(this.kubeflowTrialConfig.ps) { + const psRunScriptContent: string = this.genereateRunScript(trialJobId, trialWorkingFolder, + this.kubeflowTrialConfig.ps.command, curTrialSequenceId.toString(), 'ps'); + + await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run_ps.sh'), psRunScriptContent, { encoding: 'utf8' }); + } // Write file content ( parameter.cfg ) to local tmp folders const trialForm : TrialJobApplicationForm = (form) if(trialForm && trialForm.hyperParameters) { await fs.promises.writeFile(path.join(trialLocalTempFolder, generateParamFileName(trialForm.hyperParameters)), trialForm.hyperParameters.value, { encoding: 'utf8' }); - } + } const kubeflowJobYamlPath = path.join(trialLocalTempFolder, `kubeflow-job-${trialJobId}.yaml`); const kubeflowJobName = `nni-exp-${this.experimentId}-trial-${trialJobId}`.toLowerCase(); - const podResources : any = {}; - podResources.requests = { - 'memory': `${this.kubeflowTrialConfig.memoryMB}Mi`, - 'cpu': `${this.kubeflowTrialConfig.cpuNum}`, - 'nvidia.com/gpu': `${this.kubeflowTrialConfig.gpuNum}` + const workerPodResources : any = {}; + workerPodResources.requests = { + 'memory': `${this.kubeflowTrialConfig.worker.memoryMB}Mi`, + 'cpu': `${this.kubeflowTrialConfig.worker.cpuNum}`, + 'nvidia.com/gpu': `${this.kubeflowTrialConfig.worker.gpuNum}` } - - podResources.limits = Object.assign({}, podResources.requests); + workerPodResources.limits = Object.assign({}, workerPodResources.requests); + + let psPodResources : any = undefined; + if(this.kubeflowTrialConfig.ps) { + psPodResources = {}; + psPodResources.requests = { + 'memory': `${this.kubeflowTrialConfig.ps.memoryMB}Mi`, + 'cpu': `${this.kubeflowTrialConfig.ps.cpuNum}`, + 'nvidia.com/gpu': `${this.kubeflowTrialConfig.ps.gpuNum}` + } + psPodResources.limits = Object.assign({}, psPodResources.requests); + } // Generate kubeflow job resource yaml file for K8S yaml.write( kubeflowJobYamlPath, - this.generateKubeflowJobConfig(trialJobId, trialWorkingFolder, kubeflowJobName, podResources), + this.generateKubeflowJobConfig(trialJobId, trialWorkingFolder, kubeflowJobName, workerPodResources, psPodResources), 'utf-8' ); @@ -281,6 +293,7 @@ class KubeflowTrainingService implements TrainingService { } this.kubeflowTrialConfig = JSON.parse(value); + assert(this.kubeflowClusterConfig !== undefined && this.kubeflowTrialConfig.worker !== undefined); break; default: break; @@ -339,7 +352,15 @@ class KubeflowTrainingService implements TrainingService { return this.metricsEmitter; } - private generateKubeflowJobConfig(trialJobId: string, trialWorkingFolder: string, kubeflowJobName : string, podResources : any) : any { + /** + * Generate kubeflow resource config file + * @param trialJobId trial job id + * @param trialWorkingFolder working folder + * @param kubeflowJobName job name + * @param workerPodResources worker pod template + * @param psPodResources ps pod template + */ + private generateKubeflowJobConfig(trialJobId: string, trialWorkingFolder: string, kubeflowJobName : string, workerPodResources : any, psPodResources?: any) : any { if(!this.kubeflowClusterConfig) { throw new Error('Kubeflow Cluster config is not initialized'); } @@ -348,6 +369,15 @@ class KubeflowTrainingService implements TrainingService { throw new Error('Kubeflow trial config is not initialized'); } + const tfReplicaSpecsObj: any = {}; + tfReplicaSpecsObj.Worker = this.generateReplicaConfig(trialWorkingFolder, this.kubeflowTrialConfig.worker.replicas, + this.kubeflowTrialConfig.worker.image, 'run_worker.sh', workerPodResources); + + if(this.kubeflowTrialConfig.ps) { + tfReplicaSpecsObj.Ps = this.generateReplicaConfig(trialWorkingFolder, this.kubeflowTrialConfig.ps.replicas, + this.kubeflowTrialConfig.ps.image, 'run_ps.sh', psPodResources); + } + return { apiVersion: 'kubeflow.org/v1alpha2', kind: 'TFJob', @@ -361,44 +391,84 @@ class KubeflowTrainingService implements TrainingService { } }, spec: { - tfReplicaSpecs: { - Worker: { - replicas: 1, - template: { - metadata: { - creationTimestamp: null - }, - spec: { - containers: [ - { - // Kubeflow tensorflow operator requires that containers' name must be tensorflow - // TODO: change the name based on operator's type - name: 'tensorflow', - image: this.kubeflowTrialConfig.image, - args: ["sh", `${path.join(trialWorkingFolder, 'run.sh')}`], - volumeMounts: [{ - name: 'nni-nfs-vol', - mountPath: this.CONTAINER_MOUNT_PATH - }], - resources: podResources//, - //workingDir: '/tmp/nni/nuDEP' - }], - restartPolicy: 'ExitCode', - volumes: [{ - name: 'nni-nfs-vol', - nfs: { - server: `${this.kubeflowClusterConfig.nfs.server}`, - path: `${this.kubeflowClusterConfig.nfs.path}` - } - }] - } - } - } - } + tfReplicaSpecs: tfReplicaSpecsObj } }; } + /** + * Generate tf-operator's tfjobs replica config section + * @param trialWorkingFolder trial working folder + * @param replicaNumber replica number + * @param replicaImage image + * @param runScriptFile script file name + * @param podResources pod resource config section + */ + private generateReplicaConfig(trialWorkingFolder: string, replicaNumber: number, replicaImage: string, runScriptFile: string, podResources: any): any { + if(!this.kubeflowClusterConfig) { + throw new Error('Kubeflow Cluster config is not initialized'); + } + + if(!this.kubeflowTrialConfig) { + throw new Error('Kubeflow trial config is not initialized'); + } + + return { + replicas: replicaNumber, + template: { + metadata: { + creationTimestamp: null + }, + spec: { + containers: [ + { + // Kubeflow tensorflow operator requires that containers' name must be tensorflow + // TODO: change the name based on operator's type + name: 'tensorflow', + image: replicaImage, + args: ["sh", `${path.join(trialWorkingFolder, runScriptFile)}`], + volumeMounts: [{ + name: 'nni-nfs-vol', + mountPath: this.CONTAINER_MOUNT_PATH + }], + resources: podResources + }], + restartPolicy: 'ExitCode', + volumes: [{ + name: 'nni-nfs-vol', + nfs: { + server: `${this.kubeflowClusterConfig.nfs.server}`, + path: `${this.kubeflowClusterConfig.nfs.path}` + } + }] + } + } + }; + } + + /** + * Genereate run script for different roles(like worker or ps) + * @param trialJobId trial job id + * @param trialWorkingFolder working folder + * @param command + * @param trialSequenceId sequence id + */ + private genereateRunScript(trialJobId: string, trialWorkingFolder: string, + command: string, trialSequenceId: string, roleType: DistTrainRole): string { + return String.Format( + KUBEFLOW_RUN_SHELL_FORMAT, + `$PWD/nni/${trialJobId}`, + path.join(trialWorkingFolder, `${roleType}_output`), + trialJobId, + getExperimentId(), + trialWorkingFolder, + trialSequenceId, + command, + getIPV4Address(), + this.kubeflowRestServerPort + ); + } + private generateSequenceId(): number { if (this.nextTrialSequenceId === -1) { this.nextTrialSequenceId = getInitTrialSequenceId(); diff --git a/src/nni_manager/training_service/local/localTrainingServiceForGPU.ts b/src/nni_manager/training_service/local/localTrainingServiceForGPU.ts index 28a908480f..b7f9d91efe 100644 --- a/src/nni_manager/training_service/local/localTrainingServiceForGPU.ts +++ b/src/nni_manager/training_service/local/localTrainingServiceForGPU.ts @@ -61,7 +61,7 @@ class LocalTrainingServiceForGPU extends LocalTrainingService { this.requiredGPUNum = 0; } this.log.info('required GPU number is ' + this.requiredGPUNum); - if (this.gpuScheduler === undefined) { + if (this.gpuScheduler === undefined && this.requiredGPUNum > 0) { this.gpuScheduler = new GPUScheduler(); } break; @@ -78,7 +78,7 @@ class LocalTrainingServiceForGPU extends LocalTrainingService { } protected onTrialJobStatusChanged(trialJob: LocalTrialJobDetailForGPU, oldStatus: TrialJobStatus): void { - if (trialJob.gpuIndices !== undefined && trialJob.gpuIndices.length !== 0) { + if (trialJob.gpuIndices !== undefined && trialJob.gpuIndices.length !== 0 && this.gpuScheduler !== undefined) { if (oldStatus === 'RUNNING' && trialJob.status !== 'RUNNING') { for (const index of trialJob.gpuIndices) { this.availableGPUIndices[index] = false; @@ -93,7 +93,7 @@ class LocalTrainingServiceForGPU extends LocalTrainingService { const variables: { key: string; value: string }[] = super.getEnvironmentVariables(trialJobDetail, resource); variables.push({ key: 'CUDA_VISIBLE_DEVICES', - value: resource.gpuIndices.join(',') + value: this.gpuScheduler === undefined ? '' : resource.gpuIndices.join(',') }); return variables; @@ -125,8 +125,10 @@ class LocalTrainingServiceForGPU extends LocalTrainingService { protected occupyResource(resource: { gpuIndices: number[] }): void { super.occupyResource(resource); - for (const index of resource.gpuIndices) { - this.availableGPUIndices[index] = true; + if (this.gpuScheduler !== undefined) { + for (const index of resource.gpuIndices) { + this.availableGPUIndices[index] = true; + } } } } diff --git a/src/sdk/pynni/nni/__main__.py b/src/sdk/pynni/nni/__main__.py index 5454e98343..27b1994a2f 100644 --- a/src/sdk/pynni/nni/__main__.py +++ b/src/sdk/pynni/nni/__main__.py @@ -28,6 +28,7 @@ import importlib from .constants import ModuleName, ClassName, ClassArgs +from nni.common import enable_multi_thread from nni.msg_dispatcher import MsgDispatcher from nni.multi_phase.multi_phase_dispatcher import MultiPhaseMsgDispatcher logger = logging.getLogger('nni.main') @@ -91,6 +92,7 @@ def parse_args(): parser.add_argument('--assessor_class_filename', type=str, required=False, help='Assessor class file path') parser.add_argument('--multi_phase', action='store_true') + parser.add_argument('--multi_thread', action='store_true') flags, _ = parser.parse_known_args() return flags @@ -101,6 +103,8 @@ def main(): ''' args = parse_args() + if args.multi_thread: + enable_multi_thread() tuner = None assessor = None diff --git a/src/sdk/pynni/nni/common.py b/src/sdk/pynni/nni/common.py index 79ee214aa2..cb21efda64 100644 --- a/src/sdk/pynni/nni/common.py +++ b/src/sdk/pynni/nni/common.py @@ -78,3 +78,12 @@ def init_logger(logger_file_path): logging.getLogger('matplotlib').setLevel(logging.INFO) sys.stdout = _LoggerFileWrapper(logger_file) + +_multi_thread = False + +def enable_multi_thread(): + global _multi_thread + _multi_thread = True + +def multi_thread_enabled(): + return _multi_thread diff --git a/src/sdk/pynni/nni/msg_dispatcher.py b/src/sdk/pynni/nni/msg_dispatcher.py index 1667d53562..b1489fd981 100644 --- a/src/sdk/pynni/nni/msg_dispatcher.py +++ b/src/sdk/pynni/nni/msg_dispatcher.py @@ -21,6 +21,7 @@ import logging from collections import defaultdict import json_tricks +import threading from .protocol import CommandType, send from .msg_dispatcher_base import MsgDispatcherBase @@ -69,7 +70,7 @@ def _pack_parameter(parameter_id, params, customized=False): class MsgDispatcher(MsgDispatcherBase): def __init__(self, tuner, assessor=None): - super() + super().__init__() self.tuner = tuner self.assessor = assessor if assessor is None: @@ -85,6 +86,14 @@ def save_checkpoint(self): if self.assessor is not None: self.assessor.save_checkpoint() + def handle_initialize(self, data): + ''' + data is search space + ''' + self.tuner.update_search_space(data) + send(CommandType.Initialized, '') + return True + def handle_request_trial_jobs(self, data): # data: number or trial jobs ids = [_create_parameter_id() for _ in range(data)] @@ -127,7 +136,7 @@ def handle_report_metric_data(self, data): if self.assessor is not None: self._handle_intermediate_metric_data(data) else: - pass + pass else: raise ValueError('Data type not supported: {}'.format(data['type'])) diff --git a/src/sdk/pynni/nni/msg_dispatcher_base.py b/src/sdk/pynni/nni/msg_dispatcher_base.py index ecd249fe49..d8a3f21c34 100644 --- a/src/sdk/pynni/nni/msg_dispatcher_base.py +++ b/src/sdk/pynni/nni/msg_dispatcher_base.py @@ -22,8 +22,8 @@ import os import logging import json_tricks - -from .common import init_logger +from multiprocessing.dummy import Pool as ThreadPool +from .common import init_logger, multi_thread_enabled from .recoverable import Recoverable from .protocol import CommandType, receive @@ -31,6 +31,10 @@ _logger = logging.getLogger(__name__) class MsgDispatcherBase(Recoverable): + def __init__(self): + if multi_thread_enabled(): + self.pool = ThreadPool() + def run(self): """Run the tuner. This function will never return unless raise. @@ -39,17 +43,24 @@ def run(self): if mode == 'resume': self.load_checkpoint() - while self.handle_request(): - pass + while True: + _logger.debug('waiting receive_message') + command, data = receive() + if command is None: + break + if multi_thread_enabled(): + self.pool.map_async(self.handle_request, [(command, data)]) + else: + self.handle_request((command, data)) - _logger.info('Terminated by NNI manager') + if multi_thread_enabled(): + self.pool.close() + self.pool.join() - def handle_request(self): - _logger.debug('waiting receive_message') + _logger.info('Terminated by NNI manager') - command, data = receive() - if command is None: - return False + def handle_request(self, request): + command, data = request _logger.debug('handle request: command: [{}], data: [{}]'.format(command, data)) @@ -60,6 +71,7 @@ def handle_request(self): command_handlers = { # Tunner commands: + CommandType.Initialize: self.handle_initialize, CommandType.RequestTrialJobs: self.handle_request_trial_jobs, CommandType.UpdateSearchSpace: self.handle_update_search_space, CommandType.AddCustomizedTrialJob: self.handle_add_customized_trial, @@ -74,6 +86,9 @@ def handle_request(self): return command_handlers[command](data) + def handle_initialize(self, data): + raise NotImplementedError('handle_initialize not implemented') + def handle_request_trial_jobs(self, data): raise NotImplementedError('handle_request_trial_jobs not implemented') diff --git a/src/sdk/pynni/nni/multi_phase/multi_phase_dispatcher.py b/src/sdk/pynni/nni/multi_phase/multi_phase_dispatcher.py index ec7d2be0f1..39b5c20039 100644 --- a/src/sdk/pynni/nni/multi_phase/multi_phase_dispatcher.py +++ b/src/sdk/pynni/nni/multi_phase/multi_phase_dispatcher.py @@ -91,6 +91,14 @@ def save_checkpoint(self): if self.assessor is not None: self.assessor.save_checkpoint() + def handle_initialize(self, data): + ''' + data is search space + ''' + self.tuner.update_search_space(data) + send(CommandType.Initialized, '') + return True + def handle_request_trial_jobs(self, data): # data: number or trial jobs ids = [_create_parameter_id() for _ in range(data)] diff --git a/src/sdk/pynni/nni/protocol.py b/src/sdk/pynni/nni/protocol.py index ada5527bfa..b547294432 100644 --- a/src/sdk/pynni/nni/protocol.py +++ b/src/sdk/pynni/nni/protocol.py @@ -19,7 +19,9 @@ # ================================================================================================== import logging +import threading from enum import Enum +from .common import multi_thread_enabled class CommandType(Enum): @@ -33,6 +35,7 @@ class CommandType(Enum): Terminate = b'TE' # out + Initialized = b'ID' NewTrialJob = b'TR' SendTrialJobParameter = b'SP' NoMoreTrialJobs = b'NO' @@ -42,6 +45,7 @@ class CommandType(Enum): try: _in_file = open(3, 'rb') _out_file = open(4, 'wb') + _lock = threading.Lock() except OSError: _msg = 'IPC pipeline not exists, maybe you are importing tuner/assessor from trial code?' import logging @@ -53,12 +57,19 @@ def send(command, data): command: CommandType object. data: string payload. """ - data = data.encode('utf8') - assert len(data) < 1000000, 'Command too long' - msg = b'%b%06d%b' % (command.value, len(data), data) - logging.getLogger(__name__).debug('Sending command, data: [%s]' % msg) - _out_file.write(msg) - _out_file.flush() + global _lock + try: + if multi_thread_enabled(): + _lock.acquire() + data = data.encode('utf8') + assert len(data) < 1000000, 'Command too long' + msg = b'%b%06d%b' % (command.value, len(data), data) + logging.getLogger(__name__).debug('Sending command, data: [%s]' % msg) + _out_file.write(msg) + _out_file.flush() + finally: + if multi_thread_enabled(): + _lock.release() def receive(): diff --git a/src/webui/src/components/SlideBar.tsx b/src/webui/src/components/SlideBar.tsx index c0cbabd4e6..5bfc83ade9 100644 --- a/src/webui/src/components/SlideBar.tsx +++ b/src/webui/src/components/SlideBar.tsx @@ -8,7 +8,9 @@ class SlideBar extends React.Component<{}, {}> { return (
  • - NNI logo + + NNI logo +
  • diff --git a/src/webui/src/components/TrialsDetail.tsx b/src/webui/src/components/TrialsDetail.tsx index 477df4d4aa..4115cd4e68 100644 --- a/src/webui/src/components/TrialsDetail.tsx +++ b/src/webui/src/components/TrialsDetail.tsx @@ -3,7 +3,7 @@ import axios from 'axios'; import { MANAGER_IP } from '../static/const'; import { Row, Col, Button, Tabs, Input } from 'antd'; const Search = Input.Search; -import { TableObj, Parameters, DetailAccurPoint, TooltipForAccuracy, } from '../static/interface'; +import { TableObj, Parameters, DetailAccurPoint, TooltipForAccuracy } from '../static/interface'; import Accuracy from './overview/Accuracy'; import Duration from './trial-detail/Duration'; import Title1 from './overview/Title1'; @@ -16,6 +16,7 @@ interface TrialDetailState { accSource: object; accNodata: string; tableListSource: Array; + tableBaseSource: Array; } class TrialsDetail extends React.Component<{}, TrialDetailState> { @@ -30,7 +31,8 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> { this.state = { accSource: {}, accNodata: '', - tableListSource: [] + tableListSource: [], + tableBaseSource: [] }; } // trial accuracy graph @@ -129,17 +131,21 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> { drawTableList = () => { - axios(`${MANAGER_IP}/trial-jobs`, { - method: 'GET' - }) - .then(res => { - if (res.status === 200) { + axios + .all([ + axios.get(`${MANAGER_IP}/trial-jobs`), + axios.get(`${MANAGER_IP}/metric-data`) + ]) + .then(axios.spread((res, res1) => { + if (res.status === 200 && res1.status === 200) { const trialJobs = res.data; + const metricSource = res1.data; const trialTable: Array = []; Object.keys(trialJobs).map(item => { // only succeeded trials have finalMetricData let desc: Parameters = { - parameters: {} + parameters: {}, + intermediate: [] }; let acc; let tableAcc = 0; @@ -171,6 +177,14 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> { desc.isLink = true; } } + let mediate: Array = []; + Object.keys(metricSource).map(key => { + const items = metricSource[key]; + if (items.trialJobId === id) { + mediate.push(items.data); + } + }); + desc.intermediate = mediate; if (trialJobs[item].finalMetricData !== undefined) { acc = JSON.parse(trialJobs[item].finalMetricData.data); if (typeof (acc) === 'object') { @@ -193,11 +207,12 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> { }); if (this._isMounted) { this.setState(() => ({ - tableListSource: trialTable + tableListSource: trialTable, + tableBaseSource: trialTable })); } } - }); + })); } callback = (key: string) => { @@ -228,10 +243,10 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> { searchTrialNo = (value: string) => { window.clearInterval(this.interTableList); - const { tableListSource } = this.state; + const { tableBaseSource } = this.state; const searchResultList: Array = []; - Object.keys(tableListSource).map(key => { - const item = tableListSource[key]; + Object.keys(tableBaseSource).map(key => { + const item = tableBaseSource[key]; if (item.sequenceId.toString() === value) { searchResultList.push(item); } @@ -271,7 +286,7 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> { accSource, accNodata, tableListSource } = this.state; - + const titleOfacc = ( ); diff --git a/src/webui/src/components/overview/Progress.tsx b/src/webui/src/components/overview/Progress.tsx index 45de55fb59..5d8ed1638f 100644 --- a/src/webui/src/components/overview/Progress.tsx +++ b/src/webui/src/components/overview/Progress.tsx @@ -29,12 +29,16 @@ class Progressed extends React.Component { trialNumber, bestAccuracy, status, errors } = this.props; - // remaining time const bar2 = trialNumber.totalCurrentTrial - trialNumber.waitTrial - trialNumber.unknowTrial; const bar2Percent = (bar2 / trialProfile.MaxTrialNum) * 100; const percent = (trialProfile.execDuration / trialProfile.maxDuration) * 100; const runDuration = convertTime(trialProfile.execDuration); - const remaining = convertTime(trialProfile.maxDuration - trialProfile.execDuration); + let remaining; + if (status === 'DONE') { + remaining = '0'; + } else { + remaining = convertTime(trialProfile.maxDuration - trialProfile.execDuration); + } let errorContent; if (errors !== '') { errorContent = ( @@ -81,7 +85,7 @@ class Progressed extends React.Component { maxString={`MaxTrialNumber: ${trialProfile.MaxTrialNum}`} /> -

    Best Accuracy

    +

    Best Default Metric

    {bestAccuracy}
    @@ -99,7 +103,7 @@ class Progressed extends React.Component { -

    Duration

    +

    MaxDuration

    {convertTime(trialProfile.maxDuration)}
    diff --git a/src/webui/src/components/trial-detail/Duration.tsx b/src/webui/src/components/trial-detail/Duration.tsx index 8f4dc6fe20..367f916faa 100644 --- a/src/webui/src/components/trial-detail/Duration.tsx +++ b/src/webui/src/components/trial-detail/Duration.tsx @@ -44,14 +44,6 @@ class Duration extends React.Component<{}, DurationState> { type: 'shadow' } }, - // title: { - // left: 'center', - // text: 'Trial Duration', - // textStyle: { - // fontSize: 18, - // color: '#333' - // } - // }, grid: { bottom: '3%', containLabel: true, @@ -108,7 +100,7 @@ class Duration extends React.Component<{}, DurationState> { } else { duration = (new Date().getTime() - start) / 1000; } - trialId.push(trialJobs[item].id); + trialId.push(trialJobs[item].sequenceId); trialTime.push(duration); } }); diff --git a/src/webui/src/components/trial-detail/Para.tsx b/src/webui/src/components/trial-detail/Para.tsx index 7508e050f0..6172f6b30a 100644 --- a/src/webui/src/components/trial-detail/Para.tsx +++ b/src/webui/src/components/trial-detail/Para.tsx @@ -3,7 +3,7 @@ import axios from 'axios'; import { MANAGER_IP } from '../../static/const'; import ReactEcharts from 'echarts-for-react'; import { Row, Col, Select, Button, message } from 'antd'; -import { HoverName, ParaObj, VisualMapValue, Dimobj } from '../../static/interface'; +import { ParaObj, VisualMapValue, Dimobj } from '../../static/interface'; const Option = Select.Option; require('echarts/lib/chart/parallel'); require('echarts/lib/component/tooltip'); @@ -243,30 +243,19 @@ class Para extends React.Component<{}, ParaState> { }; } else { visualMapObj = { + bottom: '20px', type: 'continuous', precision: 3, min: visualValue.minAccuracy, max: visualValue.maxAccuracy, - color: ['#CA0000', '#FFC400', '#90EE90'] + color: ['#CA0000', '#FFC400', '#90EE90'], + calculable: true }; } let optionown = { parallelAxis, tooltip: { - trigger: 'item', - formatter: function (params: HoverName) { - return params.name; - } - }, - toolbox: { - show: true, - left: 'right', - iconStyle: { - normal: { - borderColor: '#ddd' - } - }, - z: 202 + trigger: 'item' }, parallel: { parallelAxisDefault: { @@ -276,9 +265,6 @@ class Para extends React.Component<{}, ParaState> { } }, visualMap: visualMapObj, - highlight: { - type: 'highlight' - }, series: { type: 'parallel', smooth: true, diff --git a/src/webui/src/components/trial-detail/TableList.tsx b/src/webui/src/components/trial-detail/TableList.tsx index 9e3ccd7a4c..ead44bc974 100644 --- a/src/webui/src/components/trial-detail/TableList.tsx +++ b/src/webui/src/components/trial-detail/TableList.tsx @@ -98,7 +98,7 @@ class TableList extends React.Component { data: sequence }, yAxis: { - name: 'Accuracy', + name: 'Default Metric', type: 'value', data: intermediateArr }, @@ -165,7 +165,7 @@ class TableList extends React.Component { key: 'sequenceId', width: 120, className: 'tableHead', - sorter: (a: TableObj, b: TableObj) => (a.sequenceId as number) - (b.sequenceId as number), + sorter: (a: TableObj, b: TableObj) => (a.sequenceId as number) - (b.sequenceId as number) }, { title: 'Id', dataIndex: 'id', @@ -305,6 +305,11 @@ class TableList extends React.Component { const parametersRow = { parameters: record.description.parameters }; + const intermediate = record.description.intermediate; + let showIntermediate = ''; + if (intermediate && intermediate.length > 0) { + showIntermediate = intermediate.join(', '); + } let isLogLink: boolean = false; const logPathRow = record.description.logPath; if (record.description.isLink !== undefined) { @@ -340,6 +345,10 @@ class TableList extends React.Component { {logPathRow} } + + Intermediate Result: + {showIntermediate} + ); }; diff --git a/src/webui/src/static/interface.ts b/src/webui/src/static/interface.ts index 4faf16e8d6..578f67224d 100644 --- a/src/webui/src/static/interface.ts +++ b/src/webui/src/static/interface.ts @@ -15,6 +15,7 @@ interface Parameters { parameters: ErrorParameter; logPath?: string; isLink?: boolean; + intermediate?: Array; } interface Experiment { @@ -76,10 +77,6 @@ interface Dimobj { data?: string[]; } -interface HoverName { - name: string; -} - interface ParaObj { data: number[][]; parallelAxis: Array; @@ -90,8 +87,9 @@ interface VisualMapValue { minAccuracy: number; } -export {TableObj, Parameters, Experiment, +export { + TableObj, Parameters, Experiment, AccurPoint, TrialNumber, TrialJob, DetailAccurPoint, TooltipForAccuracy, - HoverName, ParaObj, VisualMapValue, Dimobj + ParaObj, VisualMapValue, Dimobj }; \ No newline at end of file diff --git a/src/webui/src/static/style/logPath.scss b/src/webui/src/static/style/logPath.scss index d9786e7998..7b5eb10f8f 100644 --- a/src/webui/src/static/style/logPath.scss +++ b/src/webui/src/static/style/logPath.scss @@ -1,8 +1,9 @@ .logpath{ margin-left: 10px; - + font-size: 14px; .logName{ color: #268BD2; + margin-right: 5px; } .logContent{ @@ -18,3 +19,8 @@ color: blue; text-decoration: underline; } + +.intermediate{ + white-space: normal; + font-size: 14px; +} diff --git a/tools/nni_cmd/config_schema.py b/tools/nni_cmd/config_schema.py index f18c1cefcd..ae65364339 100644 --- a/tools/nni_cmd/config_schema.py +++ b/tools/nni_cmd/config_schema.py @@ -31,6 +31,7 @@ 'trainingServicePlatform': And(str, lambda x: x in ['remote', 'local', 'pai', 'kubeflow']), Optional('searchSpacePath'): os.path.exists, Optional('multiPhase'): bool, +Optional('multiThread'): bool, 'useAnnotation': bool, 'tuner': Or({ 'builtinTunerName': Or('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch'), @@ -87,12 +88,23 @@ kubeflow_trial_schema = { 'trial':{ - 'command': str, - 'codeDir': os.path.exists, - 'gpuNum': And(int, lambda x: 0 <= x <= 99999), - 'cpuNum': And(int, lambda x: 0 <= x <= 99999), - 'memoryMB': int, - 'image': str + 'codeDir': os.path.exists, + Optional('ps'): { + 'replicas': int, + 'command': str, + 'gpuNum': And(int, lambda x: 0 <= x <= 99999), + 'cpuNum': And(int, lambda x: 0 <= x <= 99999), + 'memoryMB': int, + 'image': str + }, + 'worker':{ + 'replicas': int, + 'command': str, + 'gpuNum': And(int, lambda x: 0 <= x <= 99999), + 'cpuNum': And(int, lambda x: 0 <= x <= 99999), + 'memoryMB': int, + 'image': str + } } } diff --git a/tools/nni_cmd/launcher.py b/tools/nni_cmd/launcher.py index c902765c8f..5ea6c33d37 100644 --- a/tools/nni_cmd/launcher.py +++ b/tools/nni_cmd/launcher.py @@ -99,21 +99,7 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None def set_trial_config(experiment_config, port, config_file_name): '''set trial configuration''' request_data = dict() - value_dict = dict() - value_dict['command'] = experiment_config['trial']['command'] - value_dict['codeDir'] = experiment_config['trial']['codeDir'] - value_dict['gpuNum'] = experiment_config['trial']['gpuNum'] - if experiment_config['trial'].get('cpuNum'): - value_dict['cpuNum'] = experiment_config['trial']['cpuNum'] - if experiment_config['trial'].get('memoryMB'): - value_dict['memoryMB'] = experiment_config['trial']['memoryMB'] - if experiment_config['trial'].get('image'): - value_dict['image'] = experiment_config['trial']['image'] - if experiment_config['trial'].get('dataDir'): - value_dict['dataDir'] = experiment_config['trial']['dataDir'] - if experiment_config['trial'].get('outputDir'): - value_dict['outputDir'] = experiment_config['trial']['outputDir'] - request_data['trial_config'] = value_dict + request_data['trial_config'] = experiment_config['trial'] response = rest_put(cluster_metadata_url(port), json.dumps(request_data), 20) if check_response(response): return True @@ -196,6 +182,8 @@ def set_experiment(experiment_config, mode, port, config_file_name): request_data['description'] = experiment_config['description'] if experiment_config.get('multiPhase'): request_data['multiPhase'] = experiment_config.get('multiPhase') + if experiment_config.get('multiThread'): + request_data['multiThread'] = experiment_config.get('multiThread') request_data['tuner'] = experiment_config['tuner'] if 'assessor' in experiment_config: request_data['assessor'] = experiment_config['assessor'] @@ -209,31 +197,18 @@ def set_experiment(experiment_config, mode, port, config_file_name): elif experiment_config['trainingServicePlatform'] == 'remote': request_data['clusterMetaData'].append( {'key': 'machine_list', 'value': experiment_config['machineList']}) - value_dict = dict() - value_dict['command'] = experiment_config['trial']['command'] - value_dict['codeDir'] = experiment_config['trial']['codeDir'] - value_dict['gpuNum'] = experiment_config['trial']['gpuNum'] request_data['clusterMetaData'].append( - {'key': 'trial_config', 'value': value_dict}) + {'key': 'trial_config', 'value': experiment_config['trial']}) elif experiment_config['trainingServicePlatform'] == 'pai': request_data['clusterMetaData'].append( - {'key': 'pai_config', 'value': experiment_config['paiConfig']}) - value_dict = dict() - value_dict['command'] = experiment_config['trial']['command'] - value_dict['codeDir'] = experiment_config['trial']['codeDir'] - value_dict['gpuNum'] = experiment_config['trial']['gpuNum'] - if experiment_config['trial'].get('cpuNum'): - value_dict['cpuNum'] = experiment_config['trial']['cpuNum'] - if experiment_config['trial'].get('memoryMB'): - value_dict['memoryMB'] = experiment_config['trial']['memoryMB'] - if experiment_config['trial'].get('image'): - value_dict['image'] = experiment_config['trial']['image'] - if experiment_config['trial'].get('dataDir'): - value_dict['dataDir'] = experiment_config['trial']['dataDir'] - if experiment_config['trial'].get('outputDir'): - value_dict['outputDir'] = experiment_config['trial']['outputDir'] + {'key': 'pai_config', 'value': experiment_config['paiConfig']}) request_data['clusterMetaData'].append( - {'key': 'trial_config', 'value': value_dict}) + {'key': 'trial_config', 'value': experiment_config['trial']}) + elif experiment_config['trainingServicePlatform'] == 'kubeflow': + request_data['clusterMetaData'].append( + {'key': 'kubeflow_config', 'value': experiment_config['kubeflowConfig']}) + request_data['clusterMetaData'].append( + {'key': 'trial_config', 'value': experiment_config['trial']}) response = rest_post(experiment_url(port), json.dumps(request_data), 20) if check_response(response):