Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add tests for multiple batches.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed Jun 14, 2018
1 parent d6799ac commit 13c2994
Showing 1 changed file with 59 additions and 38 deletions.
97 changes: 59 additions & 38 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5959,50 +5959,55 @@ def sym_group(out):
ret.extend(out[1])
return mx.sym.Group(ret)

# Inputs
data_arr = mx.nd.random.uniform(shape=(2, 2, 4))
h_arr = mx.nd.random.uniform(shape=(2, 4))
c_arr = mx.nd.random.uniform(shape=(2, 4))
i2h_warr = mx.nd.random.uniform(shape=(16, 4))
h2h_warr = mx.nd.random.uniform(shape=(16, 4))
i2h_barr = mx.nd.random.uniform(shape=(16))
h2h_barr = mx.nd.random.uniform(shape=(16))

args1 = {'data': data_arr, 'h': h_arr, 'c': c_arr,
'i2h_weight': i2h_warr, 'h2h_weight': h2h_warr,
'i2h_bias': i2h_barr, 'h2h_bias': h2h_barr}
args2 = {'data': data_arr, 'h': h_arr, 'c': c_arr,
'mylstm_i2h_weight': i2h_warr, 'mylstm_h2h_weight': h2h_warr,
'mylstm_i2h_bias': i2h_barr, 'mylstm_h2h_bias': h2h_barr}

# gradients for the backward of the foreach symbol
data_arr_grad1 = mx.nd.empty(data_arr.shape)
h_arr_grad1 = mx.nd.empty(h_arr.shape)
c_arr_grad1 = mx.nd.empty(c_arr.shape)
i2h_warr_grad1 = mx.nd.empty(i2h_warr.shape)
h2h_warr_grad1 = mx.nd.empty(h2h_warr.shape)
i2h_barr_grad1 = mx.nd.empty(i2h_barr.shape)
h2h_barr_grad1 = mx.nd.empty(h2h_barr.shape)
out = mx.sym.contrib.foreach(step, data, [init_h, init_c])
out = sym_group(out)
js_1 = out.tojson()
out = mx.sym.load_json(js_1)
js_2 = out.tojson()
assert js_1 == js_2

e1 = out.bind(ctx=default_context(),
args={'data': data_arr, 'h': h_arr, 'c': c_arr,
'i2h_weight': i2h_warr, 'h2h_weight': h2h_warr,
'i2h_bias': i2h_barr, 'h2h_bias': h2h_barr},
args_grad={'data': data_arr_grad1, 'h': h_arr_grad1, 'c': c_arr_grad1,
'i2h_weight': i2h_warr_grad1, 'h2h_weight': h2h_warr_grad1,
'i2h_bias': i2h_barr_grad1, 'h2h_bias': h2h_barr_grad1})
e1.forward(is_train=True)
outputs1 = e1.outputs
# backward
out_grads = []
for arr in e1.outputs:
out_grads.append(mx.nd.random.uniform(-10, 10, arr.shape))
e1.backward(out_grads)
args_grad1 = {'data': data_arr_grad1, 'h': h_arr_grad1, 'c': c_arr_grad1,
'i2h_weight': i2h_warr_grad1, 'h2h_weight': h2h_warr_grad1,
'i2h_bias': i2h_barr_grad1, 'h2h_bias': h2h_barr_grad1}

# gradients for the backward of the unrolled symbol.
data_arr_grad2 = mx.nd.empty(data_arr.shape)
h_arr_grad2 = mx.nd.empty(h_arr.shape)
c_arr_grad2 = mx.nd.empty(c_arr.shape)
i2h_warr_grad2 = mx.nd.empty(i2h_warr.shape)
h2h_warr_grad2 = mx.nd.empty(h2h_warr.shape)
i2h_barr_grad2 = mx.nd.empty(i2h_barr.shape)
h2h_barr_grad2 = mx.nd.empty(h2h_barr.shape)
args_grad2 = {'data': data_arr_grad2, 'h': h_arr_grad2, 'c': c_arr_grad2,
'mylstm_i2h_weight': i2h_warr_grad2, 'mylstm_h2h_weight': h2h_warr_grad2,
'mylstm_i2h_bias': i2h_barr_grad2, 'mylstm_h2h_bias': h2h_barr_grad2}

