From c5442cf2fd8f046e6ad75d2d5c7efb2899dd654d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 27 Mar 2022 03:46:56 -0700 Subject: [PATCH] Fix generate_non_canonical_map_case, fix `MapArray` equality (#1476) * Revamp list_equal for map type * Canonicalize schema * Add nullability and metadata --- arrow/src/array/equal/list.rs | 102 ++++++++++++++---- arrow/src/array/equal/utils.rs | 30 +++++- .../src/bin/arrow-json-integration-test.rs | 45 +++++++- 3 files changed, 153 insertions(+), 24 deletions(-) diff --git a/arrow/src/array/equal/list.rs b/arrow/src/array/equal/list.rs index 20e6400d9520..000b31a1f785 100644 --- a/arrow/src/array/equal/list.rs +++ b/arrow/src/array/equal/list.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::datatypes::DataType; use crate::{ array::ArrayData, array::{data::count_nulls, OffsetSizeTrait}, @@ -22,7 +23,9 @@ use crate::{ util::bit_util::get_bit, }; -use super::{equal_range, utils::child_logical_null_buffer}; +use super::{ + equal_range, equal_values, utils::child_logical_null_buffer, utils::equal_nulls, +}; fn lengths_equal(lhs: &[T], rhs: &[T]) -> bool { // invariant from `base_equal` @@ -58,22 +61,47 @@ fn offset_value_equal( lhs_pos: usize, rhs_pos: usize, len: usize, + data_type: &DataType, ) -> bool { let lhs_start = lhs_offsets[lhs_pos].to_usize().unwrap(); let rhs_start = rhs_offsets[rhs_pos].to_usize().unwrap(); let lhs_len = lhs_offsets[lhs_pos + len] - lhs_offsets[lhs_pos]; let rhs_len = rhs_offsets[rhs_pos + len] - rhs_offsets[rhs_pos]; - lhs_len == rhs_len - && equal_range( - lhs_values, - rhs_values, - lhs_nulls, - rhs_nulls, - lhs_start, - rhs_start, - lhs_len.to_usize().unwrap(), - ) + lhs_len == rhs_len && { + match data_type { + DataType::Map(_, _) => { + // Don't use `equal_range` which calls `utils::base_equal` that checks + // struct fields, but we don't enforce struct field names. + equal_nulls( + lhs_values, + rhs_values, + lhs_nulls, + rhs_nulls, + lhs_start, + rhs_start, + lhs_len.to_usize().unwrap(), + ) && equal_values( + lhs_values, + rhs_values, + lhs_nulls, + rhs_nulls, + lhs_start, + rhs_start, + lhs_len.to_usize().unwrap(), + ) + } + _ => equal_range( + lhs_values, + rhs_values, + lhs_nulls, + rhs_nulls, + lhs_start, + rhs_start, + lhs_len.to_usize().unwrap(), + ), + } + } } pub(super) fn list_equal( @@ -131,17 +159,46 @@ pub(super) fn list_equal( lengths_equal( &lhs_offsets[lhs_start..lhs_start + len], &rhs_offsets[rhs_start..rhs_start + len], - ) && equal_range( - lhs_values, - rhs_values, - child_lhs_nulls.as_ref(), - child_rhs_nulls.as_ref(), - lhs_offsets[lhs_start].to_usize().unwrap(), - rhs_offsets[rhs_start].to_usize().unwrap(), - (lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start]) - .to_usize() - .unwrap(), - ) + ) && { + match lhs.data_type() { + DataType::Map(_, _) => { + // Don't use `equal_range` which calls `utils::base_equal` that checks + // struct fields, but we don't enforce struct field names. + equal_nulls( + lhs_values, + rhs_values, + child_lhs_nulls.as_ref(), + child_rhs_nulls.as_ref(), + lhs_offsets[lhs_start].to_usize().unwrap(), + rhs_offsets[rhs_start].to_usize().unwrap(), + (lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start]) + .to_usize() + .unwrap(), + ) && equal_values( + lhs_values, + rhs_values, + child_lhs_nulls.as_ref(), + child_rhs_nulls.as_ref(), + lhs_offsets[lhs_start].to_usize().unwrap(), + rhs_offsets[rhs_start].to_usize().unwrap(), + (lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start]) + .to_usize() + .unwrap(), + ) + } + _ => equal_range( + lhs_values, + rhs_values, + child_lhs_nulls.as_ref(), + child_rhs_nulls.as_ref(), + lhs_offsets[lhs_start].to_usize().unwrap(), + rhs_offsets[rhs_start].to_usize().unwrap(), + (lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start]) + .to_usize() + .unwrap(), + ), + } + } } else { // get a ref of the parent null buffer bytes, to use in testing for nullness let lhs_null_bytes = lhs_nulls.unwrap().as_slice(); @@ -166,6 +223,7 @@ pub(super) fn list_equal( lhs_pos, rhs_pos, 1, + lhs.data_type(), ) }) } diff --git a/arrow/src/array/equal/utils.rs b/arrow/src/array/equal/utils.rs index 819ae32c5709..1bced978c1b5 100644 --- a/arrow/src/array/equal/utils.rs +++ b/arrow/src/array/equal/utils.rs @@ -66,7 +66,35 @@ pub(super) fn equal_nulls( #[inline] pub(super) fn base_equal(lhs: &ArrayData, rhs: &ArrayData) -> bool { - lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() + let equal_type = match (lhs.data_type(), rhs.data_type()) { + (DataType::Map(l_field, l_sorted), DataType::Map(r_field, r_sorted)) => { + let field_equal = match (l_field.data_type(), r_field.data_type()) { + (DataType::Struct(l_fields), DataType::Struct(r_fields)) + if l_fields.len() == 2 && r_fields.len() == 2 => + { + let l_key_field = l_fields.get(0).unwrap(); + let r_key_field = r_fields.get(0).unwrap(); + let l_value_field = l_fields.get(1).unwrap(); + let r_value_field = r_fields.get(1).unwrap(); + + // We don't enforce the equality of field names + let data_type_equal = l_key_field.data_type() + == r_key_field.data_type() + && l_value_field.data_type() == r_value_field.data_type(); + let nullability_equal = l_key_field.is_nullable() + == r_key_field.is_nullable() + && l_value_field.is_nullable() == r_value_field.is_nullable(); + let metadata_equal = l_key_field.metadata() == r_key_field.metadata() + && l_value_field.metadata() == r_value_field.metadata(); + data_type_equal && nullability_equal && metadata_equal + } + _ => panic!("Map type should have 2 fields Struct in its field"), + }; + field_equal && l_sorted == r_sorted + } + (l_data_type, r_data_type) => l_data_type == r_data_type, + }; + equal_type && lhs.len() == rhs.len() } // whether the two memory regions are equal diff --git a/integration-testing/src/bin/arrow-json-integration-test.rs b/integration-testing/src/bin/arrow-json-integration-test.rs index 17d2528e07ff..69b73b19f222 100644 --- a/integration-testing/src/bin/arrow-json-integration-test.rs +++ b/integration-testing/src/bin/arrow-json-integration-test.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use arrow::datatypes::Schema; +use arrow::datatypes::{DataType, Field}; use arrow::error::{ArrowError, Result}; use arrow::ipc::reader::FileReader; use arrow::ipc::writer::FileWriter; @@ -107,6 +109,47 @@ fn arrow_to_json(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> Ok(()) } +fn canonicalize_schema(schema: &Schema) -> Schema { + let fields = schema + .fields() + .iter() + .map(|field| match field.data_type() { + DataType::Map(child_field, sorted) => match child_field.data_type() { + DataType::Struct(fields) if fields.len() == 2 => { + let first_field = fields.get(0).unwrap(); + let key_field = Field::new( + "key", + first_field.data_type().clone(), + first_field.is_nullable(), + ); + let second_field = fields.get(1).unwrap(); + let value_field = Field::new( + "value", + second_field.data_type().clone(), + second_field.is_nullable(), + ); + + let struct_type = DataType::Struct(vec![key_field, value_field]); + let child_field = + Field::new("entries", struct_type, child_field.is_nullable()); + + Field::new( + field.name().as_str(), + DataType::Map(Box::new(child_field), *sorted), + field.is_nullable(), + ) + } + _ => panic!( + "The child field of Map type should be Struct type with 2 fields." + ), + }, + _ => field.clone(), + }) + .collect::>(); + + Schema::new(fields).with_metadata(schema.metadata().clone()) +} + fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { if verbose { eprintln!("Validating {} and {}", arrow_name, json_name); @@ -121,7 +164,7 @@ fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { let arrow_schema = arrow_reader.schema().as_ref().to_owned(); // compare schemas - if json_file.schema != arrow_schema { + if canonicalize_schema(&json_file.schema) != canonicalize_schema(&arrow_schema) { return Err(ArrowError::ComputeError(format!( "Schemas do not match. JSON: {:?}. Arrow: {:?}", json_file.schema, arrow_schema