Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] List chunk expression #2491

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1147,6 +1147,7 @@ class PyExpr:
def list_min(self) -> PyExpr: ...
def list_max(self) -> PyExpr: ...
def list_slice(self, start: PyExpr, end: PyExpr) -> PyExpr: ...
def list_chunk(self, size: int) -> PyExpr: ...
def struct_get(self, name: str) -> PyExpr: ...
def map_get(self, key: PyExpr) -> PyExpr: ...
def url_download(
Expand Down
12 changes: 12 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2597,6 +2597,18 @@ def slice(self, start: int | Expression, end: int | Expression) -> Expression:
end_expr = Expression._to_expression(end)
return Expression._from_pyexpr(self._expr.list_slice(start_expr._expr, end_expr._expr))

def chunk(self, size: int) -> Expression:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also add this into our Expressions docs: docs/source/api_docs/expressions.rst

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha. I didn't do this for the slice expression too, so I'll add it at the same time.

"""Splits each list into chunks of the given size

Args:
size: size of chunks to split the list into. Must be greater than 0
Returns:
Expression: an expression with lists of fixed size lists of the type of the list values
"""
if not (isinstance(size, int) and size > 0):
raise ValueError(f"Invalid value for `size`: {size}")
return Expression._from_pyexpr(self._expr.list_chunk(size))

def sum(self) -> Expression:
"""Sums each list. Empty lists and lists with all nulls yield null.

Expand Down
2 changes: 2 additions & 0 deletions docs/source/api_docs/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ List
Expression.list.join
Expression.list.lengths
Expression.list.get
Expression.list.slice
Expression.list.chunk

Struct
######
Expand Down
136 changes: 135 additions & 1 deletion src/daft-core/src/array/ops/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ fn create_iter<'a>(arr: &'a Int64Array, len: usize) -> Box<dyn Iterator<Item = i
}
}

pub fn get_slices_helper(
fn get_slices_helper(
mut parent_offsets: impl Iterator<Item = i64>,
field: Arc<Field>,
child_data_type: &DataType,
Expand Down Expand Up @@ -118,6 +118,84 @@ pub fn get_slices_helper(
.into_series())
}

/// Helper function that gets chunks of a given `size` from each list in the Series. Discards excess
/// elements that do not fit into the chunks.
///
/// This function has two paths. The first is a fast path that is taken when all lists in the
/// Series have a length that is a multiple of `size`, which means they can be chunked cleanly
/// without leftover elements. In the fast path, we simply pass the underlying array of elements to
/// the result, but reinterpret it as a list of fixed sized lists.
///
/// If there is at least one list that cannot be chunked cleanly, the underlying array of elements
/// has to be compacted to remove the excess elements. In this case we take the slower path that
/// does this compaction.
///
///
/// # Arguments
///
/// * `flat_child` - The Series that we're extracting chunks from.
/// * `field` - The field of the parent list.
/// * `validity` - The parent list's validity.
/// * `size` - The size for each chunk.
/// * `total_elements_to_skip` - The number of elements in the Series that do not fit cleanly into
/// chunks. We take the fast path iff this value is 0.
/// * `to_skip` - An optional iterator of the number of elements to skip for each list. Elements
/// are skipped when they cannot fit into their parent list's chunks.
/// * `new_offsets` - The new offsets to use for the topmost list array, this is computed based on
/// the number of chunks extracted from each list.
fn get_chunks_helper(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lots of arguments here, wanna add some docstrings to help explain them?

We aren't very good with documentation of our code, but maybe a good time to start 😛

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, added some docs. I also switched around the argument list while writing the docs because there was a more natural order to them: input/parent values (flat_child, field, validity) -> input/expression argument (size) -> "output"/result of preprocessing (elements to skip and new offsets)

flat_child: &Series,
field: Arc<Field>,
validity: Option<&arrow2::bitmap::Bitmap>,
size: usize,
total_elements_to_skip: usize,
to_skip: Option<impl Iterator<Item = usize>>,
new_offsets: Vec<i64>,
) -> DaftResult<Series> {
if total_elements_to_skip == 0 {
let inner_list_field = field.to_exploded_field()?.to_fixed_size_list_field(size)?;
let inner_list = FixedSizeListArray::new(
inner_list_field.clone(),
flat_child.clone(),
None, // Since we're creating an extra layer of lists, this layer doesn't have any
// validity information. The topmost list takes the parent's validity, and the
// child list is unaffected by the chunking operation and maintains its validity.
// This reasoning applies to the places that follow where validity is set.
);
Ok(ListArray::new(
inner_list_field.to_list_field()?,
inner_list.into_series(),
arrow2::offset::OffsetsBuffer::try_from(new_offsets)?,
validity.cloned(), // Copy the parent's validity.
)
.into_series())
} else {
let mut growable: Box<dyn Growable> = make_growable(
&field.name,
&field.to_exploded_field()?.dtype,
vec![flat_child],
false, // There's no validity to set, see the comment above.
flat_child.len() - total_elements_to_skip,
);
let mut starting_idx = 0;
for (i, to_skip) in to_skip.unwrap().enumerate() {
let num_chunks = new_offsets.get(i + 1).unwrap() - new_offsets.get(i).unwrap();
let slice_len = num_chunks as usize * size;
growable.extend(0, starting_idx, slice_len);
starting_idx += slice_len + to_skip;
}
let inner_list_field = field.to_exploded_field()?.to_fixed_size_list_field(size)?;
let inner_list = FixedSizeListArray::new(inner_list_field.clone(), growable.build()?, None);
Ok(ListArray::new(
inner_list_field.to_list_field()?,
inner_list.into_series(),
arrow2::offset::OffsetsBuffer::try_from(new_offsets)?,
validity.cloned(), // Copy the parent's validity.
)
.into_series())
}
}

impl ListArray {
pub fn count(&self, mode: CountMode) -> DaftResult<UInt64Array> {
let counts = match (mode, self.flat_child.validity()) {
Expand Down Expand Up @@ -274,6 +352,34 @@ impl ListArray {
end_iter,
)
}

pub fn get_chunks(&self, size: usize) -> DaftResult<Series> {
let mut to_skip = Vec::with_capacity(self.flat_child.len());
let mut new_offsets = Vec::with_capacity(self.flat_child.len() + 1);
let mut total_elements_to_skip = 0;
new_offsets.push(0);
for i in 0..self.offsets().len() - 1 {
let slice_len = self.offsets().get(i + 1).unwrap() - self.offsets().get(i).unwrap();
let modulo = slice_len as usize % size;
to_skip.push(modulo);
total_elements_to_skip += modulo;
new_offsets.push(new_offsets.last().unwrap() + (slice_len / size as i64));
}
let to_skip = if total_elements_to_skip == 0 {
None
} else {
Some(to_skip.iter().copied())
};
get_chunks_helper(
&self.flat_child,
self.field.clone(),
self.validity(),
size,
total_elements_to_skip,
to_skip,
new_offsets,
)
}
}

impl FixedSizeListArray {
Expand Down Expand Up @@ -420,6 +526,34 @@ impl FixedSizeListArray {
end_iter,
)
}

pub fn get_chunks(&self, size: usize) -> DaftResult<Series> {
let list_size = self.fixed_element_len();
let num_chunks = list_size / size;
let modulo = list_size % size;
let total_elements_to_skip = modulo * self.len();
let new_offsets: Vec<i64> = if !self.is_empty() && num_chunks > 0 {
(0..=((self.len() * num_chunks) as i64))
.step_by(num_chunks)
.collect()
} else {
vec![0; self.len() + 1]
};
let to_skip = if total_elements_to_skip == 0 {
None
} else {
Some(std::iter::repeat(modulo).take(self.len()))
};
get_chunks_helper(
&self.flat_child,
self.field.clone(),
self.validity(),
size,
total_elements_to_skip,
to_skip,
new_offsets,
)
}
}

macro_rules! impl_aggs_list_array {
Expand Down
12 changes: 12 additions & 0 deletions src/daft-core/src/datatypes/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,18 @@ impl Field {
})
}

pub fn to_fixed_size_list_field(&self, size: usize) -> DaftResult<Self> {
if self.dtype.is_python() {
return Ok(self.clone());
}
let list_dtype = DataType::FixedSizeList(Box::new(self.dtype.clone()), size);
Ok(Self {
name: self.name.clone(),
dtype: list_dtype,
metadata: self.metadata.clone(),
})
}

pub fn to_exploded_field(&self) -> DaftResult<Self> {
match &self.dtype {
DataType::List(child_dtype) | DataType::FixedSizeList(child_dtype, _) => {
Expand Down
10 changes: 10 additions & 0 deletions src/daft-core/src/series/ops/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ impl Series {
}
}

pub fn list_chunk(&self, size: usize) -> DaftResult<Series> {
match self.data_type() {
DataType::List(_) => self.list()?.get_chunks(size),
DataType::FixedSizeList(..) => self.fixed_size_list()?.get_chunks(size),
dt => Err(DaftError::TypeError(format!(
"list chunk not implemented for {dt}"
))),
}
}

pub fn list_sum(&self) -> DaftResult<Series> {
match self.data_type() {
DataType::List(_) => self.list()?.sum(),
Expand Down
53 changes: 53 additions & 0 deletions src/daft-dsl/src/functions/list/chunk.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use crate::ExprRef;
use daft_core::{datatypes::Field, schema::Schema, series::Series};

use super::{super::FunctionEvaluator, ListExpr};
use crate::functions::FunctionExpr;
use common_error::{DaftError, DaftResult};

pub(super) struct ChunkEvaluator {}

impl FunctionEvaluator for ChunkEvaluator {
fn fn_name(&self) -> &'static str {
"chunk"
}

fn to_field(
&self,
inputs: &[ExprRef],
schema: &Schema,
expr: &FunctionExpr,
) -> DaftResult<Field> {
let size = match expr {
FunctionExpr::List(ListExpr::Chunk(size)) => size,
_ => panic!("Expected Chunk Expr, got {expr}"),
};
match inputs {
[input] => {
let input_field = input.to_field(schema)?;
Ok(input_field
.to_exploded_field()?
.to_fixed_size_list_field(*size)?
.to_list_field()?)
}
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 1 input args, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult<Series> {
let size = match expr {
FunctionExpr::List(ListExpr::Chunk(size)) => size,
_ => panic!("Expected Chunk Expr, got {expr}"),
};
match inputs {
[input] => input.list_chunk(*size),
_ => Err(DaftError::ValueError(format!(
"Expected 1 input args, got {}",
inputs.len()
))),
}
}
}
12 changes: 12 additions & 0 deletions src/daft-dsl/src/functions/list/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod chunk;
mod count;
mod explode;
mod get;
Expand All @@ -8,6 +9,7 @@ mod min;
mod slice;
mod sum;

use chunk::ChunkEvaluator;
use count::CountEvaluator;
use daft_core::CountMode;
use explode::ExplodeEvaluator;
Expand Down Expand Up @@ -35,6 +37,7 @@ pub enum ListExpr {
Min,
Max,
Slice,
Chunk(usize),
}

impl ListExpr {
Expand All @@ -51,6 +54,7 @@ impl ListExpr {
Min => &MinEvaluator {},
Max => &MaxEvaluator {},
Slice => &SliceEvaluator {},
Chunk(_) => &ChunkEvaluator {},
}
}
}
Expand Down Expand Up @@ -126,3 +130,11 @@ pub fn slice(input: ExprRef, start: ExprRef, end: ExprRef) -> ExprRef {
}
.into()
}

pub fn chunk(input: ExprRef, size: usize) -> ExprRef {
Expr::Function {
func: super::FunctionExpr::List(ListExpr::Chunk(size)),
inputs: vec![input],
}
.into()
}
5 changes: 5 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,11 @@ impl PyExpr {
Ok(slice(self.into(), start.into(), end.into()).into())
}

pub fn list_chunk(&self, size: usize) -> PyResult<Self> {
use crate::functions::list::chunk;
Ok(chunk(self.into(), size).into())
}

pub fn struct_get(&self, name: &str) -> PyResult<Self> {
use crate::functions::struct_::get;
Ok(get(self.into(), name).into())
Expand Down
Loading
Loading