Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Fix begin_bit == end_bit == 0 for device-wide and segmented sort.
Browse files Browse the repository at this point in the history
-   Copy if begin_bit == end_bit, but overwrite not allowed
-   Fix style
-   When begin_bit == end_bit and double-buffering, don't do any sorting work
-   Uncommented segmented sort test
-   begin_bit == end_bit == 0 for upsweep/downsweep and segmented sort
-   Fixed begin_bit == end_bit == 0 case
  • Loading branch information
canonizer committed Aug 1, 2022
1 parent 81a96c9 commit 9b50753
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 25 deletions.
88 changes: 65 additions & 23 deletions cub/device/dispatch/dispatch_radix_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -912,12 +912,12 @@ struct DeviceRadixSortPolicy
struct Policy800 : ChainedPolicy<800, Policy800, Policy700>
{
enum {
PRIMARY_RADIX_BITS = (sizeof(KeyT) > 1) ? 7 : 5,
SINGLE_TILE_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5,
SEGMENTED_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5,
ONESWEEP = sizeof(KeyT) >= sizeof(uint32_t),
ONESWEEP_RADIX_BITS = 8,
OFFSET_64BIT = sizeof(OffsetT) == 8,
PRIMARY_RADIX_BITS = (sizeof(KeyT) > 1) ? 7 : 5,
SINGLE_TILE_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5,
SEGMENTED_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5,
ONESWEEP = sizeof(KeyT) >= sizeof(uint32_t),
ONESWEEP_RADIX_BITS = 8,
OFFSET_64BIT = sizeof(OffsetT) == 8,
};

// Histogram policy
Expand Down Expand Up @@ -1366,7 +1366,7 @@ struct DispatchRadixSort :
ValueT* d_values_tmp2 = (ValueT*)allocations[3];
AtomicOffsetT* d_ctrs = (AtomicOffsetT*)allocations[4];

do {
do {
// initialization
if (CubDebug(error = cudaMemsetAsync(
d_ctrs, 0, num_portions * num_passes * sizeof(AtomicOffsetT), stream))) break;
Expand Down Expand Up @@ -1498,6 +1498,8 @@ struct DispatchRadixSort :
}
}

if (CubDebug(error)) break;

// use the temporary buffers if no overwrite is allowed
if (!is_overwrite_okay && pass == 0)
{
Expand Down Expand Up @@ -1671,6 +1673,42 @@ struct DispatchRadixSort :
return InvokeOnesweep<ActivePolicyT>();
}

CUB_RUNTIME_FUNCTION __forceinline__
cudaError_t InvokeCopy()
{
// is_overwrite_okay == false here
// Return the number of temporary bytes if requested
if (d_temp_storage == nullptr)
{
temp_storage_bytes = 1;
return cudaSuccess;
}

// Copy keys
cudaError_t error = cudaSuccess;
error = cudaMemcpyAsync(d_keys.Alternate(), d_keys.Current(), num_items * sizeof(KeyT),
cudaMemcpyDefault, stream);
if (CubDebug(error))
{
return error;
}
d_keys.selector ^= 1;

// Copy values if necessary
if (!KEYS_ONLY)
{
error = cudaMemcpyAsync(d_values.Alternate(), d_values.Current(),
num_items * sizeof(ValueT), cudaMemcpyDefault, stream);
if (CubDebug(error))
{
return error;
}
}
d_values.selector ^= 1;

return error;
}

/// Invocation
template <typename ActivePolicyT>
CUB_RUNTIME_FUNCTION __forceinline__
Expand All @@ -1679,15 +1717,20 @@ struct DispatchRadixSort :
typedef typename DispatchRadixSort::MaxPolicy MaxPolicyT;
typedef typename ActivePolicyT::SingleTilePolicy SingleTilePolicyT;

// Return if empty problem
if (num_items == 0)
// Return if empty problem, or if no bits to sort and double-buffering is used
if (num_items == 0 || (begin_bit == end_bit && is_overwrite_okay))
{
if (d_temp_storage == nullptr)
{
temp_storage_bytes = 1;
}
if (d_temp_storage == nullptr)
{
temp_storage_bytes = 1;
}
return cudaSuccess;
}

return cudaSuccess;
// Check if simple copy suffices (is_overwrite_okay == false at this point)
if (begin_bit == end_bit)
{
return InvokeCopy();
}

// Force kernel code-generation in all compiler passes
Expand Down Expand Up @@ -2021,7 +2064,7 @@ struct DispatchSegmentedRadixSort :
int radix_bits = ActivePolicyT::SegmentedPolicy::RADIX_BITS;
int alt_radix_bits = ActivePolicyT::AltSegmentedPolicy::RADIX_BITS;
int num_bits = end_bit - begin_bit;
int num_passes = (num_bits + radix_bits - 1) / radix_bits;
int num_passes = CUB_MAX(DivideAndRoundUp(num_bits, radix_bits), 1);
bool is_num_passes_odd = num_passes & 1;
int max_alt_passes = (num_passes * radix_bits) - num_bits;
int alt_end_bit = CUB_MIN(end_bit, begin_bit + (max_alt_passes * alt_radix_bits));
Expand Down Expand Up @@ -2082,15 +2125,14 @@ struct DispatchSegmentedRadixSort :
{
typedef typename DispatchSegmentedRadixSort::MaxPolicy MaxPolicyT;

// Return if empty problem
if (num_items == 0)
// Return if empty problem, or if no bits to sort and double-buffering is used
if (num_items == 0 || (begin_bit == end_bit && is_overwrite_okay))
{
if (d_temp_storage == nullptr)
{
temp_storage_bytes = 1;
}

return cudaSuccess;
if (d_temp_storage == nullptr)
{
temp_storage_bytes = 1;
}
return cudaSuccess;
}

// Force kernel code-generation in all compiler passes
Expand Down
8 changes: 6 additions & 2 deletions test/test_device_radix_sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,11 @@ void TestBits(
printf("Testing key bits [%d,%d)\n", begin_bit, end_bit); fflush(stdout);
TestDirection(h_keys, num_items, num_segments, pre_sorted, h_segment_offsets, d_segment_begin_offsets, d_segment_end_offsets, begin_bit, end_bit);

// Equal bits
begin_bit = end_bit = 0;
printf("Testing key bits [%d,%d)\n", begin_bit, end_bit); fflush(stdout);
TestDirection(h_keys, num_items, num_segments, pre_sorted, h_segment_offsets, d_segment_begin_offsets, d_segment_end_offsets, begin_bit, end_bit);

// Across subword boundaries
int mid_bit = sizeof(KeyT) * 4;
printf("Testing key bits [%d,%d)\n", mid_bit - 1, mid_bit + 1); fflush(stdout);
Expand Down Expand Up @@ -1587,7 +1592,7 @@ void TestGen(
{
if (max_items == ~std::size_t(0))
{
max_items = 9000003;
max_items = 8000003;
}

if (max_segments < 0)
Expand Down Expand Up @@ -1650,7 +1655,6 @@ void TestGen(
TestSizes(h_keys.get(), large_num_items, max_segments, true);
fflush(stdout);
}

}


Expand Down

0 comments on commit 9b50753

Please sign in to comment.