Skip to content

Commit

Permalink
remove dataset loading
Browse files Browse the repository at this point in the history
  • Loading branch information
bradyrx committed Jun 4, 2019
1 parent 96298ed commit 0c1a450
Showing 1 changed file with 73 additions and 128 deletions.
201 changes: 73 additions & 128 deletions pop_tools/budget.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,83 +20,6 @@ def _process_grid(grid):
return xr.merge([area, vol, z, mask])


def _process_coords(ds, concat_dim='time', drop=True, extra_coord_vars=['time_bound']):
"""Preprocessor function to drop all non-dim coords, which slows down concatenation
Borrowed from @sridge in an issue thread
"""

time_new = ds.time - datetime.timedelta(seconds=1)
ds['time'] = time_new # shift time back 1s so that the months are correct
# ex: 12/31 is misrepresented as 1/01 without this correction

coord_vars = [v for v in ds.data_vars if concat_dim not in ds[v].dims]
for ecv in extra_coord_vars:
if ecv in ds:
coord_vars += extra_coord_vars

if drop:
return ds.drop(coord_vars)
else:
return ds.set_coords(coord_vars)


def _load_tracer_terms(basepath, filebase, tracer, var_list=None, drop_time=True):
"""Loads in the requested variables and processes them for the tracer budget
Parameters
----------
var_list : list
List of variables to load in, designated by their prefix (e.g., 'UE')
drop_time : bool, optional
If True, drop the time coordinate in addition to spatial coordinates from the dataset.
"""
model_vars = {
'UE': f'UE_{tracer}',
'VN': f'VN_{tracer}',
'WT': f'WT_{tracer}',
'HDIFE': f'HDIFE_{tracer}',
'HDIFN': f'HDIFN_{tracer}',
'HDIFB': f'HDIFB_{tracer}',
'DIA_IMPVF': f'DIA_IMPVF_{tracer}',
'KPP_SRC': f'KPP_SRC_{tracer}',
'FvICE': f'FvICE_{tracer}',
'FvPER': f'FvPER_{tracer}',
# NOTE: Hard-coded for DIC for now.
'STF': f'FG_CO2',
'SMS': f'J_{tracer}',
}

if not set(var_list).issubset(model_vars):
mismatched_var = [v for v in var_list if v not in model_vars]
error_msg = (
f'Please input a `var_list` containing appropriate budget terms. '
+ f'{mismatched_var} not in {list(model_vars.keys())}'
)
raise ValueError(error_msg)

loadVars = dict((k, model_vars[k]) for k in var_list)
ds = xr.Dataset()
for new_var, raw_var in loadVars.items():
# NOTE: This needs to be adapted to a new file loaded procedure. This assumes
# that all the files are in a single folder with one realization.
ds_i = xr.open_mfdataset(f'{basepath}/{filebase}*{raw_var}.*', preprocess=_process_coords)
ds_i = ds_i.rename({raw_var: new_var})
ds = ds.merge(ds_i)
# Drop coordinates, since they get in the way of roll, shift, diff.
if drop_time:
drop_coords = ds.coords
else:
drop_coords = [c for c in ds.coords if c != 'time']
ds = ds.drop(drop_coords)
# Rename all to z_t to avoid conflictions with xarray. Indices are figured out via the POP grid.
if 'z_w_top' in ds.dims:
ds = ds.rename({'z_w_top': 'z_t'})
if 'z_w_bot' in ds.dims:
ds = ds.rename({'z_w_bot': 'z_t'})
return ds


def _compute_kmax(budget_depth, z):
"""Compute the k-index of the maximum budget depth"""
kmax = (z <= budget_depth).argmin() - 1
Expand Down Expand Up @@ -152,36 +75,34 @@ def _compute_vertical_divergence(da):
# Places a cap of zeros on top of the ocean. This makes it easy to use the `diff` function
# with a positive z heading toward shallower depths.
zero_cap = xr.DataArray(np.zeros((ny, nx)), dims=['nlat', 'nlon'])
vdiv = xr.concat([zero_cap, da], dim='z_t')
vdiv = xr.concat([zero_cap, da.drop('time')], dim='z_t')
vdiv = vdiv.diff('z_t')
return vdiv


def _compute_lateral_advection(basepath, filebase, tracer, grid, mask, kmax=None):
def _compute_lateral_advection(ds, grid, mask, kmax=None):
"""Compute lateral advection component of budget"""
print('Computing lateral advection...')
ds = _load_tracer_terms(basepath, filebase, tracer, var_list=['UE', 'VN'])

ds = ds[['UE', 'VN']]
if kmax is not None:
ds = ds.isel(z_t=slice(0, kmax + 1))

ds = ds.groupby('time.year').mean('time').rename({'year': 'time'})
ds = _convert_to_tendency(ds, grid.vol, kmax=kmax)
ladv_zonal = _compute_horizontal_divergence(ds.UE, mask, direction='zonal')
ladv_merid = _compute_horizontal_divergence(ds.VN, mask, direction='meridional')
ladv = (ladv_zonal + ladv_merid).rename('ladv').sum('z_t')
ladv = _convert_units(ladv)
ladv.attrs['long_name'] = 'lateral advection'
return ladv.load()
return ladv


