Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

add nnictl command to list trial results with highest/lowest metric #2747

Merged
merged 15 commits into from
Aug 12, 2020
4 changes: 4 additions & 0 deletions docs/en_US/Tutorial/Nnictl.md
Original file line number Diff line number Diff line change
Expand Up @@ -311,13 +311,17 @@ Debug mode will disable version check function in Trialkeeper.

```bash
nnictl trial ls
nnictl trial ls --head 10
nnictl trial ls --tail 10
tabVersion marked this conversation as resolved.
Show resolved Hide resolved
```

* Options

|Name, shorthand|Required|Default|Description|
|------|------|------ |------|
|id| False| |ID of the experiment you want to set|
|--head|False||the number of items to be listed with the highest default metric|
|--tail|False||the number of items to be listed with the lowest default metric|

* __nnictl trial kill__

Expand Down
2 changes: 2 additions & 0 deletions tools/nni_cmd/nnictl.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def parse_args():
parser_trial_subparsers = parser_trial.add_subparsers()
parser_trial_ls = parser_trial_subparsers.add_parser('ls', help='list trial jobs')
parser_trial_ls.add_argument('id', nargs='?', help='the id of experiment')
parser_trial_ls.add_argument('--head', type=int, help='list the highest experiments on the default metric')
parser_trial_ls.add_argument('--tail', type=int, help='list the lowest experiments on the default metric')
parser_trial_ls.set_defaults(func=trial_ls)
parser_trial_kill = parser_trial_subparsers.add_parser('kill', help='kill trial jobs')
parser_trial_kill.add_argument('id', nargs='?', help='the id of experiment')
Expand Down
24 changes: 24 additions & 0 deletions tools/nni_cmd/nnictl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import re
import shutil
import subprocess
from functools import cmp_to_key
from datetime import datetime, timezone
from pathlib import Path
from subprocess import Popen
Expand Down Expand Up @@ -248,6 +249,20 @@ def stop_experiment(args):

def trial_ls(args):
'''List trial'''
def final_metric_data_cmp(lhs, rhs):
metric_l = json.loads(lhs['finalMetricData'][0]['data'])
metric_r = json.loads(rhs['finalMetricData'][0]['data'])
if isinstance(metric_l, float):
return metric_l - metric_r
elif isinstance(metric_l, dict):
return metric_l['default'] - metric_r['default']
else:
print_error('Unexpected data format. Please check your data.')
raise ValueError

if args.head and args.tail:
print_error('Head and tail cannot be set at the same time.')
return
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid')
Expand All @@ -259,6 +274,15 @@ def trial_ls(args):
response = rest_get(trial_jobs_url(rest_port), REST_TIME_OUT)
if response and check_response(response):
content = json.loads(response.text)
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
if args.head:
assert int(args.head) > 0, 'The number of requested data must be greater than 0.'
args.head = min(int(args.head), len(content))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why use int(args.head)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't know if specify the arg's type as int, it will be converted to int.

content = sorted(filter(lambda x: 'finalMetricData' in x, content),
key=cmp_to_key(final_metric_data_cmp), reverse=True)[:args.head]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this logic is not correct

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
elif args.tail:
assert int(args.tail) > 0, 'The number of requested data must be greater than 0.'
args.tail = min(int(args.tail), len(content))
content = sorted(content, key=cmp_to_key(final_metric_data_cmp))[:args.tail]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no filter here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

for index, value in enumerate(content):
content[index] = convert_time_stamp_to_date(value)
print(json.dumps(content, indent=4, sort_keys=True, separators=(',', ':')))
Expand Down
1 change: 0 additions & 1 deletion tools/nni_cmd/url_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def metric_data_url(port):
'''get metric_data url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, METRIC_DATA_API)


def check_status_url(port):
'''get check_status url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, CHECK_STATUS_API)
Expand Down