Skip to content

Commit

Permalink
parsing: use relative paths; force absolute paths to relative
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored and efiop committed Apr 18, 2022
1 parent 517f45d commit 5eb16ab
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 38 deletions.
13 changes: 12 additions & 1 deletion dvc/parsing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from collections.abc import Mapping, Sequence
from copy import deepcopy
from typing import (
Expand Down Expand Up @@ -134,9 +135,19 @@ def make_definition(

class DataResolver:
def __init__(self, repo: "Repo", wdir: str, d: dict):
from dvc.fs import LocalFileSystem

self.fs = fs = repo.fs

if os.path.isabs(wdir):
start = (
os.curdir if isinstance(fs, LocalFileSystem) else repo.root_dir
)
wdir = relpath(wdir, start)
wdir = "" if wdir == os.curdir else wdir

self.wdir = wdir
self.relpath = relpath(fs.path.join(self.wdir, "dvc.yaml"))
self.relpath = os.path.normpath(fs.path.join(self.wdir, "dvc.yaml"))

vars_ = d.get(VARS_KWD, [])
check_interpolations(vars_, VARS_KWD, self.relpath)
Expand Down
36 changes: 17 additions & 19 deletions dvc/parsing/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
recurse,
str_interpolate,
)
from dvc.utils import relpath

logger = logging.getLogger(__name__)
SeqOrMap = Union[Sequence, Mapping]
Expand Down Expand Up @@ -359,20 +358,19 @@ def load_from(
) -> "Context":
from dvc.utils.serialize import LOADERS

file = relpath(path)
if not fs.exists(path):
raise ParamsLoadError(f"'{file}' does not exist")
raise ParamsLoadError(f"'{path}' does not exist")
if fs.isdir(path):
raise ParamsLoadError(f"'{file}' is a directory")
raise ParamsLoadError(f"'{path}' is a directory")

_, ext = os.path.splitext(file)
_, ext = os.path.splitext(path)
loader = LOADERS[ext]

data = loader(path, fs=fs)
if not isinstance(data, Mapping):
typ = type(data).__name__
raise ParamsLoadError(
f"expected a dictionary, got '{typ}' in file '{file}'"
f"expected a dictionary, got '{typ}' in file '{path}'"
)

if select_keys:
Expand All @@ -381,12 +379,12 @@ def load_from(
except KeyError as exc:
key, *_ = exc.args
raise ParamsLoadError(
f"could not find '{key}' in '{file}'"
f"could not find '{key}' in '{path}'"
) from exc

meta = Meta(source=file, local=False)
meta = Meta(source=path, local=False)
ctx = cls(data, meta=meta)
ctx.imports[os.path.abspath(path)] = select_keys
ctx.imports[path] = select_keys
return ctx

def merge_update(self, other: "Context", overwrite=False):
Expand All @@ -397,26 +395,26 @@ def merge_update(self, other: "Context", overwrite=False):

def merge_from(self, fs, item: str, wdir: str, overwrite=False):
path, _, keys_str = item.partition(":")
select_keys = lfilter(bool, keys_str.split(",")) if keys_str else None
path = os.path.normpath(fs.path.join(wdir, path))

abspath = os.path.abspath(fs.path.join(wdir, path))
if abspath in self.imports:
if not select_keys and self.imports[abspath] is None:
select_keys = lfilter(bool, keys_str.split(",")) if keys_str else None
if path in self.imports:
if not select_keys and self.imports[path] is None:
return # allow specifying complete filepath multiple times
self.check_loaded(abspath, item, select_keys)
self.check_loaded(path, item, select_keys)

ctx = Context.load_from(fs, abspath, select_keys)
ctx = Context.load_from(fs, path, select_keys)

try:
self.merge_update(ctx, overwrite=overwrite)
except ReservedKeyError as exc:
raise ReservedKeyError(exc.keys, item) from exc

cp = ctx.imports[abspath]
if abspath not in self.imports:
self.imports[abspath] = cp
cp = ctx.imports[path]
if path not in self.imports:
self.imports[path] = cp
elif cp:
self.imports[abspath].extend(cp)
self.imports[path].extend(cp)

def check_loaded(self, path, item, keys):
if not keys and isinstance(self.imports[path], list):
Expand Down
17 changes: 16 additions & 1 deletion dvc/repo/params/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from dvc.scm import NoSCMError
from dvc.stage import PipelineStage
from dvc.ui import ui
from dvc.utils import error_handler, errored_revisions, onerror_collect
from dvc.utils import (
error_handler,
errored_revisions,
onerror_collect,
relpath,
)
from dvc.utils.serialize import LOADERS

if TYPE_CHECKING:
Expand Down Expand Up @@ -85,10 +90,20 @@ def _read_params(


def _collect_vars(repo, params) -> Dict:
from dvc.fs.git import GitFileSystem

vars_params: Dict[str, Dict] = defaultdict(dict)
rel_to_root = relpath(repo.root_dir)

for stage in repo.index.stages:
if isinstance(stage, PipelineStage) and stage.tracked_vars:
for file, vars_ in stage.tracked_vars.items():
if isinstance(repo.fs, GitFileSystem):
# GitFileSystem uses relatively-absolute paths from the
# root of the repo. We need to convert them to relative
# paths based on the current working directory.
file = os.path.normpath(os.path.join(rel_to_root, file))

# `params` file are shown regardless of `tracked` or not
# to reduce noise and duplication, they are skipped
if file in params:
Expand Down
4 changes: 1 addition & 3 deletions tests/func/parsing/test_foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,7 @@ def test_foreach_with_interpolated_wdir_and_local_vars(
}
},
}
assert resolver.context.imports == {
str(tmp_dir / DEFAULT_PARAMS_FILE): None
}
assert resolver.context.imports == {DEFAULT_PARAMS_FILE: None}


def test_foreach_do_syntax_is_checked_once(tmp_dir, dvc, mocker):
Expand Down
4 changes: 2 additions & 2 deletions tests/func/parsing/test_interpolated_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def test_with_templated_wdir(tmp_dir, dvc):
DEFAULT_PARAMS_FILE: {"dict.bar": "bar", "dict.ws": "data"},
}
}
assert resolver.context.imports == {str(tmp_dir / "params.yaml"): None}
assert resolver.context.imports == {"params.yaml": None}
assert resolver.context == {"dict": {"bar": "bar", "ws": "data"}}


