Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove data variable inference API #196

Merged
merged 2 commits into from
Jan 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ Top-level API
dataset.has_cf_compliant_time
dataset.decode_non_cf_time
dataset.swap_lon_axis
dataset.infer_or_keep_var
dataset.get_inferred_var

.. currentmodule:: xarray

Expand Down
69 changes: 10 additions & 59 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@

from tests.fixtures import generate_dataset
from xcdat.dataset import (
_keep_single_var,
_preprocess_non_cf_dataset,
_split_time_units_attr,
decode_non_cf_time,
get_inferred_var,
has_cf_compliant_time,
infer_or_keep_var,
open_dataset,
open_mfdataset,
)
Expand Down Expand Up @@ -45,7 +44,6 @@ def test_only_keeps_specified_var(self):

result = open_dataset(self.file_path, data_var="ts")
expected = ds.copy()
expected.attrs["xcdat_infer"] = "ts"
assert result.identical(expected)

def test_non_cf_compliant_time_is_not_decoded(self):
Expand All @@ -54,8 +52,6 @@ def test_non_cf_compliant_time_is_not_decoded(self):

result = open_dataset(self.file_path, decode_times=False)
expected = generate_dataset(cf_compliant=False, has_bounds=True)
expected.attrs["xcdat_infer"] = "ts"

assert result.identical(expected)

def test_non_cf_compliant_time_is_decoded(self):
Expand All @@ -64,7 +60,6 @@ def test_non_cf_compliant_time_is_decoded(self):

result = open_dataset(self.file_path, data_var="ts")
expected = generate_dataset(cf_compliant=True, has_bounds=True)
expected.attrs["xcdat_infer"] = "ts"
expected.time.attrs["calendar"] = "standard"
expected.time.attrs["units"] = "months since 2000-01-01"
expected.time.encoding = {
Expand All @@ -89,7 +84,6 @@ def test_preserves_lat_and_lon_bounds_if_they_exist(self):

result = open_dataset(self.file_path, data_var="ts")
expected = ds.copy()
expected.attrs["xcdat_infer"] = "ts"

assert result.identical(expected)

Expand Down Expand Up @@ -165,7 +159,6 @@ def test_swaps_from_180_to_360_and_sorts_with_prime_meridian_cell(self):
attrs={"is_generated": "True"},
),
},
attrs={"xcdat_infer": "None"},
)
assert result.identical(expected)

Expand All @@ -189,7 +182,6 @@ def test_only_keeps_specified_var(self):

result = open_mfdataset([self.file_path1, self.file_path2], data_var="ts")
expected = generate_dataset(cf_compliant=True, has_bounds=True)
expected.attrs["xcdat_infer"] = "ts"
expected.time.attrs["calendar"] = "standard"
expected.time.attrs["units"] = "months since 2000-01-01"

Expand All @@ -214,7 +206,6 @@ def test_non_cf_compliant_time_is_not_decoded(self):
result = open_mfdataset([self.file_path1, self.file_path2], decode_times=False)

expected = ds1.merge(ds2)
expected.attrs["xcdat_infer"] = "None"
assert result.identical(expected)

def test_non_cf_compliant_time_is_decoded(self):
Expand All @@ -227,7 +218,6 @@ def test_non_cf_compliant_time_is_decoded(self):

result = open_mfdataset([self.file_path1, self.file_path2], data_var="ts")
expected = generate_dataset(cf_compliant=True, has_bounds=True)
expected.attrs["xcdat_infer"] = "ts"
expected.time.attrs["units"] = "months since 2000-01-01"
expected.time.attrs["calendar"] = "standard"
expected.time.encoding = {
Expand All @@ -254,7 +244,6 @@ def test_preserves_lat_and_lon_bounds_if_they_exist(self):
ds2.to_netcdf(self.file_path2)

expected = generate_dataset(cf_compliant=True, has_bounds=True)
expected.attrs["xcdat_infer"] = "ts"
result = open_mfdataset([self.file_path1, self.file_path2], data_var="ts")
assert result.identical(expected)

Expand Down Expand Up @@ -333,7 +322,6 @@ def test_swaps_from_180_to_360_and_sorts_with_prime_meridian_cell(self):
attrs={"is_generated": "True"},
),
},
attrs={"xcdat_infer": "None"},
)
assert result.identical(expected)

