Skip to content

Commit

Permalink
TF NAS fix: avoid checking member during forward (microsoft#2781)
Browse files Browse the repository at this point in the history
Co-authored-by: liuzhe <[email protected]>
  • Loading branch information
liuzhe-lz and liuzhe authored Aug 12, 2020
1 parent 5623dbf commit e7fccfb
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 20 deletions.
8 changes: 4 additions & 4 deletions src/sdk/pynni/nni/nas/tensorflow/enas/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,10 @@ def validate_one_epoch(self, epoch):
meters = AverageMeterGroup()
for x, y in test_loader:
self.mutator.reset()
logits = self.model(x)
logits = self.model(x, training=False)
if isinstance(logits, tuple):
logits, _ = logits
metrics = self.metrics(logits, y)
metrics = self.metrics(y, logits)
loss = self.loss(y, logits)
metrics['loss'] = tf.reduce_mean(loss).numpy()
meters.update(metrics)
Expand All @@ -151,8 +151,8 @@ def validate_one_epoch(self, epoch):

def _create_train_loader(self):
train_set = self.train_set.shuffle(1000000).repeat().batch(self.batch_size)
test_set = self.test_set.shuffle(1000000).repeat().batch(self.batch_size)
test_set = self.valid_set.shuffle(1000000).repeat().batch(self.batch_size)
return iter(train_set), iter(test_set)

def _create_validate_loader(self):
return iter(self.test_set.shuffle(1000000).repeat().batch(self.batch_size))
return iter(self.test_set.shuffle(1000000).batch(self.batch_size))
27 changes: 11 additions & 16 deletions src/sdk/pynni/nni/nas/tensorflow/mutables.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,19 @@ def __init__(self, key=None):
def __deepcopy__(self, memodict=None):
raise NotImplementedError("Deep copy doesn't work for mutables.")

def __call__(self, *args, **kwargs):
self._check_built()
return super().__call__(*args, **kwargs)

def set_mutator(self, mutator):
if 'mutator' in self.__dict__:
if hasattr(self, 'mutator'):
raise RuntimeError('`set_mutator is called more than once. '
'Did you parse the search space multiple times? '
'Or did you apply multiple fixed architectures?')
self.__dict__['mutator'] = mutator
self.mutator = mutator

def call(self, *inputs):
raise NotImplementedError('Method `call` of Mutable must be overridden')

def build(self, input_shape):
self._check_built()

@property
def key(self):
return self._key
Expand All @@ -68,7 +67,6 @@ def __repr__(self):
class MutableScope(Mutable):
def __call__(self, *args, **kwargs):
try:
self._check_built()
self.mutator.enter_mutable_scope(self)
return super().__call__(*args, **kwargs)
finally:
Expand All @@ -80,7 +78,7 @@ def __init__(self, op_candidates, reduction='sum', return_mask=False, key=None):
super().__init__(key=key)
self.names = []
if isinstance(op_candidates, OrderedDict):
for name, _ in op_candidates.items():
for name in op_candidates:
assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \
"Please don't use a reserved name '{}' for your module.".format(name)
self.names.append(name)
Expand All @@ -94,21 +92,18 @@ def __init__(self, op_candidates, reduction='sum', return_mask=False, key=None):
self.choices = op_candidates
self.reduction = reduction
self.return_mask = return_mask
self._built = False

def call(self, *inputs):
if not self._built:
for op in self.choices:
if len(inputs) > 1: # FIXME: not tested
op.build([inp.shape for inp in inputs])
elif len(inputs) == 1:
op.build(inputs[0].shape)
self._built = True
out, mask = self.mutator.on_forward_layer_choice(self, *inputs)
if self.return_mask:
return out, mask
return out

def build(self, input_shape):
self._check_built()
for op in self.choices:
op.build(input_shape)

def __len__(self):
return len(self.choices)

Expand Down

0 comments on commit e7fccfb

Please sign in to comment.