From 1bb272b63eed2992435ed7a902190da92925fe9a Mon Sep 17 00:00:00 2001 From: Per Goncalves da Silva Date: Tue, 26 Mar 2019 13:41:54 +0100 Subject: [PATCH] Removes default value for ctx parameter in check_rnn_layer_forward and refactors tests --- tests/python/gpu/test_gluon_gpu.py | 13 +++++---- tests/python/unittest/test_gluon_rnn.py | 36 ++++++++++++------------- 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 88b436a0deb2..18b2b533c662 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -88,7 +88,8 @@ def test_lstmp(): rtol, atol = 1e-2, 1e-2 batch_size, seq_len = 7, 11 input_size = 5 - lstm_input = mx.nd.uniform(shape=(seq_len, batch_size, input_size), ctx=mx.gpu(0)) + ctx=mxn.gpu(0) + lstm_input = mx.nd.uniform(shape=(seq_len, batch_size, input_size), ctx=ctx) shapes = {'i2h_weight': (hidden_size*4, input_size), 'h2h_weight': (hidden_size*4, projection_size), 'i2h_bias': (hidden_size*4,), @@ -101,8 +102,8 @@ def test_lstmp(): projection_size=projection_size, input_size=input_size, prefix='lstm0_l0_') - lstm_layer.initialize(ctx=mx.gpu(0)) - lstm_cell.initialize(ctx=mx.gpu(0)) + lstm_layer.initialize(ctx=ctx) + lstm_cell.initialize(ctx=ctx) layer_params = lstm_layer.collect_params() cell_params = lstm_cell.collect_params() for k, v in weights.items(): @@ -121,13 +122,15 @@ def test_lstmp(): print('checking gradient for {}'.format('lstm0_l0_'+k)) assert_almost_equal(layer_grad.asnumpy(), cell_grad.asnumpy(), rtol=rtol, atol=atol) - check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5), mx.nd.ones((8, 3, 20))) - check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5, bidirectional=True), mx.nd.ones((8, 3, 20)), [mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))]) + check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5), mx.nd.ones((8, 3, 20)), ctx) + check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5, bidirectional=True), mx.nd.ones((8, 3, 20)), ctx, [mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))]) check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5, projection_size=5), mx.nd.ones((8, 3, 20)), + ctx, run_only=True) check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, dropout=0.5, projection_size=5), mx.nd.ones((8, 3, 20)), + ctx, [mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))], run_only=True) diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index b410362c8fd1..a368311c74cb 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -427,7 +427,7 @@ def hybrid_forward(self, F, seq): assert_almost_equal(output1.asnumpy(), output2.asnumpy()) -def check_rnn_layer_forward(layer, inputs, states=None, run_only=False, ctx=mx.cpu()): +def check_rnn_layer_forward(layer, inputs, ctx, states=None, run_only=False): layer.collect_params().initialize(ctx=ctx) inputs = inputs.as_in_context(ctx) inputs.attach_grad() @@ -476,27 +476,27 @@ def check_rnn_layer_forward(layer, inputs, states=None, run_only=False, ctx=mx.c def run_rnn_layers(dtype, dtype2, ctx=mx.cpu()): - check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), ctx=ctx) - check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype, bidirectional=True), mx.nd.ones((8, 3, 20), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype), ctx=ctx) - check_rnn_layer_forward(gluon.rnn.LSTM(10, 2,dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), ctx=ctx) - check_rnn_layer_forward(gluon.rnn.LSTM(10, 2,dtype=dtype, bidirectional=True), mx.nd.ones((8, 3, 20), dtype=dtype), [mx.nd.ones((4, 3, 10), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype)],ctx=ctx) - check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dtype=dtype, ), mx.nd.ones((8, 3, 20), dtype=dtype),ctx=ctx) - check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dtype=dtype, bidirectional=True), mx.nd.ones((8, 3, 20), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype),ctx=ctx) + check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), ctx) + check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype, bidirectional=True), mx.nd.ones((8, 3, 20), dtype=dtype), ctx, mx.nd.ones((4, 3, 10), dtype=dtype)) + check_rnn_layer_forward(gluon.rnn.LSTM(10, 2,dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), ctx) + check_rnn_layer_forward(gluon.rnn.LSTM(10, 2,dtype=dtype, bidirectional=True), mx.nd.ones((8, 3, 20), dtype=dtype), ctx, [mx.nd.ones((4, 3, 10), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype)]) + check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), ctx) + check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dtype=dtype, bidirectional=True), mx.nd.ones((8, 3, 20), dtype=dtype), ctx, mx.nd.ones((4, 3, 10), dtype=dtype)) - check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype, dropout=0.5), mx.nd.ones((8, 3, 20), dtype=dtype), - run_only=True, ctx=ctx) + check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype, dropout=0.5), mx.nd.ones((8, 3, 20), dtype=dtype), ctx, + run_only=True) check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True, dropout=0.5, dtype=dtype), - mx.nd.ones((8, 3, 20), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype), run_only=True, ctx=ctx) - check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), - run_only=True, ctx=ctx) + mx.nd.ones((8, 3, 20), dtype=dtype), ctx, mx.nd.ones((4, 3, 10), dtype=dtype), run_only=True) + check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), ctx, + run_only=True) check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, dropout=0.5, dtype=dtype), - mx.nd.ones((8, 3, 20), dtype=dtype), - [mx.nd.ones((4, 3, 10), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype)], run_only=True, ctx=ctx) - check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dropout=0.5, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), - run_only=True, ctx=ctx) + mx.nd.ones((8, 3, 20), dtype=dtype), ctx, + [mx.nd.ones((4, 3, 10), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype)], run_only=True) + check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dropout=0.5, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), ctx, + run_only=True) check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True, dropout=0.5, dtype=dtype), - mx.nd.ones((8, 3, 20), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype), run_only=True, ctx=ctx) + mx.nd.ones((8, 3, 20), dtype=dtype), ctx, mx.nd.ones((4, 3, 10), dtype=dtype), run_only=True) net = gluon.nn.Sequential() net.add(gluon.rnn.LSTM(10, bidirectional=True, dtype=dtype2)) @@ -628,7 +628,7 @@ def test_cell_fill_shape(): def test_layer_fill_shape(): layer = gluon.rnn.LSTM(10) layer.hybridize() - check_rnn_layer_forward(layer, mx.nd.ones((3, 2, 7))) + check_rnn_layer_forward(layer, mx.nd.ones((3, 2, 7)), mx.cpu()) print(layer) assert layer.l0_i2h_weight.shape[1] == 7, layer.l0_i2h_weight.shape[1]