Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Modernize scan_by_key functors / type deductions. #1376

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions testing/scan_by_key.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <unittest/unittest.h>
#include <thrust/scan.h>
#include <thrust/functional.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/retag.h>
#include <thrust/random.h>
Expand Down Expand Up @@ -540,6 +541,74 @@ void TestScanByKeyMixedTypes(void)
DECLARE_UNITTEST(TestScanByKeyMixedTypes);


template <typename T>
void TestScanByKeyDiscardOutput(std::size_t n)
{
thrust::host_vector<T> h_keys(n);
thrust::default_random_engine rng;

for (size_t i = 0, k = 0; i < n; i++)
{
h_keys[i] = static_cast<T>(k);
if (rng() % 10 == 0)
{
k++;
}
}
thrust::device_vector<T> d_keys = h_keys;

thrust::host_vector<T> h_vals(n);
for(size_t i = 0; i < n; i++)
{
h_vals[i] = static_cast<T>(i % 10);
}
thrust::device_vector<T> d_vals = h_vals;

auto out = thrust::make_discard_iterator();

// These are no-ops, but they should compile.
thrust::exclusive_scan_by_key(d_keys.cbegin(),
d_keys.cend(),
d_vals.cbegin(),
out);
thrust::exclusive_scan_by_key(d_keys.cbegin(),
d_keys.cend(),
d_vals.cbegin(),
out,
T{});
thrust::exclusive_scan_by_key(d_keys.cbegin(),
d_keys.cend(),
d_vals.cbegin(),
out,
T{},
thrust::equal_to<T>{});
thrust::exclusive_scan_by_key(d_keys.cbegin(),
d_keys.cend(),
d_vals.cbegin(),
out,
T{},
thrust::equal_to<T>{},
thrust::multiplies<T>{});

thrust::inclusive_scan_by_key(d_keys.cbegin(),
d_keys.cend(),
d_vals.cbegin(),
out);
thrust::inclusive_scan_by_key(d_keys.cbegin(),
d_keys.cend(),
d_vals.cbegin(),
out,
thrust::equal_to<T>{});
thrust::inclusive_scan_by_key(d_keys.cbegin(),
d_keys.cend(),
d_vals.cbegin(),
out,
thrust::equal_to<T>{},
thrust::multiplies<T>{});
}
DECLARE_VARIABLE_UNITTEST(TestScanByKeyDiscardOutput);


void TestScanByKeyLargeInput()
{
const unsigned int N = 1 << 20;
Expand Down
14 changes: 7 additions & 7 deletions thrust/system/cuda/detail/scan_by_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -844,14 +844,14 @@ inclusive_scan_by_key(execution_policy<Derived> &policy,
ValOutputIt value_result,
BinaryPred binary_pred)
{
typedef typename thrust::iterator_traits<ValOutputIt>::value_type value_type;
typedef typename thrust::iterator_traits<ValInputIt>::value_type value_type;
return cuda_cub::inclusive_scan_by_key(policy,
key_first,
key_last,
value_first,
value_result,
binary_pred,
plus<value_type>());
thrust::plus<>());
}

template <class Derived,
Expand All @@ -871,7 +871,7 @@ inclusive_scan_by_key(execution_policy<Derived> &policy,
key_last,
value_first,
value_result,
equal_to<key_type>());
thrust::equal_to<>());
}


Expand Down Expand Up @@ -948,7 +948,7 @@ exclusive_scan_by_key(execution_policy<Derived> &policy,
value_result,
init,
binary_pred,
plus<Init>());
plus<>());
}

template <class Derived,
Expand All @@ -971,7 +971,7 @@ exclusive_scan_by_key(execution_policy<Derived> &policy,
value_first,
value_result,
init,
equal_to<key_type>());
equal_to<>());
}


