Skip to content

Commit

Permalink
Merge pull request #107 from microsoft/master
Browse files Browse the repository at this point in the history
pull
  • Loading branch information
chicm-ms authored Aug 10, 2020
2 parents 058c8b7 + 654e824 commit 9abd8c8
Show file tree
Hide file tree
Showing 20 changed files with 163 additions and 268 deletions.
1 change: 0 additions & 1 deletion src/nni_manager/rest_server/restValidationSchemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ export namespace ValidationSchemas {
}),
pai_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
userName: joi.string().min(1).required(),
passWord: joi.string().min(1),
token: joi.string().min(1),
host: joi.string().min(1).required(),
reuse: joi.boolean(),
Expand Down
7 changes: 4 additions & 3 deletions src/nni_manager/training_service/pai/paiJobInfoCollector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ export class PAIJobInfoCollector {
// Rest call to get PAI job info and update status
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
const getJobInfoRequest: request.Options = {
uri: `${protocol}://${paiClusterConfig.host}/rest-server/api/v1/user/${paiClusterConfig.userName}/jobs/${paiTrialJob.paiJobName}`,
uri: `${protocol}://${paiClusterConfig.host}/rest-server/api/v2/jobs/${paiClusterConfig.userName}~${paiTrialJob.paiJobName}`,
method: 'GET',
json: true,
headers: {
Expand All @@ -63,8 +63,9 @@ export class PAIJobInfoCollector {

//TODO : pass in request timeout param?
request(getJobInfoRequest, (error: Error, response: request.Response, _body: any) => {
if ((error !== undefined && error !== null) || response.statusCode >= 500) {
this.log.error(`PAI Training service: get job info for trial ${paiTrialJob.id} from PAI Cluster failed!`);
// Status code 200 for success
if ((error !== undefined && error !== null) || response.statusCode >= 400) {
// The job refresh time could be ealier than job submission, so it might return 404 error code, need refactor
// Queried PAI job info failed, set job status to UNKNOWN
if (paiTrialJob.status === 'WAITING' || paiTrialJob.status === 'RUNNING') {
paiTrialJob.status = 'UNKNOWN';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,7 @@ class PAIK8STrainingService extends PAITrainingService {
this.paiJobRestServer = new PAIJobRestServer(component.get(PAIK8STrainingService));
this.paiClusterConfig = <PAIClusterConfig>JSON.parse(value);
this.paiClusterConfig.host = this.formatPAIHost(this.paiClusterConfig.host);
if (this.paiClusterConfig.passWord) {
// Get PAI authentication token
await this.updatePaiToken();
} else if (this.paiClusterConfig.token) {
this.paiToken = this.paiClusterConfig.token;
}
this.paiToken = this.paiClusterConfig.token;
break;

case TrialConfigMetadataKey.TRIAL_CONFIG: {
Expand Down Expand Up @@ -290,18 +285,20 @@ class PAIK8STrainingService extends PAITrainingService {
uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v2/jobs`,
method: 'POST',
body: paiJobConfig,
followAllRedirects: true,
headers: {
'Content-Type': 'text/yaml',
Authorization: `Bearer ${this.paiToken}`
}
};
request(submitJobRequest, (error: Error, response: request.Response, body: any) => {
// If submit success, will get status code 202. refer: https://github.com/microsoft/pai/blob/master/src/rest-server/docs/swagger.yaml
if ((error !== undefined && error !== null) || response.statusCode >= 400) {
const errorMessage: string = (error !== undefined && error !== null) ? error.message :
`Submit trial ${trialJobId} failed, http code:${response.statusCode}, http body: ${body}`;

this.log.error(errorMessage);
trialJobDetail.status = 'FAILED';
deferred.reject(errorMessage);
} else {
trialJobDetail.submitTime = Date.now();
}
Expand Down
4 changes: 2 additions & 2 deletions src/nni_manager/training_service/pai/paiTrainingService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,7 @@ abstract class PAITrainingService implements TrainingService {
}

const stopJobRequest: request.Options = {
uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v1/user/${this.paiClusterConfig.userName}\
/jobs/${trialJobDetail.paiJobName}/executionType`,
uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v2/jobs/${this.paiClusterConfig.userName}~${trialJobDetail.paiJobName}/executionType`,
method: 'PUT',
json: true,
body: { value: 'STOP' },
Expand All @@ -178,6 +177,7 @@ abstract class PAITrainingService implements TrainingService {
const deferred: Deferred<void> = new Deferred<void>();

request(stopJobRequest, (error: Error, response: request.Response, _body: any) => {
// Status code 202 for success.
if ((error !== undefined && error !== null) || response.statusCode >= 400) {
this.log.error(`PAI Training service: stop trial ${trialJobId} to PAI Cluster failed!`);
deferred.reject((error !== undefined && error !== null) ? error.message :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class RemoteMachineTrainingService implements TrainingService {
private nniManagerIpConfig?: NNIManagerIpConfig;
private versionCheck: boolean = true;
private logCollection: string;
private sshConnectionPromises: any[];

constructor(@component.Inject timer: ObservableTimer) {
this.metricsEmitter = new EventEmitter();
Expand All @@ -65,6 +66,7 @@ class RemoteMachineTrainingService implements TrainingService {
this.machineCopyExpCodeDirPromiseMap = new Map<RemoteMachineMeta, Promise<void>>();
this.machineExecutorManagerMap = new Map<RemoteMachineMeta, ExecutorManager>();
this.jobQueue = [];
this.sshConnectionPromises = [];
this.expRootDir = getExperimentRootDir();
this.timer = timer;
this.log = getLogger();
Expand All @@ -80,6 +82,12 @@ class RemoteMachineTrainingService implements TrainingService {
await restServer.start();
restServer.setEnableVersionCheck = this.versionCheck;
this.log.info('Run remote machine training service.');
if (this.sshConnectionPromises.length > 0) {
await Promise.all(this.sshConnectionPromises);
this.log.info('ssh connection initialized!');
// set sshConnectionPromises to [] to avoid log information duplicated
this.sshConnectionPromises = [];
}
while (!this.stopping) {
while (this.jobQueue.length > 0) {
this.updateGpuReservation();
Expand Down Expand Up @@ -408,7 +416,6 @@ class RemoteMachineTrainingService implements TrainingService {
//TO DO: verify if value's format is wrong, and json parse failed, how to handle error
const rmMetaList: RemoteMachineMeta[] = <RemoteMachineMeta[]>JSON.parse(machineList);

const connectionPromises = [];
for (const rmMeta of rmMetaList) {
rmMeta.occupiedGpuIndexMap = new Map<number, number>();
const executorManager: ExecutorManager = new ExecutorManager(rmMeta);
Expand All @@ -417,11 +424,9 @@ class RemoteMachineTrainingService implements TrainingService {
this.log.debug(`reached ${executor.name}`);
this.machineExecutorManagerMap.set(rmMeta, executorManager);
this.log.debug(`initializing ${executor.name}`);
connectionPromises.push(this.initRemoteMachineOnConnected(rmMeta, executor));
this.log.info(`connected to ${executor.name}`);
this.sshConnectionPromises.push(this.initRemoteMachineOnConnected(rmMeta, executor));
this.log.info(`connecting to ${executor.name}`);
}

await Promise.all(connectionPromises);
}

private async initRemoteMachineOnConnected(rmMeta: RemoteMachineMeta, executor: ShellExecutor): Promise<void> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import { AMLClient } from '../aml/amlClient';
import { AMLClusterConfig, AMLEnvironmentInformation, AMLTrialConfig } from '../aml/amlConfig';
import { AMLCommandChannel } from '../channels/amlCommandChannel';
import { CommandChannel } from "../commandChannel";
import { EnvironmentInformation, EnvironmentService, EnvironmentStatus } from '../environment';
import { EnvironmentInformation, EnvironmentService } from '../environment';


/**
Expand Down Expand Up @@ -74,7 +74,7 @@ export class AMLEnvironmentService extends EnvironmentService {
environments.forEach(async (environment) => {
const amlClient = (environment as AMLEnvironmentInformation).amlClient;
if (!amlClient) {
throw new Error('AML client not initialized!');
return Promise.reject('AML client not initialized!');
}
const newStatus = await amlClient.updateStatus(environment.status);
switch (newStatus.toUpperCase()) {
Expand All @@ -90,8 +90,8 @@ export class AMLEnvironmentService extends EnvironmentService {
environment.setStatus('SUCCEEDED');
break;
case 'FAILED':
environment.setStatus(newStatus.toUpperCase() as EnvironmentStatus);
break;
environment.setStatus('FAILED');
return Promise.reject(`AML: job ${environment.envId} is failed!`);
case 'STOPPED':
case 'STOPPING':
environment.setStatus('USER_CANCELED');
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,12 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
private paiTrialConfig: NNIPAIK8STrialConfig | undefined;
private paiJobConfig: any;
private paiToken?: string;
private paiTokenUpdateTime?: number;
private readonly paiTokenUpdateInterval: number;
private protocol: string = 'http';

private experimentId: string;

constructor() {
super();
this.paiTokenUpdateInterval = 7200000; //2hours
this.experimentId = getExperimentId();
}

Expand All @@ -53,12 +50,7 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
case TrialConfigMetadataKey.PAI_CLUSTER_CONFIG:
this.paiClusterConfig = <PAIClusterConfig>JSON.parse(value);
this.paiClusterConfig.host = this.formatPAIHost(this.paiClusterConfig.host);
if (this.paiClusterConfig.passWord) {
// Get PAI authentication token
await this.updatePaiToken();
} else if (this.paiClusterConfig.token) {
this.paiToken = this.paiClusterConfig.token;
}
this.paiToken = this.paiClusterConfig.token;
break;

case TrialConfigMetadataKey.TRIAL_CONFIG: {
Expand Down Expand Up @@ -95,7 +87,6 @@ export class OpenPaiEnvironmentService extends EnvironmentService {

public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
await this.refreshPlatform();

if (this.paiClusterConfig === undefined) {
throw new Error('PAI Cluster config is not initialized');
Expand All @@ -115,9 +106,12 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
};

request(getJobInfoRequest, async (error: any, response: request.Response, body: any) => {
// Status code 200 for success
if ((error !== undefined && error !== null) || response.statusCode >= 400) {
this.log.error(`OpenPAI: get environment list from PAI Cluster failed!\nerror: ${error}`);
deferred.reject(error);
const errorMessage: string = (error !== undefined && error !== null) ? error.message :
`OpenPAI: get environment list from PAI Cluster failed!, http code:${response.statusCode}, http body: ${JSON.stringify(body)}`;
this.log.error(`${errorMessage}`);
deferred.reject(errorMessage);
} else {
const jobInfos = new Map<string, any>();
body.forEach((jobInfo: any) => {
Expand All @@ -133,8 +127,11 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
case 'RUNNING':
case 'WAITING':
case 'SUCCEEDED':
environment.setStatus(jobResponse.state);
break;
case 'FAILED':
environment.setStatus(jobResponse.state);
deferred.reject(`OpenPAI: job ${environment.envId} is failed!`);
break;
case 'STOPPED':
case 'STOPPING':
Expand Down Expand Up @@ -166,8 +163,6 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
public async startEnvironment(environment: EnvironmentInformation): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();

await this.refreshPlatform();

if (this.paiClusterConfig === undefined) {
throw new Error('PAI Cluster config is not initialized');
}
Expand Down Expand Up @@ -195,18 +190,21 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v2/jobs`,
method: 'POST',
body: paiJobConfig,
followAllRedirects: true,
headers: {
'Content-Type': 'text/yaml',
Authorization: `Bearer ${this.paiToken}`
}
};
request(submitJobRequest, (error, response, body) => {
// Status code 202 for success, refer https://github.com/microsoft/pai/blob/master/src/rest-server/docs/swagger.yaml
if ((error !== undefined && error !== null) || response.statusCode >= 400) {
const errorMessage: string = (error !== undefined && error !== null) ? error.message :
`start environment ${environment.envId} failed, http code:${response.statusCode}, http body: ${body}`;

this.log.error(errorMessage);
environment.status = 'FAILED';
deferred.reject(errorMessage);
}
deferred.resolve();
});
Expand Down Expand Up @@ -241,8 +239,11 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
try {
request(stopJobRequest, (error, response, _body) => {
try {
// Status code 202 for success.
if ((error !== undefined && error !== null) || (response && response.statusCode >= 400)) {
this.log.error(`OpenPAI: stop job ${environment.envId} failed with ${response.statusCode}\n${error}`);
const errorMessage: string = (error !== undefined && error !== null) ? error.message :
`OpenPAI: stop job ${environment.envId} failed, http code:${response.statusCode}, http body: ${_body}`;
this.log.error(`${errorMessage}`);
deferred.reject((error !== undefined && error !== null) ? error :
`Stop trial failed, http code: ${response.statusCode}`);
} else {
Expand All @@ -262,19 +263,6 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
return deferred.promise;
}

private async refreshPlatform(): Promise<void> {
if (this.paiClusterConfig && this.paiClusterConfig.passWord) {
try {
await this.updatePaiToken();
} catch (error) {
this.log.error(`${error}`);
if (this.paiToken === undefined) {
throw new Error(error);
}
}
}
}

private generateJobConfigInYamlFormat(environment: EnvironmentInformation): any {
if (this.paiTrialConfig === undefined) {
throw new Error('trial config is not initialized');
Expand Down Expand Up @@ -386,59 +374,4 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
return host;
}
}
/**
* Update pai token by the interval time or initialize the pai token
*/
protected async updatePaiToken(): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();

const currentTime: number = new Date().getTime();
//If pai token initialized and not reach the interval time, do not update
if (this.paiTokenUpdateTime !== undefined && (currentTime - this.paiTokenUpdateTime) < this.paiTokenUpdateInterval) {
return Promise.resolve();
}

if (this.paiClusterConfig === undefined) {
const paiClusterConfigError: string = `pai cluster config not initialized!`;
this.log.error(`${paiClusterConfigError}`);
throw Error(`${paiClusterConfigError}`);
}

const authenticationReq: request.Options = {
uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v1/token`,
method: 'POST',
json: true,
body: {
username: this.paiClusterConfig.userName,
password: this.paiClusterConfig.passWord
}
};

request(authenticationReq, (error: any, response: request.Response, body: any) => {
if (error !== undefined && error !== null) {
this.log.error(`Get PAI token failed: ${error.message}, authenticationReq: ${authenticationReq}`);
deferred.reject(new Error(`Get PAI token failed: ${error.message}`));
} else {
if (response.statusCode !== 200) {
this.log.error(`Get PAI token failed: get PAI Rest return code ${response.statusCode}, authenticationReq: ${authenticationReq}`);
deferred.reject(new Error(`Get PAI token failed code: ${response.statusCode}, body: ${response.body}, authenticationReq: ${authenticationReq}, please check paiConfig username or password`));
} else {
this.paiToken = body.token;
this.paiTokenUpdateTime = new Date().getTime();
deferred.resolve();
}
}
});

let timeoutId: NodeJS.Timer;
const timeoutDelay: Promise<void> = new Promise<void>((_resolve: Function, reject: Function): void => {
// Set timeout and reject the promise once reach timeout (5 seconds)
timeoutId = setTimeout(
() => reject(new Error('Get PAI token timeout. Please check your PAI cluster.')),
5000);
});

return Promise.race([timeoutDelay, deferred.promise])
.finally(() => { clearTimeout(timeoutId); });
}
}
8 changes: 8 additions & 0 deletions src/sdk/pynni/nni/compression/torch/speedup/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ def infer_modules_masks(self):
"""
for module_name, mask in self.masks.items():
_logger.debug('Start mask inference from %s', module_name)
if module_name not in self.torch_graph.name_to_node:
# this module is not traced in the torch_graph,
# jit.trace only correctly records functions and
# modules which are not data dependent (e.g., do
# not have conditionals on data in tensors)
# so, if a node is not traced, we just skip it.
_logger.warning('%s has mask, but not found in the traced graph, just skip it.', module_name)
continue
self.infer_module_mask(module_name, None, mask=mask)

def replace_compressed_modules(self):
Expand Down
Loading

0 comments on commit 9abd8c8

Please sign in to comment.