From 0c1a450d63c01c3410e6901d9b76778ab424576e Mon Sep 17 00:00:00 2001 From: "Riley X. Brady" Date: Tue, 4 Jun 2019 12:03:21 -0600 Subject: [PATCH] remove dataset loading --- pop_tools/budget.py | 201 ++++++++++++++++---------------------------- 1 file changed, 73 insertions(+), 128 deletions(-) diff --git a/pop_tools/budget.py b/pop_tools/budget.py index def5ad04..8060c193 100644 --- a/pop_tools/budget.py +++ b/pop_tools/budget.py @@ -20,83 +20,6 @@ def _process_grid(grid): return xr.merge([area, vol, z, mask]) -def _process_coords(ds, concat_dim='time', drop=True, extra_coord_vars=['time_bound']): - """Preprocessor function to drop all non-dim coords, which slows down concatenation - - Borrowed from @sridge in an issue thread - """ - - time_new = ds.time - datetime.timedelta(seconds=1) - ds['time'] = time_new # shift time back 1s so that the months are correct - # ex: 12/31 is misrepresented as 1/01 without this correction - - coord_vars = [v for v in ds.data_vars if concat_dim not in ds[v].dims] - for ecv in extra_coord_vars: - if ecv in ds: - coord_vars += extra_coord_vars - - if drop: - return ds.drop(coord_vars) - else: - return ds.set_coords(coord_vars) - - -def _load_tracer_terms(basepath, filebase, tracer, var_list=None, drop_time=True): - """Loads in the requested variables and processes them for the tracer budget - - Parameters - ---------- - var_list : list - List of variables to load in, designated by their prefix (e.g., 'UE') - drop_time : bool, optional - If True, drop the time coordinate in addition to spatial coordinates from the dataset. - """ - model_vars = { - 'UE': f'UE_{tracer}', - 'VN': f'VN_{tracer}', - 'WT': f'WT_{tracer}', - 'HDIFE': f'HDIFE_{tracer}', - 'HDIFN': f'HDIFN_{tracer}', - 'HDIFB': f'HDIFB_{tracer}', - 'DIA_IMPVF': f'DIA_IMPVF_{tracer}', - 'KPP_SRC': f'KPP_SRC_{tracer}', - 'FvICE': f'FvICE_{tracer}', - 'FvPER': f'FvPER_{tracer}', - # NOTE: Hard-coded for DIC for now. - 'STF': f'FG_CO2', - 'SMS': f'J_{tracer}', - } - - if not set(var_list).issubset(model_vars): - mismatched_var = [v for v in var_list if v not in model_vars] - error_msg = ( - f'Please input a `var_list` containing appropriate budget terms. ' - + f'{mismatched_var} not in {list(model_vars.keys())}' - ) - raise ValueError(error_msg) - - loadVars = dict((k, model_vars[k]) for k in var_list) - ds = xr.Dataset() - for new_var, raw_var in loadVars.items(): - # NOTE: This needs to be adapted to a new file loaded procedure. This assumes - # that all the files are in a single folder with one realization. - ds_i = xr.open_mfdataset(f'{basepath}/{filebase}*{raw_var}.*', preprocess=_process_coords) - ds_i = ds_i.rename({raw_var: new_var}) - ds = ds.merge(ds_i) - # Drop coordinates, since they get in the way of roll, shift, diff. - if drop_time: - drop_coords = ds.coords - else: - drop_coords = [c for c in ds.coords if c != 'time'] - ds = ds.drop(drop_coords) - # Rename all to z_t to avoid conflictions with xarray. Indices are figured out via the POP grid. - if 'z_w_top' in ds.dims: - ds = ds.rename({'z_w_top': 'z_t'}) - if 'z_w_bot' in ds.dims: - ds = ds.rename({'z_w_bot': 'z_t'}) - return ds - - def _compute_kmax(budget_depth, z): """Compute the k-index of the maximum budget depth""" kmax = (z <= budget_depth).argmin() - 1 @@ -152,36 +75,34 @@ def _compute_vertical_divergence(da): # Places a cap of zeros on top of the ocean. This makes it easy to use the `diff` function # with a positive z heading toward shallower depths. zero_cap = xr.DataArray(np.zeros((ny, nx)), dims=['nlat', 'nlon']) - vdiv = xr.concat([zero_cap, da], dim='z_t') + vdiv = xr.concat([zero_cap, da.drop('time')], dim='z_t') vdiv = vdiv.diff('z_t') return vdiv -def _compute_lateral_advection(basepath, filebase, tracer, grid, mask, kmax=None): +def _compute_lateral_advection(ds, grid, mask, kmax=None): """Compute lateral advection component of budget""" print('Computing lateral advection...') - ds = _load_tracer_terms(basepath, filebase, tracer, var_list=['UE', 'VN']) - + ds = ds[['UE', 'VN']] if kmax is not None: ds = ds.isel(z_t=slice(0, kmax + 1)) - + ds = ds.groupby('time.year').mean('time').rename({'year': 'time'}) ds = _convert_to_tendency(ds, grid.vol, kmax=kmax) ladv_zonal = _compute_horizontal_divergence(ds.UE, mask, direction='zonal') ladv_merid = _compute_horizontal_divergence(ds.VN, mask, direction='meridional') ladv = (ladv_zonal + ladv_merid).rename('ladv').sum('z_t') ladv = _convert_units(ladv) ladv.attrs['long_name'] = 'lateral advection' - return ladv.load() + return ladv -def _compute_lateral_mixing(basepath, filebase, tracer, grid, mask, kmax=None): +def _compute_lateral_mixing(ds, grid, mask, kmax=None): """Compute lateral mixing component.""" print('Computing lateral mixing...') - ds = _load_tracer_terms(basepath, filebase, tracer, var_list=['HDIFN', 'HDIFE', 'HDIFB']) - + ds = ds[['HDIFN', 'HDIFE', 'HDIFB']] if kmax is not None: ds = ds.isel(z_t=slice(0, kmax + 1)) - + ds = ds.groupby('time.year').mean('time').rename({'year': 'time'}) # Flip sign so that positive direction is upwards. ds = _convert_to_tendency(ds, grid.vol, sign=-1, kmax=kmax) lmix_zonal = _compute_horizontal_divergence(ds.HDIFE, mask, direction='zonal') @@ -193,41 +114,39 @@ def _compute_lateral_mixing(basepath, filebase, tracer, grid, mask, kmax=None): lmix = (lmix_merid + lmix_zonal + lmix_B).rename('lmix').sum('z_t') lmix = _convert_units(lmix) lmix.attrs['long_name'] = 'lateral mixing' - return lmix.load() + return lmix -def _compute_vertical_advection(basepath, filebase, tracer, grid, mask, kmax=None): +def _compute_vertical_advection(ds, grid, mask, kmax=None): """Compute vertical advection (WT)""" print('Computing vertical advection...') - ds = _load_tracer_terms(basepath, filebase, tracer, var_list=['WT']) - + ds = ds['WT'] if kmax is not None: # Need one level below max depth to compute divergence into bottom layer. ds = ds.isel(z_t=slice(0, kmax + 2)) ds = ds.where(mask) - + ds = ds.groupby('time.year').mean('time').rename({'year': 'time'}) if kmax is not None: ds = _convert_to_tendency(ds, grid.vol, kmax=kmax + 1) else: ds = _convert_to_tendency(ds, grid.vol) # Compute divergence of vertical advection. - vadv = (ds.WT.shift(z_t=-1).fillna(0) - ds.WT).isel(z_t=slice(0, -1)) + vadv = (ds.shift(z_t=-1).fillna(0) - ds).isel(z_t=slice(0, -1)) vadv = vadv.sum('z_t').rename('vadv') vadv = _convert_units(vadv) vadv.attrs['long_name'] = 'vertical advection' - return vadv.load() + return vadv -def _compute_vertical_mixing(basepath, filebase, tracer, grid, mask, kmax=None): +def _compute_vertical_mixing(ds, grid, mask, kmax=None): """Compute contribution from vertical mixing.""" print('Computing vertical mixing...') - ds = _load_tracer_terms(basepath, filebase, tracer, var_list=['DIA_IMPVF', 'KPP_SRC']) - + ds = ds[['DIA_IMPVF', 'KPP_SRC']] if kmax is not None: ds = ds.isel(z_t=slice(0, kmax + 1)) - ds = ds.where(mask) + ds = ds.groupby('time.year').mean('time').rename({'year': 'time'}) # Only need to flip sign of DIA_IMPVF. ds['DIA_IMPVF'] = _convert_to_tendency(ds['DIA_IMPVF'], grid.area, sign=-1, kmax=kmax) ds['KPP_SRC'] = _convert_to_tendency(ds['KPP_SRC'], grid.vol, kmax=kmax) @@ -239,29 +158,27 @@ def _compute_vertical_mixing(basepath, filebase, tracer, grid, mask, kmax=None): vmix = (ds.KPP_SRC + diadiff).rename('vmix').sum('z_t') vmix = _convert_units(vmix) vmix.attrs['long_name'] = 'vertical mixing' - return vmix.load() + return vmix -def _compute_SMS(basepath, filebase, tracer, grid, mask, kmax=None): +def _compute_SMS(ds, grid, mask, kmax=None): """Compute SMS term from biology.""" print('Computing source/sink...') - ds = _load_tracer_terms(basepath, filebase, tracer, var_list=['SMS'], drop_time=False) + ds = ds['SMS'] if kmax is not None: ds = ds.isel(z_t=slice(0, kmax + 1)) ds = ds.where(mask) - ds = _convert_to_tendency(ds, grid.vol, kmax=kmax).rename({'SMS': 'sms'}).sum('z_t') + ds = _convert_to_tendency(ds, grid.vol, kmax=kmax).rename('sms').sum('z_t') ds = _convert_units(ds) # SMS comes as monthly output. Need to resample to annual for comparison to the other tracer budget terms. ds = ds.groupby('time.year').mean('time').rename({'year': 'time'}) ds.attrs['long_name'] = 'source/sink' - return ds.load() + return ds -def _compute_surface_fluxes(basepath, filebase, tracer, grid, mask): +def _compute_surface_fluxes(ds, grid, mask): """Computes surface fluxes of tracer.""" - ds = _load_tracer_terms( - basepath, filebase, tracer, var_list=['FvICE', 'FvPER', 'STF'], drop_time=False - ) + ds = ds[['FvICE', 'FvPER', 'STF']] ds = ds.where(mask) ds = _convert_to_tendency(ds, grid.area) ds = _convert_units(ds) @@ -270,22 +187,53 @@ def _compute_surface_fluxes(basepath, filebase, tracer, grid, mask): stf = (ds.STF).rename('stf') vf.attrs['long_name'] = 'virtual flux' stf.attrs['long_name'] = 'surface tracer flux' - return vf.load(), stf.load() + return vf, stf -def regional_tracer_budget( - basepath, filebase, tracer, grid, mask=None, mask_int=None, budget_depth=None, sum_area=True -): +def _process_input_dataset(ds): + """Checks that input dataset has appropriate variables, etc.""" + mandatory_vars = ['UE', 'VN', 'WT', 'HDIFE', 'HDIFN', 'HDIFB', 'DIA_IMPVF', 'KPP_SRC'] + if not set(mandatory_vars).issubset(ds): + missing_vars = [v for v in mandatory_vars if v not in ds] + error_msg = ( + 'Input dataset does not contain the mandatory variables for budget analysis. ' + + f'`ds` is missing {missing_vars}' + ) + raise IOError(error_msg) + + # Drop coordinates for shift, roll, etc. + coord_vars = [c for c in ds.coords if c != 'time'] + ds = ds.drop(coord_vars) + if 'z_w_top' in ds.dims: + ds = ds.rename({'z_w_top': 'z_t'}) + if 'z_w_bot' in ds.dims: + ds = ds.rename({'z_w_bot': 'z_t'}) + + # Force chunking. + if not ds.chunks: + raise IOError('Please input a dataset with chunks.') + return ds + + +def regional_tracer_budget(ds, grid, mask=None, mask_int=None, budget_depth=None, sum_area=True): """Return a regional tracer budget on the POP grid. Parameters ---------- - basepath : str - Path to folder with raw POP output - filebase : str - Base name of file (e.g., 'g.DPLE.GECOIAF.T62_g16.009.chey.pop.h.') - tracer : str - Tracer variable name (e.g., 'DIC') + ds : `xarray.Dataset` + Dataset containing global POP output, with the tracer suffix removed: + * UE + * VN + * WT + * HDIFE + * HDIFN + * HDIFB + * DIA_IMPVF + * KPP_SRC + * FvICE (optional) + * FvPER (optional) + * STF (optional; rename from FG_{tracer} for CO2) + * SMS (optional; rename from J_{tracer}) grid : str POP grid (e.g., POP_gx3v7, POP_gx1v7, POP_tx0.1v3) mask : `xarray.DataArray`, optional @@ -302,7 +250,7 @@ def regional_tracer_budget( reg_budget: `xarray.Dataset` Dataset containing integrated budget terms over masked POP volume. """ - + ds = _process_input_dataset(ds) grid = _process_grid(grid) if mask is None: # Default to REGION_MASK from POP. @@ -320,19 +268,16 @@ def regional_tracer_budget( else: kmax = None - ladv = _compute_lateral_advection(basepath, filebase, tracer, grid, mask, kmax=kmax) - vadv = _compute_vertical_advection(basepath, filebase, tracer, grid, mask, kmax=kmax) - lmix = _compute_lateral_mixing(basepath, filebase, tracer, grid, mask, kmax=kmax) - vmix = _compute_vertical_mixing(basepath, filebase, tracer, grid, mask, kmax=kmax) - sms = _compute_SMS(basepath, filebase, tracer, grid, mask, kmax=kmax) - vf, stf = _compute_surface_fluxes(basepath, filebase, tracer, grid, mask) + ladv = _compute_lateral_advection(ds, grid, mask, kmax=kmax) + vadv = _compute_vertical_advection(ds, grid, mask, kmax=kmax) + lmix = _compute_lateral_mixing(ds, grid, mask, kmax=kmax) + vmix = _compute_vertical_mixing(ds, grid, mask, kmax=kmax) + sms = _compute_SMS(ds, grid, mask, kmax=kmax) + vf, stf = _compute_surface_fluxes(ds, grid, mask) # Merge into dataset. reg_budget = xr.merge([ladv, vadv, lmix, vmix, sms, vf, stf]) reg_budget.attrs['units'] = 'mol/yr' if sum_area: reg_budget = reg_budget.sum(['nlat', 'nlon']) - # Append mask for user to reference. - reg_budget['mask'] = mask - reg_budget['mask'].attrs['long_name'] = 'mask over which tracer budget was computed' - return reg_budget + return reg_budget.load()