diff --git a/dvc/commands/queue/kill.py b/dvc/commands/queue/kill.py index 098c0911947..80d4dd5ca88 100644 --- a/dvc/commands/queue/kill.py +++ b/dvc/commands/queue/kill.py @@ -11,13 +11,15 @@ class CmdQueueKill(CmdBase): """Kill exp task in queue.""" def run(self): - self.repo.experiments.celery_queue.kill(revs=self.args.task) + self.repo.experiments.celery_queue.kill( + revs=self.args.task, force=self.args.force + ) return 0 def add_parser(queue_subparsers, parent_parser): - QUEUE_KILL_HELP = "Kill actively running experiment queue tasks." + QUEUE_KILL_HELP = "Send SIGINT(Ctrl-C) to running experiment queue tasks." queue_kill_parser = queue_subparsers.add_parser( "kill", parents=[parent_parser], @@ -25,6 +27,15 @@ def add_parser(queue_subparsers, parent_parser): help=QUEUE_KILL_HELP, formatter_class=argparse.RawDescriptionHelpFormatter, ) + queue_kill_parser.add_argument( + "-f", + "--force", + action="store_true", + default=False, + help="Send SIGKILL (kill -9) instead to running experiment queue " + "tasks. (The default `queue kill` will terminate more gracefully," + " collecting and cleaning up all resources)", + ) queue_kill_parser.add_argument( "task", nargs="*", diff --git a/dvc/repo/experiments/queue/celery.py b/dvc/repo/experiments/queue/celery.py index b8dca00db0b..00fb8657da0 100644 --- a/dvc/repo/experiments/queue/celery.py +++ b/dvc/repo/experiments/queue/celery.py @@ -300,12 +300,15 @@ def _get_running_task_ids(self) -> Set[str]: return running_task_ids def _try_to_kill_tasks( - self, to_kill: Dict[QueueEntry, str] + self, to_kill: Dict[QueueEntry, str], force: bool ) -> Dict[QueueEntry, str]: fail_to_kill_entries: Dict[QueueEntry, str] = {} for queue_entry, rev in to_kill.items(): try: - self.proc.kill(queue_entry.stash_rev) + if force: + self.proc.kill(queue_entry.stash_rev) + else: + self.proc.interrupt(queue_entry.stash_rev) logger.debug(f"Task {rev} had been killed.") except ProcessLookupError: fail_to_kill_entries[queue_entry] = rev @@ -333,7 +336,7 @@ def _mark_inactive_tasks_failure(self, remained_entries): if remained_revs: raise CannotKillTasksError(remained_revs) - def kill(self, revs: Collection[str]) -> None: + def kill(self, revs: Collection[str], force: bool = False) -> None: name_dict: Dict[ str, Optional[QueueEntry] ] = self.match_queue_entry_by_name(set(revs), self.iter_active()) @@ -349,7 +352,7 @@ def kill(self, revs: Collection[str]) -> None: raise UnresolvedQueueExpNamesError(missing_revs) fail_to_kill_entries: Dict[QueueEntry, str] = self._try_to_kill_tasks( - to_kill + to_kill, force ) if fail_to_kill_entries: diff --git a/dvc/stage/run.py b/dvc/stage/run.py index ca76dcb2fa7..1f068627457 100644 --- a/dvc/stage/run.py +++ b/dvc/stage/run.py @@ -2,7 +2,6 @@ import os import signal import subprocess -import threading from dvc.stage.monitor import Monitor from dvc.utils import fix_env @@ -75,19 +74,11 @@ def get_executable(): def _run(stage, executable, cmd, checkpoint_func, **kwargs): - main_thread = isinstance( - threading.current_thread(), - threading._MainThread, # pylint: disable=protected-access - ) exec_cmd = _make_cmd(executable, cmd) - old_handler = None try: p = subprocess.Popen(exec_cmd, **kwargs) - if main_thread: - old_handler = signal.signal(signal.SIGINT, signal.SIG_IGN) - tasks = _get_monitor_tasks(stage, checkpoint_func, p) if tasks: @@ -101,9 +92,9 @@ def _run(stage, executable, cmd, checkpoint_func, **kwargs): if t.killed.is_set(): raise t.error_cls(cmd, p.returncode) raise StageCmdFailedError(cmd, p.returncode) - finally: - if old_handler: - signal.signal(signal.SIGINT, old_handler) + except KeyboardInterrupt: + p.send_signal(signal.SIGINT) + p.communicate() def _get_monitor_tasks(stage, checkpoint_func, proc): diff --git a/tests/unit/command/test_queue.py b/tests/unit/command/test_queue.py index 7f54e7952e4..feba666ec20 100644 --- a/tests/unit/command/test_queue.py +++ b/tests/unit/command/test_queue.py @@ -92,6 +92,7 @@ def test_experiments_kill(dvc, scm, mocker): [ "queue", "kill", + "--force", "exp1", "exp2", ] @@ -105,7 +106,7 @@ def test_experiments_kill(dvc, scm, mocker): ) assert cmd.run() == 0 - m.assert_called_once_with(revs=["exp1", "exp2"]) + m.assert_called_once_with(revs=["exp1", "exp2"], force=True) def test_experiments_start(dvc, scm, mocker): diff --git a/tests/unit/repo/experiments/queue/test_celery.py b/tests/unit/repo/experiments/queue/test_celery.py index 6dede4d58aa..83eea296c1a 100644 --- a/tests/unit/repo/experiments/queue/test_celery.py +++ b/tests/unit/repo/experiments/queue/test_celery.py @@ -80,7 +80,8 @@ def test_post_run_after_kill(test_queue): assert result_foo.get(timeout=10) == "foo" -def test_celery_queue_kill(test_queue, mocker): +@pytest.mark.parametrize("force", [True, False]) +def test_celery_queue_kill(test_queue, mocker, force): mock_entry_foo = mocker.Mock(stash_rev="foo") mock_entry_bar = mocker.Mock(stash_rev="bar") @@ -139,13 +140,13 @@ def kill_function(rev): kill_mock = mocker.patch.object( test_queue.proc, - "kill", + "kill" if force else "interrupt", side_effect=mocker.MagicMock(side_effect=kill_function), ) with pytest.raises( CannotKillTasksError, match="Task 'foobar' is initializing," ): - test_queue.kill(["bar", "foo", "foobar"]) + test_queue.kill(["bar", "foo", "foobar"], force=force) assert kill_mock.called_once_with(mock_entry_foo.stash_rev) assert kill_mock.called_once_with(mock_entry_bar.stash_rev) assert kill_mock.called_once_with(mock_entry_foobar.stash_rev)