Skip to content

Commit

Permalink
Improve and fix loop checking for task completeness.
Browse files Browse the repository at this point in the history
  • Loading branch information
riga committed Jul 5, 2021
1 parent 9920bfd commit 043b4c6
Showing 1 changed file with 83 additions and 40 deletions.
123 changes: 83 additions & 40 deletions luigi/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import subprocess
import sys
import contextlib
import warnings

import queue as Queue
import random
Expand Down Expand Up @@ -342,21 +343,65 @@ def respond(self, response):
self._scheduler.add_scheduler_message_response(self._task_id, self._message_id, response)


class SyncResult(object):
"""
Synchronous implementation of ``multiprocessing.pool.AsyncResult`` that immediately calls *func*
with *args* and *kwargs*. Its methods :py:meth:`get`, :py:meth:`wait`, :py:meth:`ready` and
:py:meth:`successful` work in a similar fashion, depending on the result of the function call.
"""

def __init__(self, func, args=None, kwargs=None):
super(SyncResult, self).__init__()

# store function and arguments
self._func = func
self._args = args or ()
self._kwargs = kwargs or {}

# store return value and potential exceptions
self._return_value = None
self._exception = None

# immediately call
self._call()

def _call(self):
try:
self._return_value = self._func(*self._args, **self._kwargs)
except BaseException as e:
self._exception = e

def get(self, timeout=None):
if self._exception:
raise self._exception
else:
return self._return_value

def wait(self, timeout=None):
return

def ready(self):
return True

def successful(self):
return self._exception is None


class SingleProcessPool:
"""
Dummy process pool for using a single processor.
Imitates the api of multiprocessing.Pool using single-processor equivalents.
"""

def apply_async(self, function, args):
return function(*args)
def apply_async(self, function, args=None, kwargs=None):
return SyncResult(function, args=args, kwargs=kwargs)

def close(self):
pass
return

def join(self):
pass
return


class DequeQueue(collections.deque):
Expand All @@ -380,6 +425,8 @@ class AsyncCompletionException(Exception):
"""

def __init__(self, trace):
warnings.warn("{} is deprecated and will be removed in a future release".format(
self.__class__.__name__), DeprecationWarning)
self.trace = trace


Expand All @@ -389,19 +436,17 @@ class TracebackWrapper:
"""

def __init__(self, trace):
warnings.warn("{} is deprecated and will be removed in a future release".format(
self.__class__.__name__), DeprecationWarning)
self.trace = trace


def check_complete(task, out_queue):
def check_complete(task):
"""
Checks if task is complete, puts the result to out_queue.
Checks if task is complete.
"""
logger.debug("Checking if %s is complete", task)
try:
is_complete = task.complete()
except Exception:
is_complete = TracebackWrapper(traceback.format_exc())
out_queue.put((task, is_complete))
return task.complete()


class worker(Config):
Expand Down Expand Up @@ -727,7 +772,7 @@ def _handle_task_load_error(self, exception, task_ids):
)
notifications.send_error_email(subject, error_message)

def add(self, task, multiprocess=False, processes=0):
def add(self, task, multiprocess=False, processes=0, wait_interval=0.01):
"""
Add a Task for the worker to check and possibly schedule and run.
Expand All @@ -737,36 +782,36 @@ def add(self, task, multiprocess=False, processes=0):
self._first_task = task.task_id
self.add_succeeded = True
if multiprocess:
queue = multiprocessing.Manager().Queue()
pool = multiprocessing.Pool(processes=processes if processes > 0 else None)
else:
queue = DequeQueue()
pool = SingleProcessPool()
self._validate_task(task)
pool.apply_async(check_complete, [task, queue])
results = [(task, pool.apply_async(check_complete, (task,)))]

# we track queue size ourselves because len(queue) won't work for multiprocessing
queue_size = 1
try:
seen = {task.task_id}
while queue_size:
current = queue.get()
queue_size -= 1
item, is_complete = current
for next in self._add(item, is_complete):
if next.task_id not in seen:
self._validate_task(next)
seen.add(next.task_id)
pool.apply_async(check_complete, [next, queue])
queue_size += 1
except (KeyboardInterrupt, TaskException):
raise
except Exception as ex:
self.add_succeeded = False
formatted_traceback = traceback.format_exc()
self._log_unexpected_error(task)
task.trigger_event(Event.BROKEN_TASK, task, ex)
self._email_unexpected_error(task, formatted_traceback)
while results:
# fetch the first done result
for i, (task, result) in enumerate(list(results)):
if result.ready():
results.pop(i)
break
else:
time.sleep(wait_interval)
continue

# get the result or error
try:
is_complete = result.get()
except Exception as e:
is_complete = e

for dep in self._add(task, is_complete):
if dep.task_id not in seen:
self._validate_task(dep)
seen.add(dep.task_id)
results.append((dep, pool.apply_async(check_complete, (dep,))))
except BaseException:
raise
finally:
pool.close()
Expand Down Expand Up @@ -800,8 +845,6 @@ def _add(self, task, is_complete):
self._check_complete_value(is_complete)
except KeyboardInterrupt:
raise
except AsyncCompletionException as ex:
formatted_traceback = ex.trace
except BaseException:
formatted_traceback = traceback.format_exc()

Expand Down Expand Up @@ -881,9 +924,9 @@ def _validate_dependency(self, dependency):
raise Exception('requires() must return Task objects but {} is a {}'.format(dependency, type(dependency)))

def _check_complete_value(self, is_complete):
if is_complete not in (True, False):
if isinstance(is_complete, TracebackWrapper):
raise AsyncCompletionException(is_complete.trace)
if isinstance(is_complete, BaseException):
raise is_complete
elif not isinstance(is_complete, bool):
raise Exception("Return value of Task.complete() must be boolean (was %r)" % is_complete)

def _add_worker(self):
Expand Down

0 comments on commit 043b4c6

Please sign in to comment.