Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up RNN tests #7529

Merged
merged 1 commit into from
Aug 5, 2017
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 51 additions & 118 deletions tests/keras/backend/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,162 +463,95 @@ def test_function(self):

def test_rnn(self):
# implement a simple RNN
input_dim = 8
output_dim = 4
timesteps = 5
num_samples = 4
input_dim = 5
output_dim = 3
timesteps = 6

input_val = np.random.random((32, timesteps, input_dim))
init_state_val = np.random.random((32, output_dim))
W_i_val = np.random.random((input_dim, output_dim))
W_o_val = np.random.random((output_dim, output_dim))
input_val = np.random.random((num_samples, timesteps, input_dim)).astype(np.float32)
init_state_val = np.random.random((num_samples, output_dim)).astype(np.float32)
W_i_val = np.random.random((input_dim, output_dim)).astype(np.float32)
W_o_val = np.random.random((output_dim, output_dim)).astype(np.float32)
np_mask = np.random.randint(2, size=(num_samples, timesteps))

def rnn_step_fn(input_dim, output_dim, K):
W_i = K.variable(W_i_val)
W_o = K.variable(W_o_val)
def rnn_step_fn(input_dim, output_dim, k):
W_i = k.variable(W_i_val)
W_o = k.variable(W_o_val)

def step_function(x, states):
assert len(states) == 1
prev_output = states[0]
output = K.dot(x, W_i) + K.dot(prev_output, W_o)
output = k.dot(x, W_i) + k.dot(prev_output, W_o)
return output, [output]

return step_function

# test default setup
last_output_list = []
outputs_list = []
state_list = []

unrolled_last_output_list = []
unrolled_outputs_list = []
unrolled_states_list = []

backwards_last_output_list = []
backwards_outputs_list = []
backwards_states_list = []

bwd_unrolled_last_output_list = []
bwd_unrolled_outputs_list = []
bwd_unrolled_states_list = []

masked_last_output_list = []
masked_outputs_list = []
masked_states_list = []

unrolled_masked_last_output_list = []
unrolled_masked_outputs_list = []
unrolled_masked_states_list = []
last_output_list = [[], [], [], [], [], []]
outputs_list = [[], [], [], [], [], []]
state_list = [[], [], [], [], [], []]

for k in BACKENDS:
rnn_fn = rnn_step_fn(input_dim, output_dim, k)
inputs = k.variable(input_val)
initial_states = [k.variable(init_state_val)]
last_output, outputs, new_states = k.rnn(rnn_fn, inputs,
initial_states,
go_backwards=False,
mask=None)

last_output_list.append(k.eval(last_output))
outputs_list.append(k.eval(outputs))
assert len(new_states) == 1
state_list.append(k.eval(new_states[0]))
# test unroll
unrolled_last_output, unrolled_outputs, unrolled_new_states = k.rnn(
rnn_fn, inputs,
initial_states,
go_backwards=False,
mask=None,
unroll=True,
input_length=timesteps)

unrolled_last_output_list.append(k.eval(unrolled_last_output))
unrolled_outputs_list.append(k.eval(unrolled_outputs))
assert len(unrolled_new_states) == 1
unrolled_states_list.append(k.eval(unrolled_new_states[0]))

backwards_last_output, backwards_outputs, backwards_new_states = k.rnn(rnn_fn, inputs,
initial_states,
go_backwards=True,
mask=None)
backwards_last_output_list.append(k.eval(backwards_last_output))
backwards_outputs_list.append(k.eval(backwards_outputs))
assert len(backwards_new_states) == 1
backwards_states_list.append(k.eval(backwards_new_states[0]))

bwd_unrolled_last_output, bwd_unrolled_outputs, bwd_unrolled_new_states = k.rnn(
rnn_fn, inputs,
initial_states,
go_backwards=True,
mask=None,
unroll=True,
input_length=timesteps)

bwd_unrolled_last_output_list.append(k.eval(bwd_unrolled_last_output))
bwd_unrolled_outputs_list.append(k.eval(bwd_unrolled_outputs))
assert len(bwd_unrolled_new_states) == 1
bwd_unrolled_states_list.append(k.eval(bwd_unrolled_new_states[0]))

np_mask = np.random.randint(2, size=(32, timesteps))
mask = k.variable(np_mask)

masked_last_output, masked_outputs, masked_new_states = k.rnn(
rnn_fn, inputs,
initial_states,
go_backwards=False,
mask=mask)
masked_last_output_list.append(k.eval(masked_last_output))
masked_outputs_list.append(k.eval(masked_outputs))
assert len(masked_new_states) == 1
masked_states_list.append(k.eval(masked_new_states[0]))

