Skip to content

Commit

Permalink
Change syntax for subsetting an xarray Dataset
Browse files Browse the repository at this point in the history
We were using a subset_variables() function whereas it is simple
to just index the dataset with the desired variables, see:
pydata/xarray#3552 (comment)
  • Loading branch information
xylar committed Nov 20, 2019
1 parent e22bcda commit cf8e2cb
Show file tree
Hide file tree
Showing 14 changed files with 16 additions and 92 deletions.
5 changes: 1 addition & 4 deletions mpas_analysis/ocean/climatology_map_antarctic_melt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
from mpas_analysis.ocean.plot_climatology_map_subtask import \
PlotClimatologyMapSubtask

from mpas_analysis.shared.mpas_xarray import mpas_xarray

from mpas_analysis.shared.constants import constants


Expand Down Expand Up @@ -203,8 +201,7 @@ def run_task(self): # {{{

# first, load the land-ice mask from the restart file
dsLandIceMask = xr.open_dataset(self.restartFileName)
dsLandIceMask = mpas_xarray.subset_variables(dsLandIceMask,
['landIceMask'])
dsLandIceMask = dsLandIceMask[['landIceMask']]
dsLandIceMask = dsLandIceMask.isel(Time=0)
self.landIceMask = dsLandIceMask.landIceMask > 0.

Expand Down
4 changes: 1 addition & 3 deletions mpas_analysis/ocean/climatology_map_argo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@

from mpas_analysis.shared.climatology import RemapObservedClimatologySubtask

from mpas_analysis.shared.mpas_xarray import mpas_xarray


class ClimatologyMapArgoTemperature(AnalysisTask): # {{{
"""
Expand Down Expand Up @@ -439,7 +437,7 @@ def build_observational_dataset(self, fileName): # {{{

# no meaningful year since this is already a climatology
dsObs.coords['year'] = ('Time', np.ones(dsObs.dims['Time'], int))
dsObs = mpas_xarray.subset_variables(dsObs, [self.fieldName, 'month'])
dsObs = dsObs[[self.fieldName, 'month']]

slices = []
field = dsObs[self.fieldName]
Expand Down
4 changes: 1 addition & 3 deletions mpas_analysis/ocean/climatology_map_mld.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
from mpas_analysis.ocean.plot_climatology_map_subtask import \
PlotClimatologyMapSubtask

from mpas_analysis.shared.mpas_xarray import mpas_xarray


class ClimatologyMapMLD(AnalysisTask): # {{{
"""
Expand Down Expand Up @@ -248,7 +246,7 @@ def build_observational_dataset(self, fileName): # {{{
# no meaningful year since this is already a climatology
dsObs.coords['year'] = ('Time', np.ones(dsObs.dims['Time'], int))

dsObs = mpas_xarray.subset_variables(dsObs, ['mld', 'month'])
dsObs = dsObs[['mld', 'month']]
return dsObs # }}}

# }}}
Expand Down
6 changes: 1 addition & 5 deletions mpas_analysis/ocean/compute_transects_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

from mpas_analysis.shared.climatology import RemapMpasClimatologySubtask

from mpas_analysis.shared.mpas_xarray import mpas_xarray

from mpas_analysis.shared.io.utility import build_config_full_path, \
make_directories
from mpas_analysis.shared.io import write_netcdf
Expand Down Expand Up @@ -221,9 +219,7 @@ def run_task(self): # {{{

# first, compute zMid and cell mask from the restart file
with xr.open_dataset(self.restartFileName) as ds:
ds = mpas_xarray.subset_variables(ds, ['maxLevelCell',
'bottomDepth',
'layerThickness'])
ds = ds[['maxLevelCell', 'bottomDepth', 'layerThickness']]
ds = ds.isel(Time=0)

self.maxLevelCell = ds.maxLevelCell - 1
Expand Down
5 changes: 2 additions & 3 deletions mpas_analysis/ocean/meridional_heat_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from mpas_analysis.shared.plot import plot_vertical_section, plot_1D, savefig

from mpas_analysis.shared.io.utility import make_directories, build_obs_path
from mpas_analysis.shared.io import write_netcdf, subset_variables
from mpas_analysis.shared.io import write_netcdf

