Skip to content

Commit

Permalink
Fix ListArray and StructArray equality (apache#626)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Apr 19, 2022
1 parent 786792c commit ecf753b
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 230 deletions.
183 changes: 67 additions & 116 deletions arrow/src/array/equal/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,14 @@
// specific language governing permissions and limitations
// under the License.

use crate::datatypes::DataType;
use crate::{
array::ArrayData,
array::{data::count_nulls, OffsetSizeTrait},
buffer::Buffer,
util::bit_util::get_bit,
};

use super::{
equal_range, equal_values, utils::child_logical_null_buffer, utils::equal_nulls,
};
use super::equal_range;

fn lengths_equal<T: OffsetSizeTrait>(lhs: &[T], rhs: &[T]) -> bool {
// invariant from `base_equal`
Expand All @@ -49,61 +46,6 @@ fn lengths_equal<T: OffsetSizeTrait>(lhs: &[T], rhs: &[T]) -> bool {
})
}

#[allow(clippy::too_many_arguments)]
#[inline]
fn offset_value_equal<T: OffsetSizeTrait>(
lhs_values: &ArrayData,
rhs_values: &ArrayData,
lhs_nulls: Option<&Buffer>,
rhs_nulls: Option<&Buffer>,
lhs_offsets: &[T],
rhs_offsets: &[T],
lhs_pos: usize,
rhs_pos: usize,
len: usize,
data_type: &DataType,
) -> bool {
let lhs_start = lhs_offsets[lhs_pos].to_usize().unwrap();
let rhs_start = rhs_offsets[rhs_pos].to_usize().unwrap();
let lhs_len = lhs_offsets[lhs_pos + len] - lhs_offsets[lhs_pos];
let rhs_len = rhs_offsets[rhs_pos + len] - rhs_offsets[rhs_pos];

lhs_len == rhs_len && {
match data_type {
DataType::Map(_, _) => {
// Don't use `equal_range` which calls `utils::base_equal` that checks
// struct fields, but we don't enforce struct field names.
equal_nulls(
lhs_values,
rhs_values,
lhs_nulls,
rhs_nulls,
lhs_start,
rhs_start,
lhs_len.to_usize().unwrap(),
) && equal_values(
lhs_values,
rhs_values,
lhs_nulls,
rhs_nulls,
lhs_start,
rhs_start,
lhs_len.to_usize().unwrap(),
)
}
_ => equal_range(
lhs_values,
rhs_values,
lhs_nulls,
rhs_nulls,
lhs_start,
rhs_start,
lhs_len.to_usize().unwrap(),
),
}
}
}

