Skip to content

Commit

Permalink
add nnictl command to list trial results with highest/lowest metric (m…
Browse files Browse the repository at this point in the history
  • Loading branch information
tabVersion authored Aug 12, 2020
1 parent 10c177c commit 44954e0
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 2 deletions.
6 changes: 5 additions & 1 deletion docs/en_US/Tutorial/Nnictl.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,19 +305,23 @@ Debug mode will disable version check function in Trialkeeper.

* Description

You can use this command to show trial's information.
You can use this command to show trial's information. Note that if `head` or `tail` is set, only complete trials will be listed.
* Usage
```bash
nnictl trial ls
nnictl trial ls --head 10
nnictl trial ls --tail 10
```
* Options
|Name, shorthand|Required|Default|Description|
|------|------|------ |------|
|id| False| |ID of the experiment you want to set|
|--head|False||the number of items to be listed with the highest default metric|
|--tail|False||the number of items to be listed with the lowest default metric|
* __nnictl trial kill__
Expand Down
2 changes: 2 additions & 0 deletions tools/nni_cmd/nnictl.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def parse_args():
parser_trial_subparsers = parser_trial.add_subparsers()
parser_trial_ls = parser_trial_subparsers.add_parser('ls', help='list trial jobs')
parser_trial_ls.add_argument('id', nargs='?', help='the id of experiment')
parser_trial_ls.add_argument('--head', type=int, help='list the highest experiments on the default metric')
parser_trial_ls.add_argument('--tail', type=int, help='list the lowest experiments on the default metric')
parser_trial_ls.set_defaults(func=trial_ls)
parser_trial_kill = parser_trial_subparsers.add_parser('kill', help='kill trial jobs')
parser_trial_kill.add_argument('id', nargs='?', help='the id of experiment')
Expand Down
23 changes: 23 additions & 0 deletions tools/nni_cmd/nnictl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import re
import shutil
import subprocess
from functools import cmp_to_key
from datetime import datetime, timezone
from pathlib import Path
from subprocess import Popen
Expand Down Expand Up @@ -248,6 +249,20 @@ def stop_experiment(args):

def trial_ls(args):
'''List trial'''
def final_metric_data_cmp(lhs, rhs):
metric_l = json.loads(json.loads(lhs['finalMetricData'][0]['data']))
metric_r = json.loads(json.loads(rhs['finalMetricData'][0]['data']))
if isinstance(metric_l, float):
return metric_l - metric_r
elif isinstance(metric_l, dict):
return metric_l['default'] - metric_r['default']
else:
print_error('Unexpected data format. Please check your data.')
raise ValueError

if args.head and args.tail:
print_error('Head and tail cannot be set at the same time.')
return
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid')
Expand All @@ -259,6 +274,14 @@ def trial_ls(args):
response = rest_get(trial_jobs_url(rest_port), REST_TIME_OUT)
if response and check_response(response):
content = json.loads(response.text)
if args.head:
assert args.head > 0, 'The number of requested data must be greater than 0.'
content = sorted(filter(lambda x: 'finalMetricData' in x, content),
key=cmp_to_key(final_metric_data_cmp), reverse=True)[:args.head]
elif args.tail:
assert args.tail > 0, 'The number of requested data must be greater than 0.'
content = sorted(filter(lambda x: 'finalMetricData' in x, content),
key=cmp_to_key(final_metric_data_cmp))[:args.tail]
for index, value in enumerate(content):
content[index] = convert_time_stamp_to_date(value)
print(json.dumps(content, indent=4, sort_keys=True, separators=(',', ':')))
Expand Down
1 change: 0 additions & 1 deletion tools/nni_cmd/url_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def metric_data_url(port):
'''get metric_data url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, METRIC_DATA_API)


def check_status_url(port):
'''get check_status url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, CHECK_STATUS_API)
Expand Down

0 comments on commit 44954e0

Please sign in to comment.