Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Improve Classic NAS mutator #1865

Merged
merged 5 commits into from
Dec 23, 2019

Conversation

ultmaster
Copy link
Contributor

  1. Use mutator instead of BaseMutator as base class to simplify the code.
  2. Add more validation to detect possible errors and give reasonable error messages.
  3. Fix bug when search space path is an absolute path.

New utility functions in preparation for SPOS:

  1. ModelCheckpoint
  2. Add access shortcut for average meters.

if 'NNI_GEN_SEARCH_SPACE' in os.environ:
self._chosen_arch = {}
self._search_space = self._generate_search_space()
if "NNI_GEN_SEARCH_SPACE" in os.environ:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change the "NNI_GEN_SEARCH_SPACE" as an enum string?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

n_chosen = val['_value']['n_chosen']
chosen_arch[key] = {'_value': choices[:n_chosen], '_idx': list(range(n_chosen))}
for key, val in self._search_space.items():
if val["_type"] == "layer_choice":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here.

class ModelCheckpoint(Callback):
def __init__(self, checkpoint_dir, every="epoch"):
super().__init__()
assert every == "epoch"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

support only one specific value? so why we make it a parameter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The roadmap is to support checkpoint on every mini-batch. But it's not implemented yet.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove, because it is confusing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in another PR.

self.trainer.export(dest_path)


class ModelCheckpoint(Callback):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can checkpointed model be used for resume? e.g., learning rate, epoch number

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is in another PR. And no, not supported.

else:
# get chosen arch from tuner
self._chosen_arch = nni.get_next_parameter()
self._cache = self.sample_final()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is _cache used for in classic_nas?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_cache is stored so that on_forward_layer_choice implemented in Mutator can retrieve the decision.

Comment on lines 89 to 94
# doesn't support multihot for layer choice yet
onehot_list = [False] * mutable.length
assert 0 <= idx < mutable.length and search_space_ref[idx] == value, \
"Index '{}' in search space '{}' is not '{}'".format(idx, search_space_ref, value)
onehot_list[idx] = True
result[mutable.key] = torch.tensor(onehot_list, dtype=torch.bool) # pylint: disable=not-callable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest to extract this part into a private member function

Comment on lines 96 to 102
multihot_list = [False] * mutable.n_candidates
for i, v in zip(idx, value):
assert 0 <= i < mutable.n_candidates and search_space_ref[i] == v, \
"Index '{}' in search space '{}' is not '{}'".format(i, search_space_ref, v)
assert not multihot_list[i], "'{}' is selected twice in '{}', which is not allowed.".format(i, idx)
multihot_list[i] = True
result[mutable.key] = torch.tensor(multihot_list, dtype=torch.bool) # pylint: disable=not-callable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest to extract this part into a private member function

@ultmaster ultmaster merged commit 7a55811 into microsoft:master Dec 23, 2019
@leckie-chn leckie-chn mentioned this pull request Dec 25, 2019
19 tasks
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants