Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

exp: support standalone (non-queued) --temp runs #7894

Merged
merged 7 commits into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 55 additions & 40 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import re
import time
from typing import Any, Dict, Iterable, Mapping, Optional
from typing import Any, Dict, Iterable, Optional

from funcy import cached_property, first

Expand All @@ -18,7 +18,9 @@
)
from .executor.base import BaseExecutor, ExecutorInfo
from .queue.base import BaseStashQueue, QueueEntry
from .queue.local import LocalCeleryQueue, WorkspaceQueue
from .queue.celery import LocalCeleryQueue
from .queue.tempdir import TempDirQueue
from .queue.workspace import WorkspaceQueue
from .refs import (
CELERY_FAILED_STASH,
CELERY_STASH,
Expand Down Expand Up @@ -78,6 +80,12 @@ def args_file(self):
def workspace_queue(self) -> WorkspaceQueue:
return WorkspaceQueue(self.repo, WORKSPACE_STASH)

@cached_property
def tempdir_queue(self) -> TempDirQueue:
# NOTE: tempdir and workspace stash is shared since both
# implementations immediately push -> pop (queue length is only 0 or 1)
return TempDirQueue(self.repo, WORKSPACE_STASH)

@cached_property
def celery_queue(self) -> LocalCeleryQueue:
return LocalCeleryQueue(self.repo, CELERY_STASH, CELERY_FAILED_STASH)
Expand All @@ -91,21 +99,12 @@ def stash_revs(self) -> Dict[str, ExpStashEntry]:

def reproduce_one(
self,
queue: bool = False,
tmp_dir: bool = False,
checkpoint_resume: Optional[str] = None,
reset: bool = False,
machine: Optional[str] = None,
**kwargs,
):
"""Reproduce and checkout a single experiment."""
if queue and not checkpoint_resume:
reset = True

if reset:
self.reset_checkpoints()

if not (queue or tmp_dir or machine):
"""Reproduce and checkout a single (standalone) experiment."""
if not (tmp_dir or machine):
staged, _, _ = self.scm.status(untracked_files="no")
if staged:
logger.warning(
Expand All @@ -114,6 +113,30 @@ def reproduce_one(
)
self.scm.reset()

exp_queue: BaseStashQueue = (
self.tempdir_queue if tmp_dir else self.workspace_queue
)
self.queue_one(exp_queue, **kwargs)
if machine:
# TODO: decide how to handle queued remote execution
raise NotImplementedError
results = self._reproduce_queue(exp_queue)
exp_rev = first(results)
if exp_rev is not None:
self._log_reproduced(results, tmp_dir=tmp_dir)
return results

def queue_one(
self,
queue: BaseStashQueue,
checkpoint_resume: Optional[str] = None,
reset: bool = False,
**kwargs,
) -> QueueEntry:
"""Queue a single experiment."""
if reset:
self.reset_checkpoints()

if checkpoint_resume:
from dvc.scm import resolve_rev

Expand All @@ -129,29 +152,12 @@ def reproduce_one(
else:
checkpoint_resume = self._workspace_resume_rev()

exp_queue = (
self.celery_queue if (tmp_dir or queue) else self.workspace_queue
)
entry = self.new(
exp_queue,
return self.new(
queue,
checkpoint_resume=checkpoint_resume,
reset=reset,
**kwargs,
)
if queue:
name = entry.name or entry.stash_rev[:7]
ui.write(f"Queued experiment '{name}' for future execution.")
return [entry.stash_rev]
if tmp_dir:
return self.reproduce_celery(entries=[entry])
if machine:
# TODO: decide how to handle queued remote execution
raise NotImplementedError
results = self._reproduce_queue(exp_queue)
exp_rev = first(results)
if exp_rev is not None:
self._log_reproduced(results, tmp_dir=tmp_dir)
return results

def _workspace_resume_rev(self) -> Optional[str]:
last_checkpoint = self._get_last_checkpoint()
Expand All @@ -173,6 +179,10 @@ def reproduce_celery(
self.celery_queue.spawn_worker()
failed = []
try:
ui.write(
"Following logs for all queued experiments. Use Ctrl+C to "
"stop following logs (experiment execution will continue).\n"
)
for entry in entries:
# wait for task execution to start
while not self.celery_queue.proc.get(entry.stash_rev):
Expand Down Expand Up @@ -334,7 +344,7 @@ def reset_checkpoints(self):
@unlocked_repo
def _reproduce_queue(
self, queue: BaseStashQueue, **kwargs
) -> Mapping[str, str]:
) -> Dict[str, str]:
"""Reproduce queued experiments.

Arguments:
Expand Down Expand Up @@ -422,15 +432,20 @@ def get_running_exps(self, fetch_refs: bool = True) -> Dict[str, Any]:
"""Return info for running experiments."""
result = {}
infofile = self.workspace_queue.get_infofile_path("workspace")
result.update(self._get_running_exp("workspace", infofile, fetch_refs))
for entry in self.celery_queue.iter_active():
infofile = self.celery_queue.get_infofile_path(entry.stash_rev)
result.update(
self._get_running_exp(entry.stash_rev, infofile, fetch_refs)
)
result.update(
self._fetch_running_exp("workspace", infofile, fetch_refs)
)
for queue in (self.tempdir_queue, self.celery_queue):
for entry in queue.iter_active():
infofile = queue.get_infofile_path(entry.stash_rev)
result.update(
self._fetch_running_exp(
entry.stash_rev, infofile, fetch_refs
)
)
return result

def _get_running_exp(
def _fetch_running_exp(
self, rev: str, infofile: str, fetch_refs: bool
) -> Dict[str, Any]:
from dvc.scm import InvalidRemoteSCMRepo
Expand Down
17 changes: 12 additions & 5 deletions dvc/repo/experiments/executor/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from scmrepo.exceptions import SCMError as _SCMError

from dvc.scm import SCM, GitMergeError
from dvc.utils.fs import remove
from dvc.utils.fs import makedirs, remove

from ..refs import (
EXEC_APPLY,
Expand Down Expand Up @@ -65,7 +65,7 @@ class TempDirExecutor(BaseLocalExecutor):
# suggestions) that are not applicable outside of workspace runs
WARN_UNTRACKED = True
QUIET = True
DEFAULT_LOCATION = "temp"
DEFAULT_LOCATION = "tempdir"

def init_git(self, scm: "Git", branch: Optional[str] = None):
from dulwich.repo import Repo as DulwichRepo
Expand Down Expand Up @@ -120,11 +120,16 @@ def from_stash_entry(
repo: "Repo",
stash_rev: str,
entry: "ExpStashEntry",
wdir: Optional[str] = None,
**kwargs,
):
tmp_dir = mkdtemp(dir=os.path.join(repo.tmp_dir, EXEC_TMP_DIR))
parent_dir: str = wdir or os.path.join(repo.tmp_dir, EXEC_TMP_DIR)
makedirs(parent_dir, exist_ok=True)
tmp_dir = mkdtemp(dir=parent_dir)
try:
executor = cls._from_stash_entry(repo, stash_rev, entry, tmp_dir)
executor = cls._from_stash_entry(
repo, stash_rev, entry, tmp_dir, **kwargs
)
logger.debug("Init temp dir executor in '%s'", tmp_dir)
return executor
except Exception:
Expand All @@ -147,7 +152,9 @@ def from_stash_entry(
**kwargs,
):
root_dir = repo.scm.root_dir
executor = cls._from_stash_entry(repo, stash_rev, entry, root_dir)
executor = cls._from_stash_entry(
repo, stash_rev, entry, root_dir, **kwargs
)
logger.debug("Init workspace executor in '%s'", root_dir)
return executor

Expand Down
5 changes: 4 additions & 1 deletion dvc/repo/experiments/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ def setup_executor(
exp: "Experiments",
queue_entry: QueueEntry,
executor_cls: Type[BaseExecutor] = WorkspaceExecutor,
**kwargs,
) -> BaseExecutor:
scm = exp.scm
stash = ExpStash(scm, queue_entry.stash_ref)
Expand All @@ -563,7 +564,9 @@ def setup_executor(
# EXEC_MERGE - the unmerged changes (from our stash)
# to be reproduced
# EXEC_BASELINE - the baseline commit for this experiment
return executor_cls.from_stash_entry(exp.repo, stash_rev, stash_entry)
return executor_cls.from_stash_entry(
exp.repo, stash_rev, stash_entry, **kwargs
)

def get_infofile_path(self, name: str) -> str:
return os.path.join(
Expand Down
Loading