Skip to content

Commit

Permalink
Pass metrics to xgcm.Grid by default. (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian authored Apr 18, 2022
1 parent 7454358 commit 66bd763
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 5 deletions.
30 changes: 27 additions & 3 deletions pop_tools/xgcm_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,39 @@ def relabel_pop_dims(ds):
if coord in ds_new.coords:
ds_new = ds_new.drop_vars(coord)
if 'z_w_top' in ds_new.dims and 'z_w' in ds_new.dims:
ds_new = ds_new.drop('z_w_top').rename({'z_w': 'z_w_top'})
ds_new = ds_new.drop_vars('z_w_top').rename({'z_w': 'z_w_top'})
return ds_new


def to_xgcm_grid_dataset(ds, **kwargs):
def get_metrics(ds):
"""Finds metrics variables present in `ds`, returns a dict that can be passed to xgcm."""
metrics = {
('X',): ['DXU', 'DXT'], # X distances
('Y',): ['DYU', 'DYT'], # Y distances
('Z',): ['DZU', 'DZT'], # Z distances
('X', 'Y'): ['UAREA', 'TAREA'], # Areas
}
# filter to variables that are present
new_metrics = {}
for axis, names in metrics.items():
new_names = [name for name in names if name in ds]
if new_names:
new_metrics[axis] = new_names
return new_metrics


def to_xgcm_grid_dataset(ds, metrics='detect', **kwargs):
"""Modify POP model output to be compatible with xgcm.
Parameters
----------
ds : xarray.Dataset
An xarray Dataset
metrics : {"detect"} or dict, optional
Dictionary providing metrics to the `xgcm.Grid` contructor.
If ``"detect"``, will autodetect metrics that are present by searching for
variables named DXU, DXT, DYU, DYT, DZU, DZT, UAREA, TAREA.
If None, no metrics will be assigned.
kwargs:
Additional keyword arguments are passed through to `xgcm.Grid` class.
Expand Down Expand Up @@ -204,5 +226,7 @@ def to_xgcm_grid_dataset(ds, **kwargs):
"""to_xgcm_grid_dataset() function requires the `xgcm` package. \nYou can install it via PyPI or Conda"""
)
ds_new = relabel_pop_dims(ds)
grid = xgcm.Grid(ds_new, **kwargs)
if metrics == 'detect':
metrics = get_metrics(ds_new)
grid = xgcm.Grid(ds_new, metrics=metrics, **kwargs)
return grid, ds_new
52 changes: 50 additions & 2 deletions tests/test_xgcm_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
],
)
def test_to_xgcm_grid_dataset(ds, old_spatial_coords, axes):
grid, ds_new = pop_tools.to_xgcm_grid_dataset(ds, metrics=None)
grid, ds_new = pop_tools.to_xgcm_grid_dataset(ds)
assert isinstance(grid, xgcm.Grid)
assert set(axes) == set(grid.axes.keys())
new_spatial_coords = ['nlon_u', 'nlat_u', 'nlon_t', 'nlat_t']
Expand All @@ -49,4 +49,52 @@ def test_to_xgcm_grid_dataset_missing_xgcm():
with mock.patch.dict(sys.modules, {'xgcm': None}):
filepath = DATASETS.fetch('tend_zint_100m_Fe.nc')
ds = xr.open_dataset(filepath)
_, _ = pop_tools.to_xgcm_grid_dataset(ds, metrics=None)
_, _ = pop_tools.to_xgcm_grid_dataset(ds)


def test_set_metrics():
from pop_tools.xgcm_util import get_metrics

ds = xr.Dataset({'DXU': 1, 'DYT': 2, 'DZT': 3})
actual = get_metrics(ds)
expected = {('X',): ['DXU'], ('Y',): ['DYT'], ('Z',): ['DZT']}
assert actual == expected

assert not get_metrics(xr.Dataset({}))


def test_metrics_assignment_no_metrics():
grid, _ = pop_tools.to_xgcm_grid_dataset(ds_c)
assert not grid._metrics


def get_metrics(grid):
return {
tuple(sorted(key)): [metric.name for metric in metrics]
for key, metrics in grid._metrics.items()
}


@pytest.mark.parametrize('ds', [ds_a, ds_b])
def test_metrics_assignment(ds):
grid, _ = pop_tools.to_xgcm_grid_dataset(ds)
expected = {
('X',): ['DXU', 'DXT'], # X distances
('Y',): ['DYU', 'DYT'], # Y distances
('X', 'Y'): ['UAREA', 'TAREA'], # Areas
}

if 'S_FLUX_ROFF_VSF' in ds:
expected[('X', 'Y')] = ['TAREA']
expected[('X',)] = ['DXU']

actual = get_metrics(grid)
assert actual == expected

grid, _ = pop_tools.to_xgcm_grid_dataset(ds, metrics=None)
assert not grid._metrics

expected = {('X',): ['DXU']}
grid, _ = pop_tools.to_xgcm_grid_dataset(ds, metrics={'X': ['DXU']})
actual = get_metrics(grid)
assert actual == expected

0 comments on commit 66bd763

Please sign in to comment.