diff --git a/easypy/concurrency.py b/easypy/concurrency.py index 8374a540..d7781618 100644 --- a/easypy/concurrency.py +++ b/easypy/concurrency.py @@ -87,6 +87,7 @@ from easypy.units import MINUTE, HOUR from easypy.colors import colorize, uncolored from easypy.sync import SynchronizationCoordinator, ProcessExiting, raise_in_main_thread +from easypy.decorations import parametrizeable_decorator MAX_THREAD_POOL_SIZE = int(os.environ.get('EASYPY_MAX_THREAD_POOL_SIZE', 50)) @@ -210,6 +211,7 @@ class MultiException(PException, metaclass=MultiExceptionMeta): :param complete: ``True`` if all threads failed on exception """ + COMMON_TYPE = BaseException # the fallback common type template = "{0.common_type.__qualname__} raised from concurrent invocation (x{0.count}/{0.invocations_count})" def __reduce__(self): @@ -892,7 +894,7 @@ def concestor(*cls_list): return object # the base-class that rules the all -class concurrent(object): +class _concurrent(object): """ Higher-level thread execution. @@ -972,7 +974,8 @@ def __repr__(self): flags += 'T' return "<%s[%s] '%s'>" % (self.__class__.__name__, self.threadname, flags) - def _logged_func(self): + def _logged_func(self, kwargs=None): + kwargs = {**self.kwargs, **(kwargs or {})} stack = ExitStack() self.exc = None self.timer = Timer() @@ -983,7 +986,7 @@ def _logged_func(self): stack.enter_context(_logger.suppressed()) _logger.debug("%s - starting", self) while True: - self._result = self.func(*self.args, **self.kwargs) + self._result = self.func(*self.args, **kwargs) if not self.loop: return if self.wait(self.sleep): @@ -1004,7 +1007,8 @@ def _logged_func(self): stack.close() def stop(self): - _logger.debug("%s - stopping", self) + if not self.stopper.is_set(): + _logger.debug("%s - stopping", self) self.stopper.set() def wait(self, timeout=None): @@ -1014,12 +1018,16 @@ def wait(self, timeout=None): while not timer.expired: if self.stopper.is_set(): return True - non_gevent_sleep(0.1) + if IS_GEVENT: + time.sleep(0.1) + else: + non_gevent_sleep(0.1) return False return self.stopper.wait(timeout) def result(self, timeout=None): - self.wait(timeout=timeout) + if not self.wait(timeout=timeout): + raise TimeoutError() if self.throw and self.exc: raise self.exc return self._result @@ -1037,18 +1045,20 @@ def paused(self): self.start() @contextmanager - def _running(self): + def _running(self, *args, **kwargs): + func = lambda *args, **kwargs: self._logged_func(*args, **kwargs) + if DISABLE_CONCURRENCY: - self._logged_func() + self._logged_func(*args, **kwargs) yield self return if self.real_thread_no_greenlet and IS_GEVENT: _logger.debug('sending job to a real OS thread') - self._join = defer_to_thread(func=self._logged_func, threadname=self.threadname) + self._join = defer_to_thread(func=func, threadname=self.threadname) else: # threading.Thread could be a real thread or a gevent-patched thread... - self.thread = threading.Thread(target=self._logged_func, name=self.threadname, daemon=self.daemon) + self.thread = threading.Thread(target=func, name=self.threadname, daemon=self.daemon) _logger.debug('sending job to %s', self.thread) self.stopper.clear() self.thread.start() @@ -1061,12 +1071,24 @@ def _running(self): if self.throw and self.exc: raise self.exc - def __enter__(self): - self._ctx = self._running() + def start(self, *args, **kwargs): + self._ctx = self._running(*args, **kwargs) return self._ctx.__enter__() + def join(self): + self.__exit__(None, None, None) + + __enter__ = start + def __exit__(self, *args): - return self._ctx.__exit__(*args) + if not self._ctx: + return + try: + return self._ctx.__exit__(*args) + finally: + self._ctx = None + + __del__ = join def __iter__(self): # TODO: document or remove @@ -1076,12 +1098,16 @@ def __iter__(self): yield self self.iterations += 1 - start = __enter__ + def __call__(self, *args, timeout_=None, **kwargs): + self.start(*args, **kwargs) + if not self.wait(timeout=timeout_): + raise TimeoutError() + return self.result() - def join(self): - self.__exit__(None, None, None) - __del__ = join +@parametrizeable_decorator +def concurrent(func=None, *args, **kwargs): + return _concurrent(func, *args, **kwargs) # re-exports diff --git a/easypy/decorations.py b/easypy/decorations.py index dd6bc45d..a454af67 100644 --- a/easypy/decorations.py +++ b/easypy/decorations.py @@ -12,11 +12,11 @@ def parametrizeable_decorator(deco): @wraps(deco) - def inner(func=None, **kwargs): + def inner(func=None, *args, **kwargs): if func is None: - return partial(deco, **kwargs) + return partial(deco, *args, **kwargs) else: - return wraps(func)(deco(func, **kwargs)) + return wraps(func)(deco(func, *args, **kwargs)) return inner diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index 3e8ebb06..04392334 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -24,6 +24,18 @@ def test_thread_stacks(): print(get_thread_stacks().render()) +def test_call_concurrent(): + func = lambda a, b: a + b + c = concurrent(func, 1, b=2, threadname='add') + assert c() == 3 + + +def test_call_concurrent_timeout(): + c = concurrent(sleep, 1, threadname='sleep') + with pytest.raises(TimeoutError): + c(timeout_=0.1) + + def test_thread_contexts_counters(): TC = ThreadContexts(counters=('i', 'j')) assert TC.i == TC.j == 0