From b12be005fd015551908a42518edf99181d1c5171 Mon Sep 17 00:00:00 2001 From: Ahmet DAL Date: Thu, 28 Jul 2016 14:31:36 +0300 Subject: [PATCH] Marking as minimum upstream severity instead of max (#1789) Luigi marks a wrapper task as `UPSTREAM_FAILED` or `UPSTREAM_DISABLED` when ANY of its upstream task is `FAILED` or `DISABLED`. Moreover, when the wrapper task is marked as `UPSTREAM_DISABLED`, luigi shut the worker down as it should do. This causes another PENDING_TASK in the wrapper task not to run because worker is down. This behaviour should change. Marking as `UPSTREAM_` should be ALL instead of ANY --- luigi/scheduler.py | 5 ++--- test/scheduler_visualisation_test.py | 15 ++++++++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/luigi/scheduler.py b/luigi/scheduler.py index 366fbd5991..ed8ed643dc 100644 --- a/luigi/scheduler.py +++ b/luigi/scheduler.py @@ -862,9 +862,8 @@ def _upstream_status(self, task_id, upstream_status_table): elif upstream_status_table[dep_id] == '' and dep.deps: # This is the postorder update step when we set the # status based on the previously calculated child elements - status = max((upstream_status_table.get(a_task_id, '') - for a_task_id in dep.deps), - key=UPSTREAM_SEVERITY_KEY) + upstream_severities = list(upstream_status_table.get(a_task_id) for a_task_id in dep.deps if a_task_id in upstream_status_table) or [''] + status = min(upstream_severities, key=UPSTREAM_SEVERITY_KEY) upstream_status_table[dep_id] = status return upstream_status_table[dep_id] diff --git a/test/scheduler_visualisation_test.py b/test/scheduler_visualisation_test.py index 51df381fb7..2f7f195d7c 100644 --- a/test/scheduler_visualisation_test.py +++ b/test/scheduler_visualisation_test.py @@ -423,7 +423,7 @@ def requires(self): self.assertEqual(db['status'], 'DONE') missing_input = remote.task_list('PENDING', 'UPSTREAM_MISSING_INPUT') - self.assertEqual(len(missing_input), 2) + self.assertEqual(len(missing_input), 3) pa = missing_input.get(A().task_id) self.assertEqual(pa['status'], 'PENDING') @@ -433,14 +433,15 @@ def requires(self): self.assertEqual(pc['status'], 'PENDING') self.assertEqual(remote._upstream_status(C().task_id, {}), 'UPSTREAM_MISSING_INPUT') - upstream_failed = remote.task_list('PENDING', 'UPSTREAM_FAILED') - self.assertEqual(len(upstream_failed), 2) - pe = upstream_failed.get(E().task_id) + pe = missing_input.get(E().task_id) self.assertEqual(pe['status'], 'PENDING') - self.assertEqual(remote._upstream_status(E().task_id, {}), 'UPSTREAM_FAILED') + self.assertEqual(remote._upstream_status(E().task_id, {}), 'UPSTREAM_MISSING_INPUT') - pe = upstream_failed.get(D().task_id) - self.assertEqual(pe['status'], 'PENDING') + upstream_failed = remote.task_list('PENDING', 'UPSTREAM_FAILED') + self.assertEqual(len(upstream_failed), 1) + + pd = upstream_failed.get(D().task_id) + self.assertEqual(pd['status'], 'PENDING') self.assertEqual(remote._upstream_status(D().task_id, {}), 'UPSTREAM_FAILED') pending = dict(missing_input)