Skip to content

Commit

Permalink
refactor(rust): Use HashKeys abstraction (#19785)
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp authored Nov 15, 2024
1 parent f45c9e9 commit e59626d
Show file tree
Hide file tree
Showing 10 changed files with 183 additions and 48 deletions.
2 changes: 2 additions & 0 deletions crates/polars-arrow/src/compute/take/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ use crate::array::{Array, BinaryArray, PrimitiveArray};
use crate::offset::Offset;

/// `take` implementation for utf8 arrays
/// # Safety
/// The indices must be in-bounds.
pub unsafe fn take_unchecked<O: Offset, I: Index>(
values: &BinaryArray<O>,
indices: &PrimitiveArray<I>,
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-arrow/src/compute/take/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ unsafe fn take_values_indices_validity(
}

/// `take` implementation for boolean arrays
/// # Safety
/// The indices must be in-bounds.
pub unsafe fn take_unchecked(
values: &BooleanArray,
indices: &PrimitiveArray<IdxSize>,
Expand Down
18 changes: 9 additions & 9 deletions crates/polars-arrow/src/compute/take/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ use crate::compute::take::binview::take_binview_unchecked;
use crate::datatypes::{ArrowDataType, IdxArr};
use crate::types::Index;

mod binary;
mod binview;
mod bitmap;
mod boolean;
mod fixed_size_list;
mod generic_binary;
mod list;
mod primitive;
mod structure;
pub mod binary;
pub mod binview;
pub mod bitmap;
pub mod boolean;
pub mod fixed_size_list;
pub mod generic_binary;
pub mod list;
pub mod primitive;
pub mod structure;

use crate::with_match_primitive_type_full;

Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/hashing/vector_hasher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ pub fn _df_rows_to_hashes_threaded_vertical(
Ok((hashes, hasher_builder))
}

pub(crate) fn columns_to_hashes(
pub fn columns_to_hashes(
keys: &[Column],
build_hasher: Option<PlRandomState>,
hashes: &mut Vec<u64>,
Expand Down
12 changes: 5 additions & 7 deletions crates/polars-expr/src/groups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ use std::any::Any;
use std::path::Path;

use polars_core::prelude::*;
use polars_utils::aliases::PlRandomState;
use polars_utils::cardinality_sketch::CardinalitySketch;
use polars_utils::hashing::HashPartitioner;
use polars_utils::IdxSize;

use crate::hash_keys::HashKeys;

mod row_encoded;

/// A Grouper maps keys to groups, such that duplicate keys map to the same group.
Expand All @@ -22,7 +23,7 @@ pub trait Grouper: Any + Send + Sync {

/// Inserts the given keys into this Grouper, mutating groups_idxs such
/// that group_idxs[i] is the group index of keys[..][i].
fn insert_keys(&mut self, keys: &DataFrame, group_idxs: &mut Vec<IdxSize>);
fn insert_keys(&mut self, keys: HashKeys, group_idxs: &mut Vec<IdxSize>);

/// Adds the given Grouper into this one, mutating groups_idxs such that
/// the ith group of other now has group index group_idxs[i] in self.
Expand Down Expand Up @@ -69,9 +70,6 @@ pub trait Grouper: Any + Send + Sync {
fn as_any(&self) -> &dyn Any;
}

pub fn new_hash_grouper(key_schema: Arc<Schema>, random_state: PlRandomState) -> Box<dyn Grouper> {
Box::new(row_encoded::RowEncodedHashGrouper::new(
key_schema,
random_state,
))
pub fn new_hash_grouper(key_schema: Arc<Schema>) -> Box<dyn Grouper> {
Box::new(row_encoded::RowEncodedHashGrouper::new(key_schema))
}
37 changes: 10 additions & 27 deletions crates/polars-expr/src/groups/row_encoded.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
use hashbrown::hash_table::{Entry, HashTable};
use polars_core::chunked_array::ops::row_encode::_get_rows_encoded_unordered;
use polars_row::EncodingField;
use polars_utils::aliases::PlRandomState;
use polars_utils::cardinality_sketch::CardinalitySketch;
use polars_utils::itertools::Itertools;
use polars_utils::vec::PushUnchecked;

use super::*;
use crate::hash_keys::HashKeys;

const BASE_KEY_DATA_CAPACITY: usize = 1024;

Expand All @@ -31,19 +29,15 @@ pub struct RowEncodedHashGrouper {
group_keys: Vec<Key>,
key_data: Vec<Vec<u8>>,

// Used for computing canonical hashes.
random_state: PlRandomState,

// Internal random seed used to keep hash iteration order decorrelated.
// We simply store a random odd number and multiply the canonical hash by it.
seed: u64,
}

impl RowEncodedHashGrouper {
pub fn new(key_schema: Arc<Schema>, random_state: PlRandomState) -> Self {
pub fn new(key_schema: Arc<Schema>) -> Self {
Self {
key_schema,
random_state,
seed: rand::random::<u64>() | 1,
key_data: vec![Vec::with_capacity(BASE_KEY_DATA_CAPACITY)],
..Default::default()
Expand Down Expand Up @@ -119,10 +113,7 @@ impl RowEncodedHashGrouper {

impl Grouper for RowEncodedHashGrouper {
fn new_empty(&self) -> Box<dyn Grouper> {
Box::new(Self::new(
self.key_schema.clone(),
self.random_state.clone(),
))
Box::new(Self::new(self.key_schema.clone()))
}

fn reserve(&mut self, additional: usize) {
Expand All @@ -137,23 +128,15 @@ impl Grouper for RowEncodedHashGrouper {
self.table.len() as IdxSize
}

fn insert_keys(&mut self, keys: &DataFrame, group_idxs: &mut Vec<IdxSize>) {
let series = keys
.get_columns()
.iter()
.map(|c| c.as_materialized_series().clone())
.collect_vec();
let keys_encoded = _get_rows_encoded_unordered(&series[..])
.unwrap()
.into_array();
assert!(keys_encoded.len() == keys[0].len());

fn insert_keys(&mut self, keys: HashKeys, group_idxs: &mut Vec<IdxSize>) {
let HashKeys::RowEncoded(keys) = keys else {
unreachable!()
};
group_idxs.clear();
group_idxs.reserve(keys_encoded.len());
for key in keys_encoded.values_iter() {
let hash = self.random_state.hash_one(key);
group_idxs.reserve(keys.hashes.len());
for (hash, key) in keys.hashes.iter().zip(keys.keys.values_iter()) {
unsafe {
group_idxs.push_unchecked(self.insert_key(hash, key));
group_idxs.push_unchecked(self.insert_key(*hash, key));
}
}
}
Expand Down
142 changes: 142 additions & 0 deletions crates/polars-expr/src/hash_keys.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
use arrow::array::BinaryArray;
use arrow::compute::take::binary::take_unchecked;
use polars_core::frame::DataFrame;
use polars_core::prelude::row_encode::_get_rows_encoded_unordered;
use polars_core::prelude::PlRandomState;
use polars_core::series::Series;
use polars_utils::hashing::HashPartitioner;
use polars_utils::itertools::Itertools;
use polars_utils::vec::PushUnchecked;
use polars_utils::IdxSize;

/// Represents a DataFrame plus a hash per row, intended for keys in grouping
/// or joining. The hashes may or may not actually be physically pre-computed,
/// this depends per type.
pub enum HashKeys {
RowEncoded(RowEncodedKeys),
Single(SingleKeys),
}

impl HashKeys {
pub fn from_df(df: &DataFrame, random_state: PlRandomState, force_row_encoding: bool) -> Self {
if df.width() > 1 || force_row_encoding {
let keys = df
.get_columns()
.iter()
.map(|c| c.as_materialized_series().clone())
.collect_vec();
let keys_encoded = _get_rows_encoded_unordered(&keys[..]).unwrap().into_array();
assert!(keys_encoded.len() == df.height());

// TODO: use vechash? Not supported yet for lists.
// let mut hashes = Vec::with_capacity(df.height());
// columns_to_hashes(df.get_columns(), Some(random_state), &mut hashes).unwrap();

let hashes = keys_encoded
.values_iter()
.map(|k| random_state.hash_one(k))
.collect();
Self::RowEncoded(RowEncodedKeys {
hashes,
keys: keys_encoded,
})
} else {
todo!()
// Self::Single(SingleKeys {
// random_state,
// hashes: todo!(),
// keys: df[0].as_materialized_series().clone(),
// })
}
}

pub fn gen_partition_idxs(
&self,
partitioner: &HashPartitioner,
partition_idxs: &mut [Vec<IdxSize>],
) {
match self {
Self::RowEncoded(s) => s.gen_partition_idxs(partitioner, partition_idxs),
Self::Single(s) => s.gen_partition_idxs(partitioner, partition_idxs),
}
}

/// # Safety
/// The indices must be in-bounds.
pub unsafe fn gather(&self, idxs: &[IdxSize]) -> Self {
match self {
Self::RowEncoded(s) => Self::RowEncoded(s.gather(idxs)),
Self::Single(s) => Self::Single(s.gather(idxs)),
}
}
}

pub struct RowEncodedKeys {
pub hashes: Vec<u64>,
pub keys: BinaryArray<i64>,
}

impl RowEncodedKeys {
pub fn gen_partition_idxs(
&self,
partitioner: &HashPartitioner,
partition_idxs: &mut [Vec<IdxSize>],
) {
assert!(partitioner.num_partitions() == partition_idxs.len());
for (i, h) in self.hashes.iter().enumerate() {
unsafe {
// SAFETY: we assured the number of partitions matches.
let p = partitioner.hash_to_partition(*h);
partition_idxs.get_unchecked_mut(p).push(i as IdxSize);
}
}
}

/// # Safety
/// The indices must be in-bounds.
pub unsafe fn gather(&self, idxs: &[IdxSize]) -> Self {
let mut hashes = Vec::with_capacity(idxs.len());
for idx in idxs {
hashes.push_unchecked(*self.hashes.get_unchecked(*idx as usize));
}
let idx_arr = arrow::ffi::mmap::slice(idxs);
let keys = take_unchecked(&self.keys, &idx_arr);
Self { hashes, keys }
}
}

/// Single keys. Does not pre-hash for boolean & integer types, only for strings
/// and nested types.
pub struct SingleKeys {
pub random_state: PlRandomState,
pub hashes: Option<Vec<u64>>,
pub keys: Series,
}

impl SingleKeys {
pub fn gen_partition_idxs(
&self,
partitioner: &HashPartitioner,
partition_idxs: &mut [Vec<IdxSize>],
) {
assert!(partitioner.num_partitions() == partition_idxs.len());
todo!()
}

/// # Safety
/// The indices must be in-bounds.
pub unsafe fn gather(&self, idxs: &[IdxSize]) -> Self {
let hashes = self.hashes.as_ref().map(|hashes| {
let mut out = Vec::with_capacity(idxs.len());
for idx in idxs {
out.push_unchecked(*hashes.get_unchecked(*idx as usize));
}
out
});
Self {
random_state: self.random_state.clone(),
hashes,
keys: self.keys.take_slice_unchecked(idxs),
}
}
}
1 change: 1 addition & 0 deletions crates/polars-expr/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod expressions;
pub mod groups;
pub mod hash_keys;
pub mod planner;
pub mod prelude;
pub mod reduce;
Expand Down
10 changes: 8 additions & 2 deletions crates/polars-stream/src/nodes/group_by.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::sync::Arc;

use polars_core::prelude::IntoColumn;
use polars_core::prelude::{IntoColumn, PlRandomState};
use polars_core::schema::Schema;
use polars_core::utils::accumulate_dataframes_vertical_unchecked;
use polars_core::POOL;
use polars_expr::groups::Grouper;
use polars_expr::hash_keys::HashKeys;
use polars_expr::reduce::GroupedReduction;
use polars_utils::cardinality_sketch::CardinalitySketch;
use polars_utils::hashing::HashPartitioner;
Expand Down Expand Up @@ -40,6 +41,7 @@ struct GroupBySinkState {
grouper: Box<dyn Grouper>,
grouped_reductions: Vec<Box<dyn GroupedReduction>>,
local: Vec<LocalGroupBySinkState>,
random_state: PlRandomState,
}

impl GroupBySinkState {
Expand All @@ -63,6 +65,7 @@ impl GroupBySinkState {
for (mut recv, local) in receivers.into_iter().zip(&mut self.local) {
let key_selectors = &self.key_selectors;
let grouped_reduction_selectors = &self.grouped_reduction_selectors;
let random_state = &self.random_state;
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
let mut group_idxs = Vec::new();
while let Ok(morsel) = recv.recv().await {
Expand All @@ -74,7 +77,8 @@ impl GroupBySinkState {
key_columns.push(s.into_column());
}
let keys = DataFrame::new_with_broadcast_len(key_columns, df.height())?;
local.grouper.insert_keys(&keys, &mut group_idxs);
let hash_keys = HashKeys::from_df(&keys, random_state.clone(), true);
local.grouper.insert_keys(hash_keys, &mut group_idxs);

// Update reductions.
for (selector, reduction) in grouped_reduction_selectors
Expand Down Expand Up @@ -241,6 +245,7 @@ impl GroupByNode {
grouped_reductions: Vec<Box<dyn GroupedReduction>>,
grouper: Box<dyn Grouper>,
output_schema: Arc<Schema>,
random_state: PlRandomState,
) -> Self {
Self {
state: GroupByState::Sink(GroupBySinkState {
Expand All @@ -249,6 +254,7 @@ impl GroupByNode {
grouped_reductions,
grouper,
local: Vec::new(),
random_state,
}),
output_schema,
}
Expand Down
5 changes: 3 additions & 2 deletions crates/polars-stream/src/physical_plan/to_graph.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::sync::Arc;

use parking_lot::Mutex;
use polars_core::prelude::PlRandomState;
use polars_core::schema::{Schema, SchemaExt};
use polars_error::PolarsResult;
use polars_expr::groups::new_hash_grouper;
Expand Down Expand Up @@ -395,8 +396,7 @@ fn to_graph_rec<'a>(
let input_schema = &ctx.phys_sm[*input].output_schema;
let key_schema = compute_output_schema(input_schema, key, ctx.expr_arena)?
.materialize_unknown_dtypes()?;
let random_state = Default::default();
let grouper = new_hash_grouper(Arc::new(key_schema), random_state);
let grouper = new_hash_grouper(Arc::new(key_schema));

let key_selectors = key
.iter()
Expand Down Expand Up @@ -424,6 +424,7 @@ fn to_graph_rec<'a>(
grouped_reductions,
grouper,
node.output_schema.clone(),
PlRandomState::new(),
),
[input_key],
)
Expand Down

0 comments on commit e59626d

Please sign in to comment.