From 925dfaf642b1718109836690121ab189064e19fb Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Mon, 18 Nov 2024 16:50:16 +0100 Subject: [PATCH] refactor(rust): Make chunked gathers generic over chunk bit width --- crates/polars-core/src/prelude.rs | 2 +- .../src/chunked_array/gather/chunked.rs | 101 ++++++++++++------ .../src/frame/join/dispatch_left_right.rs | 2 +- .../src/frame/join/hash_join/mod.rs | 4 +- .../sinks/joins/generic_probe_inner_left.rs | 7 +- .../sinks/joins/generic_probe_outer.rs | 12 +-- crates/polars-utils/src/index.rs | 2 - 7 files changed, 80 insertions(+), 50 deletions(-) diff --git a/crates/polars-core/src/prelude.rs b/crates/polars-core/src/prelude.rs index bd1ade2d9b90..558c34dd36e0 100644 --- a/crates/polars-core/src/prelude.rs +++ b/crates/polars-core/src/prelude.rs @@ -6,7 +6,7 @@ pub(crate) use arrow::array::*; pub use arrow::datatypes::{ArrowSchema, Field as ArrowField}; pub use arrow::legacy::prelude::*; pub(crate) use arrow::trusted_len::TrustedLen; -pub use polars_utils::index::{ChunkId, IdxSize, NullableChunkId, NullableIdxSize}; +pub use polars_utils::index::{ChunkId, IdxSize, NullableIdxSize}; pub use polars_utils::pl_str::PlSmallStr; pub(crate) use polars_utils::total_ord::{TotalEq, TotalOrd}; diff --git a/crates/polars-ops/src/chunked_array/gather/chunked.rs b/crates/polars-ops/src/chunked_array/gather/chunked.rs index 249e8dc1730a..e607fa77c068 100644 --- a/crates/polars-ops/src/chunked_array/gather/chunked.rs +++ b/crates/polars-ops/src/chunked_array/gather/chunked.rs @@ -12,34 +12,60 @@ use polars_core::with_match_physical_numeric_polars_type; use crate::frame::IntoDf; -pub trait DfTake: IntoDf { +/// Gather by [`ChunkId`] +pub trait TakeChunked { + /// # Safety + /// This function doesn't do any bound checks. + unsafe fn take_chunked_unchecked( + &self, + by: &[ChunkId], + sorted: IsSorted, + ) -> Self; + + /// # Safety + /// This function doesn't do any bound checks. + unsafe fn take_opt_chunked_unchecked(&self, by: &[ChunkId]) -> Self; +} + +impl TakeChunked for DataFrame { /// Take elements by a slice of [`ChunkId`]s. /// /// # Safety /// Does not do any bound checks. /// `sorted` indicates if the chunks are sorted. - unsafe fn _take_chunked_unchecked_seq(&self, idx: &[ChunkId], sorted: IsSorted) -> DataFrame { + unsafe fn take_chunked_unchecked( + &self, + idx: &[ChunkId], + sorted: IsSorted, + ) -> DataFrame { let cols = self .to_df() ._apply_columns(&|s| s.take_chunked_unchecked(idx, sorted)); unsafe { DataFrame::new_no_checks_height_from_first(cols) } } + /// Take elements by a slice of optional [`ChunkId`]s. /// /// # Safety /// Does not do any bound checks. - unsafe fn _take_opt_chunked_unchecked_seq(&self, idx: &[NullableChunkId]) -> DataFrame { + unsafe fn take_opt_chunked_unchecked(&self, idx: &[ChunkId]) -> DataFrame { let cols = self .to_df() ._apply_columns(&|s| s.take_opt_chunked_unchecked(idx)); unsafe { DataFrame::new_no_checks_height_from_first(cols) } } +} +pub trait TakeChunkedHorPar: IntoDf { /// # Safety /// Doesn't perform any bound checks - unsafe fn _take_chunked_unchecked(&self, idx: &[ChunkId], sorted: IsSorted) -> DataFrame { + unsafe fn _take_chunked_unchecked_hor_par( + &self, + idx: &[ChunkId], + sorted: IsSorted, + ) -> DataFrame { let cols = self .to_df() ._apply_columns_par(&|s| s.take_chunked_unchecked(idx, sorted)); @@ -51,7 +77,10 @@ pub trait DfTake: IntoDf { /// Doesn't perform any bound checks /// /// Check for null state in `ChunkId`. - unsafe fn _take_opt_chunked_unchecked(&self, idx: &[ChunkId]) -> DataFrame { + unsafe fn _take_opt_chunked_unchecked_hor_par( + &self, + idx: &[ChunkId], + ) -> DataFrame { let cols = self .to_df() ._apply_columns_par(&|s| s.take_opt_chunked_unchecked(idx)); @@ -60,18 +89,7 @@ pub trait DfTake: IntoDf { } } -impl DfTake for DataFrame {} - -/// Gather by [`ChunkId`] -pub trait TakeChunked { - /// # Safety - /// This function doesn't do any bound checks. - unsafe fn take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Self; - - /// # Safety - /// This function doesn't do any bound checks. - unsafe fn take_opt_chunked_unchecked(&self, by: &[ChunkId]) -> Self; -} +impl TakeChunkedHorPar for DataFrame {} fn prepare_series(s: &Series) -> Cow { let phys = if s.dtype().is_nested() { @@ -89,14 +107,18 @@ fn prepare_series(s: &Series) -> Cow { } impl TakeChunked for Column { - unsafe fn take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Self { + unsafe fn take_chunked_unchecked( + &self, + by: &[ChunkId], + sorted: IsSorted, + ) -> Self { // @scalar-opt let s = self.as_materialized_series(); let s = unsafe { s.take_chunked_unchecked(by, sorted) }; s.into_column() } - unsafe fn take_opt_chunked_unchecked(&self, by: &[ChunkId]) -> Self { + unsafe fn take_opt_chunked_unchecked(&self, by: &[ChunkId]) -> Self { // @scalar-opt let s = self.as_materialized_series(); let s = unsafe { s.take_opt_chunked_unchecked(by) }; @@ -105,7 +127,11 @@ impl TakeChunked for Column { } impl TakeChunked for Series { - unsafe fn take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Self { + unsafe fn take_chunked_unchecked( + &self, + by: &[ChunkId], + sorted: IsSorted, + ) -> Self { let phys = prepare_series(self); use DataType::*; let out = match phys.dtype() { @@ -162,7 +188,7 @@ impl TakeChunked for Series { } /// Take function that checks of null state in `ChunkIdx`. - unsafe fn take_opt_chunked_unchecked(&self, by: &[NullableChunkId]) -> Self { + unsafe fn take_opt_chunked_unchecked(&self, by: &[ChunkId]) -> Self { let phys = prepare_series(self); use DataType::*; let out = match phys.dtype() { @@ -224,7 +250,11 @@ where T: PolarsDataType, T::Array: Debug, { - unsafe fn take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Self { + unsafe fn take_chunked_unchecked( + &self, + by: &[ChunkId], + sorted: IsSorted, + ) -> Self { let arrow_dtype = self.dtype().to_arrow(CompatLevel::newest()); let mut out = if let Some(iter) = self.downcast_slices() { @@ -261,7 +291,7 @@ where } // Take function that checks of null state in `ChunkIdx`. - unsafe fn take_opt_chunked_unchecked(&self, by: &[NullableChunkId]) -> Self { + unsafe fn take_opt_chunked_unchecked(&self, by: &[ChunkId]) -> Self { let arrow_dtype = self.dtype().to_arrow(CompatLevel::newest()); if let Some(iter) = self.downcast_slices() { @@ -301,7 +331,11 @@ where } #[cfg(feature = "object")] -unsafe fn take_unchecked_object(s: &Series, by: &[ChunkId], _sorted: IsSorted) -> Series { +unsafe fn take_unchecked_object( + s: &Series, + by: &[ChunkId], + _sorted: IsSorted, +) -> Series { let DataType::Object(_, reg) = s.dtype() else { unreachable!() }; @@ -317,7 +351,7 @@ unsafe fn take_unchecked_object(s: &Series, by: &[ChunkId], _sorted: IsSorted) - } #[cfg(feature = "object")] -unsafe fn take_opt_unchecked_object(s: &Series, by: &[NullableChunkId]) -> Series { +unsafe fn take_opt_unchecked_object(s: &Series, by: &[ChunkId]) -> Series { let DataType::Object(_, reg) = s.dtype() else { unreachable!() }; @@ -358,9 +392,9 @@ fn create_buffer_offsets(ca: &BinaryChunked) -> Vec { } #[allow(clippy::unnecessary_cast)] -unsafe fn take_unchecked_binview( +unsafe fn take_unchecked_binview( ca: &BinaryChunked, - by: &[ChunkId], + by: &[ChunkId], sorted: IsSorted, ) -> BinaryChunked { let views = ca @@ -430,7 +464,10 @@ unsafe fn take_unchecked_binview( out } -unsafe fn take_unchecked_binview_opt(ca: &BinaryChunked, by: &[NullableChunkId]) -> BinaryChunked { +unsafe fn take_unchecked_binview_opt( + ca: &BinaryChunked, + by: &[ChunkId], +) -> BinaryChunked { let views = ca .downcast_iter() .map(|arr| arr.views().as_slice()) @@ -533,7 +570,7 @@ mod test { assert_eq!(s_1.n_chunks(), 3); // ## Ids without nulls; - let by = [ + let by: [ChunkId<24>; 7] = [ ChunkId::store(0, 0), ChunkId::store(0, 1), ChunkId::store(1, 1), @@ -549,7 +586,7 @@ mod test { assert!(out.equals(&expected)); // ## Ids with nulls; - let by: [ChunkId; 4] = [ + let by: [ChunkId<24>; 4] = [ ChunkId::null(), ChunkId::store(0, 1), ChunkId::store(1, 1), @@ -570,7 +607,7 @@ mod test { s_1.append(&s_2).unwrap(); // ## Ids without nulls; - let by = [ + let by: [ChunkId<24>; 4] = [ ChunkId::store(0, 0), ChunkId::store(0, 1), ChunkId::store(1, 1), @@ -583,7 +620,7 @@ mod test { assert!(out.equals_missing(&expected)); // ## Ids with nulls; - let by: [ChunkId; 4] = [ + let by: [ChunkId<24>; 4] = [ ChunkId::null(), ChunkId::store(0, 1), ChunkId::store(1, 1), diff --git a/crates/polars-ops/src/frame/join/dispatch_left_right.rs b/crates/polars-ops/src/frame/join/dispatch_left_right.rs index b3193ce76628..3c82773c7d9e 100644 --- a/crates/polars-ops/src/frame/join/dispatch_left_right.rs +++ b/crates/polars-ops/src/frame/join/dispatch_left_right.rs @@ -114,7 +114,7 @@ fn materialize_left_join( if let Some((offset, len)) = args.slice { right_idx = slice_slice(right_idx, offset, len); } - other._take_opt_chunked_unchecked(right_idx) + other._take_opt_chunked_unchecked_hor_par(right_idx) }, }; POOL.join(materialize_left, materialize_right) diff --git a/crates/polars-ops/src/frame/join/hash_join/mod.rs b/crates/polars-ops/src/frame/join/hash_join/mod.rs index 8d5533c41e5b..1bf51d8c2e3b 100644 --- a/crates/polars-ops/src/frame/join/hash_join/mod.rs +++ b/crates/polars-ops/src/frame/join/hash_join/mod.rs @@ -23,7 +23,7 @@ pub(crate) use sort_merge::*; pub use super::*; #[cfg(feature = "chunked_ids")] -use crate::chunked_array::gather::chunked::DfTake; +use crate::chunked_array::gather::chunked::TakeChunkedHorPar; pub fn default_join_ids() -> ChunkJoinOptIds { #[cfg(feature = "chunked_ids")] @@ -75,7 +75,7 @@ pub trait JoinDispatch: IntoDf { } else { IsSorted::Not }; - df_self._take_chunked_unchecked(chunk_ids, sorted) + df_self._take_chunked_unchecked_hor_par(chunk_ids, sorted) } } diff --git a/crates/polars-pipe/src/executors/sinks/joins/generic_probe_inner_left.rs b/crates/polars-pipe/src/executors/sinks/joins/generic_probe_inner_left.rs index 5337d517cb79..da3ed24f2fa4 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/generic_probe_inner_left.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/generic_probe_inner_left.rs @@ -3,9 +3,8 @@ use std::borrow::Cow; use arrow::array::{Array, BinaryArray}; use polars_core::prelude::*; use polars_core::series::IsSorted; -use polars_ops::chunked_array::DfTake; use polars_ops::frame::join::_finish_join; -use polars_ops::prelude::{JoinArgs, JoinType}; +use polars_ops::prelude::{JoinArgs, JoinType, TakeChunked}; use polars_utils::nulls::IsNull; use polars_utils::pl_str::PlSmallStr; @@ -208,7 +207,7 @@ impl GenericJoinProbe { .data ._take_unchecked_slice_sorted(&self.join_tuples_b, false, IsSorted::Ascending) }; - let right_df = unsafe { right_df._take_opt_chunked_unchecked_seq(&self.join_tuples_a) }; + let right_df = unsafe { right_df.take_opt_chunked_unchecked(&self.join_tuples_a) }; let out = self.finish_join(left_df, right_df)?; @@ -271,7 +270,7 @@ impl GenericJoinProbe { let left_df = unsafe { self.df_a - ._take_chunked_unchecked_seq(&self.join_tuples_a, IsSorted::Not) + .take_chunked_unchecked(&self.join_tuples_a, IsSorted::Not) }; let right_df = unsafe { let mut df = Cow::Borrowed(&chunk.data); diff --git a/crates/polars-pipe/src/executors/sinks/joins/generic_probe_outer.rs b/crates/polars-pipe/src/executors/sinks/joins/generic_probe_outer.rs index 0b7cbfb2c534..82dea0326b7d 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/generic_probe_outer.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/generic_probe_outer.rs @@ -3,9 +3,8 @@ use std::sync::atomic::Ordering; use arrow::array::{Array, BinaryArray, MutablePrimitiveArray}; use polars_core::prelude::*; use polars_core::series::IsSorted; -use polars_ops::chunked_array::DfTake; use polars_ops::frame::join::_finish_join; -use polars_ops::prelude::_coalesce_full_join; +use polars_ops::prelude::{TakeChunked, _coalesce_full_join}; use polars_utils::pl_str::PlSmallStr; use crate::executors::sinks::joins::generic_build::*; @@ -40,7 +39,7 @@ pub struct GenericFullOuterJoinProbe { // amortize allocations // in inner join these are the left table // in left join there are the right table - join_tuples_a: Vec, + join_tuples_a: Vec, // in inner join these are the right table // in left join there are the left table join_tuples_b: MutablePrimitiveArray, @@ -224,10 +223,7 @@ impl GenericFullOuterJoinProbe { } self.hashes = hashes; - let left_df = unsafe { - self.df_a - ._take_opt_chunked_unchecked_seq(&self.join_tuples_a) - }; + let left_df = unsafe { self.df_a.take_opt_chunked_unchecked(&self.join_tuples_a) }; let right_df = unsafe { self.join_tuples_b.with_freeze(|idx| { let idx = IdxCa::from(idx.clone()); @@ -260,7 +256,7 @@ impl GenericFullOuterJoinProbe { let left_df = unsafe { self.df_a - ._take_chunked_unchecked_seq(&self.join_tuples_a, IsSorted::Not) + .take_chunked_unchecked(&self.join_tuples_a, IsSorted::Not) }; let size = left_df.height(); diff --git a/crates/polars-utils/src/index.rs b/crates/polars-utils/src/index.rs index 9d8cf67f8a58..fb43a1958cd6 100644 --- a/crates/polars-utils/src/index.rs +++ b/crates/polars-utils/src/index.rs @@ -191,8 +191,6 @@ pub struct ChunkId { swizzled: u64, } -pub type NullableChunkId = ChunkId; - impl Debug for ChunkId { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { if self.is_null() {