From c00d9a4b04a1861dd884d311bdc966ddfacd83d6 Mon Sep 17 00:00:00 2001 From: Sebastian Marsching Date: Tue, 27 Nov 2018 20:59:41 +0100 Subject: [PATCH] Fix race condition in Salt loader. There was a race condition in the salt loader when injecting global values (e.g. "__pillar__" or "__salt__") into modules. One effect of this race condition was that in a setup with multiple threads, some threads may see pillar data intended for other threads or the pillar data seen by a thread might even change spuriously. There have been earlier attempts to fix this problem (#27937, #29397). These patches tried to fix the problem by storing the dictionary that keeps the relevant data in a thread-local variable and referencing this thread-local variable from the variables that are injected into the modules. These patches did not fix the problem completely because they only work when a module is loaded through a single loader instance only. When there is more than one loader, there is more than one thread-local variable and the variable injected into a module is changed to point to another thread-local variable when the module is loaded again. Thus, the problem resurfaced while working on #39670. This patch attempts to solve the problem from a slightly different angle, complementing the earlier patches: The value injected into the modules now is a proxy that internally uses a thread-local variable to decide to which object it points. This means that when loading a module again through a different loader (possibly passing different pillar data), the data is actually only changed in the thread in which the loader is used. Other threads are not affected by such a change. This means that it will work correctly in the current situation where loaders are possibly created by many different modules and these modules do not necessary know in which context they are executed. Thus it is much more flexible and reliable than the more explicit approach used by the two earlier patches. Unfortunately, the stand JSON and Msgpack serialization code cannot handle proxied objects, so they have to be unwrapped before passing them to that code. The salt.utils.json and salt.utils.msgpack modules have been modified to take care of unwrapping objects that are proxied using the ThreadLocalProxy. --- salt/loader.py | 74 ++- salt/utils/json.py | 17 +- salt/utils/msgpack.py | 19 +- salt/utils/thread_local_proxy.py | 599 ++++++++++++++++++++ tests/unit/utils/test_thread_local_proxy.py | 37 ++ 5 files changed, 741 insertions(+), 5 deletions(-) create mode 100644 salt/utils/thread_local_proxy.py create mode 100644 tests/unit/utils/test_thread_local_proxy.py diff --git a/salt/loader.py b/salt/loader.py index 8a6a9eca6b56..848eb201c1ab 100644 --- a/salt/loader.py +++ b/salt/loader.py @@ -14,6 +14,7 @@ import logging import inspect import tempfile +import threading import functools import threading import traceback @@ -33,6 +34,7 @@ import salt.utils.lazy import salt.utils.odict import salt.utils.platform +import salt.utils.thread_local_proxy import salt.utils.versions from salt.exceptions import LoaderError from salt.template import check_render_pipe_str @@ -1007,6 +1009,76 @@ def _mod_type(module_path): return 'ext' +def _inject_into_mod(mod, name, value, force_lock=False): + ''' + Inject a variable into a module. This is used to inject "globals" like + ``__salt__``, ``__pillar``, or ``grains``. + + Instead of injecting the value directly, a ``ThreadLocalProxy`` is created. + If such a proxy is already present under the specified name, it is updated + with the new value. This update only affects the current thread, so that + the same name can refer to different values depending on the thread of + execution. + + This is important for data that is not truly global. For example, pillar + data might be dynamically overriden through function parameters and thus + the actual values available in pillar might depend on the thread that is + calling a module. + + mod: + module object into which the value is going to be injected. + + name: + name of the variable that is injected into the module. + + value: + value that is injected into the variable. The value is not injected + directly, but instead set as the new reference of the proxy that has + been created for the variable. + + force_lock: + whether the lock should be acquired before checking whether a proxy + object for the specified name has already been injected into the + module. If ``False`` (the default), this function checks for the + module's variable without acquiring the lock and only acquires the lock + if a new proxy has to be created and injected. + ''' + from salt.utils.thread_local_proxy import ThreadLocalProxy + old_value = getattr(mod, name, None) + # We use a double-checked locking scheme in order to avoid taking the lock + # when a proxy object has already been injected. + # In most programming languages, double-checked locking is considered + # unsafe when used without explicit memory barriers because one might read + # an uninitialized value. In CPython it is safe due to the global + # interpreter lock (GIL). In Python implementations that do not have the + # GIL, it could be unsafe, but at least Jython also guarantees that (for + # Python objects) memory is not corrupted when writing and reading without + # explicit synchronization + # (http://www.jython.org/jythonbook/en/1.0/Concurrency.html). + # Please note that in order to make this code safe in a runtime environment + # that does not make this guarantees, it is not sufficient. The + # ThreadLocalProxy must also be created with fallback_to_shared set to + # False or a lock must be added to the ThreadLocalProxy. + if force_lock: + with _inject_into_mod.lock: + if isinstance(old_value, ThreadLocalProxy): + ThreadLocalProxy.set_reference(old_value, value) + else: + setattr(mod, name, ThreadLocalProxy(value, True)) + else: + if isinstance(old_value, ThreadLocalProxy): + ThreadLocalProxy.set_reference(old_value, value) + else: + _inject_into_mod(mod, name, value, True) + + +# Lock used when injecting globals. This is needed to avoid a race condition +# when two threads try to load the same module concurrently. This must be +# outside the loader because there might be more than one loader for the same +# namespace. +_inject_into_mod.lock = threading.RLock() + + # TODO: move somewhere else? class FilterDictWrapper(MutableMapping): ''' @@ -1560,7 +1632,7 @@ def _load_module(self, name): # pack whatever other globals we were asked to for p_name, p_value in six.iteritems(self.pack): - setattr(mod, p_name, p_value) + _inject_into_mod(mod, p_name, p_value) module_name = mod.__name__.rsplit('.', 1)[-1] diff --git a/salt/utils/json.py b/salt/utils/json.py index a578b8f8436d..547baad79996 100644 --- a/salt/utils/json.py +++ b/salt/utils/json.py @@ -12,6 +12,7 @@ # Import Salt libs import salt.utils.data import salt.utils.stringutils +from salt.utils.thread_local_proxy import ThreadLocalProxy # Import 3rd-party libs from salt.ext import six @@ -114,11 +115,17 @@ def dump(obj, fp, **kwargs): using the _json_module argument) ''' json_module = kwargs.pop('_json_module', json) + orig_enc_func = kwargs.pop('default', lambda x: x) + + def _enc_func(obj): + obj = ThreadLocalProxy.unproxy(obj) + return orig_enc_func(obj) + if 'ensure_ascii' not in kwargs: kwargs['ensure_ascii'] = False if six.PY2: obj = salt.utils.data.encode(obj) - return json_module.dump(obj, fp, **kwargs) # future lint: blacklisted-function + return json_module.dump(obj, fp, default=_enc_func, **kwargs) # future lint: blacklisted-function def dumps(obj, **kwargs): @@ -138,8 +145,14 @@ def dumps(obj, **kwargs): ''' import sys json_module = kwargs.pop('_json_module', json) + orig_enc_func = kwargs.pop('default', lambda x: x) + + def _enc_func(obj): + obj = ThreadLocalProxy.unproxy(obj) + return orig_enc_func(obj) + if 'ensure_ascii' not in kwargs: kwargs['ensure_ascii'] = False if six.PY2: obj = salt.utils.data.encode(obj) - return json_module.dumps(obj, **kwargs) # future lint: blacklisted-function + return json_module.dumps(obj, default=_enc_func, **kwargs) # future lint: blacklisted-function diff --git a/salt/utils/msgpack.py b/salt/utils/msgpack.py index 7e66cb8ed739..7f36502afdcf 100644 --- a/salt/utils/msgpack.py +++ b/salt/utils/msgpack.py @@ -13,6 +13,9 @@ # Fall back to msgpack_pure import msgpack_pure as msgpack # pylint: disable=import-error +# Import Salt libs +from salt.utils.thread_local_proxy import ThreadLocalProxy + def pack(o, stream, **kwargs): ''' @@ -26,7 +29,13 @@ def pack(o, stream, **kwargs): msgpack module using the _msgpack_module argument. ''' msgpack_module = kwargs.pop('_msgpack_module', msgpack) - return msgpack_module.pack(o, stream, **kwargs) + orig_enc_func = kwargs.pop('default', lambda x: x) + + def _enc_func(obj): + obj = ThreadLocalProxy.unproxy(obj) + return orig_enc_func(obj) + + return msgpack_module.pack(o, stream, default=_enc_func, **kwargs) def packb(o, **kwargs): @@ -41,7 +50,13 @@ def packb(o, **kwargs): msgpack module using the _msgpack_module argument. ''' msgpack_module = kwargs.pop('_msgpack_module', msgpack) - return msgpack_module.packb(o, **kwargs) + orig_enc_func = kwargs.pop('default', lambda x: x) + + def _enc_func(obj): + obj = ThreadLocalProxy.unproxy(obj) + return orig_enc_func(obj) + + return msgpack_module.packb(o, default=_enc_func, **kwargs) def unpack(stream, **kwargs): diff --git a/salt/utils/thread_local_proxy.py b/salt/utils/thread_local_proxy.py new file mode 100644 index 000000000000..8be7ad03be30 --- /dev/null +++ b/salt/utils/thread_local_proxy.py @@ -0,0 +1,599 @@ +# -*- coding: utf-8 -*- +''' +Proxy object that can reference different values depending on the current +thread of execution. + +..versionadded:: 2018.3.4 + +''' + +# Import python libs +from __future__ import absolute_import +import threading + +# Import 3rd-party libs +from salt.ext import six + + +class ThreadLocalProxy(object): + ''' + Proxy that delegates all operations to its referenced object. The referenced + object is hold through a thread-local variable, so that this proxy may refer + to different objects in different threads of execution. + + For all practical purposes (operators, attributes, `isinstance`), the proxy + acts like the referenced object. Thus, code receiving the proxy object + instead of the reference object typically does not have to be changed. The + only exception is code that explicitly uses the ``type()`` function for + checking the proxy's type. While `isinstance(proxy, ...)` will yield the + expected results (based on the actual type of the referenced object), using + something like ``issubclass(type(proxy), ...)`` will not work, because + these tests will be made on the type of the proxy object instead of the + type of the referenced object. In order to avoid this, such code must be + changed to use ``issubclass(type(ThreadLocalProxy.unproxy(proxy)), ...)``. + + If an instance of this class is created with the ``fallback_to_shared`` flag + set and a thread uses the instance without setting the reference explicitly, + the reference for this thread is initialized with the latest reference set + by any thread. + + This class has primarily been designed for use by the Salt loader, but it + might also be useful in other places. + ''' + + __slots__ = ['_thread_local', '_last_reference', '_fallback_to_shared'] + + @staticmethod + def get_reference(proxy): + ''' + Return the object that is referenced by the specified proxy. + + If the proxy has not been bound to a reference for the current thread, + the behavior depends on th the ``fallback_to_shared`` flag that has + been specified when creating the proxy. If the flag has been set, the + last reference that has been set by any thread is returned (and + silently set as the reference for the current thread). If the flag has + not been set, an ``AttributeError`` is raised. + + If the object references by this proxy is itself a proxy, that proxy is + returned. Use ``unproxy`` for unwrapping the referenced object until it + is not a proxy. + + proxy: + proxy object for which the reference shall be returned. If the + specified object is not an instance of `ThreadLocalProxy`, the + behavior is unspecified. Typically, an ``AttributeError`` is + going to be raised. + ''' + thread_local = object.__getattribute__(proxy, '_thread_local') + try: + return thread_local.reference + except AttributeError: + fallback_to_shared = object.__getattribute__( + proxy, '_fallback_to_shared') + if fallback_to_shared: + # If the reference has never been set in the current thread of + # execution, we use the reference that has been last set by any + # thread. + reference = object.__getattribute__(proxy, '_last_reference') + # We save the reference in the thread local so that future + # calls to get_reference will have consistent results. + ThreadLocalProxy.set_reference(proxy, reference) + return reference + else: + # We could simply return None, but this would make it hard to + # debug situations where the reference has not been set (the + # problem might go unnoticed until some code tries to do + # something with the returned object and it might not be easy to + # find out from where the None value originates). + # For this reason, we raise an AttributeError with an error + # message explaining the problem. + raise AttributeError( + 'The proxy object has not been bound to a reference in this thread of execution.') + + @staticmethod + def set_reference(proxy, new_reference): + ''' + Set the reference to be used the current thread of execution. + + After calling this function, the specified proxy will act like it was + the referenced object. + + proxy: + proxy object for which the reference shall be set. If the specified + object is not an instance of `ThreadLocalProxy`, the behavior is + unspecified. Typically, an ``AttributeError`` is going to be + raised. + + new_reference: + reference the proxy should point to for the current thread after + calling this function. + ''' + # If the new reference is itself a proxy, we have to ensure that it does + # not refer to this proxy. If it does, we simply return because updating + # the reference would result in an inifite loop when trying to use the + # proxy. + possible_proxy = new_reference + while isinstance(possible_proxy, ThreadLocalProxy): + if possible_proxy is proxy: + return + possible_proxy = ThreadLocalProxy.get_reference(possible_proxy) + thread_local = object.__getattribute__(proxy, '_thread_local') + thread_local.reference = new_reference + object.__setattr__(proxy, '_last_reference', new_reference) + + @staticmethod + def unset_reference(proxy): + ''' + Unset the reference to be used by the current thread of execution. + + After calling this function, the specified proxy will act like the + reference had never been set for the current thread. + + proxy: + proxy object for which the reference shall be unset. If the + specified object is not an instance of `ThreadLocalProxy`, the + behavior is unspecified. Typically, an ``AttributeError`` is going + to be raised. + ''' + thread_local = object.__getattribute__(proxy, '_thread_local') + del thread_local.reference + + @staticmethod + def unproxy(possible_proxy): + ''' + Unwrap and return the object referenced by a proxy. + + This function is very similar to :func:`get_reference`, but works for + both proxies and regular objects. If the specified object is a proxy, + its reference is extracted with ``get_reference`` and returned. If it + is not a proxy, it is returned as is. + + If the object references by the proxy is itself a proxy, the unwrapping + is repeated until a regular (non-proxy) object is found. + + possible_proxy: + object that might or might not be a proxy. + ''' + while isinstance(possible_proxy, ThreadLocalProxy): + possible_proxy = ThreadLocalProxy.get_reference(possible_proxy) + return possible_proxy + + def __init__(self, initial_reference, fallback_to_shared=False): + ''' + Create a proxy object that references the specified object. + + initial_reference: + object this proxy should initially reference (for the current + thread of execution). The :func:`set_reference` function is called + for the newly created proxy, passing this object. + + fallback_to_shared: + flag indicating what should happen when the proxy is used in a + thread where the reference has not been set explicitly. If + ``True``, the thread's reference is silently initialized to use the + reference last set by any thread. If ``False`` (the default), an + exception is raised when the proxy is used in a thread without + first initializing the reference in this thread. + ''' + object.__setattr__(self, '_thread_local', threading.local()) + object.__setattr__(self, '_fallback_to_shared', fallback_to_shared) + ThreadLocalProxy.set_reference(self, initial_reference) + + def __repr__(self): + reference = ThreadLocalProxy.get_reference(self) + return repr(reference) + + def __str__(self): + reference = ThreadLocalProxy.get_reference(self) + return str(reference) + + def __lt__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference < other + + def __le__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference <= other + + def __eq__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference == other + + def __ne__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference != other + + def __gt__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference > other + + def __ge__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference >= other + + def __hash__(self): + reference = ThreadLocalProxy.get_reference(self) + return hash(reference) + + def __nonzero__(self): + reference = ThreadLocalProxy.get_reference(self) + return bool(reference) + + def __getattr__(self, name): + reference = ThreadLocalProxy.get_reference(self) + # Old-style classes might not have a __getattr__ method, but using + # getattr(...) will still work. + try: + original_method = reference.__getattr__ + except AttributeError: + return getattr(reference, name) + return reference.__getattr__(name) + + def __setattr__(self, name, value): + reference = ThreadLocalProxy.get_reference(self) + reference.__setattr__(name, value) + + def __delattr__(self, name): + reference = ThreadLocalProxy.get_reference(self) + reference.__delattr__(name) + + def __getattribute__(self, name): + reference = ThreadLocalProxy.get_reference(self) + return reference.__getattribute__(name) + + def __call__(self, *args, **kwargs): + reference = ThreadLocalProxy.get_reference(self) + return reference(*args, **kwargs) + + def __len__(self): + reference = ThreadLocalProxy.get_reference(self) + return len(reference) + + def __getitem__(self, key): + reference = ThreadLocalProxy.get_reference(self) + return reference[key] + + def __setitem__(self, key, value): + reference = ThreadLocalProxy.get_reference(self) + reference[key] = value + + def __delitem__(self, key): + reference = ThreadLocalProxy.get_reference(self) + del reference[key] + + def __iter__(self): + reference = ThreadLocalProxy.get_reference(self) + return reference.__iter__() + + def __reversed__(self): + reference = ThreadLocalProxy.get_reference(self) + return reversed(reference) + + def __contains__(self, item): + reference = ThreadLocalProxy.get_reference(self) + return item in reference + + def __add__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference + other + + def __sub__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference - other + + def __mul__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference * other + + def __floordiv__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference // other + + def __mod__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference % other + + def __divmod__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return divmod(reference, other) + + def __pow__(self, other, modulo=None): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + modulo = ThreadLocalProxy.unproxy(modulo) + if modulo is None: + return pow(reference, other) + else: + return pow(reference, other, modulo) + + def __lshift__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference << other + + def __rshift__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference >> other + + def __and__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference & other + + def __xor__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference ^ other + + def __or__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference | other + + def __div__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + try: + func = reference.__div__ + except AttributeError: + return NotImplemented + return func(other) + + def __truediv__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + try: + func = reference.__truediv__ + except AttributeError: + return NotImplemented + return func(other) + + def __radd__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other + reference + + def __rsub__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other - reference + + def __rmul__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other * reference + + def __rdiv__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + try: + func = reference.__rdiv__ + except AttributeError: + return NotImplemented + return func(other) + + def __rtruediv__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + try: + func = reference.__rtruediv__ + except AttributeError: + return NotImplemented + return func(other) + + def __rfloordiv__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other // reference + + def __rmod__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other % reference + + def __rdivmod__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return divmod(other, reference) + + def __rpow__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other ** reference + + def __rlshift__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other << reference + + def __rrshift__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other >> reference + + def __rand__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other & reference + + def __rxor__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other ^ reference + + def __ror__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other | reference + + def __iadd__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference += other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __isub__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference -= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __imul__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference *= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __idiv__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + try: + func = reference.__idiv__ + except AttributeError: + return NotImplemented + reference = func(other) + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __itruediv__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + try: + func = reference.__itruediv__ + except AttributeError: + return NotImplemented + reference = func(other) + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __ifloordiv__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference //= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __imod__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference %= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __ipow__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference **= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __ilshift__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference <<= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __irshift__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference >>= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __iand__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference &= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __ixor__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference ^= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __ior__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference |= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __neg__(self): + reference = ThreadLocalProxy.get_reference(self) + return - reference + + def __pos__(self): + reference = ThreadLocalProxy.get_reference(self) + return + reference + + def __abs__(self): + reference = ThreadLocalProxy.get_reference(self) + return abs(reference) + + def __invert__(self): + reference = ThreadLocalProxy.get_reference(self) + return ~ reference + + def __complex__(self): + reference = ThreadLocalProxy.get_reference(self) + return complex(reference) + + def __int__(self): + reference = ThreadLocalProxy.get_reference(self) + return int(reference) + + def __float__(self): + reference = ThreadLocalProxy.get_reference(self) + return float(reference) + + def __oct__(self): + reference = ThreadLocalProxy.get_reference(self) + return oct(reference) + + def __hex__(self): + reference = ThreadLocalProxy.get_reference(self) + return hex(reference) + + def __index__(self): + reference = ThreadLocalProxy.get_reference(self) + try: + func = reference.__index__ + except AttributeError: + return NotImplemented + return func() + + def __coerce__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return coerce(reference, other) + + if six.PY2: + # pylint: disable=incompatible-py3-code + def __unicode__(self): + reference = ThreadLocalProxy.get_reference(self) + return unicode(reference) + + def __long__(self): + reference = ThreadLocalProxy.get_reference(self) + return long(reference) diff --git a/tests/unit/utils/test_thread_local_proxy.py b/tests/unit/utils/test_thread_local_proxy.py new file mode 100644 index 000000000000..0cd77379aed8 --- /dev/null +++ b/tests/unit/utils/test_thread_local_proxy.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- + +# Import python libs +from __future__ import absolute_import + +# Import Salt Libs +from salt.utils import thread_local_proxy + +# Import Salt Testing Libs +from tests.support.unit import TestCase + + +class ThreadLocalProxyTestCase(TestCase): + ''' + Test case for salt.utils.thread_local_proxy module. + ''' + + def test_set_reference_avoid_loop(self): + ''' + Test that passing another proxy (or the same proxy) to set_reference + does not results in a recursive proxy loop. + ''' + test_obj1 = 1 + test_obj2 = 2 + proxy1 = thread_local_proxy.ThreadLocalProxy(test_obj1) + proxy2 = thread_local_proxy.ThreadLocalProxy(proxy1) + self.assertEqual(test_obj1, proxy1) + self.assertEqual(test_obj1, proxy2) + self.assertEqual(proxy1, proxy2) + thread_local_proxy.ThreadLocalProxy.set_reference(proxy1, test_obj2) + self.assertEqual(test_obj2, proxy1) + self.assertEqual(test_obj2, proxy2) + self.assertEqual(proxy1, proxy2) + thread_local_proxy.ThreadLocalProxy.set_reference(proxy1, proxy2) + self.assertEqual(test_obj2, proxy1) + self.assertEqual(test_obj2, proxy2) + self.assertEqual(proxy1, proxy2)