Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Use correct array type for outputs in HybridBlock.forward (#18554)
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu authored Jun 13, 2020
1 parent f1f3f44 commit 09cf48a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,7 @@ def _get_graph_v2(self, *args):
with autograd.pause(), dc.context():
out = super().__call__(*args)
flatten_out, self._out_format = _flatten(out, "output")
symbol_outputs = dc.get_symbol(flatten_out)
symbol_outputs = dc.get_symbol(flatten_out, sym_cls=type(symbol_inputs[0]))
self._cached_graph = symbol_inputs, symbol_outputs
return self._cached_graph

Expand Down
5 changes: 5 additions & 0 deletions tests/python/unittest/test_deferred_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,11 @@ def _assert_dc_gluon(setup, net, setup_is_deterministic=True, numpy=True, autogr
[p.grad() for p in net.collect_params().values()]
else:
ys_hybrid = net(*xs)

assert all(
isinstance(y, mx.numpy.ndarray) if numpy else isinstance(y, mx.ndarray.ndarray.NDArray)
for y in ys_hybrid)

ys_hybrid_np = [y.asnumpy() for y in ys_hybrid]

_all_same(ys_np, ys_hybrid_np)
Expand Down

0 comments on commit 09cf48a

Please sign in to comment.