Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu committed Jul 10, 2020
1 parent 2fa81d7 commit 0661ca6
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 71 deletions.
32 changes: 24 additions & 8 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from .utils import _indent, _brief_print_list, HookHandle, shape_is_known
from .utils import _check_same_symbol_type, _check_all_np_ndarrays
from .. import numpy_extension as _mx_npx
from .. import numpy as _mx_np
from .. import numpy as _mx_np, ndarray as nd
from .. util import is_np_array, np_shape, np_array


Expand All @@ -54,8 +54,8 @@ def _block_scope(block):
counter = _naming_counter.get(None)
if counter is not None:
count = counter.get(name, 0)
name = '%s%d'%(name, count)
counter[name] = count + 1
name = '%s%d'%(name, count)
counter_token = _naming_counter.set({})
prefix_token = _prefix.set(_prefix.get() + name + '_')
with _name.Prefix(_prefix.get()):
Expand Down Expand Up @@ -478,7 +478,10 @@ def load_dict(self, param_dict, ctx=None, allow_missing=False,
"which contains parameters %s. Set ignore_extra=True to ignore. "%(
name, error_str, _brief_print_list(params.keys())))
if name in params:
params[name]._load_init(loaded[name], ctx, cast_dtype=cast_dtype, dtype_source=dtype_source)
param = loaded[name]
if isinstance(param, np.ndarray):
param = _mx_np.array(param) if is_np_array() else nd.array(param)
params[name]._load_init(param, ctx, cast_dtype=cast_dtype, dtype_source=dtype_source)

def register_child(self, block, name=None):
"""Registers block as a child of self. :py:class:`Block` s assigned to self as
Expand Down Expand Up @@ -561,8 +564,8 @@ def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False,
params = self.collect_params()
if verbose:
init.set_verbosity(verbose=verbose)
for k, v in params.items():
v.initialize(None, ctx, init, force_reinit=force_reinit, structural_name=k)
for v in params.values():
v.initialize(None, ctx, init, force_reinit=force_reinit)

def hybridize(self, active=True, **kwargs):
""" Please refer description of HybridBlock hybridize().
Expand Down Expand Up @@ -657,6 +660,14 @@ def share_parameters(self, shared):
dense1.weight = dense0.weight
dense1.bias = dense0.bias
Note that unlike the `load_parameters` or `load_dict` functions,
`share_parameters` results in the `Parameter` object being shared (or
tied) between the models, whereas `load_parameters` or `load_dict` only
set the value of the data dictionary of a model. If you call
`load_parameters` or `load_dict` after `share_parameters`, the loaded
value will be reflected in all networks that use the shared (or tied)
`Parameter` object.
Parameters
----------
shared : Dict
Expand Down Expand Up @@ -1276,10 +1287,14 @@ def export(self, path, epoch=0, remove_amp_cast=True):
"this block at least once before calling export.")
sym = copy.copy(self._cached_graph[1])

# Deduplicate params (shared parameters use the same input symbol)
reverse_params = {v: k for k, v in self.collect_params().items()}
params = {v: k for k, v in reverse_params.items()}

# In export we have global information on the structure of the graph
# can rename the symbol inputs to human-readable, deterministic names.
# That's not true in general, which is why internally random unique identifiers are used
rename_map = {param.var().name: name for name, param in self.collect_params().items()}
# That's not true in general, which is why internally random unique identifiers are used.
rename_map = {param.var().name: name for name, param in params.items()}
for var in sym.get_inputs():
if var.name in rename_map:
var._set_attr(name=rename_map[var.name])
Expand All @@ -1290,7 +1305,8 @@ def export(self, path, epoch=0, remove_amp_cast=True):
arg_names = set(sym.list_arguments())
aux_names = set(sym.list_auxiliary_states())
arg_dict = {}
for name, param in self.collect_params().items():

for name, param in params.items():
if name in arg_names:
arg_dict['arg:%s'%name] = param._reduce()
else:
Expand Down
10 changes: 2 additions & 8 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def __init__(self, name='weight', grad_req='write', shape=None, dtype=mx_real_t,
shape = (shape,)
self._shape = shape
self._name = name
self._structured_name = ''
self._dtype = dtype
self.lr_mult = lr_mult
self.wd_mult = wd_mult
Expand Down Expand Up @@ -360,7 +359,7 @@ def _finish_deferred_init(self):
zeros_fn = ndarray.zeros
data = zeros_fn(**kwargs)
initializer.create(default_init)(
initializer.InitDesc(self.name, {'__init__': init, 'structure': self._structural_name}), data)
initializer.InitDesc(self.name, {'__init__': init}), data)

self._init_impl(data, ctx)

Expand Down Expand Up @@ -417,7 +416,7 @@ def _reduce(self):
return data

