Skip to content

Commit

Permalink
Merge pull request #226 from boutproject/fix-to_field_aligned-wrong-d…
Browse files Browse the repository at this point in the history
…im-order

Fix `to_field_aligned()`/`from_field_aligned()` for transposed arrays
  • Loading branch information
johnomotani authored Dec 24, 2021
2 parents ffda4b2 + 91833c6 commit 76533af
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 6 deletions.
2 changes: 1 addition & 1 deletion xbout/boutdataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _shift_z(self, zShift):

data_shifted_fft = data_fft * np.exp(phase.data)

data_shifted = fft.irfft(data_shifted_fft, n=nz)
data_shifted = fft.irfft(data_shifted_fft, n=nz, axis=axis)

# Return a DataArray with the same attributes as self, but values from
# data_shifted
Expand Down
62 changes: 57 additions & 5 deletions xbout/tests/test_boutdataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ def test_remove_yboundaries(
pytest.param(9, marks=pytest.mark.long),
],
)
def test_to_field_aligned(self, bout_xyt_example_files, nz):
@pytest.mark.parametrize(
"permute_dims", [False, pytest.param(True, marks=pytest.mark.long)]
)
def test_to_field_aligned(self, bout_xyt_example_files, nz, permute_dims):
dataset_list = bout_xyt_example_files(
None, lengths=(3, 3, 4, nz), nxpe=1, nype=1, nt=1
)
Expand Down Expand Up @@ -126,8 +129,15 @@ def test_to_field_aligned(self, bout_xyt_example_files, nz):
n[t, x, y, z] = 1000.0 * t + 100.0 * x + 10.0 * y + z

n.attrs["direction_y"] = "Standard"

if permute_dims:
n = n.transpose("t", "zeta", "x", "theta").compute()

n_al = n.bout.to_field_aligned()

if permute_dims:
n_al = n_al.transpose("t", "x", "theta", "zeta").compute()

assert n_al.direction_y == "Aligned"

for t in range(ds.sizes["t"]):
Expand Down Expand Up @@ -195,7 +205,10 @@ def test_to_field_aligned(self, bout_xyt_example_files, nz):
atol=0.0,
) # noqa: E501

def test_to_field_aligned_dask(self, bout_xyt_example_files):
@pytest.mark.parametrize(
"permute_dims", [False, pytest.param(True, marks=pytest.mark.long)]
)
def test_to_field_aligned_dask(self, bout_xyt_example_files, permute_dims):

nz = 6

Expand Down Expand Up @@ -231,8 +244,15 @@ def test_to_field_aligned_dask(self, bout_xyt_example_files):
assert isinstance(n.data, dask.array.Array)

n.attrs["direction_y"] = "Standard"

if permute_dims:
n = n.transpose("t", "zeta", "x", "theta").compute()

n_al = n.bout.to_field_aligned()

if permute_dims:
n_al = n_al.transpose("t", "x", "theta", "zeta").compute()

assert n_al.direction_y == "Aligned"

for t in range(ds.sizes["t"]):
Expand Down Expand Up @@ -309,7 +329,10 @@ def test_to_field_aligned_dask(self, bout_xyt_example_files):
pytest.param(9, marks=pytest.mark.long),
],
)
def test_from_field_aligned(self, bout_xyt_example_files, nz):
@pytest.mark.parametrize(
"permute_dims", [False, pytest.param(True, marks=pytest.mark.long)]
)
def test_from_field_aligned(self, bout_xyt_example_files, nz, permute_dims):
dataset_list = bout_xyt_example_files(
None, lengths=(3, 3, 4, nz), nxpe=1, nype=1, nt=1
)
Expand Down Expand Up @@ -337,8 +360,15 @@ def test_from_field_aligned(self, bout_xyt_example_files, nz):
n[t, x, y, z] = 1000.0 * t + 100.0 * x + 10.0 * y + z

n.attrs["direction_y"] = "Aligned"

if permute_dims:
n = n.transpose("t", "zeta", "x", "theta").compute()

n_nal = n.bout.from_field_aligned()

if permute_dims:
n_nal = n_nal.transpose("t", "x", "theta", "zeta").compute()

assert n_nal.direction_y == "Standard"

for t in range(ds.sizes["t"]):
Expand Down Expand Up @@ -407,7 +437,12 @@ def test_from_field_aligned(self, bout_xyt_example_files, nz):
) # noqa: E501

@pytest.mark.parametrize("stag_location", ["CELL_XLOW", "CELL_YLOW", "CELL_ZLOW"])
def test_to_field_aligned_staggered(self, bout_xyt_example_files, stag_location):
@pytest.mark.parametrize(
"permute_dims", [False, pytest.param(True, marks=pytest.mark.long)]
)
def test_to_field_aligned_staggered(
self, bout_xyt_example_files, stag_location, permute_dims
):
dataset_list = bout_xyt_example_files(
None, lengths=(3, 3, 4, 8), nxpe=1, nype=1, nt=1
)
Expand All @@ -434,8 +469,14 @@ def test_to_field_aligned_staggered(self, bout_xyt_example_files, stag_location)
for z in range(ds.sizes["zeta"]):
n[t, x, y, z] = 1000.0 * t + 100.0 * x + 10.0 * y + z

if permute_dims:
n = n.transpose("t", "zeta", "x", "theta").compute()

n_al = n.bout.to_field_aligned().copy(deep=True)

if permute_dims:
n_al = n_al.transpose("t", "x", "theta", "zeta").compute()

assert n_al.direction_y == "Aligned"

# make 'n' staggered
Expand All @@ -459,7 +500,12 @@ def test_to_field_aligned_staggered(self, bout_xyt_example_files, stag_location)
npt.assert_equal(n_stag_al.values, n_al.values)

@pytest.mark.parametrize("stag_location", ["CELL_XLOW", "CELL_YLOW", "CELL_ZLOW"])
def test_from_field_aligned_staggered(self, bout_xyt_example_files, stag_location):
@pytest.mark.parametrize(
"permute_dims", [False, pytest.param(True, marks=pytest.mark.long)]
)
def test_from_field_aligned_staggered(
self, bout_xyt_example_files, stag_location, permute_dims
):
dataset_list = bout_xyt_example_files(
None, lengths=(3, 3, 4, 8), nxpe=1, nype=1, nt=1
)
Expand Down Expand Up @@ -488,8 +534,14 @@ def test_from_field_aligned_staggered(self, bout_xyt_example_files, stag_locatio
n.attrs["direction_y"] = "Aligned"
ds["T"].attrs["direction_y"] = "Aligned"

if permute_dims:
n = n.transpose("t", "zeta", "x", "theta").compute()

n_nal = n.bout.from_field_aligned().copy(deep=True)

if permute_dims:
n_nal = n_nal.transpose("t", "x", "theta", "zeta").compute()

assert n_nal.direction_y == "Standard"

# make 'n' staggered
Expand Down

0 comments on commit 76533af

Please sign in to comment.