Expand Down Expand Up @@ -236,7 +236,7 @@ def test_vars_relpath_overwrite(tmp_dir, dvc):
}
resolver = DataResolver(dvc, tmp_dir.fs_path, d)
resolver.resolve()
assert resolver.context.imports == {str(tmp_dir / "params.yaml"): None}
assert resolver.context.imports == {"params.yaml": None}


@pytest.mark.parametrize("local", [True, False])
Expand Down
30 changes: 30 additions & 0 deletions tests/func/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,3 +600,33 @@ def test_circular_import(tmp_dir, dvc, scm, erepo_dir):
erepo_dir.dvc.imp(
os.fspath(tmp_dir), "dir_imported", "circular_import"
)


@pytest.mark.parametrize("paths", ([], ["dir"]))
def test_parameterized_repo(tmp_dir, dvc, scm, erepo_dir, paths):
path = erepo_dir.joinpath(*paths)
path.mkdir(parents=True, exist_ok=True)
(path / "params.yaml").dump({"out": "foo"})
(path / "dvc.yaml").dump(
{
"stages": {
"train": {"cmd": "echo ${out} > ${out}", "outs": ["${out}"]},
}
}
)
path.gen({"foo": "foo"})
with path.chdir():
erepo_dir.dvc.commit(None, force=True)
erepo_dir.scm.add_commit(
["params.yaml", "dvc.yaml", "dvc.lock", ".gitignore"],
message="init",
)

to_import = os.path.join(*paths, "foo")
stage = dvc.imp(os.fspath(erepo_dir), to_import, "foo_imported")

assert (tmp_dir / "foo_imported").read_text() == "foo"
assert stage.deps[0].def_repo == {
"url": os.fspath(erepo_dir),
"rev_lock": erepo_dir.scm.get_rev(),
}
22 changes: 10 additions & 12 deletions tests/unit/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,14 +288,13 @@ def test_track(tmp_dir):
"dct": {"foo": "foo", "bar": "bar", "baz": "baz"},
}
fs = LocalFileSystem()
path = tmp_dir / "params.yaml"
path.dump(d, fs=fs)
(tmp_dir / "params.yaml").dump(d, fs=fs)

context = Context.load_from(fs, path)
context = Context.load_from(fs, "params.yaml")

def key_tracked(d, key):
assert len(d) == 1
return key in d[relpath(path)]
return key in d["params.yaml"]

with context.track() as tracked:
context.select("lst")
Expand Down Expand Up @@ -323,10 +322,10 @@ def test_track_from_multiple_files(tmp_dir):
d2 = {"Train": {"us": {"layers": 100}}}

fs = LocalFileSystem()
path1 = tmp_dir / "params.yaml"
path2 = tmp_dir / "params2.yaml"
path1.dump(d1, fs=fs)
path2.dump(d2, fs=fs)
path1 = "params.yaml"
path2 = "params2.yaml"
(tmp_dir / path1).dump(d1, fs=fs)
(tmp_dir / path2).dump(d2, fs=fs)

context = Context.load_from(fs, path1)
c = Context.load_from(fs, path2)
Expand Down Expand Up @@ -428,16 +427,15 @@ def test_resolve_resolves_boolean_value():

def test_load_from_raises_if_file_not_exist(tmp_dir, dvc):
with pytest.raises(ParamsLoadError) as exc_info:
Context.load_from(dvc.fs, tmp_dir / DEFAULT_PARAMS_FILE)
Context.load_from(dvc.fs, DEFAULT_PARAMS_FILE)

assert str(exc_info.value) == "'params.yaml' does not exist"


def test_load_from_raises_if_file_is_directory(tmp_dir, dvc):
data_dir = tmp_dir / "data"
data_dir.mkdir()
(tmp_dir / "data").mkdir()

with pytest.raises(ParamsLoadError) as exc_info:
Context.load_from(dvc.fs, data_dir)
Context.load_from(dvc.fs, "data")

assert str(exc_info.value) == "'data' is a directory"

0 comments on commit 5eb16ab

Please sign in to comment.