diff --git a/dvc/dvcfile.py b/dvc/dvcfile.py index 1a150e8c9d..c7f5f86162 100644 --- a/dvc/dvcfile.py +++ b/dvc/dvcfile.py @@ -240,6 +240,20 @@ def dump(self, stage, update_pipeline=True, update_lock=True, **kwargs): if update_lock: self._dump_lockfile(stage, **kwargs) + def batch_dump(self, stages, update_pipeline=True, update_lock=True, **kwargs): + """Batch-dumps given stages appropriately in the dvcfile.""" + from dvc.stage import PipelineStage + + for stage in stages: + assert isinstance(stage, PipelineStage) + if self.verify: + check_dvcfile_path(self.repo, self.path) + if update_pipeline and not stage.is_data_source: + self._dump_pipeline_file(stage) + + if update_lock: + self._batch_dump_lockfile(stages, **kwargs) + def dump_dataset(self, dataset): with modify_yaml(self.path, fs=self.repo.fs) as data: parsed = self.datasets if data else [] @@ -263,6 +277,9 @@ def dump_dataset(self, dataset): def _dump_lockfile(self, stage, **kwargs): self._lockfile.dump(stage, **kwargs) + def _batch_dump_lockfile(self, stages, **kwargs): + self._lockfile.batch_dump(stages, **kwargs) + @staticmethod def _check_if_parametrized(stage, action: str = "dump") -> None: if stage.raw_data.parametrized: @@ -400,26 +417,34 @@ def dump_dataset(self, dataset: dict): self.repo.scm_context.track_file(self.relpath) def dump(self, stage, **kwargs): - stage_data = serialize.to_lockfile(stage, **kwargs) - with modify_yaml(self.path, fs=self.repo.fs) as data: if not data: data.update({"schema": "2.0"}) # order is important, meta should always be at the top logger.info("Generating lock file '%s'", self.relpath) + self.dump_stage_in_data(stage, data, **kwargs) - data["stages"] = data.get("stages", {}) - modified = data["stages"].get(stage.name, {}) != stage_data.get( - stage.name, {} - ) - if modified: - logger.info("Updating lock file '%s'", self.relpath) + def dump_stage_in_data(self, stage, data, **kwargs): + stage_data = serialize.to_lockfile(stage, **kwargs) + data["stages"] = data.get("stages", {}) + modified = data["stages"].get(stage.name, {}) != stage_data.get(stage.name, {}) + if modified: + logger.info("Updating lock file '%s'", self.relpath) - data["stages"].update(stage_data) + data["stages"].update(stage_data) if modified: self.repo.scm_context.track_file(self.relpath) + def batch_dump(self, stages): + with modify_yaml(self.path, fs=self.repo.fs) as data: + if not data: + data.update({"schema": "2.0"}) + # order is important, meta should always be at the top + logger.info("Generating lock file '%s'", self.relpath) + for stage in stages: + self.dump_stage_in_data(stage, data) + def remove_stage(self, stage): if not self.exists(): return diff --git a/dvc/repo/commit.py b/dvc/repo/commit.py index cea727f207..8cf10cbf1e 100644 --- a/dvc/repo/commit.py +++ b/dvc/repo/commit.py @@ -52,6 +52,8 @@ def commit( data_only=False, relink=True, ): + from dvc.dvcfile import ProjectFile + stages_info = [ info for info in self.stage.collect_granular( @@ -59,21 +61,47 @@ def commit( ) if not data_only or info.stage.is_data_source ] + + stages_by_dvcfile = {} for stage_info in stages_info: stage = stage_info.stage - if force: - stage.save(allow_missing=allow_missing) + if stage.dvcfile.path not in stages_by_dvcfile: + stages_by_dvcfile[stage.dvcfile.path] = { + "dvcfile": stage.dvcfile, + "stages": [], + } + stages_by_dvcfile[stage.dvcfile.path]["stages"].append(stage) + + for val in stages_by_dvcfile.values(): + dvcfile = val["stages"][0].dvcfile + if isinstance(dvcfile, ProjectFile): + dvcfile._reset() + old_stages = dict(dvcfile.stages) else: - changes = stage.changed_entries() - if any(changes): - prompt_to_commit(stage, changes, force=force) - stage.save(allow_missing=allow_missing) - stage.commit( - filter_info=stage_info.filter_info, - allow_missing=allow_missing, - relink=relink, - ) - stage.dump(update_pipeline=False) + old_stages = None + + for stage in val["stages"]: + old_stage = old_stages[stage.name] if old_stages else None + + if force: + stage.save(allow_missing=allow_missing, old_stage=old_stage) + else: + changes = stage.changed_entries() + if any(changes): + prompt_to_commit(stage, changes, force=force) + stage.save(allow_missing=allow_missing, old_stage=old_stage) + stage.commit( + filter_info=stage_info.filter_info, + allow_missing=allow_missing, + relink=relink, + ) + + if isinstance(dvcfile, ProjectFile): + dvcfile.batch_dump(val["stages"], update_pipeline=False) + else: + for stage in val["stages"]: + stage.dump(update_pipeline=False) + return [s.stage for s in stages_info] diff --git a/dvc/stage/__init__.py b/dvc/stage/__init__.py index 92835fe4c5..fbff441fb8 100644 --- a/dvc/stage/__init__.py +++ b/dvc/stage/__init__.py @@ -489,10 +489,10 @@ def compute_md5(self) -> Optional[str]: logger.debug("Computed %s md5: '%s'", self, m) return m - def save(self, allow_missing: bool = False, run_cache: bool = True): + def save(self, allow_missing: bool = False, run_cache: bool = True, old_stage=None): self.save_deps(allow_missing=allow_missing) - self.save_outs(allow_missing=allow_missing) + self.save_outs(allow_missing=allow_missing, old_stage=old_stage) self.md5 = self.compute_md5() @@ -509,25 +509,26 @@ def save_deps(self, allow_missing=False): if not allow_missing: raise - def get_versioned_outs(self) -> dict[str, "Output"]: + def get_versioned_outs(self, old_stage=None) -> dict[str, "Output"]: from .exceptions import StageFileDoesNotExistError, StageNotFound - try: - old = self.reload() - except (StageFileDoesNotExistError, StageNotFound): - return {} + if not old_stage: + try: + old_stage = self.reload() + except (StageFileDoesNotExistError, StageNotFound): + return {} return { out.def_path: out - for out in old.outs + for out in old_stage.outs if out.files is not None or (out.meta is not None and out.meta.version_id is not None) } - def save_outs(self, allow_missing: bool = False): + def save_outs(self, allow_missing: bool = False, old_stage=None): from dvc.output import OutputDoesNotExistError - old_versioned_outs = self.get_versioned_outs() + old_versioned_outs = self.get_versioned_outs(old_stage=old_stage) for out in self.outs: try: out.save()