Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pull #107

Merged
merged 6 commits into from
Aug 10, 2020
Merged

pull #107

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/nni_manager/rest_server/restValidationSchemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ export namespace ValidationSchemas {
}),
pai_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
userName: joi.string().min(1).required(),
passWord: joi.string().min(1),
token: joi.string().min(1),
host: joi.string().min(1).required(),
reuse: joi.boolean(),
Expand Down
7 changes: 4 additions & 3 deletions src/nni_manager/training_service/pai/paiJobInfoCollector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ export class PAIJobInfoCollector {
// Rest call to get PAI job info and update status
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
const getJobInfoRequest: request.Options = {
uri: `${protocol}://${paiClusterConfig.host}/rest-server/api/v1/user/${paiClusterConfig.userName}/jobs/${paiTrialJob.paiJobName}`,
uri: `${protocol}://${paiClusterConfig.host}/rest-server/api/v2/jobs/${paiClusterConfig.userName}~${paiTrialJob.paiJobName}`,
method: 'GET',
json: true,
headers: {
Expand All @@ -63,8 +63,9 @@ export class PAIJobInfoCollector {

//TODO : pass in request timeout param?
request(getJobInfoRequest, (error: Error, response: request.Response, _body: any) => {
if ((error !== undefined && error !== null) || response.statusCode >= 500) {
this.log.error(`PAI Training service: get job info for trial ${paiTrialJob.id} from PAI Cluster failed!`);
// Status code 200 for success
if ((error !== undefined && error !== null) || response.statusCode >= 400) {
// The job refresh time could be ealier than job submission, so it might return 404 error code, need refactor
// Queried PAI job info failed, set job status to UNKNOWN
if (paiTrialJob.status === 'WAITING' || paiTrialJob.status === 'RUNNING') {
paiTrialJob.status = 'UNKNOWN';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,7 @@ class PAIK8STrainingService extends PAITrainingService {
this.paiJobRestServer = new PAIJobRestServer(component.get(PAIK8STrainingService));
this.paiClusterConfig = <PAIClusterConfig>JSON.parse(value);
this.paiClusterConfig.host = this.formatPAIHost(this.paiClusterConfig.host);
if (this.paiClusterConfig.passWord) {
// Get PAI authentication token
await this.updatePaiToken();
} else if (this.paiClusterConfig.token) {
this.paiToken = this.paiClusterConfig.token;
}
this.paiToken = this.paiClusterConfig.token;
break;

case TrialConfigMetadataKey.TRIAL_CONFIG: {
Expand Down Expand Up @@ -290,18 +285,20 @@ class PAIK8STrainingService extends PAITrainingService {
uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v2/jobs`,
method: 'POST',
body: paiJobConfig,
followAllRedirects: true,
headers: {
'Content-Type': 'text/yaml',
Authorization: `Bearer ${this.paiToken}`
}
};
request(submitJobRequest, (error: Error, response: request.Response, body: any) => {
// If submit success, will get status code 202. refer: https://github.com/microsoft/pai/blob/master/src/rest-server/docs/swagger.yaml
if ((error !== undefined && error !== null) || response.statusCode >= 400) {
const errorMessage: string = (error !== undefined && error !== null) ? error.message :
`Submit trial ${trialJobId} failed, http code:${response.statusCode}, http body: ${body}`;

this.log.error(errorMessage);
trialJobDetail.status = 'FAILED';
deferred.reject(errorMessage);
} else {
trialJobDetail.submitTime = Date.now();
}
Expand Down
4 changes: 2 additions & 2 deletions src/nni_manager/training_service/pai/paiTrainingService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,7 @@ abstract class PAITrainingService implements TrainingService {
}

const stopJobRequest: request.Options = {
uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v1/user/${this.paiClusterConfig.userName}\
/jobs/${trialJobDetail.paiJobName}/executionType`,
uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v2/jobs/${this.paiClusterConfig.userName}~${trialJobDetail.paiJobName}/executionType`,
method: 'PUT',
json: true,
body: { value: 'STOP' },
Expand All @@ -178,6 +177,7 @@ abstract class PAITrainingService implements TrainingService {
const deferred: Deferred<void> = new Deferred<void>();

request(stopJobRequest, (error: Error, response: request.Response, _body: any) => {
// Status code 202 for success.
if ((error !== undefined && error !== null) || response.statusCode >= 400) {
this.log.error(`PAI Training service: stop trial ${trialJobId} to PAI Cluster failed!`);
deferred.reject((error !== undefined && error !== null) ? error.message :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class RemoteMachineTrainingService implements TrainingService {
private nniManagerIpConfig?: NNIManagerIpConfig;
private versionCheck: boolean = true;
private logCollection: string;
private sshConnectionPromises: any[];

constructor(@component.Inject timer: ObservableTimer) {
this.metricsEmitter = new EventEmitter();
Expand All @@ -65,6 +66,7 @@ class RemoteMachineTrainingService implements TrainingService {
this.machineCopyExpCodeDirPromiseMap = new Map<RemoteMachineMeta, Promise<void>>();
this.machineExecutorManagerMap = new Map<RemoteMachineMeta, ExecutorManager>();
this.jobQueue = [];
this.sshConnectionPromises = [];
this.expRootDir = getExperimentRootDir();
this.timer = timer;
this.log = getLogger();
Expand All @@ -80,6 +82,12 @@ class RemoteMachineTrainingService implements TrainingService {
await restServer.start();
restServer.setEnableVersionCheck = this.versionCheck;
this.log.info('Run remote machine training service.');
if (this.sshConnectionPromises.length > 0) {
await Promise.all(this.sshConnectionPromises);
this.log.info('ssh connection initialized!');
// set sshConnectionPromises to [] to avoid log information duplicated
this.sshConnectionPromises = [];
}
while (!this.stopping) {
while (this.jobQueue.length > 0) {
this.updateGpuReservation();
Expand Down Expand Up @@ -408,7 +416,6 @@ class RemoteMachineTrainingService implements TrainingService {
//TO DO: verify if value's format is wrong, and json parse failed, how to handle error
const rmMetaList: RemoteMachineMeta[] = <RemoteMachineMeta[]>JSON.parse(machineList);

const connectionPromises = [];
for (const rmMeta of rmMetaList) {
rmMeta.occupiedGpuIndexMap = new Map<number, number>();
const executorManager: ExecutorManager = new ExecutorManager(rmMeta);
Expand All @@ -417,11 +424,9 @@ class RemoteMachineTrainingService implements TrainingService {
this.log.debug(`reached ${executor.name}`);
this.machineExecutorManagerMap.set(rmMeta, executorManager);
this.log.debug(`initializing ${executor.name}`);
connectionPromises.push(this.initRemoteMachineOnConnected(rmMeta, executor));
this.log.info(`connected to ${executor.name}`);
this.sshConnectionPromises.push(this.initRemoteMachineOnConnected(rmMeta, executor));
this.log.info(`connecting to ${executor.name}`);
}

await Promise.all(connectionPromises);
}

private async initRemoteMachineOnConnected(rmMeta: RemoteMachineMeta, executor: ShellExecutor): Promise<void> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import { AMLClient } from '../aml/amlClient';
import { AMLClusterConfig, AMLEnvironmentInformation, AMLTrialConfig } from '../aml/amlConfig';
import { AMLCommandChannel } from '../channels/amlCommandChannel';
import { CommandChannel } from "../commandChannel";
import { EnvironmentInformation, EnvironmentService, EnvironmentStatus } from '../environment';
import { EnvironmentInformation, EnvironmentService } from '../environment';


/**
Expand Down Expand Up @@ -74,7 +74,7 @@ export class AMLEnvironmentService extends EnvironmentService {
environments.forEach(async (environment) => {
const amlClient = (environment as AMLEnvironmentInformation).amlClient;
if (!amlClient) {
throw new Error('AML client not initialized!');
return Promise.reject('AML client not initialized!');
}
const newStatus = await amlClient.updateStatus(environment.status);
switch (newStatus.toUpperCase()) {
Expand All @@ -90,8 +90,8 @@ export class AMLEnvironmentService extends EnvironmentService {
environment.setStatus('SUCCEEDED');
break;
case 'FAILED':
environment.setStatus(newStatus.toUpperCase() as EnvironmentStatus);
break;
environment.setStatus('FAILED');
return Promise.reject(`AML: job ${environment.envId} is failed!`);
case 'STOPPED':
case 'STOPPING':
environment.setStatus('USER_CANCELED');
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,12 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
private paiTrialConfig: NNIPAIK8STrialConfig | undefined;
private paiJobConfig: any;
private paiToken?: string;
private paiTokenUpdateTime?: number;
private readonly paiTokenUpdateInterval: number;
private protocol: string = 'http';

private experimentId: string;

constructor() {
super();
this.paiTokenUpdateInterval = 7200000; //2hours
this.experimentId = getExperimentId();
}

Expand All @@ -53,12 +50,7 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
case TrialConfigMetadataKey.PAI_CLUSTER_CONFIG:
this.paiClusterConfig = <PAIClusterConfig>JSON.parse(value);
this.paiClusterConfig.host = this.formatPAIHost(this.paiClusterConfig.host);
if (this.paiClusterConfig.passWord) {
// Get PAI authentication token
await this.updatePaiToken();
} else if (this.paiClusterConfig.token) {
this.paiToken = this.paiClusterConfig.token;
}
this.paiToken = this.paiClusterConfig.token;
break;

case TrialConfigMetadataKey.TRIAL_CONFIG: {
Expand Down Expand Up @@ -95,7 +87,6 @@ export class OpenPaiEnvironmentService extends EnvironmentService {

public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
await this.refreshPlatform();

if (this.paiClusterConfig === undefined) {
throw new Error('PAI Cluster config is not initialized');
Expand All @@ -115,9 +106,12 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
};

request(getJobInfoRequest, async (error: any, response: request.Response, body: any) => {
// Status code 200 for success
if ((error !== undefined && error !== null) || response.statusCode >= 400) {
this.log.error(`OpenPAI: get environment list from PAI Cluster failed!\nerror: ${error}`);
deferred.reject(error);
const errorMessage: string = (error !== undefined && error !== null) ? error.message :
`OpenPAI: get environment list from PAI Cluster failed!, http code:${response.statusCode}, http body: ${JSON.stringify(body)}`;
this.log.error(`${errorMessage}`);
deferred.reject(errorMessage);
} else {
const jobInfos = new Map<string, any>();
body.forEach((jobInfo: any) => {
Expand All @@ -133,8 +127,11 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
case 'RUNNING':
case 'WAITING':
case 'SUCCEEDED':
environment.setStatus(jobResponse.state);
break;
case 'FAILED':
environment.setStatus(jobResponse.state);
deferred.reject(`OpenPAI: job ${environment.envId} is failed!`);
break;
case 'STOPPED':
case 'STOPPING':
Expand Down Expand Up @@ -166,8 +163,6 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
public async startEnvironment(environment: EnvironmentInformation): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();

await this.refreshPlatform();

if (this.paiClusterConfig === undefined) {
throw new Error('PAI Cluster config is not initialized');
}
Expand Down Expand Up @@ -195,18 +190,21 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v2/jobs`,
method: 'POST',
body: paiJobConfig,
followAllRedirects: true,
headers: {
'Content-Type': 'text/yaml',
Authorization: `Bearer ${this.paiToken}`
}
};
request(submitJobRequest, (error, response, body) => {
// Status code 202 for success, refer https://github.com/microsoft/pai/blob/master/src/rest-server/docs/swagger.yaml
if ((error !== undefined && error !== null) || response.statusCode >= 400) {
const errorMessage: string = (error !== undefined && error !== null) ? error.message :
`start environment ${environment.envId} failed, http code:${response.statusCode}, http body: ${body}`;

this.log.error(errorMessage);
environment.status = 'FAILED';
deferred.reject(errorMessage);
}
deferred.resolve();
});
Expand Down Expand Up @@ -241,8 +239,11 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
try {
request(stopJobRequest, (error, response, _body) => {
try {
// Status code 202 for success.
if ((error !== undefined && error !== null) || (response && response.statusCode >= 400)) {
this.log.error(`OpenPAI: stop job ${environment.envId} failed with ${response.statusCode}\n${error}`);
const errorMessage: string = (error !== undefined && error !== null) ? error.message :
`OpenPAI: stop job ${environment.envId} failed, http code:${response.statusCode}, http body: ${_body}`;
this.log.error(`${errorMessage}`);
deferred.reject((error !== undefined && error !== null) ? error :
`Stop trial failed, http code: ${response.statusCode}`);
} else {
Expand All @@ -262,19 +263,6 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
return deferred.promise;
}

private async refreshPlatform(): Promise<void> {
if (this.paiClusterConfig && this.paiClusterConfig.passWord) {
try {
await this.updatePaiToken();
} catch (error) {
this.log.error(`${error}`);
if (this.paiToken === undefined) {
throw new Error(error);
}
}
}
}

private generateJobConfigInYamlFormat(environment: EnvironmentInformation): any {
if (this.paiTrialConfig === undefined) {
throw new Error('trial config is not initialized');
Expand Down Expand Up @@ -386,59 +374,4 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
return host;
}
}
/**
* Update pai token by the interval time or initialize the pai token
*/
protected async updatePaiToken(): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();

const currentTime: number = new Date().getTime();
//If pai token initialized and not reach the interval time, do not update
if (this.paiTokenUpdateTime !== undefined && (currentTime - this.paiTokenUpdateTime) < this.paiTokenUpdateInterval) {
return Promise.resolve();
}

if (this.paiClusterConfig === undefined) {
const paiClusterConfigError: string = `pai cluster config not initialized!`;
this.log.error(`${paiClusterConfigError}`);
throw Error(`${paiClusterConfigError}`);
}

const authenticationReq: request.Options = {
uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v1/token`,
method: 'POST',
json: true,
body: {
username: this.paiClusterConfig.userName,
password: this.paiClusterConfig.passWord
}
};

request(authenticationReq, (error: any, response: request.Response, body: any) => {
if (error !== undefined && error !== null) {
this.log.error(`Get PAI token failed: ${error.message}, authenticationReq: ${authenticationReq}`);
deferred.reject(new Error(`Get PAI token failed: ${error.message}`));
} else {
if (response.statusCode !== 200) {
this.log.error(`Get PAI token failed: get PAI Rest return code ${response.statusCode}, authenticationReq: ${authenticationReq}`);
deferred.reject(new Error(`Get PAI token failed code: ${response.statusCode}, body: ${response.body}, authenticationReq: ${authenticationReq}, please check paiConfig username or password`));
} else {
this.paiToken = body.token;
this.paiTokenUpdateTime = new Date().getTime();
deferred.resolve();
}
}
});

let timeoutId: NodeJS.Timer;
const timeoutDelay: Promise<void> = new Promise<void>((_resolve: Function, reject: Function): void => {
// Set timeout and reject the promise once reach timeout (5 seconds)
timeoutId = setTimeout(
() => reject(new Error('Get PAI token timeout. Please check your PAI cluster.')),
5000);
});

return Promise.race([timeoutDelay, deferred.promise])
.finally(() => { clearTimeout(timeoutId); });
}
}
8 changes: 8 additions & 0 deletions src/sdk/pynni/nni/compression/torch/speedup/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ def infer_modules_masks(self):
"""
for module_name, mask in self.masks.items():
_logger.debug('Start mask inference from %s', module_name)
if module_name not in self.torch_graph.name_to_node:
# this module is not traced in the torch_graph,
# jit.trace only correctly records functions and
# modules which are not data dependent (e.g., do
# not have conditionals on data in tensors)
# so, if a node is not traced, we just skip it.
_logger.warning('%s has mask, but not found in the traced graph, just skip it.', module_name)
continue
self.infer_module_mask(module_name, None, mask=mask)

def replace_compressed_modules(self):
Expand Down
Loading