Skip to content

Commit

Permalink
Merge pull request #78 from dcherian/optimize
Browse files Browse the repository at this point in the history
Some optimizations for numpy aggregations
  • Loading branch information
ml31415 authored Apr 26, 2023
2 parents 6b7877b + 16ad9a2 commit 5e265ef
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
19 changes: 12 additions & 7 deletions numpy_groupies/aggregate_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _sum(group_idx, a, size, fill_value, dtype=None):
dtype = minimum_dtype_scalar(fill_value, dtype, a)

if np.ndim(a) == 0:
ret = np.bincount(group_idx, minlength=size).astype(dtype)
ret = np.bincount(group_idx, minlength=size).astype(dtype, copy=False)
if a != 1:
ret *= a
else:
Expand All @@ -33,7 +33,9 @@ def _sum(group_idx, a, size, fill_value, dtype=None):
ret.real = np.bincount(group_idx, weights=a.real, minlength=size)
ret.imag = np.bincount(group_idx, weights=a.imag, minlength=size)
else:
ret = np.bincount(group_idx, weights=a, minlength=size).astype(dtype)
ret = np.bincount(group_idx, weights=a, minlength=size).astype(
dtype, copy=False
)

if fill_value != 0:
_fill_untouched(group_idx, ret, fill_value)
Expand Down Expand Up @@ -146,19 +148,19 @@ def _mean(group_idx, a, size, fill_value, dtype=np.dtype(np.float64)):
sums.real = np.bincount(group_idx, weights=a.real, minlength=size)
sums.imag = np.bincount(group_idx, weights=a.imag, minlength=size)
else:
sums = np.bincount(group_idx, weights=a, minlength=size).astype(dtype)
sums = np.bincount(group_idx, weights=a, minlength=size).astype(dtype, copy=False)

with np.errstate(divide="ignore", invalid="ignore"):
ret = sums.astype(dtype) / counts
ret = sums.astype(dtype, copy=False) / counts
if not np.isnan(fill_value):
ret[counts == 0] = fill_value
return ret


def _sum_of_squres(group_idx, a, size, fill_value, dtype=np.dtype(np.float64)):
ret = np.bincount(group_idx, weights=a * a, minlength=size)
counts = np.bincount(group_idx, minlength=size)
if fill_value != 0:
counts = np.bincount(group_idx, minlength=size)
ret[counts == 0] = fill_value
return ret

Expand All @@ -171,7 +173,7 @@ def _var(
counts = np.bincount(group_idx, minlength=size)
sums = np.bincount(group_idx, weights=a, minlength=size)
with np.errstate(divide="ignore", invalid="ignore"):
means = sums.astype(dtype) / counts
means = sums.astype(dtype, copy=False) / counts
counts = np.where(counts > ddof, counts - ddof, 0)
ret = (
np.bincount(group_idx, (a - means[group_idx]) ** 2, minlength=size) / counts
Expand Down Expand Up @@ -299,6 +301,7 @@ def _aggregate_base(
dtype=None,
axis=None,
_impl_dict=_impl_dict,
is_pandas=False,
**kwargs
):
iv = input_validation(group_idx, a, size=size, order=order, axis=axis, func=func)
Expand All @@ -324,7 +327,9 @@ def _aggregate_base(
kwargs["_nansqueeze"] = True
else:
good = ~np.isnan(a)
a = a[good]
if "len" not in func or is_pandas:
# a is not needed for len, nanlen!
a = a[good]
group_idx = group_idx[good]

dtype = check_dtype(dtype, func, a, flat_size)
Expand Down
1 change: 1 addition & 0 deletions numpy_groupies/aggregate_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def aggregate(
func=func,
axis=axis,
_impl_dict=_impl_dict,
is_pandas=True,
**kwargs
)

Expand Down

0 comments on commit 5e265ef

Please sign in to comment.