From ca14fa05434988ecafa80a9694efe0096f69c994 Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Tue, 7 Aug 2018 17:54:46 +0000 Subject: [PATCH 1/5] added support to check_consistency function to generate random numbers for a specific datatype (ie. fp16) this ensures that for tests that compare results among different precisions, that data is generated in the least precise type and casted to the most precise changed test_pooling_with_type test case to specify fp16 precision for random input data renamed the 2nd test_pooling_with_type function to test_pooling_with_type2 so it doesnt redefine the first and both are tested fixed equation formatting issue in pooling operator description Added myself to the contributors readme file --- CONTRIBUTORS.md | 1 + python/mxnet/test_utils.py | 9 +- src/operator/nn/pooling.cc | 3 +- tests/python/gpu/test_operator_gpu.py | 167 ++++++++++++++++++++++---- 4 files changed, 151 insertions(+), 29 deletions(-) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index b04e4a3d85c3..6bc97bb71fc1 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -176,3 +176,4 @@ List of Contributors * [Kou Ding](https://github.com/chinakook) * [Istvan Fehervari](https://github.com/ifeherva) * [Aaron Markham](https://github.com/aaronmarkham) +* [Sam Skalicky](https://github.com/samskalicky) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index e963d158446d..0662f71259f3 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -479,10 +479,8 @@ def assert_almost_equal(a, b, rtol=None, atol=None, names=('a', 'b'), equal_nan= """ rtol = get_rtol(rtol) atol = get_atol(atol) - if almost_equal(a, b, rtol, atol, equal_nan=equal_nan): return - index, rel = find_max_violation(a, b, rtol, atol) np.set_printoptions(threshold=4, suppress=True) msg = npt.build_err_msg([a, b], @@ -1203,10 +1201,9 @@ def check_speed(sym, location=None, ctx=None, N=20, grad_req=None, typ="whole", else: raise ValueError('typ can only be "whole" or "forward".') - def check_consistency(sym, ctx_list, scale=1.0, grad_req='write', arg_params=None, aux_params=None, tol=None, - raise_on_err=True, ground_truth=None, equal_nan=False, use_uniform=False): + raise_on_err=True, ground_truth=None, equal_nan=False, use_uniform=False, default_type=np.float64): """Check symbol gives the same output for different running context Parameters @@ -1283,9 +1280,9 @@ def check_consistency(sym, ctx_list, scale=1.0, grad_req='write', for n, arr in exe_list[0].arg_dict.items(): if n not in arg_params: if use_uniform: - arg_params[n] = np.random.uniform(low=-0.92, high=0.92, size=arr.shape) + arg_params[n] = np.random.uniform(low=-0.92, high=0.92, size=arr.shape).astype(default_type) else: - arg_params[n] = np.random.normal(size=arr.shape, scale=scale) + arg_params[n] = np.random.normal(size=arr.shape, scale=scale).astype(default_type) for n, arr in exe_list[0].aux_dict.items(): if n not in aux_params: aux_params[n] = 0 diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc index 9b6996d0feb0..2380f0fc21fb 100644 --- a/src/operator/nn/pooling.cc +++ b/src/operator/nn/pooling.cc @@ -377,8 +377,7 @@ We can see that Lp pooling stands between those two, in practice the most common For each window ``X``, the mathematical expression for Lp pooling is: -..math:: - f(X) = \sqrt{p}{\sum\limits_{x \in X} x^p} +:math:`f(X) = \sqrt[p]{\sum_{x}^{X} x^p}` )code" ADD_FILELINE) .set_num_inputs(1) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index d8d34ef474dd..d222b314f5ed 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -36,8 +36,11 @@ from test_operator import * from test_optimizer import * from test_random import * +from test_gluon import * +from test_loss import * from test_exc_handling import * #from test_rnn import * +from test_gluon_rnn import * from test_sparse_ndarray import * from test_sparse_operator import * from test_ndarray import * @@ -127,7 +130,7 @@ def check_ifft(shape): init_complex.real[:,i] = init[0][:,2*i] init_complex.imag[:,i] = init[0][:,2*i+1] a = np.fft.ifft(init_complex, n=None, axis=-1, norm=None) - assert_almost_equal(a.real, out1[0]/shape_old[1],rtol=1e-3, atol=1e-5) + assert_almost_equal(a.real, out1[0]/shape_old[1],rtol=1e-3, atol=1e-12) if len(shape) == 4: init_complex = np.zeros(shape_old,dtype = np.complex64) @@ -135,7 +138,7 @@ def check_ifft(shape): init_complex.real[:,:,:,i] = init[0][:,:,:,2*i] init_complex.imag[:,:,:,i] = init[0][:,:,:,2*i+1] a = np.fft.ifft(init_complex, n=None, axis=-1, norm=None) - assert_almost_equal(a.real, out1[0]/shape_old[3],rtol=1e-3, atol=1e-5) + assert_almost_equal(a.real, out1[0]/shape_old[3],rtol=1e-3, atol=1e-12) # backward if len(shape) == 2: out_grad = mx.nd.empty(shape_old) @@ -148,7 +151,7 @@ def check_ifft(shape): temp[:,i] = exe.grad_arrays[0].asnumpy()[:,2*i] a = np.fft.fft(out_grad.asnumpy(), n=None, axis=-1, norm=None) - assert_almost_equal(a.real, temp, rtol=1e-3, atol=1e-5) + assert_almost_equal(a.real, temp, rtol=1e-3, atol=1e-12) if len(shape) == 4: out_grad = mx.nd.empty(shape_old) out_grad[:] = np.random.normal(-3, 3, shape_old) @@ -160,9 +163,9 @@ def check_ifft(shape): temp[:,:,:,i] = exe.grad_arrays[0].asnumpy()[:,:,:,2*i] a = np.fft.fft(out_grad.asnumpy(), n=None, axis=-1, norm=None) - assert_almost_equal(a.real, temp, rtol=1e-3, atol=1e-5) + assert_almost_equal(a.real, temp, rtol=1e-3, atol=1e-12) -@with_seed() +@with_seed(0) def test_ifft(): nrepeat = 2 maxdim = 10 @@ -194,7 +197,7 @@ def check_fft(shape): for exe in exe_list: for arr, iarr in zip(exe.arg_arrays, init): arr[:] = iarr.astype(arr.dtype) - # forward + #forward for exe in exe_list: exe.forward(is_train=True) out1 = [exe.outputs[0].asnumpy() for exe in exe_list] @@ -221,7 +224,7 @@ def check_fft(shape): a[i,j,:,p+1] = out2[i,j+out1[0].shape[1],:,k] p = p+2 - assert_almost_equal(a, out1[0],rtol=1e-3, atol=1e-5) + assert_almost_equal(a, out1[0],rtol=1e-3, atol=1e-6) # backward if len(shape) == 2: @@ -235,7 +238,7 @@ def check_fft(shape): for exe in exe_list: exe.backward([out_grad]) a = np.fft.ifft(out_grad_complex, n=None, axis=-1, norm=None) - assert_almost_equal(a.real, exe.grad_arrays[0].asnumpy()/shape[1],rtol=1e-3, atol=1e-5) + assert_almost_equal(a.real, exe.grad_arrays[0].asnumpy()/shape[1],rtol=1e-3, atol=1e-8) if len(shape) == 4: out_grad = mx.nd.empty(out1[0].shape) @@ -248,9 +251,9 @@ def check_fft(shape): for exe in exe_list: exe.backward([out_grad]) a = np.fft.ifft(out_grad_complex, n=None, axis=-1, norm=None) - assert_almost_equal(a.real, exe.grad_arrays[0].asnumpy()/shape[3],rtol=1e-3, atol=1e-5) + assert_almost_equal(a.real, exe.grad_arrays[0].asnumpy()/shape[3],rtol=1e-3, atol=1e-6) -@with_seed() +@with_seed(0) def test_fft(): nrepeat = 2 maxdim = 10 @@ -608,19 +611,21 @@ def test_convolution_versions(): @with_seed() def test_pooling_with_type(): + ctx_list = [{'ctx': mx.gpu(0), 'pool_data': (2, 2, 10, 10), 'type_dict': {'pool_data': np.float64}}, {'ctx': mx.gpu(0), 'pool_data': (2, 2, 10, 10), 'type_dict': {'pool_data': np.float32}}, {'ctx': mx.gpu(0), 'pool_data': (2, 2, 10, 10), 'type_dict': {'pool_data': np.float16}}, {'ctx': mx.cpu(0), 'pool_data': (2, 2, 10, 10), 'type_dict': {'pool_data': np.float64}}, {'ctx': mx.cpu(0), 'pool_data': (2, 2, 10, 10), 'type_dict': {'pool_data': np.float32}}] + sym = mx.sym.Pooling(kernel=(3,3), pool_type='max', pooling_convention='valid', name='pool') - check_consistency(sym, ctx_list) + check_consistency(sym, ctx_list, default_type=np.float16) sym = mx.sym.Pooling(kernel=(3,3), pool_type='max', pooling_convention='full', name='pool') - check_consistency(sym, ctx_list) + check_consistency(sym, ctx_list, default_type=np.float16) sym = mx.sym.Pooling(kernel=(300,300), pool_type='max', global_pool=True, name='pool') - check_consistency(sym, ctx_list) + check_consistency(sym, ctx_list, default_type=np.float16) @with_seed() @@ -768,8 +773,8 @@ def test_spatial_transformer_with_type(): # Checking max pooling consistency over the data sets of different float types is problematic # as one max value in a float32 data set may not be the max value in a float16 data set. # This function will not be called. -@with_seed(1234) -def test_pooling_with_type(): +@with_seed() +def test_pooling_with_type2(): ctx_list = [{'ctx': mx.gpu(0), 'pool_data': (10, 2, 10, 10), 'type_dict': {'pool_data': np.float64}}, {'ctx': mx.gpu(0), 'pool_data': (10, 2, 10, 10), 'type_dict': {'pool_data': np.float32}}, {'ctx': mx.gpu(0), 'pool_data': (10, 2, 10, 10), 'type_dict': {'pool_data': np.float16}}, @@ -777,17 +782,16 @@ def test_pooling_with_type(): {'ctx': mx.cpu(0), 'pool_data': (10, 2, 10, 10), 'type_dict': {'pool_data': np.float32}}] sym = mx.sym.Pooling(name='pool', kernel=(3,3), stride=(2,2), pool_type='max') - check_consistency(sym, ctx_list) + check_consistency(sym, ctx_list, default_type=np.float16) sym = mx.sym.Pooling(name='pool', kernel=(3,3), pad=(1,1), pool_type='avg') - check_consistency(sym, ctx_list) + check_consistency(sym, ctx_list, default_type=np.float16) - # this is unstable - # sym = mx.sym.Pooling(name='pool', kernel=(5,5), pad=(2,2), pool_type='max') - # check_consistency(sym, ctx_list) + sym = mx.sym.Pooling(name='pool', kernel=(5,5), pad=(2,2), pool_type='max') + check_consistency(sym, ctx_list, default_type=np.float16) sym = mx.sym.Pooling(name='pool', kernel=(3,3), pad=(1,1), pool_type='sum') - check_consistency(sym, ctx_list) + check_consistency(sym, ctx_list, default_type=np.float16) @unittest.skip("Flaky test https://github.com/apache/incubator-mxnet/issues/11517") @@ -1657,6 +1661,17 @@ def check_rnn_layer_w_rand_inputs(layer): for g, c in zip(gs, cs): assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6) +@with_seed() +@assert_raises_cudnn_disabled() +def test_rnn_layer(): + check_rnn_layer(gluon.rnn.RNN(100, num_layers=3)) + check_rnn_layer(gluon.rnn.RNN(100, activation='tanh', num_layers=3)) + check_rnn_layer(gluon.rnn.LSTM(100, num_layers=3)) + check_rnn_layer(gluon.rnn.GRU(100, num_layers=3)) + + check_rnn_layer(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True)) + check_rnn_layer_w_rand_inputs(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True)) + @with_seed() def test_sequence_reverse(): check_sequence_reverse(mx.gpu(0)) @@ -1674,6 +1689,28 @@ def test_autograd_save_memory(): x.backward() +@with_seed() +def test_gluon_ctc_consistency(): + loss = mx.gluon.loss.CTCLoss() + data = mx.nd.arange(0, 4, repeat=40, ctx=mx.gpu(0)).reshape((2,20,4)).flip(axis=0) + cpu_label = mx.nd.array([[2,1,-1,-1],[3,2,2,-1]], ctx=mx.cpu(0)) + gpu_label = mx.nd.array([[2,1,-1,-1],[3,2,2,-1]], ctx=mx.gpu(0)) + + cpu_data = data.copy().as_in_context(mx.cpu(0)) + cpu_data.attach_grad() + with mx.autograd.record(): + l_cpu = loss(cpu_data, cpu_label) + l_cpu.backward() + + gpu_data = data.copyto(mx.gpu(0)) + gpu_data.attach_grad() + with mx.autograd.record(): + l_gpu = loss(gpu_data, gpu_label) + l_gpu.backward() + + assert_almost_equal(cpu_data.grad.asnumpy(), gpu_data.grad.asnumpy(), atol=1e-3, rtol=1e-3) + + @with_seed() def test_cuda_rtc(): source = r''' @@ -1704,6 +1741,16 @@ def test_cuda_rtc(): assert (y.asnumpy() == 12).all() +@with_seed() +def test_global_norm_clip_multi_device(): + x1 = mx.nd.ones((3,3), ctx=mx.gpu(0)) + x2 = mx.nd.ones((4,4), ctx=mx.cpu(0)) + norm = gluon.utils.clip_global_norm([x1, x2], 1.0) + assert norm == 5.0 + assert_almost_equal(x1.asnumpy(), np.ones((3,3))/5) + assert_almost_equal(x2.asnumpy(), np.ones((4,4))/5) + + @with_seed() def test_cross_device_autograd(): x = mx.nd.random.uniform(shape=(10,)) @@ -1922,6 +1969,84 @@ def test_context_num_gpus(): # Test that num_gpus reports at least one GPU, as the test is run on a GPU host. assert mx.context.num_gpus() > 0 +def _check_batchnorm_result(input, num_devices=1, cuda=False): + from mxnet.gluon.utils import split_and_load + def _find_bn(module): + if isinstance(module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): + return module + elif isinstance(module.module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): + return module.module + + raise RuntimeError('BN not found') + + def _syncParameters(bn1, bn2, ctx): + ctx = input.context + bn2.gamma.set_data(bn1.gamma.data(ctx)) + bn2.beta.set_data(bn1.beta.data(ctx)) + bn2.running_mean.set_data(bn1.running_mean.data(ctx)) + bn2.running_var.set_data(bn1.running_var.data(ctx)) + + input1 = input.copy() + input2 = input.copy() + + if cuda: + input1 = input.as_in_context(mx.gpu(0)) + ctx_list = [mx.gpu(i) for i in range(num_devices)] + else: + ctx_list = [mx.cpu(0) for _ in range(num_devices)] + + nch = input.shape[1] + bn1 = mx.gluon.nn.BatchNorm(in_channels=nch) + bn2 = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=nch, num_devices=num_devices) + + bn1.initialize(ctx=ctx_list[0]) + bn2.initialize(ctx=ctx_list) + + # using the same values for gamma and beta + #_syncParameters(_find_bn(bn1), _find_bn(bn2), ctx_list[0]) + + input1.attach_grad() + inputs2 = split_and_load(input2, ctx_list, batch_axis=0) + for xi in inputs2: + xi.attach_grad() + + with mx.autograd.record(): + output1 = bn1(input1) + output2 = [bn2(xi) for xi in inputs2] + loss1 = (output1 ** 2).sum() + loss2 = [(output ** 2).sum() for output in output2] + mx.autograd.backward(loss1) + mx.autograd.backward(loss2) + + output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in output2], dim=0) + # assert forwarding + assert_almost_equal(input1.asnumpy(), input2.asnumpy(), atol=1e-3, rtol=1e-3) + assert_almost_equal(output1.asnumpy(), output2.asnumpy(), atol=1e-3, rtol=1e-3) + assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(), + _find_bn(bn2).running_mean.data(ctx_list[0]).asnumpy(), + atol=1e-3, rtol=1e-3) + assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(), + _find_bn(bn2).running_var.data(ctx_list[0]).asnumpy(), + atol=1e-3, rtol=1e-3) + input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) + assert_almost_equal(input1.grad.asnumpy(), input2grad.asnumpy(), atol=1e-3, rtol=1e-3) + +def test_sync_batchnorm(): + def get_num_devices(): + for i in range(100): + try: + mx.nd.zeros((1,), ctx=mx.gpu(i)) + except: + return i + # no need to use SyncBN with 1 gpu + if get_num_devices() < 2: + return + ndev = 2 + # check with unsync version + for i in range(10): + _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)), + num_devices=ndev, cuda=True) + if __name__ == '__main__': import nose nose.runmodule() From 20f831ab046dfb49321c6977bb83ab836d9d370e Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Tue, 7 Aug 2018 18:22:28 +0000 Subject: [PATCH 2/5] updated from latest in master (had old version of the file) --- tests/python/gpu/test_operator_gpu.py | 150 ++------------------------ 1 file changed, 10 insertions(+), 140 deletions(-) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index d222b314f5ed..921844acc085 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -36,11 +36,8 @@ from test_operator import * from test_optimizer import * from test_random import * -from test_gluon import * -from test_loss import * from test_exc_handling import * #from test_rnn import * -from test_gluon_rnn import * from test_sparse_ndarray import * from test_sparse_operator import * from test_ndarray import * @@ -130,7 +127,7 @@ def check_ifft(shape): init_complex.real[:,i] = init[0][:,2*i] init_complex.imag[:,i] = init[0][:,2*i+1] a = np.fft.ifft(init_complex, n=None, axis=-1, norm=None) - assert_almost_equal(a.real, out1[0]/shape_old[1],rtol=1e-3, atol=1e-12) + assert_almost_equal(a.real, out1[0]/shape_old[1],rtol=1e-3, atol=1e-5) if len(shape) == 4: init_complex = np.zeros(shape_old,dtype = np.complex64) @@ -138,7 +135,7 @@ def check_ifft(shape): init_complex.real[:,:,:,i] = init[0][:,:,:,2*i] init_complex.imag[:,:,:,i] = init[0][:,:,:,2*i+1] a = np.fft.ifft(init_complex, n=None, axis=-1, norm=None) - assert_almost_equal(a.real, out1[0]/shape_old[3],rtol=1e-3, atol=1e-12) + assert_almost_equal(a.real, out1[0]/shape_old[3],rtol=1e-3, atol=1e-5) # backward if len(shape) == 2: out_grad = mx.nd.empty(shape_old) @@ -151,7 +148,7 @@ def check_ifft(shape): temp[:,i] = exe.grad_arrays[0].asnumpy()[:,2*i] a = np.fft.fft(out_grad.asnumpy(), n=None, axis=-1, norm=None) - assert_almost_equal(a.real, temp, rtol=1e-3, atol=1e-12) + assert_almost_equal(a.real, temp, rtol=1e-3, atol=1e-5) if len(shape) == 4: out_grad = mx.nd.empty(shape_old) out_grad[:] = np.random.normal(-3, 3, shape_old) @@ -163,9 +160,9 @@ def check_ifft(shape): temp[:,:,:,i] = exe.grad_arrays[0].asnumpy()[:,:,:,2*i] a = np.fft.fft(out_grad.asnumpy(), n=None, axis=-1, norm=None) - assert_almost_equal(a.real, temp, rtol=1e-3, atol=1e-12) + assert_almost_equal(a.real, temp, rtol=1e-3, atol=1e-5) -@with_seed(0) +@with_seed() def test_ifft(): nrepeat = 2 maxdim = 10 @@ -197,7 +194,7 @@ def check_fft(shape): for exe in exe_list: for arr, iarr in zip(exe.arg_arrays, init): arr[:] = iarr.astype(arr.dtype) - #forward + # forward for exe in exe_list: exe.forward(is_train=True) out1 = [exe.outputs[0].asnumpy() for exe in exe_list] @@ -224,7 +221,7 @@ def check_fft(shape): a[i,j,:,p+1] = out2[i,j+out1[0].shape[1],:,k] p = p+2 - assert_almost_equal(a, out1[0],rtol=1e-3, atol=1e-6) + assert_almost_equal(a, out1[0],rtol=1e-3, atol=1e-5) # backward if len(shape) == 2: @@ -238,7 +235,7 @@ def check_fft(shape): for exe in exe_list: exe.backward([out_grad]) a = np.fft.ifft(out_grad_complex, n=None, axis=-1, norm=None) - assert_almost_equal(a.real, exe.grad_arrays[0].asnumpy()/shape[1],rtol=1e-3, atol=1e-8) + assert_almost_equal(a.real, exe.grad_arrays[0].asnumpy()/shape[1],rtol=1e-3, atol=1e-5) if len(shape) == 4: out_grad = mx.nd.empty(out1[0].shape) @@ -251,9 +248,9 @@ def check_fft(shape): for exe in exe_list: exe.backward([out_grad]) a = np.fft.ifft(out_grad_complex, n=None, axis=-1, norm=None) - assert_almost_equal(a.real, exe.grad_arrays[0].asnumpy()/shape[3],rtol=1e-3, atol=1e-6) + assert_almost_equal(a.real, exe.grad_arrays[0].asnumpy()/shape[3],rtol=1e-3, atol=1e-5) -@with_seed(0) +@with_seed() def test_fft(): nrepeat = 2 maxdim = 10 @@ -611,13 +608,11 @@ def test_convolution_versions(): @with_seed() def test_pooling_with_type(): - ctx_list = [{'ctx': mx.gpu(0), 'pool_data': (2, 2, 10, 10), 'type_dict': {'pool_data': np.float64}}, {'ctx': mx.gpu(0), 'pool_data': (2, 2, 10, 10), 'type_dict': {'pool_data': np.float32}}, {'ctx': mx.gpu(0), 'pool_data': (2, 2, 10, 10), 'type_dict': {'pool_data': np.float16}}, {'ctx': mx.cpu(0), 'pool_data': (2, 2, 10, 10), 'type_dict': {'pool_data': np.float64}}, {'ctx': mx.cpu(0), 'pool_data': (2, 2, 10, 10), 'type_dict': {'pool_data': np.float32}}] - sym = mx.sym.Pooling(kernel=(3,3), pool_type='max', pooling_convention='valid', name='pool') check_consistency(sym, ctx_list, default_type=np.float16) @@ -770,9 +765,6 @@ def test_spatial_transformer_with_type(): check_consistency(sym, ctx_list, grad_req="add") -# Checking max pooling consistency over the data sets of different float types is problematic -# as one max value in a float32 data set may not be the max value in a float16 data set. -# This function will not be called. @with_seed() def test_pooling_with_type2(): ctx_list = [{'ctx': mx.gpu(0), 'pool_data': (10, 2, 10, 10), 'type_dict': {'pool_data': np.float64}}, @@ -793,7 +785,6 @@ def test_pooling_with_type2(): sym = mx.sym.Pooling(name='pool', kernel=(3,3), pad=(1,1), pool_type='sum') check_consistency(sym, ctx_list, default_type=np.float16) - @unittest.skip("Flaky test https://github.com/apache/incubator-mxnet/issues/11517") @with_seed() def test_pooling_versions(): @@ -1661,17 +1652,6 @@ def check_rnn_layer_w_rand_inputs(layer): for g, c in zip(gs, cs): assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6) -@with_seed() -@assert_raises_cudnn_disabled() -def test_rnn_layer(): - check_rnn_layer(gluon.rnn.RNN(100, num_layers=3)) - check_rnn_layer(gluon.rnn.RNN(100, activation='tanh', num_layers=3)) - check_rnn_layer(gluon.rnn.LSTM(100, num_layers=3)) - check_rnn_layer(gluon.rnn.GRU(100, num_layers=3)) - - check_rnn_layer(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True)) - check_rnn_layer_w_rand_inputs(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True)) - @with_seed() def test_sequence_reverse(): check_sequence_reverse(mx.gpu(0)) @@ -1689,28 +1669,6 @@ def test_autograd_save_memory(): x.backward() -@with_seed() -def test_gluon_ctc_consistency(): - loss = mx.gluon.loss.CTCLoss() - data = mx.nd.arange(0, 4, repeat=40, ctx=mx.gpu(0)).reshape((2,20,4)).flip(axis=0) - cpu_label = mx.nd.array([[2,1,-1,-1],[3,2,2,-1]], ctx=mx.cpu(0)) - gpu_label = mx.nd.array([[2,1,-1,-1],[3,2,2,-1]], ctx=mx.gpu(0)) - - cpu_data = data.copy().as_in_context(mx.cpu(0)) - cpu_data.attach_grad() - with mx.autograd.record(): - l_cpu = loss(cpu_data, cpu_label) - l_cpu.backward() - - gpu_data = data.copyto(mx.gpu(0)) - gpu_data.attach_grad() - with mx.autograd.record(): - l_gpu = loss(gpu_data, gpu_label) - l_gpu.backward() - - assert_almost_equal(cpu_data.grad.asnumpy(), gpu_data.grad.asnumpy(), atol=1e-3, rtol=1e-3) - - @with_seed() def test_cuda_rtc(): source = r''' @@ -1741,16 +1699,6 @@ def test_cuda_rtc(): assert (y.asnumpy() == 12).all() -@with_seed() -def test_global_norm_clip_multi_device(): - x1 = mx.nd.ones((3,3), ctx=mx.gpu(0)) - x2 = mx.nd.ones((4,4), ctx=mx.cpu(0)) - norm = gluon.utils.clip_global_norm([x1, x2], 1.0) - assert norm == 5.0 - assert_almost_equal(x1.asnumpy(), np.ones((3,3))/5) - assert_almost_equal(x2.asnumpy(), np.ones((4,4))/5) - - @with_seed() def test_cross_device_autograd(): x = mx.nd.random.uniform(shape=(10,)) @@ -1969,84 +1917,6 @@ def test_context_num_gpus(): # Test that num_gpus reports at least one GPU, as the test is run on a GPU host. assert mx.context.num_gpus() > 0 -def _check_batchnorm_result(input, num_devices=1, cuda=False): - from mxnet.gluon.utils import split_and_load - def _find_bn(module): - if isinstance(module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): - return module - elif isinstance(module.module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): - return module.module - - raise RuntimeError('BN not found') - - def _syncParameters(bn1, bn2, ctx): - ctx = input.context - bn2.gamma.set_data(bn1.gamma.data(ctx)) - bn2.beta.set_data(bn1.beta.data(ctx)) - bn2.running_mean.set_data(bn1.running_mean.data(ctx)) - bn2.running_var.set_data(bn1.running_var.data(ctx)) - - input1 = input.copy() - input2 = input.copy() - - if cuda: - input1 = input.as_in_context(mx.gpu(0)) - ctx_list = [mx.gpu(i) for i in range(num_devices)] - else: - ctx_list = [mx.cpu(0) for _ in range(num_devices)] - - nch = input.shape[1] - bn1 = mx.gluon.nn.BatchNorm(in_channels=nch) - bn2 = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=nch, num_devices=num_devices) - - bn1.initialize(ctx=ctx_list[0]) - bn2.initialize(ctx=ctx_list) - - # using the same values for gamma and beta - #_syncParameters(_find_bn(bn1), _find_bn(bn2), ctx_list[0]) - - input1.attach_grad() - inputs2 = split_and_load(input2, ctx_list, batch_axis=0) - for xi in inputs2: - xi.attach_grad() - - with mx.autograd.record(): - output1 = bn1(input1) - output2 = [bn2(xi) for xi in inputs2] - loss1 = (output1 ** 2).sum() - loss2 = [(output ** 2).sum() for output in output2] - mx.autograd.backward(loss1) - mx.autograd.backward(loss2) - - output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in output2], dim=0) - # assert forwarding - assert_almost_equal(input1.asnumpy(), input2.asnumpy(), atol=1e-3, rtol=1e-3) - assert_almost_equal(output1.asnumpy(), output2.asnumpy(), atol=1e-3, rtol=1e-3) - assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(), - _find_bn(bn2).running_mean.data(ctx_list[0]).asnumpy(), - atol=1e-3, rtol=1e-3) - assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(), - _find_bn(bn2).running_var.data(ctx_list[0]).asnumpy(), - atol=1e-3, rtol=1e-3) - input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) - assert_almost_equal(input1.grad.asnumpy(), input2grad.asnumpy(), atol=1e-3, rtol=1e-3) - -def test_sync_batchnorm(): - def get_num_devices(): - for i in range(100): - try: - mx.nd.zeros((1,), ctx=mx.gpu(i)) - except: - return i - # no need to use SyncBN with 1 gpu - if get_num_devices() < 2: - return - ndev = 2 - # check with unsync version - for i in range(10): - _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)), - num_devices=ndev, cuda=True) - if __name__ == '__main__': import nose nose.runmodule() From d4930369bb29181d4b61c7cab76393ac9b055210 Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Tue, 7 Aug 2018 18:26:53 +0000 Subject: [PATCH 3/5] shortened lines per lint spec --- python/mxnet/test_utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 0662f71259f3..976207fd4e70 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -1203,7 +1203,8 @@ def check_speed(sym, location=None, ctx=None, N=20, grad_req=None, typ="whole", def check_consistency(sym, ctx_list, scale=1.0, grad_req='write', arg_params=None, aux_params=None, tol=None, - raise_on_err=True, ground_truth=None, equal_nan=False, use_uniform=False, default_type=np.float64): + raise_on_err=True, ground_truth=None, equal_nan=False, + use_uniform=False, default_type=np.float64): """Check symbol gives the same output for different running context Parameters @@ -1280,9 +1281,11 @@ def check_consistency(sym, ctx_list, scale=1.0, grad_req='write', for n, arr in exe_list[0].arg_dict.items(): if n not in arg_params: if use_uniform: - arg_params[n] = np.random.uniform(low=-0.92, high=0.92, size=arr.shape).astype(default_type) + arg_params[n] = np.random.uniform(low=-0.92, high=0.92, + size=arr.shape).astype(default_type) else: - arg_params[n] = np.random.normal(size=arr.shape, scale=scale).astype(default_type) + arg_params[n] = np.random.normal(size=arr.shape, + scale=scale).astype(default_type) for n, arr in exe_list[0].aux_dict.items(): if n not in aux_params: aux_params[n] = 0 From 71c55fdec3dd03c414ce493f4a18d0d7638bdd1c Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Wed, 8 Aug 2018 23:25:18 +0000 Subject: [PATCH 4/5] renamed default_type argument to rand_type for clarity updated function docstring with argument description removed rand_type setting for non-max pooling tests --- python/mxnet/test_utils.py | 10 +++++++--- tests/python/gpu/test_operator_gpu.py | 14 +++++++------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 976207fd4e70..f55381428615 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -1204,7 +1204,7 @@ def check_speed(sym, location=None, ctx=None, N=20, grad_req=None, typ="whole", def check_consistency(sym, ctx_list, scale=1.0, grad_req='write', arg_params=None, aux_params=None, tol=None, raise_on_err=True, ground_truth=None, equal_nan=False, - use_uniform=False, default_type=np.float64): + use_uniform=False, rand_type=np.float64): """Check symbol gives the same output for different running context Parameters @@ -1221,6 +1221,10 @@ def check_consistency(sym, ctx_list, scale=1.0, grad_req='write', Optional, When flag set to true, random input data generated follows uniform distribution, not normal distribution + rand_type: np.dtype + Optional, when input data is passed via arg_params, + defaults to np.float64 (python float default) + Examples -------- >>> # create the symbol @@ -1282,10 +1286,10 @@ def check_consistency(sym, ctx_list, scale=1.0, grad_req='write', if n not in arg_params: if use_uniform: arg_params[n] = np.random.uniform(low=-0.92, high=0.92, - size=arr.shape).astype(default_type) + size=arr.shape).astype(rand_type) else: arg_params[n] = np.random.normal(size=arr.shape, - scale=scale).astype(default_type) + scale=scale).astype(rand_type) for n, arr in exe_list[0].aux_dict.items(): if n not in aux_params: aux_params[n] = 0 diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 921844acc085..f1cae5199587 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -614,13 +614,13 @@ def test_pooling_with_type(): {'ctx': mx.cpu(0), 'pool_data': (2, 2, 10, 10), 'type_dict': {'pool_data': np.float64}}, {'ctx': mx.cpu(0), 'pool_data': (2, 2, 10, 10), 'type_dict': {'pool_data': np.float32}}] sym = mx.sym.Pooling(kernel=(3,3), pool_type='max', pooling_convention='valid', name='pool') - check_consistency(sym, ctx_list, default_type=np.float16) + check_consistency(sym, ctx_list, rand_type=np.float16) sym = mx.sym.Pooling(kernel=(3,3), pool_type='max', pooling_convention='full', name='pool') - check_consistency(sym, ctx_list, default_type=np.float16) + check_consistency(sym, ctx_list, rand_type=np.float16) sym = mx.sym.Pooling(kernel=(300,300), pool_type='max', global_pool=True, name='pool') - check_consistency(sym, ctx_list, default_type=np.float16) + check_consistency(sym, ctx_list, rand_type=np.float16) @with_seed() @@ -774,16 +774,16 @@ def test_pooling_with_type2(): {'ctx': mx.cpu(0), 'pool_data': (10, 2, 10, 10), 'type_dict': {'pool_data': np.float32}}] sym = mx.sym.Pooling(name='pool', kernel=(3,3), stride=(2,2), pool_type='max') - check_consistency(sym, ctx_list, default_type=np.float16) + check_consistency(sym, ctx_list, rand_type=np.float16) sym = mx.sym.Pooling(name='pool', kernel=(3,3), pad=(1,1), pool_type='avg') - check_consistency(sym, ctx_list, default_type=np.float16) + check_consistency(sym, ctx_list) sym = mx.sym.Pooling(name='pool', kernel=(5,5), pad=(2,2), pool_type='max') - check_consistency(sym, ctx_list, default_type=np.float16) + check_consistency(sym, ctx_list, rand_type=np.float16) sym = mx.sym.Pooling(name='pool', kernel=(3,3), pad=(1,1), pool_type='sum') - check_consistency(sym, ctx_list, default_type=np.float16) + check_consistency(sym, ctx_list) @unittest.skip("Flaky test https://github.com/apache/incubator-mxnet/issues/11517") @with_seed() From a2e878cf145d088e6dc8017cf83501ed4f9da4a8 Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Wed, 8 Aug 2018 23:58:41 +0000 Subject: [PATCH 5/5] cleaned up check_consistency function docstring --- python/mxnet/test_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index f55381428615..69d916ef85e3 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -1222,8 +1222,9 @@ def check_consistency(sym, ctx_list, scale=1.0, grad_req='write', random input data generated follows uniform distribution, not normal distribution rand_type: np.dtype + casts the randomly generated data to this type Optional, when input data is passed via arg_params, - defaults to np.float64 (python float default) + defaults to np.float64 (numpy float default) Examples --------