Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rust, python): fix boolean schema in agg_max/min #5678

Merged
merged 2 commits into from
Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions polars/polars-core/src/chunked_array/upstream_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,21 @@ impl FromParallelIterator<bool> for BooleanChunked {
}
}

impl FromParallelIterator<Option<bool>> for BooleanChunked {
fn from_par_iter<I: IntoParallelIterator<Item = Option<bool>>>(iter: I) -> Self {
let vectors = collect_into_linked_list(iter);

let capacity: usize = get_capacity_from_par_results(&vectors);

let arr = unsafe {
BooleanArray::from_trusted_len_iter(
vectors.into_iter().flatten().trust_my_length(capacity),
)
};
Self::from_chunks("", vec![Box::new(arr)])
}
}

impl<Ptr> FromParallelIterator<Ptr> for Utf8Chunked
where
Ptr: PolarsAsRef<str> + Send + Sync,
Expand Down
97 changes: 95 additions & 2 deletions polars/polars-core/src/frame/groupby/aggregations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,14 @@ where
ca.into_series()
}

pub fn _agg_helper_idx_bool<F>(groups: &GroupsIdx, f: F) -> Series
where
F: Fn((IdxSize, &Vec<IdxSize>)) -> Option<bool> + Send + Sync,
{
let ca: BooleanChunked = POOL.install(|| groups.into_par_iter().map(f).collect());
ca.into_series()
}

// helper that iterates on the `all: Vec<Vec<u32>` collection
// this doesn't have traverse the `first: Vec<u32>` memory and is therefore faster
fn agg_helper_idx_on_all<T, F>(groups: &GroupsIdx, f: F) -> Series
Expand All @@ -183,12 +191,97 @@ where
ca.into_series()
}

pub fn _agg_helper_slice_bool<F>(groups: &[[IdxSize; 2]], f: F) -> Series
where
F: Fn([IdxSize; 2]) -> Option<bool> + Send + Sync,
{
let ca: BooleanChunked = POOL.install(|| groups.par_iter().copied().map(f).collect());
ca.into_series()
}

impl BooleanChunked {
pub(crate) unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series {
self.cast(&IDX_DTYPE).unwrap().agg_min(groups)
// faster paths
match (self.is_sorted2(), self.null_count()) {
(IsSorted::Ascending, 0) => {
return self.clone().into_series().agg_first(groups);
}
(IsSorted::Descending, 0) => {
return self.clone().into_series().agg_last(groups);
}
_ => {}
}
match groups {
GroupsProxy::Idx(groups) => _agg_helper_idx_bool(groups, |(first, idx)| {
debug_assert!(idx.len() <= self.len());
if idx.is_empty() {
None
} else if idx.len() == 1 {
self.get(first as usize)
} else {
// TODO! optimize this
// can just check if any is false and early stop
let take = { self.take_unchecked(idx.into()) };
take.min().map(|v| v == 1)
}
}),
GroupsProxy::Slice {
groups: groups_slice,
..
} => _agg_helper_slice_bool(groups_slice, |[first, len]| {
debug_assert!(len <= self.len() as IdxSize);
match len {
0 => None,
1 => self.get(first as usize),
_ => {
let arr_group = _slice_from_offsets(self, first, len);
arr_group.min().map(|v| v == 1)
}
}
}),
}
}
pub(crate) unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series {
self.cast(&IDX_DTYPE).unwrap().agg_max(groups)
// faster paths
match (self.is_sorted2(), self.null_count()) {
(IsSorted::Ascending, 0) => {
return self.clone().into_series().agg_last(groups);
}
(IsSorted::Descending, 0) => {
return self.clone().into_series().agg_first(groups);
}
_ => {}
}

match groups {
GroupsProxy::Idx(groups) => _agg_helper_idx_bool(groups, |(first, idx)| {
debug_assert!(idx.len() <= self.len());
if idx.is_empty() {
None
} else if idx.len() == 1 {
self.get(first as usize)
} else {
// TODO! optimize this
// can just check if any is true and early stop
let take = { self.take_unchecked(idx.into()) };
take.max().map(|v| v == 1)
}
}),
GroupsProxy::Slice {
groups: groups_slice,
..
} => _agg_helper_slice_bool(groups_slice, |[first, len]| {
debug_assert!(len <= self.len() as IdxSize);
match len {
0 => None,
1 => self.get(first as usize),
_ => {
let arr_group = _slice_from_offsets(self, first, len);
arr_group.max().map(|v| v == 1)
}
}
}),
}
}
pub(crate) unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series {
self.cast(&IDX_DTYPE).unwrap().agg_sum(groups)
Expand Down
44 changes: 22 additions & 22 deletions py-polars/polars/internals/dataframe/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,17 +600,17 @@ def min(self) -> pli.DataFrame:
... )
>>> df.groupby("d", maintain_order=True).min()
shape: (3, 4)
┌────────┬─────┬──────┬─────┐
│ d ┆ a ┆ b ┆ c │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ f64 ┆ u32
╞════════╪═════╪══════╪═════╡
│ Apple ┆ 1 ┆ 0.5 ┆ 0
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┤
│ Orange ┆ 2 ┆ 0.5 ┆ 1
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┤
│ Banana ┆ 4 ┆ 13.0 ┆ 0
└────────┴─────┴──────┴─────┘
┌────────┬─────┬──────┬───────
│ d ┆ a ┆ b ┆ c
│ --- ┆ --- ┆ --- ┆ ---
│ str ┆ i64 ┆ f64 ┆ bool
╞════════╪═════╪══════╪═══════
│ Apple ┆ 1 ┆ 0.5 ┆ false
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌╌
│ Orange ┆ 2 ┆ 0.5 ┆ true
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌╌
│ Banana ┆ 4 ┆ 13.0 ┆ false
└────────┴─────┴──────┴───────

"""
return self.agg(pli.all().min())
Expand All @@ -631,17 +631,17 @@ def max(self) -> pli.DataFrame:
... )
>>> df.groupby("d", maintain_order=True).max()
shape: (3, 4)
┌────────┬─────┬──────┬─────┐
│ d ┆ a ┆ b ┆ c │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ f64 ┆ u32
╞════════╪═════╪══════╪═════╡
│ Apple ┆ 3 ┆ 10.0 ┆ 1
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┤
│ Orange ┆ 2 ┆ 0.5 ┆ 1
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┤
│ Banana ┆ 5 ┆ 14.0 ┆ 1
└────────┴─────┴──────┴─────┘
┌────────┬─────┬──────┬─────
│ d ┆ a ┆ b ┆ c
│ --- ┆ --- ┆ --- ┆ ---
│ str ┆ i64 ┆ f64 ┆ bool
╞════════╪═════╪══════╪═════
│ Apple ┆ 3 ┆ 10.0 ┆ true
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌
│ Orange ┆ 2 ┆ 0.5 ┆ true
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌
│ Banana ┆ 5 ┆ 14.0 ┆ true
└────────┴─────┴──────┴─────

"""
return self.agg(pli.all().max())
Expand Down
18 changes: 18 additions & 0 deletions py-polars/tests/unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,21 @@ def test_diff_duration_dtype() -> None:
False,
True,
]


def test_boolean_agg_schema() -> None:
df = pl.DataFrame(
{
"x": [1, 1, 1],
"y": [False, True, False],
}
).lazy()

agg_df = df.groupby("x").agg(pl.col("y").max().alias("max_y"))

for streaming in [True, False]:
assert (
agg_df.collect(streaming=streaming).schema
== agg_df.schema
== {"x": pl.Int64, "max_y": pl.Boolean}
)