From 13c29948a9b324bffc4450a9b3fc5b4f0240398d Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Thu, 14 Jun 2018 22:37:55 +0000 Subject: [PATCH] add tests for multiple batches. --- tests/python/unittest/test_operator.py | 97 ++++++++++++++++---------- 1 file changed, 59 insertions(+), 38 deletions(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index a2045d77c955..7176c4dd5b5c 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -5959,6 +5959,7 @@ def sym_group(out): ret.extend(out[1]) return mx.sym.Group(ret) + # Inputs data_arr = mx.nd.random.uniform(shape=(2, 2, 4)) h_arr = mx.nd.random.uniform(shape=(2, 4)) c_arr = mx.nd.random.uniform(shape=(2, 4)) @@ -5966,7 +5967,14 @@ def sym_group(out): h2h_warr = mx.nd.random.uniform(shape=(16, 4)) i2h_barr = mx.nd.random.uniform(shape=(16)) h2h_barr = mx.nd.random.uniform(shape=(16)) - + args1 = {'data': data_arr, 'h': h_arr, 'c': c_arr, + 'i2h_weight': i2h_warr, 'h2h_weight': h2h_warr, + 'i2h_bias': i2h_barr, 'h2h_bias': h2h_barr} + args2 = {'data': data_arr, 'h': h_arr, 'c': c_arr, + 'mylstm_i2h_weight': i2h_warr, 'mylstm_h2h_weight': h2h_warr, + 'mylstm_i2h_bias': i2h_barr, 'mylstm_h2h_bias': h2h_barr} + + # gradients for the backward of the foreach symbol data_arr_grad1 = mx.nd.empty(data_arr.shape) h_arr_grad1 = mx.nd.empty(h_arr.shape) c_arr_grad1 = mx.nd.empty(c_arr.shape) @@ -5974,28 +5982,11 @@ def sym_group(out): h2h_warr_grad1 = mx.nd.empty(h2h_warr.shape) i2h_barr_grad1 = mx.nd.empty(i2h_barr.shape) h2h_barr_grad1 = mx.nd.empty(h2h_barr.shape) - out = mx.sym.contrib.foreach(step, data, [init_h, init_c]) - out = sym_group(out) - js_1 = out.tojson() - out = mx.sym.load_json(js_1) - js_2 = out.tojson() - assert js_1 == js_2 - - e1 = out.bind(ctx=default_context(), - args={'data': data_arr, 'h': h_arr, 'c': c_arr, - 'i2h_weight': i2h_warr, 'h2h_weight': h2h_warr, - 'i2h_bias': i2h_barr, 'h2h_bias': h2h_barr}, - args_grad={'data': data_arr_grad1, 'h': h_arr_grad1, 'c': c_arr_grad1, - 'i2h_weight': i2h_warr_grad1, 'h2h_weight': h2h_warr_grad1, - 'i2h_bias': i2h_barr_grad1, 'h2h_bias': h2h_barr_grad1}) - e1.forward(is_train=True) - outputs1 = e1.outputs - # backward - out_grads = [] - for arr in e1.outputs: - out_grads.append(mx.nd.random.uniform(-10, 10, arr.shape)) - e1.backward(out_grads) + args_grad1 = {'data': data_arr_grad1, 'h': h_arr_grad1, 'c': c_arr_grad1, + 'i2h_weight': i2h_warr_grad1, 'h2h_weight': h2h_warr_grad1, + 'i2h_bias': i2h_barr_grad1, 'h2h_bias': h2h_barr_grad1} + # gradients for the backward of the unrolled symbol. data_arr_grad2 = mx.nd.empty(data_arr.shape) h_arr_grad2 = mx.nd.empty(h_arr.shape) c_arr_grad2 = mx.nd.empty(c_arr.shape) @@ -6003,6 +5994,20 @@ def sym_group(out): h2h_warr_grad2 = mx.nd.empty(h2h_warr.shape) i2h_barr_grad2 = mx.nd.empty(i2h_barr.shape) h2h_barr_grad2 = mx.nd.empty(h2h_barr.shape) + args_grad2 = {'data': data_arr_grad2, 'h': h_arr_grad2, 'c': c_arr_grad2, + 'mylstm_i2h_weight': i2h_warr_grad2, 'mylstm_h2h_weight': h2h_warr_grad2, + 'mylstm_i2h_bias': i2h_barr_grad2, 'mylstm_h2h_bias': h2h_barr_grad2} + + # Symbol of running LSTM with foreach. + out = mx.sym.contrib.foreach(step, data, [init_h, init_c]) + out = sym_group(out) + js_1 = out.tojson() + out = mx.sym.load_json(js_1) + js_2 = out.tojson() + assert js_1 == js_2 + e1 = out.bind(ctx=default_context(), args=args1, args_grad=args_grad1) + + # Symbol of running unrolled LSTM. lstm = mx.rnn.LSTMCell(4, prefix='mylstm_') h = init_h c = init_c @@ -6016,22 +6021,38 @@ def sym_group(out): out = mx.sym.load_json(js_1) js_2 = out.tojson() assert js_1 == js_2 - - e2 = out.bind(ctx=default_context(), - args={'data': data_arr, 'h': h_arr, 'c': c_arr, - 'mylstm_i2h_weight': i2h_warr, 'mylstm_h2h_weight': h2h_warr, - 'mylstm_i2h_bias': i2h_barr, 'mylstm_h2h_bias': h2h_barr}, - args_grad={'data': data_arr_grad2, 'h': h_arr_grad2, 'c': c_arr_grad2, - 'mylstm_i2h_weight': i2h_warr_grad2, 'mylstm_h2h_weight': h2h_warr_grad2, - 'mylstm_i2h_bias': i2h_barr_grad2, 'mylstm_h2h_bias': h2h_barr_grad2}) - e2.forward(is_train=True) - outputs2 = e2.outputs - e2.backward(out_grads) - - for i in range(len(outputs2)): - assert_almost_equal(outputs1[i].asnumpy(), outputs2[i].asnumpy(), rtol=0.001, atol=0.0001) - for i in range(len(e1.grad_arrays)): - assert_almost_equal(e1.grad_arrays[i].asnumpy(), e2.grad_arrays[i].asnumpy()) + e2 = out.bind(ctx=default_context(), args=args2, args_grad=args_grad2) + + for i in range(5): + out_grads = [] + for arr in e1.outputs: + out_grads.append(mx.nd.random.uniform(-10, 10, arr.shape)) + + data_arr = mx.nd.random.uniform(shape=(2, 2, 4)) + h_arr = mx.nd.random.uniform(shape=(2, 4)) + c_arr = mx.nd.random.uniform(shape=(2, 4)) + i2h_warr = mx.nd.random.uniform(shape=(16, 4)) + h2h_warr = mx.nd.random.uniform(shape=(16, 4)) + i2h_barr = mx.nd.random.uniform(shape=(16)) + h2h_barr = mx.nd.random.uniform(shape=(16)) + + e1.forward(is_train=True, data = data_arr, h = h_arr, c = c_arr, + i2h_weight = i2h_warr, h2h_weight = h2h_warr, + i2h_bias = i2h_barr, h2h_bias = h2h_barr) + outputs1 = e1.outputs + e1.backward(out_grads) + + e2.forward(is_train=True, data = data_arr, h = h_arr, c = c_arr, + mylstm_i2h_weight = i2h_warr, mylstm_h2h_weight = h2h_warr, + mylstm_i2h_bias = i2h_barr, mylstm_h2h_bias = h2h_barr) + outputs2 = e2.outputs + e2.backward(out_grads) + + for i in range(len(outputs2)): + assert_almost_equal(outputs1[i].asnumpy(), outputs2[i].asnumpy(), + rtol=0.001, atol=0.0001) + for i in range(len(e1.grad_arrays)): + assert_almost_equal(e1.grad_arrays[i].asnumpy(), e2.grad_arrays[i].asnumpy()) @with_seed()