Skip to content

Commit

Permalink
Improve fill_value handling
Browse files Browse the repository at this point in the history
  • Loading branch information
ml31415 committed Dec 23, 2019
1 parent 0911e9c commit 197bde5
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 42 deletions.
10 changes: 7 additions & 3 deletions numpy_groupies/aggregate_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __call__(self, group_idx, a, size=None, fill_value=0, order='C',

# TODO: The typecheck should be done by the class itself, not by check_dtype
dtype = check_dtype(dtype, self.func, a, len(group_idx))
check_fill_value(fill_value, dtype)
check_fill_value(fill_value, dtype, func=self.func)
input_dtype = type(a) if np.isscalar(a) else a.dtype
ret, counter, mean, outer = self._initialize(flat_size, fill_value, dtype, input_dtype, group_idx.size)
group_idx = np.ascontiguousarray(group_idx)
Expand Down Expand Up @@ -85,7 +85,10 @@ def _initialize(cls, flat_size, fill_value, dtype, input_dtype, input_size):
@classmethod
def _finalize(cls, ret, counter, fill_value):
if cls.forced_fill_value is not None and fill_value != cls.forced_fill_value:
ret[counter] = fill_value
if cls.counter_dtype == bool:
ret[counter] = fill_value
else:
ret[~counter.astype(bool)] = fill_value

@classmethod
def callable(cls, nans=False, reverse=False, scalar=False):
Expand Down Expand Up @@ -192,7 +195,7 @@ def __call__(self, group_idx, a, size=None, fill_value=0, order='C',

# TODO: The typecheck should be done by the class itself, not by check_dtype
dtype = check_dtype(dtype, self.func, a, len(group_idx))
check_fill_value(fill_value, dtype)
check_fill_value(fill_value, dtype, func=self.func)
input_dtype = type(a) if np.isscalar(a) else a.dtype
ret, _, _, _= self._initialize(flat_size, fill_value, dtype, input_dtype, group_idx.size)
group_idx = np.ascontiguousarray(group_idx)
Expand Down Expand Up @@ -354,6 +357,7 @@ def _inner(ri, val, ret, counter, mean):


class Mean(Aggregate2pass):
forced_fill_value = 0
counter_fill_value = 0
counter_dtype = int

Expand Down
3 changes: 2 additions & 1 deletion numpy_groupies/aggregate_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .utils import check_boolean, funcs_no_separate_nan, get_func, aggregate_common_doc, isstr
from .utils_numpy import (aliasing, minimum_dtype, input_validation,
check_dtype, minimum_dtype_scalar)
check_dtype, check_fill_value, minimum_dtype_scalar)


def _sum(group_idx, a, size, fill_value, dtype=None):
Expand Down Expand Up @@ -271,6 +271,7 @@ def _aggregate_base(group_idx, a, func='sum', size=None, fill_value=0,
group_idx = group_idx[good]

dtype = check_dtype(dtype, func, a, flat_size)
check_fill_value(fill_value, dtype, func=func)
func = _impl_dict[func]
ret = func(group_idx, a, flat_size, fill_value=fill_value, dtype=dtype,
**kwargs)
Expand Down
28 changes: 16 additions & 12 deletions numpy_groupies/aggregate_weave.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
except ImportError:
from scipy.weave import inline

from .utils import get_func, isstr, funcs_no_separate_nan, aggregate_common_doc
from .utils import get_func, isstr, funcs_no_separate_nan, aggregate_common_doc, check_boolean
from .utils_numpy import check_dtype, aliasing, check_fill_value, input_validation


Expand Down Expand Up @@ -175,7 +175,7 @@ def get_cfuncs():
return c_funcs


c_funcs = get_cfuncs()
c_funcs.update(get_cfuncs())


c_step_count = c_size('group_idx') + r"""
Expand Down Expand Up @@ -221,6 +221,9 @@ def step_indices(group_idx):
return indices


_force_fill_0 = frozenset({'sum', 'any', 'len', 'anynan', 'mean', 'std', 'var', 'nansum', 'nanlen', 'nanmean', 'nanstd', 'nanvar'})
_force_fill_1 = frozenset({'prod', 'all', 'allnan', 'nanprod'})

def aggregate(group_idx, a, func='sum', size=None, fill_value=0, order='C',
dtype=None, axis=None, **kwargs):
func = get_func(func, aliasing, optimized_funcs)
Expand All @@ -233,15 +236,15 @@ def aggregate(group_idx, a, func='sum', size=None, fill_value=0, order='C',
order=order,
axis=axis)
dtype = check_dtype(dtype, func, a, len(group_idx))
check_fill_value(fill_value, dtype)
check_fill_value(fill_value, dtype, func=func)
nans = func.startswith('nan')

if nans:
flat_size += 1

if func in ('sum', 'any', 'len', 'anynan', 'nansum', 'nanlen'):
if func in _force_fill_0:
ret = np.zeros(flat_size, dtype=dtype)
elif func in ('prod', 'all', 'allnan', 'nanprod'):
elif func in _force_fill_1:
ret = np.ones(flat_size, dtype=dtype)
else:
ret = np.full(flat_size, fill_value, dtype=dtype)
Expand All @@ -250,14 +253,14 @@ def aggregate(group_idx, a, func='sum', size=None, fill_value=0, order='C',
inline_vars = dict(group_idx=np.ascontiguousarray(group_idx), a=np.ascontiguousarray(a),
ret=ret, fill_value=fill_value)
# TODO: Have this fixed by proper raveling
if func in ('std', 'var', 'nanstd', 'nanvar'):
if func in {'std', 'var', 'nanstd', 'nanvar'}:
counter = np.zeros_like(ret, dtype=int)
inline_vars['means'] = np.zeros_like(ret)
inline_vars['ddof'] = kwargs.pop('ddof', 0)
elif func in ('mean', 'nanmean'):
elif func in {'mean', 'nanmean'}:
counter = np.zeros_like(ret, dtype=int)
else:
# Using inverse logic, marking anyting touched with zero for later removal
# Using inverse logic, marking anything touched with zero for later removal
counter = np.ones_like(ret, dtype=bool)
inline_vars['counter'] = counter

Expand All @@ -267,10 +270,11 @@ def aggregate(group_idx, a, func='sum', size=None, fill_value=0, order='C',
inline(c_funcs[func], inline_vars.keys(), local_dict=inline_vars, define_macros=c_macros, extra_compile_args=c_args)

# Postprocessing
if func in ('sum', 'any', 'anynan', 'nansum') and fill_value != 0:
ret[counter] = fill_value
elif func in ('prod', 'all', 'allnan', 'nanprod') and fill_value != 1:
ret[counter] = fill_value
if func in _force_fill_0 and fill_value != 0 or func in _force_fill_1 and fill_value != 1:
if counter.dtype == np.bool_:
ret[counter] = fill_value
else:
ret[~counter.astype(bool)] = fill_value

if nans:
# Restore the shifted return array
Expand Down
8 changes: 6 additions & 2 deletions numpy_groupies/benchmarks/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ def arbitrary(iterator):
np.nanmean, np.nanvar, np.nanstd, 'nanfirst', 'nanlast',
'cumsum', 'cumprod', 'cummax', 'cummin', arbitrary, 'sort')


def benchmark(implementations, size=5e5, repeat=5, seed=100):
def benchmark_data(size=5e5, seed=100):
rnd = np.random.RandomState(seed=seed)
group_idx = rnd.randint(0, int(1e3), int(size))
a = rnd.random_sample(group_idx.size)
Expand All @@ -46,6 +45,11 @@ def benchmark(implementations, size=5e5, repeat=5, seed=100):
nana[(nana < 0.2) & (nana != 0)] = np.nan
nan_share = np.mean(np.isnan(nana))
assert 0.15 < nan_share < 0.25, "%3f%% nans" % (nan_share * 100)
return a, nana, group_idx


def benchmark(implementations, repeat=5, size=5e5, seed=100):
a, nana, group_idx = benchmark_data(size=size, seed=seed)

print("function" + ''.join(impl.__name__.rsplit('_', 1)[1].rjust(14) for impl in implementations))
print("-" * (9 + 14 * len(implementations)))
Expand Down
45 changes: 29 additions & 16 deletions numpy_groupies/tests/test_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
may throw NotImplementedError in order to show missing functionality without throwing
test errors.
"""
import itertools
from itertools import product
import numpy as np
import pytest

Expand All @@ -18,7 +18,8 @@ class AttrDict(dict):

@pytest.fixture(params=['np/py', 'weave/np', 'ufunc/np', 'numba/np', 'pandas/np'], scope='module')
def aggregate_cmp(request, seed=100):
if request.param == 'np/py':
test_pair = request.param
if test_pair == 'np/py':
# Some functions in purepy are not implemented
func_ref = _wrap_notimplemented_xfail(aggregate_purepy.aggregate)
func = aggregate_numpy.aggregate
Expand Down Expand Up @@ -72,22 +73,34 @@ def func_preserve_order(iterator):
return tmp


func_list = ('sum', 'prod', 'min', 'max', 'all', 'any', 'mean', 'std', 'len',
'argmin', 'argmax', 'anynan', 'allnan', 'cumsum',
'nansum', 'nanprod', 'nanmin', 'nanmax', 'nanmean', 'nanstd', 'nanlen',
func_arbitrary, func_preserve_order)

@pytest.mark.parametrize("func", func_list, ids=lambda x: getattr(x, '__name__', x))
def test_cmp(aggregate_cmp, func, decimal=10):
a = aggregate_cmp.nana if 'nan' in getattr(func, '__name__', func) else aggregate_cmp.a
res = aggregate_cmp.func(aggregate_cmp.group_idx, a, func=func)
ref = aggregate_cmp.func_ref(aggregate_cmp.group_idx, a, func=func)
if isinstance(ref, np.ndarray):
assert res.dtype == ref.dtype
np.testing.assert_allclose(res, ref, rtol=10**-decimal)
func_list = ('sum', 'prod', 'min', 'max', 'all', 'any', 'mean', 'std', 'var', 'len',
'argmin', 'argmax', 'anynan', 'allnan', 'cumsum', func_arbitrary, func_preserve_order,
'nansum', 'nanprod', 'nanmin', 'nanmax', 'nanmean', 'nanstd', 'nanvar','nanlen')


@pytest.mark.parametrize(["ndim", "order"], itertools.product([2, 3], ["C", "F"]))
@pytest.mark.parametrize(["func", "fill_value"], product(func_list, [0, 1, np.nan]),
ids=lambda x: getattr(x, '__name__', x))
def test_cmp(aggregate_cmp, func, fill_value, decimal=10):
a = aggregate_cmp.nana if 'nan' in getattr(func, '__name__', func) else aggregate_cmp.a
try:
ref = aggregate_cmp.func_ref(aggregate_cmp.group_idx, a, func=func, fill_value=fill_value)
except ValueError:
with pytest.raises(ValueError):
aggregate_cmp.func(aggregate_cmp.group_idx, a, func=func, fill_value=fill_value)
else:
try:
res = aggregate_cmp.func(aggregate_cmp.group_idx, a, func=func, fill_value=fill_value)
except ValueError:
if np.isnan(fill_value) and aggregate_cmp.test_pair.endswith('py'):
pytest.skip("pure python version uses lists and does not raise ValueErrors when inserting nan into integers")
else:
raise
if isinstance(ref, np.ndarray):
assert res.dtype == ref.dtype
np.testing.assert_allclose(res, ref, rtol=10**-decimal)


@pytest.mark.parametrize(["ndim", "order"], product([2, 3], ["C", "F"]))
def test_cmp_ndim(aggregate_cmp, ndim, order, outsize=100, decimal=14):
nindices = int(outsize ** ndim)
outshape = tuple([outsize] * ndim)
Expand Down
2 changes: 1 addition & 1 deletion numpy_groupies/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def test_nan_input(aggregate_all, func, groups=100):

def test_nan_input_len(aggregate_all, groups=100, group_size=5):
if aggregate_all.__name__.endswith('pandas'):
pytest.skip("pandas automatically skip nan values")
pytest.skip("pandas always skips nan values")
group_idx = np.arange(0, groups, dtype=int).repeat(group_size)
a = np.random.random(len(group_idx))
a[::2] = np.nan
Expand Down
17 changes: 10 additions & 7 deletions numpy_groupies/utils_numpy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Common helper functions for typing and general numpy tools."""
import numpy as np

from .utils import get_aliasing
from .utils import get_aliasing, check_boolean

_alias_numpy = {
np.add: 'sum',
Expand Down Expand Up @@ -168,12 +168,15 @@ def check_dtype(dtype, func_str, a, n):
return a_dtype


def check_fill_value(fill_value, dtype):
try:
return dtype.type(fill_value)
except ValueError:
raise ValueError("fill_value must be convertible into %s"
% dtype.type.__name__)
def check_fill_value(fill_value, dtype, func=None):
if func in ('all', 'any', 'allnan', 'anynan'):
check_boolean(fill_value)
else:
try:
return dtype.type(fill_value)
except ValueError:
raise ValueError("fill_value must be convertible into %s"
% dtype.type.__name__)


def check_group_idx(group_idx, a=None, check_min=True):
Expand Down

0 comments on commit 197bde5

Please sign in to comment.