Skip to content

Commit

Permalink
fix!: if_else and case_when now return Series
Browse files Browse the repository at this point in the history
  • Loading branch information
machow committed Jul 19, 2022
1 parent 60f3e51 commit 34d2907
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
27 changes: 17 additions & 10 deletions siuba/dply/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,19 +657,28 @@ def _distinct(__data, *args, _keep_all = False, **kwargs):
def if_else(__data, *args, **kwargs):
"""
Example:
>>> ser1 = pd.Series([1,2,3,4])
>>> if_else(ser1 > 2, np.nan, ser1) # doctest: +SKIP
array([ 1., 2., nan, nan])
>>> ser1 = pd.Series([1,2,3])
>>> if_else(ser1 > 2, np.nan, ser1)
0 1.0
1 2.0
2 NaN
dtype: float64
>>> from siuba import _
>>> f = if_else(_ < 3, _, 3)
>>> f = if_else(_ < 2, _, 2)
>>> f(ser1)
array([1, 2, 3, 3])
0 1
1 2
2 2
dtype: int64
>>> import numpy as np
>>> ser2 = pd.Series(['NA', 'a', 'b'])
>>> if_else(ser2 == 'NA', np.nan, ser2)
array([nan, 'a', 'b'], dtype=object)
0 NaN
1 a
2 b
dtype: object
"""
raise_type_error(__data)
Expand All @@ -683,9 +692,7 @@ def _if_else(__data, *args, **kwargs):
def _if_else(cond, true_vals, false_vals):
result = np.where(cond.fillna(False), true_vals, false_vals)

# TODO: should functions that take a Series, return a Series?
# for now, just return "O" type. Sort out once better research.
return result
return pd.Series(result)


# case_when ----------------
Expand Down Expand Up @@ -729,7 +736,7 @@ def case_when(__data, cases):
out[:] = val_res

# by recreating an array, attempts to cast as best dtype
return np.array(list(out))
return pd.Series(list(out))

@case_when.register(Symbolic)
@case_when.register(Call)
Expand Down
6 changes: 3 additions & 3 deletions siuba/tests/test_verb_case_when.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from siuba.dply.verbs import case_when
from pandas.testing import assert_series_equal
from numpy.testing import assert_equal
from siuba.siu import _

Expand All @@ -29,10 +30,9 @@ def data():
#(np.array([True, True, False]), 0, [0, 0, None])
])
def test_case_when_single_cond(k, v, res, data):
arr_res = np.array(res)
out = case_when(data, {k: v})

assert_equal(out, arr_res)
assert_series_equal(out, pd.Series(res))


def test_case_when_cond_order(data):
Expand All @@ -41,5 +41,5 @@ def test_case_when_cond_order(data):
True : 999
})

assert_equal(out, np.array([0, 0, 999]))
assert_series_equal(out, pd.Series([0, 0, 999]))

0 comments on commit 34d2907

Please sign in to comment.