diff --git a/cpp/include/cudf/table/experimental/row_operators.cuh b/cpp/include/cudf/table/experimental/row_operators.cuh index 2ed45c71633..c4a1ad4747a 100644 --- a/cpp/include/cudf/table/experimental/row_operators.cuh +++ b/cpp/include/cudf/table/experimental/row_operators.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include @@ -45,6 +46,7 @@ #include namespace cudf { + namespace experimental { /** @@ -68,16 +70,17 @@ struct dispatch_void_if_nested { }; namespace row { - namespace lexicographic { /** - * @brief Computes whether one row is lexicographically *less* than another row. + * @brief Computes the lexicographic comparison between 2 rows. * * Lexicographic ordering is determined by: * - Two rows are compared element by element. * - The first mismatching element defines which row is lexicographically less * or greater than the other. + * - If the rows are compared without mismatched elements, the rows are equivalent + * * * Lexicographic ordering is exactly equivalent to doing an alphabetical sort of * two words, for example, `aac` would be *less* than (or precede) `abb`. The @@ -85,11 +88,13 @@ namespace lexicographic { * `aac < abb`. * * @tparam Nullate A cudf::nullate type describing whether to check for nulls. + * @tparam NanConfig default configuration nans are equal, if set to true triggers specialized IEEE + * 754 compliant nan handling */ -template +template class device_row_comparator { - friend class self_comparator; - + // friend class self_comparator; + public: // needs to be removed, pending strict typing for indices /** * @brief Construct a function object for performing a lexicographic * comparison between the rows of two tables. @@ -139,13 +144,20 @@ class device_row_comparator { * @param null_precedence Indicates how null values are ordered with other values * @param depth The depth of the column if part of a nested column @see * preprocessed_table::depths + * @param nan_result Specifies what value should be returned if either element is `nan` */ __device__ element_comparator(Nullate check_nulls, column_device_view lhs, column_device_view rhs, null_order null_precedence = null_order::BEFORE, - int depth = 0) - : _lhs{lhs}, _rhs{rhs}, _nulls{check_nulls}, _null_precedence{null_precedence}, _depth{depth} + int depth = 0, + weak_ordering nan_result = weak_ordering::EQUIVALENT) + : _lhs{lhs}, + _rhs{rhs}, + _check_nulls{check_nulls}, + _null_precedence{null_precedence}, + _depth{depth}, + _nan_result{nan_result} { } @@ -162,7 +174,7 @@ class device_row_comparator { __device__ cuda::std::pair operator()( size_type const lhs_element_index, size_type const rhs_element_index) const noexcept { - if (_nulls) { + if (_check_nulls) { bool const lhs_is_null{_lhs.is_null(lhs_element_index)}; bool const rhs_is_null{_rhs.is_null(rhs_element_index)}; @@ -171,8 +183,12 @@ class device_row_comparator { } } - return cuda::std::pair(relational_compare(_lhs.element(lhs_element_index), - _rhs.element(rhs_element_index)), + return cuda::std::pair(NanConfig + ? relational_compare(_lhs.element(lhs_element_index), + _rhs.element(rhs_element_index), + _nan_result) + : relational_compare(_lhs.element(lhs_element_index), + _rhs.element(rhs_element_index)), std::numeric_limits::max()); } @@ -211,7 +227,8 @@ class device_row_comparator { ++depth; } - auto const comparator = element_comparator{_nulls, lcol, rcol, _null_precedence, depth}; + auto const comparator = + element_comparator{_check_nulls, lcol, rcol, _null_precedence, depth, _nan_result}; return cudf::type_dispatcher( lcol.type(), comparator, lhs_element_index, rhs_element_index); } @@ -219,21 +236,23 @@ class device_row_comparator { private: column_device_view const _lhs; column_device_view const _rhs; - Nullate const _nulls; + Nullate const _check_nulls; null_order const _null_precedence; int const _depth; + weak_ordering _nan_result; }; public: /** * @brief Checks whether the row at `lhs_index` in the `lhs` table compares - * lexicographically less than the row at `rhs_index` in the `rhs` table. + * lexicographically less, greater, or equivalent to the row at `rhs_index` in the `rhs` table. * * @param lhs_index The index of row in the `lhs` table to examine * @param rhs_index The index of the row in the `rhs` table to examine - * @return `true` if row from the `lhs` table compares less than row in the `rhs` table + * @return weak ordering comparison of the row in the `lhs` table relative to the row in the `rhs` + * table */ - __device__ bool operator()(size_type const lhs_index, size_type const rhs_index) const noexcept + __device__ weak_ordering operator()(size_type lhs_index, size_type rhs_index) const noexcept { int last_null_depth = std::numeric_limits::max(); for (size_type i = 0; i < _lhs.num_columns(); ++i) { @@ -247,7 +266,19 @@ class device_row_comparator { _null_precedence.has_value() ? (*_null_precedence)[i] : null_order::BEFORE; auto const comparator = - element_comparator{_check_nulls, _lhs.column(i), _rhs.column(i), null_precedence, depth}; + NanConfig ? element_comparator{_check_nulls, + _lhs.column(i), + _rhs.column(i), + null_precedence, + depth, + ascending ? weak_ordering::GREATER : weak_ordering::LESS} + + : element_comparator{_check_nulls, + _lhs.column(i), + _rhs.column(i), + null_precedence, + depth, + weak_ordering::EQUIVALENT}; weak_ordering state; cuda::std::tie(state, last_null_depth) = @@ -255,9 +286,11 @@ class device_row_comparator { if (state == weak_ordering::EQUIVALENT) { continue; } - return state == (ascending ? weak_ordering::LESS : weak_ordering::GREATER); + return ascending + ? state + : (state == weak_ordering::GREATER ? weak_ordering::LESS : weak_ordering::GREATER); } - return false; + return weak_ordering::EQUIVALENT; } private: @@ -269,6 +302,42 @@ class device_row_comparator { std::optional> const _null_precedence; }; // class device_row_comparator +/** + * @brief Wraps and interprets the result of templated Comparator that returns a weak_ordering. + * Returns true if the weak_ordering matches any of the templated values. + * + * Note that this should never be used with only `weak_ordering::EQUIVALENT`. + * An equality comparator should be used instead for optimal performance. + * + * @tparam Comparator generic comparator that returns a weak_ordering. + * @tparam values weak_ordering parameter pack of orderings to interpret as true + */ +template +struct weak_ordering_comparator_impl { + __device__ bool operator()(size_type const& lhs, size_type const& rhs) + { + weak_ordering const result = comparator(lhs, rhs); + return ((result == values) || ...); + } + Comparator comparator; +}; + +/** + * @brief Wraps and interprets the result of device_row_comparator, true if the result is + * weak_ordering::LESS meaning one row is lexicographically *less* than another row. + * + * @tparam Nullate A cudf::nullate type describing whether to check for nulls. + */ +template +using less_comparator = + weak_ordering_comparator_impl, weak_ordering::LESS>; + +template +using less_equivalent_comparator = + weak_ordering_comparator_impl, + weak_ordering::LESS, + weak_ordering::EQUIVALENT>; + struct preprocessed_table { using table_device_view_owner = std::invoke_result_t; @@ -416,11 +485,11 @@ class self_comparator { * * @tparam Nullate A cudf::nullate type describing whether to check for nulls. */ - template - device_row_comparator device_comparator(Nullate nullate = {}) const + template + less_comparator device_comparator(Nullate nullate = {}) const { - return device_row_comparator( - nullate, *d_t, *d_t, d_t->depths(), d_t->column_order(), d_t->null_precedence()); + return less_comparator{device_row_comparator( + nullate, *d_t, *d_t, d_t->depths(), d_t->column_order(), d_t->null_precedence())}; } private: diff --git a/cpp/include/cudf/table/row_operators.cuh b/cpp/include/cudf/table/row_operators.cuh index 4d503cd53b8..c551cac3a93 100644 --- a/cpp/include/cudf/table/row_operators.cuh +++ b/cpp/include/cudf/table/row_operators.cuh @@ -97,6 +97,26 @@ __device__ weak_ordering relational_compare(Element lhs, Element rhs) return detail::compare_elements(lhs, rhs); } +/** + * @brief A specialization for floating-point `Element` type relational comparison + * to derive the order of the elements with respect to `lhs`. Returns specified weak_ordering if + * either value is `nan`, enabling IEEE 754 compliant comparison. + * + * This specialization allows `nan` values to be evaluated as not equal to any other value, while + * also not evaluating as greater or less than + * + * @param lhs first element + * @param rhs second element + * @param nan_result specifies what value should be returned if either element is `nan` + * @return Indicates the relationship between the elements in + * the `lhs` and `rhs` columns. + */ +template ::value>* = nullptr> +__device__ weak_ordering relational_compare(Element lhs, Element rhs, weak_ordering nan_result) +{ + return isnan(lhs) or isnan(rhs) ? nan_result : detail::compare_elements(lhs, rhs); +} + /** * @brief Compare the nulls according to null order. * @@ -123,11 +143,14 @@ inline __device__ auto null_compare(bool lhs_is_null, bool rhs_is_null, null_ord * * @param[in] lhs first element * @param[in] rhs second element + * @param nan_result ignored for non-floating point operation * @return Indicates the relationship between the elements in * the `lhs` and `rhs` columns. */ template >* = nullptr> -__device__ weak_ordering relational_compare(Element lhs, Element rhs) +__device__ weak_ordering relational_compare(Element lhs, + Element rhs, + weak_ordering const nan_result = weak_ordering::GREATER) { return detail::compare_elements(lhs, rhs); } @@ -138,12 +161,15 @@ __device__ weak_ordering relational_compare(Element lhs, Element rhs) * * @param lhs first element * @param rhs second element + * @param nan_result specifies what value should be returned if either element is `nan` * @return `true` if `lhs` == `rhs` else `false`. */ template >* = nullptr> -__device__ bool equality_compare(Element lhs, Element rhs) +__device__ bool equality_compare(Element lhs, + Element rhs, + nan_equality const nan_result = nan_equality::ALL_EQUAL) { - if (isnan(lhs) and isnan(rhs)) { return true; } + if (isnan(lhs) and isnan(rhs)) { return nan_result == nan_equality::ALL_EQUAL; } return lhs == rhs; } @@ -153,10 +179,13 @@ __device__ bool equality_compare(Element lhs, Element rhs) * * @param lhs first element * @param rhs second element + * @param nan_result ignored for non-floating point operation * @return `true` if `lhs` == `rhs` else `false`. */ template >* = nullptr> -__device__ bool equality_compare(Element const lhs, Element const rhs) +__device__ bool equality_compare(Element const lhs, + Element const rhs, + nan_equality const nan_result = nan_equality::ALL_EQUAL) { return lhs == rhs; } @@ -179,13 +208,19 @@ class element_equality_comparator { * @param lhs The column containing the first element * @param rhs The column containing the second element (may be the same as lhs) * @param nulls_are_equal Indicates if two null elements are treated as equivalent + * @param nan_result specifies what value should be returned if either element is `nan` */ __host__ __device__ element_equality_comparator(Nullate has_nulls, column_device_view lhs, column_device_view rhs, - null_equality nulls_are_equal = null_equality::EQUAL) - : lhs{lhs}, rhs{rhs}, nulls{has_nulls}, nulls_are_equal{nulls_are_equal} + null_equality nulls_are_equal = null_equality::EQUAL, + nan_equality nans_are_equal = nan_equality::ALL_EQUAL) + : lhs{lhs}, + rhs{rhs}, + nulls{has_nulls}, + nulls_are_equal{nulls_are_equal}, + nans_are_equal{nans_are_equal} { } @@ -212,7 +247,8 @@ class element_equality_comparator { } return equality_compare(lhs.element(lhs_element_index), - rhs.element(rhs_element_index)); + rhs.element(rhs_element_index), + nans_are_equal); } template @@ -235,8 +272,13 @@ class row_equality_comparator { row_equality_comparator(Nullate has_nulls, table_device_view lhs, table_device_view rhs, - null_equality nulls_are_equal = null_equality::EQUAL) - : lhs{lhs}, rhs{rhs}, nulls{has_nulls}, nulls_are_equal{nulls_are_equal} + null_equality nulls_are_equal = null_equality::EQUAL, + nan_equality nans_are_equal = nan_equality::ALL_EQUAL) + : lhs{lhs}, + rhs{rhs}, + nulls{has_nulls}, + nulls_are_equal{nulls_are_equal}, + nans_are_equal{nans_are_equal} { CUDF_EXPECTS(lhs.num_columns() == rhs.num_columns(), "Mismatched number of columns."); } @@ -244,10 +286,11 @@ class row_equality_comparator { __device__ bool operator()(size_type lhs_row_index, size_type rhs_row_index) const noexcept { auto equal_elements = [=](column_device_view l, column_device_view r) { - return cudf::type_dispatcher(l.type(), - element_equality_comparator{nulls, l, r, nulls_are_equal}, - lhs_row_index, - rhs_row_index); + return cudf::type_dispatcher( + l.type(), + element_equality_comparator{nulls, l, r, nulls_are_equal, nans_are_equal}, + lhs_row_index, + rhs_row_index); }; return thrust::equal(thrust::seq, lhs.begin(), lhs.end(), rhs.begin(), equal_elements); @@ -258,6 +301,7 @@ class row_equality_comparator { table_device_view rhs; Nullate nulls; null_equality nulls_are_equal; + nan_equality nans_are_equal; }; /** diff --git a/cpp/src/binaryop/compiled/binary_ops.cu b/cpp/src/binaryop/compiled/binary_ops.cu index d260aa6d6a0..995046b056d 100644 --- a/cpp/src/binaryop/compiled/binary_ops.cu +++ b/cpp/src/binaryop/compiled/binary_ops.cu @@ -52,7 +52,15 @@ struct scalar_as_column_view { column_view(s.type(), 1, h_scalar_type_view.data(), (bitmask_type const*)s.validity_data()); return std::pair{col_v, std::unique_ptr(nullptr)}; } - template ())>* = nullptr> + template ())>* = nullptr> + return_type operator()(scalar const& s, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + { + auto col = make_column_from_scalar(s, 1, stream, mr); + return std::pair{col->view(), std::move(col)}; + } + template () and !is_struct())>* = nullptr> return_type operator()(scalar const&, rmm::cuda_stream_view, rmm::mr::device_memory_resource*) { CUDF_FAIL("Unsupported type"); @@ -303,6 +311,8 @@ void operator_dispatcher(mutable_column_view& out, binary_operator op, rmm::cuda_stream_view stream) { + if (!is_supported_operation(out.type(), lhs, rhs, op)) + CUDF_FAIL("Unsupported operator for these types"); // clang-format off switch (op) { case binary_operator::ADD: apply_binary_op(out, lhs, rhs, is_lhs_scalar, is_rhs_scalar, stream); break; diff --git a/cpp/src/binaryop/compiled/binary_ops.cuh b/cpp/src/binaryop/compiled/binary_ops.cuh index d88d2be2499..b69cd2f8033 100644 --- a/cpp/src/binaryop/compiled/binary_ops.cuh +++ b/cpp/src/binaryop/compiled/binary_ops.cuh @@ -19,9 +19,15 @@ #include "binary_ops.hpp" #include "operation.cuh" +#include #include #include +#include +#include #include +#include +#include +#include #include #include @@ -284,27 +290,71 @@ void apply_binary_op(mutable_column_view& out, bool is_rhs_scalar, rmm::cuda_stream_view stream) { - auto common_dtype = get_common_type(out.type(), lhs.type(), rhs.type()); + if (is_struct(lhs.type()) && is_struct(rhs.type())) { + auto op_order = detail::is_any_v + ? order::DESCENDING + : order::ASCENDING; + auto accept_equality = detail::is_any_v; + auto const nullability = + structs::detail::contains_null_structs(lhs) || structs::detail::contains_null_structs(rhs) + ? structs::detail::column_nullability::FORCE + : structs::detail::column_nullability::MATCH_INCOMING; + auto const lhs_flattened = + structs::detail::flatten_nested_columns(table_view{{lhs}}, {}, {}, nullability); + auto const rhs_flattened = + structs::detail::flatten_nested_columns(table_view{{rhs}}, {}, {}, nullability); + + auto lhsd = table_device_view::create(lhs_flattened); + auto rhsd = table_device_view::create(rhs_flattened); + auto compare_orders = + cudf::detail::make_device_uvector_async(std::vector(lhs.size(), op_order), stream); + auto comparator = + experimental::row::lexicographic::device_row_comparator{ + nullate::DYNAMIC{has_nested_nulls(lhs_flattened) || has_nested_nulls(rhs_flattened)}, + *lhsd, + *rhsd, + std::nullopt, + device_span{compare_orders}, + std::nullopt}; + + auto outd = column_device_view::create(out, stream); + auto optional_iter = + cudf::detail::make_optional_iterator(*outd, nullate::DYNAMIC{out.has_nulls()}); + thrust::tabulate( + rmm::exec_policy(stream), + out.begin(), + out.end(), + [optional_iter, is_lhs_scalar, is_rhs_scalar, accept_equality, comparator] __device__( + size_type i) { + auto lhs = is_lhs_scalar ? 0 : i; + auto rhs = is_rhs_scalar ? 0 : i; + return optional_iter[i].has_value() && + (accept_equality ? comparator(lhs, rhs) != weak_ordering::GREATER + : comparator(lhs, rhs) == weak_ordering::LESS); + }); - auto lhsd = column_device_view::create(lhs, stream); - auto rhsd = column_device_view::create(rhs, stream); - auto outd = mutable_column_device_view::create(out, stream); - // Create binop functor instance - if (common_dtype) { - // Execute it on every element - for_each(stream, - out.size(), - binary_op_device_dispatcher{ - *common_dtype, *outd, *lhsd, *rhsd, is_lhs_scalar, is_rhs_scalar}); } else { - // Execute it on every element - for_each(stream, - out.size(), - binary_op_double_device_dispatcher{ - *outd, *lhsd, *rhsd, is_lhs_scalar, is_rhs_scalar}); + auto common_dtype = get_common_type(out.type(), lhs.type(), rhs.type()); + + auto lhsd = column_device_view::create(lhs, stream); + auto rhsd = column_device_view::create(rhs, stream); + auto outd = mutable_column_device_view::create(out, stream); + // Create binop functor instance + if (common_dtype) { + // Execute it on every element + for_each(stream, + out.size(), + binary_op_device_dispatcher{ + *common_dtype, *outd, *lhsd, *rhsd, is_lhs_scalar, is_rhs_scalar}); + } else { + // Execute it on every element + for_each(stream, + out.size(), + binary_op_double_device_dispatcher{ + *outd, *lhsd, *rhsd, is_lhs_scalar, is_rhs_scalar}); + } } } - } // namespace compiled } // namespace binops } // namespace cudf diff --git a/cpp/src/binaryop/compiled/binary_ops.hpp b/cpp/src/binaryop/compiled/binary_ops.hpp index d1a40e15326..7a43f5465bf 100644 --- a/cpp/src/binaryop/compiled/binary_ops.hpp +++ b/cpp/src/binaryop/compiled/binary_ops.hpp @@ -172,6 +172,25 @@ std::optional get_common_type(data_type out, data_type lhs, data_type */ bool is_supported_operation(data_type out, data_type lhs, data_type rhs, binary_operator op); +/** + * @brief Check if input binary operation is supported for the given input columns and output types. + * + * If the left and right columns are struct columns, recursively checks if the input columns have + * the same number of children and the corresponding child columns are supported for the specified + * operation and output type. If either input column is not a struct, returns the result of + * is_supported_operation for the input column types. + * + * @param out output type of the binary operation + * @param lhs left column of the binary operation + * @param rhs right column of the binary operation + * @param op binary operator enum + * @return true if given binary operator supports given input columns and output types. + */ +bool is_supported_operation(data_type out, + column_view const& lhs, + column_view const& rhs, + binary_operator op); + // Defined in individual .cu files. /** * @brief Deploys single type or double type dispatcher that runs binary operation on each element diff --git a/cpp/src/binaryop/compiled/equality_ops.cu b/cpp/src/binaryop/compiled/equality_ops.cu index 61f02252a26..344884ca51c 100644 --- a/cpp/src/binaryop/compiled/equality_ops.cu +++ b/cpp/src/binaryop/compiled/equality_ops.cu @@ -16,6 +16,9 @@ #include "binary_ops.cuh" +#include +#include + namespace cudf::binops::compiled { void dispatch_equality_op(mutable_column_view& out, column_view const& lhs, @@ -27,33 +30,69 @@ void dispatch_equality_op(mutable_column_view& out, { CUDF_EXPECTS(op == binary_operator::EQUAL || op == binary_operator::NOT_EQUAL, "Unsupported operator for these types"); - auto common_dtype = get_common_type(out.type(), lhs.type(), rhs.type()); - auto outd = mutable_column_device_view::create(out, stream); - auto lhsd = column_device_view::create(lhs, stream); - auto rhsd = column_device_view::create(rhs, stream); - if (common_dtype) { - if (op == binary_operator::EQUAL) { - for_each(stream, - out.size(), - binary_op_device_dispatcher{ - *common_dtype, *outd, *lhsd, *rhsd, is_lhs_scalar, is_rhs_scalar}); - } else if (op == binary_operator::NOT_EQUAL) { - for_each(stream, - out.size(), - binary_op_device_dispatcher{ - *common_dtype, *outd, *lhsd, *rhsd, is_lhs_scalar, is_rhs_scalar}); - } + if (is_struct(lhs.type()) && is_struct(rhs.type())) { + auto const nullability = + structs::detail::contains_null_structs(lhs) || structs::detail::contains_null_structs(rhs) + ? structs::detail::column_nullability::FORCE + : structs::detail::column_nullability::MATCH_INCOMING; + auto const lhs_flattened = + structs::detail::flatten_nested_columns(table_view{{lhs}}, {}, {}, nullability); + auto const rhs_flattened = + structs::detail::flatten_nested_columns(table_view{{rhs}}, {}, {}, nullability); + auto lhsd = table_device_view::create(lhs_flattened); + auto rhsd = table_device_view::create(rhs_flattened); + auto comparator = row_equality_comparator{ + nullate::DYNAMIC{has_nested_nulls(lhs_flattened) || has_nested_nulls(rhs_flattened)}, + *lhsd, + *rhsd, + null_equality::EQUAL, + nan_equality::UNEQUAL}; + + auto outd = column_device_view::create(out, stream); + auto optional_iter = + cudf::detail::make_optional_iterator(*outd, nullate::DYNAMIC{out.has_nulls()}); + thrust::tabulate(rmm::exec_policy(stream), + out.begin(), + out.end(), + [optional_iter, + is_lhs_scalar, + is_rhs_scalar, + flip_output = (op == binary_operator::NOT_EQUAL), + comparator] __device__(size_type i) { + auto lhs = is_lhs_scalar ? 0 : i; + auto rhs = is_rhs_scalar ? 0 : i; + return optional_iter[i].has_value() and + (flip_output ? not comparator(lhs, rhs) : comparator(lhs, rhs)); + }); } else { - if (op == binary_operator::EQUAL) { - for_each(stream, - out.size(), - binary_op_double_device_dispatcher{ - *outd, *lhsd, *rhsd, is_lhs_scalar, is_rhs_scalar}); - } else if (op == binary_operator::NOT_EQUAL) { - for_each(stream, - out.size(), - binary_op_double_device_dispatcher{ - *outd, *lhsd, *rhsd, is_lhs_scalar, is_rhs_scalar}); + auto common_dtype = get_common_type(out.type(), lhs.type(), rhs.type()); + auto outd = mutable_column_device_view::create(out, stream); + auto lhsd = column_device_view::create(lhs, stream); + auto rhsd = column_device_view::create(rhs, stream); + if (common_dtype) { + if (op == binary_operator::EQUAL) { + for_each(stream, + out.size(), + binary_op_device_dispatcher{ + *common_dtype, *outd, *lhsd, *rhsd, is_lhs_scalar, is_rhs_scalar}); + } else if (op == binary_operator::NOT_EQUAL) { + for_each(stream, + out.size(), + binary_op_device_dispatcher{ + *common_dtype, *outd, *lhsd, *rhsd, is_lhs_scalar, is_rhs_scalar}); + } + } else { + if (op == binary_operator::EQUAL) { + for_each(stream, + out.size(), + binary_op_double_device_dispatcher{ + *outd, *lhsd, *rhsd, is_lhs_scalar, is_rhs_scalar}); + } else if (op == binary_operator::NOT_EQUAL) { + for_each(stream, + out.size(), + binary_op_double_device_dispatcher{ + *outd, *lhsd, *rhsd, is_lhs_scalar, is_rhs_scalar}); + } } } } diff --git a/cpp/src/binaryop/compiled/util.cpp b/cpp/src/binaryop/compiled/util.cpp index 91fa04be6e2..2a5c16ecde8 100644 --- a/cpp/src/binaryop/compiled/util.cpp +++ b/cpp/src/binaryop/compiled/util.cpp @@ -18,6 +18,8 @@ #include #include +#include +#include #include #include @@ -215,6 +217,26 @@ std::optional get_common_type(data_type out, data_type lhs, data_type bool is_supported_operation(data_type out, data_type lhs, data_type rhs, binary_operator op) { - return double_type_dispatcher(lhs, rhs, is_supported_operation_functor{}, out, op); + return double_type_dispatcher(lhs, rhs, is_supported_operation_functor{}, out, op) || + (is_struct(lhs) && is_struct(rhs) && + (op == binary_operator::EQUAL || op == binary_operator::NOT_EQUAL || + op == binary_operator::LESS || op == binary_operator::LESS_EQUAL || + op == binary_operator::GREATER || op == binary_operator::GREATER_EQUAL)); +} + +bool is_supported_operation(data_type out, + column_view const& lhs, + column_view const& rhs, + binary_operator op) +{ + return is_struct(lhs.type()) && is_struct(rhs.type()) + ? (lhs.num_children() == rhs.num_children() || lhs.num_children() == 1 || + rhs.num_children() == 1) && + std::all_of(thrust::counting_iterator(0), + thrust::counting_iterator(lhs.num_children()), + [&](size_type i) { + return is_supported_operation(out, lhs.child(i), rhs.child(i), op); + }) + : is_supported_operation(out, lhs.type(), rhs.type(), op); } } // namespace cudf::binops::compiled diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index eadcd985de3..d3a6210c517 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -193,6 +193,7 @@ ConfigureTest( binaryop/binop-compiled-test.cpp binaryop/binop-compiled-fixed_point-test.cpp binaryop/binop-generic-ptx-test.cpp + binaryop/binop-struct-test.cpp ) # ################################################################################################## diff --git a/cpp/tests/binaryop/binop-struct-test.cpp b/cpp/tests/binaryop/binop-struct-test.cpp new file mode 100644 index 00000000000..d95908c32f3 --- /dev/null +++ b/cpp/tests/binaryop/binop-struct-test.cpp @@ -0,0 +1,391 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include + +namespace cudf::test { +template +struct TypedBinopStructCompare : BaseFixture { +}; + +using NumericTypesNotBool = + cudf::test::Concat; +TYPED_TEST_SUITE(TypedBinopStructCompare, NumericTypesNotBool); +TYPED_TEST(TypedBinopStructCompare, binopcompare_no_nulls) +{ + using T = TypeParam; + + auto col1 = fixed_width_column_wrapper{26, 0, 14, 116, 89, 62, 63, 0, 121}; + auto col2 = fixed_width_column_wrapper{117, 34, 23, 29, 2, 37, 63, 0, 121}; + + auto strings1 = strings_column_wrapper{"0a", "1c", "2d", "3b", "5c", "6", "7d", "9g", "0h"}; + auto strings2 = strings_column_wrapper{"0b", "0c", "2d", "3a", "4c", "6", "8e", "9f", "0h"}; + + auto lhs = structs_column_wrapper{col1, strings1}; + auto rhs = structs_column_wrapper{col2, strings2}; + data_type dt = cudf::data_type(type_id::BOOL8); + + auto res_eq = binary_operation(lhs, rhs, binary_operator::EQUAL, dt); + auto res_neq = binary_operation(lhs, rhs, binary_operator::NOT_EQUAL, dt); + auto res_lt = binary_operation(lhs, rhs, binary_operator::LESS, dt); + auto res_lteq = binary_operation(lhs, rhs, binary_operator::LESS_EQUAL, dt); + auto res_gt = binary_operation(lhs, rhs, binary_operator::GREATER, dt); + auto res_gteq = binary_operation(lhs, rhs, binary_operator::GREATER_EQUAL, dt); + + auto expected_eq = fixed_width_column_wrapper{0, 0, 0, 0, 0, 0, 0, 0, 1}; + auto expected_neq = fixed_width_column_wrapper{1, 1, 1, 1, 1, 1, 1, 1, 0}; + auto expected_lt = fixed_width_column_wrapper{1, 1, 1, 0, 0, 0, 1, 0, 0}; + auto expected_lteq = fixed_width_column_wrapper{1, 1, 1, 0, 0, 0, 1, 0, 1}; + auto expected_gt = fixed_width_column_wrapper{0, 0, 0, 1, 1, 1, 0, 1, 0}; + auto expected_gteq = fixed_width_column_wrapper{0, 0, 0, 1, 1, 1, 0, 1, 1}; + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_eq, expected_eq); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_neq, expected_neq); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_lt, expected_lt); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_lteq, expected_lteq); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_gt, expected_gt); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_gteq, expected_gteq); +} + +TYPED_TEST(TypedBinopStructCompare, binopcompare_with_nulls) +{ + using T = TypeParam; + + auto col1 = fixed_width_column_wrapper{ + {26, 0, 14, 116, 89, 62, 63, 0, 121, 26, 0, 14, 116, 89, 62, 63, 0, 121, 1, 1, 1}, + {0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1}}; + auto col2 = fixed_width_column_wrapper{ + {117, 34, 23, 29, 2, 37, 63, 0, 121, 117, 34, 23, 29, 2, 37, 63, 0, 121, 1, 1, 1}, + {1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1}}; + + auto strings1 = + strings_column_wrapper{{"0b", "", "1c", "2a", "", "5d", "6e", "8f", "", "0a", "1c", + "2d", "3b", "5c", "6", "7d", "9g", "0h", "1f", "2g", "3h"}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1}}; + auto strings2 = + strings_column_wrapper{{"0a", "", "1d", "2a", "3c", "4", "7d", "9", "", "0b", "0c", + "2d", "3a", "4c", "6", "8e", "9f", "0h", "1f", "2g", "3h"}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1}}; + + auto lhs = structs_column_wrapper{ + {col1, strings1}, + std::vector{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0}}; + auto rhs = structs_column_wrapper{ + {col2, strings2}, + std::vector{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0}}; + data_type dt = cudf::data_type(cudf::type_id::BOOL8); + + auto res_eq = binary_operation(lhs, rhs, binary_operator::EQUAL, dt); + auto res_neq = binary_operation(lhs, rhs, binary_operator::NOT_EQUAL, dt); + auto res_lt = binary_operation(lhs, rhs, binary_operator::LESS, dt); + auto res_lteq = binary_operation(lhs, rhs, binary_operator::LESS_EQUAL, dt); + auto res_gt = binary_operation(lhs, rhs, binary_operator::GREATER, dt); + auto res_gteq = binary_operation(lhs, rhs, binary_operator::GREATER_EQUAL, dt); + + auto expected_eq = fixed_width_column_wrapper{ + {0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0}}; + auto expected_neq = fixed_width_column_wrapper{ + {1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0}}; + auto expected_lt = fixed_width_column_wrapper{ + {1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0}}; + auto expected_lteq = fixed_width_column_wrapper{ + {1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0}}; + auto expected_gt = fixed_width_column_wrapper{ + {0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0}}; + auto expected_gteq = fixed_width_column_wrapper{ + {0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0}}; + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_eq, expected_eq); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_neq, expected_neq); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_lt, expected_lt); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_lteq, expected_lteq); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_gt, expected_gt); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_gteq, expected_gteq); +} + +TYPED_TEST(TypedBinopStructCompare, binopcompare_nested_structs) +{ + using T = TypeParam; + + auto col1 = fixed_width_column_wrapper{ + 104, 40, 105, 1, 86, 128, 25, 47, 39, 117, 125, 92, 101, 59, 69, 48, 36, 50}; + auto col2 = fixed_width_column_wrapper{ + {104, 40, 105, 1, 86, 128, 25, 47, 39, 117, 125, 92, 101, 59, 69, 48, 36, 50}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}}; + auto col3 = fixed_width_column_wrapper{ + {26, 0, 14, 116, 89, 62, 63, 0, 121, 26, 0, 14, 116, 89, 62, 63, 0, 121}, + {0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}; + auto col4 = fixed_width_column_wrapper{ + {117, 34, 23, 29, 2, 37, 63, 0, 121, 117, 34, 23, 29, 2, 37, 63, 0, 121}, + {1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}; + + auto s1 = strings_column_wrapper{{"0b", + "", + "1c", + "2a", + "", + "5d", + "6e", + "8f", + "", + "0a", + "1c", + "2d", + "3b", + "5c", + "6", + "7d", + "9g", + "0h"}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0}}; + auto s2 = strings_column_wrapper{{"0a", + "", + "1d", + "2a", + "3c", + "4", + "7d", + "9", + "", + "0b", + "0c", + "2d", + "3a", + "4c", + "6", + "8e", + "9f", + "0h"}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0}}; + + auto struct_col1 = structs_column_wrapper{col3, s1}; + auto struct_col2 = structs_column_wrapper{col4, s2}; + auto lhs = structs_column_wrapper{col1, struct_col1}; + auto rhs = structs_column_wrapper{col2, struct_col2}; + data_type dt = cudf::data_type(cudf::type_id::BOOL8); + + auto res_eq = binary_operation(lhs, rhs, binary_operator::EQUAL, dt); + auto res_neq = binary_operation(lhs, rhs, binary_operator::NOT_EQUAL, dt); + auto res_lt = binary_operation(lhs, rhs, binary_operator::LESS, dt); + auto res_lteq = binary_operation(lhs, rhs, binary_operator::LESS_EQUAL, dt); + auto res_gt = binary_operation(lhs, rhs, binary_operator::GREATER, dt); + auto res_gteq = binary_operation(lhs, rhs, binary_operator::GREATER_EQUAL, dt); + + auto expected_eq = + fixed_width_column_wrapper{0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1}; + auto expected_neq = + fixed_width_column_wrapper{1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0}; + auto expected_lt = + fixed_width_column_wrapper{1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0}; + auto expected_lteq = + fixed_width_column_wrapper{1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1}; + auto expected_gt = + fixed_width_column_wrapper{0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0}; + auto expected_gteq = + fixed_width_column_wrapper{0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1}; + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_eq, expected_eq); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_neq, expected_neq); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_lt, expected_lt); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_lteq, expected_lteq); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_gt, expected_gt); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_gteq, expected_gteq); +} + +TYPED_TEST(TypedBinopStructCompare, binopcompare_scalars) +{ + using T = TypeParam; + + auto col1 = + fixed_width_column_wrapper{40, 105, 68, 25, 86, 68, 25, 127, 68, 68, 68, 68, 68, 68, 68}; + auto col2 = + fixed_width_column_wrapper{{26, 0, 14, 116, 89, 62, 63, 0, 121, 5, 115, 18, 0, 88, 18}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0}}; + auto s1 = strings_column_wrapper{ + {"6S", "5G", "4a", "5G", "", "5Z", "5e", "9a", "5G", "5", "5Gs", "5G", "", "5G2", "5G"}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0}}; + auto struct_col1 = structs_column_wrapper{col2, s1}; + auto col_val = structs_column_wrapper{col1, struct_col1}; + + cudf::test::fixed_width_column_wrapper col3{68}; + cudf::test::fixed_width_column_wrapper col4{{18}, {0}}; + auto strings2 = strings_column_wrapper{"5G"}; + auto struct_col2 = structs_column_wrapper{col4, strings2}; + cudf::table_view tbl({col3, struct_col2}); + cudf::struct_scalar struct_val(tbl); + data_type dt = cudf::data_type(cudf::type_id::BOOL8); + + auto res_eq = binary_operation(col_val, struct_val, binary_operator::EQUAL, dt); + auto res_neq = binary_operation(col_val, struct_val, binary_operator::NOT_EQUAL, dt); + auto res_lt = binary_operation(col_val, struct_val, binary_operator::LESS, dt); + auto res_gt = binary_operation(col_val, struct_val, binary_operator::GREATER, dt); + auto res_gteq = binary_operation(col_val, struct_val, binary_operator::GREATER_EQUAL, dt); + auto res_lteq = binary_operation(col_val, struct_val, binary_operator::LESS_EQUAL, dt); + + auto expected_eq = fixed_width_column_wrapper{0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0}; + auto expected_neq = fixed_width_column_wrapper{1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1}; + auto expected_lt = fixed_width_column_wrapper{1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1}; + auto expected_lteq = + fixed_width_column_wrapper{1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1}; + auto expected_gt = fixed_width_column_wrapper{0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0}; + auto expected_gteq = + fixed_width_column_wrapper{0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0}; + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_eq, expected_eq); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_neq, expected_neq); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_lt, expected_lt); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_lteq, expected_lteq); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_gt, expected_gt); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_gteq, expected_gteq); +} + +struct BinopStructCompareNAN : public cudf::test::BaseFixture { +}; + +TEST_F(BinopStructCompareNAN, float_nans) +{ + cudf::test::fixed_width_column_wrapper lhs{ + -NAN, -NAN, -NAN, NAN, NAN, NAN, 1.0f, 0.0f, -54.3f}; + cudf::test::fixed_width_column_wrapper rhs{ + -32.5f, -NAN, NAN, -0.0f, -NAN, NAN, 111.0f, -NAN, NAN}; + data_type dt = cudf::data_type(cudf::type_id::BOOL8); + + auto expected_eq = binary_operation(lhs, rhs, binary_operator::EQUAL, dt); + auto expected_neq = binary_operation(lhs, rhs, binary_operator::NOT_EQUAL, dt); + auto expected_lt = binary_operation(lhs, rhs, binary_operator::LESS, dt); + auto expected_gt = binary_operation(lhs, rhs, binary_operator::GREATER, dt); + auto expected_gteq = binary_operation(lhs, rhs, binary_operator::GREATER_EQUAL, dt); + auto expected_lteq = binary_operation(lhs, rhs, binary_operator::LESS_EQUAL, dt); + + auto struct_lhs = structs_column_wrapper{lhs}; + auto struct_rhs = structs_column_wrapper{rhs}; + auto res_eq = binary_operation(struct_lhs, struct_rhs, binary_operator::EQUAL, dt); + auto res_neq = binary_operation(struct_lhs, struct_rhs, binary_operator::NOT_EQUAL, dt); + auto res_lt = binary_operation(struct_lhs, struct_rhs, binary_operator::LESS, dt); + auto res_gt = binary_operation(struct_lhs, struct_rhs, binary_operator::GREATER, dt); + auto res_gteq = binary_operation(struct_lhs, struct_rhs, binary_operator::GREATER_EQUAL, dt); + auto res_lteq = binary_operation(struct_lhs, struct_rhs, binary_operator::LESS_EQUAL, dt); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_eq, *expected_eq); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_neq, *expected_neq); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_lt, *expected_lt); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_lteq, *expected_lteq); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_gt, *expected_gt); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_gteq, *expected_gteq); +}; + +TEST_F(BinopStructCompareNAN, double_nans) +{ + cudf::test::fixed_width_column_wrapper lhs{ + -NAN, -NAN, -NAN, NAN, NAN, NAN, 1.0f, 0.0f, -54.3f}; + cudf::test::fixed_width_column_wrapper rhs{ + -32.5f, -NAN, NAN, -0.0f, -NAN, NAN, 111.0f, -NAN, NAN}; + data_type dt = cudf::data_type(cudf::type_id::BOOL8); + + auto expected_eq = binary_operation(lhs, rhs, binary_operator::EQUAL, dt); + auto expected_neq = binary_operation(lhs, rhs, binary_operator::NOT_EQUAL, dt); + auto expected_lt = binary_operation(lhs, rhs, binary_operator::LESS, dt); + auto expected_gt = binary_operation(lhs, rhs, binary_operator::GREATER, dt); + auto expected_gteq = binary_operation(lhs, rhs, binary_operator::GREATER_EQUAL, dt); + auto expected_lteq = binary_operation(lhs, rhs, binary_operator::LESS_EQUAL, dt); + + auto struct_lhs = structs_column_wrapper{lhs}; + auto struct_rhs = structs_column_wrapper{rhs}; + auto res_eq = binary_operation(struct_lhs, struct_rhs, binary_operator::EQUAL, dt); + auto res_neq = binary_operation(struct_lhs, struct_rhs, binary_operator::NOT_EQUAL, dt); + auto res_lt = binary_operation(struct_lhs, struct_rhs, binary_operator::LESS, dt); + auto res_gt = binary_operation(struct_lhs, struct_rhs, binary_operator::GREATER, dt); + auto res_gteq = binary_operation(struct_lhs, struct_rhs, binary_operator::GREATER_EQUAL, dt); + auto res_lteq = binary_operation(struct_lhs, struct_rhs, binary_operator::LESS_EQUAL, dt); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_eq, *expected_eq); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_neq, *expected_neq); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_lt, *expected_lt); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_lteq, *expected_lteq); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_gt, *expected_gt); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*res_gteq, *expected_gteq); +}; + +struct BinopStructCompareFailures : public cudf::test::BaseFixture { + void attempt_struct_binop(binary_operator op, + data_type dt = cudf::data_type(cudf::type_id::BOOL8)) + { + auto col = fixed_width_column_wrapper{0, 89, 121}; + auto struct_col = structs_column_wrapper{col}; + binary_operation(struct_col, struct_col, op, dt); + } +}; + +TEST_F(BinopStructCompareFailures, binopcompare_lists) +{ + auto list_col = lists_column_wrapper{{0, 0}, {127, 3, 55}, {7, 3}}; + auto struct_col = structs_column_wrapper{list_col}; + auto dt = cudf::data_type(cudf::type_id::BOOL8); + + EXPECT_THROW(binary_operation(struct_col, struct_col, binary_operator::EQUAL, dt), + cudf::logic_error); + EXPECT_THROW(binary_operation(struct_col, struct_col, binary_operator::NOT_EQUAL, dt), + cudf::logic_error); + EXPECT_THROW(binary_operation(struct_col, struct_col, binary_operator::LESS, dt), + cudf::logic_error); + EXPECT_THROW(binary_operation(struct_col, struct_col, binary_operator::GREATER, dt), + cudf::logic_error); + EXPECT_THROW(binary_operation(struct_col, struct_col, binary_operator::GREATER_EQUAL, dt), + cudf::logic_error); + EXPECT_THROW(binary_operation(struct_col, struct_col, binary_operator::LESS_EQUAL, dt), + cudf::logic_error); +} + +TEST_F(BinopStructCompareFailures, binopcompare_unsupported_ops) +{ + EXPECT_THROW(attempt_struct_binop(binary_operator::ADD), cudf::logic_error); + EXPECT_THROW(attempt_struct_binop(binary_operator::SUB), cudf::logic_error); + EXPECT_THROW(attempt_struct_binop(binary_operator::MUL), cudf::logic_error); + EXPECT_THROW(attempt_struct_binop(binary_operator::DIV), cudf::logic_error); + EXPECT_THROW(attempt_struct_binop(binary_operator::TRUE_DIV), cudf::logic_error); + EXPECT_THROW(attempt_struct_binop(binary_operator::FLOOR_DIV), cudf::logic_error); + EXPECT_THROW(attempt_struct_binop(binary_operator::MOD), cudf::logic_error); + EXPECT_THROW(attempt_struct_binop(binary_operator::PMOD), cudf::logic_error); + EXPECT_THROW(attempt_struct_binop(binary_operator::PYMOD), cudf::logic_error); + EXPECT_THROW(attempt_struct_binop(binary_operator::POW), cudf::logic_error); + EXPECT_THROW(attempt_struct_binop(binary_operator::LOG_BASE), cudf::logic_error); + EXPECT_THROW(attempt_struct_binop(binary_operator::ATAN2), cudf::logic_error); + EXPECT_THROW(attempt_struct_binop(binary_operator::SHIFT_LEFT), cudf::logic_error); + EXPECT_THROW(attempt_struct_binop(binary_operator::SHIFT_RIGHT), cudf::logic_error); + EXPECT_THROW(attempt_struct_binop(binary_operator::SHIFT_RIGHT_UNSIGNED), cudf::logic_error); + EXPECT_THROW(attempt_struct_binop(binary_operator::BITWISE_AND), cudf::logic_error); + EXPECT_THROW(attempt_struct_binop(binary_operator::BITWISE_OR), cudf::logic_error); + EXPECT_THROW(attempt_struct_binop(binary_operator::BITWISE_XOR), cudf::logic_error); + EXPECT_THROW(attempt_struct_binop(binary_operator::LOGICAL_AND), cudf::logic_error); + EXPECT_THROW(attempt_struct_binop(binary_operator::LOGICAL_OR), cudf::logic_error); + EXPECT_THROW(attempt_struct_binop(binary_operator::NULL_EQUALS), cudf::logic_error); + EXPECT_THROW(attempt_struct_binop(binary_operator::NULL_MAX), cudf::logic_error); + EXPECT_THROW(attempt_struct_binop(binary_operator::NULL_MIN), cudf::logic_error); + EXPECT_THROW(attempt_struct_binop(binary_operator::GENERIC_BINARY), cudf::logic_error); + EXPECT_THROW(attempt_struct_binop(binary_operator::INVALID_BINARY), cudf::logic_error); +} + +} // namespace cudf::test