Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix keep_yboundaries=False for squashed, double-null cases #180

Merged
merged 7 commits into from
Mar 19, 2021
Merged
5 changes: 1 addition & 4 deletions xbout/boutdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,7 @@ def find_with_dims(first_var, dims):
ds = ds.reset_coords("dy")

# Apply geometry
if hasattr(ds, "geometry"):
ds = apply_geometry(ds, ds.geometry)
# if no geometry was originally applied, then ds has no geometry attribute and we
# can continue without applying geometry here
ds = apply_geometry(ds, ds.geometry)

return ds

Expand Down
22 changes: 16 additions & 6 deletions xbout/geometries.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def apply_geometry(ds, geometry_name, *, coordinates=None, grid=None):
UnregisteredGeometryError
"""

if geometry_name is None:
if geometry_name is None or geometry_name == "":
ds = _set_attrs_on_all_vars(ds, "geometry", "")
updated_ds = ds
else:
ds = _set_attrs_on_all_vars(ds, "geometry", geometry_name)
Expand Down Expand Up @@ -209,12 +210,18 @@ def _set_default_toroidal_coordinates(coordinates, ds):
coordinates = {}

# Replace any values that have not been passed in with defaults
coordinates["t"] = coordinates.get("t", ds.metadata.get("bout_tdim", "t"))
coordinates["x"] = coordinates.get(
"x", ds.metadata.get("bout_xdim", "psi_poloidal")
coordinates["t"] = coordinates.get("t", ds.metadata["bout_tdim"])

default_x = (
ds.metadata["bout_xdim"] if ds.metadata["bout_xdim"] != "x" else "psi_poloidal"
)
coordinates["y"] = coordinates.get("y", ds.metadata.get("bout_ydim", "theta"))
coordinates["z"] = coordinates.get("z", ds.metadata.get("bout_zdim", "zeta"))
coordinates["x"] = coordinates.get("x", default_x)

default_y = ds.metadata["bout_ydim"] if ds.metadata["bout_ydim"] != "y" else "theta"
coordinates["y"] = coordinates.get("y", default_y)

default_z = ds.metadata["bout_zdim"] if ds.metadata["bout_zdim"] != "z" else "zeta"
coordinates["z"] = coordinates.get("z", default_z)

return coordinates

Expand Down Expand Up @@ -285,6 +292,9 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None):
# Record which dimension 'z' was renamed to.
ds.metadata["bout_zdim"] = coordinates["z"]

# Ensure metadata is the same on all variables
ds = _set_attrs_on_all_vars(ds, "metadata", ds.metadata)

# Add 2D Cylindrical coordinates
if ("R" not in ds) and ("Z" not in ds):
ds = ds.rename(Rxy="R", Zxy="Z")
Expand Down
50 changes: 45 additions & 5 deletions xbout/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,10 @@ def attrs_remove_section(obj, section):
return ds

# Determine if file is a grid file or data dump files
remove_yboundaries = False
if "dump" in input_type:
# Gather pointers to all numerical data from BOUT++ output files
ds = _auto_open_mfboutdataset(
ds, remove_yboundaries = _auto_open_mfboutdataset(
datapath=datapath,
chunks=chunks,
keep_xboundaries=keep_xboundaries,
Expand All @@ -215,6 +216,11 @@ def attrs_remove_section(obj, section):
metadata["keep_yboundaries"] = int(keep_yboundaries)
ds = _set_attrs_on_all_vars(ds, "metadata", metadata)

if remove_yboundaries:
# If remove_yboundaries is True, we need to keep y-boundaries when opening the
# grid file, as they will be removed from the full Dataset below
keep_yboundaries = True

for var in _BOUT_TIME_DEPENDENT_META_VARS:
if var in ds:
# Assume different processors in x & y have same iteration etc.
Expand Down Expand Up @@ -250,6 +256,9 @@ def attrs_remove_section(obj, section):
# Update coordinates to match particular geometry of grid
ds = geometries.apply_geometry(ds, geometry, grid=grid)

if remove_yboundaries:
ds = ds.bout.remove_yboundaries()

# TODO read and store git commit hashes from output files

if run_name:
Expand Down Expand Up @@ -339,7 +348,7 @@ def collect(

datapath = join(path, prefix + "*.nc")

ds = _auto_open_mfboutdataset(
ds, _ = _auto_open_mfboutdataset(
datapath, keep_xboundaries=xguards, keep_yboundaries=yguards, info=info
)

Expand Down Expand Up @@ -456,7 +465,20 @@ def _auto_open_mfboutdataset(
filepaths, filetype = _expand_filepaths(datapath)

# Open just one file to read processor splitting
nxpe, nype, mxg, myg, mxsub, mysub = _read_splitting(filepaths[0], info)
nxpe, nype, mxg, myg, mxsub, mysub, is_squashed_doublenull = _read_splitting(
filepaths[0], info
)

if is_squashed_doublenull:
# Need to remove y-boundaries after loading: (i) in case we are loading a
# squashed data-set, in which case we cannot easily remove the upper
# boundary cells in _trim(); (ii) because using the remove_yboundaries()
# method for non-squashed data-sets is simpler than replicating that logic
# in _trim().
remove_yboundaries = not keep_yboundaries
keep_yboundaries = True
else:
remove_yboundaries = False

_preprocess = partial(
_trim,
Expand Down Expand Up @@ -490,6 +512,21 @@ def _auto_open_mfboutdataset(
myg = int(datapath[0]["MYG"])
nxpe = int(datapath[0]["NXPE"])
nype = int(datapath[0]["NYPE"])
is_squashed_doublenull = (
len(datapath) == 1
and (datapath[0]["jyseps2_1"] != datapath[0]["jyseps1_2"]).values
)

if is_squashed_doublenull:
# Need to remove y-boundaries after loading: (i) in case we are loading a
# squashed data-set, in which case we cannot easily remove the upper
# boundary cells in _trim(); (ii) because using the remove_yboundaries()
# method for non-squashed data-sets is simpler than replicating that logic
# in _trim().
johnomotani marked this conversation as resolved.
Show resolved Hide resolved
remove_yboundaries = not keep_yboundaries
keep_yboundaries = True
else:
remove_yboundaries = False

_preprocess = partial(
_trim,
Expand All @@ -512,7 +549,8 @@ def _auto_open_mfboutdataset(

# Remove any duplicate time values from concatenation
_, unique_indices = unique(ds["t_array"], return_index=True)
return ds.isel(t=unique_indices)

return ds.isel(t=unique_indices), remove_yboundaries


def _expand_filepaths(datapath):
Expand Down Expand Up @@ -598,6 +636,7 @@ def get_nonnegative_scalar(ds, key, default=1, info=True):
ny = ds["ny"].values
nx_file = ds.dims["x"]
ny_file = ds.dims["y"]
is_squashed_doublenull = False
if nxpe > 1 or nype > 1:
# if nxpe = nype = 1, was only one process anyway, so no need to check for
# squashing
Expand All @@ -620,11 +659,12 @@ def get_nonnegative_scalar(ds, key, default=1, info=True):

nxpe = 1
nype = 1
is_squashed_doublenull = (ds["jyseps2_1"] != ds["jyseps1_2"]).values

# Avoid trying to open this file twice
ds.close()

return nxpe, nype, mxg, myg, mxsub, mysub
return nxpe, nype, mxg, myg, mxsub, mysub, is_squashed_doublenull


def _arrange_for_concatenation(filepaths, nxpe=1, nype=1):
Expand Down
Loading