def _compute_lateral_mixing(basepath, filebase, tracer, grid, mask, kmax=None):
def _compute_lateral_mixing(ds, grid, mask, kmax=None):
"""Compute lateral mixing component."""
print('Computing lateral mixing...')
ds = _load_tracer_terms(basepath, filebase, tracer, var_list=['HDIFN', 'HDIFE', 'HDIFB'])

ds = ds[['HDIFN', 'HDIFE', 'HDIFB']]
if kmax is not None:
ds = ds.isel(z_t=slice(0, kmax + 1))

ds = ds.groupby('time.year').mean('time').rename({'year': 'time'})
# Flip sign so that positive direction is upwards.
ds = _convert_to_tendency(ds, grid.vol, sign=-1, kmax=kmax)
lmix_zonal = _compute_horizontal_divergence(ds.HDIFE, mask, direction='zonal')
Expand All @@ -193,41 +114,39 @@ def _compute_lateral_mixing(basepath, filebase, tracer, grid, mask, kmax=None):
lmix = (lmix_merid + lmix_zonal + lmix_B).rename('lmix').sum('z_t')
lmix = _convert_units(lmix)
lmix.attrs['long_name'] = 'lateral mixing'
return lmix.load()
return lmix


def _compute_vertical_advection(basepath, filebase, tracer, grid, mask, kmax=None):
def _compute_vertical_advection(ds, grid, mask, kmax=None):
"""Compute vertical advection (WT)"""
print('Computing vertical advection...')
ds = _load_tracer_terms(basepath, filebase, tracer, var_list=['WT'])

ds = ds['WT']
if kmax is not None:
# Need one level below max depth to compute divergence into bottom layer.
ds = ds.isel(z_t=slice(0, kmax + 2))
ds = ds.where(mask)

ds = ds.groupby('time.year').mean('time').rename({'year': 'time'})
if kmax is not None:
ds = _convert_to_tendency(ds, grid.vol, kmax=kmax + 1)
else:
ds = _convert_to_tendency(ds, grid.vol)

# Compute divergence of vertical advection.
vadv = (ds.WT.shift(z_t=-1).fillna(0) - ds.WT).isel(z_t=slice(0, -1))
vadv = (ds.shift(z_t=-1).fillna(0) - ds).isel(z_t=slice(0, -1))
vadv = vadv.sum('z_t').rename('vadv')
vadv = _convert_units(vadv)
vadv.attrs['long_name'] = 'vertical advection'
return vadv.load()
return vadv


def _compute_vertical_mixing(basepath, filebase, tracer, grid, mask, kmax=None):
def _compute_vertical_mixing(ds, grid, mask, kmax=None):
"""Compute contribution from vertical mixing."""
print('Computing vertical mixing...')
ds = _load_tracer_terms(basepath, filebase, tracer, var_list=['DIA_IMPVF', 'KPP_SRC'])

ds = ds[['DIA_IMPVF', 'KPP_SRC']]
if kmax is not None:
ds = ds.isel(z_t=slice(0, kmax + 1))

ds = ds.where(mask)
ds = ds.groupby('time.year').mean('time').rename({'year': 'time'})
# Only need to flip sign of DIA_IMPVF.
ds['DIA_IMPVF'] = _convert_to_tendency(ds['DIA_IMPVF'], grid.area, sign=-1, kmax=kmax)
ds['KPP_SRC'] = _convert_to_tendency(ds['KPP_SRC'], grid.vol, kmax=kmax)
Expand All @@ -239,29 +158,27 @@ def _compute_vertical_mixing(basepath, filebase, tracer, grid, mask, kmax=None):
vmix = (ds.KPP_SRC + diadiff).rename('vmix').sum('z_t')
vmix = _convert_units(vmix)
vmix.attrs['long_name'] = 'vertical mixing'
return vmix.load()
return vmix


def _compute_SMS(basepath, filebase, tracer, grid, mask, kmax=None):
def _compute_SMS(ds, grid, mask, kmax=None):
"""Compute SMS term from biology."""
print('Computing source/sink...')
ds = _load_tracer_terms(basepath, filebase, tracer, var_list=['SMS'], drop_time=False)
ds = ds['SMS']
if kmax is not None:
ds = ds.isel(z_t=slice(0, kmax + 1))
ds = ds.where(mask)
ds = _convert_to_tendency(ds, grid.vol, kmax=kmax).rename({'SMS': 'sms'}).sum('z_t')
ds = _convert_to_tendency(ds, grid.vol, kmax=kmax).rename('sms').sum('z_t')
ds = _convert_units(ds)
# SMS comes as monthly output. Need to resample to annual for comparison to the other tracer budget terms.
ds = ds.groupby('time.year').mean('time').rename({'year': 'time'})
ds.attrs['long_name'] = 'source/sink'
return ds.load()
return ds


