Skip to content

Commit

Permalink
feat(rust, python): is_sorted aggregation fast path for Utf8Chunked (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Oct 13, 2022
1 parent aab7eb1 commit 3f59000
Showing 1 changed file with 43 additions and 34 deletions.
77 changes: 43 additions & 34 deletions polars/polars-core/src/chunked_array/ops/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use arrow::types::simd::Simd;
use num::{Float, ToPrimitive};
use polars_arrow::prelude::QuantileInterpolOptions;

use crate::chunked_array::builder::get_list_builder;
use crate::chunked_array::ChunkedArray;
use crate::datatypes::{BooleanChunked, PolarsNumericType};
use crate::prelude::*;
Expand Down Expand Up @@ -738,22 +737,30 @@ impl ChunkAggSeries for Utf8Chunked {
Utf8Chunked::full_null(self.name(), 1).into_series()
}
fn max_as_series(&self) -> Series {
Series::new(
self.name(),
&[self
.downcast_iter()
.filter_map(compute::aggregate::max_string)
.fold_first_(|acc, v| if acc > v { acc } else { v })],
)
match self.is_sorted2() {
IsSorted::Ascending => Series::new(self.name(), &[self.get(self.len() - 1)]),
IsSorted::Descending => Series::new(self.name(), &[self.get(0)]),
IsSorted::Not => Series::new(
self.name(),
&[self
.downcast_iter()
.filter_map(compute::aggregate::max_string)
.fold_first_(|acc, v| if acc > v { acc } else { v })],
),
}
}
fn min_as_series(&self) -> Series {
Series::new(
self.name(),
&[self
.downcast_iter()
.filter_map(compute::aggregate::min_string)
.fold_first_(|acc, v| if acc < v { acc } else { v })],
)
match self.is_sorted2() {
IsSorted::Ascending => Series::new(self.name(), &[self.get(0)]),
IsSorted::Descending => Series::new(self.name(), &[self.get(self.len() - 1)]),
IsSorted::Not => Series::new(
self.name(),
&[self
.downcast_iter()
.filter_map(compute::aggregate::min_string)
.fold_first_(|acc, v| if acc < v { acc } else { v })],
),
}
}
}

Expand Down Expand Up @@ -782,23 +789,15 @@ impl ChunkAggSeries for BinaryChunked {
}
}

macro_rules! one_null_list {
($self:ident, $dtype: expr) => {{
let mut builder = get_list_builder(&$dtype, 0, 1, $self.name()).unwrap();
builder.append_opt_series(None);
builder.finish().into_series()
}};
}

impl ChunkAggSeries for ListChunked {
fn sum_as_series(&self) -> Series {
one_null_list!(self, self.inner_dtype())
ListChunked::full_null_with_dtype(self.name(), 1, &self.inner_dtype()).into_series()
}
fn max_as_series(&self) -> Series {
one_null_list!(self, self.inner_dtype())
ListChunked::full_null_with_dtype(self.name(), 1, &self.inner_dtype()).into_series()
}
fn min_as_series(&self) -> Series {
one_null_list!(self, self.inner_dtype())
ListChunked::full_null_with_dtype(self.name(), 1, &self.inner_dtype()).into_series()
}
}

Expand All @@ -810,16 +809,26 @@ where
T: PolarsNumericType,
{
fn arg_min(&self) -> Option<usize> {
self.into_iter()
.enumerate()
.reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
.map(|tpl| tpl.0)
match self.is_sorted2() {
IsSorted::Ascending => Some(0),
IsSorted::Descending => Some(self.len()),
IsSorted::Not => self
.into_iter()
.enumerate()
.reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
.map(|tpl| tpl.0),
}
}
fn arg_max(&self) -> Option<usize> {
self.into_iter()
.enumerate()
.reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
.map(|tpl| tpl.0)
match self.is_sorted2() {
IsSorted::Ascending => Some(self.len()),
IsSorted::Descending => Some(0),
IsSorted::Not => self
.into_iter()
.enumerate()
.reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
.map(|tpl| tpl.0),
}
}
}

Expand Down

0 comments on commit 3f59000

Please sign in to comment.