Skip to content

Commit

Permalink
More consistency checks (#2859)
Browse files Browse the repository at this point in the history
* Enable additional invariant checks in xarray's test suite

* Tweak internal consistency checks

* Various small fixes

* Always use internal invariant checks

* Fix coordinates type from DataArray.transpose()
  • Loading branch information
shoyer authored Jun 18, 2019
1 parent 4c758e6 commit 145f25f
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 52 deletions.
2 changes: 1 addition & 1 deletion xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def _update_coords(self, coords):

self._data._variables = variables
self._data._coord_names.update(new_coord_names)
self._data._dims = dict(dims)
self._data._dims = dims
self._data._indexes = None

def __delitem__(self, key):
Expand Down
5 changes: 3 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
import warnings
from collections import OrderedDict
from typing import Any

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -67,7 +68,7 @@ def _infer_coords_and_dims(shape, coords, dims):
for dim, coord in zip(dims, coords):
var = as_variable(coord, name=dim)
var.dims = (dim,)
new_coords[dim] = var
new_coords[dim] = var.to_index_variable()

sizes = dict(zip(dims, shape))
for k, v in new_coords.items():
Expand Down Expand Up @@ -1442,7 +1443,7 @@ def transpose(self, *dims, transpose_coords=None) -> 'DataArray':

variable = self.variable.transpose(*dims)
if transpose_coords:
coords = {}
coords = OrderedDict() # type: OrderedDict[Any, Variable]
for name, coord in self.coords.items():
coord_dims = tuple(dim for dim in dims if dim in coord.dims)
coords[name] = coord.variable.transpose(*coord_dims)
Expand Down
30 changes: 18 additions & 12 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def calculate_dimensions(variables):
Returns dictionary mapping from dimension names to sizes. Raises ValueError
if any of the dimension sizes conflict.
"""
dims = OrderedDict()
dims = {}
last_used = {}
scalar_vars = set(k for k, v in variables.items() if not v.dims)
for k, var in variables.items():
Expand Down Expand Up @@ -692,7 +692,7 @@ def _construct_direct(cls, variables, coord_names, dims, attrs=None,

@classmethod
def _from_vars_and_coord_names(cls, variables, coord_names, attrs=None):
dims = dict(calculate_dimensions(variables))
dims = calculate_dimensions(variables)
return cls._construct_direct(variables, coord_names, dims, attrs)

# TODO(shoyer): renable type checking on this signature when pytype has a
Expand Down Expand Up @@ -753,18 +753,20 @@ def _replace_with_new_dims( # type: ignore
coord_names: set = None,
attrs: 'Optional[OrderedDict]' = __default,
indexes: 'Optional[OrderedDict[Any, pd.Index]]' = __default,
encoding: Optional[dict] = __default,
inplace: bool = False,
) -> T:
"""Replace variables with recalculated dimensions."""
dims = dict(calculate_dimensions(variables))
dims = calculate_dimensions(variables)
return self._replace(
variables, coord_names, dims, attrs, indexes, inplace=inplace)
variables, coord_names, dims, attrs, indexes, encoding,
inplace=inplace)

def _replace_vars_and_dims( # type: ignore
self: T,
variables: 'OrderedDict[Any, Variable]' = None,
coord_names: set = None,
dims: 'OrderedDict[Any, int]' = None,
dims: Dict[Any, int] = None,
attrs: 'Optional[OrderedDict]' = __default,
inplace: bool = False,
) -> T:
Expand Down Expand Up @@ -1080,6 +1082,7 @@ def __delitem__(self, key):
"""
del self._variables[key]
self._coord_names.discard(key)
self._dims = calculate_dimensions(self._variables)

# mutable objects should not be hashable
# https://github.com/python/mypy/issues/4266
Expand Down Expand Up @@ -2469,7 +2472,7 @@ def expand_dims(self, dim=None, axis=None, **dim_kwargs):
else:
# If dims includes a label of a non-dimension coordinate,
# it will be promoted to a 1D coordinate with a single value.
variables[k] = v.set_dims(k)
variables[k] = v.set_dims(k).to_index_variable()

new_dims = self._dims.copy()
new_dims.update(dim)
Expand Down Expand Up @@ -3556,12 +3559,15 @@ def from_dict(cls, d):
def _unary_op(f, keep_attrs=False):
@functools.wraps(f)
def func(self, *args, **kwargs):
ds = self.coords.to_dataset()
for k in self.data_vars:
ds._variables[k] = f(self._variables[k], *args, **kwargs)
if keep_attrs:
ds._attrs = self._attrs
return ds
variables = OrderedDict()
for k, v in self._variables.items():
if k in self._coord_names:
variables[k] = v
else:
variables[k] = f(v, *args, **kwargs)
attrs = self._attrs if keep_attrs else None
return self._replace_with_new_dims(
variables, attrs=attrs, encoding=None)

return func

Expand Down
2 changes: 1 addition & 1 deletion xarray/core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def merge_core(objs,
'coordinates or not in the merged result: %s'
% ambiguous_coords)

return variables, coord_names, dict(dims)
return variables, coord_names, dims


def merge(objects, compat='no_conflicts', join='outer', fill_value=dtypes.NA):
Expand Down
138 changes: 109 additions & 29 deletions xarray/testing.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
"""Testing functions exposed to the user API"""
from collections import OrderedDict
from typing import Hashable, Union

import numpy as np
import pandas as pd

from xarray.core import duck_array_ops, formatting
from xarray.core import duck_array_ops
from xarray.core import formatting
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.variable import IndexVariable, Variable
from xarray.core.indexes import default_indexes


Expand Down Expand Up @@ -48,12 +53,11 @@ def assert_equal(a, b):
assert_identical, assert_allclose, Dataset.equals, DataArray.equals,
numpy.testing.assert_array_equal
"""
import xarray as xr
__tracebackhide__ = True # noqa: F841
assert type(a) == type(b) # noqa
if isinstance(a, (xr.Variable, xr.DataArray)):
if isinstance(a, (Variable, DataArray)):
assert a.equals(b), formatting.diff_array_repr(a, b, 'equals')
elif isinstance(a, xr.Dataset):
elif isinstance(a, Dataset):
assert a.equals(b), formatting.diff_dataset_repr(a, b, 'equals')
else:
raise TypeError('{} not supported by assertion comparison'
Expand All @@ -77,15 +81,14 @@ def assert_identical(a, b):
--------
assert_equal, assert_allclose, Dataset.equals, DataArray.equals
"""
import xarray as xr
__tracebackhide__ = True # noqa: F841
assert type(a) == type(b) # noqa
if isinstance(a, xr.Variable):
if isinstance(a, Variable):
assert a.identical(b), formatting.diff_array_repr(a, b, 'identical')
elif isinstance(a, xr.DataArray):
elif isinstance(a, DataArray):
assert a.name == b.name
assert a.identical(b), formatting.diff_array_repr(a, b, 'identical')
elif isinstance(a, (xr.Dataset, xr.Variable)):
elif isinstance(a, (Dataset, Variable)):
assert a.identical(b), formatting.diff_dataset_repr(a, b, 'identical')
else:
raise TypeError('{} not supported by assertion comparison'
Expand Down Expand Up @@ -117,15 +120,14 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True):
--------
assert_identical, assert_equal, numpy.testing.assert_allclose
"""
import xarray as xr
__tracebackhide__ = True # noqa: F841
assert type(a) == type(b) # noqa
kwargs = dict(rtol=rtol, atol=atol, decode_bytes=decode_bytes)
if isinstance(a, xr.Variable):
if isinstance(a, Variable):
assert a.dims == b.dims
allclose = _data_allclose_or_equiv(a.values, b.values, **kwargs)
assert allclose, '{}\n{}'.format(a.values, b.values)
elif isinstance(a, xr.DataArray):
elif isinstance(a, DataArray):
assert_allclose(a.variable, b.variable, **kwargs)
assert set(a.coords) == set(b.coords)
for v in a.coords.variables:
Expand All @@ -135,7 +137,7 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True):
b.coords[v].values, **kwargs)
assert allclose, '{}\n{}'.format(a.coords[v].values,
b.coords[v].values)
elif isinstance(a, xr.Dataset):
elif isinstance(a, Dataset):
assert set(a.data_vars) == set(b.data_vars)
assert set(a.coords) == set(b.coords)
for k in list(a.variables) + list(a.coords):
Expand All @@ -147,14 +149,12 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True):


def _assert_indexes_invariants_checks(indexes, possible_coord_variables, dims):
import xarray as xr

assert isinstance(indexes, OrderedDict), indexes
assert all(isinstance(v, pd.Index) for v in indexes.values()), \
{k: type(v) for k, v in indexes.items()}

index_vars = {k for k, v in possible_coord_variables.items()
if isinstance(v, xr.IndexVariable)}
if isinstance(v, IndexVariable)}
assert indexes.keys() <= index_vars, (set(indexes), index_vars)

# Note: when we support non-default indexes, these checks should be opt-in
Expand All @@ -166,17 +166,97 @@ def _assert_indexes_invariants_checks(indexes, possible_coord_variables, dims):
(indexes, defaults)


def _assert_indexes_invariants(a):
"""Separate helper function for checking indexes invariants only."""
import xarray as xr

if isinstance(a, xr.DataArray):
if a._indexes is not None:
_assert_indexes_invariants_checks(a._indexes, a._coords, a.dims)
elif isinstance(a, xr.Dataset):
if a._indexes is not None:
_assert_indexes_invariants_checks(
a._indexes, a._variables, a._dims)
elif isinstance(a, xr.Variable):
# no indexes
pass
def _assert_variable_invariants(var: Variable, name: Hashable = None):
if name is None:
name_or_empty = () # type: tuple
else:
name_or_empty = (name,)
assert isinstance(var._dims, tuple), name_or_empty + (var._dims,)
assert len(var._dims) == len(var._data.shape), \
name_or_empty + (var._dims, var._data.shape)
assert isinstance(var._encoding, (type(None), dict)), \
name_or_empty + (var._encoding,)
assert isinstance(var._attrs, (type(None), OrderedDict)), \
name_or_empty + (var._attrs,)


def _assert_dataarray_invariants(da: DataArray):
assert isinstance(da._variable, Variable), da._variable
_assert_variable_invariants(da._variable)

assert isinstance(da._coords, OrderedDict), da._coords
assert all(
isinstance(v, Variable) for v in da._coords.values()), da._coords
assert all(set(v.dims) <= set(da.dims) for v in da._coords.values()), \
(da.dims, {k: v.dims for k, v in da._coords.items()})
assert all(isinstance(v, IndexVariable)
for (k, v) in da._coords.items()
if v.dims == (k,)), \
{k: type(v) for k, v in da._coords.items()}
for k, v in da._coords.items():
_assert_variable_invariants(v, k)

if da._indexes is not None:
_assert_indexes_invariants_checks(da._indexes, da._coords, da.dims)

assert da._initialized is True


def _assert_dataset_invariants(ds: Dataset):
assert isinstance(ds._variables, OrderedDict), type(ds._variables)
assert all(
isinstance(v, Variable) for v in ds._variables.values()), \
ds._variables
for k, v in ds._variables.items():
_assert_variable_invariants(v, k)

assert isinstance(ds._coord_names, set), ds._coord_names
assert ds._coord_names <= ds._variables.keys(), \
(ds._coord_names, set(ds._variables))

assert type(ds._dims) is dict, ds._dims
assert all(isinstance(v, int) for v in ds._dims.values()), ds._dims
var_dims = set() # type: set
for v in ds._variables.values():
var_dims.update(v.dims)
assert ds._dims.keys() == var_dims, (set(ds._dims), var_dims)
assert all(ds._dims[k] == v.sizes[k]
for v in ds._variables.values()
for k in v.sizes), \
(ds._dims, {k: v.sizes for k, v in ds._variables.items()})
assert all(isinstance(v, IndexVariable)
for (k, v) in ds._variables.items()
if v.dims == (k,)), \
{k: type(v) for k, v in ds._variables.items() if v.dims == (k,)}
assert all(v.dims == (k,)
for (k, v) in ds._variables.items()
if k in ds._dims), \
{k: v.dims for k, v in ds._variables.items() if k in ds._dims}

if ds._indexes is not None:
_assert_indexes_invariants_checks(ds._indexes, ds._variables, ds._dims)

assert isinstance(ds._encoding, (type(None), dict))
assert isinstance(ds._attrs, (type(None), OrderedDict))
assert ds._initialized is True


def _assert_internal_invariants(
xarray_obj: Union[DataArray, Dataset, Variable],
):
"""Validate that an xarray object satisfies its own internal invariants.
This exists for the benefit of xarray's own test suite, but may be useful
in external projects if they (ill-advisedly) create objects using xarray's
private APIs.
"""
if isinstance(xarray_obj, Variable):
_assert_variable_invariants(xarray_obj)
elif isinstance(xarray_obj, DataArray):
_assert_dataarray_invariants(xarray_obj)
elif isinstance(xarray_obj, Dataset):
_assert_dataset_invariants(xarray_obj)
else:
raise TypeError(
'{} is not a supported type for xarray invariant checks'
.format(type(xarray_obj)))
13 changes: 6 additions & 7 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,21 +168,20 @@ def source_ndarray(array):

# Internal versions of xarray's test functions that validate additional
# invariants
# TODO: add more invariant checks.

def assert_equal(a, b):
xarray.testing.assert_equal(a, b)
xarray.testing._assert_indexes_invariants(a)
xarray.testing._assert_indexes_invariants(b)
xarray.testing._assert_internal_invariants(a)
xarray.testing._assert_internal_invariants(b)


def assert_identical(a, b):
xarray.testing.assert_identical(a, b)
xarray.testing._assert_indexes_invariants(a)
xarray.testing._assert_indexes_invariants(b)
xarray.testing._assert_internal_invariants(a)
xarray.testing._assert_internal_invariants(b)


def assert_allclose(a, b, **kwargs):
xarray.testing.assert_allclose(a, b, **kwargs)
xarray.testing._assert_indexes_invariants(a)
xarray.testing._assert_indexes_invariants(b)
xarray.testing._assert_internal_invariants(a)
xarray.testing._assert_internal_invariants(b)
5 changes: 5 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2752,6 +2752,11 @@ def test_delitem(self):
assert set(data.variables) == all_items - set(['var1', 'numbers'])
assert 'numbers' not in data.coords

expected = Dataset()
actual = Dataset({'y': ('x', [1, 2])})
del actual['y']
assert_identical(expected, actual)

def test_squeeze(self):
data = Dataset({'foo': (['x', 'y', 'z'], [[[1], [2]]])})
for args in [[], [['x']], [['x', 'z']]]:
Expand Down

0 comments on commit 145f25f

Please sign in to comment.