From e2366e9102e6862416bf998af52baaa5e9c0a31b Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Wed, 15 Jul 2020 22:01:36 +0000 Subject: [PATCH] Refactor scope functionality in Python API (#18619) * Refactor scope functionality in Python API - Remove deprecated metaclass functionality - Remove global state in naming - Switch from threading.local to asyncio compatible contextvars - Stop exposing UUIDs in parameter name * Fix dependencies * Fixes * Fixes * Fix * Fix after merge master --- ci/docker/Dockerfile.build.centos7 | 18 +- ci/docker/install/requirements | 2 + example/profiler/profiler_matmul.py | 4 +- include/mxnet/c_api.h | 8 + python/mxnet/attribute.py | 46 ++-- python/mxnet/base.py | 63 ------ python/mxnet/context.py | 49 +---- python/mxnet/gluon/block.py | 204 ++++++++---------- .../gluon/contrib/estimator/estimator.py | 2 +- python/mxnet/gluon/contrib/nn/basic_layers.py | 3 +- python/mxnet/gluon/parameter.py | 19 +- python/mxnet/gluon/trainer.py | 13 +- python/mxnet/name.py | 53 ++--- python/mxnet/optimizer/updater.py | 4 +- python/mxnet/profiler.py | 60 ++---- python/mxnet/symbol/contrib.py | 4 +- python/mxnet/symbol/numpy/_symbol.py | 2 +- python/mxnet/symbol/register.py | 28 +-- python/mxnet/symbol/symbol.py | 39 +++- python/mxnet/test_utils.py | 23 +- python/setup.py | 2 +- src/c_api/c_api_symbolic.cc | 14 ++ tests/python/gpu/test_gluon_gpu.py | 14 +- tests/python/gpu/test_profiler_gpu.py | 2 +- tests/python/mkl/test_mkldnn.py | 4 +- tests/python/unittest/onnx/backend_test.py | 51 ++--- .../python/unittest/onnx/mxnet_export_test.py | 10 +- tests/python/unittest/onnx/test_node.py | 8 +- tests/python/unittest/test_autograd.py | 1 + .../python/unittest/test_deferred_compute.py | 3 +- tests/python/unittest/test_gluon.py | 68 +++--- tests/python/unittest/test_gluon_contrib.py | 4 +- tests/python/unittest/test_gluon_rnn.py | 63 +++--- tests/python/unittest/test_gluon_trainer.py | 15 ++ tests/python/unittest/test_memory_opt.py | 5 - .../unittest/test_numpy_default_dtype.py | 5 - tests/python/unittest/test_numpy_op.py | 9 +- tests/python/unittest/test_profiler.py | 4 +- tests/python/unittest/test_sparse_ndarray.py | 13 +- tests/python/unittest/test_thread_local.py | 153 ++++++------- 40 files changed, 466 insertions(+), 626 deletions(-) diff --git a/ci/docker/Dockerfile.build.centos7 b/ci/docker/Dockerfile.build.centos7 index a0b5b127e7ea..8a718c4d1339 100644 --- a/ci/docker/Dockerfile.build.centos7 +++ b/ci/docker/Dockerfile.build.centos7 @@ -119,22 +119,14 @@ RUN export SHORT_CUDA_VERSION=${CUDA_VERSION%.*} && \ yum clean all; \ fi -# Python dependencies -RUN pip3 install --no-cache-dir --upgrade pip && \ - pip3 install --no-cache-dir pylint cython numpy requests h5py scipy==1.2.3 wheel \ - pytest==5.3.5 \ - pytest-env==0.6.2 \ - pytest-cov==2.8.1 \ - pytest-xdist==1.31.0 \ - pytest-timeout==1.3.4 \ - mock==2.0.0 \ - onnx==1.5.0 \ - protobuf==3.5.2 \ - tabulate==0.7.5 - # Fix the en_DK.UTF-8 locale to test locale invariance RUN localedef -i en_DK -f UTF-8 en_DK.UTF-8 +# Python dependencies +RUN python3 -m pip install --upgrade pip +COPY install/requirements /work/ +RUN python3 -m pip install -r /work/requirements + ARG USER_ID=0 COPY install/docker_filepermissions.sh /work/ RUN /work/docker_filepermissions.sh diff --git a/ci/docker/install/requirements b/ci/docker/install/requirements index ce06681d96af..bd0114ce0464 100644 --- a/ci/docker/install/requirements +++ b/ci/docker/install/requirements @@ -22,6 +22,7 @@ numpy>=1.17 requests>=2.20.0,<3 graphviz<0.9.0,>=0.8.1 +contextvars;python_version<"3.7" # Optional dependencies onnx==1.5.0 @@ -42,6 +43,7 @@ pytest-xdist==1.31.0 pytest-timeout==1.3.4 flaky==3.6.1 setuptools +wheel mock==2.0.0 # TVM dependencies diff --git a/example/profiler/profiler_matmul.py b/example/profiler/profiler_matmul.py index 6b92bcc21ec0..7fe43c84b1bb 100644 --- a/example/profiler/profiler_matmul.py +++ b/example/profiler/profiler_matmul.py @@ -52,10 +52,10 @@ def parse_args(): print("execution begin") for i in range(args.iter_num): if i == args.begin_profiling_iter: - t0 = time.clock() + t0 = time.process_time() mx.profiler.set_state('run') if i == args.end_profiling_iter: - t1 = time.clock() + t1 = time.process_time() mx.profiler.set_state('stop') executor.forward() c = executor.outputs[0] diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 36c76e5a3c57..04d863991b48 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1748,6 +1748,14 @@ MXNET_DLL int MXSymbolGetNumOutputs(SymbolHandle symbol, */ MXNET_DLL int MXSymbolGetInternals(SymbolHandle symbol, SymbolHandle *out); +/*! + * \brief Get a symbol that contains all the inputs. + * \param symbol The symbol + * \param out The output symbol whose outputs are all the internals. + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolGetInputs(SymbolHandle symbol, + SymbolHandle *out); /*! * \brief Get a symbol that contains only direct children. * \param symbol The symbol diff --git a/python/mxnet/attribute.py b/python/mxnet/attribute.py index bed723d69e15..2de8798fccff 100644 --- a/python/mxnet/attribute.py +++ b/python/mxnet/attribute.py @@ -14,16 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -# coding: utf-8 """Attribute scoping support for symbolic API.""" -import threading -import warnings +import contextvars from collections import defaultdict -from .base import string_types, classproperty, with_metaclass, _MXClassPropertyMetaClass +from .base import string_types -class AttrScope(with_metaclass(_MXClassPropertyMetaClass, object)): +class AttrScope: """Attribute manager for scoping. User can also inherit this object to change naming behavior. @@ -33,7 +30,6 @@ class AttrScope(with_metaclass(_MXClassPropertyMetaClass, object)): kwargs The attributes to set for all symbol creations in the scope. """ - _current = threading.local() _subgraph_names = defaultdict(int) def __init__(self, **kwargs): @@ -65,37 +61,23 @@ def get(self, attr): else: return attr if attr else {} - def __enter__(self): - # pylint: disable=protected-access - if not hasattr(AttrScope._current, "value"): - AttrScope._current.value = AttrScope() - self._old_scope = AttrScope._current.value - attr = AttrScope._current.value._attr.copy() + def __enter__(self): # pylint: disable=protected-access + attr = _current.get()._attr.copy() attr.update(self._attr) self._attr = attr - AttrScope._current.value = self + # Token can't be pickled and Token.old_value is Token.MISSING if _current.get() uses default value + self._old_scope = _current.get() + _current.set(self) return self def __exit__(self, ptype, value, trace): assert self._old_scope - AttrScope._current.value = self._old_scope + _current.set(self._old_scope) + - #pylint: disable=no-self-argument - @classproperty - def current(cls): - warnings.warn("AttrScope.current has been deprecated. " - "It is advised to use the `with` statement with AttrScope.", - DeprecationWarning) - if not hasattr(AttrScope._current, "value"): - cls._current.value = AttrScope() - return cls._current.value +_current = contextvars.ContextVar('namemanager', default=AttrScope()) - @current.setter - def current(cls, val): - warnings.warn("AttrScope.current has been deprecated. " - "It is advised to use the `with` statement with AttrScope.", - DeprecationWarning) - cls._current.value = val - #pylint: enable=no-self-argument -AttrScope._current.value = AttrScope() +def current(): + """Returns the current name manager.""" + return _current.get() diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 9a25d9dc43aa..65687fff54a9 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -273,69 +273,6 @@ class MXCallbackList(ctypes.Structure): ] -# Please see: https://stackoverflow.com/questions/5189699/how-to-make-a-class-property -class _MXClassPropertyDescriptor(object): - def __init__(self, fget, fset=None): - self.fget = fget - self.fset = fset - - def __get__(self, obj, clas=None): - if clas is None: - clas = type(obj) - return self.fget.__get__(obj, clas)() - - def __set__(self, obj, value): - if not self.fset: - raise MXNetError("cannot use the setter: %s to set attribute" % obj.__name__) - if inspect.isclass(obj): - type_ = obj - obj = None - else: - type_ = type(obj) - return self.fset.__get__(obj, type_)(value) - - def setter(self, func): - if not isinstance(func, (classmethod, staticmethod)): - func = classmethod(func) - self.fset = func - return self - - -class _MXClassPropertyMetaClass(type): - def __setattr__(cls, key, value): - obj = cls.__dict__.get(key) - if obj and isinstance(obj, _MXClassPropertyDescriptor): - return obj.__set__(cls, value) - - return super(_MXClassPropertyMetaClass, cls).__setattr__(key, value) - - -# with_metaclass function obtained from: https://github.com/benjaminp/six/blob/master/six.py -# pylint: disable=unused-argument -def with_metaclass(meta, *bases): - """Create a base class with a metaclass.""" - # This requires a bit of explanation: the basic idea is to make a dummy - # metaclass for one level of class instantiation that replaces itself with - # the actual metaclass. - class metaclass(type): - - def __new__(cls, name, this_bases, d): - return meta(name, bases, d) - - @classmethod - def __prepare__(cls, name, this_bases): - return meta.__prepare__(name, bases) - return type.__new__(metaclass, 'temporary_class', (), {}) -# pylint: enable=unused-argument - - -def classproperty(func): - if not isinstance(func, (classmethod, staticmethod)): - func = classmethod(func) - - return _MXClassPropertyDescriptor(func) - - def _load_lib(): """Load library by searching possible path.""" lib_path = libinfo.find_lib_path() diff --git a/python/mxnet/context.py b/python/mxnet/context.py index ac9b497396bb..5cd9c7fad766 100644 --- a/python/mxnet/context.py +++ b/python/mxnet/context.py @@ -14,18 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -# coding: utf-8 """Context management API of mxnet.""" -import threading -import warnings +import contextvars import ctypes -from .base import classproperty, with_metaclass, _MXClassPropertyMetaClass from .base import _LIB from .base import check_call -class Context(with_metaclass(_MXClassPropertyMetaClass, object)): +class Context: """Constructs a context. MXNet can run operations on CPU and different GPUs. @@ -66,8 +62,6 @@ class Context(with_metaclass(_MXClassPropertyMetaClass, object)): >>> gpu_array.context gpu(1) """ - # static class variable - _default_ctx = threading.local() devtype2str = {1: 'cpu', 2: 'gpu', 3: 'cpu_pinned', 5: 'cpu_shared'} devstr2type = {'cpu': 1, 'gpu': 2, 'cpu_pinned': 3, 'cpu_shared': 5} def __init__(self, device_type, device_id=0): @@ -115,34 +109,13 @@ def __repr__(self): return self.__str__() def __enter__(self): - if not hasattr(Context._default_ctx, "value"): - Context._default_ctx.value = Context('cpu', 0) - self._old_ctx = Context._default_ctx.value - Context._default_ctx.value = self + # Token can't be pickled and Token.old_value is Token.MISSING if _current.get() uses default value + self._old_ctx = _current.get() + _current.set(self) return self def __exit__(self, ptype, value, trace): - Context._default_ctx.value = self._old_ctx - - #pylint: disable=no-self-argument - @classproperty - def default_ctx(cls): - warnings.warn("Context.default_ctx has been deprecated. " - "Please use Context.current_context() instead. " - "Please use test_utils.set_default_context to set a default context", - DeprecationWarning) - if not hasattr(Context._default_ctx, "value"): - cls._default_ctx.value = Context('cpu', 0) - return cls._default_ctx.value - - @default_ctx.setter - def default_ctx(cls, val): - warnings.warn("Context.default_ctx has been deprecated. " - "Please use Context.current_context() instead. " - "Please use test_utils.set_default_context to set a default context", - DeprecationWarning) - cls._default_ctx.value = val - #pylint: enable=no-self-argument + _current.set(self._old_ctx) def empty_cache(self): """Empties the memory cache for the current contexts device. @@ -162,9 +135,6 @@ def empty_cache(self): dev_id = ctypes.c_int(self.device_id) check_call(_LIB.MXStorageEmptyCache(dev_type, dev_id)) -# initialize the default context in Context -Context._default_ctx.value = Context('cpu', 0) - def cpu(device_id=0): """Returns a CPU context. @@ -299,6 +269,9 @@ def gpu_memory_info(device_id=0): return (free.value, total.value) +_current = contextvars.ContextVar('namemanager', default=Context('cpu', 0)) + + def current_context(): """Returns the current context. @@ -321,6 +294,4 @@ def current_context(): ------- default_ctx : Context """ - if not hasattr(Context._default_ctx, "value"): - Context._default_ctx.value = Context('cpu', 0) - return Context._default_ctx.value + return _current.get() diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 1f9cd43dd2ee..d6782ba94224 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -20,82 +20,50 @@ """Base container class for all neural network models.""" __all__ = ['Block', 'HybridBlock', 'SymbolBlock'] -import threading import copy import warnings import weakref from collections import OrderedDict, defaultdict +import contextlib +import contextvars import re import numpy as np from ..base import mx_real_t, MXNetError, NDArrayHandle, py_str -from .. import symbol, ndarray, initializer, autograd, _deferred_compute as dc +from .. import symbol, ndarray, initializer, autograd, _deferred_compute as dc, name as _name, \ + profiler as _profiler, context as _context from ..symbol.numpy import _symbol as np_symbol from ..symbol import Symbol from ..ndarray import NDArray -from .. import name as _name -from .. import profiler as _profiler from .parameter import Parameter, DeferredInitializationError from .utils import _indent, _brief_print_list, HookHandle, shape_is_known from .utils import _check_same_symbol_type, _check_all_np_ndarrays from .. import numpy_extension as _mx_npx -from .. import numpy as _mx_np +from .. import numpy as _mx_np, ndarray as nd from .. util import is_np_array, np_shape, np_array +_naming_counter = contextvars.ContextVar('namecounter') +_prefix = contextvars.ContextVar('prefix', default='') -class _BlockScope(object): - """Scope for collecting child `Block` s.""" - _current = threading.local() - def __init__(self, block): - self._block = weakref.ref(block) if block is not None else None - self._counter = {} - self._local = threading.local() - self._local._old_scope = None - self._local._name_scope = None - - @staticmethod - def count(hint): - """ - Creates unique name for new `Block`. - The profiler scope is to support the GPU memory profiler. - """ - current = getattr(_BlockScope._current, "value", None) - if current is None: - if not hasattr(_name.NameManager._current, "value"): - _name.NameManager._current.value = _name.NameManager() - block_name = _name.NameManager._current.value.get(None, hint) - return block_name - - count = current._counter.get(hint, 0) - block_name = '%s%d'%(hint, count) - current._counter[hint] = count + 1 - return block_name - - def __enter__(self): - block = self._block() - if block is None or block.name == '': - return self - self._local._old_scope = getattr(_BlockScope._current, "value", None) - _BlockScope._current.value = self - self._local._name_scope = _name.Prefix(block.name + '_') - self._local._name_scope.__enter__() - _profiler_scope_name = block.name + ":" - self._local._profiler_scope = _profiler.Scope(_profiler_scope_name) - self._local._profiler_scope.__enter__() - return self - - def __exit__(self, ptype, value, trace): - block = self._block() - if block is None or block.name == '': - return - self._local._name_scope.__exit__(ptype, value, trace) - self._local._name_scope = None - self._local._profiler_scope.__exit__(ptype, value, trace) - self._local._profiler_scope = None - _BlockScope._current.value = self._local._old_scope +@contextlib.contextmanager +def _block_scope(block): + """Append the classname of the current Block to the symbolic and memory profiler name scopes.""" + name = type(block).__name__.lower() + counter = _naming_counter.get(None) + if counter is not None: + count = counter.get(name, 0) + counter[name] = count + 1 + name = '%s%d'%(name, count) + counter_token = _naming_counter.set({}) + prefix_token = _prefix.set(_prefix.get() + name + '_') + with _name.Prefix(_prefix.get()): + with _profiler.scope(name + ':'): + yield + _naming_counter.reset(counter_token) + _prefix.reset(prefix_token) def _gather_type_ctx_info(args): @@ -230,7 +198,7 @@ def _merger(args, fmt): return _merger(args, fmt)[0] -class Block(object): +class Block: """Base class for all neural network layers and models. Your models should subclass this class. @@ -265,8 +233,6 @@ def __init__(self): self._reg_params = {} self._forward_hooks = OrderedDict() self._forward_pre_hooks = OrderedDict() - self._name = _BlockScope.count(self._alias()) - self._scope = _BlockScope(self) def __repr__(self): s = '{name}(\n{modstr}\n)' @@ -323,11 +289,6 @@ def _find_unregistered_block_in_container(data): def _alias(self): return self.__class__.__name__.lower() - @property - def name(self): - """Name of this :py:class:`Block`, class name + counter """ - return self._name - @property def params(self): """Returns this :py:class:`Block`'s parameter dictionary (does not include its @@ -511,6 +472,8 @@ def load_dict(self, param_dict, ctx=None, allow_missing=False, "Set allow_missing=True to ignore missing parameters."%( name, error_str, _brief_print_list(loaded.keys())) + if ctx is None: + ctx = _context.current_context() for name in loaded: if not ignore_extra and name not in params: raise ValueError( @@ -518,7 +481,10 @@ def load_dict(self, param_dict, ctx=None, allow_missing=False, "which contains parameters %s. Set ignore_extra=True to ignore. "%( name, error_str, _brief_print_list(params.keys()))) if name in params: - params[name]._load_init(loaded[name], ctx, cast_dtype=cast_dtype, dtype_source=dtype_source) + param = loaded[name] + if isinstance(param, np.ndarray): + param = _mx_np.array(param) if is_np_array() else nd.array(param) + params[name]._load_init(param, ctx, cast_dtype=cast_dtype, dtype_source=dtype_source) def register_child(self, block, name=None): """Registers block as a child of self. :py:class:`Block` s assigned to self as @@ -601,8 +567,8 @@ def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False, params = self.collect_params() if verbose: init.set_verbosity(verbose=verbose) - for k, v in params.items(): - v.initialize(None, ctx, init, force_reinit=force_reinit, structural_name=k) + for v in params.values(): + v.initialize(None, ctx, init, force_reinit=force_reinit) def hybridize(self, active=True, **kwargs): """ Please refer description of HybridBlock hybridize(). @@ -697,6 +663,14 @@ def share_parameters(self, shared): dense1.weight = dense0.weight dense1.bias = dense0.bias + Note that unlike the `load_parameters` or `load_dict` functions, + `share_parameters` results in the `Parameter` object being shared (or + tied) between the models, whereas `load_parameters` or `load_dict` only + set the value of the data dictionary of a model. If you call + `load_parameters` or `load_dict` after `share_parameters`, the loaded + value will be reflected in all networks that use the shared (or tied) + `Parameter` object. + Parameters ---------- shared : Dict @@ -820,7 +794,7 @@ def regroup(args, fmt): def _register_summary_hook(block): assert not isinstance(block, HybridBlock) or not block._active, \ - '"{}" must not be hybridized to print summary.'.format(block.name) + '"{}" must not be hybridized to print summary.'.format(type(block).__name__) def _summary_hook(block, _, outputs): class_name = block.__class__.__name__ block_idx = len(summary) - 1 @@ -972,8 +946,9 @@ def _get_graph_v1(self, *args): else: flatten_inputs.append(None) grouped_inputs = _regroup(flatten_inputs, self._in_format) + params = {i: j.var() for i, j in self._reg_params.items()} - with self._scope: + with _block_scope(self): out = self.hybrid_forward(symbol, *grouped_inputs, **params) # pylint: disable=no-value-for-parameter out, self._out_format = _flatten(out, "output") @@ -1018,10 +993,10 @@ def _get_graph(self, *args): def _build_cache(self, *args): data, out = self._get_graph(*args) data_names = {data.name: i for i, data in enumerate(data)} - params = self.collect_params() - params = {p.name: p for p in params.values()} - input_names = out.list_inputs() + params = {p.var().name: p for p in self.collect_params().values()} + param_serialization_names = {p.var().name: n for n, p in self.collect_params().items()} param_names = set(params.keys()) + input_names = out.list_inputs() expected_names = set(input_names) for name in expected_names: assert name in param_names or name in data_names, \ @@ -1075,14 +1050,15 @@ def _build_cache(self, *args): # and might not contain some parameters that were deleted during optimization. self._cached_op_args = [] for i, name in enumerate(input_names): - pair = None + triple = None if name in data_names: data_indices.append(i) - pair = (True, data_names[name]) + triple = (True, name, data_names[name]) else: param_indices.append(i) if name in params: param = params[name] + serialization_name = param_serialization_names[name] # HybridBlock.export else: # The param is missing from the original params dictionary, which means the param must have # been added by the Partition API backend @@ -1096,10 +1072,12 @@ def _build_cache(self, *args): 'Please check the backend.') param = Parameter(name) + param._var_name = name + serialization_name = name # HybridBlock.export param._load_init(param_data, args[0].context) - pair = (False, param) + triple = (False, serialization_name, param) - self._cached_op_args.append(pair) + self._cached_op_args.append(triple) flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \ self._flags @@ -1148,7 +1126,7 @@ def _call_cached_op(self, *args): args_without_none = [ele for ele in args if ele is not None] cargs = [args_without_none[i] if is_arg else i.data() - for is_arg, i in self._cached_op_args] + for is_arg, name, i in self._cached_op_args] out = self._cached_op(*cargs) if isinstance(out, NDArray): out = [out] @@ -1292,7 +1270,7 @@ def _infer_attrs(self, infer_fn, attr, *args): sdict.update({name : attr for name, attr in \ zip(out.list_auxiliary_states(), aux_attrs)}) for i in self.collect_params().values(): - setattr(i, attr, sdict[i.name]) + setattr(i, attr, sdict[i.var().name]) def infer_shape(self, *args): """Infers shape of Parameters from inputs.""" @@ -1342,16 +1320,28 @@ def export(self, path, epoch=0, remove_amp_cast=True): raise RuntimeError( "Please first call block.hybridize() and then run forward with " "this block at least once before calling export.") - sym = self._cached_graph[1] + sym = copy.copy(self._cached_graph[1]) + + # Deduplicate params (shared parameters use the same input symbol) + reverse_params = {v: k for k, v in self.collect_params().items()} + params = {v: k for k, v in reverse_params.items()} + + # In export we have global information on the structure of the graph + # can rename the symbol inputs to human-readable, deterministic names. + # That's not true in general, which is why internally random unique identifiers are used. + rename_map = {param.var().name: name for name, param in params.items()} + for var in sym.get_inputs(): + if var.name in rename_map: + var._set_attr(name=rename_map[var.name]) + sym_filename = '%s-symbol.json'%path sym.save(sym_filename, remove_amp_cast=remove_amp_cast) arg_names = set(sym.list_arguments()) aux_names = set(sym.list_auxiliary_states()) arg_dict = {} - for is_arg, param in self._cached_op_args: + for is_arg, name, param in self._cached_op_args: if not is_arg: - name = param.name if name in arg_names: arg_dict['arg:{}'.format(name)] = param._reduce() else: @@ -1450,8 +1440,9 @@ def forward(self, x, *args): params = {k: v.data(ctx) for k, v in self._reg_params.items()} return self.hybrid_forward(ndarray, x, *args, **params) + params = {i: j.var() for i, j in self._reg_params.items()} - with self._scope: + with _block_scope(self): return self.hybrid_forward(symbol, x, *args, **params) def hybrid_forward(self, F, x, *args, **kwargs): @@ -1577,15 +1568,6 @@ def __repr__(self): def __init__(self, outputs, inputs, params=None): super(SymbolBlock, self).__init__() - structure = defaultdict(list) - if params is None: - params = {} - self._structured_named = False - elif any(k.find('.') != -1 for k in params): - self._structured_named = True - for k, v in params.items(): - structure[v.name].append(k) - params = {p.name : p for p in params.values()} if isinstance(inputs, symbol.Symbol) and len(inputs.list_outputs()) == 1: inputs = [inputs] @@ -1620,29 +1602,31 @@ def __init__(self, outputs, inputs, params=None): arg_types, aux_types = _infer_param_types(syms, out, arg_params, aux_params) - def _set_params_attr(name, **kwargs): - if params.get(name) is None: - param = Parameter(**kwargs) - param._name = name - else: - param = params[name] - param._check_and_setattr(**kwargs) - if self._structured_named: - lis = structure[name] - assert len(lis) > 0, "Can not find structured name for Parameter %s in 'params'. " \ - "Please check 'params' is complete!" % name - for structured_name in lis: - self._reg_params[structured_name] = param - else: - self._reg_params[name] = param + if params is None: + params = {} + unused_params = set(params.keys()) - set(arg_params) - set(aux_params) + if len(unused_params) > 0: + raise ValueError('{} params are unused by the model.'.format(unused_params)) + self._reg_params = params for i, arg in enumerate(arg_params): - if arg not in input_names: - _set_params_attr(name=arg, allow_deferred_init=True, dtype=arg_types[i]) - + if arg in self._reg_params: + self._reg_params[arg]._check_and_setattr(allow_deferred_init=True, dtype=arg_types[i]) + if self._reg_params[arg]._var is None: + self._reg_params[arg]._var_name = arg + elif arg not in input_names: + self._reg_params[arg] = Parameter(name=arg, allow_deferred_init=True, dtype=arg_types[i]) + self._reg_params[arg]._var_name = arg for i, aux in enumerate(aux_params): - if aux not in input_names: - _set_params_attr(name=aux, grad_req='null', allow_deferred_init=True, dtype=aux_types[i]) + if aux in self._reg_params: + self._reg_params[aux]._check_and_setattr(grad_req='null', allow_deferred_init=True, + dtype=aux_types[i]) + if self._reg_params[aux]._var is None: + self._reg_params[aux]._var_name = aux + elif aux not in input_names: + self._reg_params[aux] = Parameter(name=aux, grad_req='null', + allow_deferred_init=True, dtype=aux_types[i]) + self._reg_params[aux]._var_name = aux self._cached_graph = syms, out diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index 8bdecccc844c..bd1e166165b8 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -242,7 +242,7 @@ def _add_default_training_metrics(self): suggested_metric = _suggest_metric_for_loss(self.loss) if suggested_metric: self._train_metrics = [suggested_metric] - loss_name = self.loss.name.rstrip('1234567890') + loss_name = type(self.loss).__name__ self._train_metrics.append(metric_loss(loss_name)) for metric in self._train_metrics: diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py index d5bf41614749..e32c37146efd 100644 --- a/python/mxnet/gluon/contrib/nn/basic_layers.py +++ b/python/mxnet/gluon/contrib/nn/basic_layers.py @@ -24,6 +24,7 @@ 'PixelShuffle3D'] import warnings +import uuid from .... import ndarray as nd, context from ...block import HybridBlock from ...nn import Sequential, HybridSequential, BatchNorm @@ -179,7 +180,7 @@ def __init__(self, in_channels=0, num_devices=None, momentum=0.9, epsilon=1e-5, num_devices = self._get_num_devices() if num_devices is None else num_devices self._kwargs = {'eps': epsilon, 'momentum': momentum, 'fix_gamma': not scale, 'use_global_stats': use_global_stats, - 'ndev': num_devices, 'key': self.name} + 'ndev': num_devices, 'key': uuid.uuid4()} def _get_num_devices(self): warnings.warn("Caution using SyncBatchNorm: " diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index 37d5140e7939..4d1efdc9fe65 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -107,6 +107,8 @@ def __init__(self, name='weight', grad_req='write', shape=None, dtype=mx_real_t, lr_mult=1.0, wd_mult=1.0, init=None, allow_deferred_init=False, differentiable=True, stype='default', grad_stype='default'): self._var = None + self._uuid = str(uuid.uuid4()) + self._var_name = None self._data = None self._grad = None self._ctx_list = None @@ -119,8 +121,7 @@ def __init__(self, name='weight', grad_req='write', shape=None, dtype=mx_real_t, if isinstance(shape, int): shape = (shape,) self._shape = shape - self._name = 'param_{}_{}'.format(str(uuid.uuid4()).replace('-', '_'), name) - self._structured_name = '' + self._name = name self._dtype = dtype self.lr_mult = lr_mult self.wd_mult = wd_mult @@ -358,7 +359,7 @@ def _finish_deferred_init(self): zeros_fn = ndarray.zeros data = zeros_fn(**kwargs) initializer.create(default_init)( - initializer.InitDesc(self.name, {'__init__': init, 'structure': self._structural_name}), data) + initializer.InitDesc(self.name, {'__init__': init}), data) self._init_impl(data, ctx) @@ -415,7 +416,7 @@ def _reduce(self): return data def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(), - force_reinit=False, structural_name=''): + force_reinit=False): """Initializes parameter and gradient arrays. Only used for :py:class:`NDArray` API. Parameters @@ -436,10 +437,6 @@ def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(), and :py:meth:`Parameter.init` are ``None``. force_reinit : bool, default False Whether to force re-initialization if parameter is already initialized. - structural_name : str, default "" - The structural name for the parameter in the block. - The value would be accessed in InitDesc.attrs['structure'] by self-defined initializers. - Users may want to initialize parameters based on the block's structure Examples -------- >>> weight = mx.gluon.Parameter('weight', shape=(2, 2)) @@ -468,7 +465,6 @@ def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(), stacklevel=2) return self._data = self._grad = None - self._structural_name = structural_name if ctx is None: ctx = [context.current_context()] if isinstance(ctx, Context): @@ -647,7 +643,10 @@ def zero_grad(self): def var(self): """Returns a symbol representing this parameter.""" if self._var is None: - self._var = symbol.var(self.name, shape=self.shape, dtype=self.dtype, + if self._var_name is None: # _var_name is set manually in SymbolBlock.import + self._var_name = self._uuid + + self._var = symbol.var(self._var_name, shape=self.shape, dtype=self.dtype, lr_mult=self.lr_mult, wd_mult=self.wd_mult, init=self.init, stype=self._stype) if is_np_array(): diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 66db9235528a..dc9da80923da 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -96,7 +96,10 @@ def __init__(self, params, optimizer, optimizer_params=None, kvstore='device', raise ValueError( "First argument must be a list or dict of Parameters, " \ "got list of %s."%(type(param))) - self._param2idx[param.name] = i + if param._uuid in self._param2idx: + # Shared parameters have same uuid; only need to store one of the shared versions + continue + self._param2idx[param._uuid] = i self._params.append(param) param._set_trainer(self) if param._stype != 'default': @@ -164,7 +167,7 @@ def _init_params(self): params_to_init.append(param) else: param_arrays = param._check_and_get(param._data, list) - idx = self._param2idx[param.name] + idx = self._param2idx[param._uuid] if param._stype != 'default': self._kvstore.init(idx, param_arrays[0]) else: @@ -221,7 +224,7 @@ def _init_kvstore(self): # - backward() # - push_and_update(grad) # - pull(weight) - arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params} + arg_arrays = {param._uuid: param.data(self._contexts[0]) for param in self._params} kvstore, _ = _create_kvstore(config['kvstore'], len(self._contexts), arg_arrays) self._distributed = 'dist' in kvstore.type if kvstore else False update_on_kvstore = self._distributed @@ -239,7 +242,7 @@ def _init_kvstore(self): else: # Training with dense weight and dense gradients. # The only unsupported mode is async with update_on_kvstore=False - arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params} + arg_arrays = {param._uuid: param.data(self._contexts[0]) for param in self._params} kvstore, update_on_kvstore = _create_kvstore(config['kvstore'], len(self._contexts), arg_arrays) self._distributed = 'dist' in kvstore.type if kvstore else False @@ -311,7 +314,7 @@ def _row_sparse_pull(self, parameter, out, row_id, full_idx=False): self._init_kvstore() if self._params_to_init: self._init_params() - idx = self._param2idx[parameter.name] + idx = self._param2idx[parameter._uuid] if full_idx and 'dist' not in self._kvstore.type: assert row_id.size == out.shape[0] self._kvstore.pull(idx, out=out, priority=-idx, ignore_sparse=False) diff --git a/python/mxnet/name.py b/python/mxnet/name.py index e39752ecfda9..59e4f6b39a1c 100644 --- a/python/mxnet/name.py +++ b/python/mxnet/name.py @@ -14,24 +14,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -# coding: utf-8 """Automatic naming support for symbolic API.""" -import threading -import warnings -from .base import classproperty, with_metaclass, _MXClassPropertyMetaClass +import contextvars + -class NameManager(with_metaclass(_MXClassPropertyMetaClass, object)): +class NameManager: """NameManager to do automatic naming. Developers can also inherit from this class to change naming behavior. """ - _current = threading.local() - def __init__(self): self._counter = {} - self._local = threading.local() - self._local._old_manager = None + self._old_manager = None def get(self, name, hint): """Get the canonical name for a symbol. @@ -65,30 +59,14 @@ def get(self, name, hint): return name def __enter__(self): - if not hasattr(NameManager._current, "value"): - NameManager._current.value = NameManager() - self._local._old_manager = NameManager._current.value - NameManager._current.value = self + # Token can't be pickled and Token.old_value is Token.MISSING if _current.get() uses default value + self._old_manager = _current.get() + _current.set(self) return self def __exit__(self, ptype, value, trace): - assert self._local._old_manager - NameManager._current.value = self._local._old_manager - - #pylint: disable=no-self-argument - @classproperty - def current(cls): - warnings.warn("NameManager.current has been deprecated. " - "It is advised to use the `with` statement with NameManager.", - DeprecationWarning) - if not hasattr(NameManager._current, "value"): - cls._current.value = NameManager() - return cls._current.value - - @current.setter - def current(cls, val): - cls._current.value = val - #pylint: enable=no-self-argument + _current.set(self._old_manager) + class Prefix(NameManager): """A name manager that attaches a prefix to all names. @@ -103,12 +81,17 @@ class Prefix(NameManager): ['data', 'mynet_fc1_weight', 'mynet_fc1_bias'] """ def __init__(self, prefix): - super(Prefix, self).__init__() + super().__init__() self._prefix = prefix def get(self, name, hint): - name = super(Prefix, self).get(name, hint) + name = super().get(name, hint) return self._prefix + name -# initialize the default name manager -NameManager._current.value = NameManager() + +_current = contextvars.ContextVar('namemanager', default=NameManager()) + + +def current(): + """Returns the current name manager.""" + return _current.get() diff --git a/python/mxnet/optimizer/updater.py b/python/mxnet/optimizer/updater.py index a96961446f4b..9a5b25ecc2a4 100644 --- a/python/mxnet/optimizer/updater.py +++ b/python/mxnet/optimizer/updater.py @@ -21,7 +21,7 @@ import numpy from ..base import py_str from ..ndarray import NDArray -from ..profiler import Scope +from ..profiler import scope as profiler_scope from ..util import is_np_array from .utils import _as_classic @@ -55,7 +55,7 @@ def __call__(self, index, grad, weight): indices[i] = py_str(idx) idx = indices[i] if idx not in self.states: - with Scope("updater:optimizer_state"): + with profiler_scope("updater:optimizer_state"): self.states[idx] = self.optimizer.create_state_multi_precision(idx, weights[i]) self.states_synced[idx] = True elif not self.states_synced[idx]: diff --git a/python/mxnet/profiler.py b/python/mxnet/profiler.py index 3b8830d11245..d43f7383daa3 100644 --- a/python/mxnet/profiler.py +++ b/python/mxnet/profiler.py @@ -20,7 +20,8 @@ # pylint: disable=too-many-branches, too-many-statements """Profiler setting methods.""" import ctypes -import threading +import contextlib +import contextvars import warnings from .base import _LIB, check_call, c_str, ProfileHandle, c_str_array, py_str, KVStoreHandle @@ -489,7 +490,7 @@ def __init__(self, domain, name): self.name = name self.domain = domain - def mark(self, scope='process'): + def mark(self, scope='process'): # pylint: disable=redefined-outer-name """Set up the profiler state to record operator. Parameters @@ -502,50 +503,27 @@ def mark(self, scope='process'): check_call(_LIB.MXProfileSetMarker(self.domain.handle, c_str(self.name), c_str(scope))) -class Scope(object): - """ - The `_profiler.Scope` was developed to assign the profiler scope for the GPU - memory profiler. It is implicitly invoked when the Gluon API is used. +@contextlib.contextmanager +def scope(name=':', append_mode=False): + """Assign the profiler scope for the GPU memory profiler. + + It is implicitly invoked when the Gluon API is used. Parameters ========== name : Name of the Profiler Scope append_mode : Whether to append the old profiler scope at the front. - """ - _current = threading.local() - - def __init__(self, name=':', append_mode=False): - self._name = name + ":" if not name.endswith(":") else name - self._old_scope = None - if append_mode: - if not hasattr(Scope._current, "value"): - Scope._current.value = Scope() - self._name = Scope._current.value.name + self._name - - def __enter__(self): - if not hasattr(Scope._current, "value"): - Scope._current.value = Scope() - self._old_scope = Scope._current.value - Scope._current.value = self - # Invoke the C API to propagate the profiler scope information to the - # C++ backend. - check_call(_LIB.MXSetProfilerScope(c_str(self.name))) - return self - def __exit__(self, ptype, value, trace): - assert self._old_scope - Scope._current.value = self._old_scope - # If the old profiler scope is also of type `profiler.Scope`, invoke the - # C API once again to recover the previous scope information. Otherwise, - # the default scope `:` will be set. - if isinstance(self._old_scope, Scope): - check_call(_LIB.MXSetProfilerScope(c_str(self._old_scope.name))) - else: - check_call(_LIB.MXSetProfilerScope(c_str(":"))) - - @property - def name(self): - return self._name + """ + name = name + ":" if not name.endswith(":") else name + token = _current_scope.set(_current_scope.get() + name if append_mode else name) + # Invoke the C API to propagate the profiler scope information to the + # C++ backend. + check_call(_LIB.MXSetProfilerScope(c_str(name))) + yield name + _current_scope.reset(token) + # Invoke the C API once again to recover the previous scope information. + check_call(_LIB.MXSetProfilerScope(c_str(_current_scope.get()))) # initialize the default profiler scope -Scope._current.value = Scope() +_current_scope = contextvars.ContextVar('profilerscope', default=':') diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index d1048df7cd6f..3b98c429aac8 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -32,7 +32,7 @@ from . import symbol from ..base import _LIB, check_call from ..base import SymbolHandle, _as_list -from ..attribute import AttrScope +from ..attribute import AttrScope, current as current_attribute __all__ = ["rand_zipfian", "foreach", "while_loop", "cond"] @@ -163,7 +163,7 @@ def _cut_subgraph(subg): return syms def _get_unique_subgraph_name(subgraph_name): - attrs = AttrScope._current.value._attr + attrs = current_attribute()._attr if attrs.get("__subgraph_name__", "") != "": subgraph_name = "".join([attrs["__subgraph_name__"], "$", subgraph_name]) AttrScope._subgraph_names[subgraph_name] += 1 diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index d3521cad1274..9b193f850a93 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -283,7 +283,7 @@ def __neg__(self): return negative(self) def __deepcopy__(self, _): - return super(_Symbol, self).as_np_ndarray() + return super().__deepcopy__(_).as_np_ndarray() def __eq__(self, other): """x.__eq__(y) <=> x == y""" diff --git a/python/mxnet/symbol/register.py b/python/mxnet/symbol/register.py index e1ce4d44ee2b..ffea572896ab 100644 --- a/python/mxnet/symbol/register.py +++ b/python/mxnet/symbol/register.py @@ -22,13 +22,13 @@ import numpy as _np from . import _internal +from .. import name as _name, attribute from ._internal import SymbolBase, _symbol_creator -from ..attribute import AttrScope from ..base import mx_uint, check_call, _LIB, py_str from ..symbol_doc import _build_doc from ..base import _Null, _init_op_module, _is_np_op, _output_is_list from ..name import NameManager -from ..profiler import Scope +from ..profiler import _current_scope as _profiler_scope # pylint: enable=unused-import @@ -170,13 +170,9 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name)) dtype_name, dtype_name, dtype_name)) code.append(""" attr = kwargs.pop('attr', None) - if not hasattr(AttrScope._current, "value"): - AttrScope._current.value = AttrScope() - kwargs.update(AttrScope._current.value.get(attr)) + kwargs.update(attribute.current().get(attr)) name = kwargs.pop('name', None) - if not hasattr(NameManager._current, "value"): - NameManager._current.value = NameManager() - name = NameManager._current.value.get(name, '%s') + name = _name.current().get(name, '%s') _ = kwargs.pop('out', None) keys = [] vals = [] @@ -198,9 +194,7 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name)) code.append(""" if 'profiler_scope' not in keys: keys.append('profiler_scope') - if not hasattr(Scope._current, "value"): - Scope._current.value = Scope() - vals.append(Scope._current.value.name) + vals.append(_profiler_scope.get()) return _symbol_creator(%d, sym_args, sym_kwargs, keys, vals, name, %s, %s)"""%( handle.value, str(is_np_op), str(output_is_list))) else: @@ -208,9 +202,7 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name)) def %s(%s):"""%(func_name, ', '.join(signature))) if not signature_only: code.append(""" - if not hasattr(AttrScope._current, "value"): - AttrScope._current.value = AttrScope() - kwargs.update(AttrScope._current.value.get(attr)) + kwargs.update(attribute.current().get(attr)) sym_kwargs = dict() _keys = [] _vals = [] @@ -255,14 +247,10 @@ def %s(%s):"""%(func_name, ', '.join(signature))) dtype_name, dtype_name)) code.append(""" - if not hasattr(NameManager._current, "value"): - NameManager._current.value = NameManager() - name = NameManager._current.value.get(name, '%s') + name = _name.current().get(name, '%s') if 'profiler_scope' not in _keys: _keys.append('profiler_scope') - if not hasattr(Scope._current, "value"): - Scope._current.value = Scope() - _vals.append(Scope._current.value.name) + _vals.append(_profiler_scope.get()) return _symbol_creator(%d, None, sym_kwargs, _keys, _vals, name, %s, %s)"""%( func_name.lower(), handle.value, str(is_np_op), str(output_is_list))) diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 9d23050d7355..039ac0d9d195 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -30,7 +30,7 @@ from numbers import Number import numpy as _numpy # pylint: disable=relative-import -from ..attribute import AttrScope +from .. import attribute from ..base import _LIB, numeric_types, c_array, c_array_buf, c_str, c_str_array, c_handle_array from ..base import mx_uint, py_str, string_types, integer_types, mx_int, mx_int64 from ..base import NDArrayHandle, SymbolHandle @@ -43,7 +43,7 @@ from . import op from ._internal import SymbolBase, _set_symbol_class from ..util import is_np_shape -from ..profiler import Scope +from ..profiler import _current_scope as _profiler_scope __all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json", "pow", "power", "maximum", "minimum", "hypot", "eye", "zeros", @@ -675,6 +675,33 @@ def _set_attr(self, **kwargs): check_call(_LIB.MXSymbolSetAttr( self.handle, c_str(key), c_str(str(value)))) + def get_inputs(self): + """Gets a new grouped symbol `sgroup`. The output of `sgroup` is a list of inputs to this symbol. + + Consider the following code: + + Example + ------- + >>> a = mx.sym.var('a') + >>> b = mx.sym.var('b') + >>> c = a + b + >>> d = c.get_inputs() + >>> d + + >>> d.list_outputs() + ['a', 'b'] + + Returns + ------- + sgroup : Symbol + A symbol group containing all input nodes of the computation graph + used to compute the symbol. + """ + handle = SymbolHandle() + check_call(_LIB.MXSymbolGetInputs( + self.handle, ctypes.byref(handle))) + return Symbol(handle=handle) + def get_internals(self): """Gets a new grouped symbol `sgroup`. The output of `sgroup` is a list of outputs of all of the internal nodes. @@ -2678,9 +2705,7 @@ def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, handle = SymbolHandle() check_call(_LIB.MXSymbolCreateVariable(c_str(name), ctypes.byref(handle))) ret = Symbol(handle) - if not hasattr(AttrScope._current, "value"): - AttrScope._current.value = AttrScope() - attr = AttrScope._current.value.get(attr) + attr = attribute.current().get(attr) attr = {} if attr is None else attr if shape is not None: attr['__shape__'] = str(shape) @@ -2703,9 +2728,7 @@ def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, if profiler_scope is not None: attr['__profiler_scope__'] = profiler_scope else: - if not hasattr(Scope._current, "value"): - Scope._current.value = Scope() - attr['__profiler_scope__'] = Scope._current.value.name + attr['__profiler_scope__'] = _profiler_scope.get() for k, v in kwargs.items(): if k.startswith('__') and k.endswith('__'): attr[k] = str(v) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index cbfa49f490be..dd0278379f11 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -44,7 +44,7 @@ # in rare cases requests may be not installed pass import mxnet as mx -from .context import Context, current_context +from .context import current_context from .ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID from .symbol import Symbol from .symbol.numpy import _Symbol as np_symbol @@ -62,7 +62,7 @@ def default_context(): def set_default_context(ctx): """Set default context.""" - Context._default_ctx.value = ctx + mx.context._current.set(ctx) def default_dtype(): @@ -2381,8 +2381,7 @@ def is_op_runnable(): @use_np def check_gluon_hybridize_consistency(net_builder, data_l, numpy_func=None, test_grad=True, rtol=1E-4, atol=1E-4): - """Check whether a HybridBlock has consistent output between the hybridized - v.s. non-hybridized versions + """Check whether a HybridBlock has consistent output when hybridized or not hybridized The network should not contain any random number generators. @@ -2404,15 +2403,6 @@ def check_gluon_hybridize_consistency(net_builder, data_l, numpy_func=None, test atol : float, optional The absolute error tolerance, default 1E-4. Default 1E-4. """ - class _NumpyParamDictInit(mx.init.Initializer): - """Initializes parameters with the cached numpy ndarrays dictionary - """ - def __init__(self, np_params): - super(_NumpyParamDictInit, self).__init__() - self._np_params = np_params - - def _init_weight(self, name, arr): - arr[()] = self._np_params[name.attrs['structure']] saved_out_np = None saved_grad_np_l = None params_init = None @@ -2423,7 +2413,7 @@ def _init_weight(self, name, arr): if params_init is None: net.initialize() else: - net.initialize(params_init) + net.load_dict(params_init) if hybridize: net.hybridize() in_data_l = [ele.copy() for ele in data_l] @@ -2435,9 +2425,8 @@ def _init_weight(self, name, arr): out.backward(out) else: out = net(*in_data_l) - if params_init is None: - np_params = {k: v.data().asnumpy() for k, v in net.collect_params().items()} - params_init = _NumpyParamDictInit(np_params) + if params_init is None: # Deferred initialization finished + params_init = {k: v.data().asnumpy() for k, v in net.collect_params().items()} if saved_out_np is None: saved_out_np = out.asnumpy() else: diff --git a/python/setup.py b/python/setup.py index a040cbbf74e2..e2b4d645a792 100644 --- a/python/setup.py +++ b/python/setup.py @@ -30,7 +30,7 @@ else: from setuptools import setup from setuptools.extension import Extension - kwargs = {'install_requires': ['numpy>=1.17', 'requests>=2.20.0,<3', 'graphviz<0.9.0,>=0.8.1'], 'zip_safe': False} + kwargs = {'install_requires': ['numpy>=1.17', 'requests>=2.20.0,<3', 'graphviz<0.9.0,>=0.8.1', 'contextvars;python_version<"3.7"'], 'zip_safe': False} with_cython = False if '--with-cython' in sys.argv: diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 0cc0ef7341e5..3052256c825d 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -177,6 +177,20 @@ int MXSymbolGetOutput(SymbolHandle symbol, return NNSymbolGetOutput(symbol, index, out); } +int MXSymbolGetInputs(SymbolHandle symbol, + SymbolHandle *out) { + nnvm::Symbol *s = new nnvm::Symbol(); + API_BEGIN(); + std::vector inputs = static_cast(symbol)->ListInputs( + nnvm::Symbol::ListInputOption(0)); + for (const nnvm::ObjectPtr &o : inputs) { + nnvm::NodeEntry e(o); + s->outputs.push_back(e); + } + *out = s; + API_END_HANDLE_ERROR(delete s); +} + int MXSymbolGetInternals(SymbolHandle symbol, SymbolHandle *out) { nnvm::Symbol *s = new nnvm::Symbol(); diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 27d5f5e0ec3d..ba7deaeab503 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -443,20 +443,20 @@ def test_symbol_block_fp16(tmpdir): net_fp32.hybridize() data = mx.nd.zeros((1, 3, 224, 224), dtype='float16', ctx=ctx) net_fp32.forward(data) - net_fp32.export(tmpfile, 0) + symbol_file, param_file = net_fp32.export(tmpfile, 0) # 2. Load the saved model and verify if all the params are loaded correctly. - # and choose one of the param to verify the type if fp16. - sm = mx.sym.load(tmpfile + '-symbol.json') + # Choose one of the parameters to verify the type is fp16. + sm = mx.sym.load(symbol_file) inputs = mx.sym.var('data', dtype='float16') net_fp16 = mx.gluon.SymbolBlock(sm, inputs) - net_fp16.load_parameters(tmpfile + '-0000.params', ctx=ctx) + net_fp16.load_parameters(param_file, ctx=ctx) # 3. Get a conv layer's weight parameter name. Conv layer's weight param is # expected to be of dtype casted, fp16. - name = None - for param_name, param in net_fp32.collect_params().items(): + name = None + for param_name in net_fp32.collect_params().keys(): if 'conv' in param_name and 'weight' in param_name: - name = param.name + name = param_name break assert np.dtype(net_fp16.params[name].dtype) == np.dtype(np.float16) diff --git a/tests/python/gpu/test_profiler_gpu.py b/tests/python/gpu/test_profiler_gpu.py index 9356eb9df46a..11a0b7d12c0e 100644 --- a/tests/python/gpu/test_profiler_gpu.py +++ b/tests/python/gpu/test_profiler_gpu.py @@ -34,7 +34,7 @@ def test_gpu_memory_profiler_symbolic(): enable_profiler('test_profiler.json', False, False) profiler.set_state('run') - with profiler.Scope("tensordot"): + with profiler.scope("tensordot"): A = mx.sym.Variable('A') B = mx.sym.Variable('B') C = mx.symbol.dot(A, B, name='dot') diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py index d8489240d552..0bd5d1ead13b 100644 --- a/tests/python/mkl/test_mkldnn.py +++ b/tests/python/mkl/test_mkldnn.py @@ -24,7 +24,7 @@ import mxnet as mx import pytest from mxnet.test_utils import rand_ndarray, assert_almost_equal -from mxnet import gluon +from mxnet import gluon, context from mxnet.gluon import nn from mxnet.test_utils import * curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) @@ -210,7 +210,7 @@ def test_flatten_slice_after_conv(): shape = (2, 16, 16, 16) val = np.random.rand(2, 16, 16, 16).astype(np.float32) - exe = slice1._simple_bind(Context.default_ctx, data=shape) + exe = slice1._simple_bind(context.current_context(), data=shape) exe.arg_arrays[0][:] = val exe.arg_arrays[1][:] = np.random.normal(size=exe.arg_arrays[1].shape) exe.arg_arrays[2][:] = np.random.normal(size=exe.arg_arrays[2].shape) diff --git a/tests/python/unittest/onnx/backend_test.py b/tests/python/unittest/onnx/backend_test.py index 69d9e1427b69..cf3cdc1bfb54 100644 --- a/tests/python/unittest/onnx/backend_test.py +++ b/tests/python/unittest/onnx/backend_test.py @@ -19,21 +19,6 @@ # under the License. """ONNX test backend wrapper""" -try: - import onnx.backend.test -except ImportError: - raise ImportError("Onnx and protobuf need to be installed") - -import test_cases -import unittest -import backend as mxnet_backend -import logging - -operations = ['import', 'export'] -backends = ['mxnet', 'gluon'] -# This is a pytest magic variable to load extra plugins -pytest_plugins = "onnx.backend.test.report", - def build_test_suite(backend_tests): # type: () -> unittest.TestSuite ''' @@ -80,13 +65,29 @@ def prepare_tests(backend, oper): return BACKEND_TESTS -for bkend in backends: - for operation in operations: - log = logging.getLogger(bkend + operation) - if bkend == 'gluon' and operation == 'export': - log.warning('Gluon->ONNX export not implemented. Skipping tests...') - continue - log.info('Executing tests for ' + bkend + ' backend: ' + operation) - mxnet_backend.MXNetBackend.set_params(bkend, operation) - BACKEND_TESTS = prepare_tests(mxnet_backend, operation) - unittest.TextTestRunner().run(build_test_suite(BACKEND_TESTS.enable_report())) +if __name__ == '__main__': + try: + import onnx.backend.test + except ImportError: + raise ImportError("Onnx and protobuf need to be installed") + + import test_cases + import unittest + import backend as mxnet_backend + import logging + + operations = ['import', 'export'] + backends = ['mxnet', 'gluon'] + # This is a pytest magic variable to load extra plugins + pytest_plugins = "onnx.backend.test.report", + + for bkend in backends: + for operation in operations: + log = logging.getLogger(bkend + operation) + if bkend == 'gluon' and operation == 'export': + log.warning('Gluon->ONNX export not implemented. Skipping tests...') + continue + log.info('Executing tests for ' + bkend + ' backend: ' + operation) + mxnet_backend.MXNetBackend.set_params(bkend, operation) + BACKEND_TESTS = prepare_tests(mxnet_backend, operation) + unittest.TextTestRunner().run(build_test_suite(BACKEND_TESTS.enable_report())) diff --git a/tests/python/unittest/onnx/mxnet_export_test.py b/tests/python/unittest/onnx/mxnet_export_test.py index 61d00f2fca3c..0248e4b3d238 100644 --- a/tests/python/unittest/onnx/mxnet_export_test.py +++ b/tests/python/unittest/onnx/mxnet_export_test.py @@ -17,7 +17,6 @@ # pylint: disable=too-many-locals,wrong-import-position,import-error -from __future__ import absolute_import import os, sys import unittest import logging @@ -29,7 +28,6 @@ from mxnet.test_utils import set_default_context from mxnet.gluon import nn from mxnet.gluon import HybridBlock -from mxnet.contrib import onnx as onnx_mxnet import mxnet as mx logger = logging.getLogger() @@ -58,21 +56,21 @@ def _check_onnx_export(net, group_outputs=False, shape_type=tuple, extra_params= data = nd.random.uniform(0, 1, (1, 1024)) output = _force_list(net(data)) # initialize weights net_sym = _optional_group(net(sym.Variable('data')), group_outputs) - net_params = {param.name: param._reduce() for name, param in net.collect_params().items()} + net_params = {param.var().name: param._reduce() for param in net.collect_params().values()} net_params.update(extra_params) with tempfile.TemporaryDirectory() as tmpdirname: onnx_file_path = os.path.join(tmpdirname, 'net.onnx') - export_path = onnx_mxnet.export_model( + export_path = mx.contrib.onnx.export_model( sym=net_sym, params=net_params, input_shape=[shape_type(data.shape)], onnx_file_path=onnx_file_path) assert export_path == onnx_file_path # Try importing the model to symbol - _assert_sym_equal(net_sym, onnx_mxnet.import_model(export_path)[0]) + _assert_sym_equal(net_sym, mx.contrib.onnx.import_model(export_path)[0]) # Try importing the model to gluon - imported_net = onnx_mxnet.import_to_gluon(export_path, ctx=None) + imported_net = mx.contrib.onnx.import_to_gluon(export_path, ctx=None) _assert_sym_equal(net_sym, _optional_group(imported_net(sym.Variable('data')), group_outputs)) # Confirm network outputs are the same diff --git a/tests/python/unittest/onnx/test_node.py b/tests/python/unittest/onnx/test_node.py index 351ac5eb5f1b..fabba1a9e555 100644 --- a/tests/python/unittest/onnx/test_node.py +++ b/tests/python/unittest/onnx/test_node.py @@ -22,7 +22,6 @@ those PRs merged, this file will get EOL'ed. """ # pylint: disable=too-many-locals,wrong-import-position,import-error -from __future__ import absolute_import import sys import os import unittest @@ -34,7 +33,6 @@ from onnx import checker, numpy_helper, helper, load_model from onnx import TensorProto from mxnet.test_utils import download -from mxnet.contrib import onnx as onnx_mxnet import mxnet as mx import backend @@ -123,13 +121,13 @@ def test_exports(self): mx_op = mx_op(**attrs) mx_op.initialize() mx_op(mx.nd.zeros(input_shape)) - params = {p.name: p.data() for p in mx_op.collect_params().values()} + params = {p.var().name: p.data() for p in mx_op.collect_params().values()} outsym = mx_op(input_sym) else: params = {} outsym = mx_op(input_sym, **attrs) - converted_model = onnx_mxnet.export_model(outsym, params, [input_shape], np.float32, - onnx_file_path=outsym.name + ".onnx") + converted_model = mx.contrib.onnx.export_model(outsym, params, [input_shape], np.float32, + onnx_file_path=outsym.name + ".onnx") model = load_model(converted_model) checker.check_model(model) diff --git a/tests/python/unittest/test_autograd.py b/tests/python/unittest/test_autograd.py index d7f078c28b7f..cc1e87ae94e4 100644 --- a/tests/python/unittest/test_autograd.py +++ b/tests/python/unittest/test_autograd.py @@ -420,6 +420,7 @@ def test_get_symbol(): assert len(get_symbol(y).list_arguments()) == 2 @with_seed() +@pytest.mark.garbage_expected def test_grad_with_stype(): def check_grad_with_stype(array_stype, grad_stype, expected_stype): x = mx.nd.zeros((1, 1), stype=array_stype) diff --git a/tests/python/unittest/test_deferred_compute.py b/tests/python/unittest/test_deferred_compute.py index c2441dc54835..4b98470730e0 100644 --- a/tests/python/unittest/test_deferred_compute.py +++ b/tests/python/unittest/test_deferred_compute.py @@ -522,7 +522,8 @@ def test_dc_hybridblock_symbolblock_error(): inputs = mx.sym.var('data') outputs = model(inputs).get_internals() - smodel = mx.gluon.SymbolBlock(outputs, inputs, params=model.collect_params()) + smodel = mx.gluon.SymbolBlock(outputs, inputs) + smodel.initialize() assert len(smodel(mx.nd.zeros((16, 10)))) == 14 diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 77d5119e7dff..52bab2b04ca4 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -252,6 +252,12 @@ def test_dense(): assert outs == [(17, 128)] +def test_hybrid_sequential_unique_internals(): + net = mx.gluon.nn.HybridSequential() + net.add(mx.gluon.nn.Dense(100, activation='relu'), mx.gluon.nn.Dense(10)) + assert len(set(s.name for s in net(mx.sym.Variable('data')).get_internals())) == 8 + + @with_seed() def test_symbol_block(tmpdir): model = nn.HybridSequential() @@ -265,8 +271,8 @@ def test_symbol_block(tmpdir): inputs = mx.sym.var('data') outputs = model(inputs).get_internals() - - smodel = gluon.SymbolBlock(outputs, inputs, params=model.collect_params()) + params = {p.var().name: p for p in model.collect_params().values()} + smodel = gluon.SymbolBlock(outputs, inputs, params=params) assert len(smodel(mx.nd.zeros((16, 10)))) == 14 @@ -288,7 +294,8 @@ def hybrid_forward(self, F, x): inputs = mx.sym.var('data') outputs = model(inputs) - smodel = gluon.SymbolBlock(outputs, inputs, params=model.collect_params()) + params = {p.var().name: p for p in model.collect_params().values()} + smodel = gluon.SymbolBlock(outputs, inputs, params=params) net = Net(smodel) net.hybridize() assert isinstance(net(mx.nd.zeros((16, 10))), mx.nd.NDArray) @@ -306,12 +313,10 @@ def hybrid_forward(self, F, x): net_fp32.hybridize() data = mx.nd.zeros((1,3,224,224), dtype='float64', ctx=ctx) net_fp32.forward(data) - net_fp32.export(tmpfile, 0) + sym_file, params_file = net_fp32.export(tmpfile, 0) # 2.a Load the saved model and verify if all the params are loaded correctly. # and choose one of the param to verify the type if fp64.\ - sym_file = tmpfile + '-symbol.json' - params_file = tmpfile + '-0000.params' sm = mx.sym.load(sym_file) inputs = mx.sym.var('data', dtype='float64') net_fp64 = mx.gluon.SymbolBlock(sm, inputs) @@ -1506,17 +1511,19 @@ def __init__(self, b1, b2): c.load_parameters(param_path) @with_seed() -def test_symbol_block_save_load(): +def test_symbol_block_save_load(tmpdir): + tmp = str(tmpdir) + tmpfile = os.path.join(tmp, 'resnet34_fp64') + class Net(gluon.HybridBlock): def __init__(self): super(Net, self).__init__() backbone = gluon.model_zoo.vision.resnet18_v1() - data = mx.sym.var('data') - featnames = [backbone.features[i][1].name for i in range(4, 7)] - out_names = ['_'.join([featname, 'activation0_output']) for featname in featnames] - internals = backbone(data).get_internals() - outs = [internals[out_name] for out_name in out_names] - self.backbone = gluon.SymbolBlock(outs, data, params=backbone.collect_params()) + backbone.initialize() + backbone.hybridize() + backbone(mx.nd.random.normal(shape=(1, 3, 32, 32))) + sym_file, params_file = backbone.export(tmpfile) + self.backbone = gluon.SymbolBlock.imports(sym_file, 'data', params_file) self.body = nn.Conv2D(3, 1) def hybrid_forward(self, F, x): @@ -1527,10 +1534,11 @@ def hybrid_forward(self, F, x): net1.initialize(mx.init.Normal()) net1.hybridize() net1(mx.nd.random.normal(shape=(1, 3, 32, 32))) - net1.save_parameters('./test_symbol_block_save_load.params') + params_file = os.path.join(tmp, './test_symbol_block_save_load.params') + net1.save_parameters(params_file) net2 = Net() - net2.load_parameters('./test_symbol_block_save_load.params', ctx=mx.cpu()) + net2.load_parameters(params_file) @with_seed() @@ -1721,21 +1729,21 @@ def mon_callback(node_name, opr_name, arr): model.add(mx.gluon.nn.Dense(2)) model.initialize() model.hybridize() - check_name(model, [model[0].name + "_fwd_output"]) + check_name(model, ["hybridsequential_dense0_fwd_output"]) # Test with Activation, FListInputNames not registered, input name will have _input appended model = mx.gluon.nn.HybridSequential() model.add(mx.gluon.nn.Activation("relu")) model.initialize() model.hybridize() - check_name(model, [model[0].name + "_fwd_output"]) + check_name(model, ["hybridsequential_activation0_fwd_output"]) # Test with Pooling, monitor_all is set to True model = mx.gluon.nn.HybridSequential() model.add(mx.gluon.nn.AvgPool1D()) model.initialize() model.hybridize() - check_name(model, [model[0].name + '_fwd_data', model[0].name + '_fwd_output'], + check_name(model, ['hybridsequential_avgpool1d0_fwd_data', 'hybridsequential_avgpool1d0_fwd_output'], expected_opr_names=["Pooling"], monitor_all=True) # stack two layers and test @@ -1745,16 +1753,16 @@ def mon_callback(node_name, opr_name, arr): model.initialize() model.hybridize() check_name(model, - [model[0].name + '_fwd_data', model[0].name + '_fwd_weight', - model[0].name + '_fwd_bias', model[0].name + '_fwd_output', - model[1].name + '_fwd_input0', model[1].name + '_fwd_output'], monitor_all=True) + ['hybridsequential_dense0_fwd_data', 'hybridsequential_dense0_fwd_weight', + 'hybridsequential_dense0_fwd_bias', 'hybridsequential_dense0_fwd_output', + 'hybridsequential_activation0_fwd_input0', 'hybridsequential_activation0_fwd_output'], monitor_all=True) # check with different hybridize modes model.hybridize(static_alloc=True) check_name(model, - [model[0].name + '_fwd_data', model[0].name + '_fwd_weight', - model[0].name + '_fwd_bias', model[0].name + '_fwd_output', - model[1].name + '_fwd_input0', model[1].name + '_fwd_output'], monitor_all=True) + ['hybridsequential_dense0_fwd_data', 'hybridsequential_dense0_fwd_weight', + 'hybridsequential_dense0_fwd_bias', 'hybridsequential_dense0_fwd_output', + 'hybridsequential_activation0_fwd_input0', 'hybridsequential_activation0_fwd_output'], monitor_all=True) @with_seed() def test_apply(): @@ -1962,10 +1970,10 @@ def hybrid_forward(self, F, x): check_layer_forward_withinput(net, x) @with_seed() -def test_group_conv2d_16c(): - grp_list = [16] +@pytest.mark.parametrize('grp', [16]) +@pytest.mark.parametrize('kernel_size', [1, 3]) +def test_group_conv2d_16c(grp, kernel_size): input_size_list = np.random.randint(low=3, high=65, size=10).tolist() - kernel_list = [1, 3] batch_size = 4 class Net(gluon.HybridBlock): def __init__(self, @@ -1983,10 +1991,8 @@ def hybrid_forward(self, F, x): for i in range(len(input_size_list)): x = mx.nd.random.uniform(-1.0, 1.0, shape=(batch_size, 3, input_size_list[i], input_size_list[i])) - for j in range(len(grp_list)): - for k in range(len(kernel_list)): - net = Net(grp_list[j], kernel_list[k]) - check_layer_forward_withinput(net, x) + net = Net(grp, kernel_size) + check_layer_forward_withinput(net, x) @with_seed() @pytest.mark.skip(reason='skippping temporarily, tracked by https://github.com/apache/incubator-mxnet/issues/11164') diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py index 33ea1e495e91..f3c7000c1e90 100644 --- a/tests/python/unittest/test_gluon_contrib.py +++ b/tests/python/unittest/test_gluon_contrib.py @@ -35,7 +35,7 @@ def check_rnn_cell(cell, in_shape=(10, 50), out_shape=(10, 100), begin_state=Non outputs = mx.sym.Group(outputs) assert sorted(cell.collect_params().keys()) == ['h2h_bias', 'h2h_weight', 'i2h_bias', 'i2h_weight'] - assert outputs.list_outputs() == [cell.name + name for name in ['_t0_out_output', '_t1_out_output', '_t2_out_output']] + assert outputs.list_outputs() == [type(cell).__name__.lower() + name for name in ['_t0_out_output', '_t1_out_output', '_t2_out_output']] args, outs, auxs = outputs.infer_shape(rnn_t0_data=in_shape, rnn_t1_data=in_shape, @@ -121,7 +121,7 @@ def test_lstmp(): outputs, _ = cell.unroll(3, inputs) outputs = mx.sym.Group(outputs) expected_params = ['h2h_bias', 'h2h_weight', 'h2r_weight', 'i2h_bias', 'i2h_weight'] - expected_outputs = [cell.name + name for name in ['_t0_out_output', '_t1_out_output', '_t2_out_output']] + expected_outputs = [type(cell).__name__.lower() + name for name in ['_t0_out_output', '_t1_out_output', '_t2_out_output']] assert sorted(cell.collect_params().keys()) == expected_params assert outputs.list_outputs() == expected_outputs, outputs.list_outputs() diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index 933c2c17d95f..fbb7070eaa34 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -52,8 +52,10 @@ def test_rnn(): outputs = mx.sym.Group(outputs) assert sorted(cell.collect_params().keys()) == ['h2h_bias', 'h2h_weight', 'i2h_bias', 'i2h_weight'] - assert outputs.list_outputs() == \ - [cell.name + name for name in ['_t0_out_output', '_t1_out_output', '_t2_out_output']] + assert outputs.list_outputs() == [ + 'rnncell_t0_out_output', 'rnncell_t1_out_output', + 'rnncell_t2_out_output' + ] args, outs, auxs = outputs.infer_shape(t0_data=(10,50), t1_data=(10,50), t2_data=(10,50)) assert outs == [(10, 100), (10, 100), (10, 100)] @@ -65,8 +67,10 @@ def test_lstm(): outputs, _ = cell.unroll(3, inputs) outputs = mx.sym.Group(outputs) assert sorted(cell.collect_params().keys()) == ['h2h_bias', 'h2h_weight', 'i2h_bias', 'i2h_weight'] - assert outputs.list_outputs() == \ - [cell.name + name for name in ['_t0_out_output', '_t1_out_output', '_t2_out_output']] + assert outputs.list_outputs() == [ + 'lstmcell_t0_out_output', 'lstmcell_t1_out_output', + 'lstmcell_t2_out_output' + ] args, outs, auxs = outputs.infer_shape(t0_data=(10,50), t1_data=(10,50), t2_data=(10,50)) assert outs == [(10, 100), (10, 100), (10, 100)] @@ -171,8 +175,10 @@ def test_gru(): outputs, _ = cell.unroll(3, inputs) outputs = mx.sym.Group(outputs) assert sorted(cell.collect_params().keys()) == ['h2h_bias', 'h2h_weight', 'i2h_bias', 'i2h_weight'] - assert outputs.list_outputs() == \ - [cell.name + name for name in ['_t0_out_output', '_t1_out_output', '_t2_out_output']] + assert outputs.list_outputs() == [ + 'grucell_t0_out_output', 'grucell_t1_out_output', + 'grucell_t2_out_output' + ] args, outs, auxs = outputs.infer_shape(t0_data=(10,50), t1_data=(10,50), t2_data=(10,50)) assert outs == [(10, 100), (10, 100), (10, 100)] @@ -187,17 +193,15 @@ def test_residual(): params = cell.collect_params() assert sorted(params.keys()) == \ ['base_cell.h2h_bias', 'base_cell.h2h_weight', 'base_cell.i2h_bias', 'base_cell.i2h_weight'] - # assert outputs.list_outputs() == \ - # ['rnn_t0_out_plus_residual_output', 'rnn_t1_out_plus_residual_output'] args, outs, auxs = outputs.infer_shape(t0_data=(10, 50), t1_data=(10, 50)) assert outs == [(10, 50), (10, 50)] - outputs = outputs.eval(**{'t0_data':mx.nd.ones((10, 50)), - 't1_data':mx.nd.ones((10, 50)), - params['base_cell.i2h_weight'].name:mx.nd.zeros((150, 50)), - params['base_cell.i2h_bias'].name:mx.nd.zeros((150,)), - params['base_cell.h2h_weight'].name:mx.nd.zeros((150, 50)), - params['base_cell.h2h_bias'].name:mx.nd.zeros((150,))}) + outputs = outputs.eval(**{'t0_data': mx.nd.ones((10, 50)), + 't1_data': mx.nd.ones((10, 50)), + cell.base_cell.i2h_weight.var().name: mx.nd.zeros((150, 50)), + cell.base_cell.i2h_bias.var().name: mx.nd.zeros((150, )), + cell.base_cell.h2h_weight.var().name: mx.nd.zeros((150, 50)), + cell.base_cell.h2h_bias.var().name: mx.nd.zeros((150, ))}) expected_outputs = np.ones((10, 50)) assert np.array_equal(outputs[0].asnumpy(), expected_outputs) assert np.array_equal(outputs[1].asnumpy(), expected_outputs) @@ -212,11 +216,11 @@ def test_residual_bidirectional(): inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(2)] outputs, _ = cell.unroll(2, inputs, merge_outputs=False) outputs = mx.sym.Group(outputs) - params = cell.collect_params() + params = cell.collect_params() assert sorted(params.keys()) == \ - ['base_cell.l_cell.h2h_bias', 'base_cell.l_cell.h2h_weight', + ['base_cell.l_cell.h2h_bias', 'base_cell.l_cell.h2h_weight', 'base_cell.l_cell.i2h_bias', 'base_cell.l_cell.i2h_weight', - 'base_cell.r_cell.h2h_bias', 'base_cell.r_cell.h2h_weight', + 'base_cell.r_cell.h2h_bias', 'base_cell.r_cell.h2h_weight', 'base_cell.r_cell.i2h_bias', 'base_cell.r_cell.i2h_weight'] # assert outputs.list_outputs() == \ # ['bi_t0_plus_residual_output', 'bi_t1_plus_residual_output'] @@ -225,14 +229,14 @@ def test_residual_bidirectional(): assert outs == [(10, 50), (10, 50)] outputs = outputs.eval(**{'rnn_t0_data':mx.nd.ones((10, 50))+5, 'rnn_t1_data':mx.nd.ones((10, 50))+5, - params['base_cell.l_cell.i2h_weight'].name:mx.nd.zeros((75, 50)), - params['base_cell.l_cell.i2h_bias'].name:mx.nd.zeros((75,)), - params['base_cell.l_cell.h2h_weight'].name:mx.nd.zeros((75, 25)), - params['base_cell.l_cell.h2h_bias'].name:mx.nd.zeros((75,)), - params['base_cell.r_cell.i2h_weight'].name:mx.nd.zeros((75, 50)), - params['base_cell.r_cell.i2h_bias'].name:mx.nd.zeros((75,)), - params['base_cell.r_cell.h2h_weight'].name:mx.nd.zeros((75, 25)), - params['base_cell.r_cell.h2h_bias'].name:mx.nd.zeros((75,))}) + cell.base_cell.l_cell.i2h_weight.var().name:mx.nd.zeros((75, 50)), + cell.base_cell.l_cell.i2h_bias.var().name:mx.nd.zeros((75,)), + cell.base_cell.l_cell.h2h_weight.var().name:mx.nd.zeros((75, 25)), + cell.base_cell.l_cell.h2h_bias.var().name:mx.nd.zeros((75,)), + cell.base_cell.r_cell.i2h_weight.var().name:mx.nd.zeros((75, 50)), + cell.base_cell.r_cell.i2h_bias.var().name:mx.nd.zeros((75,)), + cell.base_cell.r_cell.h2h_weight.var().name:mx.nd.zeros((75, 25)), + cell.base_cell.r_cell.h2h_bias.var().name:mx.nd.zeros((75,))}) expected_outputs = np.ones((10, 50))+5 assert np.array_equal(outputs[0].asnumpy(), expected_outputs) assert np.array_equal(outputs[1].asnumpy(), expected_outputs) @@ -260,8 +264,7 @@ def test_stack(): assert '1.base_cell.h2h_bias' in keys assert '1.base_cell.i2h_weight' in keys assert '1.base_cell.i2h_bias' in keys - assert outputs.list_outputs() == \ - [cell[4].name + name for name in ['_t0_out_output', '_t1_out_output', '_t2_out_output']] + assert outputs.list_outputs() == ['lstmcell_t0_out_output', 'lstmcell_t1_out_output', 'lstmcell_t2_out_output'] args, outs, auxs = outputs.infer_shape(t0_data=(10,50), t1_data=(10,50), t2_data=(10,50)) assert outs == [(10, 100), (10, 100), (10, 100)] @@ -290,8 +293,7 @@ def test_hybridstack(): assert '1.base_cell.h2h_bias' in keys assert '1.base_cell.i2h_weight' in keys assert '1.base_cell.i2h_bias' in keys - assert outputs.list_outputs() == \ - [cell[4].name + name for name in ['_t0_out_output', '_t1_out_output', '_t2_out_output']] + assert outputs.list_outputs() == ['lstmcell_t0_out_output', 'lstmcell_t1_out_output', 'lstmcell_t2_out_output'] args, outs, auxs = outputs.infer_shape(t0_data=(10,50), t1_data=(10,50), t2_data=(10,50)) assert outs == [(10, 100), (10, 100), (10, 100)] @@ -533,7 +535,7 @@ def hybrid_forward(self, F, seq): symbol_file="./model-symbol.json", input_names=["data"], param_file="./model-0000.params", - ctx=mx.Context.default_ctx + ctx=mx.context.current_context() ) output2 = symbol(input) assert_almost_equal(output1.asnumpy(), output2.asnumpy()) @@ -910,4 +912,3 @@ def hybrid_forward(self, F, inputs, valid_len): _check_bidirectional_unroll_valid_length(1) _check_bidirectional_unroll_valid_length(3) - diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py index 8cf78042411e..1e3a1028cf43 100644 --- a/tests/python/unittest/test_gluon_trainer.py +++ b/tests/python/unittest/test_gluon_trainer.py @@ -344,3 +344,18 @@ def test_gluon_trainer_param_order(): expected_idx = 0 if name == 'ones_' else 1 expected_name = '{}.weight'.format(expected_idx) assert trainer._params[expected_idx].name == params[expected_name].name + + +def test_trainer_allreduce_hybridsequential(): + contexts = [mx.cpu(0), mx.cpu(1)] + net = mx.gluon.nn.HybridSequential() + for _ in range(8): # Create a network with 8 layers + net.add(mx.gluon.nn.Dense(1, weight_initializer='ones', bias_initializer='ones')) + net.initialize(ctx=contexts) + net.hybridize() + trainer = mx.gluon.Trainer(net.collect_params(), 'sgd', update_on_kvstore=False) + for ctx in contexts: + with mx.autograd.record(): + out = net(mx.nd.ones((1, 1), ctx=ctx)) + out.backward() + trainer.allreduce_grads() diff --git a/tests/python/unittest/test_memory_opt.py b/tests/python/unittest/test_memory_opt.py index 0cc217ffff47..d2f85e567553 100644 --- a/tests/python/unittest/test_memory_opt.py +++ b/tests/python/unittest/test_memory_opt.py @@ -104,8 +104,3 @@ def test_fc(): z = mx.sym.Activation(y, act_type='tanh', name='z') z = mx.sym.FullyConnected(z, num_hidden=num_hidden) exec = z._simple_bind(mx.cpu(), 'write', x=(num_hidden,)) - - -if __name__ == "__main__": - import nose - nose.runmodule() diff --git a/tests/python/unittest/test_numpy_default_dtype.py b/tests/python/unittest/test_numpy_default_dtype.py index 3c70aade737d..906298036978 100644 --- a/tests/python/unittest/test_numpy_default_dtype.py +++ b/tests/python/unittest/test_numpy_default_dtype.py @@ -223,8 +223,3 @@ def check_deepnp_indices_default_dtype(): check_deepnp_indices_default_dtype() check_np_indices_default_dtype() - - -if __name__ == '__main__': - import nose - nose.runmodule() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 6fb9af06e67c..2599db03b2b3 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1064,14 +1064,15 @@ def test_np_moment(): class TestMoment(HybridBlock): def __init__(self, name, axis=None, dtype=None, keepdims=False, ddof=0): super(TestMoment, self).__init__() - self._name = name + self._moment_name = name self._axis = axis self._dtype = dtype self._keepdims = keepdims self._ddof = ddof def hybrid_forward(self, F, a, *args, **kwargs): - return getattr(a, self._name)(axis=self._axis, dtype=self._dtype, keepdims=self._keepdims, ddof=self._ddof) + return getattr(a, self._moment_name)(axis=self._axis, dtype=self._dtype, + keepdims=self._keepdims, ddof=self._ddof) def is_int(dtype): return 'int' in dtype @@ -4562,9 +4563,9 @@ class TestRandomGrad(HybridBlock): def __init__(self, shape, op_name): super(TestRandomGrad, self).__init__() self._shape = shape - self._name = op_name + self._dist_name = op_name def hybrid_forward(self, F, loc, scale): - op = getattr(F.np.random, self._name, None) + op = getattr(F.np.random, self._dist_name, None) assert op is not None return op(loc=loc, scale=scale, size=self._shape) diff --git a/tests/python/unittest/test_profiler.py b/tests/python/unittest/test_profiler.py index 0bd3eaf1fb8d..eabebf48b785 100644 --- a/tests/python/unittest/test_profiler.py +++ b/tests/python/unittest/test_profiler.py @@ -64,10 +64,10 @@ def test_profiler(): for i in range(iter_num): if i == begin_profiling_iter: - t0 = time.clock() + t0 = time.process_time() profiler.set_state('run') if i == end_profiling_iter: - t1 = time.clock() + t1 = time.process_time() profiler.set_state('stop') executor.forward() c = executor.outputs[0] diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index 4e122453ff7b..31c9cb8403b3 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -18,6 +18,7 @@ import pickle as pkl from mxnet.ndarray import NDArray +import mxnet as mx from mxnet.test_utils import * from common import setup_module, with_seed, random_seed, teardown_module from mxnet.base import mx_real_t @@ -623,7 +624,7 @@ def check_create_csr_from_nd(shape, density, dtype): # verify csr matrix dtype and ctx is consistent from the ones provided assert csr_created.dtype == dtype, (csr_created, dtype) assert csr_created.data.dtype == dtype, (csr_created.data.dtype, dtype) - assert csr_created.context == Context.default_ctx, (csr_created.context, Context.default_ctx) + assert csr_created.context == mx.context.current_context(), (csr_created.context, mx.context.current_context()) csr_copy = mx.nd.array(csr_created) assert(same(csr_copy.asnumpy(), csr_created.asnumpy())) @@ -641,7 +642,7 @@ def check_create_csr_from_coo(shape, density, dtype): # verify csr matrix dtype and ctx is consistent assert csr_created.dtype == dtype, (csr_created.dtype, dtype) assert csr_created.data.dtype == dtype, (csr_created.data.dtype, dtype) - assert csr_created.context == Context.default_ctx, (csr_created.context, Context.default_ctx) + assert csr_created.context == mx.context.current_context(), (csr_created.context, mx.context.current_context()) def check_create_csr_from_scipy(shape, density, f): def assert_csr_almost_equal(nd, sp): @@ -768,7 +769,7 @@ def check_create_from_dns(shape, f, dense_arr, dtype, default_dtype, ctx): # verify the default dtype inferred from dense arr arr2 = f(dense_arr) assert(arr2.dtype == default_dtype) - assert(arr2.context == Context.default_ctx) + assert(arr2.context == mx.context.current_context()) shape = rand_shape_2d() dtype = np.int32 src_dtype = np.float64 @@ -791,7 +792,7 @@ def check_create_from_sp(shape, f, sp_arr, dtype, src_dtype, ctx): # verify the default dtype inferred from dense arr arr2 = f(sp_arr) assert(arr2.dtype == src_dtype) - assert(arr2.context == Context.default_ctx) + assert(arr2.context == mx.context.current_context()) shape = rand_shape_2d() src_dtype = np.float64 @@ -830,7 +831,7 @@ def check_csr_empty(shape, dtype, ctx): # check the default value for dtype and ctx arr = mx.nd.sparse.csr_matrix(shape) assert(arr.dtype == np.float32) - assert(arr.context == Context.default_ctx) + assert(arr.context == mx.context.current_context()) def check_rsp_empty(shape, dtype, ctx): arr = mx.nd.sparse.row_sparse_array(shape, dtype=dtype, ctx=ctx) @@ -841,7 +842,7 @@ def check_rsp_empty(shape, dtype, ctx): # check the default value for dtype and ctx arr = mx.nd.sparse.row_sparse_array(shape) assert(arr.dtype == np.float32) - assert(arr.context == Context.default_ctx) + assert(arr.context == mx.context.current_context()) stypes = ['csr', 'row_sparse'] shape = rand_shape_2d() diff --git a/tests/python/unittest/test_thread_local.py b/tests/python/unittest/test_thread_local.py index 521bc834c4f5..05308a2d3dfd 100644 --- a/tests/python/unittest/test_thread_local.py +++ b/tests/python/unittest/test_thread_local.py @@ -18,21 +18,19 @@ import threading import numpy as np import mxnet as mx -from mxnet import context, attribute, name -from mxnet.gluon import block +from mxnet import context, attribute from mxnet.context import Context from mxnet.attribute import AttrScope -from mxnet.name import NameManager from mxnet.test_utils import assert_almost_equal, set_default_context from mxnet.util import _NumpyArrayScope, set_np_shape def test_context(): ctx_list = [] - ctx_list.append(Context.default_ctx) + ctx_list.append(context.current_context()) def f(): set_default_context(mx.gpu(11)) - ctx_list.append(Context.default_ctx) + ctx_list.append(context.current_context()) thread = threading.Thread(target=f) thread.start() thread.join() @@ -41,95 +39,107 @@ def f(): assert Context.devtype2str[ctx_list[1].device_typeid] == "gpu" assert ctx_list[1].device_id == 11 - event = threading.Event() + e1 = threading.Event() + e2 = threading.Event() status = [False] def g(): with mx.cpu(10): - event.wait() - if Context.default_ctx.device_id == 10: + e2.set() + e1.wait() + if context.current_context().device_id == 10: status[0] = True thread = threading.Thread(target=g) thread.start() - Context.default_ctx = Context("cpu", 11) - event.set() - thread.join() - event.clear() + e2.wait() + with Context("cpu", 11): + e1.set() + thread.join() + e1.clear() + e2.clear() assert status[0], "Spawned thread didn't set the correct context" def test_attrscope(): attrscope_list = [] - AttrScope.current = AttrScope(y="hi", z="hey") - attrscope_list.append(AttrScope.current) - def f(): - AttrScope.current = AttrScope(x="hello") - attrscope_list.append(AttrScope.current) - thread = threading.Thread(target=f) - thread.start() - thread.join() - assert len(attrscope_list[0]._attr) == 2 - assert attrscope_list[1]._attr["x"] == "hello" + with AttrScope(y="hi", z="hey") as attrscope: + attrscope_list.append(attrscope) - event = threading.Event() + def f(): + with AttrScope(x="hello") as attrscope: + attrscope_list.append(attrscope) + + thread = threading.Thread(target=f) + thread.start() + thread.join() + assert len(attrscope_list[0]._attr) == 2 + assert attrscope_list[1]._attr["x"] == "hello" + + e1 = threading.Event() + e2 = threading.Event() status = [False] def g(): with mx.AttrScope(x="hello"): - event.wait() - if "hello" in AttrScope.current._attr.values(): + e2.set() + e1.wait() + if "hello" in mx.attribute.current()._attr.values(): status[0] = True thread = threading.Thread(target=g) thread.start() - AttrScope.current = AttrScope(x="hi") - event.set() - thread.join() - AttrScope.current = AttrScope() - event.clear() + e2.wait() + with AttrScope(x="hi"): + e1.set() + thread.join() + e1.clear() + e2.clear() assert status[0], "Spawned thread didn't set the correct attr key values" def test_name(): name_list = [] - NameManager.current = NameManager() - NameManager.current.get(None, "main_thread") - name_list.append(NameManager.current) + name_manager = mx.name.current() + name_manager.get(None, "main_thread") + name_list.append(name_manager) def f(): - NameManager.current = NameManager() - NameManager.current.get(None, "spawned_thread") - name_list.append(NameManager.current) + with mx.name.NameManager(): + name_manager = mx.name.current() + name_manager.get(None, "spawned_thread") + name_list.append(name_manager) thread = threading.Thread(target=f) thread.start() thread.join() assert "main_thread" in name_list[0]._counter, "cannot find the string `main thread` in name_list[0]._counter" assert "spawned_thread" in name_list[1]._counter, "cannot find the string `spawned thread` in name_list[1]._counter" - event = threading.Event() + e1 = threading.Event() + e2 = threading.Event() status = [False] def g(): - with NameManager(): - if "main_thread" not in NameManager.current._counter: + with mx.name.NameManager(): + e2.set() + e1.wait() + if "main_thread" not in mx.name.current()._counter: status[0] = True thread = threading.Thread(target=g) thread.start() - NameManager.current = NameManager() - NameManager.current.get(None, "main_thread") - event.set() - thread.join() - event.clear() + e2.wait() + with mx.name.NameManager(): + mx.name.current().get(None, "main_thread") + e1.set() + thread.join() + e1.clear() + e2.clear() assert status[0], "Spawned thread isn't using thread local NameManager" def test_blockscope(): - class dummy_block(object): - def __init__(self, prefix): - self.name = prefix - self._empty_prefix = False - self._profiler_scope_name = ':' + class dummy_block: + pass blockscope_list = [] status = [False] event = threading.Event() def f(): - net = dummy_block("spawned") # BlockScope only keeps a weakref to the Block - with block._BlockScope(net): - x = NameManager.current.get(None, "hello") + net = dummy_block() # BlockScope only keeps a weakref to the Block + with mx.gluon.block._block_scope(net): + x = mx.name.current().get(None, "hello") event.wait() - if x == "spawned_hello0": + if x == "dummy_block_hello0": status[0] = True thread = threading.Thread(target=f) thread.start() @@ -220,40 +230,3 @@ def f(): assert_almost_equal(data[1].asnumpy(), np.ones(shape=(0, 1, 2))) finally: set_np_shape(0) - -def test_blockscope_multithread(): - event = threading.Event() - status = [False] - - class dummy_block(object): - def __init__(self, prefix): - self.prefix = prefix - self._profiler_scope_name = prefix - self._empty_prefix = False - - def f(scope): - try: - with scope: - event.wait() - except: - status[0] = True - - def g(scope): - with scope: - pass - event.set() - - scope = block._BlockScope(dummy_block("scope_")) - count = 2 - threads = [threading.Thread(target=f, args=(scope,)), - threading.Thread(target=g, args=(scope,))] - for i in range(count): - threads[i].start() - for i in range(count): - threads[i].join() - assert status[0] is False, "_BlockScope does not work with multithread" - - -if __name__ == '__main__': - import nose - nose.runmodule()