diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index d41b5b4f030b..b9cee10b66fb 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -193,7 +193,9 @@ inline size_t GetRNNWorkspaceSize(index_t seq_length, case rnn_enum::kLstm: size = seq_length * batch_size * hidden_size * (4 + direction) + // wx*x + inter-y batch_size * hidden_size * 6 + // wh*h + h + c - seq_length * hidden_size * 8; // Used in Backward, Δbx, Δbh + seq_length * hidden_size * 8 + // Used in Backward, Δbx, Δbh + // temporary dy in backward computation for bidirectional layers + seq_length * batch_size * hidden_size * (direction - 1 ? direction : 0); break; case rnn_enum::kGru: // Differs with Lstm, the outputs of three gates are also held in memory diff --git a/src/operator/rnn_impl.h b/src/operator/rnn_impl.h index 459345797936..9f951856fc2f 100644 --- a/src/operator/rnn_impl.h +++ b/src/operator/rnn_impl.h @@ -559,6 +559,7 @@ void LstmBackward(DType* ws, const index_t w_size1 = (I + H) * H * 4; // first layer const index_t w_size2 = (D * H + H) * H * 4; // other layers const index_t cell_size = N * H; + const index_t y_size = T * N * H * D; DType* dy_tmp_ptr = ws2 + T * cell_size * 4 + cell_size * 3; for (int i = L - 1; i >= 0; --i) { const index_t input_size = i ? H * D : I; @@ -589,6 +590,10 @@ void LstmBackward(DType* ws, x, hx[idx], cx[idx], y, dy, dx, dhx[idx], dcx[idx], dhy_cur_ptr, dcy_cur_ptr, w_cur_ptr, dw_cur_ptr, db_cur_ptr, req_data, req_params, req_state, req_statecell); + + // Prevent overwritting dy while calculating dx in left2right layer + const int loop_iteration = (L - 1) - i; + dy_tmp_ptr = loop_iteration % 2 ? dy_tmp_ptr - y_size : dy_tmp_ptr + y_size; } if (dropout > 0.0f && i > 0 && req_data != kNullOp) { dropout_random = dropout_random - T * N * D * H; @@ -1504,7 +1509,7 @@ void GruBackward(DType* ws, if (dhy_l) dhy_l = dhy_l - D * N * H; y_l = y_l - T * N * H * D; - y_tmp = y_l; + y_tmp = y_tmp - T * N * H * D; if (l == 1) { wx_l = wx_l - (inputsize + H) * H * 3 * D; wh_l = wx_l + inputsize * 3 * H; diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index f2a220bbe719..6f9308b12cea 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -685,15 +685,10 @@ def check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_siz stack_input_grad = sx.grad.asnumpy() assert_allclose(fused_out.asnumpy(), stack_out.asnumpy(), rtol=rtol, atol=atol) - if mx.context.current_context().device_type == 'cpu' and \ - not mx.runtime.Features().is_enabled('MKLDNN') and \ - 'rnn' not in fused_layer.prefix: - print("LSTM and GRU on native CPU give wrong gradients. " - "Tracking issue: https://github.com/apache/incubator-mxnet/issues/17898.") - else: - assert_allclose(fused_input_grad, stack_input_grad, rtol=rtol, atol=atol) - for key, value in fused_grads.items(): - assert_allclose(value.asnumpy(), stack_grads[key].asnumpy(), rtol=rtol, atol=atol) + assert_allclose(fused_input_grad, stack_input_grad, rtol=rtol, atol=atol) + for key, value in fused_grads.items(): + assert_allclose(value.asnumpy(), stack_grads[key].asnumpy(), rtol=rtol, atol=atol) + num_layers = fused_begin_state[0].shape[0] // (2 if bidirectional else 1) check_rnn_states(fused_states, stack_states, num_layers, bidirectional, len(fused_begin_state) == 2) @@ -719,61 +714,32 @@ def create_op_by_mode(mode): return fused_op, stack_op, recurrent_block_prefix -def check_rnn_unidir_layer_gradients(mode, input_size, hidden_size, loss): +def check_rnn_unidir_layer_gradients(mode, input_size, hidden_size, num_layers, loss): fused_op, stack_op, recurrent_block_prefix = create_op_by_mode(mode) - # ==== Single layer ==== - fused_layer = fused_op(hidden_size, num_layers=1, layout='NTC', bidirectional=False, prefix=recurrent_block_prefix) - fused_layer.initialize() - - stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix) - with stack_layer.name_scope(): - stack_layer.add(stack_op(hidden_size, prefix='l0_')) - stack_layer.initialize() - check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size) - - # ==== Multiple layer ==== - fused_layer = fused_op(hidden_size, num_layers=3, layout='NTC', bidirectional=False, prefix=recurrent_block_prefix) + fused_layer = fused_op(hidden_size, num_layers=num_layers, layout='NTC', bidirectional=False, prefix=recurrent_block_prefix) fused_layer.initialize() stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix) with stack_layer.name_scope(): - stack_layer.add(stack_op(hidden_size, prefix='l0_')) - stack_layer.add(stack_op(hidden_size, prefix='l1_')) - stack_layer.add(stack_op(hidden_size, prefix='l2_')) + for n in range(num_layers): + stack_layer.add(stack_op(hidden_size, prefix="l{}_".format(n))) stack_layer.initialize() - check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size) -def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, loss): +def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, num_layers, loss): fused_op, stack_op, recurrent_block_prefix = create_op_by_mode(mode) - # ==== Single layer ==== - fused_layer = fused_op(hidden_size, num_layers=1, layout='NTC', bidirectional=True, prefix=recurrent_block_prefix) - fused_layer.initialize() - - stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix) - with stack_layer.name_scope(): - stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix='l0_'), - stack_op(hidden_size, prefix='r0_'))) - stack_layer.initialize() - check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size, bidirectional=True) - - # ==== Multiple layer ==== - fused_layer = fused_op(hidden_size, num_layers=3, layout='NTC', bidirectional=True, prefix=recurrent_block_prefix) + fused_layer = fused_op(hidden_size, num_layers=num_layers, layout='NTC', bidirectional=True, prefix=recurrent_block_prefix) fused_layer.initialize() stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix) with stack_layer.name_scope(): - stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix='l0_'), - stack_op(hidden_size, prefix='r0_'))) - stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix='l1_'), - stack_op(hidden_size, prefix='r1_'))) - stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix='l2_'), - stack_op(hidden_size, prefix='r2_'))) - stack_layer.initialize() - + for n in range(num_layers): + stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix="l{}_".format(n)), + stack_op(hidden_size, prefix="r{}_".format(n)))) + stack_layer.initialize() check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size, bidirectional=True) @@ -782,10 +748,11 @@ def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, loss): def test_fused_lstm_layer(): input_sizes = [8] hidden_sizes = [8, 16] - for input_size, hidden_size in product(input_sizes, hidden_sizes): + num_layers = [1, 2, 3, 4] + for input_size, hidden_size, num_layers in product(input_sizes, hidden_sizes, num_layers): loss = mx.gluon.loss.L2Loss() - check_rnn_unidir_layer_gradients('lstm', input_size, hidden_size, loss) - check_rnn_bidir_layer_gradients('lstm', input_size, hidden_size, loss) + check_rnn_unidir_layer_gradients('lstm', input_size, hidden_size, num_layers, loss) + check_rnn_bidir_layer_gradients('lstm', input_size, hidden_size, num_layers, loss) @with_seed() @@ -793,10 +760,11 @@ def test_fused_lstm_layer(): def test_fused_gru_layer(): input_sizes = [8] hidden_sizes = [8, 16] - for input_size, hidden_size in product(input_sizes, hidden_sizes): + num_layers = [1, 2, 3, 4] + for input_size, hidden_size, num_layers in product(input_sizes, hidden_sizes, num_layers): loss = mx.gluon.loss.L2Loss() - check_rnn_unidir_layer_gradients('gru', input_size, hidden_size, loss) - check_rnn_bidir_layer_gradients('gru', input_size, hidden_size, loss) + check_rnn_unidir_layer_gradients('gru', input_size, hidden_size, num_layers, loss) + check_rnn_bidir_layer_gradients('gru', input_size, hidden_size, num_layers, loss) @with_seed() @@ -804,10 +772,11 @@ def test_fused_gru_layer(): def test_fused_rnnrelu_layer(): input_sizes = [8] hidden_sizes = [8, 16] - for input_size, hidden_size in product(input_sizes, hidden_sizes): + num_layers = [1, 2, 3, 4] + for input_size, hidden_size, num_layers in product(input_sizes, hidden_sizes, num_layers): loss = mx.gluon.loss.L2Loss() - check_rnn_unidir_layer_gradients('rnn_relu', input_size, hidden_size, loss) - check_rnn_bidir_layer_gradients('rnn_relu', input_size, hidden_size, loss) + check_rnn_unidir_layer_gradients('rnn_relu', input_size, hidden_size, num_layers, loss) + check_rnn_bidir_layer_gradients('rnn_relu', input_size, hidden_size, num_layers, loss) @with_seed() @@ -815,10 +784,11 @@ def test_fused_rnnrelu_layer(): def test_fused_rnntanh_layer(): input_sizes = [8] hidden_sizes = [8, 16] - for input_size, hidden_size in product(input_sizes, hidden_sizes): + num_layers = [1, 2, 3, 4] + for input_size, hidden_size, num_layers in product(input_sizes, hidden_sizes, num_layers): loss = mx.gluon.loss.L2Loss() - check_rnn_unidir_layer_gradients('rnn_tanh', input_size, hidden_size, loss) - check_rnn_bidir_layer_gradients('rnn_tanh', input_size, hidden_size, loss) + check_rnn_unidir_layer_gradients('rnn_tanh', input_size, hidden_size, num_layers, loss) + check_rnn_bidir_layer_gradients('rnn_tanh', input_size, hidden_size, num_layers, loss) def test_rnn_unroll_variant_length():