Skip to content
/ cudf Public
forked from rapidsai/cudf

Commit

Permalink
groupby: Transfer struct dtype into collected aggregate
Browse files Browse the repository at this point in the history
As usual when returning from libcudf, we need to reconstruct a struct
dtype with appropriate labels. For groupby.agg(list) this can be done
by matching on the element_type of the result column and
reconstructing with a new list dtype with a leaf from the original
column.

Closes rapidsai#11765.
  • Loading branch information
wence- committed Dec 2, 2022
1 parent 48a27b8 commit 7ed111a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
16 changes: 14 additions & 2 deletions python/cudf/cudf/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from cudf.core.abc import Serializable
from cudf.core.column.column import ColumnBase, arange, as_column
from cudf.core.column_accessor import ColumnAccessor
from cudf.core.dtypes import is_categorical_dtype
from cudf.core.mixins import Reducible, Scannable
from cudf.core.multiindex import MultiIndex
from cudf.utils.utils import GetAttrGetItemMixin, _cudf_nvtx_annotate
Expand Down Expand Up @@ -449,6 +450,7 @@ def agg(self, func):
2 3.0 3.00 1.0 1.0
"""
column_names, columns, normalized_aggs = self._normalize_aggs(func)
orig_dtypes = tuple(c.dtype for c in columns)

# Note: When there are no key columns, the below produces
# a Float64Index, while Pandas returns an Int64Index
Expand All @@ -465,15 +467,25 @@ def agg(self, func):

multilevel = _is_multi_agg(func)
data = {}
for col_name, aggs, cols in zip(
column_names, included_aggregations, result_columns
for col_name, aggs, cols, orig_dtype in zip(
column_names,
included_aggregations,
result_columns,
orig_dtypes,
):
for agg, col in zip(aggs, cols):
if multilevel:
agg_name = agg.__name__ if callable(agg) else agg
key = (col_name, agg_name)
else:
key = col_name
if (
agg in {list, "collect"}
and not is_categorical_dtype(orig_dtype)
and orig_dtype != col.dtype.element_type
):
# Structs lose their labels which we reconstruct here
col = col._with_type_metadata(cudf.ListDtype(orig_dtype))
data[key] = col
data = ColumnAccessor(data, multiindex=multilevel)
if not multilevel:
Expand Down
9 changes: 5 additions & 4 deletions python/cudf/cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1579,10 +1579,11 @@ def test_groupby_list_of_structs(list_agg):
)
gdf = cudf.from_pandas(pdf)

with pytest.raises(
pd.errors.DataError if PANDAS_GE_150 else pd.core.base.DataError
):
gdf.groupby("a").agg({"b": list_agg})
assert_groupby_results_equal(
pdf.groupby("a").agg({"b": list_agg}),
gdf.groupby("a").agg({"b": list_agg}),
check_dtype=True,
)


@pytest.mark.parametrize("list_agg", [list, "collect"])
Expand Down

0 comments on commit 7ed111a

Please sign in to comment.