diff --git a/src/nni_manager/training_service/reusable/aml/amlClient.ts b/src/nni_manager/training_service/reusable/aml/amlClient.ts index a7e4c61f7f..04af301e5e 100644 --- a/src/nni_manager/training_service/reusable/aml/amlClient.ts +++ b/src/nni_manager/training_service/reusable/aml/amlClient.ts @@ -74,13 +74,11 @@ export class AMLClient { throw Error('python shell client not initialized!'); } this.pythonShellClient.send('tracking_url'); - let trackingUrl = ''; - this.pythonShellClient.on('message', function (status: any) { - const items = status.split(':'); - if (items[0] === 'tracking_url') { - trackingUrl = items.splice(1, items.length).join('') + this.pythonShellClient.on('message', (status: any) => { + const trackingUrl = this.parseContent('tracking_url', status); + if (trackingUrl !== '') { + deferred.resolve(trackingUrl); } - deferred.resolve(trackingUrl); }); this.monitorError(this.pythonShellClient, deferred); return deferred.promise; @@ -91,12 +89,11 @@ export class AMLClient { if (this.pythonShellClient === undefined) { throw Error('python shell client not initialized!'); } - let newStatus = oldStatus; this.pythonShellClient.send('update_status'); - this.pythonShellClient.on('message', function (status: any) { - const items = status.split(':'); - if (items[0] === 'status') { - newStatus = items.splice(1, items.length).join('') + this.pythonShellClient.on('message', (status: any) => { + let newStatus = this.parseContent('status', status); + if (newStatus === '') { + newStatus = oldStatus; } deferred.resolve(newStatus); }); @@ -117,10 +114,10 @@ export class AMLClient { throw Error('python shell client not initialized!'); } this.pythonShellClient.send('receive'); - this.pythonShellClient.on('message', function (command: any) { - const items = command.split(':') - if (items[0] === 'receive') { - deferred.resolve(JSON.parse(command.slice(8))) + this.pythonShellClient.on('message', (command: any) => { + const message = this.parseContent('receive', command); + if (message !== '') { + deferred.resolve(JSON.parse(message)) } }); this.monitorError(this.pythonShellClient, deferred); @@ -136,4 +133,13 @@ export class AMLClient { deferred.reject(error); }); } + + // Parse command content, command format is {head}:{content} + public parseContent(head: string, command: string): string { + const items = command.split(':'); + if (items[0] === head) { + return command.slice(head.length + 1); + } + return ''; + } } diff --git a/src/nni_manager/training_service/reusable/test/amlClient.test.ts b/src/nni_manager/training_service/reusable/test/amlClient.test.ts new file mode 100644 index 0000000000..608e34fec1 --- /dev/null +++ b/src/nni_manager/training_service/reusable/test/amlClient.test.ts @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +import * as chai from 'chai'; +import { cleanupUnitTest, prepareUnitTest } from '../../../common/utils'; +import chaiAsPromised = require("chai-as-promised"); +import { AMLClient } from '../aml/amlClient'; + + +describe('Unit Test for amlClient', () => { + + before(() => { + chai.should(); + chai.use(chaiAsPromised); + prepareUnitTest(); + }); + + after(() => { + cleanupUnitTest(); + }); + + it('test parseContent', async () => { + + let amlClient: AMLClient = new AMLClient('', '', '', '', '', '', '', ''); + + chai.assert.equal(amlClient.parseContent('test', 'test:1234'), '1234', "The content should be 1234"); + chai.assert.equal(amlClient.parseContent('test', 'abcd:1234'), '', "The content should be null"); + }); +});