From 5a5b356124e7fbadfa38e2f01046d46b28aef847 Mon Sep 17 00:00:00 2001
From: Yuge Zhang <scottyugochang@gmail.com>
Date: Fri, 21 Aug 2020 16:00:47 +0800
Subject: [PATCH] Fix visualization for ENAS micro

---
 src/sdk/pynni/nni/nas/pytorch/mutator.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/sdk/pynni/nni/nas/pytorch/mutator.py b/src/sdk/pynni/nni/nas/pytorch/mutator.py
index f08061f3c6..1e97ab2123 100644
--- a/src/sdk/pynni/nni/nas/pytorch/mutator.py
+++ b/src/sdk/pynni/nni/nas/pytorch/mutator.py
@@ -162,7 +162,7 @@ def on_forward_layer_choice(self, mutable, *args, **kwargs):
         if self._connect_all:
             return self._all_connect_tensor_reduction(mutable.reduction,
                                                       [op(*args, **kwargs) for op in mutable]), \
-                torch.ones(len(mutable))
+                torch.ones(len(mutable)).bool()
 
         def _map_fn(op, args, kwargs):
             return op(*args, **kwargs)
@@ -192,7 +192,7 @@ def on_forward_input_choice(self, mutable, tensor_list):
         """
         if self._connect_all:
             return self._all_connect_tensor_reduction(mutable.reduction, tensor_list), \
-                torch.ones(mutable.n_candidates)
+                torch.ones(mutable.n_candidates).bool()
         mask = self._get_decision(mutable)
         assert len(mask) == mutable.n_candidates, \
             "Invalid mask, expected {} to be of length {}.".format(mask, mutable.n_candidates)