Skip to content

Commit

Permalink
Clean up RNN tests
Browse files Browse the repository at this point in the history
  • Loading branch information
taehoonlee committed Aug 5, 2017
1 parent 281bc58 commit 3a441a7
Showing 1 changed file with 50 additions and 118 deletions.
168 changes: 50 additions & 118 deletions tests/keras/backend/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,162 +463,94 @@ def test_function(self):

def test_rnn(self):
# implement a simple RNN
input_dim = 8
output_dim = 4
timesteps = 5
num_samples = 4
input_dim = 5
output_dim = 3
timesteps = 6

input_val = np.random.random((32, timesteps, input_dim))
init_state_val = np.random.random((32, output_dim))
W_i_val = np.random.random((input_dim, output_dim))
W_o_val = np.random.random((output_dim, output_dim))
input_val = np.random.random((num_samples, timesteps, input_dim)).astype(np.float32)
init_state_val = np.random.random((num_samples, output_dim)).astype(np.float32)
W_i_val = np.random.random((input_dim, output_dim)).astype(np.float32)
W_o_val = np.random.random((output_dim, output_dim)).astype(np.float32)
np_mask = np.random.randint(2, size=(num_samples, timesteps))

def rnn_step_fn(input_dim, output_dim, K):
W_i = K.variable(W_i_val)
W_o = K.variable(W_o_val)
def rnn_step_fn(input_dim, output_dim, k):
W_i = k.variable(W_i_val)
W_o = k.variable(W_o_val)

def step_function(x, states):
assert len(states) == 1
prev_output = states[0]
output = K.dot(x, W_i) + K.dot(prev_output, W_o)
output = k.dot(x, W_i) + k.dot(prev_output, W_o)
return output, [output]

return step_function

# test default setup
last_output_list = []
outputs_list = []
state_list = []

unrolled_last_output_list = []
unrolled_outputs_list = []
unrolled_states_list = []

backwards_last_output_list = []
backwards_outputs_list = []
backwards_states_list = []

bwd_unrolled_last_output_list = []
bwd_unrolled_outputs_list = []
bwd_unrolled_states_list = []

masked_last_output_list = []
masked_outputs_list = []
masked_states_list = []

unrolled_masked_last_output_list = []
unrolled_masked_outputs_list = []
unrolled_masked_states_list = []
last_output_list = [[], [], [], [], [], []]
outputs_list = [[], [], [], [], [], []]
state_list = [[], [], [], [], [], []]
kwargs_list = [
{'go_backwards': False, 'mask': None},
{'go_backwards': False, 'mask': None, 'unroll': True, 'input_length': timesteps},
{'go_backwards': True, 'mask': None},
{'go_backwards': True, 'mask': None, 'unroll': True, 'input_length': timesteps},
{'go_backwards': False, 'mask': mask},
{'go_backwards': False, 'mask': mask, 'unroll': True, 'input_length': timesteps},
]

for k in BACKENDS:
rnn_fn = rnn_step_fn(input_dim, output_dim, k)
inputs = k.variable(input_val)
initial_states = [k.variable(init_state_val)]
last_output, outputs, new_states = k.rnn(rnn_fn, inputs,
initial_states,
go_backwards=False,
mask=None)

last_output_list.append(k.eval(last_output))
outputs_list.append(k.eval(outputs))
assert len(new_states) == 1
state_list.append(k.eval(new_states[0]))
# test unroll
unrolled_last_output, unrolled_outputs, unrolled_new_states = k.rnn(
rnn_fn, inputs,
initial_states,
go_backwards=False,
mask=None,
unroll=True,
input_length=timesteps)

unrolled_last_output_list.append(k.eval(unrolled_last_output))
unrolled_outputs_list.append(k.eval(unrolled_outputs))
assert len(unrolled_new_states) == 1
unrolled_states_list.append(k.eval(unrolled_new_states[0]))

backwards_last_output, backwards_outputs, backwards_new_states = k.rnn(rnn_fn, inputs,
initial_states,
go_backwards=True,
mask=None)
backwards_last_output_list.append(k.eval(backwards_last_output))
backwards_outputs_list.append(k.eval(backwards_outputs))
assert len(backwards_new_states) == 1
backwards_states_list.append(k.eval(backwards_new_states[0]))

bwd_unrolled_last_output, bwd_unrolled_outputs, bwd_unrolled_new_states = k.rnn(
rnn_fn, inputs,
initial_states,
go_backwards=True,
mask=None,
unroll=True,
input_length=timesteps)

bwd_unrolled_last_output_list.append(k.eval(bwd_unrolled_last_output))
bwd_unrolled_outputs_list.append(k.eval(bwd_unrolled_outputs))
assert len(bwd_unrolled_new_states) == 1
bwd_unrolled_states_list.append(k.eval(bwd_unrolled_new_states[0]))

