Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace CUB iterators by Thrust ones #3480

Merged
merged 2 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions cub/cub/device/device_run_length_encode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
#include <cub/device/dispatch/dispatch_reduce_by_key.cuh>
#include <cub/device/dispatch/dispatch_rle.cuh>
#include <cub/device/dispatch/tuning/tuning_run_length_encode.cuh>
#include <cub/iterator/constant_input_iterator.cuh>

#include <thrust/iterator/constant_iterator.h>

#include <iterator>

Expand Down Expand Up @@ -200,17 +201,14 @@ struct DeviceRunLengthEncode
using length_t = cub::detail::non_void_value_t<LengthsOutputIteratorT, offset_t>;

// Generator type for providing 1s values for run-length reduction
_CCCL_SUPPRESS_DEPRECATED_PUSH
using lengths_input_iterator_t = ConstantInputIterator<length_t, offset_t>;
_CCCL_SUPPRESS_DEPRECATED_POP
using lengths_input_iterator_t = THRUST_NS_QUALIFIER::constant_iterator<length_t, offset_t>;

using accum_t = ::cuda::std::__accumulator_t<reduction_op, length_t, length_t>;

using key_t = cub::detail::non_void_value_t<UniqueOutputIteratorT, cub::detail::value_t<InputIteratorT>>;

using policy_t = detail::rle::encode::policy_hub<accum_t, key_t>;

_CCCL_SUPPRESS_DEPRECATED_PUSH
return DispatchReduceByKey<
InputIteratorT,
UniqueOutputIteratorT,
Expand All @@ -232,7 +230,6 @@ struct DeviceRunLengthEncode
reduction_op(),
num_items,
stream);
_CCCL_SUPPRESS_DEPRECATED_POP
}

//! @rst
Expand Down
17 changes: 2 additions & 15 deletions cub/cub/device/dispatch/dispatch_streaming_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

#include <cub/device/dispatch/dispatch_reduce.cuh>
#include <cub/iterator/arg_index_input_iterator.cuh>
#include <cub/iterator/constant_input_iterator.cuh>

#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/iterator_adaptor.h>
#include <thrust/iterator/tabulate_output_iterator.h>

Expand All @@ -25,8 +25,6 @@

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document

// suppress deprecation warnings for ConstantInputIterator
_CCCL_SUPPRESS_DEPRECATED_PUSH
CUB_NAMESPACE_BEGIN

namespace detail::reduce
Expand Down Expand Up @@ -189,12 +187,6 @@ template <typename InputIteratorT,
detail::reduce::policy_hub<KeyValuePair<PerPartitionOffsetT, InitT>, PerPartitionOffsetT, ReductionOpT>>
struct dispatch_streaming_arg_reduce_t
{
# if _CCCL_COMPILER(NVHPC)
// NVHPC fails to suppress a deprecation when the alias is inside the function below, so we put it here and span a
// deprecation suppression region across the entire file as well
using constant_offset_it_t = ConstantInputIterator<GlobalOffsetT>;
# endif // _CCCL_COMPILER(NVHPC)

// Internal dispatch routine for computing a device-wide argument extremum, like `ArgMin` and `ArgMax`
//
// @param[in] d_temp_storage
Expand Down Expand Up @@ -234,11 +226,7 @@ struct dispatch_streaming_arg_reduce_t
cudaStream_t stream)
{
// Constant iterator to provide the offset of the current partition for the user-provided input iterator
# if !_CCCL_COMPILER(NVHPC)
_CCCL_SUPPRESS_DEPRECATED_PUSH
using constant_offset_it_t = ConstantInputIterator<GlobalOffsetT>;
_CCCL_SUPPRESS_DEPRECATED_POP
# endif
using constant_offset_it_t = THRUST_NS_QUALIFIER::constant_iterator<GlobalOffsetT>;

// Wrapped input iterator to produce index-value tuples, i.e., <PerPartitionOffsetT, InputT>-tuples
// We make sure to offset the user-provided input iterator by the current partition's offset
Expand Down Expand Up @@ -382,7 +370,6 @@ struct dispatch_streaming_arg_reduce_t
};

} // namespace detail::reduce
_CCCL_SUPPRESS_DEPRECATED_POP
CUB_NAMESPACE_END

#endif // !_CCCL_DOXYGEN_INVOKED
10 changes: 9 additions & 1 deletion cub/cub/util_type.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@

#include <cub/detail/uninitialized_copy.cuh>

#include <thrust/iterator/discard_iterator.h>

#include <cuda/std/cstdint>
#include <cuda/std/limits>
#include <cuda/std/type_traits>
Expand Down Expand Up @@ -107,7 +109,13 @@ struct non_void_value_impl
template <typename It, typename FallbackT>
struct non_void_value_impl<It, FallbackT, false>
{
using type = ::cuda::std::_If<::cuda::std::is_void<value_t<It>>::value, FallbackT, value_t<It>>;
// we consider thrust::discard_iterator's value_type as `void` as well, so users can switch from
// cub::DiscardInputIterator to thrust::discard_iterator.
using type =
::cuda::std::_If<::cuda::std::is_void<value_t<It>>::value
|| ::cuda::std::is_same<value_t<It>, THRUST_NS_QUALIFIER::discard_iterator<>::value_type>::value,
FallbackT,
value_t<It>>;
};