Expand All @@ -986,13 +986,13 @@ exclusive_scan_by_key(execution_policy<Derived> &policy,
ValInputIt value_first,
ValOutputIt value_result)
{
typedef typename iterator_traits<ValOutputIt>::value_type value_type;
typedef typename iterator_traits<ValInputIt>::value_type value_type;
return cuda_cub::exclusive_scan_by_key(policy,
key_first,
key_last,
value_first,
value_result,
value_type(0));
value_type{});
}


Expand Down
19 changes: 9 additions & 10 deletions thrust/system/detail/generic/scan_by_key.inl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


#include <thrust/detail/config.h>
#include <thrust/detail/cstdint.h>
#include <thrust/system/detail/generic/scan_by_key.h>
#include <thrust/functional.h>
#include <thrust/transform.h>
Expand Down Expand Up @@ -71,8 +72,7 @@ __host__ __device__
InputIterator2 first2,
OutputIterator result)
{
typedef typename thrust::iterator_traits<InputIterator1>::value_type InputType1;
return thrust::inclusive_scan_by_key(exec, first1, last1, first2, result, thrust::equal_to<InputType1>());
return thrust::inclusive_scan_by_key(exec, first1, last1, first2, result, thrust::equal_to<>());
}


Expand Down Expand Up @@ -108,8 +108,8 @@ __host__ __device__
BinaryPredicate binary_pred,
AssociativeOperator binary_op)
{
typedef typename thrust::iterator_traits<OutputIterator>::value_type OutputType;
typedef unsigned int HeadFlagType;
using OutputType = typename thrust::iterator_traits<InputIterator2>::value_type;
using HeadFlagType = thrust::detail::uint32_t;

const size_t n = last1 - first1;

Expand Down Expand Up @@ -146,8 +146,8 @@ __host__ __device__
InputIterator2 first2,
OutputIterator result)
{
typedef typename thrust::iterator_traits<OutputIterator>::value_type OutputType;
return thrust::exclusive_scan_by_key(exec, first1, last1, first2, result, OutputType(0));
typedef typename thrust::iterator_traits<InputIterator2>::value_type InitType;
return thrust::exclusive_scan_by_key(exec, first1, last1, first2, result, InitType{});
}


Expand All @@ -164,8 +164,7 @@ __host__ __device__
OutputIterator result,
T init)
{
typedef typename thrust::iterator_traits<InputIterator1>::value_type InputType1;
return thrust::exclusive_scan_by_key(exec, first1, last1, first2, result, init, thrust::equal_to<InputType1>());
return thrust::exclusive_scan_by_key(exec, first1, last1, first2, result, init, thrust::equal_to<>());
}


Expand Down Expand Up @@ -205,8 +204,8 @@ __host__ __device__
BinaryPredicate binary_pred,
AssociativeOperator binary_op)
{
typedef typename thrust::iterator_traits<OutputIterator>::value_type OutputType;
typedef unsigned int HeadFlagType;
using OutputType = T;
using HeadFlagType = thrust::detail::uint32_t;

const size_t n = last1 - first1;

Expand Down
8 changes: 4 additions & 4 deletions thrust/system/detail/sequential/scan_by_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ __host__ __device__
BinaryPredicate binary_pred,
BinaryFunction binary_op)
{
typedef typename thrust::iterator_traits<InputIterator1>::value_type KeyType;
typedef typename thrust::iterator_traits<OutputIterator>::value_type ValueType;
using KeyType = typename thrust::iterator_traits<InputIterator1>::value_type;
using ValueType = typename thrust::iterator_traits<InputIterator2>::value_type;

// wrap binary_op
thrust::detail::wrapped_function<
Expand Down Expand Up @@ -105,8 +105,8 @@ __host__ __device__
BinaryPredicate binary_pred,
BinaryFunction binary_op)
{
typedef typename thrust::iterator_traits<InputIterator1>::value_type KeyType;
typedef typename thrust::iterator_traits<OutputIterator>::value_type ValueType;
using KeyType = typename thrust::iterator_traits<InputIterator1>::value_type;
using ValueType = T;

if(first1 != last1)
{
Expand Down