Skip to content

Commit

Permalink
fix: Deal with masked out list elements (#20161)
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite authored Dec 5, 2024
1 parent cbc0ea0 commit b019e42
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 14 deletions.
109 changes: 95 additions & 14 deletions crates/polars-row/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,20 @@ pub fn convert_columns_amortized<'a>(
fields: impl IntoIterator<Item = (RowEncodingOptions, Option<&'a RowEncodingCatOrder>)> + Clone,
rows: &mut RowsEncoded,
) {
let mut masked_out_max_length = 0;
let mut row_widths = RowWidths::new(num_rows);
let mut encoders = columns
.iter()
.zip(fields.clone())
.map(|(column, (opt, dicts))| get_encoder(column.as_ref(), opt, dicts, &mut row_widths))
.map(|(column, (opt, dicts))| {
get_encoder(
column.as_ref(),
opt,
dicts,
&mut row_widths,
&mut masked_out_max_length,
)
})
.collect::<Vec<_>>();

// Create an offsets array, we append 0 at the beginning here so it can serve as the final
Expand All @@ -76,9 +85,10 @@ pub fn convert_columns_amortized<'a>(

// Create a buffer without initializing everything to zero.
let total_num_bytes = row_widths.sum();
let mut out = Vec::<u8>::with_capacity(total_num_bytes);
let buffer = &mut out.spare_capacity_mut()[..total_num_bytes];
let mut out = Vec::<u8>::with_capacity(total_num_bytes + masked_out_max_length);
let buffer = &mut out.spare_capacity_mut()[..total_num_bytes + masked_out_max_length];

let masked_out_write_offset = total_num_bytes;
let mut scratches = EncodeScratches::default();
for (encoder, (opt, dict)) in encoders.iter_mut().zip(fields) {
unsafe {
Expand All @@ -88,6 +98,7 @@ pub fn convert_columns_amortized<'a>(
opt,
dict,
&mut offsets[1..],
masked_out_write_offset,
&mut scratches,
)
};
Expand All @@ -108,13 +119,20 @@ fn list_num_column_bytes<O: Offset>(
opt: RowEncodingOptions,
dicts: Option<&RowEncodingCatOrder>,
row_widths: &mut RowWidths,
masked_out_max_width: &mut usize,
) -> Encoder {
let array = array.as_any().downcast_ref::<ListArray<O>>().unwrap();
let array = array.trim_to_normalized_offsets_recursive();
let values = array.values();

let mut list_row_widths = RowWidths::new(values.len());
let encoder = get_encoder(values.as_ref(), opt, dicts, &mut list_row_widths);
let encoder = get_encoder(
values.as_ref(),
opt,
dicts,
&mut list_row_widths,
masked_out_max_width,
);

match array.validity() {
None => row_widths.push_iter(array.offsets().offset_and_length_iter().map(
Expand All @@ -133,6 +151,12 @@ fn list_num_column_bytes<O: Offset>(
.zip(validity.iter())
.map(|((offset, length), is_valid)| {
if !is_valid {
if length > 0 {
for i in offset..offset + length {
*masked_out_max_width =
(*masked_out_max_width).max(list_row_widths.get(i));
}
}
return 1;
}

Expand Down Expand Up @@ -261,6 +285,7 @@ fn get_encoder(
opt: RowEncodingOptions,
dict: Option<&RowEncodingCatOrder>,
row_widths: &mut RowWidths,
masked_out_max_width: &mut usize,
) -> Encoder {
use ArrowDataType as D;
let dtype = array.dtype();
Expand All @@ -275,8 +300,13 @@ fn get_encoder(

debug_assert_eq!(array.values().len(), array.len() * width);
let mut nested_row_widths = RowWidths::new(array.values().len());
let nested_encoder =
get_encoder(array.values().as_ref(), opt, dict, &mut nested_row_widths);
let nested_encoder = get_encoder(
array.values().as_ref(),
opt,
dict,
&mut nested_row_widths,
masked_out_max_width,
);
Some(EncoderState::FixedSizeList(
Box::new(nested_encoder),
*width,
Expand All @@ -297,6 +327,7 @@ fn get_encoder(
opt,
None,
&mut RowWidths::new(row_widths.num_rows()),
masked_out_max_width,
)
})
.collect(),
Expand All @@ -310,6 +341,7 @@ fn get_encoder(
opt,
dict.as_ref(),
&mut RowWidths::new(row_widths.num_rows()),
masked_out_max_width,
)
})
.collect(),
Expand All @@ -333,8 +365,13 @@ fn get_encoder(

debug_assert_eq!(array.values().len(), array.len() * width);
let mut nested_row_widths = RowWidths::new(array.values().len());
let nested_encoder =
get_encoder(array.values().as_ref(), opt, dict, &mut nested_row_widths);
let nested_encoder = get_encoder(
array.values().as_ref(),
opt,
dict,
&mut nested_row_widths,
masked_out_max_width,
);

let mut fsl_row_widths = nested_row_widths.collapse_chunks(*width, array.len());
fsl_row_widths.push_constant(1); // validity byte
Expand All @@ -358,13 +395,25 @@ fn get_encoder(
match dict {
None => {
for array in array.values() {
let encoder = get_encoder(array.as_ref(), opt, None, row_widths);
let encoder = get_encoder(
array.as_ref(),
opt,
None,
row_widths,
masked_out_max_width,
);
nested_encoders.push(encoder);
}
},
Some(RowEncodingCatOrder::Struct(dicts)) => {
for (array, dict) in array.values().iter().zip(dicts) {
let encoder = get_encoder(array.as_ref(), opt, dict.as_ref(), row_widths);
let encoder = get_encoder(
array.as_ref(),
opt,
dict.as_ref(),
row_widths,
masked_out_max_width,
);
nested_encoders.push(encoder);
}
},
Expand All @@ -376,8 +425,12 @@ fn get_encoder(
}
},

D::List(_) => list_num_column_bytes::<i32>(array, opt, dict, row_widths),
D::LargeList(_) => list_num_column_bytes::<i64>(array, opt, dict, row_widths),
D::List(_) => {
list_num_column_bytes::<i32>(array, opt, dict, row_widths, masked_out_max_width)
},
D::LargeList(_) => {
list_num_column_bytes::<i64>(array, opt, dict, row_widths, masked_out_max_width)
},

D::BinaryView => {
let dc_array = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
Expand Down Expand Up @@ -654,6 +707,9 @@ unsafe fn encode_array(
opt: RowEncodingOptions,
dict: Option<&RowEncodingCatOrder>,
offsets: &mut [usize],
masked_out_write_offset: usize, // Masked out values need to be written somewhere. We just
// reserved space at the end and tell all values to write
// there.
scratches: &mut EncodeScratches,
) {
let Some(state) = &encoder.state else {
Expand Down Expand Up @@ -709,6 +765,13 @@ unsafe fn encode_array(
if !is_valid {
buffer[offsets[i]] = MaybeUninit::new(list_null_sentinel);
offsets[i] += 1;

// Values might have been masked out.
if length > 0 {
nested_offsets
.extend(std::iter::repeat_n(masked_out_write_offset, length));
}

continue;
}

Expand All @@ -732,6 +795,7 @@ unsafe fn encode_array(
opt,
dict,
nested_offsets,
masked_out_write_offset,
&mut EncodeScratches::default(),
)
};
Expand All @@ -756,6 +820,7 @@ unsafe fn encode_array(
opt,
dict,
&mut child_offsets,
masked_out_write_offset,
scratches,
);
for (i, offset) in offsets.iter_mut().enumerate() {
Expand All @@ -768,12 +833,28 @@ unsafe fn encode_array(
match dict {
None => {
for array in arrays {
encode_array(buffer, array, opt, None, offsets, scratches);
encode_array(
buffer,
array,
opt,
None,
offsets,
masked_out_write_offset,
scratches,
);
}
},
Some(RowEncodingCatOrder::Struct(dicts)) => {
for (array, dict) in arrays.iter().zip(dicts) {
encode_array(buffer, array, opt, dict.as_ref(), offsets, scratches);
encode_array(
buffer,
array,
opt,
dict.as_ref(),
offsets,
masked_out_write_offset,
scratches,
);
}
},
_ => unreachable!(),
Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/unit/test_row_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,18 @@ def test_list_nulls(field: tuple[bool, bool, bool]) -> None:
roundtrip_series_re([[None], [None, None], [None, None, None]], dtype, field)


@pytest.mark.parametrize("field", FIELD_COMBS)
def test_masked_out_list_20151(field: tuple[bool, bool, bool]) -> None:
dtype = pl.List(pl.Int64())

values = [[1, 2], None, [4, 5], [None, 3]]

array_series = pl.Series(values, dtype=pl.Array(pl.Int64(), 2))
list_from_array_series = array_series.cast(dtype)

roundtrip_series_re(list_from_array_series, dtype, field)


def test_int_after_null() -> None:
roundtrip_re(
pl.DataFrame(
Expand Down

0 comments on commit b019e42

Please sign in to comment.