diff --git a/doc/luigi_patterns.rst b/doc/luigi_patterns.rst index 532b2d9966..327f86432e 100644 --- a/doc/luigi_patterns.rst +++ b/doc/luigi_patterns.rst @@ -226,6 +226,33 @@ the task parameters or other dynamic attributes: Since, by default, resources have a usage limit of 1, no two instances of Task A will now run if they have the same `important_file_name` property. +Decreasing resources of running tasks +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +At scheduling time, the luigi scheduler needs to be aware of the maximum +resource consumption a task might have once it runs. For some tasks, however, +it can be beneficial to decrease the amount of consumed resources between two +steps within their run method (e.g. after some heavy computation). In this +case, a different task waiting for that particular resource can already be +scheduled. + +.. code-block:: python + + class A(luigi.Task): + + # set maximum resources a priori + resources = {"some_resource": 3} + + def run(self): + # do something + ... + + # decrease consumption of "some_resource" by one + self.decrease_running_resources({"some_resource": 1}) + + # continue with reduced resources + ... + Monitoring task pipelines ~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/luigi/scheduler.py b/luigi/scheduler.py index 0a25cf7704..f83893e3c0 100644 --- a/luigi/scheduler.py +++ b/luigi/scheduler.py @@ -784,6 +784,8 @@ def add_task(self, task_id=None, status=PENDING, runnable=True, worker_id = worker worker = self._update_worker(worker_id) + resources = {} if resources is None else resources.copy() + if retry_policy_dict is None: retry_policy_dict = {} @@ -815,7 +817,9 @@ def add_task(self, task_id=None, status=PENDING, runnable=True, if status == RUNNING and not task.worker_running: task.worker_running = worker_id if batch_id: - task.resources_running = self._state.get_batch_running_tasks(batch_id)[0].resources_running + # copy resources_running of the first batch task + batch_tasks = self._state.get_batch_running_tasks(batch_id) + task.resources_running = batch_tasks[0].resources_running.copy() task.time_running = time.time() if tracking_url is not None or task.status != RUNNING: @@ -970,8 +974,9 @@ def _used_resources(self): used_resources = collections.defaultdict(int) if self._resources is not None: for task in self._state.get_active_tasks_by_status(RUNNING): - if getattr(task, 'resources_running', task.resources): - for resource, amount in six.iteritems(getattr(task, 'resources_running', task.resources)): + resources_running = getattr(task, "resources_running", task.resources) + if resources_running: + for resource, amount in six.iteritems(resources_running): used_resources[resource] += amount return used_resources @@ -1175,7 +1180,7 @@ def get_work(self, host=None, assistant=False, current_tasks=None, worker=None, elif best_task: self._state.set_status(best_task, RUNNING, self._config) best_task.worker_running = worker_id - best_task.resources_running = best_task.resources + best_task.resources_running = best_task.resources.copy() best_task.time_running = time.time() self._update_task_history(best_task, RUNNING, host=host) @@ -1237,6 +1242,7 @@ def _serialize_task(self, task_id, include_deps=True, deps=None): 'name': task.family, 'priority': task.priority, 'resources': task.resources, + 'resources_running': getattr(task, "resources_running", None), 'tracking_url': getattr(task, "tracking_url", None), 'status_message': getattr(task, "status_message", None), 'progress_percentage': getattr(task, "progress_percentage", None) @@ -1521,6 +1527,31 @@ def get_task_progress_percentage(self, task_id): else: return {"taskId": task_id, "progressPercentage": None} + @rpc_method() + def decrease_running_task_resources(self, task_id, decrease_resources): + if self._state.has_task(task_id): + task = self._state.get_task(task_id) + if task.status != RUNNING: + return + + def decrease(resources, decrease_resources): + for resource, decrease_amount in six.iteritems(decrease_resources): + if decrease_amount > 0 and resource in resources: + resources[resource] = max(0, resources[resource] - decrease_amount) + + decrease(task.resources_running, decrease_resources) + if task.batch_id is not None: + for batch_task in self._state.get_batch_running_tasks(task.batch_id): + decrease(batch_task.resources_running, decrease_resources) + + @rpc_method() + def get_running_task_resources(self, task_id): + if self._state.has_task(task_id): + task = self._state.get_task(task_id) + return {"taskId": task_id, "resources": getattr(task, "resources_running", None)} + else: + return {"taskId": task_id, "resources": None} + def _update_task_history(self, task, status, host=None): try: if status == DONE or status == FAILED: diff --git a/luigi/static/visualiser/js/visualiserApp.js b/luigi/static/visualiser/js/visualiserApp.js index b49def6df7..2cd3ceef51 100644 --- a/luigi/static/visualiser/js/visualiserApp.js +++ b/luigi/static/visualiser/js/visualiserApp.js @@ -65,7 +65,7 @@ function visualiserApp(luigi) { taskParams: taskParams, displayName: task.display_name, priority: task.priority, - resources: JSON.stringify(task.resources).replace(/,"/g, ', "'), + resources: JSON.stringify(task.resources_running || task.resources).replace(/,"/g, ', "'), displayTime: displayTime, displayTimestamp: task.last_updated, timeRunning: time_running, diff --git a/luigi/task.py b/luigi/task.py index 890ff348e5..08e40ae214 100644 --- a/luigi/task.py +++ b/luigi/task.py @@ -34,6 +34,7 @@ import copy import functools +import luigi from luigi import six from luigi import parameter @@ -686,7 +687,7 @@ def _dump(self): pickle.dumps(self) """ - unpicklable_properties = ('set_tracking_url', 'set_status_message', 'set_progress_percentage') + unpicklable_properties = tuple(luigi.worker.TaskProcess.forward_reporter_callbacks.values()) reserved_properties = {} for property_name in unpicklable_properties: if hasattr(self, property_name): diff --git a/luigi/worker.py b/luigi/worker.py index 4a2e276c38..aec97d32b9 100644 --- a/luigi/worker.py +++ b/luigi/worker.py @@ -110,6 +110,15 @@ class TaskProcess(multiprocessing.Process): Mainly for convenience since this is run in a separate process. """ + # mapping of status_reporter methods to task callbacks that are added to the task + # before they actually run, and removed afterwards + forward_reporter_callbacks = { + "update_tracking_url": "set_tracking_url", + "update_status_message": "set_status_message", + "update_progress_percentage": "set_progress_percentage", + "decrease_running_resources": "decrease_running_resources", + } + def __init__(self, task, worker_id, result_queue, status_reporter, use_multiprocessing=False, worker_timeout=0, check_unfulfilled_deps=True): super(TaskProcess, self).__init__() @@ -124,15 +133,15 @@ def __init__(self, task, worker_id, result_queue, status_reporter, self.check_unfulfilled_deps = check_unfulfilled_deps def _run_get_new_deps(self): - self.task.set_tracking_url = self.status_reporter.update_tracking_url - self.task.set_status_message = self.status_reporter.update_status_message - self.task.set_progress_percentage = self.status_reporter.update_progress_percentage + # set task callbacks before running + for reporter_attr, task_attr in six.iteritems(self.forward_reporter_callbacks): + setattr(self.task, task_attr, getattr(self.status_reporter, reporter_attr)) task_gen = self.task.run() - self.task.set_tracking_url = None - self.task.set_status_message = None - self.task.set_progress_percentage = None + # reset task callbacks + for reporter_attr, task_attr in six.iteritems(self.forward_reporter_callbacks): + setattr(self.task, task_attr, None) if not isinstance(task_gen, types.GeneratorType): return None @@ -274,6 +283,9 @@ def update_status_message(self, message): def update_progress_percentage(self, percentage): self._scheduler.set_task_progress_percentage(self._task_id, percentage) + def decrease_running_resources(self, decrease_resources): + self._scheduler.decrease_running_task_resources(self._task_id, decrease_resources) + class SingleProcessPool(object): """ diff --git a/test/scheduler_api_test.py b/test/scheduler_api_test.py index e1033abe62..8fe672a0ee 100644 --- a/test/scheduler_api_test.py +++ b/test/scheduler_api_test.py @@ -408,12 +408,12 @@ def test_set_batch_runner_max(self): self.sch.add_task(worker=WORKER, task_id='A_2', status=DONE) self.assertEqual({'A_1', 'A_2'}, set(self.sch.task_list(DONE, '').keys())) - def _start_simple_batch(self, use_max=False, mark_running=True): + def _start_simple_batch(self, use_max=False, mark_running=True, resources=None): self.sch.add_task_batcher(worker=WORKER, task_family='A', batched_args=['a']) self.sch.add_task(worker=WORKER, task_id='A_1', family='A', params={'a': '1'}, - batchable=True) + batchable=True, resources=resources) self.sch.add_task(worker=WORKER, task_id='A_2', family='A', params={'a': '2'}, - batchable=True) + batchable=True, resources=resources) response = self.sch.get_work(worker=WORKER) if mark_running: batch_id = response['batch_id'] @@ -496,6 +496,13 @@ def test_batch_update_progress(self): for task_id in ('A_1', 'A_2', 'A_1_2'): self.assertEqual(30, self.sch.get_task_progress_percentage(task_id)['progressPercentage']) + def test_batch_decrease_resources(self): + self.sch.update_resources(x=3) + self._start_simple_batch(resources={'x': 3}) + self.sch.decrease_running_task_resources('A_1_2', {'x': 1}) + for task_id in ('A_1', 'A_2', 'A_1_2'): + self.assertEqual(2, self.sch.get_running_task_resources(task_id)['resources']['x']) + def test_batch_tracking_url(self): self._start_simple_batch() self.sch.add_task(worker=WORKER, task_id='A_1_2', tracking_url='http://test.tracking.url/') diff --git a/test/task_running_resources_test.py b/test/task_running_resources_test.py new file mode 100644 index 0000000000..79727d971f --- /dev/null +++ b/test/task_running_resources_test.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2012-2015 Spotify AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import time +import signal +import multiprocessing +from contextlib import contextmanager + +from helpers import unittest, RunOnceTask + +import luigi +import luigi.server + + +class ResourceTestTask(RunOnceTask): + + param = luigi.Parameter() + reduce_foo = luigi.BoolParameter() + + def process_resources(self): + return {"foo": 2} + + def run(self): + if self.reduce_foo: + self.decrease_running_resources({"foo": 1}) + + time.sleep(2) + + super(ResourceTestTask, self).run() + + +class ResourceWrapperTask(RunOnceTask): + + reduce_foo = ResourceTestTask.reduce_foo + + def requires(self): + return [ + ResourceTestTask(param="a", reduce_foo=self.reduce_foo), + ResourceTestTask(param="b"), + ] + + +class LocalRunningResourcesTest(unittest.TestCase): + + def test_resource_reduction(self): + # trivial resource reduction on local scheduler + # test the running_task_resources setter and getter + sch = luigi.scheduler.Scheduler(resources={"foo": 2}) + + with luigi.worker.Worker(scheduler=sch) as w: + task = ResourceTestTask(param="a", reduce_foo=True) + + w.add(task) + w.run() + + self.assertEqual(sch.get_running_task_resources(task.task_id)["resources"]["foo"], 1) + + +class ConcurrentRunningResourcesTest(unittest.TestCase): + + def get_app(self): + return luigi.server.app(luigi.scheduler.Scheduler()) + + def setUp(self): + super(ConcurrentRunningResourcesTest, self).setUp() + + # run the luigi server in a new process and wait for its startup + self._process = multiprocessing.Process(target=luigi.server.run) + self._process.start() + time.sleep(0.5) + + # configure the rpc scheduler, update the foo resource + self.sch = luigi.rpc.RemoteScheduler() + self.sch.update_resource("foo", 3) + + def tearDown(self): + super(ConcurrentRunningResourcesTest, self).tearDown() + + # graceful server shutdown + self._process.terminate() + self._process.join(timeout=1) + if self._process.is_alive(): + os.kill(self._process.pid, signal.SIGKILL) + + @contextmanager + def worker(self, scheduler=None, processes=2): + with luigi.worker.Worker(scheduler=scheduler or self.sch, worker_processes=processes) as w: + w._config.wait_interval = 0.2 + w._config.check_unfulfilled_deps = False + yield w + + @contextmanager + def assert_duration(self, min_duration=0, max_duration=-1): + t0 = time.time() + try: + yield + finally: + duration = time.time() - t0 + self.assertGreater(duration, min_duration) + if max_duration > 0: + self.assertLess(duration, max_duration) + + def test_tasks_serial(self): + # serial test + # run two tasks that do not reduce the "foo" resource + # as the total foo resource (3) is smaller than the requirement of two tasks (4), + # the scheduler is forced to run them serially which takes longer than 4 seconds + with self.worker() as w: + w.add(ResourceWrapperTask(reduce_foo=False)) + + with self.assert_duration(min_duration=4): + w.run() + + def test_tasks_parallel(self): + # parallel test + # run two tasks and the first one lowers its requirement on the "foo" resource, so that + # the total "foo" resource (3) is sufficient to run both tasks in parallel shortly after + # the first task started, so the entire process should not exceed 4 seconds + with self.worker() as w: + w.add(ResourceWrapperTask(reduce_foo=True)) + + with self.assert_duration(max_duration=4): + w.run()