Expand Down Expand Up @@ -742,7 +730,7 @@ def test_decodes_years_with_a_reference_date_on_a_leap_year(self):
assert result.time_bnds.encoding == expected.time_bnds.encoding


class TestInferOrKeepVar:
class TestKeepSingleVar:
@pytest.fixture(autouse=True)
def setup(self):
self.ds = generate_dataset(cf_compliant=True, has_bounds=True)
Expand All @@ -756,25 +744,24 @@ def tests_raises_logger_debug_if_only_bounds_data_variables_exist(self, caplog):
ds = self.ds.copy()
ds = ds.drop_vars("ts")

infer_or_keep_var(ds, data_var=None)
_keep_single_var(ds, data_var=None)
assert "This dataset only contains bounds data variables." in caplog.text

def test_raises_error_if_specified_data_var_does_not_exist(self):
ds = self.ds_mod.copy()
with pytest.raises(KeyError):
infer_or_keep_var(ds, data_var="nonexistent")
_keep_single_var(ds, data_var="nonexistent")

def test_raises_error_if_specified_data_var_is_a_bounds_var(self):
ds = self.ds_mod.copy()
with pytest.raises(KeyError):
infer_or_keep_var(ds, data_var="lat_bnds")
_keep_single_var(ds, data_var="lat_bnds")

def test_returns_dataset_if_it_only_has_one_non_bounds_data_var(self):
ds = self.ds.copy()

result = infer_or_keep_var(ds, data_var=None)
result = _keep_single_var(ds, data_var=None)
expected = ds.copy()
expected.attrs["xcdat_infer"] = "ts"

assert result.identical(expected)

Expand All @@ -784,65 +771,29 @@ def test_returns_dataset_if_it_contains_multiple_non_bounds_data_var_with_logger
caplog.set_level(logging.DEBUG)

ds = self.ds_mod.copy()
result = infer_or_keep_var(ds, data_var=None)
result = _keep_single_var(ds, data_var=None)
expected = ds.copy()
expected.attrs["xcdat_infer"] = "None"

assert result.identical(expected)
assert (
"This dataset contains more than one regular data variable: ['tas', 'ts']. "
"If desired, pass the `data_var` kwarg to limit the dataset to a single data var."
) in caplog.text

def test_returns_dataset_with_specified_data_var_and_inference_attr(self):
result = infer_or_keep_var(self.ds_mod, data_var="ts")
def test_returns_dataset_with_specified_data_var(self):
result = _keep_single_var(self.ds_mod, data_var="ts")
expected = self.ds.copy()
expected.attrs["xcdat_infer"] = "ts"

assert result.identical(expected)
assert not result.identical(self.ds_mod)

def test_bounds_always_persist(self):
ds = infer_or_keep_var(self.ds_mod, data_var="ts")
ds = _keep_single_var(self.ds_mod, data_var="ts")
assert ds.get("lat_bnds") is not None
assert ds.get("lon_bnds") is not None
assert ds.get("time_bnds") is not None


class TestGetInferredVar:
@pytest.fixture(autouse=True)
def setup(self):
self.ds = generate_dataset(cf_compliant=True, has_bounds=True)

def test_raises_error_if_inference_attr_is_none(self):
with pytest.raises(KeyError):
get_inferred_var(self.ds)

def test_raises_error_if_inference_attr_is_set_to_nonexistent_data_var(self):
ds = self.ds.copy()
ds.attrs["xcdat_infer"] = "nonexistent_var"

with pytest.raises(KeyError):
get_inferred_var(ds)

