-
Notifications
You must be signed in to change notification settings - Fork 175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEAT] List chunk expression #2491
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,7 +56,7 @@ fn create_iter<'a>(arr: &'a Int64Array, len: usize) -> Box<dyn Iterator<Item = i | |
} | ||
} | ||
|
||
pub fn get_slices_helper( | ||
fn get_slices_helper( | ||
mut parent_offsets: impl Iterator<Item = i64>, | ||
field: Arc<Field>, | ||
child_data_type: &DataType, | ||
|
@@ -118,6 +118,84 @@ pub fn get_slices_helper( | |
.into_series()) | ||
} | ||
|
||
/// Helper function that gets chunks of a given `size` from each list in the Series. Discards excess | ||
/// elements that do not fit into the chunks. | ||
/// | ||
/// This function has two paths. The first is a fast path that is taken when all lists in the | ||
/// Series have a length that is a multiple of `size`, which means they can be chunked cleanly | ||
/// without leftover elements. In the fast path, we simply pass the underlying array of elements to | ||
/// the result, but reinterpret it as a list of fixed sized lists. | ||
/// | ||
/// If there is at least one list that cannot be chunked cleanly, the underlying array of elements | ||
/// has to be compacted to remove the excess elements. In this case we take the slower path that | ||
/// does this compaction. | ||
/// | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `flat_child` - The Series that we're extracting chunks from. | ||
/// * `field` - The field of the parent list. | ||
/// * `validity` - The parent list's validity. | ||
/// * `size` - The size for each chunk. | ||
/// * `total_elements_to_skip` - The number of elements in the Series that do not fit cleanly into | ||
/// chunks. We take the fast path iff this value is 0. | ||
/// * `to_skip` - An optional iterator of the number of elements to skip for each list. Elements | ||
/// are skipped when they cannot fit into their parent list's chunks. | ||
/// * `new_offsets` - The new offsets to use for the topmost list array, this is computed based on | ||
/// the number of chunks extracted from each list. | ||
fn get_chunks_helper( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lots of arguments here, wanna add some docstrings to help explain them? We aren't very good with documentation of our code, but maybe a good time to start 😛 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good, added some docs. I also switched around the argument list while writing the docs because there was a more natural order to them: input/parent values (flat_child, field, validity) -> input/expression argument (size) -> "output"/result of preprocessing (elements to skip and new offsets) |
||
flat_child: &Series, | ||
field: Arc<Field>, | ||
validity: Option<&arrow2::bitmap::Bitmap>, | ||
size: usize, | ||
total_elements_to_skip: usize, | ||
to_skip: Option<impl Iterator<Item = usize>>, | ||
new_offsets: Vec<i64>, | ||
) -> DaftResult<Series> { | ||
if total_elements_to_skip == 0 { | ||
let inner_list_field = field.to_exploded_field()?.to_fixed_size_list_field(size)?; | ||
let inner_list = FixedSizeListArray::new( | ||
inner_list_field.clone(), | ||
flat_child.clone(), | ||
None, // Since we're creating an extra layer of lists, this layer doesn't have any | ||
// validity information. The topmost list takes the parent's validity, and the | ||
// child list is unaffected by the chunking operation and maintains its validity. | ||
// This reasoning applies to the places that follow where validity is set. | ||
); | ||
Ok(ListArray::new( | ||
inner_list_field.to_list_field()?, | ||
inner_list.into_series(), | ||
arrow2::offset::OffsetsBuffer::try_from(new_offsets)?, | ||
validity.cloned(), // Copy the parent's validity. | ||
) | ||
.into_series()) | ||
} else { | ||
let mut growable: Box<dyn Growable> = make_growable( | ||
&field.name, | ||
&field.to_exploded_field()?.dtype, | ||
vec![flat_child], | ||
false, // There's no validity to set, see the comment above. | ||
flat_child.len() - total_elements_to_skip, | ||
); | ||
let mut starting_idx = 0; | ||
for (i, to_skip) in to_skip.unwrap().enumerate() { | ||
let num_chunks = new_offsets.get(i + 1).unwrap() - new_offsets.get(i).unwrap(); | ||
let slice_len = num_chunks as usize * size; | ||
growable.extend(0, starting_idx, slice_len); | ||
starting_idx += slice_len + to_skip; | ||
} | ||
let inner_list_field = field.to_exploded_field()?.to_fixed_size_list_field(size)?; | ||
let inner_list = FixedSizeListArray::new(inner_list_field.clone(), growable.build()?, None); | ||
Ok(ListArray::new( | ||
inner_list_field.to_list_field()?, | ||
inner_list.into_series(), | ||
arrow2::offset::OffsetsBuffer::try_from(new_offsets)?, | ||
validity.cloned(), // Copy the parent's validity. | ||
) | ||
.into_series()) | ||
} | ||
} | ||
|
||
impl ListArray { | ||
pub fn count(&self, mode: CountMode) -> DaftResult<UInt64Array> { | ||
let counts = match (mode, self.flat_child.validity()) { | ||
|
@@ -274,6 +352,34 @@ impl ListArray { | |
end_iter, | ||
) | ||
} | ||
|
||
pub fn get_chunks(&self, size: usize) -> DaftResult<Series> { | ||
let mut to_skip = Vec::with_capacity(self.flat_child.len()); | ||
let mut new_offsets = Vec::with_capacity(self.flat_child.len() + 1); | ||
let mut total_elements_to_skip = 0; | ||
new_offsets.push(0); | ||
for i in 0..self.offsets().len() - 1 { | ||
let slice_len = self.offsets().get(i + 1).unwrap() - self.offsets().get(i).unwrap(); | ||
let modulo = slice_len as usize % size; | ||
to_skip.push(modulo); | ||
total_elements_to_skip += modulo; | ||
new_offsets.push(new_offsets.last().unwrap() + (slice_len / size as i64)); | ||
} | ||
let to_skip = if total_elements_to_skip == 0 { | ||
None | ||
} else { | ||
Some(to_skip.iter().copied()) | ||
}; | ||
get_chunks_helper( | ||
&self.flat_child, | ||
self.field.clone(), | ||
self.validity(), | ||
size, | ||
total_elements_to_skip, | ||
to_skip, | ||
new_offsets, | ||
) | ||
} | ||
} | ||
|
||
impl FixedSizeListArray { | ||
|
@@ -420,6 +526,34 @@ impl FixedSizeListArray { | |
end_iter, | ||
) | ||
} | ||
|
||
pub fn get_chunks(&self, size: usize) -> DaftResult<Series> { | ||
let list_size = self.fixed_element_len(); | ||
let num_chunks = list_size / size; | ||
let modulo = list_size % size; | ||
let total_elements_to_skip = modulo * self.len(); | ||
let new_offsets: Vec<i64> = if !self.is_empty() && num_chunks > 0 { | ||
(0..=((self.len() * num_chunks) as i64)) | ||
.step_by(num_chunks) | ||
.collect() | ||
} else { | ||
vec![0; self.len() + 1] | ||
}; | ||
let to_skip = if total_elements_to_skip == 0 { | ||
None | ||
} else { | ||
Some(std::iter::repeat(modulo).take(self.len())) | ||
}; | ||
get_chunks_helper( | ||
&self.flat_child, | ||
self.field.clone(), | ||
self.validity(), | ||
size, | ||
total_elements_to_skip, | ||
to_skip, | ||
new_offsets, | ||
) | ||
} | ||
} | ||
|
||
macro_rules! impl_aggs_list_array { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
use crate::ExprRef; | ||
use daft_core::{datatypes::Field, schema::Schema, series::Series}; | ||
|
||
use super::{super::FunctionEvaluator, ListExpr}; | ||
use crate::functions::FunctionExpr; | ||
use common_error::{DaftError, DaftResult}; | ||
|
||
pub(super) struct ChunkEvaluator {} | ||
|
||
impl FunctionEvaluator for ChunkEvaluator { | ||
fn fn_name(&self) -> &'static str { | ||
"chunk" | ||
} | ||
|
||
fn to_field( | ||
&self, | ||
inputs: &[ExprRef], | ||
schema: &Schema, | ||
expr: &FunctionExpr, | ||
) -> DaftResult<Field> { | ||
let size = match expr { | ||
FunctionExpr::List(ListExpr::Chunk(size)) => size, | ||
_ => panic!("Expected Chunk Expr, got {expr}"), | ||
}; | ||
match inputs { | ||
[input] => { | ||
let input_field = input.to_field(schema)?; | ||
Ok(input_field | ||
.to_exploded_field()? | ||
.to_fixed_size_list_field(*size)? | ||
.to_list_field()?) | ||
} | ||
_ => Err(DaftError::SchemaMismatch(format!( | ||
"Expected 1 input args, got {}", | ||
inputs.len() | ||
))), | ||
} | ||
} | ||
|
||
fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult<Series> { | ||
let size = match expr { | ||
FunctionExpr::List(ListExpr::Chunk(size)) => size, | ||
_ => panic!("Expected Chunk Expr, got {expr}"), | ||
}; | ||
match inputs { | ||
[input] => input.list_chunk(*size), | ||
_ => Err(DaftError::ValueError(format!( | ||
"Expected 1 input args, got {}", | ||
inputs.len() | ||
))), | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also add this into our Expressions docs:
docs/source/api_docs/expressions.rst
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha. I didn't do this for the slice expression too, so I'll add it at the same time.