Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Merge pull request #1376 from allisonvacanti/bug/scan_by_key_moderniz…
Browse files Browse the repository at this point in the history
…e/gh.1374

Modernize scan_by_key functors / type deductions.
  • Loading branch information
alliepiper authored Feb 9, 2021
2 parents 730c3bb + f7f2129 commit 1ef0374
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 21 deletions.
69 changes: 69 additions & 0 deletions testing/scan_by_key.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <unittest/unittest.h>
#include <thrust/scan.h>
#include <thrust/functional.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/retag.h>
#include <thrust/random.h>
Expand Down Expand Up @@ -540,6 +541,74 @@ void TestScanByKeyMixedTypes(void)
DECLARE_UNITTEST(TestScanByKeyMixedTypes);


template <typename T>
void TestScanByKeyDiscardOutput(std::size_t n)
{
thrust::host_vector<T> h_keys(n);
thrust::default_random_engine rng;

for (size_t i = 0, k = 0; i < n; i++)
{
h_keys[i] = static_cast<T>(k);
if (rng() % 10 == 0)
{
k++;
}
}
thrust::device_vector<T> d_keys = h_keys;

thrust::host_vector<T> h_vals(n);
for(size_t i = 0; i < n; i++)
{
h_vals[i] = static_cast<T>(i % 10);
}
thrust::device_vector<T> d_vals = h_vals;

auto out = thrust::make_discard_iterator();

// These are no-ops, but they should compile.
thrust::exclusive_scan_by_key(d_keys.cbegin(),
d_keys.cend(),
d_vals.cbegin(),
out);
thrust::exclusive_scan_by_key(d_keys.cbegin(),
d_keys.cend(),
d_vals.cbegin(),
out,
T{});
thrust::exclusive_scan_by_key(d_keys.cbegin(),
d_keys.cend(),
d_vals.cbegin(),
out,
T{},
thrust::equal_to<T>{});
thrust::exclusive_scan_by_key(d_keys.cbegin(),
d_keys.cend(),
d_vals.cbegin(),
out,
T{},
thrust::equal_to<T>{},
thrust::multiplies<T>{});

thrust::inclusive_scan_by_key(d_keys.cbegin(),
d_keys.cend(),
d_vals.cbegin(),
out);
thrust::inclusive_scan_by_key(d_keys.cbegin(),
d_keys.cend(),
d_vals.cbegin(),
out,
thrust::equal_to<T>{});
thrust::inclusive_scan_by_key(d_keys.cbegin(),
d_keys.cend(),
d_vals.cbegin(),
out,
thrust::equal_to<T>{},
thrust::multiplies<T>{});
}
DECLARE_VARIABLE_UNITTEST(TestScanByKeyDiscardOutput);


void TestScanByKeyLargeInput()
{
const unsigned int N = 1 << 20;
Expand Down
14 changes: 7 additions & 7 deletions thrust/system/cuda/detail/scan_by_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -844,14 +844,14 @@ inclusive_scan_by_key(execution_policy<Derived> &policy,
ValOutputIt value_result,
BinaryPred binary_pred)
{
typedef typename thrust::iterator_traits<ValOutputIt>::value_type value_type;
typedef typename thrust::iterator_traits<ValInputIt>::value_type value_type;
return cuda_cub::inclusive_scan_by_key(policy,
key_first,
key_last,
value_first,
value_result,
binary_pred,
plus<value_type>());
thrust::plus<>());
}

template <class Derived,
Expand All @@ -871,7 +871,7 @@ inclusive_scan_by_key(execution_policy<Derived> &policy,
key_last,
value_first,
value_result,
equal_to<key_type>());
thrust::equal_to<>());
}


Expand Down Expand Up @@ -948,7 +948,7 @@ exclusive_scan_by_key(execution_policy<Derived> &policy,
value_result,
init,
binary_pred,
plus<Init>());
plus<>());
}

template <class Derived,
Expand All @@ -971,7 +971,7 @@ exclusive_scan_by_key(execution_policy<Derived> &policy,
value_first,
value_result,
init,
equal_to<key_type>());
equal_to<>());
}


