Skip to content

Commit

Permalink
Use SIGINT as the default signal in queue kill
Browse files Browse the repository at this point in the history
fix: #8624
1. Add a new flag `--force` for `queue kill`
2. Make `SIGINT` to be the default option and `SIGKILL` to be with
   `--force`
3. Add tests for `queue kill`
4. Bump dvc-task into 0.1.9
  • Loading branch information
karajan1001 committed Dec 30, 2022
1 parent 5f0b1ed commit 06673dd
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 17 deletions.
16 changes: 14 additions & 2 deletions dvc/commands/queue/kill.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,32 @@ class CmdQueueKill(CmdBase):
"""Kill exp task in queue."""

def run(self):
self.repo.experiments.celery_queue.kill(revs=self.args.task)
self.repo.experiments.celery_queue.kill(
revs=self.args.task, force=self.args.force
)

return 0


def add_parser(queue_subparsers, parent_parser):
QUEUE_KILL_HELP = "Kill actively running experiment queue tasks."
QUEUE_KILL_HELP = (
"Gracefully interrupt running experiment queue tasks "
"(equivalent to Ctrl-C)"
)
queue_kill_parser = queue_subparsers.add_parser(
"kill",
parents=[parent_parser],
description=append_doc_link(QUEUE_KILL_HELP, "queue/kill"),
help=QUEUE_KILL_HELP,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
queue_kill_parser.add_argument(
"-f",
"--force",
action="store_true",
default=False,
help="Forcefully and immediately kill running experiment queue tasks",
)
queue_kill_parser.add_argument(
"task",
nargs="*",
Expand Down
18 changes: 10 additions & 8 deletions dvc/repo/experiments/queue/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,15 @@ def _get_running_task_ids(self) -> Set[str]:
return running_task_ids

def _try_to_kill_tasks(
self, to_kill: Dict[QueueEntry, str]
self, to_kill: Dict[QueueEntry, str], force: bool
) -> Dict[QueueEntry, str]:
fail_to_kill_entries: Dict[QueueEntry, str] = {}
for queue_entry, rev in to_kill.items():
try:
self.proc.kill(queue_entry.stash_rev)
if force:
self.proc.kill(queue_entry.stash_rev)
else:
self.proc.interrupt(queue_entry.stash_rev)
logger.debug(f"Task {rev} had been killed.")
except ProcessLookupError:
fail_to_kill_entries[queue_entry] = rev
Expand Down Expand Up @@ -331,20 +334,19 @@ def _mark_inactive_tasks_failure(self, remained_entries):
if remained_revs:
raise CannotKillTasksError(remained_revs)

def _kill_entries(self, entries: Dict[QueueEntry, str]):
def _kill_entries(self, entries: Dict[QueueEntry, str], force: bool):
logger.debug(
"Found active tasks: '%s' to kill",
list(entries.values()),
)
inactive_entries: Dict[QueueEntry, str] = self._try_to_kill_tasks(
entries
entries, force
)

if inactive_entries:
self._mark_inactive_tasks_failure(inactive_entries)

def kill(self, revs: Collection[str]) -> None:

def kill(self, revs: Collection[str], force: bool = False) -> None:
name_dict: Dict[
str, Optional[QueueEntry]
] = self.match_queue_entry_by_name(set(revs), self.iter_active())
Expand All @@ -360,7 +362,7 @@ def kill(self, revs: Collection[str]) -> None:
raise UnresolvedQueueExpNamesError(missing_revs)

if to_kill:
self._kill_entries(to_kill)
self._kill_entries(to_kill, force)

def shutdown(self, kill: bool = False):
self.celery.control.shutdown()
Expand All @@ -369,7 +371,7 @@ def shutdown(self, kill: bool = False):
for entry in self.iter_active():
to_kill[entry] = entry.name or entry.stash_rev
if to_kill:
self._kill_entries(to_kill)
self._kill_entries(to_kill, True)

def follow(
self,
Expand Down
2 changes: 1 addition & 1 deletion dvc/stage/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ def _run(stage: "Stage", executable, cmd, checkpoint_func, **kwargs):
threading.current_thread(),
threading._MainThread, # type: ignore[attr-defined]
)
old_handler = None

exec_cmd = _make_cmd(executable, cmd)
old_handler = None

try:
p = subprocess.Popen(exec_cmd, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ dependencies = [
"typing-extensions>=3.7.4",
"scmrepo==0.1.5",
"dvc-render==0.0.17",
"dvc-task==0.1.8",
"dvc-task==0.1.9",
"dvclive>=1.2.2",
"dvc-data==0.28.4",
"dvc-http==2.27.2",
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/command/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def test_experiments_kill(dvc, scm, mocker):
[
"queue",
"kill",
"--force",
"exp1",
"exp2",
]
Expand All @@ -105,7 +106,7 @@ def test_experiments_kill(dvc, scm, mocker):
)

assert cmd.run() == 0
m.assert_called_once_with(revs=["exp1", "exp2"])
m.assert_called_once_with(revs=["exp1", "exp2"], force=True)


def test_experiments_start(dvc, scm, mocker):
Expand Down
9 changes: 5 additions & 4 deletions tests/unit/repo/experiments/queue/test_celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_shutdown_with_kill(test_queue, mocker):

shutdown_spy.assert_called_once()
kill_spy.assert_called_once_with(
{mock_entry_foo: "foo", mock_entry_bar: "bar"}
{mock_entry_foo: "foo", mock_entry_bar: "bar"}, True
)


Expand Down Expand Up @@ -78,7 +78,8 @@ def test_post_run_after_kill(test_queue):
assert result_foo.get(timeout=10) == "foo"


def test_celery_queue_kill(test_queue, mocker):
@pytest.mark.parametrize("force", [True, False])
def test_celery_queue_kill(test_queue, mocker, force):

mock_entry_foo = mocker.Mock(stash_rev="foo")
mock_entry_bar = mocker.Mock(stash_rev="bar")
Expand Down Expand Up @@ -137,13 +138,13 @@ def kill_function(rev):

kill_mock = mocker.patch.object(
test_queue.proc,
"kill",
"kill" if force else "interrupt",
side_effect=mocker.MagicMock(side_effect=kill_function),
)
with pytest.raises(
CannotKillTasksError, match="Task 'foobar' is initializing,"
):
test_queue.kill(["bar", "foo", "foobar"])
test_queue.kill(["bar", "foo", "foobar"], force=force)
assert kill_mock.called_once_with(mock_entry_foo.stash_rev)
assert kill_mock.called_once_with(mock_entry_bar.stash_rev)
assert kill_mock.called_once_with(mock_entry_foobar.stash_rev)
Expand Down

0 comments on commit 06673dd

Please sign in to comment.