Skip to content

Commit

Permalink
Merge branch 'master' into fix-to_field_aligned-wrong-dim-order
Browse files Browse the repository at this point in the history
  • Loading branch information
johnomotani authored Dec 24, 2021
2 parents d01d28c + ffda4b2 commit 91833c6
Show file tree
Hide file tree
Showing 10 changed files with 422 additions and 85 deletions.
14 changes: 8 additions & 6 deletions xbout/boutdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ def find_with_dims(first_var, dims):
if first_var is None:
raise ValueError(
f"Could not find variable to interpolate with both "
f"{ds.metadata.get('bout_xdim', 'x')} and "
f"{ds.metadata.get('bout_ydim', 'y')} dimensions"
f"{self.data.metadata.get('bout_xdim', 'x')} and "
f"{self.data.metadata.get('bout_ydim', 'y')} dimensions"
)
variables.remove(first_var)
ds = self.data[first_var].bout.interpolate_parallel(
Expand Down Expand Up @@ -440,19 +440,21 @@ def integrate_midpoints(self, variable, *, dims=None, cumulative_t=False):

spatial_dims = set(dims) - set([tcoord])

integrand = variable * spatial_volume_element

# Need to check if the variable being integrated is a Field2D, which does not
# have a z-dimension to sum over. Other variables are OK because metric
# coefficients, dx and dy all have both x- and y-dimensions so variable would be
# broadcast to include them if necessary
missing_z_sum = zcoord in dims and zcoord not in variable.dims

integrand = variable * spatial_volume_element

integral = integrand.sum(dim=spatial_dims)

# If integrand is a Field2D, need to multiply by nz if integrating over z
if missing_z_sum:
spatial_dims -= set(zcoord)
integral = integrand.sum(dim=spatial_dims)
integral = integral * ds.sizes[zcoord]
else:
integral = integrand.sum(dim=spatial_dims)

if tcoord in dims:
if cumulative_t:
Expand Down
77 changes: 51 additions & 26 deletions xbout/geometries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import xarray as xr
import numpy as np

from .region import Region, _create_regions_toroidal
from .region import Region, _create_regions_toroidal, _create_single_region
from .utils import (
_add_attrs_to_var,
_set_attrs_on_all_vars,
Expand Down Expand Up @@ -183,8 +183,9 @@ def apply_geometry(ds, geometry_name, *, coordinates=None, grid=None):

# In BOUT++ v5, dz is either a Field2D or Field3D.
# We can use it as a 1D coordinate if it's a Field3D, _or_ if nz == 1
bout_v5 = updated_ds.metadata["BOUT_VERSION"] > 5.0 or (
updated_ds.metadata["BOUT_VERSION"] == 5.0 and updated_ds["dz"].ndim == 2
bout_version = updated_ds.metadata.get("BOUT_VERSION", 4.3)
bout_v5 = bout_version > 5.0 or (
bout_version == 5.0 and updated_ds["dz"].ndim >= 2
)
use_metric_3d = updated_ds.metadata.get("use_metric_3d", False)
can_use_1d_z_coord = (nz == 1) or use_metric_3d
Expand All @@ -197,14 +198,20 @@ def apply_geometry(ds, geometry_name, *, coordinates=None, grid=None):
raise ValueError(
f"Spacing is not constant. Cannot create z coordinate"
)
dz = updated_ds["dz"][0, 0]

dz = updated_ds["dz"].min()
else:
dz = updated_ds["dz"]

z0 = 2 * np.pi * updated_ds.metadata["ZMIN"]
z1 = z0 + nz * dz
if not np.isclose(
z1, 2.0 * np.pi * updated_ds.metadata["ZMAX"], rtol=1.0e-15, atol=0.0
if not np.all(
np.isclose(
z1,
2.0 * np.pi * updated_ds.metadata["ZMAX"],
rtol=1.0e-15,
atol=0.0,
)
):
warn(
f"Size of toroidal domain as calculated from nz*dz ({str(z1 - z0)}"
Expand Down Expand Up @@ -272,7 +279,10 @@ def _set_default_toroidal_coordinates(coordinates, ds):
coordinates = {}

# Replace any values that have not been passed in with defaults
coordinates["t"] = coordinates.get("t", ds.metadata["bout_tdim"])
if ds.metadata["is_restart"] == 0:
# Don't need "t" coordinate for restart files which have no time dimension, and
# adding it breaks the check for reloading in open_boutdataset
coordinates["t"] = coordinates.get("t", ds.metadata["bout_tdim"])

default_x = (
ds.metadata["bout_xdim"] if ds.metadata["bout_xdim"] != "x" else "psi_poloidal"
Expand All @@ -288,6 +298,28 @@ def _set_default_toroidal_coordinates(coordinates, ds):
return coordinates


def _add_vars_from_grid(ds, grid, variables):
# Get extra geometry information from grid file if it's not in the dump files
for v in variables:
if v not in ds:
if grid is None:
raise ValueError(
f"Grid file is required to provide {v}. Pass the grid "
f"file name as the 'gridfilepath' argument to "
f"open_boutdataset()."
)
# ds[v] = grid[v]
# Work around issue where xarray drops attributes on coordinates when a new
# DataArray is assigned to the Dataset, see
# https://github.com/pydata/xarray/issues/4415
# https://github.com/pydata/xarray/issues/4393
# This way adds as a 'Variable' instead of as a 'DataArray'
ds[v] = (grid[v].dims, grid[v].values)

_add_attrs_to_var(ds, v)
return ds


@register_geometry("toroidal")
def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None):

Expand All @@ -311,24 +343,7 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None):
)

# Get extra geometry information from grid file if it's not in the dump files
needed_variables = ["psixy", "Rxy", "Zxy"]
for v in needed_variables:
if v not in ds:
if grid is None:
raise ValueError(
f"Grid file is required to provide {v}. Pass the grid "
f"file name as the 'gridfilepath' argument to "
f"open_boutdataset()."
)
# ds[v] = grid[v]
# Work around issue where xarray drops attributes on coordinates when a new
# DataArray is assigned to the Dataset, see
# https://github.com/pydata/xarray/issues/4415
# https://github.com/pydata/xarray/issues/4393
# This way adds as a 'Variable' instead of as a 'DataArray'
ds[v] = (grid[v].dims, grid[v].values)

_add_attrs_to_var(ds, v)
ds = _add_vars_from_grid(ds, grid, ["psixy", "Rxy", "Zxy"])

if "t" in ds.dims:
# Rename 't' if user requested it
Expand All @@ -343,7 +358,8 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None):
ds[coordinates["x"]].attrs["units"] = "Wb"

# Record which dimensions 't', 'x', and 'y' were renamed to.
ds.metadata["bout_tdim"] = coordinates["t"]
if ds.metadata["is_restart"] == 0:
ds.metadata["bout_tdim"] = coordinates["t"]
# x dimension not renamed, so this is still 'x'
ds.metadata["bout_xdim"] = "x"
ds.metadata["bout_ydim"] = coordinates["y"]
Expand Down Expand Up @@ -418,3 +434,12 @@ def add_s_alpha_geometry_coords(ds, *, coordinates=None, grid=None):
del ds["hthe"]

return ds


@register_geometry("fci")
def add_fci_geometry_coords(ds, *, coordinates=None, grid=None):
assert coordinates is None, "Not implemented"
ds = _add_vars_from_grid(ds, grid, ["R", "Z"])
ds = ds.set_coords(("R", "Z"))
ds = _create_single_region(ds, periodic_y=True)
return ds
80 changes: 60 additions & 20 deletions xbout/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@
"wtime_per_rhs",
"wtime_per_rhs_e",
"wtime_per_rhs_i",
"hist_hi",
"tt",
"PE_XIND",
"PE_YIND",
"MYPE",
]
_BOUT_PER_PROC_VARIABLES_REQUIRED_FROM_RESTARTS = ["hist_hi", "tt"]
_BOUT_TIME_DEPENDENT_META_VARS = ["iteration"]


Expand Down Expand Up @@ -67,11 +66,12 @@ def open_boutdataset(
keep_yboundaries=False,
run_name=None,
info=True,
is_restart=None,
**kwargs,
):
"""
Load a dataset from a set of BOUT output files, including the input options
file. Can also load from a grid file.
file. Can also load from a grid file or from restart files.
Note that when reloading a Dataset that was saved by xBOUT, the state of the saved
Dataset is restored, and the values of `keep_xboundaries`, `keep_yboundaries`, and
Expand Down Expand Up @@ -137,6 +137,12 @@ def open_boutdataset(
Useful if you are going to open multiple simulations and compare the
results.
info : bool or "terse", optional
is_restart : bool, optional
Restart files require some special handling (e.g. working around variables that
are not present in restart files). By default, this special handling is enabled
if the files do not have a time dimension and `restart` is present in the file
name in `datapath`. This option can be set to True or False to explicitly enable
or disable the restart file handling.
kwargs : optional
Keyword arguments are passed down to `xarray.open_mfdataset`, which in
turn passes extra kwargs down to `xarray.open_dataset`.
Expand All @@ -151,6 +157,11 @@ def open_boutdataset(

input_type = _check_dataset_type(datapath)

if is_restart is None:
is_restart = input_type == "restart"
elif is_restart is True:
input_type = "restart"

if "reload" in input_type:
if input_type == "reload":
if isinstance(datapath, Path):
Expand Down Expand Up @@ -232,13 +243,14 @@ def attrs_remove_section(obj, section):

# Determine if file is a grid file or data dump files
remove_yboundaries = False
if "dump" in input_type:
if "dump" in input_type or "restart" in input_type:
# Gather pointers to all numerical data from BOUT++ output files
ds, remove_yboundaries = _auto_open_mfboutdataset(
datapath=datapath,
chunks=chunks,
keep_xboundaries=keep_xboundaries,
keep_yboundaries=keep_yboundaries,
is_restart=is_restart,
**kwargs,
)
elif "grid" in input_type:
Expand All @@ -257,20 +269,22 @@ def attrs_remove_section(obj, section):
# bool attributes
metadata["keep_xboundaries"] = int(keep_xboundaries)
metadata["keep_yboundaries"] = int(keep_yboundaries)
metadata["is_restart"] = int(is_restart)
ds = _set_attrs_on_all_vars(ds, "metadata", metadata)

if remove_yboundaries:
# If remove_yboundaries is True, we need to keep y-boundaries when opening the
# grid file, as they will be removed from the full Dataset below
keep_yboundaries = True

for var in _BOUT_TIME_DEPENDENT_META_VARS:
if var in ds:
# Assume different processors in x & y have same iteration etc.
latest_top_left = {dim: 0 for dim in ds[var].dims}
if "t" in ds[var].dims:
latest_top_left["t"] = -1
ds[var] = ds[var].isel(latest_top_left).squeeze(drop=True)
if not is_restart:
for var in _BOUT_TIME_DEPENDENT_META_VARS:
if var in ds:
# Assume different processors in x & y have same iteration etc.
latest_top_left = {dim: 0 for dim in ds[var].dims}
if "t" in ds[var].dims:
latest_top_left["t"] = -1
ds[var] = ds[var].isel(latest_top_left).squeeze(drop=True)

ds = _add_options(ds, inputfilepath)

Expand Down Expand Up @@ -450,6 +464,8 @@ def _check_dataset_type(datapath):
- only one file, and no time dimension
(iii) produced by BOUT++
- one or several files
(iv) restart files produced by BOUT++
- one or several files, no time dimension, filenames include `restart`
"""

if not _is_path(datapath):
Expand Down Expand Up @@ -483,12 +499,18 @@ def _check_dataset_type(datapath):
if "metadata:keep_yboundaries" in ds.attrs:
# (i)
return "reload"
elif len(filepaths) > 1 or "t" in ds.dims:
elif "t" in ds.dims:
# (iii)
return "dump"
else:
elif all(["restart" in Path(p).name for p in filepaths]):
# (iv)
return "restart"
elif len(filepaths) == 1:
# (ii)
return "grid"
else:
# fall back to opening as dump files
return "dump"


def _auto_open_mfboutdataset(
Expand All @@ -497,11 +519,17 @@ def _auto_open_mfboutdataset(
info=True,
keep_xboundaries=False,
keep_yboundaries=False,
is_restart=False,
**kwargs,
):
if chunks is None:
chunks = {}

if is_restart:
data_vars = "minimal"
else:
data_vars = _BOUT_TIME_DEPENDENT_META_VARS

if _is_path(datapath):
filepaths, filetype = _expand_filepaths(datapath)

Expand All @@ -527,6 +555,7 @@ def _auto_open_mfboutdataset(
keep_boundaries={"x": keep_xboundaries, "y": keep_yboundaries},
nxpe=nxpe,
nype=nype,
is_restart=is_restart,
)

paths_grid, concat_dims = _arrange_for_concatenation(filepaths, nxpe, nype)
Expand All @@ -535,7 +564,7 @@ def _auto_open_mfboutdataset(
paths_grid,
concat_dim=concat_dims,
combine="nested",
data_vars=_BOUT_TIME_DEPENDENT_META_VARS,
data_vars=data_vars,
preprocess=_preprocess,
engine=filetype,
chunks=chunks,
Expand Down Expand Up @@ -573,6 +602,7 @@ def _auto_open_mfboutdataset(
keep_boundaries={"x": keep_xboundaries, "y": keep_yboundaries},
nxpe=nxpe,
nype=nype,
is_restart=is_restart,
)

datapath = [_preprocess(x) for x in datapath]
Expand All @@ -582,15 +612,17 @@ def _auto_open_mfboutdataset(
ds = xr.combine_nested(
ds_grid,
concat_dim=concat_dims,
data_vars=_BOUT_TIME_DEPENDENT_META_VARS,
data_vars=data_vars,
join="exact",
combine_attrs="no_conflicts",
)

# Remove any duplicate time values from concatenation
_, unique_indices = unique(ds["t_array"], return_index=True)
if not is_restart:
# Remove any duplicate time values from concatenation
_, unique_indices = unique(ds["t_array"], return_index=True)
ds = ds.isel(t=unique_indices)

return ds.isel(t=unique_indices), remove_yboundaries
return ds, remove_yboundaries


def _expand_filepaths(datapath):
Expand Down Expand Up @@ -749,7 +781,7 @@ def _arrange_for_concatenation(filepaths, nxpe=1, nype=1):
return paths_grid, concat_dims


def _trim(ds, *, guards, keep_boundaries, nxpe, nype):
def _trim(ds, *, guards, keep_boundaries, nxpe, nype, is_restart):
"""
Trims all guard (and optionally boundary) cells off a single dataset read from a
single BOUT dump file, to prepare for concatenation.
Expand All @@ -767,6 +799,8 @@ def _trim(ds, *, guards, keep_boundaries, nxpe, nype):
Number of processors in x direction
nype : int
Number of processors in y direction
is_restart : bool
Is data being loaded from restart files?
"""

if any(keep_boundaries.values()):
Expand All @@ -791,7 +825,13 @@ def _trim(ds, *, guards, keep_boundaries, nxpe, nype):
):
trimmed_ds = trimmed_ds.drop_vars(name)

return trimmed_ds.drop_vars(_BOUT_PER_PROC_VARIABLES, errors="ignore")
to_drop = _BOUT_PER_PROC_VARIABLES
if not is_restart:
# These variables are required to be consistent when loading restart files, so
# that they can be written out again in to_restart()
to_drop = to_drop + _BOUT_PER_PROC_VARIABLES_REQUIRED_FROM_RESTARTS

return trimmed_ds.drop_vars(to_drop, errors="ignore")


def _infer_contains_boundaries(ds, nxpe, nype):
Expand Down
Loading

0 comments on commit 91833c6

Please sign in to comment.