def test_raises_error_if_inference_attr_is_set_to_bounds_var(self):
ds = self.ds.copy()
ds.attrs["xcdat_infer"] = "lat_bnds"

with pytest.raises(KeyError):
get_inferred_var(ds)

def test_returns_inferred_data_var(self):

ds = self.ds.copy()
ds.attrs["xcdat_infer"] = "ts"

result = get_inferred_var(ds)
expected = ds.ts

assert result.identical(expected)


class TestPreProcessNonCFDataset:
@pytest.fixture(autouse=True)
def setup(self):
Expand Down
23 changes: 1 addition & 22 deletions tests/test_spatial_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,28 +69,7 @@ def test_spatial_average_for_lat_and_lon_region_using_custom_weights(self):

assert result.identical(expected)

def test_spatial_average_for_lat_and_lon_region_for_an_inferred_data_var(self):
ds = self.ds.copy()
ds.attrs["xcdat_infer"] = "ts"

# `data_var` kwarg is not specified, so an inference is attempted
result = ds.spatial.spatial_avg(
axis=["lat", "lon"], lat_bounds=(-5.0, 5), lon_bounds=(-170, -120.1)
)

expected = self.ds.copy()
expected.attrs["xcdat_infer"] = "ts"
expected["ts"] = xr.DataArray(
data=np.array([2.25, 1.0, 1.0]),
coords={"time": expected.time},
dims="time",
)

assert result.identical(expected)

