diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 6e79e06d55b2..e1210fbd3e6e 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -870,8 +870,13 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol= if isinstance(expected, (list, tuple)): expected = {k:v for k, v in zip(sym.list_arguments(), expected)} args_grad_npy = {k:_rng.normal(size=v.shape) for k, v in expected.items()} - args_grad_data = {k: mx.nd.array(v, ctx=ctx) if grad_stypes is None or k not in grad_stypes \ - else mx.nd.array(v, ctx=ctx).tostype(grad_stypes[k])} + args_grad_data = {} + for k, v in args_grad_npy.items(): + nd = mx.nd.array(v, ctx=ctx) + if grad_stypes is not None and k in grad_stypes: + args_grad_data[k] = nd.tostype(grad_stypes[k]) + else: + args_grad_data[k] = nd if isinstance(grad_req, str): grad_req = {k:grad_req for k in sym.list_arguments()}