diff --git a/salt/utils/process.py b/salt/utils/process.py index e70aca530c21..bc4e8c510be6 100644 --- a/salt/utils/process.py +++ b/salt/utils/process.py @@ -14,6 +14,7 @@ import types import signal import logging +import functools import threading import contextlib import subprocess @@ -21,7 +22,6 @@ import multiprocessing.util import socket - # Import salt libs import salt.defaults.exitcodes import salt.utils.files @@ -716,6 +716,15 @@ def __init__(self, *args, **kwargs): (salt.log.setup.shutdown_multiprocessing_logging, [], {}) ] + # Because we need to enforce our after fork and finalize routines, + # we must wrap this class run method to allow for these extra steps + # to be executed pre and post calling the actual run method, + # having subclasses call super would just not work. + # + # We use setattr here to fool pylint not to complain that we're + # overriding run from the subclass here + setattr(self, 'run', self.__decorate_run(self.run)) + # __setstate__ and __getstate__ are only used on Windows. def __setstate__(self, state): args = state['args'] @@ -741,25 +750,30 @@ def __getstate__(self): def _setup_process_logging(self): salt.log.setup.setup_multiprocessing_logging(self.log_queue) - def run(self): - for method, args, kwargs in self._after_fork_methods: - method(*args, **kwargs) - try: - return super(Process, self).run() - except SystemExit: - # These are handled by multiprocessing.Process._bootstrap() - raise - except Exception as exc: - log.error( - 'An un-handled exception from the multiprocessing process ' - '\'%s\' was caught:\n', self.name, exc_info=True) - # Re-raise the exception. multiprocessing.Process will write it to - # sys.stderr and set the proper exitcode and we have already logged - # it above. - raise - finally: - for method, args, kwargs in self._finalize_methods: + def __decorate_run(self, run_func): + + @functools.wraps(run_func) + def wrapped_run_func(): + for method, args, kwargs in self._after_fork_methods: method(*args, **kwargs) + try: + return run_func() + except SystemExit: + # These are handled by multiprocessing.Process._bootstrap() + six.reraise(*sys.exc_info()) + except Exception as exc: # pylint: disable=broad-except + log.error( + 'An un-handled exception from the multiprocessing process ' + '\'%s\' was caught:\n', self.name, exc_info=True) + # Re-raise the exception. multiprocessing.Process will write it to + # sys.stderr and set the proper exitcode and we have already logged + # it above. + six.reraise(*sys.exc_info()) + finally: + for method, args, kwargs in self._finalize_methods: + method(*args, **kwargs) + + return wrapped_run_func class MultiprocessingProcess(Process): diff --git a/tests/unit/utils/test_process.py b/tests/unit/utils/test_process.py index b892401f0899..a87d4c68340a 100644 --- a/tests/unit/utils/test_process.py +++ b/tests/unit/utils/test_process.py @@ -43,7 +43,9 @@ def wrapper(self): def _die(): salt.utils.process.appendproctitle('test_{0}'.format(name)) - setattr(self, 'die_' + name, _die) + attrname = 'die_' + name + setattr(self, attrname, _die) + self.addCleanup(delattr, self, attrname) return wrapper @@ -61,7 +63,9 @@ def _incr(counter, num): salt.utils.process.appendproctitle('test_{0}'.format(name)) for _ in range(0, num): counter.value += 1 - setattr(self, 'incr_' + name, _incr) + attrname = 'incr_' + name + setattr(self, attrname, _incr) + self.addCleanup(delattr, self, attrname) return wrapper @@ -79,7 +83,9 @@ def _spin(): salt.utils.process.appendproctitle('test_{0}'.format(name)) while True: time.sleep(1) - setattr(self, 'spin_' + name, _spin) + attrname = 'spin_' + name + setattr(self, attrname, _spin) + self.addCleanup(delattr, self, attrname) return wrapper @@ -241,6 +247,48 @@ def test_daemonize_if(self): # pylint: enable=assignment-from-none +class TestProcessCallbacks(TestCase): + + @staticmethod + def process_target(evt): + evt.set() + + @skipIf(NO_MOCK, NO_MOCK_REASON) + def test_callbacks(self): + 'Validate Process call after fork and finalize methods' + teardown_to_mock = 'salt.log.setup.shutdown_multiprocessing_logging' + log_to_mock = 'salt.utils.process.Process._setup_process_logging' + with patch(teardown_to_mock) as ma, patch(log_to_mock) as mb: + evt = multiprocessing.Event() + proc = salt.utils.process.Process(target=self.process_target, args=(evt,)) + proc.run() + assert evt.is_set() + mb.assert_called() + ma.assert_called() + + @skipIf(NO_MOCK, NO_MOCK_REASON) + def test_callbacks_called_when_run_overriden(self): + 'Validate Process sub classes call after fork and finalize methods when run is overridden' + + class MyProcess(salt.utils.process.Process): + + def __init__(self): + super(MyProcess, self).__init__() + self.evt = multiprocessing.Event() + + def run(self): + self.evt.set() + + teardown_to_mock = 'salt.log.setup.shutdown_multiprocessing_logging' + log_to_mock = 'salt.utils.process.Process._setup_process_logging' + with patch(teardown_to_mock) as ma, patch(log_to_mock) as mb: + proc = MyProcess() + proc.run() + assert proc.evt.is_set() + ma.assert_called() + mb.assert_called() + + class TestSignalHandlingProcess(TestCase): @classmethod @@ -323,33 +371,6 @@ def test_signal_processing_regression_test(self): def no_op_target(): pass - @skipIf(NO_MOCK, NO_MOCK_REASON) - def test_signal_processing_test_after_fork_called(self): - 'Validate Process and sub classes call after fork methods' - evt = multiprocessing.Event() - sig_to_mock = 'salt.utils.process.SignalHandlingProcess._setup_signals' - log_to_mock = 'salt.utils.process.Process._setup_process_logging' - with patch(sig_to_mock) as ma, patch(log_to_mock) as mb: - self.sh_proc = salt.utils.process.SignalHandlingProcess(target=self.no_op_target) - self.sh_proc.run() - ma.assert_called() - mb.assert_called() - - @skipIf(NO_MOCK, NO_MOCK_REASON) - def test_signal_processing_test_final_methods_called(self): - 'Validate Process and sub classes call finalize methods' - evt = multiprocessing.Event() - teardown_to_mock = 'salt.log.setup.shutdown_multiprocessing_logging' - log_to_mock = 'salt.utils.process.Process._setup_process_logging' - sig_to_mock = 'salt.utils.process.SignalHandlingProcess._setup_signals' - # Mock _setup_signals so we do not register one for this process. - with patch(sig_to_mock): - with patch(teardown_to_mock) as ma, patch(log_to_mock) as mb: - self.sh_proc = salt.utils.process.SignalHandlingProcess(target=self.no_op_target) - self.sh_proc.run() - ma.assert_called() - mb.assert_called() - @staticmethod def pid_setting_target(sub_target, val, evt): val.value = os.getpid() @@ -406,6 +427,58 @@ def test_signal_processing_handle_signals_called(self): proc.join(30) +class TestSignalHandlingProcessCallbacks(TestCase): + + @staticmethod + def process_target(evt): + evt.set() + + @skipIf(NO_MOCK, NO_MOCK_REASON) + def test_callbacks(self): + 'Validate SignalHandlingProcess call after fork and finalize methods' + + teardown_to_mock = 'salt.log.setup.shutdown_multiprocessing_logging' + log_to_mock = 'salt.utils.process.Process._setup_process_logging' + sig_to_mock = 'salt.utils.process.SignalHandlingProcess._setup_signals' + # Mock _setup_signals so we do not register one for this process. + evt = multiprocessing.Event() + with patch(sig_to_mock): + with patch(teardown_to_mock) as ma, patch(log_to_mock) as mb: + sh_proc = salt.utils.process.SignalHandlingProcess( + target=self.process_target, + args=(evt,) + ) + sh_proc.run() + assert evt.is_set() + ma.assert_called() + mb.assert_called() + + @skipIf(NO_MOCK, NO_MOCK_REASON) + def test_callbacks_called_when_run_overriden(self): + 'Validate SignalHandlingProcess sub classes call after fork and finalize methods when run is overridden' + + class MyProcess(salt.utils.process.SignalHandlingProcess): + + def __init__(self): + super(MyProcess, self).__init__() + self.evt = multiprocessing.Event() + + def run(self): + self.evt.set() + + teardown_to_mock = 'salt.log.setup.shutdown_multiprocessing_logging' + log_to_mock = 'salt.utils.process.Process._setup_process_logging' + sig_to_mock = 'salt.utils.process.SignalHandlingProcess._setup_signals' + # Mock _setup_signals so we do not register one for this process. + with patch(sig_to_mock): + with patch(teardown_to_mock) as ma, patch(log_to_mock) as mb: + sh_proc = MyProcess() + sh_proc.run() + assert sh_proc.evt.is_set() + ma.assert_called() + mb.assert_called() + + class TestDup2(TestCase): def test_dup2_no_fileno(self):