Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Improve integration test #21

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions test/naive/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
__pycache__

tuner_search_space.json
tuner_result.txt
assessor_result.txt
29 changes: 29 additions & 0 deletions test/naive/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
## Usage

To test before installing:

./run.py --preinstall

To test the integrity of installation:

./run.py

It will print `PASS` in green eventually if everything works well.

## Details

This test case tests the communication between trials and tuner/assessor.

The naive trials receive an integer `x` as parameter, and reports `x`, `x²`, `x³`, ... , `x¹⁰` as metrics.

The naive tuner simply generates the sequence of natural numbers, and print received metrics to `tuner_result.txt`.

The naive assessor kills trials when `sum(metrics) % 11 == 1`, and print killed trials to `assessor_result.txt`.

When tuner and assessor exit without exception, they will append `DONE` to corresponding result file. Otherwise they append `ERROR`.

## Issues

* Private APIs are used to detect whether tuner and assessor have terminated successfully.
* The output of REST server is not tested.
* Remote machine training service is not tested.
5 changes: 4 additions & 1 deletion test/naive/naive_assessor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging
import os

from nni.assessor import Assessor, AssessResult

_logger = logging.getLogger('NaiveAssessor')
_logger.info('start')
_result = open('/tmp/nni_assessor_result.txt', 'w')

_pwd = os.path.dirname(__file__)
_result = open(os.path.join(_pwd, 'assessor_result.txt'), 'w')

class NaiveAssessor(Assessor):
def __init__(self, optimize_mode):
Expand Down
7 changes: 5 additions & 2 deletions test/naive/naive_tuner.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import json
import logging
import os

from nni.tuner import Tuner

_logger = logging.getLogger('NaiveTuner')
_logger.info('start')
_result = open('/tmp/nni_tuner_result.txt', 'w')

_pwd = os.path.dirname(__file__)
_result = open(os.path.join(_pwd, 'tuner_result.txt'), 'w')

class NaiveTuner(Tuner):
def __init__(self, optimize_mode):
Expand All @@ -24,7 +27,7 @@ def receive_trial_result(self, parameter_id, parameters, reward):

def update_search_space(self, search_space):
_logger.info('update_search_space: %s' % search_space)
with open('/tmp/nni_tuner_search_space.json', 'w') as file_:
with open(os.path.join(_pwd, 'tuner_search_space.json'), 'w') as file_:
json.dump(search_space, file_)

def _on_exit(self):
Expand Down
Empty file modified test/naive/nnictl
100644 → 100755
Empty file.
Empty file modified test/naive/nnimanager
100644 → 100755
Empty file.
37 changes: 26 additions & 11 deletions test/naive/run.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import subprocess
import time
import traceback
import sys

GREEN = '\33[32m'
RED = '\33[31m'
Expand All @@ -18,15 +19,24 @@ def read_last_line(file_name):
except (FileNotFoundError, ValueError):
return None

def run():
os.environ['PATH'] = os.environ['PATH'] + ':' + os.environ['PWD']
def run(installed = True):
if not installed:
os.environ['PATH'] = os.environ['PATH'] + ':' + os.environ['PWD']
sdk_path = os.path.abspath('../../src/sdk/pynni')
cmd_path = os.path.abspath('../../tools')
pypath = os.environ.get('PYTHONPATH')
if pypath:
pypath = ':'.join([pypath, sdk_path, cmd_path])
else:
pypath = ':'.join([sdk_path, cmd_path])
os.environ['PYTHONPATH'] = pypath

with contextlib.suppress(FileNotFoundError):
os.remove('tuner_search_space.txt')
with contextlib.suppress(FileNotFoundError):
os.remove('tuner_result.txt')
with contextlib.suppress(FileNotFoundError):
os.remove('/tmp/nni_assessor_result.txt')
os.remove('assessor_result.txt')

proc = subprocess.run(['nnictl', 'create', '--config', 'local.yml'])
assert proc.returncode == 0, '`nnictl create` failed with code %d' % proc.returncode
Expand All @@ -37,8 +47,8 @@ def run():
for _ in range(60):
time.sleep(1)

tuner_status = read_last_line('/tmp/nni_tuner_result.txt')
assessor_status = read_last_line('/tmp/nni_assessor_result.txt')
tuner_status = read_last_line('tuner_result.txt')
assessor_status = read_last_line('assessor_result.txt')

assert tuner_status != 'ERROR', 'Tuner exited with error'
assert assessor_status != 'ERROR', 'Assessor exited with error'
Expand All @@ -47,7 +57,7 @@ def run():
break

if tuner_status is not None:
for line in open('/tmp/nni_tuner_result.txt'):
for line in open('tuner_result.txt'):
if line.strip() in ('DONE', 'ERROR'):
break
trial = int(line.split(' ')[0])
Expand All @@ -58,28 +68,33 @@ def run():
assert tuner_status == 'DONE' and assessor_status == 'DONE', 'Failed to finish in 1 min'

ss1 = json.load(open('search_space.json'))
ss2 = json.load(open('/tmp/nni_tuner_search_space.json'))
ss2 = json.load(open('tuner_search_space.json'))
assert ss1 == ss2, 'Tuner got wrong search space'

tuner_result = set(open('/tmp/nni_tuner_result.txt'))
tuner_result = set(open('tuner_result.txt'))
expected = set(open('expected_tuner_result.txt'))
# Trials may complete before NNI gets assessor's result,
# so it is possible to have more final result than expected
assert tuner_result.issuperset(expected), 'Bad tuner result'

assessor_result = set(open('/tmp/nni_assessor_result.txt'))
assessor_result = set(open('assessor_result.txt'))
expected = set(open('expected_assessor_result.txt'))
assert assessor_result == expected, 'Bad assessor result'

if __name__ == '__main__':
installed = (sys.argv[-1] != '--preinstall')

try:
run()
run(installed)
# TODO: check the output of rest server
print(GREEN + 'PASS' + CLEAR)
ret_code = 0
except Exception as error:
print(RED + 'FAIL' + CLEAR)
print('%r' % error)
traceback.print_exc()
raise error
ret_code = 1

subprocess.run(['nnictl', 'stop'])

sys.exit(ret_code)