Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename accessor classes and methods for API consistency #142

Merged
merged 3 commits into from
Nov 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ New Features

- Add ``DatasetBoundsAccessor`` class for filling missing bounds,
returning mapping of bounds, returning names of bounds keys
- Add ``XCDATBoundsAccessor`` class for accessing xcdat public methods
- Add ``BoundsAccessor`` class for accessing xcdat public methods
from other accessor classes

- This will be probably be the API endpoint for most users, unless
Expand Down
6 changes: 3 additions & 3 deletions tests/test_bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
import xarray as xr

from tests.fixtures import generate_dataset, lat_bnds, lon_bnds, time_bnds
from xcdat.bounds import DatasetBoundsAccessor
from xcdat.bounds import BoundsAccessor


class TestDatasetBoundsAccessor:
class TestBoundsAccessor:
@pytest.fixture(autouse=True)
def setup(self):
self.ds = generate_dataset(cf_compliant=True, has_bounds=False)
self.ds_with_bnds = generate_dataset(cf_compliant=True, has_bounds=True)

def test__init__(self):
obj = DatasetBoundsAccessor(self.ds)
obj = BoundsAccessor(self.ds)
assert obj._dataset.identical(self.ds)

def test_decorator_call(self):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_non_cf_compliant_time_is_decoded(self):

result_ds = open_dataset(self.file_path, data_var="ts")
# Replicates decode_times=False, which adds units to "time" coordinate.
# Refer to xcdat.bounds.DatasetBoundsAccessor._add_bounds() for
# Refer to xcdat.bounds.BoundsAccessor._add_bounds() for
# how attributes propagate from coord to coord bounds.
result_ds["time_bnds"].attrs["units"] = "months since 2000-01-01"

Expand Down Expand Up @@ -129,7 +129,7 @@ def test_only_keeps_specified_var(self):
result_ds = open_mfdataset([self.file_path1, self.file_path2], data_var="ts")

# Replicates decode_times=False, which adds units to "time" coordinate.
# Refer to xcdat.bounds.DatasetBoundsAccessor._add_bounds() for
# Refer to xcdat.bounds.BoundsAccessor._add_bounds() for
# how attributes propagate from coord to coord bounds.
result_ds.time_bnds.attrs["units"] = "months since 2000-01-01"

Expand Down Expand Up @@ -160,7 +160,7 @@ def test_non_cf_compliant_time_is_decoded(self):

result_ds = open_mfdataset([self.file_path1, self.file_path2], data_var="ts")
# Replicates decode_times=False, which adds units to "time" coordinate.
# Refer to xcdat.bounds.DatasetBoundsAccessor._add_bounds() for
# Refer to xcdat.bounds.BoundsAccessor._add_bounds() for
# how attributes propagate from coord to coord bounds.
result_ds.time_bnds.attrs["units"] = "months since 2000-01-01"

Expand Down
32 changes: 25 additions & 7 deletions tests/test_spatial_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,29 @@
import pytest
import xarray as xr

import xcdat.spatial_avg # noqa: F401
from tests.fixtures import generate_dataset
from xcdat.spatial_avg import SpatialAverageAccessor


class TestSpatialAverage:
class TestSpatialAverageAcccessor:
@pytest.fixture(autouse=True)
def setup(self):
self.ds = generate_dataset(cf_compliant=True, has_bounds=True)

def test__init__(self):
ds = self.ds.copy()
obj = SpatialAverageAccessor(ds)

assert obj._dataset.identical(ds)

def test_decorator_call(self):
ds = self.ds.copy()
obj = ds.spatial

assert obj._dataset.identical(ds)


class TestSpatialAvg:
@pytest.fixture(autouse=True)
def setup(self):
self.ds = generate_dataset(cf_compliant=True, has_bounds=True)
Expand All @@ -20,7 +38,7 @@ def setup(self):

