From 5a5b356124e7fbadfa38e2f01046d46b28aef847 Mon Sep 17 00:00:00 2001 From: Yuge Zhang <scottyugochang@gmail.com> Date: Fri, 21 Aug 2020 16:00:47 +0800 Subject: [PATCH] Fix visualization for ENAS micro --- src/sdk/pynni/nni/nas/pytorch/mutator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sdk/pynni/nni/nas/pytorch/mutator.py b/src/sdk/pynni/nni/nas/pytorch/mutator.py index f08061f3c6..1e97ab2123 100644 --- a/src/sdk/pynni/nni/nas/pytorch/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/mutator.py @@ -162,7 +162,7 @@ def on_forward_layer_choice(self, mutable, *args, **kwargs): if self._connect_all: return self._all_connect_tensor_reduction(mutable.reduction, [op(*args, **kwargs) for op in mutable]), \ - torch.ones(len(mutable)) + torch.ones(len(mutable)).bool() def _map_fn(op, args, kwargs): return op(*args, **kwargs) @@ -192,7 +192,7 @@ def on_forward_input_choice(self, mutable, tensor_list): """ if self._connect_all: return self._all_connect_tensor_reduction(mutable.reduction, tensor_list), \ - torch.ones(mutable.n_candidates) + torch.ones(mutable.n_candidates).bool() mask = self._get_decision(mutable) assert len(mask) == mutable.n_candidates, \ "Invalid mask, expected {} to be of length {}.".format(mask, mutable.n_candidates)