Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Fix device-side launch
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed Dec 4, 2021
1 parent 40b41e6 commit 20e1ff4
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 53 deletions.
4 changes: 2 additions & 2 deletions cub/device/dispatch/dispatch_segmented_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1386,7 +1386,7 @@ private:
return error;
}

if (CubDebug(error = cudaStreamSynchronize(stream)))
if (CubDebug(error = SyncStream(stream)))
{
return error;
}
Expand All @@ -1395,7 +1395,7 @@ private:
else
{
#if CUB_INCLUDE_DEVICE_CODE
if (CubDebug(error = cudaStreamSynchronize(stream)))
if (CubDebug(error = SyncStream(stream)))
{
return error;
}
Expand Down
241 changes: 190 additions & 51 deletions test/test_device_segmented_sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,8 @@ private:
}
};

template <typename KeyT>
template <typename KeyT,
bool IsIntegralType = std::is_integral<KeyT>::value>
class InputDescription
{
thrust::host_vector<int> segment_sizes;
Expand All @@ -433,8 +434,8 @@ public:
}
};

template <>
class InputDescription<float>
template <typename KeyT>
class InputDescription<KeyT, false>
{
thrust::host_vector<int> segment_sizes;

Expand All @@ -450,9 +451,9 @@ public:
}

template <typename ValueT = cub::NullType>
Input<float, ValueT> gen(bool reverse)
Input<KeyT, ValueT> gen(bool reverse)
{
return Input<float, ValueT>(reverse, segment_sizes);
return Input<KeyT, ValueT>(reverse, segment_sizes);
}
};

