Skip to content

Commit

Permalink
Gate dyn comparison of dictionary arrays behind dyn_cmp_dict (#2597)
Browse files Browse the repository at this point in the history
* Add dyn_cmp_dict feature flag

* Fix tests

* Clippy
  • Loading branch information
tustvold authored Aug 27, 2022
1 parent 86446ea commit 6ab208c
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 7 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/arrow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ jobs:
- name: Test
run: |
cargo test -p arrow
- name: Test --features=force_validate,prettyprint,ipc_compression,ffi
- name: Test --features=force_validate,prettyprint,ipc_compression,ffi,dyn_cmp_dict
run: |
cargo test -p arrow --features=force_validate,prettyprint,ipc_compression,ffi
cargo test -p arrow --features=force_validate,prettyprint,ipc_compression,ffi,dyn_cmp_dict
- name: Test --features=nan_ordering
run: |
cargo test -p arrow --features "nan_ordering"
Expand Down Expand Up @@ -175,4 +175,4 @@ jobs:
rustup component add clippy
- name: Run clippy
run: |
cargo clippy -p arrow --features=prettyprint,csv,ipc,test_utils,ffi,ipc_compression --all-targets -- -D warnings
cargo clippy -p arrow --features=prettyprint,csv,ipc,test_utils,ffi,ipc_compression,dyn_cmp_dict --all-targets -- -D warnings
11 changes: 7 additions & 4 deletions arrow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ path = "src/lib.rs"
bench = false

[target.'cfg(target_arch = "wasm32")'.dependencies]
ahash = { version = "0.8", default-features = false, features=["compile-time-rng"] }
ahash = { version = "0.8", default-features = false, features = ["compile-time-rng"] }

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
ahash = { version = "0.8", default-features = false, features=["runtime-rng"] }
ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] }

[dependencies]
serde = { version = "1.0", default-features = false }
Expand Down Expand Up @@ -90,6 +90,9 @@ force_validate = []
ffi = []
# Enable NaN-ordering behavior on comparison kernels
nan_ordering = []
# Enable dyn-comparison of dictionary arrays with other arrays
# Note: this does not impact comparison against scalars
dyn_cmp_dict = []

[dev-dependencies]
rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] }
Expand All @@ -102,7 +105,7 @@ tempfile = { version = "3", default-features = false }
[[example]]
name = "dynamic_types"
required-features = ["prettyprint"]
path="./examples/dynamic_types.rs"
path = "./examples/dynamic_types.rs"

[[bench]]
name = "aggregate_kernels"
Expand Down Expand Up @@ -144,7 +147,7 @@ required-features = ["test_utils"]
[[bench]]
name = "comparison_kernels"
harness = false
required-features = ["test_utils"]
required-features = ["test_utils", "dyn_cmp_dict"]

[[bench]]
name = "filter_kernels"
Expand Down
40 changes: 40 additions & 0 deletions arrow/src/compute/kernels/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2089,6 +2089,7 @@ where
compare_op(left_array, right_array, op)
}

