diff --git a/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py b/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py index 889c0f35a7..8cd107ec9d 100644 --- a/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py @@ -23,7 +23,9 @@ def forward(self, inputs, hidden): curr_c, curr_h = m(inputs, (prev_c[i], prev_h[i])) next_c.append(curr_c) next_h.append(curr_h) - inputs = curr_h[-1] + # current implementation only supports batch size equals 1, + # but the algorithm does not necessarily have this limitation + inputs = curr_h[-1].view(1, -1) return next_c, next_h