Expand All @@ -986,13 +986,13 @@ exclusive_scan_by_key(execution_policy<Derived> &policy,
ValInputIt value_first,
ValOutputIt value_result)
{
typedef typename iterator_traits<ValOutputIt>::value_type value_type;
typedef typename iterator_traits<ValInputIt>::value_type value_type;
return cuda_cub::exclusive_scan_by_key(policy,
key_first,
key_last,
value_first,
value_result,
value_type(0));
value_type{});
}


Expand Down
19 changes: 9 additions & 10 deletions thrust/system/detail/generic/scan_by_key.inl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


#include <thrust/detail/config.h>
#include <thrust/detail/cstdint.h>
#include <thrust/system/detail/generic/scan_by_key.h>
#include <thrust/functional.h>
#include <thrust/transform.h>
Expand Down Expand Up @@ -71,8 +72,7 @@ __host__ __device__
InputIterator2 first2,
OutputIterator result)
{
typedef typename thrust::iterator_traits<InputIterator1>::value_type InputType1;
return thrust::inclusive_scan_by_key(exec, first1, last1, first2, result, thrust::equal_to<InputType1>());
return thrust::inclusive_scan_by_key(exec, first1, last1, first2, result, thrust::equal_to<>());
}


Expand Down Expand Up @@ -108,8 +108,8 @@ __host__ __device__
BinaryPredicate binary_pred,
AssociativeOperator binary_op)
{
typedef typename thrust::iterator_traits<OutputIterator>::value_type OutputType;
typedef unsigned int HeadFlagType;
using OutputType = typename thrust::iterator_traits<InputIterator2>::value_type;
using HeadFlagType = thrust::detail::uint32_t;

const size_t n = last1 - first1;

Expand Down Expand Up @@ -146,8 +146,8 @@ __host__ __device__
InputIterator2 first2,
OutputIterator result)
{
typedef typename thrust::iterator_traits<OutputIterator>::value_type OutputType;
return thrust::exclusive_scan_by_key(exec, first1, last1, first2, result, OutputType(0));
typedef typename thrust::iterator_traits<InputIterator2>::value_type InitType;
return thrust::exclusive_scan_by_key(exec, first1, last1, first2, result, InitType{});
}


Expand All @@ -164,8 +164,7 @@ __host__ __device__
OutputIterator result,
T init)
{
typedef typename thrust::iterator_traits<InputIterator1>::value_type InputType1;
return thrust::exclusive_scan_by_key(exec, first1, last1, first2, result, init, thrust::equal_to<InputType1>());
return thrust::exclusive_scan_by_key(exec, first1, last1, first2, result, init, thrust::equal_to<>());
}


Expand Down Expand Up @@ -205,8 +204,8 @@ __host__ __device__
BinaryPredicate binary_pred,
AssociativeOperator binary_op)
{
typedef typename thrust::iterator_traits<OutputIterator>::value_type OutputType;
typedef unsigned int HeadFlagType;
using OutputType = T;
using HeadFlagType = thrust::detail::uint32_t;

const size_t n = last1 - first1;

Expand Down
8 changes: 4 additions & 4 deletions thrust/system/detail/sequential/scan_by_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ __host__ __device__
BinaryPredicate binary_pred,
BinaryFunction binary_op)
{
typedef typename thrust::iterator_traits<InputIterator1>::value_type KeyType;
typedef typename thrust::iterator_traits<OutputIterator>::value_type ValueType;
using KeyType = typename thrust::iterator_traits<InputIterator1>::value_type;
using ValueType = typename thrust::iterator_traits<InputIterator2>::value_type;

// wrap binary_op
thrust::detail::wrapped_function<
Expand Down Expand Up @@ -105,8 +105,8 @@ __host__ __device__
BinaryPredicate binary_pred,
BinaryFunction binary_op)
{
typedef typename thrust::iterator_traits<InputIterator1>::value_type KeyType;
typedef typename thrust::iterator_traits<OutputIterator>::value_type ValueType;
using KeyType = typename thrust::iterator_traits<InputIterator1>::value_type;
using ValueType = T;

if(first1 != last1)
{
Expand Down

0 comments on commit 1ef0374

Please sign in to comment.