diff --git a/tests/keras/backend/backend_test.py b/tests/keras/backend/backend_test.py index f54d7c543283..426bca931335 100644 --- a/tests/keras/backend/backend_test.py +++ b/tests/keras/backend/backend_test.py @@ -463,162 +463,94 @@ def test_function(self): def test_rnn(self): # implement a simple RNN - input_dim = 8 - output_dim = 4 - timesteps = 5 + num_samples = 4 + input_dim = 5 + output_dim = 3 + timesteps = 6 - input_val = np.random.random((32, timesteps, input_dim)) - init_state_val = np.random.random((32, output_dim)) - W_i_val = np.random.random((input_dim, output_dim)) - W_o_val = np.random.random((output_dim, output_dim)) + input_val = np.random.random((num_samples, timesteps, input_dim)).astype(np.float32) + init_state_val = np.random.random((num_samples, output_dim)).astype(np.float32) + W_i_val = np.random.random((input_dim, output_dim)).astype(np.float32) + W_o_val = np.random.random((output_dim, output_dim)).astype(np.float32) + np_mask = np.random.randint(2, size=(num_samples, timesteps)) - def rnn_step_fn(input_dim, output_dim, K): - W_i = K.variable(W_i_val) - W_o = K.variable(W_o_val) + def rnn_step_fn(input_dim, output_dim, k): + W_i = k.variable(W_i_val) + W_o = k.variable(W_o_val) def step_function(x, states): assert len(states) == 1 prev_output = states[0] - output = K.dot(x, W_i) + K.dot(prev_output, W_o) + output = k.dot(x, W_i) + k.dot(prev_output, W_o) return output, [output] return step_function # test default setup - last_output_list = [] - outputs_list = [] - state_list = [] - - unrolled_last_output_list = [] - unrolled_outputs_list = [] - unrolled_states_list = [] - - backwards_last_output_list = [] - backwards_outputs_list = [] - backwards_states_list = [] - - bwd_unrolled_last_output_list = [] - bwd_unrolled_outputs_list = [] - bwd_unrolled_states_list = [] - - masked_last_output_list = [] - masked_outputs_list = [] - masked_states_list = [] - - unrolled_masked_last_output_list = [] - unrolled_masked_outputs_list = [] - unrolled_masked_states_list = [] + last_output_list = [[], [], [], [], [], []] + outputs_list = [[], [], [], [], [], []] + state_list = [[], [], [], [], [], []] + kwargs_list = [ + {'go_backwards': False, 'mask': None}, + {'go_backwards': False, 'mask': None, 'unroll': True, 'input_length': timesteps}, + {'go_backwards': True, 'mask': None}, + {'go_backwards': True, 'mask': None, 'unroll': True, 'input_length': timesteps}, + {'go_backwards': False, 'mask': mask}, + {'go_backwards': False, 'mask': mask, 'unroll': True, 'input_length': timesteps}, + ] for k in BACKENDS: rnn_fn = rnn_step_fn(input_dim, output_dim, k) inputs = k.variable(input_val) initial_states = [k.variable(init_state_val)] - last_output, outputs, new_states = k.rnn(rnn_fn, inputs, - initial_states, - go_backwards=False, - mask=None) - - last_output_list.append(k.eval(last_output)) - outputs_list.append(k.eval(outputs)) - assert len(new_states) == 1 - state_list.append(k.eval(new_states[0])) - # test unroll - unrolled_last_output, unrolled_outputs, unrolled_new_states = k.rnn( - rnn_fn, inputs, - initial_states, - go_backwards=False, - mask=None, - unroll=True, - input_length=timesteps) - - unrolled_last_output_list.append(k.eval(unrolled_last_output)) - unrolled_outputs_list.append(k.eval(unrolled_outputs)) - assert len(unrolled_new_states) == 1 - unrolled_states_list.append(k.eval(unrolled_new_states[0])) - - backwards_last_output, backwards_outputs, backwards_new_states = k.rnn(rnn_fn, inputs, - initial_states, - go_backwards=True, - mask=None) - backwards_last_output_list.append(k.eval(backwards_last_output)) - backwards_outputs_list.append(k.eval(backwards_outputs)) - assert len(backwards_new_states) == 1 - backwards_states_list.append(k.eval(backwards_new_states[0])) - - bwd_unrolled_last_output, bwd_unrolled_outputs, bwd_unrolled_new_states = k.rnn( - rnn_fn, inputs, - initial_states, - go_backwards=True, - mask=None, - unroll=True, - input_length=timesteps) - - bwd_unrolled_last_output_list.append(k.eval(bwd_unrolled_last_output)) - bwd_unrolled_outputs_list.append(k.eval(bwd_unrolled_outputs)) - assert len(bwd_unrolled_new_states) == 1 - bwd_unrolled_states_list.append(k.eval(bwd_unrolled_new_states[0])) - - np_mask = np.random.randint(2, size=(32, timesteps)) mask = k.variable(np_mask) - masked_last_output, masked_outputs, masked_new_states = k.rnn( - rnn_fn, inputs, - initial_states, - go_backwards=False, - mask=mask) - masked_last_output_list.append(k.eval(masked_last_output)) - masked_outputs_list.append(k.eval(masked_outputs)) - assert len(masked_new_states) == 1 - masked_states_list.append(k.eval(masked_new_states[0])) - - unrolled_masked_last_output, unrolled_masked_outputs, unrolled_masked_new_states = k.rnn( - rnn_fn, inputs, - initial_states, - go_backwards=False, - mask=mask, - unroll=True, - input_length=timesteps) - unrolled_masked_last_output_list.append(k.eval(unrolled_masked_last_output)) - unrolled_masked_outputs_list.append(k.eval(unrolled_masked_outputs)) - assert len(unrolled_masked_new_states) == 1 - unrolled_masked_states_list.append(k.eval(unrolled_masked_new_states[0])) - - assert_list_pairwise(last_output_list, shape=False, atol=1e-04) - assert_list_pairwise(outputs_list, shape=False, atol=1e-04) - assert_list_pairwise(state_list, shape=False, atol=1e-04) - assert_list_pairwise(backwards_states_list, shape=False, atol=1e-04) - assert_list_pairwise(backwards_last_output_list, shape=False, atol=1e-04) - assert_list_pairwise(backwards_outputs_list, shape=False, atol=1e-04) - - for l, u_l in zip(last_output_list, unrolled_last_output_list): + for (i, kwargs) in enumerate(kwargs_list): + last_output, outputs, new_states = k.rnn(rnn_fn, inputs, + initial_states, + **kwargs) + + last_output_list[i].append(k.eval(last_output)) + outputs_list[i].append(k.eval(outputs)) + assert len(new_states) == 1 + state_list[i].append(k.eval(new_states[0])) + + assert_list_pairwise(last_output_list[0], shape=False, atol=1e-04) + assert_list_pairwise(outputs_list[0], shape=False, atol=1e-04) + assert_list_pairwise(state_list[0], shape=False, atol=1e-04) + assert_list_pairwise(last_output_list[2], shape=False, atol=1e-04) + assert_list_pairwise(outputs_list[2], shape=False, atol=1e-04) + assert_list_pairwise(state_list[2], shape=False, atol=1e-04) + + for l, u_l in zip(last_output_list[0], last_output_list[1]): assert_allclose(l, u_l, atol=1e-04) - for o, u_o in zip(outputs_list, unrolled_outputs_list): + for o, u_o in zip(outputs_list[0], outputs_list[1]): assert_allclose(o, u_o, atol=1e-04) - for s, u_s in zip(state_list, unrolled_states_list): + for s, u_s in zip(state_list[0], state_list[1]): assert_allclose(s, u_s, atol=1e-04) - for b_l, b_u_l in zip(backwards_last_output_list, bwd_unrolled_last_output_list): + for b_l, b_u_l in zip(last_output_list[2], last_output_list[3]): assert_allclose(b_l, b_u_l, atol=1e-04) - for b_o, b_u_o, in zip(backwards_outputs_list, bwd_unrolled_outputs_list): + for b_o, b_u_o in zip(outputs_list[2], outputs_list[3]): assert_allclose(b_o, b_u_o, atol=1e-04) - for b_s, b_u_s in zip(backwards_states_list, bwd_unrolled_states_list): + for b_s, b_u_s in zip(state_list[2], state_list[3]): assert_allclose(b_s, b_u_s, atol=1e-04) - for m_l, u_m_l, k in zip(masked_last_output_list, unrolled_masked_last_output_list, BACKENDS): + for m_l, u_m_l, k in zip(last_output_list[4], last_output_list[5], BACKENDS): # skip this compare on tensorflow if k != KTF: assert_allclose(m_l, u_m_l, atol=1e-04) - for m_o, u_m_o, k in zip(masked_outputs_list, unrolled_masked_outputs_list, BACKENDS): + for m_o, u_m_o, k in zip(outputs_list[4], outputs_list[5], BACKENDS): # skip this compare on tensorflow if k != KTF: assert_allclose(m_o, u_m_o, atol=1e-04) - for m_s, u_m_s, k in zip(masked_states_list, unrolled_masked_states_list, BACKENDS): + for m_s, u_m_s, k in zip(state_list[4], state_list[5], BACKENDS): if k != KTF: assert_allclose(m_s, u_m_s, atol=1e-04)