Expand Down Expand Up @@ -1386,7 +1387,7 @@ void InputTestRandom(Input<KeyT, ValueT> &input)
{
for (bool sort_buffers: { pointers, double_buffer })
{
for (int iteration = 0; iteration < MAX_ITERATIONS / 10; iteration++)
for (int iteration = 0; iteration < MAX_ITERATIONS; iteration++)
{
RandomizeInput(h_keys, h_values);

Expand Down Expand Up @@ -1480,7 +1481,8 @@ AssertTrue(keys_ok);
}

template <typename KeyT,
typename ValueT>
typename ValueT,
bool IsSupportedType = std::is_integral<KeyT>::value>
struct EdgeTestDispatch
{
// Edge cases that needs to be tested
Expand All @@ -1495,61 +1497,79 @@ struct EdgeTestDispatch
template <typename ActivePolicyT>
CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t Invoke()
{
using SmallAndMediumPolicyT =
typename ActivePolicyT::SmallAndMediumSegmentedSortPolicyT;
using LargeSegmentPolicyT = typename ActivePolicyT::LargeSegmentPolicy;
if (CUB_IS_HOST_CODE)
{
#if CUB_INCLUDE_HOST_CODE
using SmallAndMediumPolicyT =
typename ActivePolicyT::SmallAndMediumSegmentedSortPolicyT;
using LargeSegmentPolicyT = typename ActivePolicyT::LargeSegmentPolicy;

const int small_segment_max_segment_size =
SmallAndMediumPolicyT::SmallPolicyT::ITEMS_PER_TILE;
const int small_segment_max_segment_size =
SmallAndMediumPolicyT::SmallPolicyT::ITEMS_PER_TILE;

const int items_per_small_segment =
SmallAndMediumPolicyT::SmallPolicyT::ITEMS_PER_THREAD;
const int items_per_small_segment =
SmallAndMediumPolicyT::SmallPolicyT::ITEMS_PER_THREAD;

const int medium_segment_max_segment_size =
SmallAndMediumPolicyT::MediumPolicyT::ITEMS_PER_TILE;
const int medium_segment_max_segment_size =
SmallAndMediumPolicyT::MediumPolicyT::ITEMS_PER_TILE;

const int single_thread_segment_size = items_per_small_segment;
const int single_thread_segment_size = items_per_small_segment;

const int large_cached_segment_max_segment_size =
LargeSegmentPolicyT::BLOCK_THREADS *
LargeSegmentPolicyT::ITEMS_PER_THREAD;
const int large_cached_segment_max_segment_size =
LargeSegmentPolicyT::BLOCK_THREADS *
LargeSegmentPolicyT::ITEMS_PER_THREAD;

for (bool sort_descending : {ascending, descending})
{
Input<KeyT, ValueT> edge_cases =
InputDescription<KeyT>()
.add({a_lot_of, empty_short_circuit_segment_size})
.add({a_lot_of, copy_short_circuit_segment_size})
.add({a_lot_of, swap_short_circuit_segment_size})
.add({a_lot_of, swap_short_circuit_segment_size + 1})
.add({a_lot_of, swap_short_circuit_segment_size + 1})
.add({a_lot_of, single_thread_segment_size - 1})
.add({a_lot_of, single_thread_segment_size })
.add({a_lot_of, single_thread_segment_size + 1 })
.add({a_lot_of, single_thread_segment_size * 2 - 1 })
.add({a_lot_of, single_thread_segment_size * 2 })
.add({a_lot_of, single_thread_segment_size * 2 + 1 })
.add({a_bunch_of, small_segment_max_segment_size - 1})
.add({a_bunch_of, small_segment_max_segment_size})
.add({a_bunch_of, small_segment_max_segment_size + 1})
.add({a_bunch_of, medium_segment_max_segment_size - 1})
.add({a_bunch_of, medium_segment_max_segment_size})
.add({a_bunch_of, medium_segment_max_segment_size + 1})
.add({a_bunch_of, large_cached_segment_max_segment_size - 1})
.add({a_bunch_of, large_cached_segment_max_segment_size})
.add({a_bunch_of, large_cached_segment_max_segment_size + 1})
.add({a_few, large_cached_segment_max_segment_size * 2})
.add({a_few, large_cached_segment_max_segment_size * 3})
.add({a_few, large_cached_segment_max_segment_size * 5})
.template gen<ValueT>(sort_descending);

InputTest<KeyT, ValueT>(sort_descending, edge_cases);
for (bool sort_descending : {ascending, descending})
{
Input<KeyT, ValueT> edge_cases =
InputDescription<KeyT>()
.add({a_lot_of, empty_short_circuit_segment_size})
.add({a_lot_of, copy_short_circuit_segment_size})
.add({a_lot_of, swap_short_circuit_segment_size})
.add({a_lot_of, swap_short_circuit_segment_size + 1})
.add({a_lot_of, swap_short_circuit_segment_size + 1})
.add({a_lot_of, single_thread_segment_size - 1})
.add({a_lot_of, single_thread_segment_size })
.add({a_lot_of, single_thread_segment_size + 1 })
.add({a_lot_of, single_thread_segment_size * 2 - 1 })
.add({a_lot_of, single_thread_segment_size * 2 })
.add({a_lot_of, single_thread_segment_size * 2 + 1 })
.add({a_bunch_of, small_segment_max_segment_size - 1})
.add({a_bunch_of, small_segment_max_segment_size})
.add({a_bunch_of, small_segment_max_segment_size + 1})
.add({a_bunch_of, medium_segment_max_segment_size - 1})
.add({a_bunch_of, medium_segment_max_segment_size})
.add({a_bunch_of, medium_segment_max_segment_size + 1})
.add({a_bunch_of, large_cached_segment_max_segment_size - 1})
.add({a_bunch_of, large_cached_segment_max_segment_size})
.add({a_bunch_of, large_cached_segment_max_segment_size + 1})
.add({a_few, large_cached_segment_max_segment_size * 2})
.add({a_few, large_cached_segment_max_segment_size * 3})
.add({a_few, large_cached_segment_max_segment_size * 5})
.template gen<ValueT>(sort_descending);

InputTest<KeyT, ValueT>(sort_descending, edge_cases);
}
#endif
}

return cudaSuccess;
}
};

template <typename KeyT,
typename ValueT>
struct EdgeTestDispatch<KeyT, ValueT, false>
{
template <typename ActivePolicyT>
CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t Invoke()
{
// Edge case test is using an optimized testing approach which is
// incompatible with duplicates. RandomTest is used for other types.
return cudaSuccess;
}
};

template <typename KeyT,
typename ValueT>
void EdgePatternsTest()
Expand Down Expand Up @@ -1608,7 +1628,7 @@ void RandomTest(int min_segments,
{
const int max_items = 10000000;

for (int iteration = 0; iteration < 10 * MAX_ITERATIONS; iteration++)
for (int iteration = 0; iteration < MAX_ITERATIONS; iteration++)
{
Input<KeyT, ValueT> edge_cases = GenRandomInput<KeyT, ValueT>(max_items,
min_segments,
Expand Down Expand Up @@ -1638,13 +1658,132 @@ void Test()
}


#ifdef CUB_CDP
template <typename KeyT>
__global__ void LauncherKernel(
void *tmp_storage,
std::size_t temp_storage_bytes,
const KeyT *in_keys,
KeyT *out_keys,
int num_items,
int num_segments,
const int *offsets)
{
CubDebug(cub::DeviceSegmentedSort::SortKeys(tmp_storage,
temp_storage_bytes,
in_keys,
out_keys,
num_items,
num_segments,
offsets,
offsets + 1));
}

template <typename KeyT,
typename ValueT>
void TestDeviceSideLaunch(Input<KeyT, ValueT> &input)
{
thrust::host_vector<KeyT> h_keys_output(input.get_num_items());
thrust::device_vector<KeyT> keys_output(input.get_num_items());

thrust::host_vector<ValueT> h_values_output(input.get_num_items());
thrust::device_vector<ValueT> values_output(input.get_num_items());

KeyT *d_keys_output = thrust::raw_pointer_cast(keys_output.data());

thrust::host_vector<KeyT> h_keys(input.get_num_items());
thrust::host_vector<ValueT> h_values(input.get_num_items());

const thrust::host_vector<int> &h_offsets = input.get_h_offsets();

for (int iteration = 0; iteration < MAX_ITERATIONS; iteration++)
{
RandomizeInput(h_keys, h_values);

input.get_d_keys_vec() = h_keys;
input.get_d_values_vec() = h_values;

const KeyT *d_input = input.get_d_keys();

std::size_t temp_storage_bytes{};
cub::DeviceSegmentedSort::SortKeys(nullptr,
temp_storage_bytes,
d_input,
d_keys_output,
input.get_num_items(),
input.get_num_segments(),
input.get_d_offsets(),
input.get_d_offsets() + 1);

thrust::device_vector<std::uint8_t> temp_storage(temp_storage_bytes);
std::uint8_t *d_temp_storage = thrust::raw_pointer_cast(temp_storage.data());

LauncherKernel<KeyT><<<1, 1>>>(
d_temp_storage,
temp_storage_bytes,
d_input,
d_keys_output,
input.get_num_items(),
input.get_num_segments(),
input.get_d_offsets());
CubDebugExit(cudaDeviceSynchronize());
CubDebugExit(cudaPeekAtLastError());

HostReferenceSort(false,
false,
input.get_num_segments(),
h_offsets,
h_keys,
h_values);

h_keys_output = keys_output;

const bool keys_ok =
compare_two_outputs(h_offsets, h_keys, h_keys_output);

AssertTrue(keys_ok);

input.shuffle();
}
}

template <typename KeyT>
void TestDeviceSideLaunch(int min_segments, int max_segments)
{
const int max_items = 10000000;

for (int iteration = 0; iteration < MAX_ITERATIONS; iteration++)
{
Input<KeyT, KeyT> edge_cases =
GenRandomInput<KeyT, KeyT>(max_items,
min_segments,
max_segments,
descending);

TestDeviceSideLaunch(edge_cases);
}
}

template <typename KeyT>
void TestDeviceSideLaunch()
{
TestDeviceSideLaunch<KeyT>(1 << 2, 1 << 8);
TestDeviceSideLaunch<KeyT>(1 << 9, 1 << 19);
}
#endif


int main(int argc, char** argv)
{
CommandLineArgs args(argc, argv);

// Initialize device
CubDebugExit(args.DeviceInit());

#ifdef CUB_CDP
TestDeviceSideLaunch<int>();
#endif

TestZeroSegments();
TestEmptySegments(1 << 2);
TestEmptySegments(1 << 22);
Expand Down

0 comments on commit 20e1ff4

Please sign in to comment.