diff --git a/src/sdk/pynni/nni/nas/tensorflow/enas/trainer.py b/src/sdk/pynni/nni/nas/tensorflow/enas/trainer.py index 2d03a3fbb8..a9645e9203 100644 --- a/src/sdk/pynni/nni/nas/tensorflow/enas/trainer.py +++ b/src/sdk/pynni/nni/nas/tensorflow/enas/trainer.py @@ -136,10 +136,10 @@ def validate_one_epoch(self, epoch): meters = AverageMeterGroup() for x, y in test_loader: self.mutator.reset() - logits = self.model(x) + logits = self.model(x, training=False) if isinstance(logits, tuple): logits, _ = logits - metrics = self.metrics(logits, y) + metrics = self.metrics(y, logits) loss = self.loss(y, logits) metrics['loss'] = tf.reduce_mean(loss).numpy() meters.update(metrics) @@ -151,8 +151,8 @@ def validate_one_epoch(self, epoch): def _create_train_loader(self): train_set = self.train_set.shuffle(1000000).repeat().batch(self.batch_size) - test_set = self.test_set.shuffle(1000000).repeat().batch(self.batch_size) + test_set = self.valid_set.shuffle(1000000).repeat().batch(self.batch_size) return iter(train_set), iter(test_set) def _create_validate_loader(self): - return iter(self.test_set.shuffle(1000000).repeat().batch(self.batch_size)) + return iter(self.test_set.shuffle(1000000).batch(self.batch_size)) diff --git a/src/sdk/pynni/nni/nas/tensorflow/mutables.py b/src/sdk/pynni/nni/nas/tensorflow/mutables.py index b83b6f6325..06183a34c1 100644 --- a/src/sdk/pynni/nni/nas/tensorflow/mutables.py +++ b/src/sdk/pynni/nni/nas/tensorflow/mutables.py @@ -28,20 +28,19 @@ def __init__(self, key=None): def __deepcopy__(self, memodict=None): raise NotImplementedError("Deep copy doesn't work for mutables.") - def __call__(self, *args, **kwargs): - self._check_built() - return super().__call__(*args, **kwargs) - def set_mutator(self, mutator): - if 'mutator' in self.__dict__: + if hasattr(self, 'mutator'): raise RuntimeError('`set_mutator is called more than once. ' 'Did you parse the search space multiple times? ' 'Or did you apply multiple fixed architectures?') - self.__dict__['mutator'] = mutator + self.mutator = mutator def call(self, *inputs): raise NotImplementedError('Method `call` of Mutable must be overridden') + def build(self, input_shape): + self._check_built() + @property def key(self): return self._key @@ -68,7 +67,6 @@ def __repr__(self): class MutableScope(Mutable): def __call__(self, *args, **kwargs): try: - self._check_built() self.mutator.enter_mutable_scope(self) return super().__call__(*args, **kwargs) finally: @@ -80,7 +78,7 @@ def __init__(self, op_candidates, reduction='sum', return_mask=False, key=None): super().__init__(key=key) self.names = [] if isinstance(op_candidates, OrderedDict): - for name, _ in op_candidates.items(): + for name in op_candidates: assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \ "Please don't use a reserved name '{}' for your module.".format(name) self.names.append(name) @@ -94,21 +92,18 @@ def __init__(self, op_candidates, reduction='sum', return_mask=False, key=None): self.choices = op_candidates self.reduction = reduction self.return_mask = return_mask - self._built = False def call(self, *inputs): - if not self._built: - for op in self.choices: - if len(inputs) > 1: # FIXME: not tested - op.build([inp.shape for inp in inputs]) - elif len(inputs) == 1: - op.build(inputs[0].shape) - self._built = True out, mask = self.mutator.on_forward_layer_choice(self, *inputs) if self.return_mask: return out, mask return out + def build(self, input_shape): + self._check_built() + for op in self.choices: + op.build(input_shape) + def __len__(self): return len(self.choices)