def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(),
force_reinit=False, structural_name=''):
force_reinit=False):
"""Initializes parameter and gradient arrays. Only used for :py:class:`NDArray` API.
Parameters
Expand All @@ -438,10 +437,6 @@ def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(),
and :py:meth:`Parameter.init` are ``None``.
force_reinit : bool, default False
Whether to force re-initialization if parameter is already initialized.
structural_name : str, default ""
The structural name for the parameter in the block.
The value would be accessed in InitDesc.attrs['structure'] by self-defined initializers.
Users may want to initialize parameters based on the block's structure
Examples
--------
>>> weight = mx.gluon.Parameter('weight', shape=(2, 2))
Expand Down Expand Up @@ -470,7 +465,6 @@ def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(),
stacklevel=2)
return
self._data = self._grad = None
self._structural_name = structural_name
if ctx is None:
ctx = [context.current_context()]
if isinstance(ctx, Context):
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def __neg__(self):
return negative(self)

def __deepcopy__(self, _):
return super(_Symbol, self).as_np_ndarray()
return super().__deepcopy__(_).as_np_ndarray()

def __eq__(self, other):
"""x.__eq__(y) <=> x == y"""
Expand Down
19 changes: 4 additions & 15 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2381,8 +2381,7 @@ def is_op_runnable():
@use_np
def check_gluon_hybridize_consistency(net_builder, data_l, numpy_func=None, test_grad=True,
rtol=1E-4, atol=1E-4):
"""Check whether a HybridBlock has consistent output between the hybridized
v.s. non-hybridized versions
"""Check whether a HybridBlock has consistent output when hybridized or not hybridized
The network should not contain any random number generators.
Expand All @@ -2404,15 +2403,6 @@ def check_gluon_hybridize_consistency(net_builder, data_l, numpy_func=None, test
atol : float, optional
The absolute error tolerance, default 1E-4. Default 1E-4.
"""
class _NumpyParamDictInit(mx.init.Initializer):
"""Initializes parameters with the cached numpy ndarrays dictionary
"""
def __init__(self, np_params):
super(_NumpyParamDictInit, self).__init__()
self._np_params = np_params

def _init_weight(self, name, arr):
arr[()] = self._np_params[name.attrs['structure']]
saved_out_np = None
saved_grad_np_l = None
params_init = None
Expand All @@ -2423,7 +2413,7 @@ def _init_weight(self, name, arr):
if params_init is None:
net.initialize()
else:
net.initialize(params_init)
net.load_dict(params_init)
if hybridize:
net.hybridize()
in_data_l = [ele.copy() for ele in data_l]
Expand All @@ -2435,9 +2425,8 @@ def _init_weight(self, name, arr):
out.backward(out)
else:
out = net(*in_data_l)
if params_init is None:
np_params = {k: v.data().asnumpy() for k, v in net.collect_params().items()}
params_init = _NumpyParamDictInit(np_params)
if params_init is None: # Deferred initialization finished
params_init = {k: v.data().asnumpy() for k, v in net.collect_params().items()}
if saved_out_np is None:
saved_out_np = out.asnumpy()
else:
Expand Down
14 changes: 7 additions & 7 deletions tests/python/gpu/test_gluon_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,20 +443,20 @@ def test_symbol_block_fp16(tmpdir):
net_fp32.hybridize()
data = mx.nd.zeros((1, 3, 224, 224), dtype='float16', ctx=ctx)
net_fp32.forward(data)
net_fp32.export(tmpfile, 0)
symbol_file, param_file = net_fp32.export(tmpfile, 0)

# 2. Load the saved model and verify if all the params are loaded correctly.
# and choose one of the param to verify the type if fp16.
sm = mx.sym.load(tmpfile + '-symbol.json')
# Choose one of the parameters to verify the type is fp16.
sm = mx.sym.load(symbol_file)
inputs = mx.sym.var('data', dtype='float16')
net_fp16 = mx.gluon.SymbolBlock(sm, inputs)
net_fp16.load_parameters(tmpfile + '-0000.params', ctx=ctx)
net_fp16.load_parameters(param_file, ctx=ctx)
# 3. Get a conv layer's weight parameter name. Conv layer's weight param is
# expected to be of dtype casted, fp16.
name = None
for param_name, param in net_fp32.collect_params().items():
name = None
for param_name in net_fp32.collect_params().keys():
if 'conv' in param_name and 'weight' in param_name:
name = param.name
name = param_name
break
assert np.dtype(net_fp16.params[name].dtype) == np.dtype(np.float16)

Expand Down
51 changes: 26 additions & 25 deletions tests/python/unittest/onnx/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,6 @@
# under the License.

"""ONNX test backend wrapper"""
try:
import onnx.backend.test
except ImportError:
raise ImportError("Onnx and protobuf need to be installed")

import test_cases
import unittest
import backend as mxnet_backend
import logging

operations = ['import', 'export']
backends = ['mxnet', 'gluon']
# This is a pytest magic variable to load extra plugins
pytest_plugins = "onnx.backend.test.report",


