Skip to content

Commit

Permalink
revert change in test util (#183)
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin authored Aug 22, 2017
1 parent d3c176c commit 08435ed
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,8 +870,13 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol=
if isinstance(expected, (list, tuple)):
expected = {k:v for k, v in zip(sym.list_arguments(), expected)}
args_grad_npy = {k:_rng.normal(size=v.shape) for k, v in expected.items()}
args_grad_data = {k: mx.nd.array(v, ctx=ctx) if grad_stypes is None or k not in grad_stypes \
else mx.nd.array(v, ctx=ctx).tostype(grad_stypes[k])}
args_grad_data = {}
for k, v in args_grad_npy.items():
nd = mx.nd.array(v, ctx=ctx)
if grad_stypes is not None and k in grad_stypes:
args_grad_data[k] = nd.tostype(grad_stypes[k])
else:
args_grad_data[k] = nd

if isinstance(grad_req, str):
grad_req = {k:grad_req for k in sym.list_arguments()}
Expand Down

0 comments on commit 08435ed

Please sign in to comment.