From 109d9a3210f618d1803cc751436cdf3dfa2ba589 Mon Sep 17 00:00:00 2001 From: SparkSnail Date: Fri, 7 Aug 2020 11:12:27 +0800 Subject: [PATCH] Fix remote machine connection logic (#2725) --- .../remoteMachineTrainingService.ts | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts b/src/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts index c997b03a01..b291690a0b 100644 --- a/src/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts +++ b/src/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts @@ -57,6 +57,7 @@ class RemoteMachineTrainingService implements TrainingService { private nniManagerIpConfig?: NNIManagerIpConfig; private versionCheck: boolean = true; private logCollection: string; + private sshConnectionPromises: any[]; constructor(@component.Inject timer: ObservableTimer) { this.metricsEmitter = new EventEmitter(); @@ -65,6 +66,7 @@ class RemoteMachineTrainingService implements TrainingService { this.machineCopyExpCodeDirPromiseMap = new Map>(); this.machineExecutorManagerMap = new Map(); this.jobQueue = []; + this.sshConnectionPromises = []; this.expRootDir = getExperimentRootDir(); this.timer = timer; this.log = getLogger(); @@ -80,6 +82,12 @@ class RemoteMachineTrainingService implements TrainingService { await restServer.start(); restServer.setEnableVersionCheck = this.versionCheck; this.log.info('Run remote machine training service.'); + if (this.sshConnectionPromises.length > 0) { + await Promise.all(this.sshConnectionPromises); + this.log.info('ssh connection initialized!'); + // set sshConnectionPromises to [] to avoid log information duplicated + this.sshConnectionPromises = []; + } while (!this.stopping) { while (this.jobQueue.length > 0) { this.updateGpuReservation(); @@ -408,7 +416,6 @@ class RemoteMachineTrainingService implements TrainingService { //TO DO: verify if value's format is wrong, and json parse failed, how to handle error const rmMetaList: RemoteMachineMeta[] = JSON.parse(machineList); - const connectionPromises = []; for (const rmMeta of rmMetaList) { rmMeta.occupiedGpuIndexMap = new Map(); const executorManager: ExecutorManager = new ExecutorManager(rmMeta); @@ -417,11 +424,9 @@ class RemoteMachineTrainingService implements TrainingService { this.log.debug(`reached ${executor.name}`); this.machineExecutorManagerMap.set(rmMeta, executorManager); this.log.debug(`initializing ${executor.name}`); - connectionPromises.push(this.initRemoteMachineOnConnected(rmMeta, executor)); - this.log.info(`connected to ${executor.name}`); + this.sshConnectionPromises.push(this.initRemoteMachineOnConnected(rmMeta, executor)); + this.log.info(`connecting to ${executor.name}`); } - - await Promise.all(connectionPromises); } private async initRemoteMachineOnConnected(rmMeta: RemoteMachineMeta, executor: ShellExecutor): Promise {