From 19f98f5a9f4e14be74e538d51a9ca62d238d89ff Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Mon, 5 Dec 2022 11:11:47 +0800 Subject: [PATCH] Use `SIGINT` as the default signal in `queue kill` fix: #8624 1. Add a new flag `--force` for `queue kill` 2. Make `SIGINT` to be the default option and `SIGKILL` to be with `--force` 3. Add tests for `queue kill` 4. bump into dvc-task 0.1.7 --- dvc/commands/queue/kill.py | 15 +++++++++++++-- dvc/repo/experiments/queue/celery.py | 11 +++++++---- dvc/stage/run.py | 2 +- pyproject.toml | 2 +- tests/unit/command/test_queue.py | 3 ++- tests/unit/repo/experiments/queue/test_celery.py | 7 ++++--- 6 files changed, 28 insertions(+), 12 deletions(-) diff --git a/dvc/commands/queue/kill.py b/dvc/commands/queue/kill.py index 098c0911947..80d4dd5ca88 100644 --- a/dvc/commands/queue/kill.py +++ b/dvc/commands/queue/kill.py @@ -11,13 +11,15 @@ class CmdQueueKill(CmdBase): """Kill exp task in queue.""" def run(self): - self.repo.experiments.celery_queue.kill(revs=self.args.task) + self.repo.experiments.celery_queue.kill( + revs=self.args.task, force=self.args.force + ) return 0 def add_parser(queue_subparsers, parent_parser): - QUEUE_KILL_HELP = "Kill actively running experiment queue tasks." + QUEUE_KILL_HELP = "Send SIGINT(Ctrl-C) to running experiment queue tasks." queue_kill_parser = queue_subparsers.add_parser( "kill", parents=[parent_parser], @@ -25,6 +27,15 @@ def add_parser(queue_subparsers, parent_parser): help=QUEUE_KILL_HELP, formatter_class=argparse.RawDescriptionHelpFormatter, ) + queue_kill_parser.add_argument( + "-f", + "--force", + action="store_true", + default=False, + help="Send SIGKILL (kill -9) instead to running experiment queue " + "tasks. (The default `queue kill` will terminate more gracefully," + " collecting and cleaning up all resources)", + ) queue_kill_parser.add_argument( "task", nargs="*", diff --git a/dvc/repo/experiments/queue/celery.py b/dvc/repo/experiments/queue/celery.py index b8dca00db0b..00fb8657da0 100644 --- a/dvc/repo/experiments/queue/celery.py +++ b/dvc/repo/experiments/queue/celery.py @@ -300,12 +300,15 @@ def _get_running_task_ids(self) -> Set[str]: return running_task_ids def _try_to_kill_tasks( - self, to_kill: Dict[QueueEntry, str] + self, to_kill: Dict[QueueEntry, str], force: bool ) -> Dict[QueueEntry, str]: fail_to_kill_entries: Dict[QueueEntry, str] = {} for queue_entry, rev in to_kill.items(): try: - self.proc.kill(queue_entry.stash_rev) + if force: + self.proc.kill(queue_entry.stash_rev) + else: + self.proc.interrupt(queue_entry.stash_rev) logger.debug(f"Task {rev} had been killed.") except ProcessLookupError: fail_to_kill_entries[queue_entry] = rev @@ -333,7 +336,7 @@ def _mark_inactive_tasks_failure(self, remained_entries): if remained_revs: raise CannotKillTasksError(remained_revs) - def kill(self, revs: Collection[str]) -> None: + def kill(self, revs: Collection[str], force: bool = False) -> None: name_dict: Dict[ str, Optional[QueueEntry] ] = self.match_queue_entry_by_name(set(revs), self.iter_active()) @@ -349,7 +352,7 @@ def kill(self, revs: Collection[str]) -> None: raise UnresolvedQueueExpNamesError(missing_revs) fail_to_kill_entries: Dict[QueueEntry, str] = self._try_to_kill_tasks( - to_kill + to_kill, force ) if fail_to_kill_entries: diff --git a/dvc/stage/run.py b/dvc/stage/run.py index ca76dcb2fa7..3831b4e3238 100644 --- a/dvc/stage/run.py +++ b/dvc/stage/run.py @@ -79,9 +79,9 @@ def _run(stage, executable, cmd, checkpoint_func, **kwargs): threading.current_thread(), threading._MainThread, # pylint: disable=protected-access ) + old_handler = None exec_cmd = _make_cmd(executable, cmd) - old_handler = None try: p = subprocess.Popen(exec_cmd, **kwargs) diff --git a/pyproject.toml b/pyproject.toml index 319d59b7f22..98f5766b3c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ dependencies = [ "typing-extensions>=3.7.4", "scmrepo==0.1.4", "dvc-render==0.0.14", - "dvc-task==0.1.6", + "dvc-task==0.1.7", "dvclive>=1.0", "dvc-data==0.28.3", "dvc-http==2.27.2", diff --git a/tests/unit/command/test_queue.py b/tests/unit/command/test_queue.py index 7f54e7952e4..feba666ec20 100644 --- a/tests/unit/command/test_queue.py +++ b/tests/unit/command/test_queue.py @@ -92,6 +92,7 @@ def test_experiments_kill(dvc, scm, mocker): [ "queue", "kill", + "--force", "exp1", "exp2", ] @@ -105,7 +106,7 @@ def test_experiments_kill(dvc, scm, mocker): ) assert cmd.run() == 0 - m.assert_called_once_with(revs=["exp1", "exp2"]) + m.assert_called_once_with(revs=["exp1", "exp2"], force=True) def test_experiments_start(dvc, scm, mocker): diff --git a/tests/unit/repo/experiments/queue/test_celery.py b/tests/unit/repo/experiments/queue/test_celery.py index 6dede4d58aa..83eea296c1a 100644 --- a/tests/unit/repo/experiments/queue/test_celery.py +++ b/tests/unit/repo/experiments/queue/test_celery.py @@ -80,7 +80,8 @@ def test_post_run_after_kill(test_queue): assert result_foo.get(timeout=10) == "foo" -def test_celery_queue_kill(test_queue, mocker): +@pytest.mark.parametrize("force", [True, False]) +def test_celery_queue_kill(test_queue, mocker, force): mock_entry_foo = mocker.Mock(stash_rev="foo") mock_entry_bar = mocker.Mock(stash_rev="bar") @@ -139,13 +140,13 @@ def kill_function(rev): kill_mock = mocker.patch.object( test_queue.proc, - "kill", + "kill" if force else "interrupt", side_effect=mocker.MagicMock(side_effect=kill_function), ) with pytest.raises( CannotKillTasksError, match="Task 'foobar' is initializing," ): - test_queue.kill(["bar", "foo", "foobar"]) + test_queue.kill(["bar", "foo", "foobar"], force=force) assert kill_mock.called_once_with(mock_entry_foo.stash_rev) assert kill_mock.called_once_with(mock_entry_bar.stash_rev) assert kill_mock.called_once_with(mock_entry_foobar.stash_rev)