pub(super) fn list_equal<T: OffsetSizeTrait>(
lhs: &ArrayData,
rhs: &ArrayData,
Expand All @@ -123,7 +65,7 @@ pub(super) fn list_equal<T: OffsetSizeTrait>(
// no child values. This causes panics when trying to count set bits.
//
// We caught this by chance from an accidental test-case, but due to the nature of this
// crash only occuring on list equality checks, we are adding a check here, instead of
// crash only occurring on list equality checks, we are adding a check here, instead of
// on the buffer/bitmap utilities, as a length check would incur a penalty for almost all
// other use-cases.
//
Expand All @@ -149,82 +91,91 @@ pub(super) fn list_equal<T: OffsetSizeTrait>(
let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len);
let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len);

// compute the child logical bitmap
let child_lhs_nulls =
child_logical_null_buffer(lhs, lhs_nulls, lhs.child_data().get(0).unwrap());
let child_rhs_nulls =
child_logical_null_buffer(rhs, rhs_nulls, rhs.child_data().get(0).unwrap());
if lhs_null_count != rhs_null_count {
return false;
}

if lhs_null_count == 0 && rhs_null_count == 0 {
lengths_equal(
&lhs_offsets[lhs_start..lhs_start + len],
&rhs_offsets[rhs_start..rhs_start + len],
) && {
match lhs.data_type() {
DataType::Map(_, _) => {
// Don't use `equal_range` which calls `utils::base_equal` that checks
// struct fields, but we don't enforce struct field names.
equal_nulls(
lhs_values,
rhs_values,
child_lhs_nulls.as_ref(),
child_rhs_nulls.as_ref(),
lhs_offsets[lhs_start].to_usize().unwrap(),
rhs_offsets[rhs_start].to_usize().unwrap(),
(lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start])
.to_usize()
.unwrap(),
) && equal_values(
lhs_values,
rhs_values,
child_lhs_nulls.as_ref(),
child_rhs_nulls.as_ref(),
lhs_offsets[lhs_start].to_usize().unwrap(),
rhs_offsets[rhs_start].to_usize().unwrap(),
(lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start])
.to_usize()
.unwrap(),
)
}
_ => equal_range(
lhs_values,
rhs_values,
child_lhs_nulls.as_ref(),
child_rhs_nulls.as_ref(),
lhs_offsets[lhs_start].to_usize().unwrap(),
rhs_offsets[rhs_start].to_usize().unwrap(),
(lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start])
.to_usize()
.unwrap(),
),
}
}
) && equal_range(
lhs_values,
rhs_values,
lhs_values.null_buffer(),
rhs_values.null_buffer(),
lhs_offsets[lhs_start].to_usize().unwrap(),
rhs_offsets[rhs_start].to_usize().unwrap(),
(lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start])
.to_usize()
.unwrap(),
)
} else {
// get a ref of the parent null buffer bytes, to use in testing for nullness
let lhs_null_bytes = lhs_nulls.unwrap().as_slice();
let rhs_null_bytes = rhs_nulls.unwrap().as_slice();

// with nulls, we need to compare item by item whenever it is not null
// TODO: Could potentially compare runs of not NULL values
(0..len).all(|i| {
let lhs_pos = lhs_start + i;
let rhs_pos = rhs_start + i;

let lhs_is_null = !get_bit(lhs_null_bytes, lhs_pos + lhs.offset());
let rhs_is_null = !get_bit(rhs_null_bytes, rhs_pos + rhs.offset());

if lhs_is_null != rhs_is_null {
return false;
}

let lhs_offset_start = lhs_offsets[lhs_pos].to_usize().unwrap();
let lhs_offset_end = lhs_offsets[lhs_pos + 1].to_usize().unwrap();
let rhs_offset_start = rhs_offsets[rhs_pos].to_usize().unwrap();
let rhs_offset_end = rhs_offsets[rhs_pos + 1].to_usize().unwrap();

let lhs_len = lhs_offset_end - lhs_offset_start;
let rhs_len = rhs_offset_end - rhs_offset_start;

lhs_is_null
|| (lhs_is_null == rhs_is_null)
&& offset_value_equal::<T>(
|| (lhs_len == rhs_len
&& equal_range(
lhs_values,
rhs_values,
child_lhs_nulls.as_ref(),
child_rhs_nulls.as_ref(),
lhs_offsets,
rhs_offsets,
lhs_pos,
rhs_pos,
1,
lhs.data_type(),
)
lhs_values.null_buffer(),
rhs_values.null_buffer(),
lhs_offset_start,
rhs_offset_start,
lhs_len,
))
})
}
}

#[cfg(test)]
mod tests {
use crate::array::{Int64Builder, ListBuilder};

#[test]
fn list_array_non_zero_nulls() {
// Tests handling of list arrays with non-empty null ranges
let mut builder = ListBuilder::new(Int64Builder::new(10));
builder.values().append_value(1).unwrap();
builder.values().append_value(2).unwrap();
builder.values().append_value(3).unwrap();
builder.append(true).unwrap();
builder.append(false).unwrap();
let array1 = builder.finish();

let mut builder = ListBuilder::new(Int64Builder::new(10));
builder.values().append_value(1).unwrap();
builder.values().append_value(2).unwrap();
builder.values().append_value(3).unwrap();
builder.append(true).unwrap();
builder.values().append_null().unwrap();
builder.values().append_null().unwrap();
builder.append(false).unwrap();
let array2 = builder.finish();

assert_eq!(array1, array2);
}
}
3 changes: 1 addition & 2 deletions arrow/src/array/equal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,7 @@ fn equal_range(
rhs_start: usize,
len: usize,
) -> bool {
utils::base_equal(lhs, rhs)
&& utils::equal_nulls(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len)
utils::equal_nulls(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len)
&& equal_values(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len)
}

Expand Down
Loading

0 comments on commit ecf753b

Please sign in to comment.