Skip to content

Commit

Permalink
Fix nanmin, nanmax bug
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Jan 8, 2025
1 parent 3fc25b7 commit 8f8d051
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
17 changes: 15 additions & 2 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,15 +393,17 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
"nanmin",
chunk="nanmin",
combine="nanmin",
fill_value=dtypes.NA,
fill_value=dtypes.INF,
final_fill_value=dtypes.NA,
preserves_dtype=True,
)
max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF, preserves_dtype=True)
nanmax = Aggregation(
"nanmax",
chunk="nanmax",
combine="nanmax",
fill_value=dtypes.NA,
fill_value=dtypes.NINF,
final_fill_value=dtypes.NA,
preserves_dtype=True,
)

Expand Down Expand Up @@ -845,6 +847,17 @@ def _initialize_aggregation(
# absent in one block, but present in another block
# We set it for numpy to get nansum, nanprod tests to pass
# where the identity element is 0, 1
# Also needed for nanmin, nanmax where intermediate fill_value is +-np.inf,
# but final_fill_value is dtypes.NA
if (
# TODO: this is a total hack.
agg.name in ["nanmin", "nanmax"]
and agg.fill_value["intermediate"] != (agg.fill_value[agg.name],)
and min_count == 0
):
min_count = 1
agg.fill_value["user"] = agg.fill_value["user"] or agg.fill_value[agg.name]

if min_count > 0:
agg.min_count = min_count
agg.numpy += ("nanlen",)
Expand Down
8 changes: 4 additions & 4 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,6 @@ def _finalize_results(
agg: Aggregation,
axis: T_Axes,
expected_groups: pd.Index | None,
fill_value: Any,
reindex: bool,
) -> FinalResultsDict:
"""Finalize results by
Expand All @@ -1142,6 +1141,7 @@ def _finalize_results(
else:
finalized[agg.name] = agg.finalize(*squeezed["intermediates"], **agg.finalize_kwargs)

fill_value = agg.fill_value["user"]
if min_count > 0:
count_mask = counts < min_count
if count_mask.any():
Expand Down Expand Up @@ -1183,7 +1183,7 @@ def _aggregate(
) -> FinalResultsDict:
"""Final aggregation step of tree reduction"""
results = combine(x_chunk, agg, axis, keepdims, is_aggregate=True)
return _finalize_results(results, agg, axis, expected_groups, fill_value, reindex)
return _finalize_results(results, agg, axis, expected_groups, reindex)


def _expand_dims(results: IntermediateDict) -> IntermediateDict:
Expand Down Expand Up @@ -1449,7 +1449,7 @@ def _reduce_blockwise(
if _is_arg_reduction(agg):
results["intermediates"][0] = np.unravel_index(results["intermediates"][0], array.shape)[-1]

result = _finalize_results(results, agg, axis, expected_groups, fill_value=fill_value, reindex=reindex)
result = _finalize_results(results, agg, axis, expected_groups, reindex=reindex)
return result


Expand Down Expand Up @@ -1926,7 +1926,7 @@ def _groupby_combine(a, axis, dummy_axis, dtype, keepdims):
def _groupby_aggregate(a):
# Convert cubed dict to one that _finalize_results works with
results = {"groups": expected_groups, "intermediates": a.values()}
out = _finalize_results(results, agg, axis, expected_groups, fill_value, reindex)
out = _finalize_results(results, agg, axis, expected_groups, reindex)
return out[agg.name]

# convert list of dtypes to a structured dtype for cubed
Expand Down
5 changes: 2 additions & 3 deletions tests/test_properties.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
import warnings
from collections.abc import Callable
from typing import Any
Expand Down Expand Up @@ -75,7 +74,6 @@ def test_groupby_reduce(data, array, func: str) -> None:
assume(not (("quantile" in func or "var" in func or "std" in func) and array.dtype.kind == "c"))
# arg* with nans in array are weird
assume("arg" not in func and not np.any(np.isnan(array).ravel()))
assume(False if func in ["median", "quantile"] and math.prod(array.numblocks) > 1 else True)

axis = -1
by = data.draw(
Expand Down Expand Up @@ -143,7 +141,8 @@ def test_groupby_reduce_numpy_vs_dask(data, array, func: str) -> None:
assume(not (("quantile" in func or "var" in func or "std" in func) and array.dtype.kind == "c"))
# # arg* with nans in array are weird
assume("arg" not in func and not np.any(np.isnan(numpy_array.ravel())))
assume(False if func in ["median", "quantile"] and math.prod(array.numblocks) > 1 else True)
if func in ["nanmedian", "nanquantile", "median", "quantile"]:
array = array.rechunk({-1: -1})

axis = -1
by = data.draw(by_arrays(shape=(array.shape[-1],)))
Expand Down

0 comments on commit 8f8d051

Please sign in to comment.