from mpas_analysis.shared import AnalysisTask
from mpas_analysis.shared.html import write_image_xml
Expand Down Expand Up @@ -249,8 +249,7 @@ def run_task(self): # {{{
'timeMonthly_avg_meridionalHeatTransportLatZ']

annualClimatology = xr.open_dataset(climatologyFileName)
annualClimatology = subset_variables(annualClimatology,
variableList)
annualClimatology = annualClimatology[variableList]
if 'Time' in annualClimatology.dims:
annualClimatology = annualClimatology.isel(Time=0)

Expand Down
4 changes: 1 addition & 3 deletions mpas_analysis/ocean/regional_ts_diagrams.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@
from mpas_analysis.shared.regions import ComputeRegionMasksSubtask, \
get_feature_list

from mpas_analysis.shared.mpas_xarray.mpas_xarray import subset_variables

from mpas_analysis.ocean.utility import compute_zmid

from mpas_analysis.shared.constants import constants
Expand Down Expand Up @@ -827,7 +825,7 @@ def _get_mpas_T_S(self, config): # {{{
variableList = ['timeMonthly_avg_activeTracers_temperature',
'timeMonthly_avg_activeTracers_salinity',
'timeMonthly_avg_layerThickness']
ds = subset_variables(ds, variableList)
ds = ds[variableList]

ds['zMid'] = compute_zmid(dsRestart.bottomDepth,
dsRestart.maxLevelCell,
Expand Down
6 changes: 1 addition & 5 deletions mpas_analysis/ocean/remap_depth_slices_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

from mpas_analysis.shared.climatology import RemapMpasClimatologySubtask

from mpas_analysis.shared.mpas_xarray import mpas_xarray

from mpas_analysis.ocean.utility import compute_zmid


Expand Down Expand Up @@ -112,9 +110,7 @@ def run_task(self): # {{{

# first, load the land-ice mask from the restart file
ds = xr.open_dataset(self.restartFileName)
ds = mpas_xarray.subset_variables(ds, ['maxLevelCell',
'bottomDepth',
'layerThickness'])
ds = ds[['maxLevelCell', 'bottomDepth', 'layerThickness']]
ds = ds.isel(Time=0)

self.maxLevelCell = ds.maxLevelCell - 1
Expand Down
4 changes: 1 addition & 3 deletions mpas_analysis/ocean/remap_sose_climatology.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from mpas_analysis.shared.climatology import RemapObservedClimatologySubtask, \
get_antarctic_stereographic_projection

from mpas_analysis.shared.mpas_xarray import mpas_xarray


class RemapSoseClimatology(RemapObservedClimatologySubtask):
# {{{
Expand Down Expand Up @@ -146,7 +144,7 @@ def build_observational_dataset(self, fileName): # {{{

if self.botFieldName is not None:
varList.append(self.botFieldName)
dsObs = mpas_xarray.subset_variables(dsObs, varList)
dsObs = dsObs[varList]

if self.depths is not None:
field = dsObs[self.fieldName]
Expand Down
4 changes: 1 addition & 3 deletions mpas_analysis/sea_ice/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

from mpas_analysis.shared.time_series import combine_time_series_with_ncrcat
from mpas_analysis.shared.io import open_mpas_dataset, write_netcdf
from mpas_analysis.shared.mpas_xarray.mpas_xarray import subset_variables

from mpas_analysis.shared.html import write_image_xml

Expand Down Expand Up @@ -612,8 +611,7 @@ def _compute_area_vol(self): # {{{

dsTimeSeries = {}
dsMesh = xr.open_dataset(self.restartFileName)
dsMesh = subset_variables(dsMesh,
variableList=['latCell', 'areaCell'])
dsMesh = dsMesh[['latCell', 'areaCell']]
# Load data
ds = open_mpas_dataset(
fileName=self.inputFile,
Expand Down
4 changes: 1 addition & 3 deletions mpas_analysis/shared/climatology/mpas_climatology_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@

from mpas_analysis.shared.constants import constants

from mpas_analysis.shared.mpas_xarray.mpas_xarray import subset_variables


class MpasClimatologyTask(AnalysisTask): # {{{
'''
Expand Down Expand Up @@ -676,7 +674,7 @@ def _compute_climatologies_with_xarray(self, inDirectory, outDirectory):
def _preprocess(ds):
# drop unused variables during preprocessing because only the
# variables we want are guaranteed to be in all the files
return subset_variables(ds, variableList)
return ds[variableList]

season = self.season
parentTask = self.parentTask
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@
from mpas_analysis.shared.climatology.comparison_descriptors import \
get_comparison_descriptor

from mpas_analysis.shared.mpas_xarray import mpas_xarray


class RemapMpasClimatologySubtask(AnalysisTask): # {{{
'''
Expand Down Expand Up @@ -238,7 +236,7 @@ def run_task(self): # {{{
self.climatologyName))