def test_spatial_average_for_lat_and_lon_region_for_explicit_data_var(
self,
):
def test_spatial_average_for_lat_and_lon_region(self):
ds = self.ds.copy()
result = ds.spatial.spatial_avg(
"ts", axis=["lat", "lon"], lat_bounds=(-5.0, 5), lon_bounds=(-170, -120.1)
Expand Down
2 changes: 0 additions & 2 deletions xcdat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
from xcdat.bounds import BoundsAccessor # noqa: F401
from xcdat.dataset import ( # noqa: F401
decode_non_cf_time,
get_inferred_var,
has_cf_compliant_time,
infer_or_keep_var,
open_dataset,
open_mfdataset,
)
Expand Down
90 changes: 6 additions & 84 deletions xcdat/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def open_dataset(
else:
ds = xr.open_dataset(path, decode_times=False, **kwargs)

ds = infer_or_keep_var(ds, data_var)
ds = _keep_single_var(ds, data_var)
ds = ds.bounds.add_missing_bounds()
if ds.cf.dims.get("X") is not None and lon_orient is not None:
ds = swap_lon_axis(ds, to=lon_orient, sort_ascending=True)
Expand Down Expand Up @@ -233,7 +233,7 @@ def open_mfdataset(
preprocess=preprocess,
**kwargs,
)
ds = infer_or_keep_var(ds, data_var)
ds = _keep_single_var(ds, data_var)
ds = ds.bounds.add_missing_bounds()
if ds.cf.dims.get("X") is not None and lon_orient is not None:
ds = swap_lon_axis(ds, to=lon_orient, sort_ascending=True)
Expand Down Expand Up @@ -431,27 +431,13 @@ def decode_non_cf_time(dataset: xr.Dataset) -> xr.Dataset:
return dataset


def infer_or_keep_var(dataset: xr.Dataset, data_var: Optional[str]) -> xr.Dataset:
"""Infer or explicitly keep a specific data variable in the Dataset.

If ``data_var`` is None, then this function checks the number of
regular (non-bounds) data variables in the Dataset. If there is a single
regular data var, then it will add an 'xcdat_infer' attr pointing to it in
the Dataset. XCDAT APIs can then call `get_inferred_var()` to get the data
var linked to the 'xcdat_infer' attr. If there are multiple regular data
variables, the 'xcdat_infer' attr is not set and the Dataset is returned
as is.
def _keep_single_var(dataset: xr.Dataset, data_var: Optional[str]) -> xr.Dataset:
"""Keep a single data variable in the Dataset.

If ``data_var`` is not None, then this function checks if the ``data_var``
exists in the Dataset and if it is a regular data var. If those checks pass,
it will subset the Dataset to retain that ``data_var`` and all bounds data
vars. An 'xcdat_infer' attr pointing to the ``data_var`` is also added
to the Dataset.

This utility function is useful for designing XCDAT APIs with an optional
``data_var`` kwarg. If ``data_var`` is None, an inference to the desired
data var is performed with a call to this function. Otherwise, perform the
API operation explicitly on ``data_var``.
vars.

Parameters
----------
Expand All @@ -473,10 +459,6 @@ def infer_or_keep_var(dataset: xr.Dataset, data_var: Optional[str]) -> xr.Datase
If the user specifies a bounds variable to keep.
"""
ds = dataset.copy()
# Make sure the "xcdat_infer" attr is "None" because a Dataset may be
# written with this attr already set.
ds.attrs["xcdat_infer"] = "None"

all_vars = ds.data_vars.keys()
bounds_vars = ds.bounds.names
regular_vars = sorted(list(set(all_vars) ^ set(bounds_vars)))
Expand All @@ -485,9 +467,7 @@ def infer_or_keep_var(dataset: xr.Dataset, data_var: Optional[str]) -> xr.Datase
logger.debug("This dataset only contains bounds data variables.")

if data_var is None:
if len(regular_vars) == 1:
ds.attrs["xcdat_infer"] = regular_vars[0]
elif len(regular_vars) > 1:
if len(regular_vars) > 1:
logger.debug(
"This dataset contains more than one regular data variable: "
f"{regular_vars}. If desired, pass the `data_var` kwarg to "
Expand All @@ -502,68 +482,10 @@ def infer_or_keep_var(dataset: xr.Dataset, data_var: Optional[str]) -> xr.Datase
raise KeyError("Please specify a regular (non-bounds) data variable.")

ds = dataset[[data_var] + bounds_vars]
ds.attrs["xcdat_infer"] = data_var

return ds


def get_inferred_var(dataset: xr.Dataset) -> xr.DataArray:
"""Gets the inferred data variable that is tagged in the Dataset.

This function looks for the "xcdat_infer" attribute pointing
to the desired data var in the Dataset, which can be set through
``xcdat.open_dataset()``, ``xcdat.open_mf_dataset()``, or manually.

This utility function is useful for designing XCDAT APIs with an optional
``data_var`` kwarg. If ``data_var`` is None, an inference to the desired
data var is performed with a call to this function. Otherwise, perform the
API operation explicitly on ``data_var``.

Parameters
----------
dataset : xr.Dataset
The Dataset.

Returns
-------
xr.DataArray
The inferred data variable.

Raises
------
KeyError
If the 'xcdat_infer' attr is not set in the Dataset.
KeyError
If the 'xcdat_infer' attr points to a non-existent data var.
KeyError
If the 'xcdat_infer' attr points to a bounds data var.
"""
inferred_var = dataset.attrs.get("xcdat_infer", None)
bounds_vars = dataset.bounds.names

if inferred_var is None:
raise KeyError(
"Dataset attr 'xcdat_infer' is not set so the desired data variable "
"cannot be inferred. You must pass the `data_var` kwarg to this operation."
)
else:
data_var = dataset.get(inferred_var, None)
if data_var is None:
raise KeyError(
"Dataset attr 'xcdat_infer' is set to non-existent data variable, "
f"'{inferred_var}'. Either pass the `data_var` kwarg to this operation, "
"or set 'xcdat_infer' to a regular (non-bounds) data variable."
)
if inferred_var in bounds_vars:
raise KeyError(
"Dataset attr `xcdat_infer` is set to the bounds data variable, "
f"'{inferred_var}'. Either pass the `data_var` kwarg, or set "
"'xcdat_infer' to a regular (non-bounds) data variable."
)

return data_var.copy()


def _preprocess_non_cf_dataset(
ds: xr.Dataset, callable: Optional[Callable] = None
) -> xr.Dataset:
Expand Down
Loading