np_mask = np.random.randint(2, size=(32, timesteps))
mask = k.variable(np_mask)

masked_last_output, masked_outputs, masked_new_states = k.rnn(
rnn_fn, inputs,
initial_states,
go_backwards=False,
mask=mask)
masked_last_output_list.append(k.eval(masked_last_output))
masked_outputs_list.append(k.eval(masked_outputs))
assert len(masked_new_states) == 1
masked_states_list.append(k.eval(masked_new_states[0]))

unrolled_masked_last_output, unrolled_masked_outputs, unrolled_masked_new_states = k.rnn(
rnn_fn, inputs,
initial_states,
go_backwards=False,
mask=mask,
unroll=True,
input_length=timesteps)
unrolled_masked_last_output_list.append(k.eval(unrolled_masked_last_output))
unrolled_masked_outputs_list.append(k.eval(unrolled_masked_outputs))
assert len(unrolled_masked_new_states) == 1
unrolled_masked_states_list.append(k.eval(unrolled_masked_new_states[0]))

assert_list_pairwise(last_output_list, shape=False, atol=1e-04)
assert_list_pairwise(outputs_list, shape=False, atol=1e-04)
assert_list_pairwise(state_list, shape=False, atol=1e-04)
assert_list_pairwise(backwards_states_list, shape=False, atol=1e-04)
assert_list_pairwise(backwards_last_output_list, shape=False, atol=1e-04)
assert_list_pairwise(backwards_outputs_list, shape=False, atol=1e-04)

for l, u_l in zip(last_output_list, unrolled_last_output_list):
for (i, kwargs) in enumerate(kwargs_list):
last_output, outputs, new_states = k.rnn(rnn_fn, inputs,
initial_states,
**kwargs)

last_output_list[i].append(k.eval(last_output))
outputs_list[i].append(k.eval(outputs))
assert len(new_states) == 1
state_list[i].append(k.eval(new_states[0]))

assert_list_pairwise(last_output_list[0], shape=False, atol=1e-04)
assert_list_pairwise(outputs_list[0], shape=False, atol=1e-04)
assert_list_pairwise(state_list[0], shape=False, atol=1e-04)
assert_list_pairwise(last_output_list[2], shape=False, atol=1e-04)
assert_list_pairwise(outputs_list[2], shape=False, atol=1e-04)
assert_list_pairwise(state_list[2], shape=False, atol=1e-04)

for l, u_l in zip(last_output_list[0], last_output_list[1]):
assert_allclose(l, u_l, atol=1e-04)

for o, u_o in zip(outputs_list, unrolled_outputs_list):
for o, u_o in zip(outputs_list[0], outputs_list[1]):
assert_allclose(o, u_o, atol=1e-04)

for s, u_s in zip(state_list, unrolled_states_list):
for s, u_s in zip(state_list[0], state_list[1]):
assert_allclose(s, u_s, atol=1e-04)

for b_l, b_u_l in zip(backwards_last_output_list, bwd_unrolled_last_output_list):
for b_l, b_u_l in zip(last_output_list[2], last_output_list[3]):
assert_allclose(b_l, b_u_l, atol=1e-04)

for b_o, b_u_o, in zip(backwards_outputs_list, bwd_unrolled_outputs_list):
for b_o, b_u_o in zip(outputs_list[2], outputs_list[3]):
assert_allclose(b_o, b_u_o, atol=1e-04)

for b_s, b_u_s in zip(backwards_states_list, bwd_unrolled_states_list):
for b_s, b_u_s in zip(state_list[2], state_list[3]):
assert_allclose(b_s, b_u_s, atol=1e-04)

for m_l, u_m_l, k in zip(masked_last_output_list, unrolled_masked_last_output_list, BACKENDS):
for m_l, u_m_l, k in zip(last_output_list[4], last_output_list[5], BACKENDS):
# skip this compare on tensorflow
if k != KTF:
assert_allclose(m_l, u_m_l, atol=1e-04)

for m_o, u_m_o, k in zip(masked_outputs_list, unrolled_masked_outputs_list, BACKENDS):
for m_o, u_m_o, k in zip(outputs_list[4], outputs_list[5], BACKENDS):
# skip this compare on tensorflow
if k != KTF:
assert_allclose(m_o, u_m_o, atol=1e-04)

for m_s, u_m_s, k in zip(masked_states_list, unrolled_masked_states_list, BACKENDS):
for m_s, u_m_s, k in zip(state_list[4], state_list[5], BACKENDS):
if k != KTF:
assert_allclose(m_s, u_m_s, atol=1e-04)

Expand Down

0 comments on commit 3a441a7

Please sign in to comment.