Skip to content
forked from pydata/xarray

Commit

Permalink
Add some extra tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Mar 5, 2021
1 parent 5f86df0 commit 61c5e43
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 23 deletions.
38 changes: 26 additions & 12 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2084,19 +2084,33 @@ def rolling_window(
var = self

if utils.is_scalar(dim):
for arg in [window, window_dim, center]:
assert utils.is_scalar(arg)
for name, arg in zip(
["window", "window_dim", "center"], [window, window_dim, center]
):
if not utils.is_scalar(arg):
raise ValueError(
f"Expected {name}={arg!r} to be a scalar like 'dim'."
)
dim = [dim]
window = [window]
window_dim = [window_dim]
center = [center]
else:
if len(dim) != len(window):
raise ValueError(
"'dim', 'window', 'window_dim', and 'center' must be the same length. "
f"Received dim={dim!r}, window={window!r}, window_dim={window_dim!r},"
f" and center={center!r}."
)

# dim is now a list
nroll = len(dim)
if utils.is_scalar(window):
window = [window] * nroll
if utils.is_scalar(window_dim):
window_dim = [window_dim] * nroll
if utils.is_scalar(center):
center = [center] * nroll
if (
len(dim) != len(window)
or len(dim) != len(window_dim)
or len(dim) != len(center)
):
raise ValueError(
"'dim', 'window', 'window_dim', and 'center' must be the same length. "
f"Received dim={dim!r}, window={window!r}, window_dim={window_dim!r},"
f" and center={center!r}."
)

pads = {}
for d, win, cent in zip(dim, window, center):
Expand Down
46 changes: 35 additions & 11 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,27 @@ def test_nd_rolling(self, center, dims):
)
assert_equal(actual, expected)

@pytest.mark.parametrize(
("dim, window, window_dim, center"),
[
("x", [3, 3], "x_w", True),
("x", 3, ("x_w", "x_w"), True),
("x", 3, "x_w", [True, True]),
],
)
def test_rolling_window_errors(self, dim, window, window_dim, center):
x = self.cls(
("x", "y", "z"),
np.arange(7 * 6 * 8).reshape(7, 6, 8).astype(float),
)
with pytest.raises(ValueError):
x.rolling_window(
dim=dim,
window=window,
window_dim=window_dim,
center=center,
)


class TestVariable(VariableSubclassobjects):
cls = staticmethod(Variable)
Expand Down Expand Up @@ -2198,23 +2219,23 @@ def test_datetime64(self):

# These tests make use of multi-dimensional variables, which are not valid
# IndexVariable objects:
@pytest.mark.xfail
@pytest.mark.skip
def test_getitem_error(self):
super().test_getitem_error()

@pytest.mark.xfail
@pytest.mark.skip
def test_getitem_advanced(self):
super().test_getitem_advanced()

@pytest.mark.xfail
@pytest.mark.skip
def test_getitem_fancy(self):
super().test_getitem_fancy()

@pytest.mark.xfail
@pytest.mark.skip
def test_getitem_uint(self):
super().test_getitem_fancy()

@pytest.mark.xfail
@pytest.mark.skip
@pytest.mark.parametrize(
"mode",
[
Expand All @@ -2233,24 +2254,27 @@ def test_getitem_uint(self):
def test_pad(self, mode, xr_arg, np_arg):
super().test_pad(mode, xr_arg, np_arg)

@pytest.mark.xfail
@pytest.mark.parametrize("xr_arg, np_arg", _PAD_XR_NP_ARGS)
@pytest.mark.skip
def test_pad_constant_values(self, xr_arg, np_arg):
super().test_pad_constant_values(xr_arg, np_arg)

@pytest.mark.xfail
@pytest.mark.skip
def test_rolling_window(self):
super().test_rolling_window()

@pytest.mark.xfail
@pytest.mark.skip
def test_rolling_1d(self):
super().test_rolling_1d()

@pytest.mark.xfail
@pytest.mark.skip
def test_nd_rolling(self):
super().test_nd_rolling()

@pytest.mark.xfail
@pytest.mark.skip
def test_rolling_window_errors(self):
super().test_rolling_window_errors()

@pytest.mark.skip
def test_coarsen_2d(self):
super().test_coarsen_2d()

Expand Down

0 comments on commit 61c5e43

Please sign in to comment.