Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fixed auxiliary state issue for Gluon partition
Browse files Browse the repository at this point in the history
  • Loading branch information
guanxinq committed Jan 29, 2020
1 parent 434c3c7 commit 7228343
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
9 changes: 6 additions & 3 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ def _build_cache(self, *args):
if name in data_names.keys():
arg_array.append(args[data_names[name]])
else:
if ctx == None:
if ctx is None:
ctx = params.get(name)._ctx_list[0]
arg_array.append(ndarray.random.uniform(shape=params.get(name)._shape))
# Partition the graph.
Expand Down Expand Up @@ -1053,9 +1053,12 @@ def register_child(self, block, name=None):
super(HybridBlock, self).register_child(block, name)
self._clear_cached_op()

def hybridize(self, active=True, backend=None, backend_args={}, **kwargs):
def hybridize(self, active=True, backend=None, backend_args=None, **kwargs):
self._backend = backend
self._backend_args = backend_args
if backend_args is None:
self._backend_args = {}
else:
self._backend_args = backend_args
self._active = active
self._flags = list(kwargs.items())
self._clear_cached_op()
Expand Down
24 changes: 13 additions & 11 deletions tests/python/unittest/test_subgraph_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ def network_structure_3():
ret = ret1 + ret2
ret = mx.sym.BatchNorm(ret)
ret = mx.sym.BatchNorm(ret)
return (ret, ['data'], [(2, 3, 10, 10)])

# Return the same and shape of 'data' and auxiliary states
return (ret, ['data', *ret.list_auxiliary_states()], [(2,3,10,10), (3,), (3,), (3,), (3,)])

def network_structure_4():
# the last op has multiple duplicate outputs
data = mx.sym.var('data', shape=(2, 3, 10, 10))
Expand Down Expand Up @@ -84,23 +85,24 @@ def network_structure_7():
return (ret, ['data'], [(1,)])

def get_graphs():
return [(network_structure_1(), ['Convolution']),
return [
(network_structure_1(), ['Convolution']),
(network_structure_2(), ['exp', 'sin', '_Plus', 'elemwise_add', '_plus']),
(network_structure_2(), ['exp', 'cos', '_Plus', 'elemwise_add', '_plus']),
# To do: fix batch norm issue for gluon tests.
#(network_structure_3(), ['exp', 'sin', '_Plus', 'elemwise_add', '_plus']),
#(network_structure_3(), ['exp', 'cos', '_Plus', 'elemwise_add', '_plus']),
#(network_structure_3(), ['exp', 'sin', '_Plus', 'elemwise_add', '_plus', 'BatchNorm']),
#(network_structure_3(), ['exp', 'cos', '_Plus', 'elemwise_add', '_plus', 'BatchNorm']),
#(network_structure_3(), ['exp', 'BatchNorm']),
#(network_structure_3(), ['BatchNorm']),
(network_structure_3(), ['exp', 'sin', '_Plus', 'elemwise_add', '_plus']),
(network_structure_3(), ['exp', 'cos', '_Plus', 'elemwise_add', '_plus']),
(network_structure_3(), ['exp', 'sin', '_Plus', 'elemwise_add', '_plus', 'BatchNorm']),
(network_structure_3(), ['exp', 'cos', '_Plus', 'elemwise_add', '_plus', 'BatchNorm']),
(network_structure_3(), ['exp', 'BatchNorm']),
(network_structure_3(), ['BatchNorm']),
(network_structure_4(), ['exp']),
(network_structure_5(), ['_plus', '_Plus', 'elemwise_add']),
(network_structure_6(), []),
(network_structure_6(), [mx.sym.sin.__name__]),
(network_structure_6(), [mx.sym.Convolution.__name__]),
(network_structure_6(), [mx.sym.sin.__name__, mx.sym.Convolution.__name__]),
(network_structure_7(), ['sin', 'elemwise_add', '_plus', '_Plus'])]
(network_structure_7(), ['sin', 'elemwise_add', '_plus', '_Plus'])
]

def check_subgraph_exe1(sym, subgraph_backend, op_names):
"""Use the partitioned sym to simple_bind an executor and compare the outputs
Expand Down

0 comments on commit 7228343

Please sign in to comment.