Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

batch dump and batch update of old stages #10630

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions dvc/dvcfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,20 @@ def dump(self, stage, update_pipeline=True, update_lock=True, **kwargs):
if update_lock:
self._dump_lockfile(stage, **kwargs)

def batch_dump(self, stages, update_pipeline=True, update_lock=True, **kwargs):
"""Batch-dumps given stages appropriately in the dvcfile."""
from dvc.stage import PipelineStage

for stage in stages:
assert isinstance(stage, PipelineStage)
if self.verify:
check_dvcfile_path(self.repo, self.path)
if update_pipeline and not stage.is_data_source:
self._dump_pipeline_file(stage)

if update_lock:
self._batch_dump_lockfile(stages, **kwargs)

def dump_dataset(self, dataset):
with modify_yaml(self.path, fs=self.repo.fs) as data:
parsed = self.datasets if data else []
Expand All @@ -263,6 +277,9 @@ def dump_dataset(self, dataset):
def _dump_lockfile(self, stage, **kwargs):
self._lockfile.dump(stage, **kwargs)

def _batch_dump_lockfile(self, stages, **kwargs):
self._lockfile.batch_dump(stages, **kwargs)

@staticmethod
def _check_if_parametrized(stage, action: str = "dump") -> None:
if stage.raw_data.parametrized:
Expand Down Expand Up @@ -400,26 +417,34 @@ def dump_dataset(self, dataset: dict):
self.repo.scm_context.track_file(self.relpath)

def dump(self, stage, **kwargs):
stage_data = serialize.to_lockfile(stage, **kwargs)

with modify_yaml(self.path, fs=self.repo.fs) as data:
if not data:
data.update({"schema": "2.0"})
# order is important, meta should always be at the top
logger.info("Generating lock file '%s'", self.relpath)
self.dump_stage_in_data(stage, data, **kwargs)

data["stages"] = data.get("stages", {})
modified = data["stages"].get(stage.name, {}) != stage_data.get(
stage.name, {}
)
if modified:
logger.info("Updating lock file '%s'", self.relpath)
def dump_stage_in_data(self, stage, data, **kwargs):
stage_data = serialize.to_lockfile(stage, **kwargs)
data["stages"] = data.get("stages", {})
modified = data["stages"].get(stage.name, {}) != stage_data.get(stage.name, {})
if modified:
logger.info("Updating lock file '%s'", self.relpath)

data["stages"].update(stage_data)
data["stages"].update(stage_data)

if modified:
self.repo.scm_context.track_file(self.relpath)

def batch_dump(self, stages):
with modify_yaml(self.path, fs=self.repo.fs) as data:
if not data:
data.update({"schema": "2.0"})
# order is important, meta should always be at the top
logger.info("Generating lock file '%s'", self.relpath)
for stage in stages:
self.dump_stage_in_data(stage, data)

def remove_stage(self, stage):
if not self.exists():
return
Expand Down
52 changes: 40 additions & 12 deletions dvc/repo/commit.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,28 +52,56 @@ def commit(
data_only=False,
relink=True,
):
from dvc.dvcfile import ProjectFile

stages_info = [
info
for info in self.stage.collect_granular(
target, with_deps=with_deps, recursive=recursive
)
if not data_only or info.stage.is_data_source
]

stages_by_dvcfile = {}
for stage_info in stages_info:
stage = stage_info.stage
if force:
stage.save(allow_missing=allow_missing)
if stage.dvcfile.path not in stages_by_dvcfile:
stages_by_dvcfile[stage.dvcfile.path] = {
"dvcfile": stage.dvcfile,
"stages": [],
}
stages_by_dvcfile[stage.dvcfile.path]["stages"].append(stage)

for val in stages_by_dvcfile.values():
dvcfile = val["stages"][0].dvcfile
if isinstance(dvcfile, ProjectFile):
dvcfile._reset()
old_stages = dict(dvcfile.stages)
else:
changes = stage.changed_entries()
if any(changes):
prompt_to_commit(stage, changes, force=force)
stage.save(allow_missing=allow_missing)
stage.commit(
filter_info=stage_info.filter_info,
allow_missing=allow_missing,
relink=relink,
)
stage.dump(update_pipeline=False)
old_stages = None

for stage in val["stages"]:
old_stage = old_stages[stage.name] if old_stages else None

if force:
stage.save(allow_missing=allow_missing, old_stage=old_stage)
else:
changes = stage.changed_entries()
if any(changes):
prompt_to_commit(stage, changes, force=force)
stage.save(allow_missing=allow_missing, old_stage=old_stage)
stage.commit(
filter_info=stage_info.filter_info,
allow_missing=allow_missing,
relink=relink,
)

if isinstance(dvcfile, ProjectFile):
dvcfile.batch_dump(val["stages"], update_pipeline=False)
else:
for stage in val["stages"]:
stage.dump(update_pipeline=False)

return [s.stage for s in stages_info]


Expand Down
21 changes: 11 additions & 10 deletions dvc/stage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,10 +489,10 @@ def compute_md5(self) -> Optional[str]:
logger.debug("Computed %s md5: '%s'", self, m)
return m

def save(self, allow_missing: bool = False, run_cache: bool = True):
def save(self, allow_missing: bool = False, run_cache: bool = True, old_stage=None):
self.save_deps(allow_missing=allow_missing)

self.save_outs(allow_missing=allow_missing)
self.save_outs(allow_missing=allow_missing, old_stage=old_stage)

self.md5 = self.compute_md5()

Expand All @@ -509,25 +509,26 @@ def save_deps(self, allow_missing=False):
if not allow_missing:
raise

def get_versioned_outs(self) -> dict[str, "Output"]:
def get_versioned_outs(self, old_stage=None) -> dict[str, "Output"]:
from .exceptions import StageFileDoesNotExistError, StageNotFound

try:
old = self.reload()
except (StageFileDoesNotExistError, StageNotFound):
return {}
if not old_stage:
try:
old_stage = self.reload()
except (StageFileDoesNotExistError, StageNotFound):
return {}

return {
out.def_path: out
for out in old.outs
for out in old_stage.outs
if out.files is not None
or (out.meta is not None and out.meta.version_id is not None)
}

def save_outs(self, allow_missing: bool = False):
def save_outs(self, allow_missing: bool = False, old_stage=None):
from dvc.output import OutputDoesNotExistError

old_versioned_outs = self.get_versioned_outs()
old_versioned_outs = self.get_versioned_outs(old_stage=old_stage)
for out in self.outs:
try:
out.save()
Expand Down
Loading