diff --git a/xarray/core/variable.py b/xarray/core/variable.py index d6ef206e7b5..52bcb2c161c 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2084,19 +2084,33 @@ def rolling_window( var = self if utils.is_scalar(dim): - for arg in [window, window_dim, center]: - assert utils.is_scalar(arg) + for name, arg in zip( + ["window", "window_dim", "center"], [window, window_dim, center] + ): + if not utils.is_scalar(arg): + raise ValueError( + f"Expected {name}={arg!r} to be a scalar like 'dim'." + ) dim = [dim] - window = [window] - window_dim = [window_dim] - center = [center] - else: - if len(dim) != len(window): - raise ValueError( - "'dim', 'window', 'window_dim', and 'center' must be the same length. " - f"Received dim={dim!r}, window={window!r}, window_dim={window_dim!r}," - f" and center={center!r}." - ) + + # dim is now a list + nroll = len(dim) + if utils.is_scalar(window): + window = [window] * nroll + if utils.is_scalar(window_dim): + window_dim = [window_dim] * nroll + if utils.is_scalar(center): + center = [center] * nroll + if ( + len(dim) != len(window) + or len(dim) != len(window_dim) + or len(dim) != len(center) + ): + raise ValueError( + "'dim', 'window', 'window_dim', and 'center' must be the same length. " + f"Received dim={dim!r}, window={window!r}, window_dim={window_dim!r}," + f" and center={center!r}." + ) pads = {} for d, win, cent in zip(dim, window, center): diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index fa07714bba6..1aaaff9b2d4 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -948,6 +948,27 @@ def test_nd_rolling(self, center, dims): ) assert_equal(actual, expected) + @pytest.mark.parametrize( + ("dim, window, window_dim, center"), + [ + ("x", [3, 3], "x_w", True), + ("x", 3, ("x_w", "x_w"), True), + ("x", 3, "x_w", [True, True]), + ], + ) + def test_rolling_window_errors(self, dim, window, window_dim, center): + x = self.cls( + ("x", "y", "z"), + np.arange(7 * 6 * 8).reshape(7, 6, 8).astype(float), + ) + with pytest.raises(ValueError): + x.rolling_window( + dim=dim, + window=window, + window_dim=window_dim, + center=center, + ) + class TestVariable(VariableSubclassobjects): cls = staticmethod(Variable) @@ -2198,23 +2219,23 @@ def test_datetime64(self): # These tests make use of multi-dimensional variables, which are not valid # IndexVariable objects: - @pytest.mark.xfail + @pytest.mark.skip def test_getitem_error(self): super().test_getitem_error() - @pytest.mark.xfail + @pytest.mark.skip def test_getitem_advanced(self): super().test_getitem_advanced() - @pytest.mark.xfail + @pytest.mark.skip def test_getitem_fancy(self): super().test_getitem_fancy() - @pytest.mark.xfail + @pytest.mark.skip def test_getitem_uint(self): super().test_getitem_fancy() - @pytest.mark.xfail + @pytest.mark.skip @pytest.mark.parametrize( "mode", [ @@ -2233,24 +2254,27 @@ def test_getitem_uint(self): def test_pad(self, mode, xr_arg, np_arg): super().test_pad(mode, xr_arg, np_arg) - @pytest.mark.xfail - @pytest.mark.parametrize("xr_arg, np_arg", _PAD_XR_NP_ARGS) + @pytest.mark.skip def test_pad_constant_values(self, xr_arg, np_arg): super().test_pad_constant_values(xr_arg, np_arg) - @pytest.mark.xfail + @pytest.mark.skip def test_rolling_window(self): super().test_rolling_window() - @pytest.mark.xfail + @pytest.mark.skip def test_rolling_1d(self): super().test_rolling_1d() - @pytest.mark.xfail + @pytest.mark.skip def test_nd_rolling(self): super().test_nd_rolling() - @pytest.mark.xfail + @pytest.mark.skip + def test_rolling_window_errors(self): + super().test_rolling_window_errors() + + @pytest.mark.skip def test_coarsen_2d(self): super().test_coarsen_2d()