Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Refactor scope functionality in Python API #18619

Merged
merged 8 commits into from
Jul 15, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions example/profiler/profiler_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ def parse_args():
print("execution begin")
for i in range(args.iter_num):
if i == args.begin_profiling_iter:
t0 = time.clock()
t0 = time.process_time()
mx.profiler.set_state('run')
if i == args.end_profiling_iter:
t1 = time.clock()
t1 = time.process_time()
mx.profiler.set_state('stop')
executor.forward()
c = executor.outputs[0]
Expand Down
8 changes: 8 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1748,6 +1748,14 @@ MXNET_DLL int MXSymbolGetNumOutputs(SymbolHandle symbol,
*/
MXNET_DLL int MXSymbolGetInternals(SymbolHandle symbol,
SymbolHandle *out);
/*!
* \brief Get a symbol that contains all the inputs.
* \param symbol The symbol
* \param out The output symbol whose outputs are all the internals.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolGetInputs(SymbolHandle symbol,
SymbolHandle *out);
/*!
* \brief Get a symbol that contains only direct children.
* \param symbol The symbol
Expand Down
46 changes: 14 additions & 32 deletions python/mxnet/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
"""Attribute scoping support for symbolic API."""
import threading
import warnings
import contextvars
from collections import defaultdict

from .base import string_types, classproperty, with_metaclass, _MXClassPropertyMetaClass
from .base import string_types

class AttrScope(with_metaclass(_MXClassPropertyMetaClass, object)):
class AttrScope:
"""Attribute manager for scoping.

User can also inherit this object to change naming behavior.
Expand All @@ -33,7 +30,6 @@ class AttrScope(with_metaclass(_MXClassPropertyMetaClass, object)):
kwargs
The attributes to set for all symbol creations in the scope.
"""
_current = threading.local()
_subgraph_names = defaultdict(int)

def __init__(self, **kwargs):
Expand Down Expand Up @@ -65,37 +61,23 @@ def get(self, attr):
else:
return attr if attr else {}

def __enter__(self):
# pylint: disable=protected-access
if not hasattr(AttrScope._current, "value"):
AttrScope._current.value = AttrScope()
self._old_scope = AttrScope._current.value
attr = AttrScope._current.value._attr.copy()
def __enter__(self): # pylint: disable=protected-access
attr = _current.get()._attr.copy()
attr.update(self._attr)
self._attr = attr
AttrScope._current.value = self
# Token can't be pickled and Token.old_value is Token.MISSING if _current.get() uses default value
self._old_scope = _current.get()
_current.set(self)
return self

def __exit__(self, ptype, value, trace):
assert self._old_scope
AttrScope._current.value = self._old_scope
_current.set(self._old_scope)


#pylint: disable=no-self-argument
@classproperty
def current(cls):
warnings.warn("AttrScope.current has been deprecated. "
"It is advised to use the `with` statement with AttrScope.",
DeprecationWarning)
if not hasattr(AttrScope._current, "value"):
cls._current.value = AttrScope()
return cls._current.value
_current = contextvars.ContextVar('namemanager', default=AttrScope())

@current.setter
def current(cls, val):
warnings.warn("AttrScope.current has been deprecated. "
"It is advised to use the `with` statement with AttrScope.",
DeprecationWarning)
cls._current.value = val
#pylint: enable=no-self-argument

AttrScope._current.value = AttrScope()
def current():
"""Returns the current name manager."""
return _current.get()
63 changes: 0 additions & 63 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,69 +273,6 @@ class MXCallbackList(ctypes.Structure):
]


# Please see: https://stackoverflow.com/questions/5189699/how-to-make-a-class-property
class _MXClassPropertyDescriptor(object):
def __init__(self, fget, fset=None):
self.fget = fget
self.fset = fset

def __get__(self, obj, clas=None):
if clas is None:
clas = type(obj)
return self.fget.__get__(obj, clas)()

def __set__(self, obj, value):
if not self.fset:
raise MXNetError("cannot use the setter: %s to set attribute" % obj.__name__)
if inspect.isclass(obj):
type_ = obj
obj = None
else:
type_ = type(obj)
return self.fset.__get__(obj, type_)(value)

def setter(self, func):
if not isinstance(func, (classmethod, staticmethod)):
func = classmethod(func)
self.fset = func
return self


class _MXClassPropertyMetaClass(type):
def __setattr__(cls, key, value):
obj = cls.__dict__.get(key)
if obj and isinstance(obj, _MXClassPropertyDescriptor):
return obj.__set__(cls, value)

return super(_MXClassPropertyMetaClass, cls).__setattr__(key, value)


# with_metaclass function obtained from: https://github.com/benjaminp/six/blob/master/six.py
# pylint: disable=unused-argument
def with_metaclass(meta, *bases):
"""Create a base class with a metaclass."""
# This requires a bit of explanation: the basic idea is to make a dummy
# metaclass for one level of class instantiation that replaces itself with
# the actual metaclass.
class metaclass(type):

def __new__(cls, name, this_bases, d):
return meta(name, bases, d)

@classmethod
def __prepare__(cls, name, this_bases):
return meta.__prepare__(name, bases)
return type.__new__(metaclass, 'temporary_class', (), {})
# pylint: enable=unused-argument


def classproperty(func):
if not isinstance(func, (classmethod, staticmethod)):
func = classmethod(func)

return _MXClassPropertyDescriptor(func)


def _load_lib():
"""Load library by searching possible path."""
lib_path = libinfo.find_lib_path()
Expand Down
49 changes: 10 additions & 39 deletions python/mxnet/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
"""Context management API of mxnet."""
import threading
import warnings
import contextvars
import ctypes
from .base import classproperty, with_metaclass, _MXClassPropertyMetaClass
from .base import _LIB
from .base import check_call


class Context(with_metaclass(_MXClassPropertyMetaClass, object)):
class Context:
"""Constructs a context.

MXNet can run operations on CPU and different GPUs.
Expand Down Expand Up @@ -66,8 +62,6 @@ class Context(with_metaclass(_MXClassPropertyMetaClass, object)):
>>> gpu_array.context
gpu(1)
"""
# static class variable
_default_ctx = threading.local()
devtype2str = {1: 'cpu', 2: 'gpu', 3: 'cpu_pinned', 5: 'cpu_shared'}
devstr2type = {'cpu': 1, 'gpu': 2, 'cpu_pinned': 3, 'cpu_shared': 5}
def __init__(self, device_type, device_id=0):
Expand Down Expand Up @@ -115,34 +109,13 @@ def __repr__(self):
return self.__str__()

def __enter__(self):
if not hasattr(Context._default_ctx, "value"):
Context._default_ctx.value = Context('cpu', 0)
self._old_ctx = Context._default_ctx.value
Context._default_ctx.value = self
# Token can't be pickled and Token.old_value is Token.MISSING if _current.get() uses default value
self._old_ctx = _current.get()
_current.set(self)
return self

def __exit__(self, ptype, value, trace):
Context._default_ctx.value = self._old_ctx

#pylint: disable=no-self-argument
@classproperty
def default_ctx(cls):
warnings.warn("Context.default_ctx has been deprecated. "
"Please use Context.current_context() instead. "
"Please use test_utils.set_default_context to set a default context",
DeprecationWarning)
if not hasattr(Context._default_ctx, "value"):
cls._default_ctx.value = Context('cpu', 0)
return cls._default_ctx.value

@default_ctx.setter
def default_ctx(cls, val):
warnings.warn("Context.default_ctx has been deprecated. "
"Please use Context.current_context() instead. "
"Please use test_utils.set_default_context to set a default context",
DeprecationWarning)
cls._default_ctx.value = val
#pylint: enable=no-self-argument
_current.set(self._old_ctx)

def empty_cache(self):
"""Empties the memory cache for the current contexts device.
Expand All @@ -162,9 +135,6 @@ def empty_cache(self):
dev_id = ctypes.c_int(self.device_id)
check_call(_LIB.MXStorageEmptyCache(dev_type, dev_id))

# initialize the default context in Context
Context._default_ctx.value = Context('cpu', 0)


def cpu(device_id=0):
"""Returns a CPU context.
Expand Down Expand Up @@ -299,6 +269,9 @@ def gpu_memory_info(device_id=0):
return (free.value, total.value)


_current = contextvars.ContextVar('namemanager', default=Context('cpu', 0))


def current_context():
"""Returns the current context.

Expand All @@ -321,6 +294,4 @@ def current_context():
-------
default_ctx : Context
"""
if not hasattr(Context._default_ctx, "value"):
Context._default_ctx.value = Context('cpu', 0)
return Context._default_ctx.value
return _current.get()
Loading