def test_raises_error_if_data_var_not_in_dataset(self):
with pytest.raises(KeyError):
self.ds.spatial.avg(
self.ds.spatial.spatial_avg(
"not_a_data_var",
axis=["lat", "incorrect_axess"],
)
Expand All @@ -32,7 +50,7 @@ def test_weighted_spatial_average_for_lat_and_lon_region_for_an_inferred_data_va
ds.attrs["xcdat_infer"] = "ts"

# `data_var` kwarg is not specified, so an inference is attempted
result = ds.spatial.avg(
result = ds.spatial.spatial_avg(
axis=["lat", "lon"], lat_bounds=(-5.0, 5), lon_bounds=(-170, -120.1)
)

Expand All @@ -50,7 +68,7 @@ def test_weighted_spatial_average_for_lat_and_lon_region_for_explicit_data_var(
self,
):
ds = self.ds.copy()
result = ds.spatial.avg(
result = ds.spatial.spatial_avg(
"ts", axis=["lat", "lon"], lat_bounds=(-5.0, 5), lon_bounds=(-170, -120.1)
)

Expand All @@ -67,7 +85,7 @@ def test_weighted_spatial_average_for_lat_region(self):
ds = self.ds.copy()

# Specifying axis as a str instead of list of str.
result = ds.spatial.avg(
result = ds.spatial.spatial_avg(
"ts", axis="lat", lat_bounds=(-5.0, 5), lon_bounds=(-170, -120.1)
)

Expand All @@ -86,7 +104,7 @@ def test_chunked_weighted_spatial_average_for_lat_region(self):
ds = self.ds.copy().chunk(2)

# Specifying axis as a str instead of list of str.
result = ds.spatial.avg(
result = ds.spatial.spatial_avg(
"ts", axis="lat", lat_bounds=(-5.0, 5), lon_bounds=(-170, -120.1)
)

Expand Down
2 changes: 2 additions & 0 deletions xcdat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Top-level package for xcdat."""
from xcdat.bounds import BoundsAccessor # noqa: F401
from xcdat.dataset import decode_time_units, open_dataset, open_mfdataset # noqa: F401
from xcdat.spatial_avg import SpatialAverageAccessor # noqa: F401
from xcdat.xcdat import XCDATAccessor # noqa: F401

__version__ = "0.1.0"
4 changes: 2 additions & 2 deletions xcdat/bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@


@xr.register_dataset_accessor("bounds")
class DatasetBoundsAccessor:
"""A class to represent the DatasetBoundsAccessor.
class BoundsAccessor:
"""A class to represent the BoundsAccessor.

Examples
---------
Expand Down
16 changes: 8 additions & 8 deletions xcdat/spatial_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@


@xr.register_dataset_accessor("spatial")
class DatasetSpatialAverageAccessor:
"""A class to represent the DatasetSpatialAverageAccessor."""
class SpatialAverageAccessor:
"""A class to represent the SpatialAverageAccessor."""

def __init__(self, dataset: xr.Dataset):
self._dataset: xr.Dataset = dataset

def avg(
def spatial_avg(
self,
data_var: Optional[str] = None,
axis: Union[List[SupportedAxes], SupportedAxes] = ["lat", "lon"],
Expand Down Expand Up @@ -103,17 +103,17 @@ def avg(

Get global average time series:

>>> ts_global = ds.spatial.avg("tas", axis=["lat", "lon"])["tas"]
>>> ts_global = ds.spatial.spatial_avg("tas", axis=["lat", "lon"])["tas"]

Get time series in Nino 3.4 domain:

>>> ts_n34 = ds.spatial.avg("ts", axis=["lat", "lon"],
>>> ts_n34 = ds.spatial.spatial_avg("ts", axis=["lat", "lon"],
>>> lat_bounds=(-5, 5),
>>> lon_bounds=(-170, -120))["ts"]

Get zonal mean time series:

>>> ts_zonal = ds.spatial.avg("tas", axis=['lon'])["tas"]
>>> ts_zonal = ds.spatial.spatial_avg("tas", axis=['lon'])["tas"]

Using custom weights for averaging:

Expand All @@ -124,7 +124,7 @@ def avg(
>>> dims=["lat", "lon"],
>>> )
>>>
>>> ts_global = ds.spatial.avg("tas", axis=["lat","lon"],
>>> ts_global = ds.spatial.spatial_avg("tas", axis=["lat","lon"],
>>> weights=weights)["tas"]
"""
dataset = self._dataset.copy()
Expand Down Expand Up @@ -603,7 +603,7 @@ def _validate_weights(
This methods checks for the dimensional alignment between the
``weights`` and ``data_var``. It assumes that ``data_var`` has the same
keys that are specified in ``axis``, which has already been validated
using ``self._validate_axis()`` in ``self.avg()``.
using ``self._validate_axis()`` in ``self.spatial_avg()``.

Parameters
----------
Expand Down
27 changes: 13 additions & 14 deletions xcdat/xcdat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

import xarray as xr

from xcdat.bounds import Coord, DatasetBoundsAccessor # noqa: F401
from xcdat.spatial_avg import DatasetSpatialAverageAccessor # noqa: F401
from xcdat.spatial_avg import RegionAxisBounds, SupportedAxes
from xcdat.bounds import BoundsAccessor, Coord
from xcdat.spatial_avg import RegionAxisBounds, SpatialAverageAccessor, SupportedAxes
from xcdat.utils import is_documented_by


Expand All @@ -32,7 +31,7 @@ class XCDATAccessor:
def __init__(self, dataset: xr.Dataset):
self._dataset: xr.Dataset = dataset

@is_documented_by(DatasetSpatialAverageAccessor.avg)
@is_documented_by(SpatialAverageAccessor.spatial_avg)
def spatial_avg(
self,
data_var: Optional[str] = None,
Expand All @@ -41,26 +40,26 @@ def spatial_avg(
lat_bounds: Optional[RegionAxisBounds] = None,
lon_bounds: Optional[RegionAxisBounds] = None,
) -> xr.Dataset:
obj = DatasetSpatialAverageAccessor(self._dataset)
return obj.avg(data_var, axis, weights, lat_bounds, lon_bounds)
obj = SpatialAverageAccessor(self._dataset)
return obj.spatial_avg(data_var, axis, weights, lat_bounds, lon_bounds)

@property # type: ignore
@is_documented_by(DatasetBoundsAccessor.bounds)
@is_documented_by(BoundsAccessor.bounds)
def bounds(self) -> Dict[str, Optional[xr.DataArray]]:
obj = DatasetBoundsAccessor(self._dataset)
obj = BoundsAccessor(self._dataset)
return obj.bounds

@is_documented_by(DatasetBoundsAccessor.fill_missing)
@is_documented_by(BoundsAccessor.fill_missing)
def fill_missing_bounds(self) -> xr.Dataset:
obj = DatasetBoundsAccessor(self._dataset)
obj = BoundsAccessor(self._dataset)
return obj.fill_missing()

@is_documented_by(DatasetBoundsAccessor.get_bounds)
@is_documented_by(BoundsAccessor.get_bounds)
def get_bounds(self, coord: Coord) -> xr.DataArray:
obj = DatasetBoundsAccessor(self._dataset)
obj = BoundsAccessor(self._dataset)
return obj.get_bounds(coord)

@is_documented_by(DatasetBoundsAccessor.add_bounds)
@is_documented_by(BoundsAccessor.add_bounds)
def add_bounds(self, coord: Coord, width: float = 0.5) -> xr.Dataset:
obj = DatasetBoundsAccessor(self._dataset)
obj = BoundsAccessor(self._dataset)
return obj.add_bounds(coord, width)