#[cfg(feature = "dyn_cmp_dict")]
macro_rules! typed_dict_non_dict_cmp {
($LEFT: expr, $RIGHT: expr, $LEFT_KEY_TYPE: expr, $RIGHT_TYPE: tt, $OP_BOOL: expr, $OP: expr) => {{
match $LEFT_KEY_TYPE {
Expand Down Expand Up @@ -2132,6 +2133,7 @@ macro_rules! typed_dict_non_dict_cmp {
}};
}

#[cfg(feature = "dyn_cmp_dict")]
macro_rules! typed_cmp_dict_non_dict {
($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr) => {{
match ($LEFT.data_type(), $RIGHT.data_type()) {
Expand Down Expand Up @@ -2182,6 +2184,16 @@ macro_rules! typed_cmp_dict_non_dict {
}};
}

#[cfg(not(feature = "dyn_cmp_dict"))]
macro_rules! typed_cmp_dict_non_dict {
($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr) => {{
Err(ArrowError::CastError(format!(
"Comparing dictionary array of type {} with array of type {} requires \"dyn_cmp_dict\" feature",
$LEFT.data_type(), $RIGHT.data_type()
)))
}}
}

macro_rules! typed_compares {
($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr, $OP_FLOAT: expr) => {{
match ($LEFT.data_type(), $RIGHT.data_type()) {
Expand Down Expand Up @@ -2298,6 +2310,7 @@ macro_rules! typed_compares {
}

/// Applies $OP to $LEFT and $RIGHT which are two dictionaries which have (the same) key type $KT
#[cfg(feature = "dyn_cmp_dict")]
macro_rules! typed_dict_cmp {
($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_FLOAT: expr, $OP_BOOL: expr, $KT: tt) => {{
match ($LEFT.value_type(), $RIGHT.value_type()) {
Expand Down Expand Up @@ -2430,6 +2443,7 @@ macro_rules! typed_dict_cmp {
}};
}

#[cfg(feature = "dyn_cmp_dict")]
macro_rules! typed_dict_compares {
// Applies `LEFT OP RIGHT` when `LEFT` and `RIGHT` both are `DictionaryArray`
($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_FLOAT: expr, $OP_BOOL: expr) => {{
Expand Down Expand Up @@ -2494,8 +2508,19 @@ macro_rules! typed_dict_compares {
}};
}

#[cfg(not(feature = "dyn_cmp_dict"))]
macro_rules! typed_dict_compares {
($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_FLOAT: expr, $OP_BOOL: expr) => {{
Err(ArrowError::CastError(format!(
"Comparing array of type {} with array of type {} requires \"dyn_cmp_dict\" feature",
$LEFT.data_type(), $RIGHT.data_type()
)))
}}
}

/// Perform given operation on `DictionaryArray` and `PrimitiveArray`. The value
/// type of `DictionaryArray` is same as `PrimitiveArray`'s type.
#[cfg(feature = "dyn_cmp_dict")]
fn cmp_dict_primitive<K, T, F>(
left: &DictionaryArray<K>,
right: &dyn Array,
Expand All @@ -2516,6 +2541,7 @@ where
/// Perform given operation on two `DictionaryArray`s which value type is
/// primitive type. Returns an error if the two arrays have different value
/// type
#[cfg(feature = "dyn_cmp_dict")]
pub fn cmp_dict<K, T, F>(
left: &DictionaryArray<K>,
right: &DictionaryArray<K>,
Expand All @@ -2535,6 +2561,7 @@ where

/// Perform the given operation on two `DictionaryArray`s which value type is
/// `DataType::Boolean`.
#[cfg(feature = "dyn_cmp_dict")]
pub fn cmp_dict_bool<K, F>(
left: &DictionaryArray<K>,
right: &DictionaryArray<K>,
Expand All @@ -2553,6 +2580,7 @@ where

/// Perform the given operation on two `DictionaryArray`s which value type is
/// `DataType::Utf8` or `DataType::LargeUtf8`.
#[cfg(feature = "dyn_cmp_dict")]
pub fn cmp_dict_utf8<K, OffsetSize: OffsetSizeTrait, F>(
left: &DictionaryArray<K>,
right: &DictionaryArray<K>,
Expand All @@ -2574,6 +2602,7 @@ where

/// Perform the given operation on two `DictionaryArray`s which value type is
/// `DataType::Binary` or `DataType::LargeBinary`.
#[cfg(feature = "dyn_cmp_dict")]
pub fn cmp_dict_binary<K, OffsetSize: OffsetSizeTrait, F>(
left: &DictionaryArray<K>,
right: &DictionaryArray<K>,
Expand Down Expand Up @@ -5476,6 +5505,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_eq_dyn_neq_dyn_dictionary_i8_array() {
// Construct a value array
let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]);
Expand All @@ -5496,6 +5526,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_eq_dyn_neq_dyn_dictionary_u64_array() {
let values = UInt64Array::from_iter_values([10_u64, 11, 12, 13, 14, 15, 16, 17]);

Expand All @@ -5517,6 +5548,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_eq_dyn_neq_dyn_dictionary_utf8_array() {
let test1 = vec!["a", "a", "b", "c"];
let test2 = vec!["a", "b", "b", "c"];
Expand Down Expand Up @@ -5544,6 +5576,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_eq_dyn_neq_dyn_dictionary_binary_array() {
let values: BinaryArray = ["hello", "", "parquet"]
.into_iter()
Expand All @@ -5568,6 +5601,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_eq_dyn_neq_dyn_dictionary_interval_array() {
let values = IntervalDayTimeArray::from(vec![1, 6, 10, 2, 3, 5]);

Expand All @@ -5589,6 +5623,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_eq_dyn_neq_dyn_dictionary_date_array() {
let values = Date32Array::from(vec![1, 6, 10, 2, 3, 5]);

Expand All @@ -5610,6 +5645,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_eq_dyn_neq_dyn_dictionary_bool_array() {
let values = BooleanArray::from(vec![true, false]);

Expand All @@ -5631,6 +5667,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_lt_dyn_gt_dyn_dictionary_i8_array() {
// Construct a value array
let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]);
Expand Down Expand Up @@ -5660,6 +5697,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_lt_dyn_gt_dyn_dictionary_bool_array() {
let values = BooleanArray::from(vec![true, false]);

Expand Down Expand Up @@ -5702,6 +5740,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_eq_dyn_neq_dyn_dictionary_i8_i8_array() {
let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]);
let keys = Int8Array::from_iter_values([2_i8, 3, 4]);
Expand Down Expand Up @@ -5736,6 +5775,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_lt_dyn_lt_eq_dyn_gt_dyn_gt_eq_dyn_dictionary_i8_i8_array() {
let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]);
let keys = Int8Array::from_iter_values([2_i8, 3, 4]);
Expand Down

0 comments on commit 6ab208c

Please sign in to comment.