Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make binary operators work between fixed-point and floating args #16116

Merged
merged 2 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions cpp/include/cudf/binaryop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,56 @@ enum class binary_operator : int32_t {
///< (null, false) is null, and (valid, valid) == LOGICAL_OR(valid, valid)
INVALID_BINARY ///< invalid operation
};

/// Binary operation common type default
template <typename L, typename R, typename = void>
struct binary_op_common_type {};

/// Binary operation common type specialization
template <typename L, typename R>
struct binary_op_common_type<L, R, std::enable_if_t<has_common_type_v<L, R>>> {
/// The common type of the template parameters
using type = std::common_type_t<L, R>;
};

/// Binary operation common type specialization
template <typename L, typename R>
struct binary_op_common_type<
L,
R,
std::enable_if_t<is_fixed_point<L>() && cuda::std::is_floating_point_v<R>>> {
/// The common type of the template parameters
using type = L;
};

/// Binary operation common type specialization
template <typename L, typename R>
struct binary_op_common_type<
L,
R,
std::enable_if_t<is_fixed_point<R>() && cuda::std::is_floating_point_v<L>>> {
/// The common type of the template parameters
using type = R;
};

/// Binary operation common type helper
template <typename L, typename R>
using binary_op_common_type_t = typename binary_op_common_type<L, R>::type;

namespace detail {
template <typename AlwaysVoid, typename L, typename R>
struct binary_op_has_common_type_impl : std::false_type {};

template <typename L, typename R>
struct binary_op_has_common_type_impl<std::void_t<binary_op_common_type_t<L, R>>, L, R>
: std::true_type {};
} // namespace detail

/// Checks if binary operation types have a common type
template <typename L, typename R>
constexpr inline bool binary_op_has_common_type_v =
detail::binary_op_has_common_type_impl<void, L, R>::value;

