Skip to content

Commit

Permalink
fix wrong order of hidden and cell state (fix microsoft#2839)
Browse files Browse the repository at this point in the history
  • Loading branch information
jyh2986 committed Aug 31, 2020
1 parent bf8be1e commit 23b892c
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ def __init__(self, layers, size, bias):
for _ in range(self.lstm_num_layers)])

def forward(self, inputs, hidden):
prev_c, prev_h = hidden
next_c, next_h = [], []
prev_h, prev_c = hidden
next_h, next_c = [], []
for i, m in enumerate(self.lstm_modules):
curr_c, curr_h = m(inputs, (prev_c[i], prev_h[i]))
curr_h, curr_c = m(inputs, (prev_h[i], prev_c[i]))
next_c.append(curr_c)
next_h.append(curr_h)
# current implementation only supports batch size equals 1,
# but the algorithm does not necessarily have this limitation
inputs = curr_h[-1].view(1, -1)
return next_c, next_h
return next_h, next_c


class EnasMutator(Mutator):
Expand Down Expand Up @@ -136,7 +136,7 @@ def _initialize(self):
self.sample_skip_penalty = 0

def _lstm_next_step(self):
self._c, self._h = self.lstm(self._inputs, (self._c, self._h))
self._h, self._c = self.lstm(self._inputs, (self._h, self._c))

def _mark_anchor(self, key):
self._anchors_hid[key] = self._h[-1]
Expand Down

0 comments on commit 23b892c

Please sign in to comment.