From 23b892ca4d3368a386cf7ca434de9f83a9fba3d6 Mon Sep 17 00:00:00 2001 From: Yunho Jeon Date: Mon, 31 Aug 2020 06:21:07 +0000 Subject: [PATCH] fix wrong order of hidden and cell state (fix #2839) --- src/sdk/pynni/nni/nas/pytorch/enas/mutator.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py b/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py index 7763622a58..7fdba26b99 100644 --- a/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py @@ -17,16 +17,16 @@ def __init__(self, layers, size, bias): for _ in range(self.lstm_num_layers)]) def forward(self, inputs, hidden): - prev_c, prev_h = hidden - next_c, next_h = [], [] + prev_h, prev_c = hidden + next_h, next_c = [], [] for i, m in enumerate(self.lstm_modules): - curr_c, curr_h = m(inputs, (prev_c[i], prev_h[i])) + curr_h, curr_c = m(inputs, (prev_h[i], prev_c[i])) next_c.append(curr_c) next_h.append(curr_h) # current implementation only supports batch size equals 1, # but the algorithm does not necessarily have this limitation inputs = curr_h[-1].view(1, -1) - return next_c, next_h + return next_h, next_c class EnasMutator(Mutator): @@ -136,7 +136,7 @@ def _initialize(self): self.sample_skip_penalty = 0 def _lstm_next_step(self): - self._c, self._h = self.lstm(self._inputs, (self._c, self._h)) + self._h, self._c = self.lstm(self._inputs, (self._h, self._c)) def _mark_anchor(self, key): self._anchors_hid[key] = self._h[-1]