diff --git a/doc/ref_fundamental.rst b/doc/ref_fundamental.rst index f769771..63eefb3 100644 --- a/doc/ref_fundamental.rst +++ b/doc/ref_fundamental.rst @@ -6,6 +6,14 @@ Reference: Basic Building Blocks Context ------- +.. note:: this class implements Python's ``__copy__`` and ``__deepcopy__`` + protocols. Each of these returns the context being 'copied' identically. + +.. note:: during an pickle operation, the current default :class:`Context` + is always used. + +.. seealso:: :ref:`sec-context-management` + .. autoclass:: Context() :members: diff --git a/doc/reference.rst b/doc/reference.rst index 37d6f47..3a7b5e7 100644 --- a/doc/reference.rst +++ b/doc/reference.rst @@ -88,15 +88,15 @@ Lifetime Helpers function call to which they're passed. These callback return a callback handle that must be kept alive until the callback is no longer needed. -Global Data -^^^^^^^^^^^ +.. _sec-context-management: + +Context Management +^^^^^^^^^^^^^^^^^^ -.. data:: DEFAULT_CONTEXT +.. autofunction:: get_default_context - ISL objects being unpickled or initialized from strings will be instantiated - within this :class:`Context`. +.. autofunction:: push_context - .. versionadded:: 2015.2 Symbolic Constants ^^^^^^^^^^^^^^^^^^ diff --git a/islpy/__init__.py b/islpy/__init__.py index c998af9..9a5c559 100644 --- a/islpy/__init__.py +++ b/islpy/__init__.py @@ -20,6 +20,9 @@ THE SOFTWARE. """ +import sys +import contextlib + import islpy._isl as _isl from islpy.version import VERSION, VERSION_TEXT # noqa import six @@ -145,14 +148,104 @@ EXPR_CLASSES = tuple(cls for cls in ALL_CLASSES if "Aff" in cls.__name__ or "Polynomial" in cls.__name__) -DEFAULT_CONTEXT = Context() + +def _module_property(func): + """Decorator to turn module functions into properties. + Function names must be prefixed with an underscore.""" + module = sys.modules[func.__module__] + + def base_getattr(name): + raise AttributeError( + f"module '{module.__name__}' has no attribute '{name}'") + + old_getattr = getattr(module, "__getattr__", base_getattr) + + def new_getattr(name): + if f"_{name}" == func.__name__: + return func() + else: + return old_getattr(name) + + module.__getattr__ = new_getattr + return func + + +import threading + + +_thread_local_storage = threading.local() + + +def _check_init_default_context(): + if not hasattr(_thread_local_storage, "islpy_default_contexts"): + _thread_local_storage.islpy_default_contexts = [Context()] + + +def get_default_context(): + """Get or create the default context under current thread. + + :return: the current default :class:`Context` + + .. versionadded:: 2020.3 + """ + _check_init_default_context() + return _thread_local_storage.islpy_default_contexts[-1] def _get_default_context(): - """A callable to get the default context for the benefit of Python's - ``__reduce__`` protocol. + from warnings import warn + warn("It appears that you might be deserializing an islpy.Context" + "that was serialized by a previous version of islpy." + "If so, this is discouraged and please consider to re-serialize" + "the Context with the newer version to avoid possible inconsistencies.", + UserWarning) + return get_default_context() + + +@contextlib.contextmanager +def push_context(ctx=None): + """Context manager to push new default :class:`Context` + + :param ctx: an optional explicit context that is pushed to + the stack of default :class:`Context` s + + .. versionadded:: 2020.3 + + :mod:`islpy` internally maintains a stack of default :class:`Context` s + for each Python thread. + By default, each stack is initialized with a base default :class:`Context`. + ISL objects being unpickled or initialized from strings will be + instantiated within the top :class:`Context` of the stack of + the executing thread. + + Usage example:: + + with islpy.push_context() as dctx: + s = islpy.Set("{[0]: }") + assert s.get_ctx() == dctx + """ - return DEFAULT_CONTEXT + if ctx is None: + ctx = Context() + _check_init_default_context() + _thread_local_storage.islpy_default_contexts.append(ctx) + yield ctx + _thread_local_storage.islpy_default_contexts.pop() + + +@_module_property +def _DEFAULT_CONTEXT(): # noqa: N802 + from warnings import warn + warn("Use of islpy.DEFAULT_CONTEXT is deprecated " + "and will be removed in 2022." + " Please use `islpy.get_default_context()` instead. ", + FutureWarning, + stacklevel=3) + return get_default_context() + + +if sys.version_info < (3, 7): + DEFAULT_CONTEXT = get_default_context() def _read_from_str_wrapper(cls, context, s): @@ -168,10 +261,14 @@ def _add_functionality(): # {{{ Context def context_reduce(self): - if self._wraps_same_instance_as(DEFAULT_CONTEXT): - return (_get_default_context, ()) - else: - return (Context, ()) + return (get_default_context, ()) + + def context_copy(self): + return self + + def context_deepcopy(self, memo): + del memo + return self def context_eq(self, other): return isinstance(other, Context) and self._wraps_same_instance_as(other) @@ -180,9 +277,10 @@ def context_ne(self, other): return not self.__eq__(other) Context.__reduce__ = context_reduce + Context.__copy__ = context_copy + Context.__deepcopy__ = context_deepcopy Context.__eq__ = context_eq Context.__ne__ = context_ne - # }}} # {{{ generic initialization, pickling @@ -197,7 +295,7 @@ def obj_new(cls, s=None, context=None): return cls._prev_new(cls) if context is None: - context = DEFAULT_CONTEXT + context = get_default_context() result = cls.read_from_str(context, s) return result @@ -473,7 +571,7 @@ def obj_get_coefficients_by_name(self, dimtype=None, dim_to_name=None): def id_new(cls, name, user=None, context=None): if context is None: - context = DEFAULT_CONTEXT + context = get_default_context() result = cls.alloc(context, name, user) result._made_from_python = True @@ -777,7 +875,7 @@ def expr_like_floordiv(self, other): def val_new(cls, src, context=None): if context is None: - context = DEFAULT_CONTEXT + context = get_default_context() if isinstance(src, six.string_types): result = cls.read_from_str(context, src) @@ -1274,7 +1372,7 @@ def make_zero_and_vars(set_vars, params=[], ctx=None): ) """ if ctx is None: - ctx = DEFAULT_CONTEXT + ctx = get_default_context() if isinstance(set_vars, str): set_vars = [s.strip() for s in set_vars.split(",")] diff --git a/test/test_isl.py b/test/test_isl.py index 0b6002d..62e2aea 100644 --- a/test/test_isl.py +++ b/test/test_isl.py @@ -189,11 +189,11 @@ def cb_print_for(printer, options, node): printer = printer.print_str("Callback For") return printer - opts = isl.AstPrintOptions.alloc(isl.DEFAULT_CONTEXT) + opts = isl.AstPrintOptions.alloc(isl.get_default_context()) opts, cb_print_user_handle = opts.set_print_user(cb_print_user) opts, cb_print_for_handle = opts.set_print_for(cb_print_for) - printer = isl.Printer.to_str(isl.DEFAULT_CONTEXT) + printer = isl.Printer.to_str(isl.get_default_context()) printer = printer.set_output_format(isl.format.C) printer.print_str("// Start\n") printer = ast.print_(printer, opts) @@ -248,7 +248,7 @@ def isl_ast_codegen(S): # noqa: N803 m = isl.Map.identity(m.get_space()) m = isl.Map.from_domain(S) ast = b.ast_from_schedule(m) - p = isl.Printer.to_str(isl.DEFAULT_CONTEXT) + p = isl.Printer.to_str(isl.get_default_context()) p = p.set_output_format(isl.format.C) p.flush() p = p.print_ast_node(ast) @@ -362,8 +362,47 @@ def test_bound(): def test_copy_context(): ctx = isl.Context() import copy - assert not ctx._wraps_same_instance_as(copy.copy(ctx)) - assert not isl.DEFAULT_CONTEXT._wraps_same_instance_as(copy.copy(ctx)) + assert ctx._wraps_same_instance_as(copy.copy(ctx)) + assert ctx == copy.copy(ctx) + assert not isl.get_default_context()._wraps_same_instance_as(copy.copy(ctx)) + + +def test_context_manager(): + import pickle + + def transfer_copy(obj): + return pickle.loads(pickle.dumps(obj)) + + b1 = isl.BasicSet("{ [0] : }") + old_dctx = isl.get_default_context() + assert b1.get_ctx() == old_dctx + + with isl.push_context() as dctx: + assert dctx == isl.get_default_context() + assert not old_dctx._wraps_same_instance_as(dctx) + b2 = isl.BasicSet("{ [0] : }") + assert b2.get_ctx() == dctx + # Under context manager always use `dctx` + assert transfer_copy(b2).get_ctx() == transfer_copy(b1).get_ctx() == dctx + + # Check for proper exit + assert old_dctx == isl.get_default_context() + + # Check for nested context + with isl.push_context() as c1: + with isl.push_context() as c2: + assert c1 != c2 + with isl.push_context() as c3: + assert c2 != c3 + # Check for proper exit + assert old_dctx == isl.get_default_context() + + +def test_deprecated_default_context(): + import warnings + with warnings.catch_warnings(): + dctx = isl.DEFAULT_CONTEXT + assert dctx == isl.get_default_context() def test_ast_node_list_free():