diff --git a/python/mxnet/rnn/rnn_cell.py b/python/mxnet/rnn/rnn_cell.py index 1d2fc677841a..b7e07a59b4c1 100644 --- a/python/mxnet/rnn/rnn_cell.py +++ b/python/mxnet/rnn/rnn_cell.py @@ -913,6 +913,26 @@ def __call__(self, inputs, states): output = symbol.elemwise_add(output, inputs, name="%s_plus_residual" % output.name) return output, states + def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None): + self.reset() + + self.base_cell._modified = False + outputs, states = self.base_cell.unroll(length, inputs=inputs, begin_state=begin_state, + layout=layout, merge_outputs=merge_outputs) + self.base_cell._modified = True + + merge_outputs = isinstance(outputs, symbol.Symbol) if merge_outputs is None else \ + merge_outputs + inputs, _ = _normalize_sequence(length, inputs, layout, merge_outputs) + if merge_outputs: + outputs = symbol.elemwise_add(outputs, inputs, name="%s_plus_residual" % outputs.name) + else: + outputs = [symbol.elemwise_add(output_sym, input_sym, + name="%s_plus_residual" % output_sym.name) + for output_sym, input_sym in zip(outputs, inputs)] + + return outputs, states + class BidirectionalCell(BaseRNNCell): """Bidirectional RNN cell. @@ -928,9 +948,18 @@ class BidirectionalCell(BaseRNNCell): """ def __init__(self, l_cell, r_cell, params=None, output_prefix='bi_'): super(BidirectionalCell, self).__init__('', params=params) + self._output_prefix = output_prefix self._override_cell_params = params is not None + + if self._override_cell_params: + assert l_cell._own_params and r_cell._own_params, \ + "Either specify params for BidirectionalCell " \ + "or child cells, not both." + l_cell.params._params.update(self.params._params) + r_cell.params._params.update(self.params._params) + self.params._params.update(l_cell.params._params) + self.params._params.update(r_cell.params._params) self._cells = [l_cell, r_cell] - self._output_prefix = output_prefix def unpack_weights(self, args): return _cells_unpack_weights(self._cells, args) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index fd3dd9289836..48e44133216b 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1093,6 +1093,25 @@ def test_unfuse(): check_rnn_consistency(fused, stack) check_rnn_consistency(stack, fused) +def test_residual_fused(): + cell = mx.rnn.ResidualCell( + mx.rnn.FusedRNNCell(50, num_layers=3, mode='lstm', + prefix='rnn_', dropout=0.5)) + + inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(2)] + outputs, _ = cell.unroll(2, inputs, merge_outputs=None) + assert sorted(cell.params._params.keys()) == \ + ['rnn_parameters'] + + args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10, 50), rnn_t1_data=(10, 50)) + assert outs == [(10, 2, 50)] + outputs = outputs.eval(ctx=mx.gpu(0), + rnn_t0_data=mx.nd.ones((10, 50), ctx=mx.gpu(0))+5, + rnn_t1_data=mx.nd.ones((10, 50), ctx=mx.gpu(0))+5, + rnn_parameters=mx.nd.zeros((61200,), ctx=mx.gpu(0))) + expected_outputs = np.ones((10, 2, 50))+5 + assert np.array_equal(outputs[0].asnumpy(), expected_outputs) + if __name__ == '__main__': test_countsketch() test_ifft() @@ -1103,6 +1122,7 @@ def test_unfuse(): test_gru() test_rnn() test_unfuse() + test_residual_fused() test_convolution_options() test_convolution_versions() test_convolution_with_type() diff --git a/tests/python/unittest/test_rnn.py b/tests/python/unittest/test_rnn.py index 903ce013e8f0..419104d57dd2 100644 --- a/tests/python/unittest/test_rnn.py +++ b/tests/python/unittest/test_rnn.py @@ -72,8 +72,6 @@ def test_residual(): args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10, 50), rnn_t1_data=(10, 50)) assert outs == [(10, 50), (10, 50)] - print(args) - print(outputs.list_arguments()) outputs = outputs.eval(rnn_t0_data=mx.nd.ones((10, 50)), rnn_t1_data=mx.nd.ones((10, 50)), rnn_i2h_weight=mx.nd.zeros((150, 50)), @@ -85,6 +83,38 @@ def test_residual(): assert np.array_equal(outputs[1].asnumpy(), expected_outputs) +def test_residual_bidirectional(): + cell = mx.rnn.ResidualCell( + mx.rnn.BidirectionalCell( + mx.rnn.GRUCell(25, prefix='rnn_l_'), + mx.rnn.GRUCell(25, prefix='rnn_r_'))) + + inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(2)] + outputs, _ = cell.unroll(2, inputs, merge_outputs=False) + outputs = mx.sym.Group(outputs) + assert sorted(cell.params._params.keys()) == \ + ['rnn_l_h2h_bias', 'rnn_l_h2h_weight', 'rnn_l_i2h_bias', 'rnn_l_i2h_weight', + 'rnn_r_h2h_bias', 'rnn_r_h2h_weight', 'rnn_r_i2h_bias', 'rnn_r_i2h_weight'] + assert outputs.list_outputs() == \ + ['bi_t0_plus_residual_output', 'bi_t1_plus_residual_output'] + + args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10, 50), rnn_t1_data=(10, 50)) + assert outs == [(10, 50), (10, 50)] + outputs = outputs.eval(rnn_t0_data=mx.nd.ones((10, 50))+5, + rnn_t1_data=mx.nd.ones((10, 50))+5, + rnn_l_i2h_weight=mx.nd.zeros((75, 50)), + rnn_l_i2h_bias=mx.nd.zeros((75,)), + rnn_l_h2h_weight=mx.nd.zeros((75, 25)), + rnn_l_h2h_bias=mx.nd.zeros((75,)), + rnn_r_i2h_weight=mx.nd.zeros((75, 50)), + rnn_r_i2h_bias=mx.nd.zeros((75,)), + rnn_r_h2h_weight=mx.nd.zeros((75, 25)), + rnn_r_h2h_bias=mx.nd.zeros((75,))) + expected_outputs = np.ones((10, 50))+5 + assert np.array_equal(outputs[0].asnumpy(), expected_outputs) + assert np.array_equal(outputs[1].asnumpy(), expected_outputs) + + def test_stack(): cell = mx.rnn.SequentialRNNCell() for i in range(5):