diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index e22e0019..a49761fb 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -255,6 +255,31 @@ def _get_measure(obj: Union[DataArray, Dataset], key: str) -> List[str]: return list(results) +def _get_bounds(obj: Union[DataArray, Dataset], key: str) -> List[str]: + """ + Translate from key (either CF key or variable name) to its bounds' variable names. + This function interprets the ``bounds`` attribute on DataArrays. + + Parameters + ---------- + obj : DataArray, Dataset + DataArray belonging to the coordinate to be checked + key : str + key to check for. + + Returns + ------- + List[str], Variable name(s) in parent xarray object that are bounds of `key` + """ + + results = set() + for var in apply_mapper(_get_all, obj, key, error=False, default=[key]): + if "bounds" in obj[var].attrs: + results |= {obj[var].attrs["bounds"]} + + return list(results) + + def _get_with_standard_name( obj: Union[DataArray, Dataset], name: Union[str, List[str]] ) -> List[str]: @@ -436,6 +461,14 @@ def _getattr( try: attribute: Union[Mapping, Callable] = getattr(obj, attr) except AttributeError: + if getattr( + CFDatasetAccessor if isinstance(obj, DataArray) else CFDataArrayAccessor, + attr, + None, + ): + raise AttributeError( + f"{obj.__class__.__name__+'.cf'!r} object has no attribute {attr!r}" + ) raise AttributeError( f"{attr!r} is not a valid attribute on the underlying xarray object." ) @@ -976,7 +1009,9 @@ def __repr__(self): coords = self._obj.coords dims = self._obj.dims - def make_text_section(subtitle, vardict, valid_values, default_keys=None): + def make_text_section(subtitle, attr, valid_values, default_keys=None): + + vardict = getattr(self, attr, {}) star = " * " tab = len(star) * " " @@ -1019,21 +1054,21 @@ def make_text_section(subtitle, vardict, valid_values, default_keys=None): return "\n".join(rows) + "\n" text = "Coordinates:" - text += make_text_section("CF Axes", self.axes, coords, _AXIS_NAMES) + text += make_text_section("CF Axes", "axes", coords, _AXIS_NAMES) + text += make_text_section("CF Coordinates", "coordinates", coords, _COORD_NAMES) text += make_text_section( - "CF Coordinates", self.coordinates, coords, _COORD_NAMES + "Cell Measures", "cell_measures", coords, _CELL_MEASURES ) - text += make_text_section( - "Cell Measures", self.cell_measures, coords, _CELL_MEASURES - ) - text += make_text_section("Standard Names", self.standard_names, coords) + text += make_text_section("Standard Names", "standard_names", coords) + text += make_text_section("Bounds", "bounds", coords) if isinstance(self._obj, Dataset): data_vars = self._obj.data_vars text += "\nData Variables:" text += make_text_section( - "Cell Measures", self.cell_measures, data_vars, _CELL_MEASURES + "Cell Measures", "cell_measures", data_vars, _CELL_MEASURES ) - text += make_text_section("Standard Names", self.standard_names, data_vars) + text += make_text_section("Standard Names", "standard_names", data_vars) + text += make_text_section("Bounds", "bounds", data_vars) return text @@ -1144,7 +1179,7 @@ def get_standard_names(self) -> List[str]: @property def standard_names(self) -> Dict[str, List[str]]: """ - Returns a sorted list of standard names in Dataset. + Returns a dictionary mapping standard names to variable names. Parameters ---------- @@ -1153,7 +1188,7 @@ def standard_names(self) -> Dict[str, List[str]]: Returns ------- - Dictionary of standard names in dataset + Dictionary mapping standard names to variable names. """ if isinstance(self._obj, Dataset): variables = self._obj.variables @@ -1480,6 +1515,26 @@ def __getitem__(self, key: Union[str, List[str]]) -> Union[DataArray, Dataset]: """ return _getitem(self, key) + @property + def bounds(self) -> Dict[str, List[str]]: + """ + Property that returns a dictionary mapping valid keys + to the variable names of their bounds. + + Returns + ------- + Dictionary mapping valid keys to the variable names of their bounds. + """ + + obj = self._obj + keys = self.keys() | set(obj.variables) + + vardict = { + key: apply_mapper(_get_bounds, obj, key, error=False) for key in keys + } + + return {k: sorted(v) for k, v in vardict.items() if v} + def get_bounds(self, key: str) -> DataArray: """ Get bounds variable corresponding to key. @@ -1493,12 +1548,8 @@ def get_bounds(self, key: str) -> DataArray: ------- DataArray """ - name = apply_mapper( - _single(_get_all), self._obj, key, error=False, default=[key] - )[0] - bounds = self._obj[name].attrs["bounds"] - obj = self._maybe_to_dataset() - return obj[bounds] + + return apply_mapper(_variables(_single(_get_bounds)), self._obj, key)[0] def get_bounds_dim_name(self, key: str) -> str: """ diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index 77253114..41bbb9f8 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -61,10 +61,14 @@ def test_repr(): * longitude: ['lon'] * time: ['time'] + - Bounds: n/a + Data Variables: - Cell Measures: area, volume: n/a - Standard Names: air_temperature: ['air'] + + - Bounds: n/a """ assert actual == dedent(expected) @@ -89,6 +93,8 @@ def test_repr(): - Standard Names: * latitude: ['lat'] * longitude: ['lon'] * time: ['time'] + + - Bounds: n/a """ assert actual == dedent(expected) @@ -108,11 +114,15 @@ def test_repr(): - Standard Names: n/a + - Bounds: n/a + Data Variables: - Cell Measures: area, volume: n/a - Standard Names: sea_water_potential_temperature: ['TEMP'] sea_water_x_velocity: ['UVEL'] + + - Bounds: n/a """ assert actual == dedent(expected) @@ -163,6 +173,8 @@ def test_cell_measures(): - Standard Names: air_temperature: ['air'] foo_std_name: ['foo'] + + - Bounds: n/a """ assert actual.endswith(dedent(expected)) @@ -625,6 +637,11 @@ def test_add_bounds(obj, dims): def test_bounds(): ds = airds.copy(deep=True).cf.add_bounds("lat") + + actual = ds.cf.bounds + expected = {"Y": ["lat_bounds"], "lat": ["lat_bounds"], "latitude": ["lat_bounds"]} + assert ds.cf.bounds == expected + actual = ds.cf[["lat"]] expected = ds[["lat", "lat_bounds"]] assert_identical(actual, expected) @@ -651,6 +668,19 @@ def test_bounds(): with pytest.warns(UserWarning, match="{'foo'} not found in object"): ds.cf[["air"]] + # Dataset has bounds + expected = """\ + - Bounds: Y: ['lat_bounds'] + lat: ['lat_bounds'] + latitude: ['lat_bounds'] + """ + assert dedent(expected) in ds.cf.__repr__() + + # DataArray does not have bounds + expected = airds.cf["air"].cf.__repr__() + actual = ds.cf["air"].cf.__repr__() + assert actual == expected + def test_bounds_to_vertices(): # All available diff --git a/doc/api.rst b/doc/api.rst index 4c987f51..49aa425c 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -63,6 +63,7 @@ Attributes :template: autosummary/accessor_attribute.rst Dataset.cf.axes + Dataset.cf.bounds Dataset.cf.cell_measures Dataset.cf.coordinates Dataset.cf.standard_names diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 1760fc56..3a54cbd1 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -6,6 +6,7 @@ What's New v0.5.2 (unreleased) =================== +- Added :py:attr:`Dataset.cf.axes` to return a dictionary mapping valid keys to the variable names of their bounds. By `Mattia Almansi`_. - :py:meth:`DataArray.cf.differentiate` and :py:meth:`Dataset.cf.differentiate` can optionally correct sign of the derivative by interpreting the ``"positive"`` attribute. By `Deepak Cherian`_.