diff --git a/src/nni_manager/rest_server/restValidationSchemas.ts b/src/nni_manager/rest_server/restValidationSchemas.ts index a480501a79..b845dcc30e 100644 --- a/src/nni_manager/rest_server/restValidationSchemas.ts +++ b/src/nni_manager/rest_server/restValidationSchemas.ts @@ -103,7 +103,6 @@ export namespace ValidationSchemas { }), pai_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase userName: joi.string().min(1).required(), - passWord: joi.string().min(1), token: joi.string().min(1), host: joi.string().min(1).required(), reuse: joi.boolean(), diff --git a/src/nni_manager/training_service/pai/paiJobInfoCollector.ts b/src/nni_manager/training_service/pai/paiJobInfoCollector.ts index 2590547849..5f6ccf4d9c 100644 --- a/src/nni_manager/training_service/pai/paiJobInfoCollector.ts +++ b/src/nni_manager/training_service/pai/paiJobInfoCollector.ts @@ -52,7 +52,7 @@ export class PAIJobInfoCollector { // Rest call to get PAI job info and update status // Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API const getJobInfoRequest: request.Options = { - uri: `${protocol}://${paiClusterConfig.host}/rest-server/api/v1/user/${paiClusterConfig.userName}/jobs/${paiTrialJob.paiJobName}`, + uri: `${protocol}://${paiClusterConfig.host}/rest-server/api/v2/jobs/${paiClusterConfig.userName}~${paiTrialJob.paiJobName}`, method: 'GET', json: true, headers: { @@ -63,8 +63,9 @@ export class PAIJobInfoCollector { //TODO : pass in request timeout param? request(getJobInfoRequest, (error: Error, response: request.Response, _body: any) => { - if ((error !== undefined && error !== null) || response.statusCode >= 500) { - this.log.error(`PAI Training service: get job info for trial ${paiTrialJob.id} from PAI Cluster failed!`); + // Status code 200 for success + if ((error !== undefined && error !== null) || response.statusCode >= 400) { + // The job refresh time could be ealier than job submission, so it might return 404 error code, need refactor // Queried PAI job info failed, set job status to UNKNOWN if (paiTrialJob.status === 'WAITING' || paiTrialJob.status === 'RUNNING') { paiTrialJob.status = 'UNKNOWN'; diff --git a/src/nni_manager/training_service/pai/paiK8S/paiK8STrainingService.ts b/src/nni_manager/training_service/pai/paiK8S/paiK8STrainingService.ts index 59bd994535..6d6169a6b2 100644 --- a/src/nni_manager/training_service/pai/paiK8S/paiK8STrainingService.ts +++ b/src/nni_manager/training_service/pai/paiK8S/paiK8STrainingService.ts @@ -55,12 +55,7 @@ class PAIK8STrainingService extends PAITrainingService { this.paiJobRestServer = new PAIJobRestServer(component.get(PAIK8STrainingService)); this.paiClusterConfig = JSON.parse(value); this.paiClusterConfig.host = this.formatPAIHost(this.paiClusterConfig.host); - if (this.paiClusterConfig.passWord) { - // Get PAI authentication token - await this.updatePaiToken(); - } else if (this.paiClusterConfig.token) { - this.paiToken = this.paiClusterConfig.token; - } + this.paiToken = this.paiClusterConfig.token; break; case TrialConfigMetadataKey.TRIAL_CONFIG: { @@ -290,18 +285,20 @@ class PAIK8STrainingService extends PAITrainingService { uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v2/jobs`, method: 'POST', body: paiJobConfig, + followAllRedirects: true, headers: { 'Content-Type': 'text/yaml', Authorization: `Bearer ${this.paiToken}` } }; request(submitJobRequest, (error: Error, response: request.Response, body: any) => { + // If submit success, will get status code 202. refer: https://github.com/microsoft/pai/blob/master/src/rest-server/docs/swagger.yaml if ((error !== undefined && error !== null) || response.statusCode >= 400) { const errorMessage: string = (error !== undefined && error !== null) ? error.message : `Submit trial ${trialJobId} failed, http code:${response.statusCode}, http body: ${body}`; - this.log.error(errorMessage); trialJobDetail.status = 'FAILED'; + deferred.reject(errorMessage); } else { trialJobDetail.submitTime = Date.now(); } diff --git a/src/nni_manager/training_service/pai/paiTrainingService.ts b/src/nni_manager/training_service/pai/paiTrainingService.ts index e26c16ecee..aff583de54 100644 --- a/src/nni_manager/training_service/pai/paiTrainingService.ts +++ b/src/nni_manager/training_service/pai/paiTrainingService.ts @@ -162,8 +162,7 @@ abstract class PAITrainingService implements TrainingService { } const stopJobRequest: request.Options = { - uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v1/user/${this.paiClusterConfig.userName}\ -/jobs/${trialJobDetail.paiJobName}/executionType`, + uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v2/jobs/${this.paiClusterConfig.userName}~${trialJobDetail.paiJobName}/executionType`, method: 'PUT', json: true, body: { value: 'STOP' }, @@ -178,6 +177,7 @@ abstract class PAITrainingService implements TrainingService { const deferred: Deferred = new Deferred(); request(stopJobRequest, (error: Error, response: request.Response, _body: any) => { + // Status code 202 for success. if ((error !== undefined && error !== null) || response.statusCode >= 400) { this.log.error(`PAI Training service: stop trial ${trialJobId} to PAI Cluster failed!`); deferred.reject((error !== undefined && error !== null) ? error.message : diff --git a/src/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts b/src/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts index c997b03a01..b291690a0b 100644 --- a/src/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts +++ b/src/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts @@ -57,6 +57,7 @@ class RemoteMachineTrainingService implements TrainingService { private nniManagerIpConfig?: NNIManagerIpConfig; private versionCheck: boolean = true; private logCollection: string; + private sshConnectionPromises: any[]; constructor(@component.Inject timer: ObservableTimer) { this.metricsEmitter = new EventEmitter(); @@ -65,6 +66,7 @@ class RemoteMachineTrainingService implements TrainingService { this.machineCopyExpCodeDirPromiseMap = new Map>(); this.machineExecutorManagerMap = new Map(); this.jobQueue = []; + this.sshConnectionPromises = []; this.expRootDir = getExperimentRootDir(); this.timer = timer; this.log = getLogger(); @@ -80,6 +82,12 @@ class RemoteMachineTrainingService implements TrainingService { await restServer.start(); restServer.setEnableVersionCheck = this.versionCheck; this.log.info('Run remote machine training service.'); + if (this.sshConnectionPromises.length > 0) { + await Promise.all(this.sshConnectionPromises); + this.log.info('ssh connection initialized!'); + // set sshConnectionPromises to [] to avoid log information duplicated + this.sshConnectionPromises = []; + } while (!this.stopping) { while (this.jobQueue.length > 0) { this.updateGpuReservation(); @@ -408,7 +416,6 @@ class RemoteMachineTrainingService implements TrainingService { //TO DO: verify if value's format is wrong, and json parse failed, how to handle error const rmMetaList: RemoteMachineMeta[] = JSON.parse(machineList); - const connectionPromises = []; for (const rmMeta of rmMetaList) { rmMeta.occupiedGpuIndexMap = new Map(); const executorManager: ExecutorManager = new ExecutorManager(rmMeta); @@ -417,11 +424,9 @@ class RemoteMachineTrainingService implements TrainingService { this.log.debug(`reached ${executor.name}`); this.machineExecutorManagerMap.set(rmMeta, executorManager); this.log.debug(`initializing ${executor.name}`); - connectionPromises.push(this.initRemoteMachineOnConnected(rmMeta, executor)); - this.log.info(`connected to ${executor.name}`); + this.sshConnectionPromises.push(this.initRemoteMachineOnConnected(rmMeta, executor)); + this.log.info(`connecting to ${executor.name}`); } - - await Promise.all(connectionPromises); } private async initRemoteMachineOnConnected(rmMeta: RemoteMachineMeta, executor: ShellExecutor): Promise { diff --git a/src/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts b/src/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts index fea393e75d..aefd94fbb5 100644 --- a/src/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts +++ b/src/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts @@ -16,7 +16,7 @@ import { AMLClient } from '../aml/amlClient'; import { AMLClusterConfig, AMLEnvironmentInformation, AMLTrialConfig } from '../aml/amlConfig'; import { AMLCommandChannel } from '../channels/amlCommandChannel'; import { CommandChannel } from "../commandChannel"; -import { EnvironmentInformation, EnvironmentService, EnvironmentStatus } from '../environment'; +import { EnvironmentInformation, EnvironmentService } from '../environment'; /** @@ -74,7 +74,7 @@ export class AMLEnvironmentService extends EnvironmentService { environments.forEach(async (environment) => { const amlClient = (environment as AMLEnvironmentInformation).amlClient; if (!amlClient) { - throw new Error('AML client not initialized!'); + return Promise.reject('AML client not initialized!'); } const newStatus = await amlClient.updateStatus(environment.status); switch (newStatus.toUpperCase()) { @@ -90,8 +90,8 @@ export class AMLEnvironmentService extends EnvironmentService { environment.setStatus('SUCCEEDED'); break; case 'FAILED': - environment.setStatus(newStatus.toUpperCase() as EnvironmentStatus); - break; + environment.setStatus('FAILED'); + return Promise.reject(`AML: job ${environment.envId} is failed!`); case 'STOPPED': case 'STOPPING': environment.setStatus('USER_CANCELED'); diff --git a/src/nni_manager/training_service/reusable/environments/openPaiEnvironmentService.ts b/src/nni_manager/training_service/reusable/environments/openPaiEnvironmentService.ts index 3d92df5c99..596c81dbe9 100644 --- a/src/nni_manager/training_service/reusable/environments/openPaiEnvironmentService.ts +++ b/src/nni_manager/training_service/reusable/environments/openPaiEnvironmentService.ts @@ -28,15 +28,12 @@ export class OpenPaiEnvironmentService extends EnvironmentService { private paiTrialConfig: NNIPAIK8STrialConfig | undefined; private paiJobConfig: any; private paiToken?: string; - private paiTokenUpdateTime?: number; - private readonly paiTokenUpdateInterval: number; private protocol: string = 'http'; private experimentId: string; constructor() { super(); - this.paiTokenUpdateInterval = 7200000; //2hours this.experimentId = getExperimentId(); } @@ -53,12 +50,7 @@ export class OpenPaiEnvironmentService extends EnvironmentService { case TrialConfigMetadataKey.PAI_CLUSTER_CONFIG: this.paiClusterConfig = JSON.parse(value); this.paiClusterConfig.host = this.formatPAIHost(this.paiClusterConfig.host); - if (this.paiClusterConfig.passWord) { - // Get PAI authentication token - await this.updatePaiToken(); - } else if (this.paiClusterConfig.token) { - this.paiToken = this.paiClusterConfig.token; - } + this.paiToken = this.paiClusterConfig.token; break; case TrialConfigMetadataKey.TRIAL_CONFIG: { @@ -95,7 +87,6 @@ export class OpenPaiEnvironmentService extends EnvironmentService { public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise { const deferred: Deferred = new Deferred(); - await this.refreshPlatform(); if (this.paiClusterConfig === undefined) { throw new Error('PAI Cluster config is not initialized'); @@ -115,9 +106,12 @@ export class OpenPaiEnvironmentService extends EnvironmentService { }; request(getJobInfoRequest, async (error: any, response: request.Response, body: any) => { + // Status code 200 for success if ((error !== undefined && error !== null) || response.statusCode >= 400) { - this.log.error(`OpenPAI: get environment list from PAI Cluster failed!\nerror: ${error}`); - deferred.reject(error); + const errorMessage: string = (error !== undefined && error !== null) ? error.message : + `OpenPAI: get environment list from PAI Cluster failed!, http code:${response.statusCode}, http body: ${JSON.stringify(body)}`; + this.log.error(`${errorMessage}`); + deferred.reject(errorMessage); } else { const jobInfos = new Map(); body.forEach((jobInfo: any) => { @@ -133,8 +127,11 @@ export class OpenPaiEnvironmentService extends EnvironmentService { case 'RUNNING': case 'WAITING': case 'SUCCEEDED': + environment.setStatus(jobResponse.state); + break; case 'FAILED': environment.setStatus(jobResponse.state); + deferred.reject(`OpenPAI: job ${environment.envId} is failed!`); break; case 'STOPPED': case 'STOPPING': @@ -166,8 +163,6 @@ export class OpenPaiEnvironmentService extends EnvironmentService { public async startEnvironment(environment: EnvironmentInformation): Promise { const deferred: Deferred = new Deferred(); - await this.refreshPlatform(); - if (this.paiClusterConfig === undefined) { throw new Error('PAI Cluster config is not initialized'); } @@ -195,18 +190,21 @@ export class OpenPaiEnvironmentService extends EnvironmentService { uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v2/jobs`, method: 'POST', body: paiJobConfig, + followAllRedirects: true, headers: { 'Content-Type': 'text/yaml', Authorization: `Bearer ${this.paiToken}` } }; request(submitJobRequest, (error, response, body) => { + // Status code 202 for success, refer https://github.com/microsoft/pai/blob/master/src/rest-server/docs/swagger.yaml if ((error !== undefined && error !== null) || response.statusCode >= 400) { const errorMessage: string = (error !== undefined && error !== null) ? error.message : `start environment ${environment.envId} failed, http code:${response.statusCode}, http body: ${body}`; this.log.error(errorMessage); environment.status = 'FAILED'; + deferred.reject(errorMessage); } deferred.resolve(); }); @@ -241,8 +239,11 @@ export class OpenPaiEnvironmentService extends EnvironmentService { try { request(stopJobRequest, (error, response, _body) => { try { + // Status code 202 for success. if ((error !== undefined && error !== null) || (response && response.statusCode >= 400)) { - this.log.error(`OpenPAI: stop job ${environment.envId} failed with ${response.statusCode}\n${error}`); + const errorMessage: string = (error !== undefined && error !== null) ? error.message : + `OpenPAI: stop job ${environment.envId} failed, http code:${response.statusCode}, http body: ${_body}`; + this.log.error(`${errorMessage}`); deferred.reject((error !== undefined && error !== null) ? error : `Stop trial failed, http code: ${response.statusCode}`); } else { @@ -262,19 +263,6 @@ export class OpenPaiEnvironmentService extends EnvironmentService { return deferred.promise; } - private async refreshPlatform(): Promise { - if (this.paiClusterConfig && this.paiClusterConfig.passWord) { - try { - await this.updatePaiToken(); - } catch (error) { - this.log.error(`${error}`); - if (this.paiToken === undefined) { - throw new Error(error); - } - } - } - } - private generateJobConfigInYamlFormat(environment: EnvironmentInformation): any { if (this.paiTrialConfig === undefined) { throw new Error('trial config is not initialized'); @@ -386,59 +374,4 @@ export class OpenPaiEnvironmentService extends EnvironmentService { return host; } } - /** - * Update pai token by the interval time or initialize the pai token - */ - protected async updatePaiToken(): Promise { - const deferred: Deferred = new Deferred(); - - const currentTime: number = new Date().getTime(); - //If pai token initialized and not reach the interval time, do not update - if (this.paiTokenUpdateTime !== undefined && (currentTime - this.paiTokenUpdateTime) < this.paiTokenUpdateInterval) { - return Promise.resolve(); - } - - if (this.paiClusterConfig === undefined) { - const paiClusterConfigError: string = `pai cluster config not initialized!`; - this.log.error(`${paiClusterConfigError}`); - throw Error(`${paiClusterConfigError}`); - } - - const authenticationReq: request.Options = { - uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v1/token`, - method: 'POST', - json: true, - body: { - username: this.paiClusterConfig.userName, - password: this.paiClusterConfig.passWord - } - }; - - request(authenticationReq, (error: any, response: request.Response, body: any) => { - if (error !== undefined && error !== null) { - this.log.error(`Get PAI token failed: ${error.message}, authenticationReq: ${authenticationReq}`); - deferred.reject(new Error(`Get PAI token failed: ${error.message}`)); - } else { - if (response.statusCode !== 200) { - this.log.error(`Get PAI token failed: get PAI Rest return code ${response.statusCode}, authenticationReq: ${authenticationReq}`); - deferred.reject(new Error(`Get PAI token failed code: ${response.statusCode}, body: ${response.body}, authenticationReq: ${authenticationReq}, please check paiConfig username or password`)); - } else { - this.paiToken = body.token; - this.paiTokenUpdateTime = new Date().getTime(); - deferred.resolve(); - } - } - }); - - let timeoutId: NodeJS.Timer; - const timeoutDelay: Promise = new Promise((_resolve: Function, reject: Function): void => { - // Set timeout and reject the promise once reach timeout (5 seconds) - timeoutId = setTimeout( - () => reject(new Error('Get PAI token timeout. Please check your PAI cluster.')), - 5000); - }); - - return Promise.race([timeoutDelay, deferred.promise]) - .finally(() => { clearTimeout(timeoutId); }); - } } diff --git a/src/sdk/pynni/nni/compression/torch/speedup/compressor.py b/src/sdk/pynni/nni/compression/torch/speedup/compressor.py index b31acfe664..41753e1c9f 100644 --- a/src/sdk/pynni/nni/compression/torch/speedup/compressor.py +++ b/src/sdk/pynni/nni/compression/torch/speedup/compressor.py @@ -141,6 +141,14 @@ def infer_modules_masks(self): """ for module_name, mask in self.masks.items(): _logger.debug('Start mask inference from %s', module_name) + if module_name not in self.torch_graph.name_to_node: + # this module is not traced in the torch_graph, + # jit.trace only correctly records functions and + # modules which are not data dependent (e.g., do + # not have conditionals on data in tensors) + # so, if a node is not traced, we just skip it. + _logger.warning('%s has mask, but not found in the traced graph, just skip it.', module_name) + continue self.infer_module_mask(module_name, None, mask=mask) def replace_compressed_modules(self): diff --git a/src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py b/src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py index 47aa8087df..2635617031 100644 --- a/src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py +++ b/src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py @@ -222,6 +222,7 @@ def __repr__(self): 'ReLU': lambda module_masks, mask: relu_inshape(module_masks, mask), 'ReLU6': lambda module_masks, mask: relu_inshape(module_masks, mask), 'aten::relu': lambda module_masks, mask: relu_inshape(module_masks, mask), + 'aten::relu_': lambda module_masks, mask: relu_inshape(module_masks, mask), 'Conv2d': lambda module_masks, mask: conv2d_inshape(module_masks, mask), 'MaxPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), 'aten::max_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), @@ -241,7 +242,8 @@ def __repr__(self): 'aten::cat': lambda module_mask, mask, cat_info, last_visited: cat_inshape(module_mask, mask, cat_info, last_visited), 'aten::mean': lambda module_masks, mask, shape: mean_inshape(module_masks, mask, shape), 'Dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask), - 'Dropout2d': lambda module_masks, mask: dropout_inshape(module_masks, mask) + 'Dropout2d': lambda module_masks, mask: dropout_inshape(module_masks, mask), + 'aten::dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask) } """ @@ -258,8 +260,14 @@ def dropout_inshape(module_masks, mask): return module_masks.output_mask # if alreay visited assert module_masks.input_mask <= mask - if module_masks.input_mask == mask: - return None + # It should be the same, we pass the masks by the reference(not the value), + # so they acutually are two references of the same object(mask, + # module_masks.input_mask). So we should continue pass the mask + # to the following nodes even module_masks.input_mask == mask. + # if pass the mask by copy.deepcopy(), then we can stop when + # module_masks.input_mask == mask. + # if module_masks.input_mask == mask: + # return None module_masks.set_input_mask(mask) module_masks.set_output_mask(mask) return module_masks.output_mask @@ -413,7 +421,8 @@ def linear_inshape(module_masks, mask): """ assert isinstance(mask, CoarseMask) assert mask.mask_index[0] is None - assert module_masks.input_mask is None + if module_masks.input_mask is not None: + assert module_masks.input_mask <= mask module_masks.set_input_mask(mask) return None @@ -451,7 +460,10 @@ def view_inshape(module_masks, mask, shape): assert mask.mask_index[0] is None assert mask.mask_index[2] is None assert mask.mask_index[3] is None - assert module_masks.input_mask is None + # due to the cat operation, the same node may be + # accessed more than once + if module_masks.input_mask is not None: + assert module_masks.input_mask <= mask module_masks.set_input_mask(mask) output_cmask = CoarseMask(num_dim=2) index = [] @@ -535,12 +547,9 @@ def relu_inshape(module_masks, mask): The mask of its output tensor """ assert isinstance(mask, CoarseMask) - # TODO: double check this assert, is it possible that a module is passed twice if module_masks.input_mask is not None: # check if has a mask conflict - assert module_masks.input_mask == mask - # No need to pass the mask again - return None + assert module_masks.input_mask <= mask # assert module_masks.input_mask is None, "A relu op can only be processed once" module_masks.set_input_mask(mask) module_masks.set_output_mask(mask) diff --git a/src/sdk/pynni/tests/test_model_speedup.py b/src/sdk/pynni/tests/test_model_speedup.py index a06f991c97..845ed793ff 100644 --- a/src/sdk/pynni/tests/test_model_speedup.py +++ b/src/sdk/pynni/tests/test_model_speedup.py @@ -145,18 +145,18 @@ def test_speedup_bigmodel(self): assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY) def test_speedup_integration(self): - for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2']: + for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2', 'densenet121', 'inception_v3']: Model = getattr(models, model_name) net = Model(pretrained=True, progress=False).to(device) + speedup_model = Model().to(device) net.eval() # this line is necessary + speedup_model.eval() # random generate the prune config for the pruner cfgs = generate_random_sparsity(net) pruner = L1FilterPruner(net, cfgs) pruner.compress() pruner.export_model(MODEL_FILE, MASK_FILE) pruner._unwrap_model() - speedup_model = Model().to(device) - speedup_model.eval() state_dict = torch.load(MODEL_FILE) speedup_model.load_state_dict(state_dict) zero_bn_bias(net) diff --git a/src/webui/src/components/Modals/ExperimentDrawer.tsx b/src/webui/src/components/Modals/ExperimentPanel.tsx similarity index 56% rename from src/webui/src/components/Modals/ExperimentDrawer.tsx rename to src/webui/src/components/Modals/ExperimentPanel.tsx index 142af89f59..cbc674ec9a 100644 --- a/src/webui/src/components/Modals/ExperimentDrawer.tsx +++ b/src/webui/src/components/Modals/ExperimentPanel.tsx @@ -1,17 +1,16 @@ import * as React from 'react'; -import axios from 'axios'; import { downFile } from '../../static/function'; import { Stack, PrimaryButton, DefaultButton, Panel, StackItem, Pivot, PivotItem } from 'office-ui-fabric-react'; -import { MANAGER_IP, DRAWEROPTION } from '../../static/const'; +import { DRAWEROPTION } from '../../static/const'; +import { EXPERIMENT, TRIALS } from '../../static/datamodel'; import MonacoEditor from 'react-monaco-editor'; import '../../static/style/logDrawer.scss'; -import { TrialManager } from '../../static/model/trialmanager'; interface ExpDrawerProps { - isVisble: boolean; closeExpDrawer: () => void; + experimentProfile: object; } interface ExpDrawerState { @@ -21,7 +20,9 @@ interface ExpDrawerState { class ExperimentDrawer extends React.Component { - public _isCompareMount!: boolean; + public _isExperimentMount!: boolean; + private refreshId!: number | undefined; + constructor(props: ExpDrawerProps) { super(props); @@ -32,42 +33,40 @@ class ExperimentDrawer extends React.Component { } getExperimentContent = (): void => { - axios - .all([ - axios.get(`${MANAGER_IP}/experiment`), - axios.get(`${MANAGER_IP}/trial-jobs`), - axios.get(`${MANAGER_IP}/metric-data`) - ]) - .then(axios.spread((resExperiment, resTrialJobs, resMetricData) => { - if (resExperiment.status === 200 && resTrialJobs.status === 200 && resMetricData.status === 200) { - if (resExperiment.data.params.searchSpace) { - resExperiment.data.params.searchSpace = JSON.parse(resExperiment.data.params.searchSpace); - } - const trialMessagesArr = TrialManager.expandJobsToTrials(resTrialJobs.data); - const interResultList = resMetricData.data; - Object.keys(trialMessagesArr).map(item => { - // not deal with trial's hyperParameters - const trialId = trialMessagesArr[item].id; - // add intermediate result message - trialMessagesArr[item].intermediate = []; - Object.keys(interResultList).map(key => { - const interId = `${interResultList[key].trialJobId}-${interResultList[key].parameterId}`; - if (trialId === interId) { - trialMessagesArr[item].intermediate.push(interResultList[key]); - } - }); - }); - const result = { - experimentParameters: resExperiment.data, - trialMessage: trialMessagesArr - }; - if (this._isCompareMount === true) { - this.setState({ experiment: JSON.stringify(result, null, 4) }); - } + const experimentData = JSON.parse(JSON.stringify(this.props.experimentProfile)); + if (experimentData.params.searchSpace) { + experimentData.params.searchSpace = JSON.parse(experimentData.params.searchSpace); + } + const trialMessagesArr = TRIALS.getTrialJobList(); + const interResultList = TRIALS.getMetricsList(); + Object.keys(trialMessagesArr).map(item => { + // not deal with trial's hyperParameters + const trialId = trialMessagesArr[item].jobId; + // add intermediate result message + trialMessagesArr[item].intermediate = []; + Object.keys(interResultList).map(key => { + const interId = interResultList[key].trialJobId; + if (trialId === interId) { + trialMessagesArr[item].intermediate.push(interResultList[key]); } - })); - } + }); + }); + const result = { + experimentParameters: experimentData, + trialMessage: trialMessagesArr + }; + if (this._isExperimentMount === true) { + this.setState({ experiment: JSON.stringify(result, null, 4) }); + } + if (['DONE', 'ERROR', 'STOPPED'].includes(EXPERIMENT.status)) { + if(this.refreshId !== null || this.refreshId !== undefined){ + window.clearInterval(this.refreshId); + } + } + + } + downExperimentParameters = (): void => { const { experiment } = this.state; downFile(experiment, 'experiment.json'); @@ -78,31 +77,28 @@ class ExperimentDrawer extends React.Component { } componentDidMount(): void { - this._isCompareMount = true; + this._isExperimentMount = true; this.getExperimentContent(); + this.refreshId = window.setInterval(this.getExperimentContent, 10000); window.addEventListener('resize', this.onWindowResize); } - componentWillReceiveProps(nextProps: ExpDrawerProps): void { - const { isVisble } = nextProps; - if (isVisble === true) { - this.getExperimentContent(); - } - } - componentWillUnmount(): void { - this._isCompareMount = false; + this._isExperimentMount = false; + window.clearTimeout(this.refreshId); window.removeEventListener('resize', this.onWindowResize); } render(): React.ReactNode { - const { isVisble, closeExpDrawer } = this.props; + const { closeExpDrawer } = this.props; const { experiment, expDrawerHeight } = this.state; return ( diff --git a/src/webui/src/components/Modals/Killjob.tsx b/src/webui/src/components/Modals/Killjob.tsx index 580ff5ff24..2f4c7a1833 100644 --- a/src/webui/src/components/Modals/Killjob.tsx +++ b/src/webui/src/components/Modals/Killjob.tsx @@ -77,7 +77,7 @@ class KillJob extends React.Component { onKill = (): void => { this.setState({ isCalloutVisible: false }, () => { const { trial } = this.props; - killJob(trial.key, trial.jobId, trial.status); + killJob(trial.key, trial.id, trial.status); }); } diff --git a/src/webui/src/components/Modals/LogDrawer.tsx b/src/webui/src/components/Modals/LogPanel.tsx similarity index 98% rename from src/webui/src/components/Modals/LogDrawer.tsx rename to src/webui/src/components/Modals/LogPanel.tsx index a54b0f4c25..97a408fe9d 100644 --- a/src/webui/src/components/Modals/LogDrawer.tsx +++ b/src/webui/src/components/Modals/LogPanel.tsx @@ -92,6 +92,8 @@ class LogDrawer extends React.Component { isOpen={true} hasCloseButton={false} isFooterAtBottom={true} + isLightDismiss={true} + onLightDismissClick={closeDrawer} >
{ openDocs = (): void => { window.open(WEBUIDOC); } - + openGithubNNI = (): void => { - const {version} = this.state; + const { version } = this.state; const nniLink = `https://github.com/Microsoft/nni/tree/${version}`; window.open(nniLink); } @@ -178,8 +179,8 @@ class NavCon extends React.Component { {/* the drawer for dispatcher & nnimanager log message */} - {isvisibleLogDrawer && } - + {isvisibleLogDrawer && } + {isvisibleExperimentDrawer && } ); } diff --git a/src/webui/src/components/overview/Progress.tsx b/src/webui/src/components/overview/Progress.tsx index c63e23827c..c21e9e8921 100644 --- a/src/webui/src/components/overview/Progress.tsx +++ b/src/webui/src/components/overview/Progress.tsx @@ -9,7 +9,7 @@ import { EXPERIMENT, TRIALS } from '../../static/datamodel'; import { convertTime } from '../../static/function'; import ConcurrencyInput from './NumInput'; import ProgressBar from './ProgressItem'; -import LogDrawer from '../Modals/LogDrawer'; +import LogDrawer from '../Modals/LogPanel'; import MessageInfo from '../Modals/MessageInfo'; import { infoIcon } from "../Buttons/Icon"; import '../../static/style/progress.scss'; diff --git a/src/webui/src/components/trial-detail/Para.tsx b/src/webui/src/components/trial-detail/Para.tsx index f2f130949a..6a018b71e3 100644 --- a/src/webui/src/components/trial-detail/Para.tsx +++ b/src/webui/src/components/trial-detail/Para.tsx @@ -162,21 +162,32 @@ class Para extends React.Component { const scale = this.convertToD3Scale(v); if (k === primaryMetricKey && scale !== undefined && scale.interpolate) { // set color for primary metrics - colorScale = this.convertToD3Scale(v, false) - .range(['green', 'red']) - .interpolate(d3.interpolateHsl); - colorDim = k; + // `colorScale` is used to produce a color range, while `scale` is to produce a pixel range + colorScale = this.convertToD3Scale(v, false); + convertedTrials.sort((a, b) => EXPERIMENT.optimizeMode === 'minimize' ? a[k] - b[k] : b[k] - a[k]); // filter top trials if (percent != 1) { const keptTrialNum = Math.max(Math.ceil(convertedTrials.length * percent), 1); - convertedTrials.sort((a, b) => EXPERIMENT.optimizeMode === 'minimize' ? a[k] - b[k] : b[k] - a[k]); convertedTrials = convertedTrials.slice(0, keptTrialNum); const domain = d3.extent(convertedTrials, item => item[k]); scale.domain([domain[0], domain[1]]); + colorScale.domain([domain[0], domain[1]]); if (colorScale !== undefined) { colorScale.domain(domain); } } + // reverse the converted trials to show the top ones upfront + convertedTrials.reverse(); + const assignColors = (scale: any): void => { + scale.range([0, 1]); // fake a range to perform invert + const [scaleMin, scaleMax] = scale.domain(); + const pivot = scale.invert(0.5); + scale.domain([scaleMin, pivot, scaleMax]) + .range(['#90EE90', '#FFC400', '#CA0000']) + .interpolate(d3.interpolateHsl); + }; + assignColors(colorScale); + colorDim = k; } dimensions.push([k, { type: 'number', diff --git a/src/webui/src/components/trial-detail/TableList.tsx b/src/webui/src/components/trial-detail/TableList.tsx index 94af7e8fb9..4ce1ccd7bb 100644 --- a/src/webui/src/components/trial-detail/TableList.tsx +++ b/src/webui/src/components/trial-detail/TableList.tsx @@ -269,7 +269,7 @@ class TableList extends React.Component { showIntermediateModal = async (record: TrialJobInfo, event: React.SyntheticEvent): Promise => { event.preventDefault(); event.stopPropagation(); - const res = await axios.get(`${MANAGER_IP}/metric-data/${record.jobId}`); + const res = await axios.get(`${MANAGER_IP}/metric-data/${record.id}`); if (res.status === 200) { const intermediateArr: number[] = []; // support intermediate result is dict because the last intermediate result is @@ -277,14 +277,10 @@ class TableList extends React.Component { // get intermediate result dict keys array const { intermediateKey } = this.state; const otherkeys: string[] = []; - // One trial job may contains multiple parameter id - // only show current trial's metric data - const metricDatas = res.data.filter(item => { - return item.parameterId == record.parameterId; - }); + const metricDatas = res.data; if (metricDatas.length !== 0) { // just add type=number keys - const intermediateMetrics = parseMetrics(res.data[0].data); + const intermediateMetrics = parseMetrics(metricDatas[0].data); for (const key in intermediateMetrics) { if (typeof intermediateMetrics[key] === 'number') { otherkeys.push(key); diff --git a/src/webui/src/static/interface.ts b/src/webui/src/static/interface.ts index c033c225f4..493ffc4b41 100644 --- a/src/webui/src/static/interface.ts +++ b/src/webui/src/static/interface.ts @@ -43,8 +43,6 @@ interface TableRecord { startTime: number; endTime?: number; id: string; - jobId: string; - parameterId: string; duration: number; status: string; intermediateCount: number; @@ -126,8 +124,6 @@ interface Intermedia { interface MetricDataRecord { timestamp: number; trialJobId: string; - trialId: string; - parameterId: string; type: string; sequence: number; data: string; @@ -135,8 +131,6 @@ interface MetricDataRecord { interface TrialJobInfo { id: string; - jobId: string; - parameterId: string; sequenceId: number; status: string; startTime?: number; diff --git a/src/webui/src/static/model/trial.ts b/src/webui/src/static/model/trial.ts index ebdd35bc77..b1431a0a5f 100644 --- a/src/webui/src/static/model/trial.ts +++ b/src/webui/src/static/model/trial.ts @@ -115,8 +115,6 @@ class Trial implements TableObj { key: this.info.id, sequenceId: this.info.sequenceId, id: this.info.id, - jobId: this.info.jobId, - parameterId: this.info.parameterId, // eslint-disable-next-line @typescript-eslint/no-non-null-assertion startTime: this.info.startTime!, endTime: this.info.endTime, diff --git a/src/webui/src/static/model/trialmanager.ts b/src/webui/src/static/model/trialmanager.ts index bc613e1ba1..ffc0f85f55 100644 --- a/src/webui/src/static/model/trialmanager.ts +++ b/src/webui/src/static/model/trialmanager.ts @@ -7,29 +7,13 @@ import { requestAxios } from '../function'; function groupMetricsByTrial(metrics: MetricDataRecord[]): Map { const ret = new Map(); for (const metric of metrics) { - const trialId = `${metric.trialJobId}-${metric.parameterId}`; - metric.trialId = trialId; - if (ret.has(trialId)) { + if (ret.has(metric.trialJobId)) { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - ret.get(trialId)!.push(metric); + ret.get(metric.trialJobId)!.push(metric); } else { - ret.set(trialId, [metric]); + ret.set(metric.trialJobId, [ metric ]); } } - // to compatiable with multi-trial in same job, fix offset of sequence - ret.forEach((trialMetrics) => { - let minSequenceNumber = Number.POSITIVE_INFINITY; - trialMetrics.map((item) => { - if (item.sequence < minSequenceNumber && item.type !== "FINAL") { - minSequenceNumber = item.sequence; - } - }); - trialMetrics.map((item) => { - if (item.type !== "FINAL") { - item.sequence -= minSequenceNumber; - } - }); - }); return ret; } @@ -48,6 +32,16 @@ class TrialManager { private latestMetricdataErrorMessage: string = ''; // metric-data-latest error message private isMetricdataRangeError: boolean = false; // metric-data-range api error filed private metricdataRangeErrorMessage: string = ''; // metric-data-latest error message + private metricsList: Array = []; + private trialJobList: Array = []; + + public getMetricsList(): Array { + return this.metricsList; + } + + public getTrialJobList(): Array { + return this.trialJobList; + } public async init(): Promise { while (!this.infoInitialized || !this.metricInitialized) { @@ -135,57 +129,6 @@ class TrialManager { return new MetricSpace([...this.trials.values()]); } - public static expandJobsToTrials(jobs: TrialJobInfo[]): TrialJobInfo[] { - const trials: TrialJobInfo[] = []; - - for (const jobInfo of jobs as TrialJobInfo[]) { - if (jobInfo.hyperParameters) { - let trial: TrialJobInfo | undefined; - let lastTrial: TrialJobInfo | undefined; - for (let i = 0; i < jobInfo.hyperParameters.length; i++) { - const hyperParameters = jobInfo.hyperParameters[i] - const hpObject = JSON.parse(hyperParameters); - const parameterId = hpObject["parameter_id"]; - trial = { - id: `${jobInfo.id}-${parameterId}`, - jobId: jobInfo.id, - parameterId: parameterId, - sequenceId: parameterId, - status: "SUCCEEDED", - startTime: jobInfo.startTime, - endTime: jobInfo.startTime, - hyperParameters: [hyperParameters], - logPath: jobInfo.logPath, - stderrPath: jobInfo.stderrPath, - }; - if (jobInfo.finalMetricData) { - for (const metricData of jobInfo.finalMetricData) { - if (metricData.parameterId == parameterId) { - trial.finalMetricData = [metricData]; - trial.endTime = metricData.timestamp; - break; - } - } - } - if (lastTrial) { - trial.startTime = lastTrial.endTime; - } else { - trial.startTime = jobInfo.startTime; - } - lastTrial = trial; - trials.push(trial); - } - if (lastTrial !== undefined) { - lastTrial.status = jobInfo.status; - lastTrial.endTime = jobInfo.endTime; - } - } else { - trials.push(jobInfo); - } - } - return trials; - } - // if this.jobListError = true, show trial error message [/trial-jobs] public jobListError(): boolean { return this.isJobListError; @@ -229,8 +172,7 @@ class TrialManager { let updated = false; requestAxios(`${MANAGER_IP}/trial-jobs`) .then(data => { - const newTrials = TrialManager.expandJobsToTrials(data as any); - for (const trialInfo of newTrials as TrialJobInfo[]) { + for (const trialInfo of data as TrialJobInfo[]) { if (this.trials.has(trialInfo.id)) { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion updated = this.trials.get(trialInfo.id)!.updateTrialJobInfo(trialInfo) || updated; @@ -265,7 +207,10 @@ class TrialManager { private async updateAllMetrics(): Promise { return requestAxios(`${MANAGER_IP}/metric-data`) - .then(data => this.doUpdateMetrics(data as any, false)) + .then(data => { + this.metricsList = data; + return this.doUpdateMetrics(data as any, false); + }) .catch(error => { this.isMetricdataError = true; this.MetricdataErrorMessage = `${error.message}`;