Skip to content

Commit

Permalink
Merge pull request #231 from machow/fix-summarize-validation
Browse files Browse the repository at this point in the history
fix: pandas summarize validate len, use Series arrays
  • Loading branch information
machow authored May 12, 2020
2 parents fd3ee53 + 0622683 commit cd977e2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
10 changes: 8 additions & 2 deletions siuba/dply/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np

from pandas.core.groupby import DataFrameGroupBy
from pandas.core.dtypes.inference import is_scalar
from siuba.siu import Symbolic, Call, strip_symbolic, MetaArg, BinaryOp, create_sym_call, Lazy

DPLY_FUNCTIONS = (
Expand Down Expand Up @@ -391,10 +392,15 @@ def summarize(__data, **kwargs):
for k, v in kwargs.items():
res = v(__data) if callable(v) else v

# TODO: validation?
# validate operations returned single result
if not is_scalar(res) and len(res) > 1:
raise ValueError("Summarize argument, %s, must return result of length 1 or a scalar." % k)

results[k] = res
# keep result, but use underlying array to avoid crazy index issues
# on DataFrame construction (#138)
results[k] = res.array if isinstance(res, pd.Series) else res

# must pass index, or raises error when using all scalar values
return DataFrame(results, index = [0])


Expand Down
17 changes: 16 additions & 1 deletion siuba/tests/test_verb_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,25 @@ def test_summarize_unnamed_args(df):
)


@pytest.mark.skip("TODO: Summarize should fail when result len > 1 (#138)")
def test_summarize_validates_length():
with pytest.raises(ValueError):
summarize(data_frame(x = [1,2]), res = _.x + 1)


def test_frame_mode_returns_many():
# related to length validation above
with pytest.raises(ValueError):
df = data_frame(x = [1, 2, 3])
res = summarize(df, result = _.x.mode())


def test_summarize_removes_series_index():
# Note: currently wouldn't work in postgresql, since _.x + _.y not an agg func
df = data_frame(g = ['a', 'b', 'c'], x = [1,2,3], y = [4,5,6])

assert_equal_query(
df,
group_by(_.g) >> summarize(res = _.x + _.y),
df.assign(res = df.x + df.y).drop(columns = ["x", "y"])
)

0 comments on commit cd977e2

Please sign in to comment.