Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

TF NAS fix #2781

Merged
merged 1 commit into from
Aug 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/sdk/pynni/nni/nas/tensorflow/enas/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,10 @@ def validate_one_epoch(self, epoch):
meters = AverageMeterGroup()
for x, y in test_loader:
self.mutator.reset()
logits = self.model(x)
logits = self.model(x, training=False)
if isinstance(logits, tuple):
logits, _ = logits
metrics = self.metrics(logits, y)
metrics = self.metrics(y, logits)
loss = self.loss(y, logits)
metrics['loss'] = tf.reduce_mean(loss).numpy()
meters.update(metrics)
Expand All @@ -151,8 +151,8 @@ def validate_one_epoch(self, epoch):

def _create_train_loader(self):
train_set = self.train_set.shuffle(1000000).repeat().batch(self.batch_size)
test_set = self.test_set.shuffle(1000000).repeat().batch(self.batch_size)
test_set = self.valid_set.shuffle(1000000).repeat().batch(self.batch_size)
return iter(train_set), iter(test_set)

def _create_validate_loader(self):
return iter(self.test_set.shuffle(1000000).repeat().batch(self.batch_size))
return iter(self.test_set.shuffle(1000000).batch(self.batch_size))
27 changes: 11 additions & 16 deletions src/sdk/pynni/nni/nas/tensorflow/mutables.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,19 @@ def __init__(self, key=None):
def __deepcopy__(self, memodict=None):
raise NotImplementedError("Deep copy doesn't work for mutables.")

def __call__(self, *args, **kwargs):
self._check_built()
return super().__call__(*args, **kwargs)

def set_mutator(self, mutator):
if 'mutator' in self.__dict__:
if hasattr(self, 'mutator'):
raise RuntimeError('`set_mutator is called more than once. '
'Did you parse the search space multiple times? '
'Or did you apply multiple fixed architectures?')
self.__dict__['mutator'] = mutator
self.mutator = mutator

def call(self, *inputs):
raise NotImplementedError('Method `call` of Mutable must be overridden')

def build(self, input_shape):
self._check_built()

@property
def key(self):
return self._key
Expand All @@ -68,7 +67,6 @@ def __repr__(self):
class MutableScope(Mutable):
def __call__(self, *args, **kwargs):
try:
self._check_built()
self.mutator.enter_mutable_scope(self)
return super().__call__(*args, **kwargs)
finally:
Expand All @@ -80,7 +78,7 @@ def __init__(self, op_candidates, reduction='sum', return_mask=False, key=None):
super().__init__(key=key)
self.names = []
if isinstance(op_candidates, OrderedDict):
for name, _ in op_candidates.items():
for name in op_candidates:
assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \
"Please don't use a reserved name '{}' for your module.".format(name)
self.names.append(name)
Expand All @@ -94,21 +92,18 @@ def __init__(self, op_candidates, reduction='sum', return_mask=False, key=None):
self.choices = op_candidates
self.reduction = reduction
self.return_mask = return_mask
self._built = False

def call(self, *inputs):
if not self._built:
for op in self.choices:
if len(inputs) > 1: # FIXME: not tested
op.build([inp.shape for inp in inputs])
elif len(inputs) == 1:
op.build(inputs[0].shape)
self._built = True
out, mask = self.mutator.on_forward_layer_choice(self, *inputs)
if self.return_mask:
return out, mask
return out

def build(self, input_shape):
self._check_built()
for op in self.choices:
op.build(input_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the meaning of op.build?

Copy link
Contributor Author

@liuzhe-lz liuzhe-lz Aug 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TF layers do not require input shape to __init__, so they create weights when build is called.
Normally this happens implicitly during first forward pass. We trigger it manually because only one candidate layer is connected at first forward so other candidate layers cannot inference their input shape automatically.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, thanks.


def __len__(self):
return len(self.choices)

Expand Down