Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Refactor scope functionality in Python API (#18619)
Browse files Browse the repository at this point in the history
* Refactor scope functionality in Python API

- Remove deprecated metaclass functionality
- Remove global state in naming
- Switch from threading.local to asyncio compatible contextvars
- Stop exposing UUIDs in parameter name

* Fix dependencies

* Fixes

* Fixes

* Fix

* Fix after merge master
  • Loading branch information
leezu authored Jul 15, 2020
1 parent 12ec046 commit e2366e9
Show file tree
Hide file tree
Showing 40 changed files with 466 additions and 626 deletions.
18 changes: 5 additions & 13 deletions ci/docker/Dockerfile.build.centos7
Original file line number Diff line number Diff line change
Expand Up @@ -119,22 +119,14 @@ RUN export SHORT_CUDA_VERSION=${CUDA_VERSION%.*} && \
yum clean all; \
fi

# Python dependencies
RUN pip3 install --no-cache-dir --upgrade pip && \
pip3 install --no-cache-dir pylint cython numpy requests h5py scipy==1.2.3 wheel \
pytest==5.3.5 \
pytest-env==0.6.2 \
pytest-cov==2.8.1 \
pytest-xdist==1.31.0 \
pytest-timeout==1.3.4 \
mock==2.0.0 \
onnx==1.5.0 \
protobuf==3.5.2 \
tabulate==0.7.5

# Fix the en_DK.UTF-8 locale to test locale invariance
RUN localedef -i en_DK -f UTF-8 en_DK.UTF-8

# Python dependencies
RUN python3 -m pip install --upgrade pip
COPY install/requirements /work/
RUN python3 -m pip install -r /work/requirements

ARG USER_ID=0
COPY install/docker_filepermissions.sh /work/
RUN /work/docker_filepermissions.sh
Expand Down
2 changes: 2 additions & 0 deletions ci/docker/install/requirements
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
numpy>=1.17
requests>=2.20.0,<3
graphviz<0.9.0,>=0.8.1
contextvars;python_version<"3.7"

# Optional dependencies
onnx==1.5.0
Expand All @@ -42,6 +43,7 @@ pytest-xdist==1.31.0
pytest-timeout==1.3.4
flaky==3.6.1
setuptools
wheel
mock==2.0.0

# TVM dependencies
Expand Down
4 changes: 2 additions & 2 deletions example/profiler/profiler_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ def parse_args():
print("execution begin")
for i in range(args.iter_num):
if i == args.begin_profiling_iter:
t0 = time.clock()
t0 = time.process_time()
mx.profiler.set_state('run')
if i == args.end_profiling_iter:
t1 = time.clock()
t1 = time.process_time()
mx.profiler.set_state('stop')
executor.forward()
c = executor.outputs[0]
Expand Down
8 changes: 8 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1748,6 +1748,14 @@ MXNET_DLL int MXSymbolGetNumOutputs(SymbolHandle symbol,
*/
MXNET_DLL int MXSymbolGetInternals(SymbolHandle symbol,
SymbolHandle *out);
/*!
* \brief Get a symbol that contains all the inputs.
* \param symbol The symbol
* \param out The output symbol whose outputs are all the internals.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolGetInputs(SymbolHandle symbol,
SymbolHandle *out);
/*!
* \brief Get a symbol that contains only direct children.
* \param symbol The symbol
Expand Down
46 changes: 14 additions & 32 deletions python/mxnet/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
"""Attribute scoping support for symbolic API."""
import threading
import warnings
import contextvars
from collections import defaultdict

from .base import string_types, classproperty, with_metaclass, _MXClassPropertyMetaClass
from .base import string_types

class AttrScope(with_metaclass(_MXClassPropertyMetaClass, object)):
class AttrScope:
"""Attribute manager for scoping.
User can also inherit this object to change naming behavior.
Expand All @@ -33,7 +30,6 @@ class AttrScope(with_metaclass(_MXClassPropertyMetaClass, object)):
kwargs
The attributes to set for all symbol creations in the scope.
"""
_current = threading.local()
_subgraph_names = defaultdict(int)

def __init__(self, **kwargs):
Expand Down Expand Up @@ -65,37 +61,23 @@ def get(self, attr):
else:
return attr if attr else {}

def __enter__(self):
# pylint: disable=protected-access
if not hasattr(AttrScope._current, "value"):
AttrScope._current.value = AttrScope()
self._old_scope = AttrScope._current.value
attr = AttrScope._current.value._attr.copy()
def __enter__(self): # pylint: disable=protected-access
attr = _current.get()._attr.copy()
attr.update(self._attr)
self._attr = attr
AttrScope._current.value = self
# Token can't be pickled and Token.old_value is Token.MISSING if _current.get() uses default value
self._old_scope = _current.get()
_current.set(self)
return self

def __exit__(self, ptype, value, trace):
assert self._old_scope
AttrScope._current.value = self._old_scope
_current.set(self._old_scope)


#pylint: disable=no-self-argument
@classproperty
def current(cls):
warnings.warn("AttrScope.current has been deprecated. "
"It is advised to use the `with` statement with AttrScope.",
DeprecationWarning)
if not hasattr(AttrScope._current, "value"):
cls._current.value = AttrScope()
return cls._current.value
_current = contextvars.ContextVar('namemanager', default=AttrScope())

@current.setter
def current(cls, val):
warnings.warn("AttrScope.current has been deprecated. "
"It is advised to use the `with` statement with AttrScope.",
DeprecationWarning)
cls._current.value = val
#pylint: enable=no-self-argument

AttrScope._current.value = AttrScope()
def current():
"""Returns the current name manager."""
return _current.get()
63 changes: 0 additions & 63 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,69 +273,6 @@ class MXCallbackList(ctypes.Structure):
]


# Please see: https://stackoverflow.com/questions/5189699/how-to-make-a-class-property
class _MXClassPropertyDescriptor(object):
def __init__(self, fget, fset=None):
self.fget = fget
self.fset = fset

def __get__(self, obj, clas=None):
if clas is None:
clas = type(obj)
return self.fget.__get__(obj, clas)()

def __set__(self, obj, value):
if not self.fset:
raise MXNetError("cannot use the setter: %s to set attribute" % obj.__name__)
if inspect.isclass(obj):
type_ = obj
obj = None
else:
type_ = type(obj)
return self.fset.__get__(obj, type_)(value)

def setter(self, func):
if not isinstance(func, (classmethod, staticmethod)):
func = classmethod(func)
self.fset = func
return self


class _MXClassPropertyMetaClass(type):
def __setattr__(cls, key, value):
obj = cls.__dict__.get(key)
if obj and isinstance(obj, _MXClassPropertyDescriptor):
return obj.__set__(cls, value)

return super(_MXClassPropertyMetaClass, cls).__setattr__(key, value)


# with_metaclass function obtained from: https://github.com/benjaminp/six/blob/master/six.py
# pylint: disable=unused-argument
def with_metaclass(meta, *bases):
"""Create a base class with a metaclass."""
# This requires a bit of explanation: the basic idea is to make a dummy
# metaclass for one level of class instantiation that replaces itself with
# the actual metaclass.
class metaclass(type):

def __new__(cls, name, this_bases, d):
return meta(name, bases, d)

@classmethod
def __prepare__(cls, name, this_bases):
return meta.__prepare__(name, bases)
return type.__new__(metaclass, 'temporary_class', (), {})
# pylint: enable=unused-argument


def classproperty(func):
if not isinstance(func, (classmethod, staticmethod)):
func = classmethod(func)

return _MXClassPropertyDescriptor(func)


def _load_lib():
"""Load library by searching possible path."""
lib_path = libinfo.find_lib_path()
Expand Down
49 changes: 10 additions & 39 deletions python/mxnet/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
"""Context management API of mxnet."""
import threading
import warnings
import contextvars
import ctypes
from .base import classproperty, with_metaclass, _MXClassPropertyMetaClass
from .base import _LIB
from .base import check_call


class Context(with_metaclass(_MXClassPropertyMetaClass, object)):
class Context:
"""Constructs a context.
MXNet can run operations on CPU and different GPUs.
Expand Down Expand Up @@ -66,8 +62,6 @@ class Context(with_metaclass(_MXClassPropertyMetaClass, object)):
>>> gpu_array.context
gpu(1)
"""
# static class variable
_default_ctx = threading.local()
devtype2str = {1: 'cpu', 2: 'gpu', 3: 'cpu_pinned', 5: 'cpu_shared'}
devstr2type = {'cpu': 1, 'gpu': 2, 'cpu_pinned': 3, 'cpu_shared': 5}
def __init__(self, device_type, device_id=0):
Expand Down Expand Up @@ -115,34 +109,13 @@ def __repr__(self):
return self.__str__()

def __enter__(self):
if not hasattr(Context._default_ctx, "value"):
Context._default_ctx.value = Context('cpu', 0)
self._old_ctx = Context._default_ctx.value
Context._default_ctx.value = self
# Token can't be pickled and Token.old_value is Token.MISSING if _current.get() uses default value
self._old_ctx = _current.get()
_current.set(self)
return self

def __exit__(self, ptype, value, trace):
Context._default_ctx.value = self._old_ctx

#pylint: disable=no-self-argument
@classproperty
def default_ctx(cls):
warnings.warn("Context.default_ctx has been deprecated. "
"Please use Context.current_context() instead. "
"Please use test_utils.set_default_context to set a default context",
DeprecationWarning)
if not hasattr(Context._default_ctx, "value"):
cls._default_ctx.value = Context('cpu', 0)
return cls._default_ctx.value

@default_ctx.setter
def default_ctx(cls, val):
warnings.warn("Context.default_ctx has been deprecated. "
"Please use Context.current_context() instead. "
"Please use test_utils.set_default_context to set a default context",
DeprecationWarning)
cls._default_ctx.value = val
#pylint: enable=no-self-argument
_current.set(self._old_ctx)

def empty_cache(self):
"""Empties the memory cache for the current contexts device.
Expand All @@ -162,9 +135,6 @@ def empty_cache(self):
dev_id = ctypes.c_int(self.device_id)
check_call(_LIB.MXStorageEmptyCache(dev_type, dev_id))

# initialize the default context in Context
Context._default_ctx.value = Context('cpu', 0)


def cpu(device_id=0):
"""Returns a CPU context.
Expand Down Expand Up @@ -299,6 +269,9 @@ def gpu_memory_info(device_id=0):
return (free.value, total.value)


_current = contextvars.ContextVar('namemanager', default=Context('cpu', 0))


def current_context():
"""Returns the current context.
Expand All @@ -321,6 +294,4 @@ def current_context():
-------
default_ctx : Context
"""
if not hasattr(Context._default_ctx, "value"):
Context._default_ctx.value = Context('cpu', 0)
return Context._default_ctx.value
return _current.get()
Loading

0 comments on commit e2366e9

Please sign in to comment.