# Symbol of running LSTM with foreach.
out = mx.sym.contrib.foreach(step, data, [init_h, init_c])
out = sym_group(out)
js_1 = out.tojson()
out = mx.sym.load_json(js_1)
js_2 = out.tojson()
assert js_1 == js_2
e1 = out.bind(ctx=default_context(), args=args1, args_grad=args_grad1)

# Symbol of running unrolled LSTM.
lstm = mx.rnn.LSTMCell(4, prefix='mylstm_')
h = init_h
c = init_c
Expand All @@ -6016,22 +6021,38 @@ def sym_group(out):
out = mx.sym.load_json(js_1)
js_2 = out.tojson()
assert js_1 == js_2

e2 = out.bind(ctx=default_context(),
args={'data': data_arr, 'h': h_arr, 'c': c_arr,
'mylstm_i2h_weight': i2h_warr, 'mylstm_h2h_weight': h2h_warr,
'mylstm_i2h_bias': i2h_barr, 'mylstm_h2h_bias': h2h_barr},
args_grad={'data': data_arr_grad2, 'h': h_arr_grad2, 'c': c_arr_grad2,
'mylstm_i2h_weight': i2h_warr_grad2, 'mylstm_h2h_weight': h2h_warr_grad2,
'mylstm_i2h_bias': i2h_barr_grad2, 'mylstm_h2h_bias': h2h_barr_grad2})
e2.forward(is_train=True)
outputs2 = e2.outputs
e2.backward(out_grads)

for i in range(len(outputs2)):
assert_almost_equal(outputs1[i].asnumpy(), outputs2[i].asnumpy(), rtol=0.001, atol=0.0001)
for i in range(len(e1.grad_arrays)):
assert_almost_equal(e1.grad_arrays[i].asnumpy(), e2.grad_arrays[i].asnumpy())
e2 = out.bind(ctx=default_context(), args=args2, args_grad=args_grad2)

for i in range(5):
out_grads = []
for arr in e1.outputs:
out_grads.append(mx.nd.random.uniform(-10, 10, arr.shape))

data_arr = mx.nd.random.uniform(shape=(2, 2, 4))
h_arr = mx.nd.random.uniform(shape=(2, 4))
c_arr = mx.nd.random.uniform(shape=(2, 4))
i2h_warr = mx.nd.random.uniform(shape=(16, 4))
h2h_warr = mx.nd.random.uniform(shape=(16, 4))
i2h_barr = mx.nd.random.uniform(shape=(16))
h2h_barr = mx.nd.random.uniform(shape=(16))

e1.forward(is_train=True, data = data_arr, h = h_arr, c = c_arr,
i2h_weight = i2h_warr, h2h_weight = h2h_warr,
i2h_bias = i2h_barr, h2h_bias = h2h_barr)
outputs1 = e1.outputs
e1.backward(out_grads)

e2.forward(is_train=True, data = data_arr, h = h_arr, c = c_arr,
mylstm_i2h_weight = i2h_warr, mylstm_h2h_weight = h2h_warr,
mylstm_i2h_bias = i2h_barr, mylstm_h2h_bias = h2h_barr)
outputs2 = e2.outputs
e2.backward(out_grads)

for i in range(len(outputs2)):
assert_almost_equal(outputs1[i].asnumpy(), outputs2[i].asnumpy(),
rtol=0.001, atol=0.0001)
for i in range(len(e1.grad_arrays)):
assert_almost_equal(e1.grad_arrays[i].asnumpy(), e2.grad_arrays[i].asnumpy())


@with_seed()
Expand Down

0 comments on commit 13c2994

Please sign in to comment.