Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

fix lstm layer with projection save params (#17266) #17286

Merged
merged 1 commit into from
Feb 3, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
fix lstm layer with projection save params (#17266)
szha authored and frankfliu committed Jan 31, 2020

Verified

This commit was signed with the committer’s verified signature. The key has expired.
zalegrala Zach Leslie
commit 44ca019660474b2a2aaa484b3e56f1c66eb14874
2 changes: 1 addition & 1 deletion python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
@@ -124,7 +124,7 @@ def __repr__(self):
def _collect_params_with_prefix(self, prefix=''):
if prefix:
prefix += '.'
pattern = re.compile(r'(l|r)(\d)_(i2h|h2h)_(weight|bias)\Z')
pattern = re.compile(r'(l|r)(\d)_(i2h|h2h|h2r)_(weight|bias)\Z')
def convert_key(m, bidirectional): # for compatibility with old parameter format
d, l, g, t = [m.group(i) for i in range(1, 5)]
if bidirectional:
2 changes: 2 additions & 0 deletions tests/python/gpu/test_gluon_gpu.py
Original file line number Diff line number Diff line change
@@ -137,6 +137,8 @@ def test_lstmp():
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, dropout=0.5, projection_size=5),
mx.nd.ones((8, 3, 20)),
[mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))], run_only=True, ctx=ctx)
lstm_layer.save_parameters('gpu_tmp.params')
lstm_layer.load_parameters('gpu_tmp.params')


@with_seed()