Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Fix Error in nas SPOS trainer, apply_fixed_architecture #3051

Merged
merged 2 commits into from
Nov 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions nni/algorithms/nas/pytorch/spos/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def train_one_epoch(self, epoch):
self.model.train()
meters = AverageMeterGroup()
for step, (x, y) in enumerate(self.train_loader):
x, y = x.to(self.device), y.to(self.device)
self.optimizer.zero_grad()
self.mutator.reset()
logits = self.model(x)
Expand All @@ -82,6 +83,7 @@ def validate_one_epoch(self, epoch):
meters = AverageMeterGroup()
with torch.no_grad():
for step, (x, y) in enumerate(self.valid_loader):
x, y = x.to(self.device), y.to(self.device)
self.mutator.reset()
logits = self.model(x)
loss = self.loss(logits, y)
Expand Down
16 changes: 11 additions & 5 deletions nni/nas/pytorch/fixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@ class FixedArchitecture(Mutator):
Preloaded architecture object.
strict : bool
Force everything that appears in ``fixed_arc`` to be used at least once.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the docstring accordingly.

verbose : bool
Print log messages if set to True
"""

def __init__(self, model, fixed_arc, strict=True):
def __init__(self, model, fixed_arc, strict=True, verbose=True):
super().__init__(model)
self._fixed_arc = fixed_arc
self.verbose = verbose

mutable_keys = set([mutable.key for mutable in self.mutables if not isinstance(mutable, MutableScope)])
fixed_arc_keys = set(self._fixed_arc.keys())
Expand Down Expand Up @@ -99,10 +102,11 @@ def replace_layer_choice(self, module=None, prefix=""):
if sum(chosen) == 1 and max(chosen) == 1 and not mutable.return_mask:
# sum is one, max is one, there has to be an only one
# this is compatible with both integer arrays, boolean arrays and float arrays
_logger.info("Replacing %s with candidate number %d.", global_name, chosen.index(1))
if self.verbose:
_logger.info("Replacing %s with candidate number %d.", global_name, chosen.index(1))
setattr(module, name, mutable[chosen.index(1)])
else:
if mutable.return_mask:
if mutable.return_mask and self.verbose:
_logger.info("`return_mask` flag of %s is true. As it relies on the behavior of LayerChoice, " \
"LayerChoice will not be replaced.")
# remove unused parameters
Expand All @@ -113,7 +117,7 @@ def replace_layer_choice(self, module=None, prefix=""):
self.replace_layer_choice(mutable, global_name)


def apply_fixed_architecture(model, fixed_arc):
def apply_fixed_architecture(model, fixed_arc, verbose=True):
"""
Load architecture from `fixed_arc` and apply to model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here


Expand All @@ -123,6 +127,8 @@ def apply_fixed_architecture(model, fixed_arc):
Model with mutables.
fixed_arc : str or dict
Path to the JSON that stores the architecture, or dict that stores the exported architecture.
verbose : bool
Print log messages if set to True

Returns
-------
Expand All @@ -133,7 +139,7 @@ def apply_fixed_architecture(model, fixed_arc):
if isinstance(fixed_arc, str):
with open(fixed_arc) as f:
fixed_arc = json.load(f)
architecture = FixedArchitecture(model, fixed_arc)
architecture = FixedArchitecture(model, fixed_arc, verbose)
architecture.reset()

# for the convenience of parameters counting
Expand Down