def _compute_surface_fluxes(basepath, filebase, tracer, grid, mask):
def _compute_surface_fluxes(ds, grid, mask):
"""Computes surface fluxes of tracer."""
ds = _load_tracer_terms(
basepath, filebase, tracer, var_list=['FvICE', 'FvPER', 'STF'], drop_time=False
)
ds = ds[['FvICE', 'FvPER', 'STF']]
ds = ds.where(mask)
ds = _convert_to_tendency(ds, grid.area)
ds = _convert_units(ds)
Expand All @@ -270,22 +187,53 @@ def _compute_surface_fluxes(basepath, filebase, tracer, grid, mask):
stf = (ds.STF).rename('stf')
vf.attrs['long_name'] = 'virtual flux'
stf.attrs['long_name'] = 'surface tracer flux'
return vf.load(), stf.load()
return vf, stf


def regional_tracer_budget(
basepath, filebase, tracer, grid, mask=None, mask_int=None, budget_depth=None, sum_area=True
):
def _process_input_dataset(ds):
"""Checks that input dataset has appropriate variables, etc."""
mandatory_vars = ['UE', 'VN', 'WT', 'HDIFE', 'HDIFN', 'HDIFB', 'DIA_IMPVF', 'KPP_SRC']
if not set(mandatory_vars).issubset(ds):
missing_vars = [v for v in mandatory_vars if v not in ds]
error_msg = (
'Input dataset does not contain the mandatory variables for budget analysis. '
+ f'`ds` is missing {missing_vars}'
)
raise IOError(error_msg)

# Drop coordinates for shift, roll, etc.
coord_vars = [c for c in ds.coords if c != 'time']
ds = ds.drop(coord_vars)
if 'z_w_top' in ds.dims:
ds = ds.rename({'z_w_top': 'z_t'})
if 'z_w_bot' in ds.dims:
ds = ds.rename({'z_w_bot': 'z_t'})

# Force chunking.
if not ds.chunks:
raise IOError('Please input a dataset with chunks.')
return ds


def regional_tracer_budget(ds, grid, mask=None, mask_int=None, budget_depth=None, sum_area=True):
"""Return a regional tracer budget on the POP grid.
Parameters
----------
basepath : str
Path to folder with raw POP output
filebase : str
Base name of file (e.g., 'g.DPLE.GECOIAF.T62_g16.009.chey.pop.h.')
tracer : str
Tracer variable name (e.g., 'DIC')
ds : `xarray.Dataset`
Dataset containing global POP output, with the tracer suffix removed:
* UE
* VN
* WT
* HDIFE
* HDIFN
* HDIFB
* DIA_IMPVF
* KPP_SRC
* FvICE (optional)
* FvPER (optional)
* STF (optional; rename from FG_{tracer} for CO2)
* SMS (optional; rename from J_{tracer})
grid : str
POP grid (e.g., POP_gx3v7, POP_gx1v7, POP_tx0.1v3)
mask : `xarray.DataArray`, optional
Expand All @@ -302,7 +250,7 @@ def regional_tracer_budget(
reg_budget: `xarray.Dataset`
Dataset containing integrated budget terms over masked POP volume.
"""

ds = _process_input_dataset(ds)
grid = _process_grid(grid)
if mask is None:
# Default to REGION_MASK from POP.
Expand All @@ -320,19 +268,16 @@ def regional_tracer_budget(
else:
kmax = None

ladv = _compute_lateral_advection(basepath, filebase, tracer, grid, mask, kmax=kmax)
vadv = _compute_vertical_advection(basepath, filebase, tracer, grid, mask, kmax=kmax)
lmix = _compute_lateral_mixing(basepath, filebase, tracer, grid, mask, kmax=kmax)
vmix = _compute_vertical_mixing(basepath, filebase, tracer, grid, mask, kmax=kmax)
sms = _compute_SMS(basepath, filebase, tracer, grid, mask, kmax=kmax)
vf, stf = _compute_surface_fluxes(basepath, filebase, tracer, grid, mask)
ladv = _compute_lateral_advection(ds, grid, mask, kmax=kmax)
vadv = _compute_vertical_advection(ds, grid, mask, kmax=kmax)
lmix = _compute_lateral_mixing(ds, grid, mask, kmax=kmax)
vmix = _compute_vertical_mixing(ds, grid, mask, kmax=kmax)
sms = _compute_SMS(ds, grid, mask, kmax=kmax)
vf, stf = _compute_surface_fluxes(ds, grid, mask)

# Merge into dataset.
reg_budget = xr.merge([ladv, vadv, lmix, vmix, sms, vf, stf])
reg_budget.attrs['units'] = 'mol/yr'
if sum_area:
reg_budget = reg_budget.sum(['nlat', 'nlon'])
# Append mask for user to reference.
reg_budget['mask'] = mask
reg_budget['mask'].attrs['long_name'] = 'mask over which tracer budget was computed'
return reg_budget
return reg_budget.load()

0 comments on commit 0c1a450

Please sign in to comment.