/**
Expand Down
4 changes: 0 additions & 4 deletions cub/test/catch2_test_device_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@
#include <c2h/custom_type.h>
#include <c2h/extended_types.h>

// need to suppress deprecation warnings for ConstantInputIterator in the cudafe1.stub.c file, so there is no matching
// _CCCL_SUPPRESS_DEPRECATED_POP at the end of this file
_CCCL_SUPPRESS_DEPRECATED_PUSH

DECLARE_LAUNCH_WRAPPER(cub::DeviceReduce::Reduce, device_reduce);
DECLARE_LAUNCH_WRAPPER(cub::DeviceReduce::Sum, device_sum);
DECLARE_LAUNCH_WRAPPER(cub::DeviceReduce::Min, device_min);
Expand Down
3 changes: 0 additions & 3 deletions cub/test/catch2_test_device_reduce_fp_inf.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ DECLARE_LAUNCH_WRAPPER(cub::DeviceReduce::ArgMin, device_arg_min_old);
DECLARE_LAUNCH_WRAPPER(cub::DeviceReduce::ArgMax, device_arg_max_old);
_CCCL_SUPPRESS_DEPRECATED_POP

// suppress deprecation of ConstantInputIterator in cudafe1.stub.c file
_CCCL_SUPPRESS_DEPRECATED_PUSH

// %PARAM% TEST_LAUNCH lid 0:1

C2H_TEST("Device reduce arg{min,max} works with inf items", "[reduce][device]")
Expand Down
3 changes: 0 additions & 3 deletions cub/test/catch2_test_device_reduce_large_offsets.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ DECLARE_LAUNCH_WRAPPER(cub::DeviceReduce::ArgMin, device_arg_min);
DECLARE_LAUNCH_WRAPPER(cub::DeviceReduce::Max, device_max);
DECLARE_LAUNCH_WRAPPER(cub::DeviceReduce::ArgMax, device_arg_max);

// suppress deprecation of ConstantInputIterator in cudafe1.stub.c file
_CCCL_SUPPRESS_DEPRECATED_PUSH

// %PARAM% TEST_LAUNCH lid 0:1:2

// List of offset types to test
Expand Down
14 changes: 0 additions & 14 deletions cub/test/catch2_test_device_run_length_encode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,6 @@
*
******************************************************************************/

#include <cuda/__cccl_config>

#if _CCCL_COMPILER(NVHPC)
// to suppress warnings for CountingInputIterator
_CCCL_SUPPRESS_DEPRECATED_PUSH
#endif // _CCCL_COMPILER(NVHPC)

#include "insert_nested_NVTX_range_guard.h"
// above header needs to be included first

Expand All @@ -50,9 +43,6 @@ _CCCL_SUPPRESS_DEPRECATED_PUSH

DECLARE_LAUNCH_WRAPPER(cub::DeviceRunLengthEncode::Encode, run_length_encode);

// suppress deprecation of ConstantInputIterator in cudafe1.stub.c file
_CCCL_SUPPRESS_DEPRECATED_PUSH

// %PARAM% TEST_LAUNCH lid 0:1:2

using all_types =
Expand Down Expand Up @@ -274,7 +264,3 @@ C2H_TEST("DeviceRunLengthEncode::Encode can handle leading NaN", "[device][run_l
REQUIRE(out_counts == reference_counts);
REQUIRE(out_num_runs == reference_num_runs);
}

#if _CCCL_COMPILER(NVHPC)
_CCCL_SUPPRESS_DEPRECATED_POP
#endif // _CCCL_COMPILER(NVHPC)
11 changes: 5 additions & 6 deletions cub/test/catch2_test_util_type.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,21 @@
*
******************************************************************************/

#include <cub/iterator/counting_input_iterator.cuh>
#include <cub/iterator/discard_output_iterator.cuh>
#include <cub/util_type.cuh>

#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/discard_iterator.h>

#include <cuda/std/type_traits>

#include <c2h/catch2_test_helper.h>
#include <c2h/extended_types.h>

C2H_TEST("Tests non_void_value_t", "[util][type]")
{
_CCCL_SUPPRESS_DEPRECATED_PUSH
using fallback_t = float;
using void_fancy_it = cub::DiscardOutputIterator<std::size_t>;
using non_void_fancy_it = cub::CountingInputIterator<int>;
using void_fancy_it = thrust::discard_iterator<std::size_t>;
using non_void_fancy_it = thrust::counting_iterator<int>;

// falls back for const void*
STATIC_REQUIRE(::cuda::std::is_same<fallback_t, //
Expand All @@ -62,7 +62,6 @@ C2H_TEST("Tests non_void_value_t", "[util][type]")
// works for a fancy iterator that has int as value type
STATIC_REQUIRE(::cuda::std::is_same<int, //
cub::detail::non_void_value_t<non_void_fancy_it, fallback_t>>::value);
_CCCL_SUPPRESS_DEPRECATED_POP
}

CUB_DEFINE_DETECT_NESTED_TYPE(cat_detect, cat);
Expand Down
Loading