Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Refactor scope functionality in Python API #18619

Merged
merged 8 commits into from
Jul 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions ci/docker/Dockerfile.build.centos7
Original file line number Diff line number Diff line change
Expand Up @@ -119,22 +119,14 @@ RUN export SHORT_CUDA_VERSION=${CUDA_VERSION%.*} && \
yum clean all; \
fi

# Python dependencies
RUN pip3 install --no-cache-dir --upgrade pip && \
pip3 install --no-cache-dir pylint cython numpy requests h5py scipy==1.2.3 wheel \
pytest==5.3.5 \
pytest-env==0.6.2 \
pytest-cov==2.8.1 \
pytest-xdist==1.31.0 \
pytest-timeout==1.3.4 \
mock==2.0.0 \
onnx==1.5.0 \
protobuf==3.5.2 \
tabulate==0.7.5

# Fix the en_DK.UTF-8 locale to test locale invariance
RUN localedef -i en_DK -f UTF-8 en_DK.UTF-8

# Python dependencies
RUN python3 -m pip install --upgrade pip
COPY install/requirements /work/
RUN python3 -m pip install -r /work/requirements

ARG USER_ID=0
COPY install/docker_filepermissions.sh /work/
RUN /work/docker_filepermissions.sh
Expand Down
2 changes: 2 additions & 0 deletions ci/docker/install/requirements
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
numpy>=1.17
requests>=2.20.0,<3
graphviz<0.9.0,>=0.8.1
contextvars;python_version<"3.7"

# Optional dependencies
onnx==1.5.0
Expand All @@ -42,6 +43,7 @@ pytest-xdist==1.31.0
pytest-timeout==1.3.4
flaky==3.6.1
setuptools
wheel
mock==2.0.0

