Skip to content

Commit

Permalink
Allow kw in clustering (#218)
Browse files Browse the repository at this point in the history
* allow kw in clustering

* fix typo

* fix references
  • Loading branch information
JoranAngevaare authored Feb 24, 2025
1 parent ee5e1f6 commit b4abcba
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
25 changes: 17 additions & 8 deletions optim_esm_tools/analyze/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from optim_esm_tools.config import config
from optim_esm_tools.config import get_logger
from optim_esm_tools.utils import timed
from optim_esm_tools.utils import timed, deprecated
from optim_esm_tools.utils import tqdm


Expand Down Expand Up @@ -232,11 +232,17 @@ def _check_input(
return lat, lon


def _split_to_continous(
@deprecated
def _split_to_continous(*a, **kw):
return _split_to_continuous(*a, **kw)


def _split_to_continuous(
masks: ty.List,
**kw,
) -> ty.List[np.ndarray]:
no_group = -1
mask_groups = masks_array_to_coninuous_sets(masks, no_group_value=no_group)
mask_groups = masks_array_to_coninuous_sets(masks, no_group_value=no_group, **kw)
continous_masks = []
for grouped_members in mask_groups:
for group_id in np.unique(grouped_members):
Expand Down Expand Up @@ -296,7 +302,7 @@ def _build_cluster_with_kw(
masks.append(np.array(full_2d_mask))

if force_continuity:
masks = _split_to_continous(masks=masks)
masks = _split_to_continuous(masks=masks)

clusters = [_find_lat_lon_values(m, lats=lat, lons=lon) for m in masks]

Expand Down Expand Up @@ -493,18 +499,21 @@ def masks_array_to_coninuous_sets(

result_groups = np.ones_like(masks[0], dtype=np.int64) * no_group_value
check_buffer = np.zeros_like(masks[0], dtype=np.bool_)

kw_cont_sets = dict(
len_x=len_x,
len_y=len_y,
add_diagonal=add_diagonal,
)
kw_cont_sets.update(kw)
# Warning, do notice that the result_buffer and check_buffer are modified in place! However, _group_mask_in_continous_sets does reset the buffer each time
# Therefore, we have to copy the result each time! Otherwise that result will be overwritten in the next iteration
return [
_group_mask_in_continous_sets(
mask=mask,
no_group_value=no_group_value,
add_diagonal=add_diagonal,
len_x=len_x,
len_y=len_y,
result_buffer=result_groups,
check_buffer=check_buffer,
**kw_cont_sets,
).copy()
for mask in masks
]
Expand Down
2 changes: 1 addition & 1 deletion optim_esm_tools/region_finding/keep_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_masks(
coords.append(np.array(this_coords))
masks.append(this_mask)
if force_continuity:
masks = oet.analyze.clustering._split_to_continous(masks=masks)
masks = oet.analyze.clustering._split_to_continuous(masks=masks)
lat, lon = np.meshgrid(lats, lons)
coords = [
oet.analyze.clustering._find_lat_lon_values(m, lats=lat.T, lons=lon.T)
Expand Down

0 comments on commit b4abcba

Please sign in to comment.