Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Ensure that splitted ChunkedArray also flattens chunks #16837

Merged
merged 1 commit into from
Jun 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 142 additions & 28 deletions crates/polars-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ macro_rules! split_array {
}};
}

// This one splits, but doesn't flatten chunks;
pub fn split_ca<T>(ca: &ChunkedArray<T>, n: usize) -> PolarsResult<Vec<ChunkedArray<T>>>
where
T: PolarsDataType,
Expand Down Expand Up @@ -139,46 +140,147 @@ pub fn split_series(s: &Series, n: usize) -> PolarsResult<Vec<Series>> {
split_array!(s, n, i64)
}

/// Split a [`DataFrame`] in `target` elements. The target doesn't have to be respected if not
/// strict. Deviation of the target might be done to create more equal size chunks.
///
/// # Panics
/// if chunks are not aligned
pub fn split_df_as_ref(df: &DataFrame, target: usize, strict: bool) -> Vec<DataFrame> {
let total_len = df.height();
#[allow(clippy::len_without_is_empty)]
pub trait Container: Clone {
fn slice(&self, offset: i64, len: usize) -> Self;

fn len(&self) -> usize;

fn iter_chunks(&self) -> impl Iterator<Item = Self>;

fn n_chunks(&self) -> usize;

fn chunk_lengths(&self) -> impl Iterator<Item = usize>;
}

impl Container for DataFrame {
fn slice(&self, offset: i64, len: usize) -> Self {
DataFrame::slice(self, offset, len)
}

fn len(&self) -> usize {
self.height()
}

fn iter_chunks(&self) -> impl Iterator<Item = Self> {
flatten_df_iter(self)
}

fn n_chunks(&self) -> usize {
DataFrame::n_chunks(self)
}

fn chunk_lengths(&self) -> impl Iterator<Item = usize> {
self.get_columns()[0].chunk_lengths()
}
}

impl<T: PolarsDataType> Container for ChunkedArray<T> {
fn slice(&self, offset: i64, len: usize) -> Self {
ChunkedArray::slice(self, offset, len)
}

fn len(&self) -> usize {
ChunkedArray::len(self)
}

fn iter_chunks(&self) -> impl Iterator<Item = Self> {
self.downcast_iter()
.map(|arr| Self::with_chunk(self.name(), arr.clone()))
}

fn n_chunks(&self) -> usize {
self.chunks().len()
}

fn chunk_lengths(&self) -> impl Iterator<Item = usize> {
ChunkedArray::chunk_lengths(self)
}
}

impl Container for Series {
fn slice(&self, offset: i64, len: usize) -> Self {
self.0.slice(offset, len)
}

fn len(&self) -> usize {
self.0.len()
}

fn iter_chunks(&self) -> impl Iterator<Item = Self> {
(0..self.0.n_chunks()).map(|i| self.select_chunk(i))
}

fn n_chunks(&self) -> usize {
self.chunks().len()
}

fn chunk_lengths(&self) -> impl Iterator<Item = usize> {
self.0.chunk_lengths()
}
}

fn split_impl<C: Container>(container: &C, target: usize, chunk_size: usize) -> Vec<C> {
let total_len = container.len();
let mut out = Vec::with_capacity(target);

for i in 0..target {
let offset = i * chunk_size;
let len = if i == (target - 1) {
total_len.saturating_sub(offset)
} else {
chunk_size
};
let container = container.slice((i * chunk_size) as i64, len);
out.push(container);
}
out
}

pub fn split<C: Container>(container: &C, target: usize) -> Vec<C> {
let total_len = container.len();
if total_len == 0 {
return vec![df.clone()];
return vec![container.clone()];
}

let chunk_size = std::cmp::max(total_len / target, 1);

if df.n_chunks() == target
&& df.get_columns()[0]
if container.n_chunks() == target
&& container
.chunk_lengths()
.all(|len| len.abs_diff(chunk_size) < 100)
{
return flatten_df_iter(df).collect();
return container.iter_chunks().collect();
}
split_impl(container, target, chunk_size)
}

let mut out = Vec::with_capacity(target);
/// Split a [`Container`] in `target` elements. The target doesn't have to be respected if not
/// Deviation of the target might be done to create more equal size chunks.
pub fn split_and_flatten<C: Container>(container: &C, target: usize) -> Vec<C> {
let total_len = container.len();
if total_len == 0 {
return vec![container.clone()];
}

if df.n_chunks() == 1 || strict {
for i in 0..target {
let offset = i * chunk_size;
let len = if i == (target - 1) {
total_len.saturating_sub(offset)
} else {
chunk_size
};
let df = df.slice((i * chunk_size) as i64, len);
out.push(df);
}
let chunk_size = std::cmp::max(total_len / target, 1);

if container.n_chunks() == target
&& container
.chunk_lengths()
.all(|len| len.abs_diff(chunk_size) < 100)
{
return container.iter_chunks().collect();
}

if container.n_chunks() == 1 {
split_impl(container, target, chunk_size)
} else {
let chunks = flatten_df_iter(df);
let mut out = Vec::with_capacity(target);
let chunks = container.iter_chunks();

'new_chunk: for mut chunk in chunks {
loop {
let h = chunk.height();
let h = chunk.len();
if h < chunk_size {
// TODO if the chunk is much smaller than chunk size, we should try to merge it with the next one.
out.push(chunk);
Expand All @@ -191,14 +293,26 @@ pub fn split_df_as_ref(df: &DataFrame, target: usize, strict: bool) -> Vec<DataF
continue 'new_chunk;
}

// This would be faster if we had a `split` operation.
// TODO! use `split` operation here. That saves a null count.
out.push(chunk.slice(0, chunk_size));
chunk = chunk.slice(chunk_size as i64, h - chunk_size);
}
}
out
}
}

out
/// Split a [`DataFrame`] in `target` elements. The target doesn't have to be respected if not
/// strict. Deviation of the target might be done to create more equal size chunks.
///
/// # Panics
/// if chunks are not aligned
pub fn split_df_as_ref(df: &DataFrame, target: usize, strict: bool) -> Vec<DataFrame> {
if strict {
split(df, target)
} else {
split_and_flatten(df, target)
}
}

#[doc(hidden)]
Expand Down
30 changes: 16 additions & 14 deletions crates/polars-ops/src/frame/join/asof/groups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use polars_core::hashing::{
};
use polars_core::prelude::*;
use polars_core::utils::flatten::flatten_nullable;
use polars_core::utils::{_set_partition_size, split_ca, split_df};
use polars_core::utils::{_set_partition_size, split_and_flatten};
use polars_core::{with_match_physical_float_polars_type, IdBuildHasher, POOL};
use polars_utils::abs_diff::AbsDiff;
use polars_utils::hashing::{hash_to_partition, DirtyHash};
Expand Down Expand Up @@ -169,21 +169,24 @@ where
A: for<'a> AsofJoinState<T::Physical<'a>>,
F: Sync + for<'a> Fn(T::Physical<'a>, T::Physical<'a>) -> bool,
{
let left_asof = left_asof.rechunk();
let right_asof = right_asof.rechunk();
let (left_asof, right_asof) = POOL.join(|| left_asof.rechunk(), || right_asof.rechunk());
let left_val_arr = left_asof.downcast_iter().next().unwrap();
let right_val_arr = right_asof.downcast_iter().next().unwrap();

let n_threads = POOL.current_num_threads();
let split_by_left = split_ca(by_left, n_threads).unwrap();
let split_by_right = split_ca(by_right, n_threads).unwrap();
// `strict` is false so that we always flatten. Even if there are more chunks than threads.
let split_by_left = split_and_flatten(by_left, n_threads);
let split_by_right = split_and_flatten(by_right, n_threads);
let offsets = compute_len_offsets(split_by_left.iter().map(|s| s.len()));

// TODO: handle nulls more efficiently. Right now we just join on the value
// ignoring the validity mask, and ignore the nulls later.
let right_slices = split_by_right
.iter()
.map(|ca| ca.downcast_iter().next().unwrap().values_iter().copied())
.map(|ca| {
assert_eq!(ca.chunks().len(), 1);
ca.downcast_iter().next().unwrap().values_iter().copied()
})
.collect();
let hash_tbls = build_tables(right_slices, false);
let n_tables = hash_tbls.len();
Expand All @@ -197,6 +200,7 @@ where
let mut group_states: PlHashMap<IdxSize, A> =
PlHashMap::with_capacity(_HASHMAP_INIT_SIZE);

assert_eq!(by_left.chunks().len(), 1);
let by_left_chunk = by_left.downcast_iter().next().unwrap();
for (rel_idx_left, opt_by_left_k) in by_left_chunk.iter().enumerate() {
let Some(by_left_k) = opt_by_left_k else {
Expand Down Expand Up @@ -245,14 +249,13 @@ where
A: for<'a> AsofJoinState<T::Physical<'a>>,
F: Sync + for<'a> Fn(T::Physical<'a>, T::Physical<'a>) -> bool,
{
let left_asof = left_asof.rechunk();
let right_asof = right_asof.rechunk();
let (left_asof, right_asof) = POOL.join(|| left_asof.rechunk(), || right_asof.rechunk());
let left_val_arr = left_asof.downcast_iter().next().unwrap();
let right_val_arr = right_asof.downcast_iter().next().unwrap();

let n_threads = POOL.current_num_threads();
let split_by_left = split_ca(by_left, n_threads).unwrap();
let split_by_right = split_ca(by_right, n_threads).unwrap();
let split_by_left = split_and_flatten(by_left, n_threads);
let split_by_right = split_and_flatten(by_right, n_threads);
let offsets = compute_len_offsets(split_by_left.iter().map(|s| s.len()));

let hb = RandomState::default();
Expand Down Expand Up @@ -311,14 +314,13 @@ where
A: for<'a> AsofJoinState<T::Physical<'a>>,
F: Sync + for<'a> Fn(T::Physical<'a>, T::Physical<'a>) -> bool,
{
let left_asof = left_asof.rechunk();
let right_asof = right_asof.rechunk();
let (left_asof, right_asof) = POOL.join(|| left_asof.rechunk(), || right_asof.rechunk());
let left_val_arr = left_asof.downcast_iter().next().unwrap();
let right_val_arr = right_asof.downcast_iter().next().unwrap();

let n_threads = POOL.current_num_threads();
let split_by_left = split_df(by_left, n_threads, false);
let split_by_right = split_df(by_right, n_threads, false);
let split_by_left = split_and_flatten(by_left, n_threads);
let split_by_right = split_and_flatten(by_right, n_threads);

let (build_hashes, random_state) =
_df_rows_to_hashes_threaded_vertical(&split_by_right, None).unwrap();
Expand Down