Skip to content

Commit

Permalink
dev
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhassell committed Sep 19, 2023
1 parent 529b143 commit 6b30280
Showing 1 changed file with 34 additions and 10 deletions.
44 changes: 34 additions & 10 deletions cf/data/dask_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,19 @@ def regrid(
dst_shape: sequence of `int`
The shape of the destination grid.
axis_order, sequence of `int`
axis_order: sequence of `int`
The axis order that transposes *a* so that the regrid axes
become the trailing dimensions, ordered consistently with
the order used to create the weights matrix; and the
non-regrid axes become the leading dimensions.
ref_src_mask, `numpy.ndarray` or `None`
*Parameter example:*
``[0, 3, 2, 1]``
*Parameter example:*
``[0, 2, 1]``
ref_src_mask: `numpy.ndarray` or `None`
If a `numpy.ndarray` with shape *src_shape* then this is
the reference source grid mask that was used during the
creation of the weights matrix given by *weights*, and the
Expand Down Expand Up @@ -143,15 +149,15 @@ def regrid(
**Linear regridding**
Destination grid cell j will only be masked if a) it is
masked in destination grid definition; or b) ``w_ji >=
masked in the destination grid definition; or b) ``w_ji >=
min_weight`` for those masked source grid cells i for
which ``w_ji > 0``.
**Conservative first-order regridding**
Destination grid cell j will only be masked if a) it is
masked in destination grid definition; or b) The sum of
``w_ji`` for all non-masked source grid cells i is
masked in the destination grid definition; or b) the sum
of ``w_ji`` for all non-masked source grid cells i is
strictly less than *min_weight*.
:Returns:
Expand All @@ -168,8 +174,9 @@ def regrid(
# are the gathered regridding axes and whose left-hand dimension
# represent of all the other dimensions.
# ----------------------------------------------------------------
n_src_axes = len(src_shape)
a = a.transpose(axis_order)
non_regrid_shape = a.shape[: a.ndim - len(src_shape)]
non_regrid_shape = a.shape[: a.ndim - n_src_axes]
dst_size, src_size = weights.shape
a = a.reshape(-1, src_size)
a = a.T
Expand Down Expand Up @@ -200,7 +207,7 @@ def regrid(
if variable_mask or (src_mask is None and ref_src_mask.any()):
raise ValueError(
f"Can't regrid with the {method!r} method when the source "
f"data mask varies over different {len(src_shape)}-d "
f"data mask varies over different {n_src_axes}-d "
"regridding slices"
)

Expand Down Expand Up @@ -279,10 +286,27 @@ def regrid(
a = a.T
a = a.reshape(non_regrid_shape + tuple(dst_shape))

n_dst_axes = len(dst_shape)
if n_src_axes == n_dst_axes:
pass
elif n_src_axes == 1 and n_dst_axes > 1:
# E.g. UGRID -> regular lat-lon; changes axis order from
# [0,2,1] to [0,3,1,2]

r = axis_order[-1]
axis_order = [i + n_dst_axes - 1 if i > r else i for i in axis_order]
axis_order.extend(range(r + 1, r + n_dst_axes))
elif n_dst_axes == 1 and n_src_axes > 1:
# E.g. regular lat-lon -> UGRID; changes axis order from
# [0,3,2,1] to [0,2,1]
pass # TODOUGRID
else:
raise ValueError("TODOUGRID")

d = {k: i for i, k in enumerate(axis_order)}
axis_reorder = [i for k, i in sorted(d.items())]

a = a.transpose(axis_reorder)

return a


Expand Down Expand Up @@ -514,8 +538,8 @@ def _regrid(
a = weights.dot(a)

if dst_mask is not None:
a = np.ma.array(a)
a[dst_mask] = np.ma.masked
a = np.ma.array(a, mask=dst_mask)
# a[dst_mask] = np.ma.masked

return a, src_mask, dst_mask, weights

Expand Down

0 comments on commit 6b30280

Please sign in to comment.