diff --git a/polars/polars-lazy/polars-pipe/src/executors/sinks/groupby/primitive.rs b/polars/polars-lazy/polars-pipe/src/executors/sinks/groupby/primitive.rs index 89149a9827fb..9054bbdfc677 100644 --- a/polars/polars-lazy/polars-pipe/src/executors/sinks/groupby/primitive.rs +++ b/polars/polars-lazy/polars-pipe/src/executors/sinks/groupby/primitive.rs @@ -16,6 +16,7 @@ use rayon::prelude::*; use super::aggregates::AggregateFn; use crate::executors::sinks::groupby::aggregates::AggregateFunction; +use crate::executors::sinks::groupby::string::{apply_aggregate, write_agg_idx}; use crate::executors::sinks::groupby::utils::compute_slices; use crate::executors::sinks::utils::load_vec; use crate::executors::sinks::HASHMAP_INIT_SIZE; @@ -48,10 +49,9 @@ pub struct PrimitiveGroupbySink { // the aggregations are all tightly packed // the aggregation function of a group can be found // by: - // first get the correct vec by the partition index // * offset = (idx) // * end = (offset + n_aggs) - aggregators: Vec>, + aggregators: Vec, key: Arc, // the columns that will be aggregated aggregation_columns: Arc>>, @@ -82,9 +82,8 @@ where let partitions = _set_partition_size(); let pre_agg = load_vec(partitions, || PlIdHashMap::with_capacity(HASHMAP_INIT_SIZE)); - let aggregators = load_vec(partitions, || { - Vec::with_capacity(HASHMAP_INIT_SIZE * aggregation_columns.len()) - }); + let aggregators = + Vec::with_capacity(HASHMAP_INIT_SIZE * aggregation_columns.len() * partitions); Self { thread_no: 0, @@ -107,20 +106,28 @@ where } fn pre_finalize(&mut self) -> PolarsResult> { - let mut aggregators = std::mem::take(&mut self.aggregators); + // we create a pointer to the aggregation functions buffer + // we will deref *mut on every partition thread + // this will be safe, as the partitions guarantee that access don't alias. + let aggregators = self.aggregators.as_ptr() as usize; + let aggregators_len = self.aggregators.len(); let slices = compute_slices(&self.pre_agg_partitions, self.slice); POOL.install(|| { let dfs = self.pre_agg_partitions .par_iter() - .zip(aggregators.par_iter_mut()) .zip(slices.par_iter()) - .filter_map(|((agg_map, agg_fns), slice)| { + .filter_map(|(agg_map, slice)| { let (offset, slice_len) = (*slice)?; if agg_map.is_empty() { return None; } + // safety: + // we will not alias. + let ptr = aggregators as *mut AggregateFunction; + let agg_fns = + unsafe { std::slice::from_raw_parts_mut(ptr, aggregators_len) }; let mut key_builder = PrimitiveChunkedBuilder::::new( self.output_schema.get_index(0).unwrap().0, agg_map.len(), @@ -191,6 +198,14 @@ where // cow -> &series -> &dyn series_trait -> &chunkedarray let ca: &ChunkedArray = s.as_ref().as_ref(); + // write the hashes to self.hashes buffer + // s.vec_hash(self.hb.clone(), &mut self.hashes).unwrap(); + // now we have written hashes, we take the pointer to this buffer + // we will write the aggregation_function indexes in the same buffer + // this is unsafe and we must check that we only write the hashes that + // already read/taken. So we write on the slots we just read + let agg_idx_ptr = self.hashes.as_ptr() as *mut i64 as *mut IdxSize; + // todo! ammortize allocation for phys_e in self.aggregation_columns.iter() { let s = phys_e.evaluate(&chunk, context.execution_state.as_ref())?; @@ -198,19 +213,13 @@ where self.aggregation_series.push(s.rechunk()); } - let mut agg_iters = self - .aggregation_series - .iter() - .map(|s| s.phys_iter()) - .collect::>(); - let arr = ca.downcast_iter().next().unwrap(); - for (opt_v, &h) in arr.iter().zip(self.hashes.iter()) { + for (iteration_idx, (opt_v, &h)) in arr.iter().zip(self.hashes.iter()).enumerate() { let opt_v = opt_v.copied(); let part = hash_to_partition(h, self.pre_agg_partitions.len()); let current_partition = unsafe { self.pre_agg_partitions.get_unchecked_release_mut(part) }; - let current_aggregators = unsafe { self.aggregators.get_unchecked_release_mut(part) }; + let current_aggregators = &mut self.aggregators; let entry = current_partition .raw_entry_mut() @@ -233,14 +242,27 @@ where } RawEntryMut::Occupied(entry) => *entry.get(), }; - for (i, agg_iter) in (0 as IdxSize..num_aggs as IdxSize).zip(agg_iters.iter_mut()) { - let i = (agg_idx + i) as usize; - let agg_fn = unsafe { current_aggregators.get_unchecked_release_mut(i) }; + // # Safety + // we write to the hashes buffer we iterate over at the moment. + // this is sound because we writes are trailing from iteration + unsafe { write_agg_idx(agg_idx_ptr, iteration_idx, agg_idx) }; + } + + // note that this slice looks into the self.hashes buffer + let agg_idxs = unsafe { std::slice::from_raw_parts(agg_idx_ptr, ca.len()) }; - agg_fn.pre_agg(chunk.chunk_index, agg_iter.as_mut()) - } + for (agg_i, aggregation_s) in (0..num_aggs).zip(&self.aggregation_series) { + let has_physical_agg = self.agg_fns[agg_i].has_physical_agg(); + apply_aggregate( + agg_i, + chunk.chunk_index, + agg_idxs, + aggregation_s, + has_physical_agg, + &mut self.aggregators, + ); } - drop(agg_iters); + self.aggregation_series.clear(); Ok(SinkResult::CanHaveMoreInput) } @@ -252,37 +274,36 @@ where self.pre_agg_partitions .iter_mut() .zip(other.pre_agg_partitions.iter()) - .zip(self.aggregators.iter_mut()) - .zip(other.aggregators.iter()) - .for_each( - |(((map_self, map_other), aggregators_self), aggregators_other)| { - for (key, &agg_idx_other) in map_other.iter() { - unsafe { - let entry = map_self.raw_entry_mut().from_key(key); - - let agg_idx_self = match entry { - RawEntryMut::Vacant(entry) => { - let offset = NumCast::from(aggregators_self.len()).unwrap(); - entry.insert(*key, offset); - // initialize the aggregators - for agg_fn in &self.agg_fns { - aggregators_self.push(agg_fn.split2()) - } - offset - } - RawEntryMut::Occupied(entry) => *entry.get(), - }; - for i in 0..self.aggregation_columns.len() { - let agg_fn_other = aggregators_other - .get_unchecked_release(agg_idx_other as usize + i); - let agg_fn_self = aggregators_self - .get_unchecked_release_mut(agg_idx_self as usize + i); - agg_fn_self.combine(agg_fn_other.as_any()) + .for_each(|(map_self, map_other)| { + for (key, &agg_idx_other) in map_other.iter() { + let entry = map_self.raw_entry_mut().from_key(key); + + let agg_idx_self = match entry { + RawEntryMut::Vacant(entry) => { + let offset = NumCast::from(self.aggregators.len()).unwrap(); + entry.insert(*key, offset); + // initialize the aggregators + for agg_fn in &self.agg_fns { + self.aggregators.push(agg_fn.split2()) } + offset + } + RawEntryMut::Occupied(entry) => *entry.get(), + }; + // combine the aggregation functions + for i in 0..self.aggregation_columns.len() { + unsafe { + let agg_fn_other = other + .aggregators + .get_unchecked_release(agg_idx_other as usize + i); + let agg_fn_self = self + .aggregators + .get_unchecked_release_mut(agg_idx_self as usize + i); + agg_fn_self.combine(agg_fn_other.as_any()) } } - }, - ); + } + }); } fn finalize(&mut self, _context: &PExecutionContext) -> PolarsResult { diff --git a/polars/polars-lazy/polars-pipe/src/executors/sinks/groupby/string.rs b/polars/polars-lazy/polars-pipe/src/executors/sinks/groupby/string.rs index f318c001de2f..0f6208519f9e 100644 --- a/polars/polars-lazy/polars-pipe/src/executors/sinks/groupby/string.rs +++ b/polars/polars-lazy/polars-pipe/src/executors/sinks/groupby/string.rs @@ -34,7 +34,6 @@ pub struct Utf8GroupbySink { // the aggregations/keys are all tightly packed // the aggregation function of a group can be found // by: - // first get the correct vec by the partition index // * offset = (idx) // * end = (offset + 1) keys: Vec>, @@ -361,11 +360,11 @@ impl Sink for Utf8GroupbySink { } // write agg_idx to the hashes buffer. -unsafe fn write_agg_idx(h: *mut IdxSize, i: usize, agg_idx: IdxSize) { +pub(super) unsafe fn write_agg_idx(h: *mut IdxSize, i: usize, agg_idx: IdxSize) { h.add(i).write(agg_idx) } -fn apply_aggregate( +pub(super) fn apply_aggregate( agg_i: usize, chunk_idx: IdxSize, agg_idxs: &[IdxSize],