unrolled_masked_last_output, unrolled_masked_outputs, unrolled_masked_new_states = k.rnn(
rnn_fn, inputs,
initial_states,
go_backwards=False,
mask=mask,
unroll=True,
input_length=timesteps)
unrolled_masked_last_output_list.append(k.eval(unrolled_masked_last_output))
unrolled_masked_outputs_list.append(k.eval(unrolled_masked_outputs))
assert len(unrolled_masked_new_states) == 1
unrolled_masked_states_list.append(k.eval(unrolled_masked_new_states[0]))

assert_list_pairwise(last_output_list, shape=False, atol=1e-04)
assert_list_pairwise(outputs_list, shape=False, atol=1e-04)
assert_list_pairwise(state_list, shape=False, atol=1e-04)
assert_list_pairwise(backwards_states_list, shape=False, atol=1e-04)
assert_list_pairwise(backwards_last_output_list, shape=False, atol=1e-04)
assert_list_pairwise(backwards_outputs_list, shape=False, atol=1e-04)

for l, u_l in zip(last_output_list, unrolled_last_output_list):
kwargs_list = [
{'go_backwards': False, 'mask': None},
{'go_backwards': False, 'mask': None, 'unroll': True, 'input_length': timesteps},
{'go_backwards': True, 'mask': None},
{'go_backwards': True, 'mask': None, 'unroll': True, 'input_length': timesteps},
{'go_backwards': False, 'mask': mask},
{'go_backwards': False, 'mask': mask, 'unroll': True, 'input_length': timesteps},
]

for (i, kwargs) in enumerate(kwargs_list):
last_output, outputs, new_states = k.rnn(rnn_fn, inputs,
initial_states,
**kwargs)

last_output_list[i].append(k.eval(last_output))
outputs_list[i].append(k.eval(outputs))
assert len(new_states) == 1
state_list[i].append(k.eval(new_states[0]))

assert_list_pairwise(last_output_list[0], shape=False, atol=1e-04)
assert_list_pairwise(outputs_list[0], shape=False, atol=1e-04)
assert_list_pairwise(state_list[0], shape=False, atol=1e-04)
assert_list_pairwise(last_output_list[2], shape=False, atol=1e-04)
assert_list_pairwise(outputs_list[2], shape=False, atol=1e-04)
assert_list_pairwise(state_list[2], shape=False, atol=1e-04)

for l, u_l in zip(last_output_list[0], last_output_list[1]):
assert_allclose(l, u_l, atol=1e-04)

for o, u_o in zip(outputs_list, unrolled_outputs_list):
for o, u_o in zip(outputs_list[0], outputs_list[1]):
assert_allclose(o, u_o, atol=1e-04)

for s, u_s in zip(state_list, unrolled_states_list):
for s, u_s in zip(state_list[0], state_list[1]):
assert_allclose(s, u_s, atol=1e-04)

for b_l, b_u_l in zip(backwards_last_output_list, bwd_unrolled_last_output_list):
for b_l, b_u_l in zip(last_output_list[2], last_output_list[3]):
assert_allclose(b_l, b_u_l, atol=1e-04)

for b_o, b_u_o, in zip(backwards_outputs_list, bwd_unrolled_outputs_list):
for b_o, b_u_o in zip(outputs_list[2], outputs_list[3]):
assert_allclose(b_o, b_u_o, atol=1e-04)

for b_s, b_u_s in zip(backwards_states_list, bwd_unrolled_states_list):
for b_s, b_u_s in zip(state_list[2], state_list[3]):
assert_allclose(b_s, b_u_s, atol=1e-04)

for m_l, u_m_l, k in zip(masked_last_output_list, unrolled_masked_last_output_list, BACKENDS):
for m_l, u_m_l, k in zip(last_output_list[4], last_output_list[5], BACKENDS):
# skip this compare on tensorflow
if k != KTF:
assert_allclose(m_l, u_m_l, atol=1e-04)

for m_o, u_m_o, k in zip(masked_outputs_list, unrolled_masked_outputs_list, BACKENDS):
for m_o, u_m_o, k in zip(outputs_list[4], outputs_list[5], BACKENDS):
# skip this compare on tensorflow
if k != KTF:
assert_allclose(m_o, u_m_o, atol=1e-04)

for m_s, u_m_s, k in zip(masked_states_list, unrolled_masked_states_list, BACKENDS):
for m_s, u_m_s, k in zip(state_list[4], state_list[5], BACKENDS):
if k != KTF:
assert_allclose(m_s, u_m_s, atol=1e-04)

Expand Down