/**
* @brief Performs a binary operation between a scalar and a column.
*
Expand Down
14 changes: 11 additions & 3 deletions cpp/src/binaryop/compiled/binary_ops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,16 @@ struct type_casted_accessor {
column_device_view const& col,
bool is_scalar) const
{
if constexpr (column_device_view::has_element_accessor<Element>() and
std::is_convertible_v<Element, CastType>)
return static_cast<CastType>(col.element<Element>(is_scalar ? 0 : i));
if constexpr (column_device_view::has_element_accessor<Element>()) {
auto const element = col.element<Element>(is_scalar ? 0 : i);
if constexpr (std::is_convertible_v<Element, CastType>) {
return static_cast<CastType>(element);
} else if constexpr (is_fixed_point<Element>() && cuda::std::is_floating_point_v<CastType>) {
return convert_fixed_to_floating<CastType>(element);
} else if constexpr (is_fixed_point<CastType>() && cuda::std::is_floating_point_v<Element>) {
return convert_floating_to_fixed<CastType>(element, numeric::scale_type{0});
}
}
return {};
}
};
Expand Down Expand Up @@ -159,6 +166,7 @@ struct ops2_wrapper {
TypeRhs y = rhs.element<TypeRhs>(is_rhs_scalar ? 0 : i);
auto result = [&]() {
if constexpr (std::is_same_v<BinaryOperator, ops::NullEquals> or
std::is_same_v<BinaryOperator, ops::NullNotEquals> or
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch.

std::is_same_v<BinaryOperator, ops::NullLogicalAnd> or
std::is_same_v<BinaryOperator, ops::NullLogicalOr> or
std::is_same_v<BinaryOperator, ops::NullMax> or
Expand Down
12 changes: 6 additions & 6 deletions cpp/src/binaryop/compiled/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ struct common_type_functor {
template <typename TypeLhs, typename TypeRhs>
std::optional<data_type> operator()() const
{
if constexpr (cudf::has_common_type_v<TypeLhs, TypeRhs>) {
using TypeCommon = std::common_type_t<TypeLhs, TypeRhs>;
if constexpr (binary_op_has_common_type_v<TypeLhs, TypeRhs>) {
using TypeCommon = binary_op_common_type_t<TypeLhs, TypeRhs>;
return data_type{type_to_id<TypeCommon>()};
}

Expand Down Expand Up @@ -85,8 +85,8 @@ struct is_binary_operation_supported {
{
if constexpr (column_device_view::has_element_accessor<TypeLhs>() and
column_device_view::has_element_accessor<TypeRhs>()) {
if constexpr (has_common_type_v<TypeLhs, TypeRhs>) {
using common_t = std::common_type_t<TypeLhs, TypeRhs>;
if constexpr (binary_op_has_common_type_v<TypeLhs, TypeRhs>) {
using common_t = binary_op_common_type_t<TypeLhs, TypeRhs>;
return std::is_invocable_v<BinaryOperator, common_t, common_t>;
} else {
return std::is_invocable_v<BinaryOperator, TypeLhs, TypeRhs>;
Expand All @@ -102,8 +102,8 @@ struct is_binary_operation_supported {
if constexpr (column_device_view::has_element_accessor<TypeLhs>() and
column_device_view::has_element_accessor<TypeRhs>()) {
if (has_mutable_element_accessor(out_type) or is_fixed_point(out_type)) {
if constexpr (has_common_type_v<TypeLhs, TypeRhs>) {
using common_t = std::common_type_t<TypeLhs, TypeRhs>;
if constexpr (binary_op_has_common_type_v<TypeLhs, TypeRhs>) {
using common_t = binary_op_common_type_t<TypeLhs, TypeRhs>;
if constexpr (std::is_invocable_v<BinaryOperator, common_t, common_t>) {
using ReturnType = std::invoke_result_t<BinaryOperator, common_t, common_t>;
return is_constructible<ReturnType>(out_type) or
Expand Down
58 changes: 58 additions & 0 deletions cpp/tests/binaryop/binop-compiled-fixed_point-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,3 +843,61 @@ TYPED_TEST(FixedPointTest_64_128_Reps, FixedPoint_64_128_ComparisonTests)
CUDF_TEST_EXPECT_COLUMNS_EQUAL(h->view(), falses);
}
}

template <typename ResultType>
void test_fixed_floating(cudf::binary_operator op,
double floating_value,
int decimal_value,
int decimal_scale,
ResultType expected)
{
auto const scale = numeric::scale_type{decimal_scale};
auto const result_type = cudf::data_type(cudf::type_to_id<ResultType>());
auto const nullable =
(op == cudf::binary_operator::NULL_EQUALS || op == cudf::binary_operator::NULL_NOT_EQUALS ||
op == cudf::binary_operator::NULL_MIN || op == cudf::binary_operator::NULL_MAX);

cudf::test::fixed_width_column_wrapper<double> floating_col({floating_value});
cudf::test::fixed_point_column_wrapper<int> decimal_col({decimal_value}, scale);

auto result = binary_operation(floating_col, decimal_col, op, result_type);

if constexpr (cudf::is_fixed_point<ResultType>()) {
using wrapper_type = cudf::test::fixed_point_column_wrapper<typename ResultType::rep>;
auto const expected_col = nullable ? wrapper_type({expected.value()}, {true}, expected.scale())
: wrapper_type({expected.value()}, expected.scale());
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_col, *result.get());
} else {
using wrapper_type = cudf::test::fixed_width_column_wrapper<ResultType>;
auto const expected_col =
nullable ? wrapper_type({expected}, {true}) : wrapper_type({expected});
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_col, *result.get());
}
}

TYPED_TEST(FixedPointCompiledTest, FixedPointWithFloating)
{
using namespace numeric;

// BOOLEAN
test_fixed_floating(cudf::binary_operator::EQUAL, 1.0, 10, -1, true);
test_fixed_floating(cudf::binary_operator::NOT_EQUAL, 1.0, 10, -1, false);
test_fixed_floating(cudf::binary_operator::LESS, 2.0, 10, -1, false);
test_fixed_floating(cudf::binary_operator::GREATER, 2.0, 10, -1, true);
test_fixed_floating(cudf::binary_operator::LESS_EQUAL, 2.0, 20, -1, true);
test_fixed_floating(cudf::binary_operator::GREATER_EQUAL, 2.0, 30, -1, false);
test_fixed_floating(cudf::binary_operator::NULL_EQUALS, 1.0, 10, -1, true);
test_fixed_floating(cudf::binary_operator::NULL_NOT_EQUALS, 1.0, 10, -1, false);

// PRIMARY ARITHMETIC
auto const decimal_result = numeric::decimal32(4, numeric::scale_type{0});
test_fixed_floating(cudf::binary_operator::ADD, 1.0, 30, -1, decimal_result);
test_fixed_floating(cudf::binary_operator::SUB, 6.0, 20, -1, decimal_result);
test_fixed_floating(cudf::binary_operator::MUL, 2.0, 20, -1, decimal_result);
test_fixed_floating(cudf::binary_operator::DIV, 8.0, 2, 0, decimal_result);
test_fixed_floating(cudf::binary_operator::MOD, 9.0, 50, -1, decimal_result);

// OTHER ARITHMETIC
test_fixed_floating(cudf::binary_operator::NULL_MAX, 4.0, 20, -1, decimal_result);
test_fixed_floating(cudf::binary_operator::NULL_MIN, 4.0, 200, -1, decimal_result);
}
Loading