Skip to content

Commit

Permalink
Extend rebin_irregular() to handle masks.
Browse files Browse the repository at this point in the history
  • Loading branch information
DanRyanIrish committed May 3, 2024
1 parent f6d2e64 commit 8f0ab29
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 19 deletions.
45 changes: 33 additions & 12 deletions stixpy/product/tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,39 @@
from types import SimpleNamespace

import numpy as np
import pytest

from .. import tools

def test_rebin_irregular():
# Define inputs
axes_idx_edges = [0, 2, 3], [0, 2, 4], [0, 3, 5]
@pytest.fixture
def cube():
data = np.ones((3, 4, 5))
operation = np.sum
# Run function.
output = tools.rebin_irregular(data, axes_idx_edges, operation=operation)
# Compare with expected
expected = np.array([[[12., 8.],
[12., 8.]],
[[ 6., 4.],
[ 6., 4.]]])
np.testing.assert_allclose(output, expected)
mask = np.array([[[ True, True, True, False, False],
[ True, False, True, False, False],
[ True, False, False, False, False],
[ True, False, True, True, True]],
[[ True, True, True, True, True],
[ True, True, False, False, True],
[False, False, True, False, False],
[ True, True, False, False, True]],
[[False, False, True, True, False],
[False, True, False, False, True],
[ True, False, True, True, True],
[ True, True, True, False, False]]])
return SimpleNamespace(data=data, mask=mask, shape=data.shape)


def test_rebin_irregular(cube):
# Run function.
axes_idx_edges = [0, 2, 3], [0, 2, 4], [0, 3, 5]
data, mask = tools.rebin_irregular(cube, axes_idx_edges, operation=np.sum,
operation_ignores_mask=False, handle_mask=np.all)
# Define expected outputs.
expected_data = np.array([[[2., 5.],
[6., 5.]],
[[4., 2.],
[1., 2.]]])
expected_mask = np.zeros(expected_data.shape, dtype=bool)
# Compare outputs with expected.
np.testing.assert_allclose(data, expected_data)
assert (mask == expected_mask).all()
55 changes: 48 additions & 7 deletions stixpy/product/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,31 @@

import numpy as np

ARRAY_MASK_MAP = {}
ARRAY_MASK_MAP[np.ndarray] = np.ma.masked_array
_NUMPY_COPY_IF_NEEDED = False if np.__version__.startswith("1.") else None
try:
import dask.array
ARRAY_MASK_MAP[dask.array.core.Array] = dask.array.ma.masked_array
except ImportError:
pass

def rebin_irregular(data, axes_idx_edges, operation=np.mean):

# TODO: Delete once this function is broken out in ndcube to a util and import from there
def _convert_to_masked_array(data, mask, operation_ignores_mask):
m = None if (mask is None or mask is False or operation_ignores_mask) else mask
if m is not None:
for array_type, masked_type in ARRAY_MASK_MAP.items():
if isinstance(data, array_type):
break
else:
masked_type = np.ma.masked_array
warn_user("data and mask arrays of different or unrecognized types. Casting them into a numpy masked array.")
return masked_type(data, m)


def rebin_irregular(cube, axes_idx_edges, operation=np.mean, operation_ignores_mask=False,
handle_mask=np.all):
"""
Downsample array by combining irregularly sized, but contiguous, blocks of elements into bins.
Expand Down Expand Up @@ -39,17 +62,35 @@ def rebin_irregular(data, axes_idx_edges, operation=np.mean):
[[ 6., 4.],
[ 6., 4.]]])
"""
# TODO: extend this function to handle NDCube, not just an array.
ndim = data.ndim
# Sanitize inputs
ndim = len(cube.shape)
if len(axes_idx_edges) != ndim:
raise ValueError(f"length of axes_idx_edges must be {ndim}")
# Combine data and mask and determine whether mask also needs rebinning.
data = _convert_to_masked_array(cube.data, cube.mask, operation_ignores_mask)
rebin_mask = False
if handle_mask is None or operation_ignores_mask:
mask = None
else:
mask = cube.mask
if not isinstance(cube.mask, (type(None), bool)):
rebin_mask = True
# Iterate through dimensions and perform rebinning operation.
for i, edges in enumerate(axes_idx_edges):
# If no edge indices provided for dimension, skip to next one.
if edges is None:
continue
r = []
# Iterate through pixel blocks and collapse via rebinning operation.
tmp_data = []
tmp_mask = []
item = [slice(None)] * ndim
for j in range(len(edges)-1):
item[i] = slice(edges[j], edges[j+1])
r.append(operation(data[*item], axis=i))
data = np.stack(r, axis=i)
return data
tmp_data.append(operation(data[*item], axis=i))
if rebin_mask is True:
tmp_mask.append(handle_mask(mask[*item], axis=i))
data = np.stack(tmp_data, axis=i)
if rebin_mask:
mask = np.stack(tmp_mask, axis=i)
data = data.data
return data, mask

0 comments on commit 8f0ab29

Please sign in to comment.