dsMask = xr.open_dataset(self.mpasClimatologyTask.inputFiles[0])
dsMask = mpas_xarray.subset_variables(dsMask, self.variableList)
dsMask = dsMask[self.variableList]
iselValues = {'Time': 0}
if self.iselValues is not None:
iselValues.update(self.iselValues)
Expand Down Expand Up @@ -526,8 +524,7 @@ def _mask_climatologies(self, season, dsMask): # {{{
if not os.path.exists(maskedClimatologyFileName):
# slice and mask the data set
climatology = xr.open_dataset(climatologyFileName)
climatology = mpas_xarray.subset_variables(climatology,
self.variableList)
climatology = climatology[self.variableList]
iselValues = {}
if 'Time' in climatology.dims:
iselValues['Time'] = 0
Expand Down
3 changes: 1 addition & 2 deletions mpas_analysis/shared/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@
StreamsFile
from mpas_analysis.shared.io.utility import paths, decode_strings
from mpas_analysis.shared.io.write_netcdf import write_netcdf
from mpas_analysis.shared.io.mpas_reader import open_mpas_dataset, \
subset_variables
from mpas_analysis.shared.io.mpas_reader import open_mpas_dataset
48 changes: 1 addition & 47 deletions mpas_analysis/shared/io/mpas_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,53 +94,7 @@ def open_mpas_dataset(fileName, calendar,
days_to_datetime(startDate, calendar=calendar),
days_to_datetime(endDate, calendar=calendar)))
if variableList is not None:
ds = subset_variables(ds, variableList)

return ds # }}}


def subset_variables(ds, variableList): # {{{
"""
Given a data set and a list of variable names, returns a new data set that
contains only variables with those names.
Parameters
----------
ds : ``xarray.DataSet`` object
The data set from which a subset of variables is to be extracted.
variableList : string or list of strings
The names of the variables to be extracted.
Returns
-------
ds : ``xarray.DataSet`` object
A copy of the original data set with only the variables in
variableList.
Raises
------
ValueError
If the resulting data set is empty.
"""
# Authors
# -------
# Phillip J. Wolfram, Xylar Asay-Davis

allvars = ds.data_vars.keys()

# get set of variables to drop (all ds variables not in vlist)
dropvars = set(allvars) - set(variableList)

# drop variables not requested and coordinates that are no longer needed
ds = ds.drop_vars(dropvars)

if len(ds.data_vars.keys()) == 0:
raise ValueError(
'Empty dataset is returned.\n'
'Variables {}\n'
'are not found within the dataset '
'variables: {}.'.format(variableList, allvars))
ds = ds[variableList]

return ds # }}}

Expand Down
4 changes: 1 addition & 3 deletions mpas_analysis/shared/regions/compute_region_masks_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
make_directories, get_region_mask
from mpas_analysis.shared.io import write_netcdf

from mpas_analysis.shared.mpas_xarray import mpas_xarray


def get_feature_list(geojsonFileName):
'''
Expand Down Expand Up @@ -58,7 +56,7 @@ def compute_mpas_region_masks(geojsonFileName, meshFileName, maskFileName,
return

with xr.open_dataset(meshFileName) as dsMesh:
dsMesh = mpas_xarray.subset_variables(dsMesh, ['lonCell', 'latCell'])
dsMesh = dsMesh[['lonCell', 'latCell']]
latCell = numpy.rad2deg(dsMesh.latCell.values)

# transform longitudes to [-180, 180)
Expand Down

0 comments on commit cf8e2cb

Please sign in to comment.