Skip to content

Commit

Permalink
Support min_by agg sort based
Browse files Browse the repository at this point in the history
    Signed-off-by: Haoyang Li <[email protected]>
  • Loading branch information
thirtiseven committed Jul 3, 2024
1 parent 25febbc commit 101a929
Show file tree
Hide file tree
Showing 13 changed files with 291 additions and 38 deletions.
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ add_library(
src/groupby/sort/group_m2.cu
src/groupby/sort/group_max.cu
src/groupby/sort/group_min.cu
src/groupby/sort/group_min_by.cu
src/groupby/sort/group_merge_lists.cu
src/groupby/sort/group_merge_m2.cu
src/groupby/sort/group_nth_element.cu
Expand Down
86 changes: 49 additions & 37 deletions cpp/include/cudf/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,43 +83,44 @@ class aggregation {
* @brief Possible aggregation operations
*/
enum Kind {
SUM, ///< sum reduction
PRODUCT, ///< product reduction
MIN, ///< min reduction
MAX, ///< max reduction
COUNT_VALID, ///< count number of valid elements
COUNT_ALL, ///< count number of elements
ANY, ///< any reduction
ALL, ///< all reduction
SUM_OF_SQUARES, ///< sum of squares reduction
MEAN, ///< arithmetic mean reduction
M2, ///< sum of squares of differences from the mean
VARIANCE, ///< variance
STD, ///< standard deviation
MEDIAN, ///< median reduction
QUANTILE, ///< compute specified quantile(s)
ARGMAX, ///< Index of max element
ARGMIN, ///< Index of min element
NUNIQUE, ///< count number of unique elements
NTH_ELEMENT, ///< get the nth element
ROW_NUMBER, ///< get row-number of current index (relative to rolling window)
EWMA, ///< get exponential weighted moving average at current index
RANK, ///< get rank of current index
COLLECT_LIST, ///< collect values into a list
COLLECT_SET, ///< collect values into a list without duplicate entries
LEAD, ///< window function, accesses row at specified offset following current row
LAG, ///< window function, accesses row at specified offset preceding current row
PTX, ///< PTX UDF based reduction
CUDA, ///< CUDA UDF based reduction
MERGE_LISTS, ///< merge multiple lists values into one list
MERGE_SETS, ///< merge multiple lists values into one list then drop duplicate entries
MERGE_M2, ///< merge partial values of M2 aggregation,
COVARIANCE, ///< covariance between two sets of elements
CORRELATION, ///< correlation between two sets of elements
TDIGEST, ///< create a tdigest from a set of input values
MERGE_TDIGEST, ///< create a tdigest by merging multiple tdigests together
HISTOGRAM, ///< compute frequency of each element
MERGE_HISTOGRAM ///< merge partial values of HISTOGRAM aggregation,
SUM, ///< sum reduction
PRODUCT, ///< product reduction
MIN, ///< min reduction
MAX, ///< max reduction
COUNT_VALID, ///< count number of valid elements
COUNT_ALL, ///< count number of elements
ANY, ///< any reduction
ALL, ///< all reduction
SUM_OF_SQUARES, ///< sum of squares reduction
MEAN, ///< arithmetic mean reduction
M2, ///< sum of squares of differences from the mean
VARIANCE, ///< variance
STD, ///< standard deviation
MEDIAN, ///< median reduction
QUANTILE, ///< compute specified quantile(s)
ARGMAX, ///< Index of max element
ARGMIN, ///< Index of min element
NUNIQUE, ///< count number of unique elements
NTH_ELEMENT, ///< get the nth element
ROW_NUMBER, ///< get row-number of current index (relative to rolling window)
EWMA, ///< get exponential weighted moving average at current index
RANK, ///< get rank of current index
COLLECT_LIST, ///< collect values into a list
COLLECT_SET, ///< collect values into a list without duplicate entries
LEAD, ///< window function, accesses row at specified offset following current row
LAG, ///< window function, accesses row at specified offset preceding current row
PTX, ///< PTX UDF based reduction
CUDA, ///< CUDA UDF based reduction
MERGE_LISTS, ///< merge multiple lists values into one list
MERGE_SETS, ///< merge multiple lists values into one list then drop duplicate entries
MERGE_M2, ///< merge partial values of M2 aggregation,
COVARIANCE, ///< covariance between two sets of elements
CORRELATION, ///< correlation between two sets of elements
TDIGEST, ///< create a tdigest from a set of input values
MERGE_TDIGEST, ///< create a tdigest by merging multiple tdigests together
HISTOGRAM, ///< compute frequency of each element
MERGE_HISTOGRAM, ///< merge partial values of HISTOGRAM aggregation,
MIN_BY ///< min reduction by another column
};

aggregation() = delete;
Expand Down Expand Up @@ -381,6 +382,17 @@ std::unique_ptr<Base> make_argmax_aggregation();
template <typename Base = aggregation>
std::unique_ptr<Base> make_argmin_aggregation();

/**
* @brief Factory to create a MIN_BY aggregation
*
* `MIN_BY` returns the value of the element in the group that is the minimum
* according to the order_by column.
*
* @return A MIN_BY aggregation object
*/
template <typename Base = aggregation>
std::unique_ptr<Base> make_min_by_aggregation();

/**
* @brief Factory to create a NUNIQUE aggregation
*
Expand Down
29 changes: 29 additions & 0 deletions cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class simple_aggregations_collector { // Declares the interface for the simple
class product_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class min_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class min_by_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class max_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
Expand Down Expand Up @@ -217,6 +219,25 @@ class min_aggregation final : public rolling_aggregation,
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
};

/**
* @brief Derived class for specifying a min_by aggregation
*/
class min_by_aggregation final : public groupby_aggregation, public reduce_aggregation {
public:
min_by_aggregation() : aggregation(MIN_BY) {}

[[nodiscard]] std::unique_ptr<aggregation> clone() const override
{
return std::make_unique<min_by_aggregation>(*this);
}
std::vector<std::unique_ptr<aggregation>> get_simple_aggregations(
data_type col_type, simple_aggregations_collector& collector) const override
{
return collector.visit(col_type, *this);
}
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
};

/**
* @brief Derived class for specifying a max aggregation
*/
Expand Down Expand Up @@ -1219,6 +1240,12 @@ struct target_type_impl<Source, aggregation::MIN> {
using type = Source;
};

// Computing MIN_BY of Source, use Source accumulator
template <typename Source>
struct target_type_impl<Source, aggregation::MIN_BY> {
using type = struct_view;
};

// Computing MAX of Source, use Source accumulator
template <typename Source>
struct target_type_impl<Source, aggregation::MAX> {
Expand Down Expand Up @@ -1517,6 +1544,8 @@ CUDF_HOST_DEVICE inline decltype(auto) aggregation_dispatcher(aggregation::Kind
return f.template operator()<aggregation::PRODUCT>(std::forward<Ts>(args)...);
case aggregation::MIN:
return f.template operator()<aggregation::MIN>(std::forward<Ts>(args)...);
case aggregation::MIN_BY:
return f.template operator()<aggregation::MIN_BY>(std::forward<Ts>(args)...);
case aggregation::MAX:
return f.template operator()<aggregation::MAX>(std::forward<Ts>(args)...);
case aggregation::COUNT_VALID:
Expand Down
15 changes: 15 additions & 0 deletions cpp/src/aggregation/aggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
return visit(col_type, static_cast<aggregation const&>(agg));
}

std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
data_type col_type, min_by_aggregation const& agg)
{
return visit(col_type, static_cast<aggregation const&>(agg));
}

std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
data_type col_type, max_aggregation const& agg)
{
Expand Down Expand Up @@ -637,6 +643,15 @@ template std::unique_ptr<aggregation> make_argmin_aggregation<aggregation>();
template std::unique_ptr<rolling_aggregation> make_argmin_aggregation<rolling_aggregation>();
template std::unique_ptr<groupby_aggregation> make_argmin_aggregation<groupby_aggregation>();

/// Factory to create a MIN_BY aggregation
template <typename Base>
std::unique_ptr<Base> make_min_by_aggregation()
{
return std::make_unique<detail::min_by_aggregation>();
}
template std::unique_ptr<aggregation> make_min_by_aggregation<aggregation>();
template std::unique_ptr<groupby_aggregation> make_min_by_aggregation<groupby_aggregation>();

/// Factory to create an NUNIQUE aggregation
template <typename Base>
std::unique_ptr<Base> make_nunique_aggregation(null_policy null_handling)
Expand Down
12 changes: 12 additions & 0 deletions cpp/src/groupby/sort/aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,18 @@ void aggregate_result_functor::operator()<aggregation::MIN>(aggregation const& a
cache.add_result(values, agg, std::move(result));
}

template <>
void aggregate_result_functor::operator()<aggregation::MIN_BY>(aggregation const& agg)
{
if (cache.has_result(values, agg)) return;

cache.add_result(
values,
agg,
detail::group_min_by(
get_grouped_values(), helper.group_labels(stream), helper.num_groups(stream), stream, mr));
}

template <>
void aggregate_result_functor::operator()<aggregation::MAX>(aggregation const& agg)
{
Expand Down
61 changes: 61 additions & 0 deletions cpp/src/groupby/sort/group_min_by.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "groupby/sort/group_single_pass_reduction_util.cuh"

#include <cudf/detail/gather.hpp>
#include <cudf/utilities/span.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/resource_ref.hpp>

#include <thrust/gather.h>

namespace cudf {
namespace groupby {
namespace detail {
std::unique_ptr<column> group_min_by(column_view const& structs_column,
cudf::device_span<size_type const> group_labels,
size_type num_groups,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
auto structs_view = cudf::structs_column_view{structs_column};
auto const orders = structs_view.get_sliced_child(1);

auto indices = type_dispatcher(orders.type(),
group_reduction_dispatcher<aggregation::ARGMIN>{},
orders,
num_groups,
group_labels,
stream,
mr);

auto indices_view = indices->mutable_view();

auto res = cudf::detail::gather(table_view{{structs_column}},
indices_view,
out_of_bounds_policy::NULLIFY,
cudf::detail::negative_index_policy::NOT_ALLOWED,
stream,
mr);

return std::move(res->release()[0]);
}

} // namespace detail
} // namespace groupby
} // namespace cudf
6 changes: 6 additions & 0 deletions cpp/src/groupby/sort/group_reductions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ std::unique_ptr<column> group_min(column_view const& values,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr);

std::unique_ptr<column> group_min_by(column_view const& structs_column,
cudf::device_span<size_type const> group_labels,
size_type num_groups,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr);

/**
* @brief Internal API to calculate groupwise maximum value
*
Expand Down
1 change: 1 addition & 0 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ ConfigureTest(
groupby/keys_tests.cpp
groupby/lists_tests.cpp
groupby/m2_tests.cpp
groupby/min_by_tests.cpp
groupby/min_tests.cpp
groupby/max_scan_tests.cpp
groupby/max_tests.cpp
Expand Down
77 changes: 77 additions & 0 deletions cpp/tests/groupby/min_by_tests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <tests/groupby/groupby_test_util.hpp>

#include <cudf_test/base_fixture.hpp>
#include <cudf_test/column_wrapper.hpp>
#include <cudf_test/iterator_utilities.hpp>
#include <cudf_test/type_lists.hpp>

#include <cudf/detail/aggregation/aggregation.hpp>

using namespace cudf::test::iterators;

template <typename V>
struct groupby_min_by_test : public cudf::test::BaseFixture {};
using K = int32_t;

TYPED_TEST_SUITE(groupby_min_by_test, cudf::test::FixedWidthTypes);

TYPED_TEST(groupby_min_by_test, basic)
{
using V = TypeParam;

if (std::is_same_v<V, bool>) return;

cudf::test::fixed_width_column_wrapper<K> keys{1, 2, 3, 1, 2, 2, 1, 3, 3, 2};
cudf::test::fixed_width_column_wrapper<K> values{4, 1, 2, 3, 4, 5, 6, 7, 8, 9};
cudf::test::fixed_width_column_wrapper<V> orders{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
cudf::test::structs_column_wrapper vals{values, orders};

cudf::test::fixed_width_column_wrapper<K> expect_keys{1, 2, 3};
cudf::test::fixed_width_column_wrapper<K> expect_values{4, 1, 2};
cudf::test::fixed_width_column_wrapper<V> expect_orders{1, 2, 3};
cudf::test::structs_column_wrapper expect_vals{expect_values, expect_orders};

auto agg = cudf::make_min_by_aggregation<cudf::groupby_aggregation>();
test_single_agg(keys, vals, expect_keys, expect_vals, std::move(agg));

auto agg2 = cudf::make_min_by_aggregation<cudf::groupby_aggregation>();
test_single_agg(keys, vals, expect_keys, expect_vals, std::move(agg2), force_use_sort_impl::YES);
}

struct groupby_min_by_string_test : public cudf::test::BaseFixture {};

TEST_F(groupby_min_by_string_test, basic)
{
cudf::test::fixed_width_column_wrapper<K> keys{1, 2, 3, 1, 2, 2, 1, 3, 3, 2};
cudf::test::fixed_width_column_wrapper<K> values{4, 1, 2, 3, 4, 5, 6, 7, 8, 9};
cudf::test::strings_column_wrapper orders{
"año", "bit", "₹1", "aaa", "zit", "bat", "aab", "$1", "€1", "wut"};
cudf::test::structs_column_wrapper vals{values, orders};

cudf::test::fixed_width_column_wrapper<K> expect_keys{1, 2, 3};
cudf::test::fixed_width_column_wrapper<K> expect_values{3, 5, 7};
cudf::test::strings_column_wrapper expect_orders{"aaa", "bat", "$1"};
cudf::test::structs_column_wrapper expect_vals{expect_values, expect_orders};

auto agg = cudf::make_min_by_aggregation<cudf::groupby_aggregation>();
test_single_agg(keys, vals, expect_keys, expect_vals, std::move(agg));

auto agg2 = cudf::make_min_by_aggregation<cudf::groupby_aggregation>();
test_single_agg(keys, vals, expect_keys, expect_vals, std::move(agg2), force_use_sort_impl::YES);
}
Loading

0 comments on commit 101a929

Please sign in to comment.