diff --git a/cpp/src/groupby/sort/aggregate.cpp b/cpp/src/groupby/sort/aggregate.cpp index 02036ff0bbf..55a0b89e446 100644 --- a/cpp/src/groupby/sort/aggregate.cpp +++ b/cpp/src/groupby/sort/aggregate.cpp @@ -31,7 +31,7 @@ #include #include #include -#include +#include #include #include #include @@ -99,7 +99,7 @@ void aggregate_result_functor::operator()(aggregation const& a agg, detail::group_sum( get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr)); -}; +} template <> void aggregate_result_functor::operator()(aggregation const& agg) @@ -111,7 +111,7 @@ void aggregate_result_functor::operator()(aggregation cons agg, detail::group_product( get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr)); -}; +} template <> void aggregate_result_functor::operator()(aggregation const& agg) @@ -126,7 +126,7 @@ void aggregate_result_functor::operator()(aggregation const helper.key_sort_order(stream), stream, mr)); -}; +} template <> void aggregate_result_functor::operator()(aggregation const& agg) @@ -141,7 +141,7 @@ void aggregate_result_functor::operator()(aggregation const helper.key_sort_order(stream), stream, mr)); -}; +} template <> void aggregate_result_functor::operator()(aggregation const& agg) @@ -181,7 +181,7 @@ void aggregate_result_functor::operator()(aggregation const& a }(); cache.add_result(values, agg, std::move(result)); -}; +} template <> void aggregate_result_functor::operator()(aggregation const& agg) @@ -221,7 +221,7 @@ void aggregate_result_functor::operator()(aggregation const& a }(); cache.add_result(values, agg, std::move(result)); -}; +} template <> void aggregate_result_functor::operator()(aggregation const& agg) @@ -248,7 +248,7 @@ void aggregate_result_functor::operator()(aggregation const& stream, mr); cache.add_result(values, agg, std::move(result)); -}; +} template <> void aggregate_result_functor::operator()(aggregation const& agg) @@ -263,7 +263,7 @@ void aggregate_result_functor::operator()(aggregation const& ag values, agg, detail::group_m2(get_grouped_values(), mean_result, helper.group_labels(stream), stream, mr)); -}; +} template <> void aggregate_result_functor::operator()(aggregation const& agg) @@ -286,7 +286,7 @@ void aggregate_result_functor::operator()(aggregation con stream, mr); cache.add_result(values, agg, std::move(result)); -}; +} template <> void aggregate_result_functor::operator()(aggregation const& agg) @@ -300,7 +300,7 @@ void aggregate_result_functor::operator()(aggregation const& a auto result = cudf::detail::unary_operation(var_result, unary_operator::SQRT, stream, mr); cache.add_result(values, agg, std::move(result)); -}; +} template <> void aggregate_result_functor::operator()(aggregation const& agg) @@ -321,7 +321,7 @@ void aggregate_result_functor::operator()(aggregation con stream, mr); cache.add_result(values, agg, std::move(result)); -}; +} template <> void aggregate_result_functor::operator()(aggregation const& agg) @@ -341,7 +341,7 @@ void aggregate_result_functor::operator()(aggregation const stream, mr); cache.add_result(values, agg, std::move(result)); -}; +} template <> void aggregate_result_functor::operator()(aggregation const& agg) @@ -358,7 +358,7 @@ void aggregate_result_functor::operator()(aggregation cons stream, mr); cache.add_result(values, agg, std::move(result)); -}; +} template <> void aggregate_result_functor::operator()(aggregation const& agg) @@ -404,7 +404,7 @@ void aggregate_result_functor::operator()(aggregation stream, mr); cache.add_result(values, agg, std::move(result)); -}; +} template <> void aggregate_result_functor::operator()(aggregation const& agg) @@ -426,9 +426,9 @@ void aggregate_result_functor::operator()(aggregation cache.add_result( values, agg, - lists::detail::drop_list_duplicates( - lists_column_view(collect_result->view()), nulls_equal, nans_equal, stream, mr)); -}; + lists::detail::distinct( + lists_column_view{collect_result->view()}, nulls_equal, nans_equal, stream, mr)); +} /** * @brief Perform merging for the lists that correspond to the same key value. @@ -455,7 +455,7 @@ void aggregate_result_functor::operator()(aggregation agg, detail::group_merge_lists( get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), stream, mr)); -}; +} /** * @brief Perform merging for the lists corresponding to the same key value, then dropping duplicate @@ -473,13 +473,13 @@ void aggregate_result_functor::operator()(aggregation * column for this aggregation. * * Firstly, this aggregation performs `MERGE_LISTS` to concatenate the input lists (corresponding to - * the same key) into intermediate lists, then it calls `lists::drop_list_duplicates` on them to + * the same key) into intermediate lists, then it calls `lists::distinct` on them to * remove duplicate list entries. As such, the input (partial results) to this aggregation should be * generated by (distributed) `COLLECT_LIST` aggregations, not `COLLECT_SET`, to avoid unnecessarily * removing duplicate entries for the partial results. * * Since duplicate list entries will be removed, the parameters `null_equality` and `nan_equality` - * are needed for calling to `lists::drop_list_duplicates`. + * are needed for calling `lists::distinct`. */ template <> void aggregate_result_functor::operator()(aggregation const& agg) @@ -494,12 +494,12 @@ void aggregate_result_functor::operator()(aggregation c auto const& merge_sets_agg = dynamic_cast(agg); cache.add_result(values, agg, - lists::detail::drop_list_duplicates(lists_column_view(merged_result->view()), - merge_sets_agg._nulls_equal, - merge_sets_agg._nans_equal, - stream, - mr)); -}; + lists::detail::distinct(lists_column_view{merged_result->view()}, + merge_sets_agg._nulls_equal, + merge_sets_agg._nans_equal, + stream, + mr)); +} /** * @brief Perform merging for the M2 values that correspond to the same key value. @@ -528,7 +528,7 @@ void aggregate_result_functor::operator()(aggregation con agg, detail::group_merge_m2( get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), stream, mr)); -}; +} /** * @brief Creates column views with only valid elements in both input column views @@ -600,7 +600,7 @@ void aggregate_result_functor::operator()(aggregation c cov_agg._ddof, stream, mr)); -}; +} /** * @brief Perform correlation between two child columns of non-nullable struct column. @@ -710,7 +710,7 @@ void aggregate_result_functor::operator()(aggregation cons max_centroids, stream, mr)); -}; +} /** * @brief Generate a merged tdigest column from a grouped set of input tdigest columns. @@ -752,7 +752,7 @@ void aggregate_result_functor::operator()(aggregatio max_centroids, stream, mr)); -}; +} } // namespace detail diff --git a/cpp/src/reductions/collect_ops.cu b/cpp/src/reductions/collect_ops.cu index c9bd06a1171..4d6a32b528a 100644 --- a/cpp/src/reductions/collect_ops.cu +++ b/cpp/src/reductions/collect_ops.cu @@ -18,8 +18,7 @@ #include #include #include -#include -#include +#include #include #include #include @@ -27,25 +26,28 @@ namespace cudf { namespace reduction { -std::unique_ptr drop_duplicates(list_scalar const& scalar, - null_equality nulls_equal, - nan_equality nans_equal, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) +namespace { + +/** + * @brief Check if we need to handle nulls in the input column. + * + * @param input The input column + * @param null_handling The null handling policy + * @return A boolean value indicating if we need to handle nulls + */ +bool need_handle_nulls(column_view const& input, null_policy null_handling) { - auto list_wrapper = lists::detail::make_lists_column_from_scalar(scalar, 1, stream, mr); - auto lcw = lists_column_view(list_wrapper->view()); - auto no_dup_wrapper = lists::drop_list_duplicates(lcw, nulls_equal, nans_equal, mr); - auto no_dup = lists_column_view(no_dup_wrapper->view()).get_sliced_child(stream); - return make_list_scalar(no_dup, stream, mr); + return null_handling == null_policy::EXCLUDE && input.has_nulls(); } +} // namespace + std::unique_ptr collect_list(column_view const& col, null_policy null_handling, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - if (null_handling == null_policy::EXCLUDE && col.has_nulls()) { + if (need_handle_nulls(col, null_handling)) { auto d_view = column_device_view::create(col, stream); auto filter = detail::validity_accessor(*d_view); auto null_purged_table = detail::copy_if(table_view{{col}}, filter, stream, mr); @@ -72,9 +74,27 @@ std::unique_ptr collect_set(column_view const& col, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - auto scalar = collect_list(col, null_handling, stream, mr); - auto ls = dynamic_cast(scalar.get()); - return drop_duplicates(*ls, nulls_equal, nans_equal, stream, mr); + // `input_as_collect_list` is the result of the input column that has been processed to obey + // the given null handling behavior. + [[maybe_unused]] auto const [input_as_collect_list, unused_scalar] = [&] { + if (need_handle_nulls(col, null_handling)) { + // Only call `collect_list` when we need to handle nulls. + auto scalar = collect_list(col, null_handling, stream, mr); + return std::pair(static_cast(scalar.get())->view(), std::move(scalar)); + } + + return std::pair(col, std::unique_ptr(nullptr)); + }(); + + auto distinct_table = detail::distinct(table_view{{input_as_collect_list}}, + std::vector{0}, + duplicate_keep_option::KEEP_ANY, + nulls_equal, + nans_equal, + stream, + mr); + + return std::make_unique(std::move(distinct_table->get_column(0)), true, stream, mr); } std::unique_ptr merge_sets(lists_column_view const& col, @@ -83,9 +103,15 @@ std::unique_ptr merge_sets(lists_column_view const& col, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - auto flatten_col = col.get_sliced_child(stream); - auto scalar = std::make_unique(flatten_col, true, stream, mr); - return drop_duplicates(*scalar, nulls_equal, nans_equal, stream, mr); + auto flatten_col = col.get_sliced_child(stream); + auto distinct_table = detail::distinct(table_view{{flatten_col}}, + std::vector{0}, + duplicate_keep_option::KEEP_ANY, + nulls_equal, + nans_equal, + stream, + mr); + return std::make_unique(std::move(distinct_table->get_column(0)), true, stream, mr); } } // namespace reduction diff --git a/cpp/src/rolling/detail/rolling.cuh b/cpp/src/rolling/detail/rolling.cuh index d5d30d1f699..933d0410df5 100644 --- a/cpp/src/rolling/detail/rolling.cuh +++ b/cpp/src/rolling/detail/rolling.cuh @@ -38,7 +38,7 @@ #include #include #include -#include +#include #include #include #include @@ -928,8 +928,8 @@ class rolling_aggregation_postprocessor final : public cudf::detail::aggregation stream, rmm::mr::get_current_device_resource()); - result = lists::detail::drop_list_duplicates( - lists_column_view(collected_list->view()), agg._nulls_equal, agg._nans_equal, stream, mr); + result = lists::detail::distinct( + lists_column_view{collected_list->view()}, agg._nulls_equal, agg._nans_equal, stream, mr); } // perform the element-wise square root operation on result of VARIANCE diff --git a/cpp/tests/groupby/collect_set_tests.cpp b/cpp/tests/groupby/collect_set_tests.cpp index c429dc72259..cf324cf3a8e 100644 --- a/cpp/tests/groupby/collect_set_tests.cpp +++ b/cpp/tests/groupby/collect_set_tests.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,23 +14,61 @@ * limitations under the License. */ -#include - #include #include #include +#include #include +#include +#include +#include +#include namespace cudf { namespace test { -#define COL_K cudf::test::fixed_width_column_wrapper -#define COL_V cudf::test::fixed_width_column_wrapper -#define COL_S cudf::test::strings_column_wrapper -#define LCL_V cudf::test::lists_column_wrapper -#define LCL_S cudf::test::lists_column_wrapper -#define VALIDITY std::initializer_list +namespace { + +constexpr cudf::test::debug_output_level verbosity{cudf::test::debug_output_level::FIRST_ERROR}; + +using keys_col = cudf::test::fixed_width_column_wrapper; +using strings_col = cudf::test::strings_column_wrapper; +using strings_lists = cudf::test::lists_column_wrapper; +using validity_col = std::initializer_list; + +auto groupby_collect_set(cudf::column_view const& keys, + cudf::column_view const& values, + std::unique_ptr&& agg) +{ + std::vector requests; + requests.emplace_back(cudf::groupby::aggregation_request()); + requests[0].values = values; + requests[0].aggregations.emplace_back(std::move(agg)); + + auto const result = cudf::groupby::groupby(cudf::table_view({keys})).aggregate(requests); + auto const result_keys = result.first->view(); // <== table_view of 1 column + auto const result_vals = result.second[0].results[0]->view(); // <== column_view + + // Sort the output columns based on the output keys. + // This is to facilitate comparison of the output with the expected columns. + auto keys_vals_sorted = cudf::sort_by_key(cudf::table_view{{result_keys.column(0), result_vals}}, + result_keys, + {}, + {cudf::null_order::AFTER}) + ->release(); + + // After the columns were reordered, individual rows of the output values column (which are lists) + // also need to be sorted. + auto out_values = + cudf::lists::sort_lists(cudf::lists_column_view{keys_vals_sorted.back()->view()}, + cudf::order::ASCENDING, + cudf::null_order::AFTER); + + return std::pair(std::move(keys_vals_sorted.front()), std::move(out_values)); +} + +} // namespace struct CollectSetTest : public cudf::test::BaseFixture { static auto collect_set() @@ -61,74 +99,117 @@ TYPED_TEST_SUITE(CollectSetTypedTest, FixedWidthTypesNotBool); TYPED_TEST(CollectSetTypedTest, TrivialInput) { + using vals_col = cudf::test::fixed_width_column_wrapper; + using lists_col = cudf::test::lists_column_wrapper; + // Empty input - test_single_agg(COL_K{}, COL_V{}, COL_K{}, LCL_V{}, CollectSetTest::collect_set()); + { + keys_col keys{}; + vals_col vals{}; + keys_col keys_expected{}; + lists_col vals_expected{}; + + auto const [out_keys, out_lists] = + groupby_collect_set(keys, vals, CollectSetTest::collect_set()); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(keys_expected, *out_keys, verbosity); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(vals_expected, *out_lists, verbosity); + } // Single key input { - COL_K keys{1}; - COL_V vals{10}; - COL_K keys_expected{1}; - LCL_V vals_expected{LCL_V{10}}; - test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set()); + keys_col keys{1}; + vals_col vals{10}; + keys_col keys_expected{1}; + lists_col vals_expected{lists_col{10}}; + + auto const [out_keys, out_lists] = + groupby_collect_set(keys, vals, CollectSetTest::collect_set()); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(keys_expected, *out_keys, verbosity); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(vals_expected, *out_lists, verbosity); } // Non-repeated keys { - COL_K keys{2, 1}; - COL_V vals{20, 10}; - COL_K keys_expected{1, 2}; - LCL_V vals_expected{LCL_V{10}, LCL_V{20}}; - test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set()); + keys_col keys{2, 1}; + vals_col vals{20, 10}; + keys_col keys_expected{1, 2}; + lists_col vals_expected{lists_col{10}, lists_col{20}}; + + auto const [out_keys, out_lists] = + groupby_collect_set(keys, vals, CollectSetTest::collect_set()); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(keys_expected, *out_keys, verbosity); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(vals_expected, *out_lists, verbosity); } } TYPED_TEST(CollectSetTypedTest, TypicalInput) { + using vals_col = cudf::test::fixed_width_column_wrapper; + using lists_col = cudf::test::lists_column_wrapper; + // Pre-sorted keys { - COL_K keys{1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3}; - COL_V vals{10, 11, 10, 10, 20, 21, 21, 20, 30, 33, 32, 31}; - COL_K keys_expected{1, 2, 3}; - LCL_V vals_expected{{10, 11}, {20, 21}, {30, 31, 32, 33}}; - test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set()); + keys_col keys{1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3}; + vals_col vals{10, 11, 10, 10, 20, 21, 21, 20, 30, 33, 32, 31}; + keys_col keys_expected{1, 2, 3}; + lists_col vals_expected{{10, 11}, {20, 21}, {30, 31, 32, 33}}; + + auto const [out_keys, out_lists] = + groupby_collect_set(keys, vals, CollectSetTest::collect_set()); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(keys_expected, *out_keys, verbosity); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(vals_expected, *out_lists, verbosity); } // Expect the result keys to be sorted by sort-based groupby { - COL_K keys{4, 1, 2, 4, 3, 3, 2, 1}; - COL_V vals{40, 10, 20, 40, 30, 30, 20, 11}; - COL_K keys_expected{1, 2, 3, 4}; - LCL_V vals_expected{{10, 11}, {20}, {30}, {40}}; - test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set()); + keys_col keys{4, 1, 2, 4, 3, 3, 2, 1}; + vals_col vals{40, 10, 20, 40, 30, 30, 20, 11}; + keys_col keys_expected{1, 2, 3, 4}; + lists_col vals_expected{{10, 11}, {20}, {30}, {40}}; + + auto const [out_keys, out_lists] = + groupby_collect_set(keys, vals, CollectSetTest::collect_set()); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(keys_expected, *out_keys, verbosity); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(vals_expected, *out_lists, verbosity); } } // Keys and values columns are sliced columns TYPED_TEST(CollectSetTypedTest, SlicedColumnsInput) { - COL_K keys_original{1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3}; - COL_V vals_original{10, 11, 10, 10, 20, 21, 21, 20, 30, 33, 32, 31}; + using vals_col = cudf::test::fixed_width_column_wrapper; + using lists_col = cudf::test::lists_column_wrapper; + + keys_col keys_original{1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3}; + vals_col vals_original{10, 11, 10, 10, 20, 21, 21, 20, 30, 33, 32, 31}; { auto const keys = cudf::slice(keys_original, {0, 4})[0]; // { 1, 1, 1, 1 } auto const vals = cudf::slice(vals_original, {0, 4})[0]; // { 10, 11, 10, 10 } - auto const keys_expected = COL_K{1}; - auto const vals_expected = LCL_V{{10, 11}}; - test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set()); + auto const keys_expected = keys_col{1}; + auto const vals_expected = lists_col{{10, 11}}; + + auto const [out_keys, out_lists] = + groupby_collect_set(keys, vals, CollectSetTest::collect_set()); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(keys_expected, *out_keys, verbosity); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(vals_expected, *out_lists, verbosity); } { auto const keys = cudf::slice(keys_original, {2, 10})[0]; // { 1, 1, 2, 2, 2, 2, 3, 3 } auto const vals = cudf::slice(vals_original, {2, 10})[0]; // { 10, 10, 20, 21, 21, 20, 30, 33 } - auto const keys_expected = COL_K{1, 2, 3}; - auto const vals_expected = LCL_V{{10}, {20, 21}, {30, 33}}; - test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set()); + auto const keys_expected = keys_col{1, 2, 3}; + auto const vals_expected = lists_col{{10}, {20, 21}, {30, 33}}; + + auto const [out_keys, out_lists] = + groupby_collect_set(keys, vals, CollectSetTest::collect_set()); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(keys_expected, *out_keys, verbosity); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(vals_expected, *out_lists, verbosity); } } TEST_F(CollectSetTest, StringInput) { - COL_K keys{1, 2, 3, 3, 2, 1, 2, 1, 2, 1, 1, 1, 1}; - COL_S vals{ + keys_col keys{1, 2, 3, 3, 2, 1, 2, 1, 2, 1, 1, 1, 1}; + strings_col vals{ "String 1, first", "String 2, first", "String 3, first", @@ -143,112 +224,171 @@ TEST_F(CollectSetTest, StringInput) "String 1, second", // repeated "String 1, second" // repeated }; - COL_K keys_expected{1, 2, 3}; - LCL_S vals_expected{{"String 1, first", "String 1, second"}, - {"String 2, first", "String 2, second"}, - {"String 3, first", "String 3, second"}}; - test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set()); + keys_col keys_expected{1, 2, 3}; + strings_lists vals_expected{{"String 1, first", "String 1, second"}, + {"String 2, first", "String 2, second"}, + {"String 3, first", "String 3, second"}}; + + auto const [out_keys, out_lists] = groupby_collect_set(keys, vals, CollectSetTest::collect_set()); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(keys_expected, *out_keys, verbosity); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(vals_expected, *out_lists, verbosity); } TEST_F(CollectSetTest, FloatsWithNaN) { - COL_K keys{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + keys_col keys{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; cudf::test::fixed_width_column_wrapper vals{ {1.0f, 1.0f, -2.3e-5f, -2.3e-5f, 2.3e5f, 2.3e5f, -NAN, -NAN, NAN, NAN, 0.0f, 0.0f}, {true, true, true, true, true, true, true, true, true, true, false, false}}; - COL_K keys_expected{1}; + keys_col keys_expected{1}; + cudf::test::lists_column_wrapper vals_expected; + // null equal with nan unequal - cudf::test::lists_column_wrapper vals_expected{ - {{-2.3e-5f, 1.0f, 2.3e5f, -NAN, -NAN, NAN, NAN, 0.0f}, - VALIDITY{true, true, true, true, true, true, true, false}}, - }; - test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set()); + { + vals_expected = {{{-2.3e-5f, 1.0f, 2.3e5f, -NAN, -NAN, NAN, NAN, 0.0f}, + validity_col{true, true, true, true, true, true, true, false}}}; + auto const [out_keys, out_lists] = + groupby_collect_set(keys, vals, CollectSetTest::collect_set()); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(keys_expected, *out_keys, verbosity); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(vals_expected, *out_lists, verbosity); + } + // null unequal with nan unequal - vals_expected = {{{-2.3e-5f, 1.0f, 2.3e5f, -NAN, -NAN, NAN, NAN, 0.0f, 0.0f}, - VALIDITY{true, true, true, true, true, true, true, false, false}}}; - test_single_agg( - keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set_null_unequal()); + { + vals_expected = {{{-2.3e-5f, 1.0f, 2.3e5f, -NAN, -NAN, NAN, NAN, 0.0f, 0.0f}, + validity_col{true, true, true, true, true, true, true, false, false}}}; + auto const [out_keys, out_lists] = + groupby_collect_set(keys, vals, CollectSetTest::collect_set_null_unequal()); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(keys_expected, *out_keys, verbosity); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(vals_expected, *out_lists, verbosity); + } + // null exclude with nan unequal - vals_expected = {{-2.3e-5f, 1.0f, 2.3e5f, -NAN, -NAN, NAN, NAN}}; - test_single_agg( - keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set_null_exclude()); + { + vals_expected = {{-2.3e-5f, 1.0f, 2.3e5f, -NAN, -NAN, NAN, NAN}}; + auto const [out_keys, out_lists] = + groupby_collect_set(keys, vals, CollectSetTest::collect_set_null_exclude()); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(keys_expected, *out_keys, verbosity); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(vals_expected, *out_lists, verbosity); + } + // null equal with nan equal - vals_expected = {{{-2.3e-5f, 1.0f, 2.3e5f, NAN, 0.0f}, VALIDITY{true, true, true, true, false}}}; - test_single_agg(keys, - vals, - keys_expected, - vals_expected, - cudf::make_collect_set_aggregation( - null_policy::INCLUDE, null_equality::EQUAL, nan_equality::ALL_EQUAL)); + { + vals_expected = { + {{-2.3e-5f, 1.0f, 2.3e5f, NAN, 0.0f}, validity_col{true, true, true, true, false}}}; + auto const [out_keys, out_lists] = + groupby_collect_set(keys, + vals, + cudf::make_collect_set_aggregation( + null_policy::INCLUDE, null_equality::EQUAL, nan_equality::ALL_EQUAL)); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(keys_expected, *out_keys, verbosity); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(vals_expected, *out_lists, verbosity); + } + // null unequal with nan equal - vals_expected = { - {{-2.3e-5f, 1.0f, 2.3e5f, -NAN, 0.0f, 0.0f}, VALIDITY{true, true, true, true, false, false}}}; - test_single_agg(keys, - vals, - keys_expected, - vals_expected, - cudf::make_collect_set_aggregation( - null_policy::INCLUDE, null_equality::UNEQUAL, nan_equality::ALL_EQUAL)); + { + vals_expected = {{{-2.3e-5f, 1.0f, 2.3e5f, -NAN, 0.0f, 0.0f}, + validity_col{true, true, true, true, false, false}}}; + auto const [out_keys, out_lists] = + groupby_collect_set(keys, + vals, + cudf::make_collect_set_aggregation( + null_policy::INCLUDE, null_equality::UNEQUAL, nan_equality::ALL_EQUAL)); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(keys_expected, *out_keys, verbosity); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(vals_expected, *out_lists, verbosity); + } } TYPED_TEST(CollectSetTypedTest, CollectWithNulls) { + using vals_col = cudf::test::fixed_width_column_wrapper; + using lists_col = cudf::test::lists_column_wrapper; + // Just use an arbitrary value to store null entries // Using this alias variable will make the code look cleaner constexpr int32_t null = 0; // Pre-sorted keys { - COL_K keys{1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3}; - COL_V vals{{10, 10, null, null, 20, null, null, null, 30, 31, 30, 31}, - {true, true, false, false, true, false, false, false, true, true, true, true}}; - COL_K keys_expected{1, 2, 3}; + keys_col keys{1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3}; + vals_col vals{{10, 10, null, null, 20, null, null, null, 30, 31, 30, 31}, + {true, true, false, false, true, false, false, false, true, true, true, true}}; + keys_col keys_expected{1, 2, 3}; + lists_col vals_expected; // By default, nulls are consider equals, thus only one null is kept per key - LCL_V vals_expected{{{10, null}, VALIDITY{true, false}}, - {{20, null}, VALIDITY{true, false}}, - {{30, 31}, VALIDITY{true, true}}}; - test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set()); + { + vals_expected = {{{10, null}, validity_col{true, false}}, + {{20, null}, validity_col{true, false}}, + {{30, 31}, validity_col{true, true}}}; + auto const [out_keys, out_lists] = + groupby_collect_set(keys, vals, CollectSetTest::collect_set()); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(keys_expected, *out_keys, verbosity); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(vals_expected, *out_lists, verbosity); + } // All nulls per key are kept (nulls are put at the end of each list) - vals_expected = LCL_V{{{10, null, null}, VALIDITY{true, false, false}}, - {{20, null, null, null}, VALIDITY{true, false, false, false}}, - {{30, 31}, VALIDITY{true, true}}}; - test_single_agg( - keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set_null_unequal()); + { + vals_expected = lists_col{{{10, null, null}, validity_col{true, false, false}}, + {{20, null, null, null}, validity_col{true, false, false, false}}, + {{30, 31}, validity_col{true, true}}}; + auto const [out_keys, out_lists] = + groupby_collect_set(keys, vals, CollectSetTest::collect_set_null_unequal()); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(keys_expected, *out_keys, verbosity); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(vals_expected, *out_lists, verbosity); + } // All nulls per key are excluded - vals_expected = LCL_V{{10}, {20}, {30, 31}}; - test_single_agg( - keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set_null_exclude()); + { + vals_expected = lists_col{{10}, {20}, {30, 31}}; + auto const [out_keys, out_lists] = + groupby_collect_set(keys, vals, CollectSetTest::collect_set_null_exclude()); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(keys_expected, *out_keys, verbosity); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(vals_expected, *out_lists, verbosity); + } } // Expect the result keys to be sorted by sort-based groupby { - COL_K keys{4, 1, 2, 4, 3, 3, 3, 3, 2, 1}; - COL_V vals{{40, 10, 20, 40, null, null, null, null, 21, null}, - {true, true, true, true, false, false, false, false, true, false}}; - COL_K keys_expected{1, 2, 3, 4}; + keys_col keys{4, 1, 2, 4, 3, 3, 3, 3, 2, 1}; + vals_col vals{{40, 10, 20, 40, null, null, null, null, 21, null}, + {true, true, true, true, false, false, false, false, true, false}}; + keys_col keys_expected{1, 2, 3, 4}; + lists_col vals_expected; // By default, nulls are consider equals, thus only one null is kept per key - LCL_V vals_expected{{{10, null}, VALIDITY{true, false}}, - {{20, 21}, VALIDITY{true, true}}, - {{null}, VALIDITY{false}}, - {{40}, VALIDITY{true}}}; - test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set()); + { + vals_expected = {{{10, null}, validity_col{true, false}}, + {{20, 21}, validity_col{true, true}}, + {{null}, validity_col{false}}, + {{40}, validity_col{true}}}; + auto const [out_keys, out_lists] = + groupby_collect_set(keys, vals, CollectSetTest::collect_set()); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(keys_expected, *out_keys, verbosity); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(vals_expected, *out_lists, verbosity); + } // All nulls per key are kept (nulls are put at the end of each list) - vals_expected = LCL_V{{{10, null}, VALIDITY{true, false}}, - {{20, 21}, VALIDITY{true, true}}, - {{null, null, null, null}, VALIDITY{false, false, false, false}}, - {{40}, VALIDITY{true}}}; - test_single_agg( - keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set_null_unequal()); + { + vals_expected = + lists_col{{{10, null}, validity_col{true, false}}, + {{20, 21}, validity_col{true, true}}, + {{null, null, null, null}, validity_col{false, false, false, false}}, + {{40}, validity_col{true}}}; + auto const [out_keys, out_lists] = + groupby_collect_set(keys, vals, CollectSetTest::collect_set_null_unequal()); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(keys_expected, *out_keys, verbosity); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(vals_expected, *out_lists, verbosity); + } // All nulls per key are excluded - vals_expected = LCL_V{{10}, {20, 21}, {}, {40}}; - test_single_agg( - keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set_null_exclude()); + { + vals_expected = lists_col{{10}, {20, 21}, {}, {40}}; + auto const [out_keys, out_lists] = + groupby_collect_set(keys, vals, CollectSetTest::collect_set_null_exclude()); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(keys_expected, *out_keys, verbosity); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(vals_expected, *out_lists, verbosity); + } } } diff --git a/cpp/tests/groupby/merge_sets_tests.cpp b/cpp/tests/groupby/merge_sets_tests.cpp index 57f67f6b81a..55a25e40b53 100644 --- a/cpp/tests/groupby/merge_sets_tests.cpp +++ b/cpp/tests/groupby/merge_sets_tests.cpp @@ -23,6 +23,8 @@ #include #include #include +#include +#include #include using namespace cudf::test::iterators; @@ -45,9 +47,26 @@ auto merge_sets(vcol_views const& keys_cols, vcol_views const& values_cols) requests[0].aggregations.emplace_back( cudf::make_merge_sets_aggregation()); - auto gb_obj = cudf::groupby::groupby(cudf::table_view({*keys})); - auto result = gb_obj.aggregate(requests); - return std::pair(std::move(result.first->release()[0]), std::move(result.second[0].results[0])); + auto const result = cudf::groupby::groupby(cudf::table_view({*keys})).aggregate(requests); + auto const result_keys = result.first->view(); // <== table_view of 1 column + auto const result_vals = result.second[0].results[0]->view(); // <== column_view + + // Sort the output columns based on the output keys. + // This is to facilitate comparison of the output with the expected columns. + auto keys_vals_sorted = cudf::sort_by_key(cudf::table_view{{result_keys.column(0), result_vals}}, + result_keys, + {}, + {cudf::null_order::AFTER}) + ->release(); + + // After the columns were reordered, individual rows of the output values column (which are lists) + // also need to be sorted. + auto out_values = + cudf::lists::sort_lists(cudf::lists_column_view{keys_vals_sorted.back()->view()}, + cudf::order::ASCENDING, + cudf::null_order::AFTER); + + return std::pair(std::move(keys_vals_sorted.front()), std::move(out_values)); } } // namespace @@ -137,8 +156,6 @@ TYPED_TEST(GroupbyMergeSetsTypedTest, InputHasNulls) using keys_col = cudf::test::fixed_width_column_wrapper; using lists_col = cudf::test::lists_column_wrapper; - // Note that the null elements here are not sorted, while the results from current collect_list - // and collect_set are sorted. auto const keys1 = keys_col{1, 2}; auto const keys2 = keys_col{1, 3}; auto const keys3 = keys_col{2, 3, 4}; @@ -213,8 +230,6 @@ TYPED_TEST(GroupbyMergeSetsTypedTest, InputHasNullsAndEmptyLists) using keys_col = cudf::test::fixed_width_column_wrapper; using lists_col = cudf::test::lists_column_wrapper; - // Note that the null elements here are not sorted, while the results from current collect_list - // and collect_set are sorted. auto const keys1 = keys_col{1, 2, 3}; auto const keys2 = keys_col{1, 3, 4}; auto const keys3 = keys_col{2, 3, 4}; diff --git a/cpp/tests/reductions/collect_ops_tests.cpp b/cpp/tests/reductions/collect_ops_tests.cpp index d5b4c8e38f7..a0fdab5e994 100644 --- a/cpp/tests/reductions/collect_ops_tests.cpp +++ b/cpp/tests/reductions/collect_ops_tests.cpp @@ -21,11 +21,29 @@ #include #include +#include using namespace cudf::test::iterators; namespace cudf::test { +namespace { + +auto collect_set(cudf::column_view const& input, std::unique_ptr const& agg) +{ + auto const result_scalar = cudf::reduce(input, agg, data_type{type_id::LIST}); + + // The results of `collect_set` are unordered thus we need to sort them for comparison. + auto const result_sorted_table = + cudf::sort(cudf::table_view{{dynamic_cast(result_scalar.get())->view()}}, + {}, + {cudf::null_order::AFTER}); + + return std::make_unique(std::move(result_sorted_table->get_column(0))); +} + +} // namespace + template struct CollectTestFixedWidth : public cudf::test::BaseFixture { }; @@ -81,22 +99,22 @@ TYPED_TEST(CollectTestFixedWidth, CollectSet) null_policy::INCLUDE, null_equality::UNEQUAL, nan_equality::ALL_EQUAL); // test without nulls - auto const ret = cudf::reduce(col, null_eq, data_type{type_id::LIST}); + auto const ret = collect_set(col, null_eq); fw_wrapper expected{{0, 5, 64, 99, 120}}; CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, dynamic_cast(ret.get())->view()); // null exclude - auto const ret1 = cudf::reduce(col_with_null, null_exclude, data_type{type_id::LIST}); + auto const ret1 = collect_set(col_with_null, null_exclude); fw_wrapper expected1{{0, 5, 64, 99}}; CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected1, dynamic_cast(ret1.get())->view()); // null equal - auto const ret2 = cudf::reduce(col_with_null, null_eq, data_type{type_id::LIST}); + auto const ret2 = collect_set(col_with_null, null_eq); fw_wrapper expected2{{0, 5, 64, 99, -1}, {1, 1, 1, 1, 0}}; CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected2, dynamic_cast(ret2.get())->view()); // null unequal - auto const ret3 = cudf::reduce(col_with_null, null_unequal, data_type{type_id::LIST}); + auto const ret3 = collect_set(col_with_null, null_unequal); fw_wrapper expected3{{0, 5, 64, 99, -1, -1, -1}, {1, 1, 1, 1, 0, 0, 0}}; CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected3, dynamic_cast(ret3.get())->view()); } @@ -139,8 +157,7 @@ TYPED_TEST(CollectTestFixedWidth, MergeSets) // test without nulls auto const lists1 = lists_col{{1, 2, 3}, {}, {}, {4}, {1, 3, 4}, {0, 3, 10}, {}}; auto const expected1 = fw_wrapper{{0, 1, 2, 3, 4, 10}}; - auto const ret1 = cudf::reduce( - lists1, make_merge_sets_aggregation(), data_type{type_id::LIST}); + auto const ret1 = collect_set(lists1, make_merge_sets_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected1, dynamic_cast(ret1.get())->view()); // test with null_equal @@ -155,17 +172,14 @@ TYPED_TEST(CollectTestFixedWidth, MergeSets) }, null_at(5)}; auto const expected2 = fw_wrapper{{1, 2, 3, 4, 5, 0}, {1, 1, 1, 1, 1, 0}}; - auto const ret2 = cudf::reduce( - lists2, make_merge_sets_aggregation(), data_type{type_id::LIST}); + auto const ret2 = collect_set(lists2, make_merge_sets_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected2, dynamic_cast(ret2.get())->view()); // test with null_unequal auto const& lists3 = lists2; auto const expected3 = fw_wrapper{{1, 2, 3, 4, 5, 0, 0, 0, 0, 0}, {1, 1, 1, 1, 1, 0, 0, 0, 0, 0}}; auto const ret3 = - cudf::reduce(lists3, - make_merge_sets_aggregation(null_equality::UNEQUAL), - data_type{type_id::LIST}); + collect_set(lists3, make_merge_sets_aggregation(null_equality::UNEQUAL)); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected3, dynamic_cast(ret3.get())->view()); } @@ -182,35 +196,31 @@ TEST_F(CollectTest, CollectSetWithNaN) // nan unequal with null equal fp_wrapper expected1{{-2.3e-5f, 1.0f, 2.3e5f, -NAN, -NAN, NAN, NAN, 0.0f}, {1, 1, 1, 1, 1, 1, 1, 0}}; - auto const ret1 = - cudf::reduce(col, make_collect_set_aggregation(), data_type{type_id::LIST}); + auto const ret1 = collect_set(col, make_collect_set_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected1, dynamic_cast(ret1.get())->view()); // nan unequal with null unequal fp_wrapper expected2{{-2.3e-5f, 1.0f, 2.3e5f, -NAN, -NAN, NAN, NAN, 0.0f, 0.0f}, {1, 1, 1, 1, 1, 1, 1, 0, 0}}; - auto const ret2 = cudf::reduce( + auto const ret2 = collect_set( col, - make_collect_set_aggregation(null_policy::INCLUDE, null_equality::UNEQUAL), - data_type{type_id::LIST}); + make_collect_set_aggregation(null_policy::INCLUDE, null_equality::UNEQUAL)); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected2, dynamic_cast(ret2.get())->view()); // nan equal with null equal fp_wrapper expected3{{-2.3e-5f, 1.0f, 2.3e5f, NAN, 0.0f}, {1, 1, 1, 1, 0}}; auto const ret3 = - cudf::reduce(col, - make_collect_set_aggregation( - null_policy::INCLUDE, null_equality::EQUAL, nan_equality::ALL_EQUAL), - data_type{type_id::LIST}); + collect_set(col, + make_collect_set_aggregation( + null_policy::INCLUDE, null_equality::EQUAL, nan_equality::ALL_EQUAL)); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected3, dynamic_cast(ret3.get())->view()); // nan equal with null unequal fp_wrapper expected4{{-2.3e-5f, 1.0f, 2.3e5f, -NAN, 0.0f, 0.0f}, {1, 1, 1, 1, 0, 0}}; auto const ret4 = - cudf::reduce(col, - make_collect_set_aggregation( - null_policy::INCLUDE, null_equality::UNEQUAL, nan_equality::ALL_EQUAL), - data_type{type_id::LIST}); + collect_set(col, + make_collect_set_aggregation( + null_policy::INCLUDE, null_equality::UNEQUAL, nan_equality::ALL_EQUAL)); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected4, dynamic_cast(ret4.get())->view()); } @@ -229,33 +239,28 @@ TEST_F(CollectTest, MergeSetsWithNaN) // nan unequal with null equal fp_wrapper expected1{{-2.3e-5f, 1.0f, 2.3e5f, -NAN, NAN, NAN, 0.0f}, {1, 1, 1, 1, 1, 1, 0}}; - auto const ret1 = - cudf::reduce(col, make_merge_sets_aggregation(), data_type{type_id::LIST}); + auto const ret1 = collect_set(col, make_merge_sets_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected1, dynamic_cast(ret1.get())->view()); // nan unequal with null unequal fp_wrapper expected2{{-2.3e-5f, 1.0f, 2.3e5f, -NAN, NAN, NAN, 0.0f, 0.0f, 0.0f}, {1, 1, 1, 1, 1, 1, 0, 0, 0}}; auto const ret2 = - cudf::reduce(col, - make_merge_sets_aggregation(null_equality::UNEQUAL), - data_type{type_id::LIST}); + collect_set(col, make_merge_sets_aggregation(null_equality::UNEQUAL)); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected2, dynamic_cast(ret2.get())->view()); // nan equal with null equal fp_wrapper expected3{{-2.3e-5f, 1.0f, 2.3e5f, -NAN, 0.0f}, {1, 1, 1, 1, 0}}; - auto const ret3 = cudf::reduce( + auto const ret3 = collect_set( col, - make_merge_sets_aggregation(null_equality::EQUAL, nan_equality::ALL_EQUAL), - data_type{type_id::LIST}); + make_merge_sets_aggregation(null_equality::EQUAL, nan_equality::ALL_EQUAL)); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected3, dynamic_cast(ret3.get())->view()); // nan equal with null unequal fp_wrapper expected4{{-2.3e-5f, 1.0f, 2.3e5f, -NAN, 0.0f, 0.0f, 0.0f}, {1, 1, 1, 1, 0, 0, 0}}; - auto const ret4 = cudf::reduce(col, - make_merge_sets_aggregation( - null_equality::UNEQUAL, nan_equality::ALL_EQUAL), - data_type{type_id::LIST}); + auto const ret4 = collect_set(col, + make_merge_sets_aggregation( + null_equality::UNEQUAL, nan_equality::ALL_EQUAL)); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected4, dynamic_cast(ret4.get())->view()); } @@ -282,16 +287,14 @@ TEST_F(CollectTest, CollectStrings) // collect_set with null_equal auto const expected3 = str_col{{"a", "b", "c", "d", "e", ""}, null_at(5)}; - auto const ret3 = cudf::reduce( - s_col, make_collect_set_aggregation(), data_type{type_id::LIST}); + auto const ret3 = collect_set(s_col, make_collect_set_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected3, dynamic_cast(ret3.get())->view()); // collect_set with null_unequal auto const expected4 = str_col{{"a", "b", "c", "d", "e", "", ""}, {1, 1, 1, 1, 1, 0, 0}}; - auto const ret4 = cudf::reduce( + auto const ret4 = collect_set( s_col, - make_collect_set_aggregation(null_policy::INCLUDE, null_equality::UNEQUAL), - data_type{type_id::LIST}); + make_collect_set_aggregation(null_policy::INCLUDE, null_equality::UNEQUAL)); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected4, dynamic_cast(ret4.get())->view()); lists_col strings{{"a"}, @@ -311,17 +314,14 @@ TEST_F(CollectTest, CollectStrings) // merge_sets with null_equal auto const expected6 = str_col{{"a", "b", "c", "d", "e", "null"}, {1, 1, 1, 1, 1, 0}}; - auto const ret6 = cudf::reduce( - strings, make_merge_sets_aggregation(), data_type{type_id::LIST}); + auto const ret6 = collect_set(strings, make_merge_sets_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected6, dynamic_cast(ret6.get())->view()); // merge_sets with null_unequal auto const expected7 = str_col{{"a", "b", "c", "d", "e", "null", "null", "null"}, {1, 1, 1, 1, 1, 0, 0, 0}}; auto const ret7 = - cudf::reduce(strings, - make_merge_sets_aggregation(null_equality::UNEQUAL), - data_type{type_id::LIST}); + collect_set(strings, make_merge_sets_aggregation(null_equality::UNEQUAL)); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected7, dynamic_cast(ret7.get())->view()); } @@ -335,8 +335,7 @@ TEST_F(CollectTest, CollectEmptys) empty, make_collect_list_aggregation(), data_type{type_id::LIST}); CUDF_TEST_EXPECT_COLUMNS_EQUAL(int_col{}, dynamic_cast(ret.get())->view()); - ret = cudf::reduce( - empty, make_collect_set_aggregation(), data_type{type_id::LIST}); + ret = collect_set(empty, make_collect_set_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(int_col{}, dynamic_cast(ret.get())->view()); // test collect all null columns @@ -345,8 +344,7 @@ TEST_F(CollectTest, CollectEmptys) all_nulls, make_collect_list_aggregation(), data_type{type_id::LIST}); CUDF_TEST_EXPECT_COLUMNS_EQUAL(int_col{}, dynamic_cast(ret.get())->view()); - ret = cudf::reduce( - all_nulls, make_collect_set_aggregation(), data_type{type_id::LIST}); + ret = collect_set(all_nulls, make_collect_set_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(int_col{}, dynamic_cast(ret.get())->view()); } diff --git a/cpp/tests/rolling/collect_ops_test.cpp b/cpp/tests/rolling/collect_ops_test.cpp index dff48e998b4..a0af8f150e3 100644 --- a/cpp/tests/rolling/collect_ops_test.cpp +++ b/cpp/tests/rolling/collect_ops_test.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -1284,6 +1285,67 @@ TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowOnStructsWithMinPe CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } +// The results of `collect_set` are unordered lists. +// Thus, we have to sort the lists for comparison. +namespace { + +template +auto rolling_collect_set(cudf::column_view const& input, + WindowType const& preceding_window, + WindowType const& following_window, + cudf::size_type min_periods, + cudf::rolling_aggregation const& agg) +{ + auto const result = + cudf::rolling_window(input, preceding_window, following_window, min_periods, agg); + EXPECT_EQ(result->type().id(), cudf::type_id::LIST); + + return cudf::lists::sort_lists( + cudf::lists_column_view{result->view()}, cudf::order::ASCENDING, cudf::null_order::AFTER); +} + +template +auto grouped_rolling_collect_set(cudf::table_view const& group_keys, + cudf::column_view const& input, + WindowType const& preceding_window, + WindowType const& following_window, + cudf::size_type min_periods, + cudf::rolling_aggregation const& agg) +{ + auto const result = cudf::grouped_rolling_window( + group_keys, input, preceding_window, following_window, min_periods, agg); + EXPECT_EQ(result->type().id(), cudf::type_id::LIST); + + return cudf::lists::sort_lists( + cudf::lists_column_view{result->view()}, cudf::order::ASCENDING, cudf::null_order::AFTER); +} + +template +auto grouped_time_range_rolling_collect_set(cudf::table_view const& group_keys, + cudf::column_view const& timestamp_column, + cudf::order const& timestamp_order, + cudf::column_view const& input, + WindowType const& preceding_window_in_days, + WindowType const& following_window_in_days, + cudf::size_type min_periods, + cudf::rolling_aggregation const& agg) +{ + auto const result = cudf::grouped_time_range_rolling_window(group_keys, + timestamp_column, + timestamp_order, + input, + preceding_window_in_days, + following_window_in_days, + min_periods, + agg); + EXPECT_EQ(result->type().id(), cudf::type_id::LIST); + + return cudf::lists::sort_lists( + cudf::lists_column_view{result->view()}, cudf::order::ASCENDING, cudf::null_order::AFTER); +} + +} // namespace + struct CollectSetTest : public cudf::test::BaseFixture { }; @@ -1314,11 +1376,11 @@ TYPED_TEST(TypedCollectSetTest, BasicRollingWindow) static_cast(foll_column).size()); auto const result_column_based_window = - rolling_window(input_column, - prev_column, - foll_column, - 1, - *make_collect_set_aggregation()); + rolling_collect_set(input_column, + prev_column, + foll_column, + 1, + *make_collect_set_aggregation()); auto const expected_result = lists_column_wrapper{ @@ -1332,16 +1394,16 @@ TYPED_TEST(TypedCollectSetTest, BasicRollingWindow) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_column_based_window->view()); - auto const result_fixed_window = - rolling_window(input_column, 2, 1, 1, *make_collect_set_aggregation()); + auto const result_fixed_window = rolling_collect_set( + input_column, 2, 1, 1, *make_collect_set_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_fixed_window->view()); auto const result_with_nulls_excluded = - rolling_window(input_column, - 2, - 1, - 1, - *make_collect_set_aggregation(null_policy::EXCLUDE)); + rolling_collect_set(input_column, + 2, + 1, + 1, + *make_collect_set_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } @@ -1362,11 +1424,11 @@ TYPED_TEST(TypedCollectSetTest, RollingWindowWithEmptyOutputLists) static_cast(foll_column).size()); auto const result_column_based_window = - rolling_window(input_column, - prev_column, - foll_column, - 0, - *make_collect_set_aggregation()); + rolling_collect_set(input_column, + prev_column, + foll_column, + 0, + *make_collect_set_aggregation()); auto const expected_result = lists_column_wrapper{ @@ -1382,11 +1444,11 @@ TYPED_TEST(TypedCollectSetTest, RollingWindowWithEmptyOutputLists) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_column_based_window->view()); auto const result_with_nulls_excluded = - rolling_window(input_column, - prev_column, - foll_column, - 0, - *make_collect_set_aggregation(null_policy::EXCLUDE)); + rolling_collect_set(input_column, + prev_column, + foll_column, + 0, + *make_collect_set_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } @@ -1407,11 +1469,11 @@ TYPED_TEST(TypedCollectSetTest, RollingWindowHonoursMinPeriods) auto preceding = 2; auto following = 1; auto min_periods = 3; - auto const result = rolling_window(input_column, - preceding, - following, - min_periods, - *make_collect_set_aggregation()); + auto const result = rolling_collect_set(input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation()); auto const expected_result = lists_column_wrapper{ {{}, {0, 1, 2}, {1, 2}, {2, 4}, {2, 4, 5}, {}}, @@ -1422,11 +1484,11 @@ TYPED_TEST(TypedCollectSetTest, RollingWindowHonoursMinPeriods) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); auto const result_with_nulls_excluded = - rolling_window(input_column, - preceding, - following, - min_periods, - *make_collect_set_aggregation(null_policy::EXCLUDE)); + rolling_collect_set(input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); @@ -1434,11 +1496,11 @@ TYPED_TEST(TypedCollectSetTest, RollingWindowHonoursMinPeriods) following = 2; min_periods = 4; - auto result_2 = rolling_window(input_column, - preceding, - following, - min_periods, - *make_collect_set_aggregation()); + auto result_2 = rolling_collect_set(input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation()); auto expected_result_2 = lists_column_wrapper{ {{}, {0, 1, 2}, {1, 2, 4}, {2, 4, 5}, {}, {}}, cudf::detail::make_counting_transform_iterator(0, [num_elements](auto i) { @@ -1448,11 +1510,11 @@ TYPED_TEST(TypedCollectSetTest, RollingWindowHonoursMinPeriods) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_2->view(), result_2->view()); auto result_2_with_nulls_excluded = - rolling_window(input_column, - preceding, - following, - min_periods, - *make_collect_set_aggregation(null_policy::EXCLUDE)); + rolling_collect_set(input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_2->view(), result_2_with_nulls_excluded->view()); @@ -1472,11 +1534,11 @@ TEST_F(CollectSetTest, RollingWindowHonoursMinPeriodsOnStrings) auto preceding = 2; auto following = 1; auto min_periods = 3; - auto const result = rolling_window(input_column, - preceding, - following, - min_periods, - *make_collect_set_aggregation()); + auto const result = rolling_collect_set(input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation()); auto const expected_result = lists_column_wrapper{ {{}, {"0", "1", "2"}, {"1", "2"}, {"2", "4"}, {"2", "4"}, {}}, @@ -1487,11 +1549,11 @@ TEST_F(CollectSetTest, RollingWindowHonoursMinPeriodsOnStrings) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); auto const result_with_nulls_excluded = - rolling_window(input_column, - preceding, - following, - min_periods, - *make_collect_set_aggregation(null_policy::EXCLUDE)); + rolling_collect_set(input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); @@ -1499,11 +1561,11 @@ TEST_F(CollectSetTest, RollingWindowHonoursMinPeriodsOnStrings) following = 2; min_periods = 4; - auto result_2 = rolling_window(input_column, - preceding, - following, - min_periods, - *make_collect_set_aggregation()); + auto result_2 = rolling_collect_set(input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation()); auto expected_result_2 = lists_column_wrapper{ {{}, {"0", "1", "2"}, {"1", "2", "4"}, {"2", "4"}, {}, {}}, cudf::detail::make_counting_transform_iterator(0, [num_elements](auto i) { @@ -1513,11 +1575,11 @@ TEST_F(CollectSetTest, RollingWindowHonoursMinPeriodsOnStrings) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_2->view(), result_2->view()); auto result_2_with_nulls_excluded = - rolling_window(input_column, - preceding, - following, - min_periods, - *make_collect_set_aggregation(null_policy::EXCLUDE)); + rolling_collect_set(input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_2->view(), result_2_with_nulls_excluded->view()); @@ -1539,11 +1601,11 @@ TEST_F(CollectSetTest, RollingWindowHonoursMinPeriodsWithDecimal) auto preceding = 2; auto following = 1; auto min_periods = 3; - auto const result = rolling_window(input_column, - preceding, - following, - min_periods, - *make_collect_set_aggregation()); + auto const result = rolling_collect_set(input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation()); auto expected_result_child_values = std::vector{0, 1, 0, 1, 2, 1, 2, 3, 2, 3}; auto expected_result_child = @@ -1565,11 +1627,11 @@ TEST_F(CollectSetTest, RollingWindowHonoursMinPeriodsWithDecimal) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); auto const result_with_nulls_excluded = - rolling_window(input_column, - preceding, - following, - min_periods, - *make_collect_set_aggregation(null_policy::EXCLUDE)); + rolling_collect_set(input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); @@ -1580,11 +1642,11 @@ TEST_F(CollectSetTest, RollingWindowHonoursMinPeriodsWithDecimal) auto preceding = 2; auto following = 2; auto min_periods = 4; - auto const result = rolling_window(input_column, - preceding, - following, - min_periods, - *make_collect_set_aggregation()); + auto const result = rolling_collect_set(input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation()); auto expected_result_child_values = std::vector{0, 1, 2, 0, 1, 2, 3, 1, 2, 3}; auto expected_result_child = @@ -1606,11 +1668,11 @@ TEST_F(CollectSetTest, RollingWindowHonoursMinPeriodsWithDecimal) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); auto const result_with_nulls_excluded = - rolling_window(input_column, - preceding, - following, - min_periods, - *make_collect_set_aggregation(null_policy::EXCLUDE)); + rolling_collect_set(input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); @@ -1631,12 +1693,13 @@ TYPED_TEST(TypedCollectSetTest, BasicGroupedRollingWindow) auto const preceding = 2; auto const following = 1; auto const min_periods = 1; - auto const result = grouped_rolling_window(table_view{std::vector{group_column}}, - input_column, - preceding, - following, - min_periods, - *make_collect_set_aggregation()); + auto const result = + grouped_rolling_collect_set(table_view{std::vector{group_column}}, + input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation()); auto const expected_result = lists_column_wrapper{ @@ -1645,7 +1708,7 @@ TYPED_TEST(TypedCollectSetTest, BasicGroupedRollingWindow) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); - auto const result_with_nulls_excluded = grouped_rolling_window( + auto const result_with_nulls_excluded = grouped_rolling_collect_set( table_view{std::vector{group_column}}, input_column, preceding, @@ -1674,12 +1737,12 @@ TYPED_TEST(TypedCollectSetTest, BasicGroupedRollingWindowWithNulls) { // Nulls included and nulls are equal. auto const result = - grouped_rolling_window(table_view{std::vector{group_column}}, - input_column, - preceding, - following, - min_periods, - *make_collect_set_aggregation()); + grouped_rolling_collect_set(table_view{std::vector{group_column}}, + input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation()); // Null values are sorted to the tails of lists (sets) auto expected_child = fixed_width_column_wrapper{{ 10, 0, // row 0 @@ -1719,13 +1782,14 @@ TYPED_TEST(TypedCollectSetTest, BasicGroupedRollingWindowWithNulls) { // Nulls included and nulls are NOT equal. - auto const result = grouped_rolling_window(table_view{std::vector{group_column}}, - input_column, - preceding, - following, - min_periods, - *make_collect_set_aggregation( - null_policy::INCLUDE, null_equality::UNEQUAL)); + auto const result = + grouped_rolling_collect_set(table_view{std::vector{group_column}}, + input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation( + null_policy::INCLUDE, null_equality::UNEQUAL)); // Null values are sorted to the tails of lists (sets) auto expected_child = fixed_width_column_wrapper{{ 10, 0, // row 0 @@ -1765,7 +1829,7 @@ TYPED_TEST(TypedCollectSetTest, BasicGroupedRollingWindowWithNulls) { // Nulls excluded. - auto const result = grouped_rolling_window( + auto const result = grouped_rolling_collect_set( table_view{std::vector{group_column}}, input_column, preceding, @@ -1816,14 +1880,14 @@ TYPED_TEST(TypedCollectSetTest, BasicGroupedTimeRangeRollingWindow) auto const following = 1; auto const min_periods = 1; auto const result = - grouped_time_range_rolling_window(table_view{std::vector{group_column}}, - time_column, - cudf::order::ASCENDING, - input_column, - preceding, - following, - min_periods, - *make_collect_list_aggregation()); + grouped_time_range_rolling_collect_set(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation()); auto const expected_result = lists_column_wrapper{ {10, 11, 12, 13}, @@ -1838,7 +1902,7 @@ TYPED_TEST(TypedCollectSetTest, BasicGroupedTimeRangeRollingWindow) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); - auto const result_with_nulls_excluded = grouped_time_range_rolling_window( + auto const result_with_nulls_excluded = grouped_time_range_rolling_collect_set( table_view{std::vector{group_column}}, time_column, cudf::order::ASCENDING, @@ -1846,7 +1910,7 @@ TYPED_TEST(TypedCollectSetTest, BasicGroupedTimeRangeRollingWindow) preceding, following, min_periods, - *make_collect_list_aggregation(null_policy::EXCLUDE)); + *make_collect_set_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } @@ -1867,14 +1931,14 @@ TYPED_TEST(TypedCollectSetTest, GroupedTimeRangeRollingWindowWithNulls) auto const following = 1; auto const min_periods = 1; auto const result = - grouped_time_range_rolling_window(table_view{std::vector{group_column}}, - time_column, - cudf::order::ASCENDING, - input_column, - preceding, - following, - min_periods, - *make_collect_set_aggregation()); + grouped_time_range_rolling_collect_set(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation()); auto null_at_1 = null_at(1); auto null_at_3 = null_at(3); @@ -1894,7 +1958,7 @@ TYPED_TEST(TypedCollectSetTest, GroupedTimeRangeRollingWindowWithNulls) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); - auto const result_with_nulls_excluded = grouped_time_range_rolling_window( + auto const result_with_nulls_excluded = grouped_time_range_rolling_collect_set( table_view{std::vector{group_column}}, time_column, cudf::order::ASCENDING, @@ -1936,12 +2000,13 @@ TYPED_TEST(TypedCollectSetTest, SlicedGroupedRollingWindow) auto const preceding = 2; auto const following = 1; auto const min_periods = 1; - auto const result = grouped_rolling_window(table_view{std::vector{group_col}}, - input_col, - preceding, - following, - min_periods, - *make_collect_set_aggregation()); + auto const result = + grouped_rolling_collect_set(table_view{std::vector{group_col}}, + input_col, + preceding, + following, + min_periods, + *make_collect_set_aggregation()); auto const expected_result = lists_column_wrapper{{11, 13}, {11, 13}, {13}, {20, 21}, {20, 21}}.release(); @@ -1963,11 +2028,11 @@ TEST_F(CollectSetTest, BoolRollingWindow) static_cast(foll_column).size()); auto const result_column_based_window = - rolling_window(input_column, - prev_column, - foll_column, - 1, - *make_collect_set_aggregation()); + rolling_collect_set(input_column, + prev_column, + foll_column, + 1, + *make_collect_set_aggregation()); auto const expected_result = lists_column_wrapper{ @@ -1981,16 +2046,16 @@ TEST_F(CollectSetTest, BoolRollingWindow) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_column_based_window->view()); - auto const result_fixed_window = - rolling_window(input_column, 2, 1, 1, *make_collect_set_aggregation()); + auto const result_fixed_window = rolling_collect_set( + input_column, 2, 1, 1, *make_collect_set_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_fixed_window->view()); auto const result_with_nulls_excluded = - rolling_window(input_column, - 2, - 1, - 1, - *make_collect_set_aggregation(null_policy::EXCLUDE)); + rolling_collect_set(input_column, + 2, + 1, + 1, + *make_collect_set_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } @@ -2007,12 +2072,13 @@ TEST_F(CollectSetTest, BoolGroupedRollingWindow) auto const preceding = 2; auto const following = 1; auto const min_periods = 1; - auto const result = grouped_rolling_window(table_view{std::vector{group_column}}, - input_column, - preceding, - following, - min_periods, - *make_collect_set_aggregation()); + auto const result = + grouped_rolling_collect_set(table_view{std::vector{group_column}}, + input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation()); auto const expected_result = lists_column_wrapper{ {false, true}, @@ -2027,7 +2093,7 @@ TEST_F(CollectSetTest, BoolGroupedRollingWindow) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); - auto const result_with_nulls_excluded = grouped_rolling_window( + auto const result_with_nulls_excluded = grouped_rolling_collect_set( table_view{std::vector{group_column}}, input_column, preceding, @@ -2052,12 +2118,13 @@ TEST_F(CollectSetTest, FloatGroupedRollingWindowWithNaNs) auto const following = 1; auto const min_periods = 1; // test on nan_equality::UNEQUAL - auto const result = grouped_rolling_window(table_view{std::vector{group_column}}, - input_column, - preceding, - following, - min_periods, - *make_collect_set_aggregation()); + auto const result = + grouped_rolling_collect_set(table_view{std::vector{group_column}}, + input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation()); auto const expected_result = lists_column_wrapper{ {{0.2341, 1.23}, std::initializer_list{true, true}}, @@ -2075,14 +2142,14 @@ TEST_F(CollectSetTest, FloatGroupedRollingWindowWithNaNs) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); // test on nan_equality::ALL_EQUAL - auto const result_nan_equal = - grouped_rolling_window(table_view{std::vector{group_column}}, - input_column, - preceding, - following, - min_periods, - *make_collect_set_aggregation( - null_policy::INCLUDE, null_equality::EQUAL, nan_equality::ALL_EQUAL)); + auto const result_nan_equal = grouped_rolling_collect_set( + table_view{std::vector{group_column}}, + input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation( + null_policy::INCLUDE, null_equality::EQUAL, nan_equality::ALL_EQUAL)); auto const expected_result_nan_equal = lists_column_wrapper{ {{0.2341, 1.23}, std::initializer_list{true, true}}, @@ -2115,11 +2182,11 @@ TEST_F(CollectSetTest, BasicRollingWindowWithNaNs) static_cast(foll_column).size()); auto const result_column_based_window = - rolling_window(input_column, - prev_column, - foll_column, - 1, - *make_collect_set_aggregation()); + rolling_collect_set(input_column, + prev_column, + foll_column, + 1, + *make_collect_set_aggregation()); auto const expected_result = lists_column_wrapper{ @@ -2133,16 +2200,16 @@ TEST_F(CollectSetTest, BasicRollingWindowWithNaNs) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_column_based_window->view()); - auto const result_fixed_window = - rolling_window(input_column, 2, 1, 1, *make_collect_set_aggregation()); + auto const result_fixed_window = rolling_collect_set( + input_column, 2, 1, 1, *make_collect_set_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_fixed_window->view()); auto const result_with_nulls_excluded = - rolling_window(input_column, - 2, - 1, - 1, - *make_collect_set_aggregation(null_policy::EXCLUDE)); + rolling_collect_set(input_column, + 2, + 1, + 1, + *make_collect_set_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); @@ -2157,12 +2224,12 @@ TEST_F(CollectSetTest, BasicRollingWindowWithNaNs) .release(); auto const result_with_nan_equal = - rolling_window(input_column, - 2, - 1, - 1, - *make_collect_set_aggregation( - null_policy::INCLUDE, null_equality::EQUAL, nan_equality::ALL_EQUAL)); + rolling_collect_set(input_column, + 2, + 1, + 1, + *make_collect_set_aggregation( + null_policy::INCLUDE, null_equality::EQUAL, nan_equality::ALL_EQUAL)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_for_nan_equal->view(), result_with_nan_equal->view()); @@ -2190,11 +2257,11 @@ TEST_F(CollectSetTest, StructTypeRollingWindow) 0, {}); }(); - auto const result = rolling_window(input_column, - prev_column, - foll_column, - 1, - *make_collect_set_aggregation()); + auto const result = rolling_collect_set(input_column, + prev_column, + foll_column, + 1, + *make_collect_set_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected->view(), result->view()); } @@ -2208,10 +2275,10 @@ TEST_F(CollectSetTest, ListTypeRollingWindow) auto const prev_column = fixed_width_column_wrapper{1, 2, 2, 2, 2}; auto const foll_column = fixed_width_column_wrapper{1, 1, 1, 1, 0}; - EXPECT_THROW(rolling_window(input_column, - prev_column, - foll_column, - 1, - *make_collect_set_aggregation()), + EXPECT_THROW(rolling_collect_set(input_column, + prev_column, + foll_column, + 1, + *make_collect_set_aggregation()), cudf::logic_error); } diff --git a/java/src/test/java/ai/rapids/cudf/ReductionTest.java b/java/src/test/java/ai/rapids/cudf/ReductionTest.java index 2efd23703bc..cc172204ed3 100644 --- a/java/src/test/java/ai/rapids/cudf/ReductionTest.java +++ b/java/src/test/java/ai/rapids/cudf/ReductionTest.java @@ -399,13 +399,15 @@ private static Stream createFloatArrayParams() { } private static void assertEqualsDelta(ReductionAggregation op, Scalar expected, Scalar result, - Double percentage) { + Double percentage) { if (FLOAT_REDUCTIONS.contains(op.getWrapped().kind)) { assertEqualsWithinPercentage(expected.getDouble(), result.getDouble(), percentage); } else if (expected.getType().typeId == DType.DTypeEnum.LIST) { - try (ColumnView e = expected.getListAsColumnView(); - ColumnView r = result.getListAsColumnView()) { - AssertUtils.assertColumnsAreEqual(e, r); + try (ColumnVector expectedAsList = ColumnVector.fromScalar(expected, 1); + ColumnVector resultAsList = ColumnVector.fromScalar(result, 1); + ColumnVector expectedSorted = expectedAsList.listSortRows(false, false); + ColumnVector resultSorted = resultAsList.listSortRows(false, false)) { + AssertUtils.assertColumnsAreEqual(expectedSorted, resultSorted); } } else { assertEquals(expected, result); @@ -413,13 +415,15 @@ private static void assertEqualsDelta(ReductionAggregation op, Scalar expected, } private static void assertEqualsDelta(ReductionAggregation op, Scalar expected, Scalar result, - Float percentage) { + Float percentage) { if (FLOAT_REDUCTIONS.contains(op.getWrapped().kind)) { assertEqualsWithinPercentage(expected.getFloat(), result.getFloat(), percentage); } else if (expected.getType().typeId == DType.DTypeEnum.LIST) { - try (ColumnView e = expected.getListAsColumnView(); - ColumnView r = result.getListAsColumnView()) { - AssertUtils.assertColumnsAreEqual(e, r); + try (ColumnVector expectedAsList = ColumnVector.fromScalar(expected, 1); + ColumnVector resultAsList = ColumnVector.fromScalar(result, 1); + ColumnVector expectedSorted = expectedAsList.listSortRows(false, false); + ColumnVector resultSorted = resultAsList.listSortRows(false, false)) { + AssertUtils.assertColumnsAreEqual(expectedSorted, resultSorted); } } else { assertEquals(expected, result); diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 639d498d2f3..fbaead1e429 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -4776,38 +4776,42 @@ void testWindowingCollectSet() { // a) excluding NULLs try (Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindows(aggCollect.onColumn(3).overWindow(winOpts)); + ColumnVector resultSorted = windowAggResults.getColumn(0).listSortRows(false, false); ColumnVector expected = ColumnVector.fromLists( new ListType(false, new BasicType(false, DType.INT32)), Arrays.asList(5), Arrays.asList(1, 5), Arrays.asList(1, 5), Arrays.asList(1), Arrays.asList(1, 4), Arrays.asList(1, 3, 4), Arrays.asList(3, 4), Arrays.asList(3, 4), Arrays.asList(), Arrays.asList(6), Arrays.asList(6, 7), Arrays.asList(6, 7))) { - assertColumnsAreEqual(expected, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expected, resultSorted); } // b) including NULLs AND NULLs are equal try (Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindows(aggCollectWithEqNulls.onColumn(3).overWindow(winOpts)); + ColumnVector resultSorted = windowAggResults.getColumn(0).listSortRows(false, false); ColumnVector expected = ColumnVector.fromLists( new ListType(false, new BasicType(false, DType.INT32)), Arrays.asList(5), Arrays.asList(1, 5), Arrays.asList(1, 5), Arrays.asList(1), Arrays.asList(1, 4), Arrays.asList(1, 3, 4), Arrays.asList(3, 4), Arrays.asList(3, 4), Arrays.asList((Integer) null), Arrays.asList(6, null), Arrays.asList(6, 7, null), Arrays.asList(6, 7))) { - assertColumnsAreEqual(expected, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expected, resultSorted); } // c) including NULLs AND NULLs are unequal try (Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindows(aggCollectWithUnEqNulls.onColumn(3).overWindow(winOpts)); + ColumnVector resultSorted = windowAggResults.getColumn(0).listSortRows(false, false); ColumnVector expected = ColumnVector.fromLists( new ListType(false, new BasicType(false, DType.INT32)), Arrays.asList(5), Arrays.asList(1, 5), Arrays.asList(1, 5), Arrays.asList(1), Arrays.asList(1, 4), Arrays.asList(1, 3, 4), Arrays.asList(3, 4), Arrays.asList(3, 4), Arrays.asList(null, null), Arrays.asList(6, null, null), Arrays.asList(6, 7, null), Arrays.asList(6, 7))) { - assertColumnsAreEqual(expected, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expected, resultSorted); } // Primitive type: FLOAT64 // a) excluding NULLs try (Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindows(aggCollect.onColumn(4).overWindow(winOpts)); + ColumnVector resultSorted = windowAggResults.getColumn(0).listSortRows(false, false); ColumnVector expected = ColumnVector.fromLists( new ListType(false, new BasicType(false, DType.FLOAT64)), Arrays.asList(1.1), Arrays.asList(1.1), Arrays.asList(1.1, 2.2), Arrays.asList(2.2), @@ -4815,11 +4819,12 @@ void testWindowingCollectSet() { Arrays.asList(-3.0, 1.3e-7, Double.NaN), Arrays.asList(-3.0, Double.NaN), Arrays.asList(1e-3), Arrays.asList(1e-3, Double.NaN), Arrays.asList(Double.NaN, Double.NaN), Arrays.asList(Double.NaN, Double.NaN))) { - assertColumnsAreEqual(expected, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expected, resultSorted); } // b) including NULLs AND NULLs are equal try (Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindows(aggCollectWithEqNulls.onColumn(4).overWindow(winOpts)); + ColumnVector resultSorted = windowAggResults.getColumn(0).listSortRows(false, false); ColumnVector expected = ColumnVector.fromLists( new ListType(false, new BasicType(false, DType.FLOAT64)), Arrays.asList(1.1), Arrays.asList(1.1, null), Arrays.asList(1.1, 2.2, null), Arrays.asList(2.2, null), @@ -4827,11 +4832,12 @@ void testWindowingCollectSet() { Arrays.asList(-3.0, 1.3e-7, Double.NaN), Arrays.asList(-3.0, Double.NaN), Arrays.asList(1e-3, null), Arrays.asList(1e-3, Double.NaN, null), Arrays.asList(Double.NaN, Double.NaN, null), Arrays.asList(Double.NaN, Double.NaN))) { - assertColumnsAreEqual(expected, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expected, resultSorted); } // c) including NULLs AND NULLs are equal AND NaNs are equal try (Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindows(aggCollectWithEqNaNs.onColumn(4).overWindow(winOpts)); + ColumnVector resultSorted = windowAggResults.getColumn(0).listSortRows(false, false); ColumnVector expected = ColumnVector.fromLists( new ListType(false, new BasicType(false, DType.FLOAT64)), Arrays.asList(1.1), Arrays.asList(1.1, null), Arrays.asList(1.1, 2.2, null), Arrays.asList(2.2, null), @@ -4839,7 +4845,7 @@ void testWindowingCollectSet() { Arrays.asList(-3.0, 1.3e-7, Double.NaN), Arrays.asList(-3.0, Double.NaN), Arrays.asList(1e-3, null), Arrays.asList(1e-3, Double.NaN, null), Arrays.asList(Double.NaN, null), Arrays.asList(Double.NaN))) { - assertColumnsAreEqual(expected, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expected, resultSorted); } } } @@ -7045,8 +7051,10 @@ void testGroupByCollectSetIncludeNulls() { Arrays.asList(13, null, null), Arrays.asList(14, 15, null, null), Arrays.asList(1, 4), Arrays.asList(0)) .build(); - Table found = input.groupBy(0).aggregate(collectSet.onColumn(1))) { - assertTablesAreEqual(expected, found); + Table found = input.groupBy(0).aggregate(collectSet.onColumn(1)); + ColumnVector listsSorted = found.getColumn(1).listSortRows(false, false)) { + assertColumnsAreEqual(expected.getColumn(0), found.getColumn(0)); + assertColumnsAreEqual(expected.getColumn(1), listsSorted); } // test with null equal and nan unequal collectSet = GroupByAggregation.collectSet(NullPolicy.INCLUDE, @@ -7066,8 +7074,10 @@ void testGroupByCollectSetIncludeNulls() { Arrays.asList(1.0, Double.NaN, null), Arrays.asList((Integer) null)) .build(); - Table found = input.groupBy(0).aggregate(collectSet.onColumn(1))) { - assertTablesAreEqual(expected, found); + Table found = input.groupBy(0).aggregate(collectSet.onColumn(1)); + ColumnVector listsSorted = found.getColumn(1).listSortRows(false, false)) { + assertColumnsAreEqual(expected.getColumn(0), found.getColumn(0)); + assertColumnsAreEqual(expected.getColumn(1), listsSorted); } // test with null equal and nan equal collectSet = GroupByAggregation.collectSet(NullPolicy.INCLUDE, @@ -7087,8 +7097,10 @@ void testGroupByCollectSetIncludeNulls() { Arrays.asList(0.0), Arrays.asList(Double.NaN, (Integer) null)) .build(); - Table found = input.groupBy(0).aggregate(collectSet.onColumn(1))) { - assertTablesAreEqual(expected, found); + Table found = input.groupBy(0).aggregate(collectSet.onColumn(1)); + ColumnVector listsSorted = found.getColumn(1).listSortRows(false, false)) { + assertColumnsAreEqual(expected.getColumn(0), found.getColumn(0)); + assertColumnsAreEqual(expected.getColumn(1), listsSorted); } } @@ -7134,10 +7146,18 @@ void testGroupByMergeSets() { Table retListOfInts = input.groupBy(0).aggregate(GroupByAggregation.mergeSets().onColumn(1)); Table retListOfDoubles = input.groupBy(0).aggregate(GroupByAggregation.mergeSets().onColumn(2)); Table retListOfDoublesNaNEq = input.groupBy(0).aggregate( - GroupByAggregation.mergeSets(NullEquality.UNEQUAL, NaNEquality.ALL_EQUAL).onColumn(2))) { - assertTablesAreEqual(expectedListOfInts, retListOfInts); - assertTablesAreEqual(expectedListOfDoubles, retListOfDoubles); - assertTablesAreEqual(expectedListOfDoublesNaNEq, retListOfDoublesNaNEq); + GroupByAggregation.mergeSets(NullEquality.UNEQUAL, NaNEquality.ALL_EQUAL).onColumn(2)); + ColumnVector listsIntsSorted = retListOfInts.getColumn(1).listSortRows(false, false); + ColumnVector listsDoublesSorted = retListOfDoubles.getColumn(1).listSortRows(false, false); + ColumnVector listsDoublesNaNEqSorted = retListOfDoublesNaNEq.getColumn(1).listSortRows(false, false)) { + assertColumnsAreEqual(expectedListOfInts.getColumn(0), retListOfInts.getColumn(0)); + assertColumnsAreEqual(expectedListOfDoubles.getColumn(0), retListOfDoubles.getColumn(0)); + assertColumnsAreEqual(expectedListOfDoublesNaNEq.getColumn(0), retListOfDoublesNaNEq.getColumn(0)); + + assertColumnsAreEqual(expectedListOfInts.getColumn(1), listsIntsSorted); + assertColumnsAreEqual(expectedListOfDoubles.getColumn(1), listsDoublesSorted); + assertColumnsAreEqual(expectedListOfDoublesNaNEq.getColumn(1), listsDoublesNaNEqSorted); + } }