diff --git a/arrow-ord/src/cmp.rs b/arrow-ord/src/cmp.rs index 3db88e97937a..bc5d55fec3f7 100644 --- a/arrow-ord/src/cmp.rs +++ b/arrow-ord/src/cmp.rs @@ -26,8 +26,7 @@ use arrow_array::cast::AsArray; use arrow_array::types::ByteArrayType; use arrow_array::{ - downcast_primitive_array, AnyDictionaryArray, Array, ArrowNativeTypeOp, BooleanArray, Datum, - FixedSizeBinaryArray, GenericByteArray, + downcast_primitive_array, AnyDictionaryArray, Array, ArrowNativeTypeOp, BooleanArray, Datum, FixedSizeBinaryArray, GenericByteArray }; use arrow_buffer::bit_util::ceil; use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer}; @@ -36,7 +35,7 @@ use arrow_select::take::take; use std::cmp::Ordering; use std::ops::Not; -use crate::ord::{build_compare, DynComparator}; +use crate::ord::build_compare; #[derive(Debug, Copy, Clone)] enum Op { @@ -177,11 +176,13 @@ fn compare_list(l: &dyn Array, r: &dyn Array) -> Result { let min_len = std::cmp::min(l_len, r_len); if let (DataType::List(_), DataType::List(_)) = (l_t, r_t) { + let l = l.as_list::(); + let r = r.as_list::(); // Since `compare_op` does not support inconsistent lengths, we compare the // prefix with `compare_op` only, and compare the left if the prefix is equal for i in 0..min_len { - let l = l.as_list::().value(i); - let r = r.as_list::().value(i); + let l = l.value(i); + let r = r.value(i); let ord = compare_list(&l, &r)?; if ord != Ordering::Equal { return Ok(ord); @@ -217,51 +218,33 @@ fn process_nested( ) -> Result, ArrowError> { use arrow_schema::DataType::*; if let (List(_), List(_)) = (l_t, r_t) { + fn process_ordering( + l: &dyn Array, + r: &dyn Array, + target_ord: Ordering, + len: usize, + ) -> Result { + let l = l.as_list::(); + let r = r.as_list::(); + let mut values = BooleanArray::builder(len); + for i in 0..len { + let l = l.value(i); + let r = r.value(i); + let ord = compare_list(&l, &r)?; + values.append_value(ord == target_ord); + } + Ok(values.finish()) + } + // Process nested data types match op { Op::Less => { - let l = l.as_list::(); - let r = r.as_list::(); - let mut values = BooleanArray::builder(len); - for i in 0..l.len() { - let l = l.value(i); - let r = r.value(i); - let v = compare_list(&l, &r)?; - values.append_value(v == Ordering::Less); - } - - let values = values.finish(); - Ok(Some(values)) + let v = process_ordering(l, r, Ordering::Less, len)?; + Ok(Some(v)) } Op::Equal => { - let l = l.as_list::(); - let r = r.as_list::(); - let mut values = BooleanArray::builder(len); - for i in 0..l.len() { - let l = l.value(i); - let r = r.value(i); - let l_len = l.len(); - let r_len = r.len(); - if l_len != r_len { - values.append_value(false); - continue; - } - - let eq_res = eq(&l, &r)?; - fn post_process(eq: &BooleanArray) -> bool { - for j in 0..eq.len() { - if !eq.value(j) { - return false; - } - } - true - } - - values.append_value(post_process(&eq_res)); - } - - let values = values.finish(); - Ok(Some(values)) + let v = process_ordering(l, r, Ordering::Equal, len)?; + Ok(Some(v)) } _ => Err(ArrowError::NotYetImplemented(format!( "Comparison for {op} is NYI"