-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Improve Classic NAS mutator #1865
Improve Classic NAS mutator #1865
Conversation
if 'NNI_GEN_SEARCH_SPACE' in os.environ: | ||
self._chosen_arch = {} | ||
self._search_space = self._generate_search_space() | ||
if "NNI_GEN_SEARCH_SPACE" in os.environ: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change the "NNI_GEN_SEARCH_SPACE" as an enum string?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
n_chosen = val['_value']['n_chosen'] | ||
chosen_arch[key] = {'_value': choices[:n_chosen], '_idx': list(range(n_chosen))} | ||
for key, val in self._search_space.items(): | ||
if val["_type"] == "layer_choice": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here.
class ModelCheckpoint(Callback): | ||
def __init__(self, checkpoint_dir, every="epoch"): | ||
super().__init__() | ||
assert every == "epoch" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
support only one specific value? so why we make it a parameter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The roadmap is to support checkpoint on every mini-batch. But it's not implemented yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove, because it is confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed in another PR.
self.trainer.export(dest_path) | ||
|
||
|
||
class ModelCheckpoint(Callback): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can checkpointed model be used for resume? e.g., learning rate, epoch number
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is in another PR. And no, not supported.
else: | ||
# get chosen arch from tuner | ||
self._chosen_arch = nni.get_next_parameter() | ||
self._cache = self.sample_final() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is _cache
used for in classic_nas?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_cache
is stored so that on_forward_layer_choice
implemented in Mutator can retrieve the decision.
# doesn't support multihot for layer choice yet | ||
onehot_list = [False] * mutable.length | ||
assert 0 <= idx < mutable.length and search_space_ref[idx] == value, \ | ||
"Index '{}' in search space '{}' is not '{}'".format(idx, search_space_ref, value) | ||
onehot_list[idx] = True | ||
result[mutable.key] = torch.tensor(onehot_list, dtype=torch.bool) # pylint: disable=not-callable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest to extract this part into a private member function
multihot_list = [False] * mutable.n_candidates | ||
for i, v in zip(idx, value): | ||
assert 0 <= i < mutable.n_candidates and search_space_ref[i] == v, \ | ||
"Index '{}' in search space '{}' is not '{}'".format(i, search_space_ref, v) | ||
assert not multihot_list[i], "'{}' is selected twice in '{}', which is not allowed.".format(i, idx) | ||
multihot_list[i] = True | ||
result[mutable.key] = torch.tensor(multihot_list, dtype=torch.bool) # pylint: disable=not-callable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest to extract this part into a private member function
New utility functions in preparation for SPOS: