From 235a0551fd21b1265b75549e608f8f6be2797870 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Thu, 13 Oct 2022 10:18:21 +0200 Subject: [PATCH] feat(rust, python): is_sorted aggregation fast path for Utf8Chunked --- .../src/chunked_array/ops/aggregate.rs | 77 +++++++++++-------- 1 file changed, 43 insertions(+), 34 deletions(-) diff --git a/polars/polars-core/src/chunked_array/ops/aggregate.rs b/polars/polars-core/src/chunked_array/ops/aggregate.rs index 5344760a578b..49e0f912eeb8 100644 --- a/polars/polars-core/src/chunked_array/ops/aggregate.rs +++ b/polars/polars-core/src/chunked_array/ops/aggregate.rs @@ -6,7 +6,6 @@ use arrow::types::simd::Simd; use num::{Float, ToPrimitive}; use polars_arrow::prelude::QuantileInterpolOptions; -use crate::chunked_array::builder::get_list_builder; use crate::chunked_array::ChunkedArray; use crate::datatypes::{BooleanChunked, PolarsNumericType}; use crate::prelude::*; @@ -738,22 +737,30 @@ impl ChunkAggSeries for Utf8Chunked { Utf8Chunked::full_null(self.name(), 1).into_series() } fn max_as_series(&self) -> Series { - Series::new( - self.name(), - &[self - .downcast_iter() - .filter_map(compute::aggregate::max_string) - .fold_first_(|acc, v| if acc > v { acc } else { v })], - ) + match self.is_sorted2() { + IsSorted::Ascending => Series::new(self.name(), &[self.get(self.len() - 1)]), + IsSorted::Descending => Series::new(self.name(), &[self.get(0)]), + IsSorted::Not => Series::new( + self.name(), + &[self + .downcast_iter() + .filter_map(compute::aggregate::max_string) + .fold_first_(|acc, v| if acc > v { acc } else { v })], + ), + } } fn min_as_series(&self) -> Series { - Series::new( - self.name(), - &[self - .downcast_iter() - .filter_map(compute::aggregate::min_string) - .fold_first_(|acc, v| if acc < v { acc } else { v })], - ) + match self.is_sorted2() { + IsSorted::Ascending => Series::new(self.name(), &[self.get(0)]), + IsSorted::Descending => Series::new(self.name(), &[self.get(self.len() - 1)]), + IsSorted::Not => Series::new( + self.name(), + &[self + .downcast_iter() + .filter_map(compute::aggregate::min_string) + .fold_first_(|acc, v| if acc < v { acc } else { v })], + ), + } } } @@ -782,23 +789,15 @@ impl ChunkAggSeries for BinaryChunked { } } -macro_rules! one_null_list { - ($self:ident, $dtype: expr) => {{ - let mut builder = get_list_builder(&$dtype, 0, 1, $self.name()).unwrap(); - builder.append_opt_series(None); - builder.finish().into_series() - }}; -} - impl ChunkAggSeries for ListChunked { fn sum_as_series(&self) -> Series { - one_null_list!(self, self.inner_dtype()) + ListChunked::full_null_with_dtype(self.name(), 1, &self.inner_dtype()).into_series() } fn max_as_series(&self) -> Series { - one_null_list!(self, self.inner_dtype()) + ListChunked::full_null_with_dtype(self.name(), 1, &self.inner_dtype()).into_series() } fn min_as_series(&self) -> Series { - one_null_list!(self, self.inner_dtype()) + ListChunked::full_null_with_dtype(self.name(), 1, &self.inner_dtype()).into_series() } } @@ -810,16 +809,26 @@ where T: PolarsNumericType, { fn arg_min(&self) -> Option { - self.into_iter() - .enumerate() - .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc }) - .map(|tpl| tpl.0) + match self.is_sorted2() { + IsSorted::Ascending => Some(0), + IsSorted::Descending => Some(self.len()), + IsSorted::Not => self + .into_iter() + .enumerate() + .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc }) + .map(|tpl| tpl.0), + } } fn arg_max(&self) -> Option { - self.into_iter() - .enumerate() - .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc }) - .map(|tpl| tpl.0) + match self.is_sorted2() { + IsSorted::Ascending => Some(self.len()), + IsSorted::Descending => Some(0), + IsSorted::Not => self + .into_iter() + .enumerate() + .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc }) + .map(|tpl| tpl.0), + } } }