diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index bd6ee6a2cdcf..5ab95f65eccf 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -962,7 +962,7 @@ def _build_cache(self, *args): if name in data_names.keys(): arg_array.append(args[data_names[name]]) else: - if ctx == None: + if ctx is None: ctx = params.get(name)._ctx_list[0] arg_array.append(ndarray.random.uniform(shape=params.get(name)._shape)) # Partition the graph. @@ -1053,9 +1053,12 @@ def register_child(self, block, name=None): super(HybridBlock, self).register_child(block, name) self._clear_cached_op() - def hybridize(self, active=True, backend=None, backend_args={}, **kwargs): + def hybridize(self, active=True, backend=None, backend_args=None, **kwargs): self._backend = backend - self._backend_args = backend_args + if backend_args is None: + self._backend_args = {} + else: + self._backend_args = backend_args self._active = active self._flags = list(kwargs.items()) self._clear_cached_op() diff --git a/tests/python/unittest/test_subgraph_op.py b/tests/python/unittest/test_subgraph_op.py index c4543b34aeae..000825f7febf 100644 --- a/tests/python/unittest/test_subgraph_op.py +++ b/tests/python/unittest/test_subgraph_op.py @@ -51,8 +51,9 @@ def network_structure_3(): ret = ret1 + ret2 ret = mx.sym.BatchNorm(ret) ret = mx.sym.BatchNorm(ret) - return (ret, ['data'], [(2, 3, 10, 10)]) - + # Return the same and shape of 'data' and auxiliary states + return (ret, ['data', *ret.list_auxiliary_states()], [(2,3,10,10), (3,), (3,), (3,), (3,)]) + def network_structure_4(): # the last op has multiple duplicate outputs data = mx.sym.var('data', shape=(2, 3, 10, 10)) @@ -84,23 +85,24 @@ def network_structure_7(): return (ret, ['data'], [(1,)]) def get_graphs(): - return [(network_structure_1(), ['Convolution']), + return [ + (network_structure_1(), ['Convolution']), (network_structure_2(), ['exp', 'sin', '_Plus', 'elemwise_add', '_plus']), (network_structure_2(), ['exp', 'cos', '_Plus', 'elemwise_add', '_plus']), - # To do: fix batch norm issue for gluon tests. - #(network_structure_3(), ['exp', 'sin', '_Plus', 'elemwise_add', '_plus']), - #(network_structure_3(), ['exp', 'cos', '_Plus', 'elemwise_add', '_plus']), - #(network_structure_3(), ['exp', 'sin', '_Plus', 'elemwise_add', '_plus', 'BatchNorm']), - #(network_structure_3(), ['exp', 'cos', '_Plus', 'elemwise_add', '_plus', 'BatchNorm']), - #(network_structure_3(), ['exp', 'BatchNorm']), - #(network_structure_3(), ['BatchNorm']), + (network_structure_3(), ['exp', 'sin', '_Plus', 'elemwise_add', '_plus']), + (network_structure_3(), ['exp', 'cos', '_Plus', 'elemwise_add', '_plus']), + (network_structure_3(), ['exp', 'sin', '_Plus', 'elemwise_add', '_plus', 'BatchNorm']), + (network_structure_3(), ['exp', 'cos', '_Plus', 'elemwise_add', '_plus', 'BatchNorm']), + (network_structure_3(), ['exp', 'BatchNorm']), + (network_structure_3(), ['BatchNorm']), (network_structure_4(), ['exp']), (network_structure_5(), ['_plus', '_Plus', 'elemwise_add']), (network_structure_6(), []), (network_structure_6(), [mx.sym.sin.__name__]), (network_structure_6(), [mx.sym.Convolution.__name__]), (network_structure_6(), [mx.sym.sin.__name__, mx.sym.Convolution.__name__]), - (network_structure_7(), ['sin', 'elemwise_add', '_plus', '_Plus'])] + (network_structure_7(), ['sin', 'elemwise_add', '_plus', '_Plus']) + ] def check_subgraph_exe1(sym, subgraph_backend, op_names): """Use the partitioned sym to simple_bind an executor and compare the outputs