Skip to content

Commit

Permalink
Use SIGINT as the default signal in queue kill
Browse files Browse the repository at this point in the history
fix: iterative#8624
1. Add a new flag `--force` for `queue kill`
2. Make `SIGINT` to be the default option and `SIGKILL` to be with
   `--force`
3. Remove `SIGINT` blocking.
4. Add tests for `queue kill`
  • Loading branch information
karajan1001 committed Dec 5, 2022
1 parent aa2e830 commit 2ec97e1
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 22 deletions.
15 changes: 13 additions & 2 deletions dvc/commands/queue/kill.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,31 @@ class CmdQueueKill(CmdBase):
"""Kill exp task in queue."""

def run(self):
self.repo.experiments.celery_queue.kill(revs=self.args.task)
self.repo.experiments.celery_queue.kill(
revs=self.args.task, force=self.args.force
)

return 0


def add_parser(queue_subparsers, parent_parser):
QUEUE_KILL_HELP = "Kill actively running experiment queue tasks."
QUEUE_KILL_HELP = "Send SIGINT(Ctrl-C) to running experiment queue tasks."
queue_kill_parser = queue_subparsers.add_parser(
"kill",
parents=[parent_parser],
description=append_doc_link(QUEUE_KILL_HELP, "queue/kill"),
help=QUEUE_KILL_HELP,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
queue_kill_parser.add_argument(
"-f",
"--force",
action="store_true",
default=False,
help="Send SIGKILL (kill -9) instead to running experiment queue "
"tasks. (The default `queue kill` will terminate more gracefully,"
" collecting and cleaning up all resources)",
)
queue_kill_parser.add_argument(
"task",
nargs="*",
Expand Down
11 changes: 7 additions & 4 deletions dvc/repo/experiments/queue/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,15 @@ def _get_running_task_ids(self) -> Set[str]:
return running_task_ids

def _try_to_kill_tasks(
self, to_kill: Dict[QueueEntry, str]
self, to_kill: Dict[QueueEntry, str], force: bool
) -> Dict[QueueEntry, str]:
fail_to_kill_entries: Dict[QueueEntry, str] = {}
for queue_entry, rev in to_kill.items():
try:
self.proc.kill(queue_entry.stash_rev)
if force:
self.proc.kill(queue_entry.stash_rev)
else:
self.proc.interrupt(queue_entry.stash_rev)
logger.debug(f"Task {rev} had been killed.")
except ProcessLookupError:
fail_to_kill_entries[queue_entry] = rev
Expand Down Expand Up @@ -333,7 +336,7 @@ def _mark_inactive_tasks_failure(self, remained_entries):
if remained_revs:
raise CannotKillTasksError(remained_revs)

def kill(self, revs: Collection[str]) -> None:
def kill(self, revs: Collection[str], force: bool = False) -> None:
name_dict: Dict[
str, Optional[QueueEntry]
] = self.match_queue_entry_by_name(set(revs), self.iter_active())
Expand All @@ -349,7 +352,7 @@ def kill(self, revs: Collection[str]) -> None:
raise UnresolvedQueueExpNamesError(missing_revs)

fail_to_kill_entries: Dict[QueueEntry, str] = self._try_to_kill_tasks(
to_kill
to_kill, force
)

if fail_to_kill_entries:
Expand Down
15 changes: 3 additions & 12 deletions dvc/stage/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
import signal
import subprocess
import threading

from dvc.stage.monitor import Monitor
from dvc.utils import fix_env
Expand Down Expand Up @@ -75,19 +74,11 @@ def get_executable():


def _run(stage, executable, cmd, checkpoint_func, **kwargs):
main_thread = isinstance(
threading.current_thread(),
threading._MainThread, # pylint: disable=protected-access
)

exec_cmd = _make_cmd(executable, cmd)
old_handler = None

try:
p = subprocess.Popen(exec_cmd, **kwargs)
if main_thread:
old_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)

tasks = _get_monitor_tasks(stage, checkpoint_func, p)

if tasks:
Expand All @@ -101,9 +92,9 @@ def _run(stage, executable, cmd, checkpoint_func, **kwargs):
if t.killed.is_set():
raise t.error_cls(cmd, p.returncode)
raise StageCmdFailedError(cmd, p.returncode)
finally:
if old_handler:
signal.signal(signal.SIGINT, old_handler)
except KeyboardInterrupt:
p.send_signal(signal.SIGINT)
p.communicate()


def _get_monitor_tasks(stage, checkpoint_func, proc):
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/command/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def test_experiments_kill(dvc, scm, mocker):
[
"queue",
"kill",
"--force",
"exp1",
"exp2",
]
Expand All @@ -105,7 +106,7 @@ def test_experiments_kill(dvc, scm, mocker):
)

assert cmd.run() == 0
m.assert_called_once_with(revs=["exp1", "exp2"])
m.assert_called_once_with(revs=["exp1", "exp2"], force=True)


def test_experiments_start(dvc, scm, mocker):
Expand Down
7 changes: 4 additions & 3 deletions tests/unit/repo/experiments/queue/test_celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def test_post_run_after_kill(test_queue):
assert result_foo.get(timeout=10) == "foo"


def test_celery_queue_kill(test_queue, mocker):
@pytest.mark.parametrize("force", [True, False])
def test_celery_queue_kill(test_queue, mocker, force):

mock_entry_foo = mocker.Mock(stash_rev="foo")
mock_entry_bar = mocker.Mock(stash_rev="bar")
Expand Down Expand Up @@ -139,13 +140,13 @@ def kill_function(rev):

kill_mock = mocker.patch.object(
test_queue.proc,
"kill",
"kill" if force else "interrupt",
side_effect=mocker.MagicMock(side_effect=kill_function),
)
with pytest.raises(
CannotKillTasksError, match="Task 'foobar' is initializing,"
):
test_queue.kill(["bar", "foo", "foobar"])
test_queue.kill(["bar", "foo", "foobar"], force=force)
assert kill_mock.called_once_with(mock_entry_foo.stash_rev)
assert kill_mock.called_once_with(mock_entry_bar.stash_rev)
assert kill_mock.called_once_with(mock_entry_foobar.stash_rev)
Expand Down

0 comments on commit 2ec97e1

Please sign in to comment.