def build_test_suite(backend_tests): # type: () -> unittest.TestSuite
'''
Expand Down Expand Up @@ -80,13 +65,29 @@ def prepare_tests(backend, oper):
return BACKEND_TESTS


for bkend in backends:
for operation in operations:
log = logging.getLogger(bkend + operation)
if bkend == 'gluon' and operation == 'export':
log.warning('Gluon->ONNX export not implemented. Skipping tests...')
continue
log.info('Executing tests for ' + bkend + ' backend: ' + operation)
mxnet_backend.MXNetBackend.set_params(bkend, operation)
BACKEND_TESTS = prepare_tests(mxnet_backend, operation)
unittest.TextTestRunner().run(build_test_suite(BACKEND_TESTS.enable_report()))
if __name__ == '__main__':
try:
import onnx.backend.test
except ImportError:
raise ImportError("Onnx and protobuf need to be installed")

import test_cases
import unittest
import backend as mxnet_backend
import logging

operations = ['import', 'export']
backends = ['mxnet', 'gluon']
# This is a pytest magic variable to load extra plugins
pytest_plugins = "onnx.backend.test.report",

for bkend in backends:
for operation in operations:
log = logging.getLogger(bkend + operation)
if bkend == 'gluon' and operation == 'export':
log.warning('Gluon->ONNX export not implemented. Skipping tests...')
continue
log.info('Executing tests for ' + bkend + ' backend: ' + operation)
mxnet_backend.MXNetBackend.set_params(bkend, operation)
BACKEND_TESTS = prepare_tests(mxnet_backend, operation)
unittest.TextTestRunner().run(build_test_suite(BACKEND_TESTS.enable_report()))
10 changes: 4 additions & 6 deletions tests/python/unittest/onnx/mxnet_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@


# pylint: disable=too-many-locals,wrong-import-position,import-error
from __future__ import absolute_import
import os, sys
import unittest
import logging
Expand All @@ -29,7 +28,6 @@
from mxnet.test_utils import set_default_context
from mxnet.gluon import nn
from mxnet.gluon import HybridBlock
from mxnet.contrib import onnx as onnx_mxnet
import mxnet as mx

logger = logging.getLogger()
Expand Down Expand Up @@ -58,21 +56,21 @@ def _check_onnx_export(net, group_outputs=False, shape_type=tuple, extra_params=
data = nd.random.uniform(0, 1, (1, 1024))
output = _force_list(net(data)) # initialize weights
net_sym = _optional_group(net(sym.Variable('data')), group_outputs)
net_params = {param.name: param._reduce() for name, param in net.collect_params().items()}
net_params = {param.var().name: param._reduce() for param in net.collect_params().values()}
net_params.update(extra_params)
with tempfile.TemporaryDirectory() as tmpdirname:
onnx_file_path = os.path.join(tmpdirname, 'net.onnx')
export_path = onnx_mxnet.export_model(
export_path = mx.contrib.onnx.export_model(
sym=net_sym,
params=net_params,
input_shape=[shape_type(data.shape)],
onnx_file_path=onnx_file_path)
assert export_path == onnx_file_path
# Try importing the model to symbol
_assert_sym_equal(net_sym, onnx_mxnet.import_model(export_path)[0])
_assert_sym_equal(net_sym, mx.contrib.onnx.import_model(export_path)[0])

# Try importing the model to gluon
imported_net = onnx_mxnet.import_to_gluon(export_path, ctx=None)
imported_net = mx.contrib.onnx.import_to_gluon(export_path, ctx=None)
_assert_sym_equal(net_sym, _optional_group(imported_net(sym.Variable('data')), group_outputs))

# Confirm network outputs are the same
Expand Down
8 changes: 7 additions & 1 deletion tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,12 @@ def test_dense():
assert outs == [(17, 128)]


def test_hybrid_sequential_unique_internals():
net = mx.gluon.nn.HybridSequential()
net.add(mx.gluon.nn.Dense(100, activation='relu'), mx.gluon.nn.Dense(10))
assert len(set(s.name for s in net(mx.sym.Variable('data')).get_internals())) == 8


@with_seed()
def test_symbol_block(tmpdir):
model = nn.HybridSequential()
Expand Down Expand Up @@ -1515,7 +1521,7 @@ def __init__(self):
backbone = gluon.model_zoo.vision.resnet18_v1()
backbone.initialize()
backbone.hybridize()
backbone(mx.nd.random.normal(shape=(1, 3, 32, 32)))
backbone(mx.nd.random.normal(shape=(1, 3, 32, 32), ctx=mx.cpu()))
sym_file, params_file = backbone.export(tmpfile)
self.backbone = gluon.SymbolBlock.imports(sym_file, 'data', params_file)
self.body = nn.Conv2D(3, 1)
Expand Down

0 comments on commit 0661ca6

Please sign in to comment.