diff --git a/src/nni_manager/training_service/reusable/aml/amlConfig.ts b/src/nni_manager/training_service/reusable/aml/amlConfig.ts index 2c101883e1..dd8c2345d4 100644 --- a/src/nni_manager/training_service/reusable/aml/amlConfig.ts +++ b/src/nni_manager/training_service/reusable/aml/amlConfig.ts @@ -36,4 +36,5 @@ export class AMLTrialConfig extends TrialConfig { export class AMLEnvironmentInformation extends EnvironmentInformation { public amlClient?: AMLClient; + public currentMessageIndex: number = -1; } diff --git a/src/nni_manager/training_service/reusable/channels/amlCommandChannel.ts b/src/nni_manager/training_service/reusable/channels/amlCommandChannel.ts index d57befd507..6fbdf40cef 100644 --- a/src/nni_manager/training_service/reusable/channels/amlCommandChannel.ts +++ b/src/nni_manager/training_service/reusable/channels/amlCommandChannel.ts @@ -14,7 +14,6 @@ class AMLRunnerConnection extends RunnerConnection { export class AMLCommandChannel extends CommandChannel { private stopping: boolean = false; - private currentMessageIndex: number = -1; private sendQueues: [EnvironmentInformation, string][] = []; private readonly NNI_METRICS_PATTERN: string = `NNISDK_MEb'(?.*?)'`; @@ -89,7 +88,9 @@ export class AMLCommandChannel extends CommandChannel { const runnerConnections = [...this.runnerConnections.values()] as AMLRunnerConnection[]; for (const runnerConnection of runnerConnections) { // to loop all commands - const amlClient = (runnerConnection.environment as AMLEnvironmentInformation).amlClient; + const amlEnvironmentInformation: AMLEnvironmentInformation = runnerConnection.environment as AMLEnvironmentInformation; + const amlClient = amlEnvironmentInformation.amlClient; + let currentMessageIndex = amlEnvironmentInformation.currentMessageIndex; if (!amlClient) { throw new Error('AML client not initialized!'); } @@ -97,15 +98,16 @@ export class AMLCommandChannel extends CommandChannel { if (command && Object.prototype.hasOwnProperty.call(command, "trial_runner")) { const messages = command['trial_runner']; if (messages) { - if (messages instanceof Object && this.currentMessageIndex < messages.length - 1) { - for (let index = this.currentMessageIndex + 1; index < messages.length; index ++) { + if (messages instanceof Object && currentMessageIndex < messages.length - 1) { + for (let index = currentMessageIndex + 1; index < messages.length; index ++) { this.handleCommand(runnerConnection.environment, messages[index]); } - this.currentMessageIndex = messages.length - 1; - } else if (this.currentMessageIndex === -1){ + currentMessageIndex = messages.length - 1; + } else if (currentMessageIndex === -1){ this.handleCommand(runnerConnection.environment, messages); - this.currentMessageIndex += 1; + currentMessageIndex += 1; } + amlEnvironmentInformation.currentMessageIndex = currentMessageIndex; } } }