Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
unify trial job id name (#3053)
Browse files Browse the repository at this point in the history
Co-authored-by: Ning Shang <[email protected]>
  • Loading branch information
J-shang and Ning Shang authored Nov 10, 2020
1 parent be652aa commit 050ee2b
Show file tree
Hide file tree
Showing 15 changed files with 37 additions and 43 deletions.
10 changes: 2 additions & 8 deletions nni/experiment/nni_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,7 @@ def __init__(self, json_obj):
self.value = None
self.trialJobId = None
for key in json_obj.keys():
if key == 'id':
setattr(self, 'trialJobId', json_obj[key])
elif hasattr(self, key):
setattr(self, key, json_obj[key])
setattr(self, key, json_obj[key])
self.value = json.loads(self.value)

def __repr__(self):
Expand Down Expand Up @@ -219,10 +216,7 @@ def __init__(self, json_obj):
self.finalMetricData = None
self.stderrPath = None
for key in json_obj.keys():
if key == 'id':
setattr(self, 'trialJobId', json_obj[key])
elif hasattr(self, key):
setattr(self, key, json_obj[key])
setattr(self, key, json_obj[key])
if self.hyperParameters:
self.hyperParameters = [TrialHyperParameters(json.loads(e)) for e in self.hyperParameters]
if self.finalMetricData:
Expand Down
12 changes: 6 additions & 6 deletions nni/tools/nnictl/nnictl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,9 @@ def log_trial(args):
if response and check_response(response):
content = json.loads(response.text)
for trial in content:
trial_id_list.append(trial.get('id'))
trial_id_list.append(trial.get('trialJobId'))
if trial.get('logPath'):
trial_id_path_dict[trial.get('id')] = trial['logPath']
trial_id_path_dict[trial.get('trialJobId')] = trial['logPath']
else:
print_error('Restful server is not running...')
exit(1)
Expand Down Expand Up @@ -674,7 +674,7 @@ def show_experiment_info():
content = json.loads(response.text)
for index, value in enumerate(content):
content[index] = convert_time_stamp_to_date(value)
print(TRIAL_MONITOR_CONTENT % (content[index].get('id'), content[index].get('startTime'), \
print(TRIAL_MONITOR_CONTENT % (content[index].get('trialJobId'), content[index].get('startTime'), \
content[index].get('endTime'), content[index].get('status')))
print(TRIAL_MONITOR_TAIL)

Expand Down Expand Up @@ -747,7 +747,7 @@ def groupby_trial_id(intermediate_results):
return
intermediate_results = groupby_trial_id(json.loads(intermediate_results_response.text))
for record in content:
record['intermediate'] = intermediate_results[record['id']]
record['intermediate'] = intermediate_results[record['trialJobId']]
if args.type == 'json':
with open(args.path, 'w') as file:
file.write(json.dumps(content))
Expand All @@ -759,9 +759,9 @@ def groupby_trial_id(intermediate_results):
formated_record['intermediate'] = '[' + ','.join(record['intermediate']) + ']'
record_value = json.loads(record['value'])
if not isinstance(record_value, (float, int)):
formated_record.update({**record['parameter'], **record_value, **{'id': record['id']}})
formated_record.update({**record['parameter'], **record_value, **{'trialJobId': record['trialJobId']}})
else:
formated_record.update({**record['parameter'], **{'reward': record_value, 'id': record['id']}})
formated_record.update({**record['parameter'], **{'reward': record_value, 'trialJobId': record['trialJobId']}})
trial_records.append(formated_record)
if not trial_records:
print_error('No trial results collected! Please check your trial log...')
Expand Down
4 changes: 2 additions & 2 deletions nni/tools/nnictl/tensorboard_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def parse_log_path(args, trial_content):
path_list = []
host_list = []
for trial in trial_content:
if args.trial_id and args.trial_id != 'all' and trial.get('id') != args.trial_id:
if args.trial_id and args.trial_id != 'all' and trial.get('trialJobId') != args.trial_id:
continue
pattern = r'(?P<head>.+)://(?P<host>.+):(?P<path>.*)'
match = re.search(pattern, trial['logPath'])
Expand All @@ -40,7 +40,7 @@ def copy_data_from_remote(args, nni_config, trial_content, path_list, host_list,
machine_dict[machine['ip']] = {'port': machine['port'], 'passwd': machine['passwd'], 'username': machine['username'],
'sshKeyPath': machine.get('sshKeyPath'), 'passphrase': machine.get('passphrase')}
for index, host in enumerate(host_list):
local_path = os.path.join(temp_nni_path, trial_content[index].get('id'))
local_path = os.path.join(temp_nni_path, trial_content[index].get('trialJobId'))
local_path_list.append(local_path)
print_normal('Copying log data from %s to %s' % (host + ':' + path_list[index], local_path))
sftp = create_ssh_sftp_client(host, machine_dict[host]['port'], machine_dict[host]['username'], machine_dict[host]['passwd'],
Expand Down
2 changes: 1 addition & 1 deletion test/nni_test/nnitest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def print_file_content(filepath):
def print_trial_job_log(training_service, trial_jobs_url):
trial_jobs = get_trial_jobs(trial_jobs_url)
for trial_job in trial_jobs:
trial_log_dir = os.path.join(get_experiment_dir(EXPERIMENT_URL), 'trials', trial_job['id'])
trial_log_dir = os.path.join(get_experiment_dir(EXPERIMENT_URL), 'trials', trial_job['trialJobId'])
log_files = ['stderr', 'trial.log'] if training_service == 'local' else ['stdout_log_collection.log']
for log_file in log_files:
print_file_content(os.path.join(trial_log_dir, log_file))
Expand Down
4 changes: 2 additions & 2 deletions ts/nni_manager/common/datastore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ interface MetricDataRecord {
}

interface TrialJobInfo {
id: string;
trialJobId: string;
sequenceId?: number;
status: TrialJobStatus;
startTime?: number;
Expand All @@ -63,7 +63,7 @@ interface HyperParameterFormat {
interface ExportedDataFormat {
parameter: Record<string, any>;
value: Record<string, any>;
id: string;
trialJobId: string;
}

abstract class DataStore {
Expand Down
12 changes: 6 additions & 6 deletions ts/nni_manager/core/nniDataStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class NNIDataStore implements DataStore {
const oneEntry: ExportedDataFormat = {
parameter: parameters.parameters,
value: JSON.parse(job.finalMetricData[0].data),
id: job.id
trialJobId: job.trialJobId
};
exportedData.push(oneEntry);
} else {
Expand All @@ -188,7 +188,7 @@ class NNIDataStore implements DataStore {
const oneEntry: ExportedDataFormat = {
parameter: value,
value: metricValue,
id: job.id
trialJobId: job.trialJobId
};
exportedData.push(oneEntry);
}
Expand Down Expand Up @@ -229,7 +229,7 @@ class NNIDataStore implements DataStore {
}
if (!(status !== undefined && jobInfo.status !== status)) {
if (jobInfo.status === 'SUCCEEDED') {
jobInfo.finalMetricData = finalMetricsMap.get(jobInfo.id);
jobInfo.finalMetricData = finalMetricsMap.get(jobInfo.trialJobId);
}
result.push(jobInfo);
}
Expand Down Expand Up @@ -320,7 +320,7 @@ class NNIDataStore implements DataStore {
jobInfo = map.get(record.trialJobId);
} else {
jobInfo = {
id: record.trialJobId,
trialJobId: record.trialJobId,
status: this.getJobStatusByLatestEvent('UNKNOWN', record.event),
hyperParameters: []
};
Expand Down Expand Up @@ -364,14 +364,14 @@ class NNIDataStore implements DataStore {
const newHParam: any = this.parseHyperParameter(record.data);
if (newHParam !== undefined) {
if (jobInfo.hyperParameters !== undefined) {
let hParamIds: Set<number> | undefined = hParamIdMap.get(jobInfo.id);
let hParamIds: Set<number> | undefined = hParamIdMap.get(jobInfo.trialJobId);
if (hParamIds === undefined) {
hParamIds = new Set();
}
if (!hParamIds.has(newHParam.parameter_index)) {
jobInfo.hyperParameters.push(JSON.stringify(newHParam));
hParamIds.add(newHParam.parameter_index);
hParamIdMap.set(jobInfo.id, hParamIds);
hParamIdMap.set(jobInfo.trialJobId, hParamIds);
}
} else {
assert(false, 'jobInfo.hyperParameters is undefined');
Expand Down
4 changes: 2 additions & 2 deletions ts/nni_manager/core/nnimanager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ class NNIManager implements Manager {
// Check the final status for WAITING and RUNNING jobs
await Promise.all(allTrialJobs
.filter((job: TrialJobInfo) => job.status === 'WAITING' || job.status === 'RUNNING')
.map((job: TrialJobInfo) => this.dataStore.storeTrialJobEvent('FAILED', job.id)));
.map((job: TrialJobInfo) => this.dataStore.storeTrialJobEvent('FAILED', job.trialJobId)));

// Collect generated trials and imported trials
const finishedTrialData: string = await this.exportData();
Expand Down Expand Up @@ -304,7 +304,7 @@ class NNIManager implements Manager {
// FIXME: can this be undefined?
trial.sequenceId !== undefined && minSeqId <= trial.sequenceId && trial.sequenceId <= maxSeqId
));
const targetTrialIds = new Set(targetTrials.map(trial => trial.id));
const targetTrialIds = new Set(targetTrials.map(trial => trial.trialJobId));

const allMetrics = await this.dataStore.getMetricData();
return allMetrics.filter(metric => targetTrialIds.has(metric.trialJobId));
Expand Down
6 changes: 3 additions & 3 deletions ts/nni_manager/core/test/mockedDatastore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class MockedDataStore implements DataStore {
}
if (!(status && jobInfo.status !== status)) {
if (jobInfo.status === 'SUCCEEDED') {
jobInfo.finalMetricData = await this.getFinalMetricData(jobInfo.id);
jobInfo.finalMetricData = await this.getFinalMetricData(jobInfo.trialJobId);
}
result.push(jobInfo);
}
Expand Down Expand Up @@ -206,7 +206,7 @@ class MockedDataStore implements DataStore {

public getTrialJob(trialJobId: string): Promise<TrialJobInfo> {
return Promise.resolve({
id: '1234',
trialJobId: '1234',
status: 'SUCCEEDED',
startTime: Date.now(),
endTime: Date.now()
Expand Down Expand Up @@ -242,7 +242,7 @@ class MockedDataStore implements DataStore {
jobInfo = map.get(record.trialJobId);
} else {
jobInfo = {
id: record.trialJobId,
trialJobId: record.trialJobId,
status: this.getJobStatusByLatestEvent(record.event),
};
}
Expand Down
2 changes: 1 addition & 1 deletion ts/nni_manager/core/test/nnimanager.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ describe('Unit test for nnimanager', function () {
it('test getTrialJob valid', () => {
//query a exist id
return nniManager.getTrialJob('1234').then(function (trialJobDetail) {
expect(trialJobDetail.id).to.be.equal('1234');
expect(trialJobDetail.trialJobId).to.be.equal('1234');
}).catch((error) => {
assert.fail(error);
})
Expand Down
6 changes: 3 additions & 3 deletions ts/nni_manager/rest_server/test/mockedNNIManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ export class MockedNNIManager extends Manager {
public getTrialJob(trialJobId: string): Promise<TrialJobInfo> {
const deferred: Deferred<TrialJobInfo> = new Deferred<TrialJobInfo>();
const jobInfo: TrialJobInfo = {
id: '1234',
trialJobId: '1234',
status: 'SUCCEEDED',
startTime: Date.now(),
endTime: Date.now()
Expand Down Expand Up @@ -152,7 +152,7 @@ export class MockedNNIManager extends Manager {
}
public listTrialJobs(status?: TrialJobStatus): Promise<TrialJobInfo[]> {
const job1: TrialJobInfo = {
id: '1234',
trialJobId: '1234',
status: 'SUCCEEDED',
startTime: Date.now(),
endTime: Date.now(),
Expand All @@ -166,7 +166,7 @@ export class MockedNNIManager extends Manager {
}]
};
const job2: TrialJobInfo = {
id: '3456',
trialJobId: '3456',
status: 'FAILED',
startTime: Date.now(),
endTime: Date.now(),
Expand Down
2 changes: 1 addition & 1 deletion ts/nni_manager/rest_server/test/restserver.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ describe('Unit test for rest server', () => {
assert.fail(err.message);
} else {
expect(res.statusCode).to.equal(200);
expect(JSON.parse(body).id).to.equal('1234');
expect(JSON.parse(body).trialJobId).to.equal('1234');
}
done();
});
Expand Down
2 changes: 1 addition & 1 deletion ts/webui/src/components/Overview.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class Overview extends React.Component<{}, OverviewState> {
</Stack>
</div>
</Stack>
<SuccessTable trialIds={bestTrials.map(trial => trial.info.id)} />
<SuccessTable trialIds={bestTrials.map(trial => trial.info.trialJobId)} />
</div>
<div className='overviewCommand1'>
<Command1 />
Expand Down
2 changes: 1 addition & 1 deletion ts/webui/src/static/interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ interface MetricDataRecord {
}

interface TrialJobInfo {
id: string;
trialJobId: string;
sequenceId: number;
status: string;
startTime?: number;
Expand Down
6 changes: 3 additions & 3 deletions ts/webui/src/static/model/trial.ts
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ class Trial implements TableObj {
}

return {
key: this.info.id,
key: this.info.trialJobId,
sequenceId: this.info.sequenceId,
id: this.info.id,
id: this.info.trialJobId,
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
startTime: this.info.startTime!,
endTime: this.info.endTime,
Expand All @@ -169,7 +169,7 @@ class Trial implements TableObj {
}

get id(): string {
return this.info.id;
return this.info.trialJobId;
}

get duration(): number {
Expand Down
6 changes: 3 additions & 3 deletions ts/webui/src/static/model/trialmanager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,11 @@ class TrialManager {
requestAxios(`${MANAGER_IP}/trial-jobs`)
.then(data => {
for (const trialInfo of data as TrialJobInfo[]) {
if (this.trials.has(trialInfo.id)) {
if (this.trials.has(trialInfo.trialJobId)) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
updated = this.trials.get(trialInfo.id)!.updateTrialJobInfo(trialInfo) || updated;
updated = this.trials.get(trialInfo.trialJobId)!.updateTrialJobInfo(trialInfo) || updated;
} else {
this.trials.set(trialInfo.id, new Trial(trialInfo, undefined));
this.trials.set(trialInfo.trialJobId, new Trial(trialInfo, undefined));
updated = true;
}
this.maxSequenceId = Math.max(this.maxSequenceId, trialInfo.sequenceId);
Expand Down

0 comments on commit 050ee2b

Please sign in to comment.