diff --git a/HISTORY.rst b/HISTORY.rst index c4b3b912..39098727 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -25,7 +25,7 @@ New Features - Add ``DatasetBoundsAccessor`` class for filling missing bounds, returning mapping of bounds, returning names of bounds keys -- Add ``XCDATBoundsAccessor`` class for accessing xcdat public methods +- Add ``BoundsAccessor`` class for accessing xcdat public methods from other accessor classes - This will be probably be the API endpoint for most users, unless diff --git a/tests/test_bounds.py b/tests/test_bounds.py index ff192b4d..d58b6ffd 100644 --- a/tests/test_bounds.py +++ b/tests/test_bounds.py @@ -3,17 +3,17 @@ import xarray as xr from tests.fixtures import generate_dataset, lat_bnds, lon_bnds, time_bnds -from xcdat.bounds import DatasetBoundsAccessor +from xcdat.bounds import BoundsAccessor -class TestDatasetBoundsAccessor: +class TestBoundsAccessor: @pytest.fixture(autouse=True) def setup(self): self.ds = generate_dataset(cf_compliant=True, has_bounds=False) self.ds_with_bnds = generate_dataset(cf_compliant=True, has_bounds=True) def test__init__(self): - obj = DatasetBoundsAccessor(self.ds) + obj = BoundsAccessor(self.ds) assert obj._dataset.identical(self.ds) def test_decorator_call(self): diff --git a/tests/test_dataset.py b/tests/test_dataset.py index c3ec5f2c..0a6e8221 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -54,7 +54,7 @@ def test_non_cf_compliant_time_is_decoded(self): result_ds = open_dataset(self.file_path, data_var="ts") # Replicates decode_times=False, which adds units to "time" coordinate. - # Refer to xcdat.bounds.DatasetBoundsAccessor._add_bounds() for + # Refer to xcdat.bounds.BoundsAccessor._add_bounds() for # how attributes propagate from coord to coord bounds. result_ds["time_bnds"].attrs["units"] = "months since 2000-01-01" @@ -129,7 +129,7 @@ def test_only_keeps_specified_var(self): result_ds = open_mfdataset([self.file_path1, self.file_path2], data_var="ts") # Replicates decode_times=False, which adds units to "time" coordinate. - # Refer to xcdat.bounds.DatasetBoundsAccessor._add_bounds() for + # Refer to xcdat.bounds.BoundsAccessor._add_bounds() for # how attributes propagate from coord to coord bounds. result_ds.time_bnds.attrs["units"] = "months since 2000-01-01" @@ -160,7 +160,7 @@ def test_non_cf_compliant_time_is_decoded(self): result_ds = open_mfdataset([self.file_path1, self.file_path2], data_var="ts") # Replicates decode_times=False, which adds units to "time" coordinate. - # Refer to xcdat.bounds.DatasetBoundsAccessor._add_bounds() for + # Refer to xcdat.bounds.BoundsAccessor._add_bounds() for # how attributes propagate from coord to coord bounds. result_ds.time_bnds.attrs["units"] = "months since 2000-01-01" diff --git a/tests/test_spatial_avg.py b/tests/test_spatial_avg.py index f0090938..d6f1bc48 100644 --- a/tests/test_spatial_avg.py +++ b/tests/test_spatial_avg.py @@ -2,11 +2,29 @@ import pytest import xarray as xr -import xcdat.spatial_avg # noqa: F401 from tests.fixtures import generate_dataset +from xcdat.spatial_avg import SpatialAverageAccessor -class TestSpatialAverage: +class TestSpatialAverageAcccessor: + @pytest.fixture(autouse=True) + def setup(self): + self.ds = generate_dataset(cf_compliant=True, has_bounds=True) + + def test__init__(self): + ds = self.ds.copy() + obj = SpatialAverageAccessor(ds) + + assert obj._dataset.identical(ds) + + def test_decorator_call(self): + ds = self.ds.copy() + obj = ds.spatial + + assert obj._dataset.identical(ds) + + +class TestSpatialAvg: @pytest.fixture(autouse=True) def setup(self): self.ds = generate_dataset(cf_compliant=True, has_bounds=True) @@ -20,7 +38,7 @@ def setup(self): def test_raises_error_if_data_var_not_in_dataset(self): with pytest.raises(KeyError): - self.ds.spatial.avg( + self.ds.spatial.spatial_avg( "not_a_data_var", axis=["lat", "incorrect_axess"], ) @@ -32,7 +50,7 @@ def test_weighted_spatial_average_for_lat_and_lon_region_for_an_inferred_data_va ds.attrs["xcdat_infer"] = "ts" # `data_var` kwarg is not specified, so an inference is attempted - result = ds.spatial.avg( + result = ds.spatial.spatial_avg( axis=["lat", "lon"], lat_bounds=(-5.0, 5), lon_bounds=(-170, -120.1) ) @@ -50,7 +68,7 @@ def test_weighted_spatial_average_for_lat_and_lon_region_for_explicit_data_var( self, ): ds = self.ds.copy() - result = ds.spatial.avg( + result = ds.spatial.spatial_avg( "ts", axis=["lat", "lon"], lat_bounds=(-5.0, 5), lon_bounds=(-170, -120.1) ) @@ -67,7 +85,7 @@ def test_weighted_spatial_average_for_lat_region(self): ds = self.ds.copy() # Specifying axis as a str instead of list of str. - result = ds.spatial.avg( + result = ds.spatial.spatial_avg( "ts", axis="lat", lat_bounds=(-5.0, 5), lon_bounds=(-170, -120.1) ) @@ -86,7 +104,7 @@ def test_chunked_weighted_spatial_average_for_lat_region(self): ds = self.ds.copy().chunk(2) # Specifying axis as a str instead of list of str. - result = ds.spatial.avg( + result = ds.spatial.spatial_avg( "ts", axis="lat", lat_bounds=(-5.0, 5), lon_bounds=(-170, -120.1) ) diff --git a/xcdat/__init__.py b/xcdat/__init__.py index fbf0a7a7..31515ed2 100644 --- a/xcdat/__init__.py +++ b/xcdat/__init__.py @@ -1,5 +1,7 @@ """Top-level package for xcdat.""" +from xcdat.bounds import BoundsAccessor # noqa: F401 from xcdat.dataset import decode_time_units, open_dataset, open_mfdataset # noqa: F401 +from xcdat.spatial_avg import SpatialAverageAccessor # noqa: F401 from xcdat.xcdat import XCDATAccessor # noqa: F401 __version__ = "0.1.0" diff --git a/xcdat/bounds.py b/xcdat/bounds.py index 2b4d31cb..be3cdfd2 100644 --- a/xcdat/bounds.py +++ b/xcdat/bounds.py @@ -17,8 +17,8 @@ @xr.register_dataset_accessor("bounds") -class DatasetBoundsAccessor: - """A class to represent the DatasetBoundsAccessor. +class BoundsAccessor: + """A class to represent the BoundsAccessor. Examples --------- diff --git a/xcdat/spatial_avg.py b/xcdat/spatial_avg.py index 8a7c2b2b..3060ebe5 100644 --- a/xcdat/spatial_avg.py +++ b/xcdat/spatial_avg.py @@ -21,13 +21,13 @@ @xr.register_dataset_accessor("spatial") -class DatasetSpatialAverageAccessor: - """A class to represent the DatasetSpatialAverageAccessor.""" +class SpatialAverageAccessor: + """A class to represent the SpatialAverageAccessor.""" def __init__(self, dataset: xr.Dataset): self._dataset: xr.Dataset = dataset - def avg( + def spatial_avg( self, data_var: Optional[str] = None, axis: Union[List[SupportedAxes], SupportedAxes] = ["lat", "lon"], @@ -103,17 +103,17 @@ def avg( Get global average time series: - >>> ts_global = ds.spatial.avg("tas", axis=["lat", "lon"])["tas"] + >>> ts_global = ds.spatial.spatial_avg("tas", axis=["lat", "lon"])["tas"] Get time series in Nino 3.4 domain: - >>> ts_n34 = ds.spatial.avg("ts", axis=["lat", "lon"], + >>> ts_n34 = ds.spatial.spatial_avg("ts", axis=["lat", "lon"], >>> lat_bounds=(-5, 5), >>> lon_bounds=(-170, -120))["ts"] Get zonal mean time series: - >>> ts_zonal = ds.spatial.avg("tas", axis=['lon'])["tas"] + >>> ts_zonal = ds.spatial.spatial_avg("tas", axis=['lon'])["tas"] Using custom weights for averaging: @@ -124,7 +124,7 @@ def avg( >>> dims=["lat", "lon"], >>> ) >>> - >>> ts_global = ds.spatial.avg("tas", axis=["lat","lon"], + >>> ts_global = ds.spatial.spatial_avg("tas", axis=["lat","lon"], >>> weights=weights)["tas"] """ dataset = self._dataset.copy() @@ -603,7 +603,7 @@ def _validate_weights( This methods checks for the dimensional alignment between the ``weights`` and ``data_var``. It assumes that ``data_var`` has the same keys that are specified in ``axis``, which has already been validated - using ``self._validate_axis()`` in ``self.avg()``. + using ``self._validate_axis()`` in ``self.spatial_avg()``. Parameters ---------- diff --git a/xcdat/xcdat.py b/xcdat/xcdat.py index 85a17cda..88615e3c 100644 --- a/xcdat/xcdat.py +++ b/xcdat/xcdat.py @@ -3,9 +3,8 @@ import xarray as xr -from xcdat.bounds import Coord, DatasetBoundsAccessor # noqa: F401 -from xcdat.spatial_avg import DatasetSpatialAverageAccessor # noqa: F401 -from xcdat.spatial_avg import RegionAxisBounds, SupportedAxes +from xcdat.bounds import BoundsAccessor, Coord +from xcdat.spatial_avg import RegionAxisBounds, SpatialAverageAccessor, SupportedAxes from xcdat.utils import is_documented_by @@ -32,7 +31,7 @@ class XCDATAccessor: def __init__(self, dataset: xr.Dataset): self._dataset: xr.Dataset = dataset - @is_documented_by(DatasetSpatialAverageAccessor.avg) + @is_documented_by(SpatialAverageAccessor.spatial_avg) def spatial_avg( self, data_var: Optional[str] = None, @@ -41,26 +40,26 @@ def spatial_avg( lat_bounds: Optional[RegionAxisBounds] = None, lon_bounds: Optional[RegionAxisBounds] = None, ) -> xr.Dataset: - obj = DatasetSpatialAverageAccessor(self._dataset) - return obj.avg(data_var, axis, weights, lat_bounds, lon_bounds) + obj = SpatialAverageAccessor(self._dataset) + return obj.spatial_avg(data_var, axis, weights, lat_bounds, lon_bounds) @property # type: ignore - @is_documented_by(DatasetBoundsAccessor.bounds) + @is_documented_by(BoundsAccessor.bounds) def bounds(self) -> Dict[str, Optional[xr.DataArray]]: - obj = DatasetBoundsAccessor(self._dataset) + obj = BoundsAccessor(self._dataset) return obj.bounds - @is_documented_by(DatasetBoundsAccessor.fill_missing) + @is_documented_by(BoundsAccessor.fill_missing) def fill_missing_bounds(self) -> xr.Dataset: - obj = DatasetBoundsAccessor(self._dataset) + obj = BoundsAccessor(self._dataset) return obj.fill_missing() - @is_documented_by(DatasetBoundsAccessor.get_bounds) + @is_documented_by(BoundsAccessor.get_bounds) def get_bounds(self, coord: Coord) -> xr.DataArray: - obj = DatasetBoundsAccessor(self._dataset) + obj = BoundsAccessor(self._dataset) return obj.get_bounds(coord) - @is_documented_by(DatasetBoundsAccessor.add_bounds) + @is_documented_by(BoundsAccessor.add_bounds) def add_bounds(self, coord: Coord, width: float = 0.5) -> xr.Dataset: - obj = DatasetBoundsAccessor(self._dataset) + obj = BoundsAccessor(self._dataset) return obj.add_bounds(coord, width)