# TVM dependencies
Expand Down
4 changes: 2 additions & 2 deletions example/profiler/profiler_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ def parse_args():
print("execution begin")
for i in range(args.iter_num):
if i == args.begin_profiling_iter:
t0 = time.clock()
t0 = time.process_time()
mx.profiler.set_state('run')
if i == args.end_profiling_iter:
t1 = time.clock()
t1 = time.process_time()
mx.profiler.set_state('stop')
executor.forward()
c = executor.outputs[0]
Expand Down
8 changes: 8 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1748,6 +1748,14 @@ MXNET_DLL int MXSymbolGetNumOutputs(SymbolHandle symbol,
*/
MXNET_DLL int MXSymbolGetInternals(SymbolHandle symbol,
SymbolHandle *out);
/*!
* \brief Get a symbol that contains all the inputs.
* \param symbol The symbol
* \param out The output symbol whose outputs are all the internals.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolGetInputs(SymbolHandle symbol,
SymbolHandle *out);
/*!
* \brief Get a symbol that contains only direct children.
* \param symbol The symbol
Expand Down
46 changes: 14 additions & 32 deletions python/mxnet/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
"""Attribute scoping support for symbolic API."""
import threading
import warnings
import contextvars
from collections import defaultdict

from .base import string_types, classproperty, with_metaclass, _MXClassPropertyMetaClass
from .base import string_types

class AttrScope(with_metaclass(_MXClassPropertyMetaClass, object)):
class AttrScope:
"""Attribute manager for scoping.

User can also inherit this object to change naming behavior.
Expand All @@ -33,7 +30,6 @@ class AttrScope(with_metaclass(_MXClassPropertyMetaClass, object)):
kwargs
The attributes to set for all symbol creations in the scope.
"""
_current = threading.local()
_subgraph_names = defaultdict(int)

def __init__(self, **kwargs):
Expand Down Expand Up @@ -65,37 +61,23 @@ def get(self, attr):
else:
return attr if attr else {}

def __enter__(self):
# pylint: disable=protected-access
if not hasattr(AttrScope._current, "value"):
AttrScope._current.value = AttrScope()
self._old_scope = AttrScope._current.value
attr = AttrScope._current.value._attr.copy()
def __enter__(self): # pylint: disable=protected-access
attr = _current.get()._attr.copy()
attr.update(self._attr)
self._attr = attr
AttrScope._current.value = self
# Token can't be pickled and Token.old_value is Token.MISSING if _current.get() uses default value
self._old_scope = _current.get()
_current.set(self)
return self

def __exit__(self, ptype, value, trace):
assert self._old_scope
AttrScope._current.value = self._old_scope
_current.set(self._old_scope)


#pylint: disable=no-self-argument
@classproperty
def current(cls):
warnings.warn("AttrScope.current has been deprecated. "
"It is advised to use the `with` statement with AttrScope.",
DeprecationWarning)
if not hasattr(AttrScope._current, "value"):
cls._current.value = AttrScope()
return cls._current.value
_current = contextvars.ContextVar('namemanager', default=AttrScope())

@current.setter
def current(cls, val):
warnings.warn("AttrScope.current has been deprecated. "
"It is advised to use the `with` statement with AttrScope.",
DeprecationWarning)
cls._current.value = val
#pylint: enable=no-self-argument

AttrScope._current.value = AttrScope()
def current():
"""Returns the current name manager."""
return _current.get()
63 changes: 0 additions & 63 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,69 +273,6 @@ class MXCallbackList(ctypes.Structure):
]


# Please see: https://stackoverflow.com/questions/5189699/how-to-make-a-class-property
class _MXClassPropertyDescriptor(object):
def __init__(self, fget, fset=None):
self.fget = fget
self.fset = fset

def __get__(self, obj, clas=None):
if clas is None:
clas = type(obj)
return self.fget.__get__(obj, clas)()

def __set__(self, obj, value):
if not self.fset:
raise MXNetError("cannot use the setter: %s to set attribute" % obj.__name__)
if inspect.isclass(obj):
type_ = obj
obj = None
else:
type_ = type(obj)
return self.fset.__get__(obj, type_)(value)

def setter(self, func):
if not isinstance(func, (classmethod, staticmethod)):
func = classmethod(func)
self.fset = func
return self


class _MXClassPropertyMetaClass(type):
def __setattr__(cls, key, value):
obj = cls.__dict__.get(key)
if obj and isinstance(obj, _MXClassPropertyDescriptor):
return obj.__set__(cls, value)

return super(_MXClassPropertyMetaClass, cls).__setattr__(key, value)


# with_metaclass function obtained from: https://github.com/benjaminp/six/blob/master/six.py
# pylint: disable=unused-argument
def with_metaclass(meta, *bases):
"""Create a base class with a metaclass."""
# This requires a bit of explanation: the basic idea is to make a dummy
# metaclass for one level of class instantiation that replaces itself with
# the actual metaclass.
class metaclass(type):

def __new__(cls, name, this_bases, d):
return meta(name, bases, d)

@classmethod
def __prepare__(cls, name, this_bases):
return meta.__prepare__(name, bases)
return type.__new__(metaclass, 'temporary_class', (), {})
# pylint: enable=unused-argument


def classproperty(func):
if not isinstance(func, (classmethod, staticmethod)):
func = classmethod(func)

return _MXClassPropertyDescriptor(func)


def _load_lib():
"""Load library by searching possible path."""
lib_path = libinfo.find_lib_path()
Expand Down
49 changes: 10 additions & 39 deletions python/mxnet/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
"""Context management API of mxnet."""
import threading
import warnings
import contextvars
import ctypes
from .base import classproperty, with_metaclass, _MXClassPropertyMetaClass
from .base import _LIB
from .base import check_call


class Context(with_metaclass(_MXClassPropertyMetaClass, object)):
class Context:
"""Constructs a context.

MXNet can run operations on CPU and different GPUs.
Expand Down Expand Up @@ -66,8 +62,6 @@ class Context(with_metaclass(_MXClassPropertyMetaClass, object)):
>>> gpu_array.context
gpu(1)
"""
# static class variable
_default_ctx = threading.local()
devtype2str = {1: 'cpu', 2: 'gpu', 3: 'cpu_pinned', 5: 'cpu_shared'}
devstr2type = {'cpu': 1, 'gpu': 2, 'cpu_pinned': 3, 'cpu_shared': 5}
def __init__(self, device_type, device_id=0):
Expand Down Expand Up @@ -115,34 +109,13 @@ def __repr__(self):
return self.__str__()

def __enter__(self):
if not hasattr(Context._default_ctx, "value"):
Context._default_ctx.value = Context('cpu', 0)
self._old_ctx = Context._default_ctx.value
Context._default_ctx.value = self
# Token can't be pickled and Token.old_value is Token.MISSING if _current.get() uses default value
self._old_ctx = _current.get()
_current.set(self)
return self

def __exit__(self, ptype, value, trace):
Context._default_ctx.value = self._old_ctx

#pylint: disable=no-self-argument
@classproperty
def default_ctx(cls):
warnings.warn("Context.default_ctx has been deprecated. "
"Please use Context.current_context() instead. "
"Please use test_utils.set_default_context to set a default context",
DeprecationWarning)
if not hasattr(Context._default_ctx, "value"):
cls._default_ctx.value = Context('cpu', 0)
return cls._default_ctx.value

@default_ctx.setter
def default_ctx(cls, val):
warnings.warn("Context.default_ctx has been deprecated. "
"Please use Context.current_context() instead. "
"Please use test_utils.set_default_context to set a default context",
DeprecationWarning)
cls._default_ctx.value = val
#pylint: enable=no-self-argument
_current.set(self._old_ctx)

def empty_cache(self):
"""Empties the memory cache for the current contexts device.
Expand All @@ -162,9 +135,6 @@ def empty_cache(self):
dev_id = ctypes.c_int(self.device_id)
check_call(_LIB.MXStorageEmptyCache(dev_type, dev_id))

# initialize the default context in Context
Context._default_ctx.value = Context('cpu', 0)


def cpu(device_id=0):
"""Returns a CPU context.
Expand Down Expand Up @@ -299,6 +269,9 @@ def gpu_memory_info(device_id=0):
return (free.value, total.value)


_current = contextvars.ContextVar('namemanager', default=Context('cpu', 0))


def current_context():
"""Returns the current context.

Expand All @@ -321,6 +294,4 @@ def current_context():
-------
default_ctx : Context
"""
if not hasattr(Context._default_ctx, "value"):
Context._default_ctx.value = Context('cpu', 0)
return Context._default_ctx.value
return _current.get()
Loading