Skip to content

Commit

Permalink
Use NV_IF_TARGET to select between host/device/sm implementations.
Browse files Browse the repository at this point in the history
  • Loading branch information
alliepiper committed May 10, 2022
1 parent 8ba5c13 commit 7999303
Show file tree
Hide file tree
Showing 18 changed files with 795 additions and 802 deletions.
33 changes: 24 additions & 9 deletions cub/agent/agent_sub_warp_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
#include <cub/warp/warp_merge_sort.cuh>
#include <cub/warp/warp_store.cuh>

#include <nv/target>

#include <thrust/system/cuda/detail/core/util.h>


Expand Down Expand Up @@ -108,6 +110,23 @@ class AgentSubWarpSort
{
template <typename T>
__device__ bool operator()(T lhs, T rhs)
{
return this->impl(lhs, rhs);
}

#if defined(__CUDA_FP16_TYPES_EXIST__)
__device__ bool operator()(__half lhs, __half rhs)
{
// Need to explicitly cast to float for SM <= 52.
NV_IF_TARGET(NV_PROVIDES_SM_53,
(return this->impl(lhs, rhs);),
(return this->impl(__half2float(lhs), __half2float(rhs));));
}
#endif

private:
template <typename T>
__device__ bool impl(T lhs, T rhs)
{
if (IS_DESCENDING)
{
Expand All @@ -118,19 +137,15 @@ class AgentSubWarpSort
return lhs < rhs;
}
}

#if defined(__CUDA_FP16_TYPES_EXIST__) && (CUB_PTX_ARCH < 530)
__device__ bool operator()(__half lhs, __half rhs)
{
return (*this)(__half2float(lhs), __half2float(rhs));
}
#endif
};

#if defined(__CUDA_FP16_TYPES_EXIST__) && (CUB_PTX_ARCH < 530)
#if defined(__CUDA_FP16_TYPES_EXIST__)
__device__ static bool equal(__half lhs, __half rhs)
{
return __half2float(lhs) == __half2float(rhs);
// Need to explicitly cast to float for SM <= 52.
NV_IF_TARGET(NV_PROVIDES_SM_53,
(return lhs == rhs;),
(return __half2float(lhs) == __half2float(rhs);));
}
#endif

Expand Down
31 changes: 15 additions & 16 deletions cub/detail/device_synchronize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include <cub/util_arch.cuh>
#include <cub/util_namespace.cuh>

#include <nv/target>

#include <cuda_runtime_api.h>

CUB_NAMESPACE_BEGIN
Expand All @@ -36,31 +38,28 @@ CUB_RUNTIME_FUNCTION inline cudaError_t device_synchronize()
{
cudaError_t result = cudaErrorUnknown;

if (CUB_IS_HOST_CODE)
{
#if CUB_INCLUDE_HOST_CODE
result = cudaDeviceSynchronize();
#endif
}
else
{
// Device code with the CUDA runtime.
#if defined(CUB_INCLUDE_DEVICE_CODE) && defined(CUB_RUNTIME_ENABLED)
#ifdef CUB_RUNTIME_ENABLED

#if defined(__CUDACC__) && \
((__CUDACC_VER_MAJOR__ > 11) || \
((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 6)))
// CUDA >= 11.6
result = __cudaDeviceSynchronizeDeprecationAvoidance();
// CUDA >= 11.6
#define CUB_TMP_DEVICE_SYNC_IMPL \
result = __cudaDeviceSynchronizeDeprecationAvoidance();
#else // CUDA < 11.6
result = cudaDeviceSynchronize();
#define CUB_TMP_DEVICE_SYNC_IMPL result = cudaDeviceSynchronize();
#endif

#else // Device code without the CUDA runtime.
// Device side CUDA API calls are not supported in this configuration.
result = cudaErrorInvalidConfiguration;
// Device side CUDA API calls are not supported in this configuration.
#define CUB_TMP_DEVICE_SYNC_IMPL result = cudaErrorInvalidConfiguration;
#endif
}

NV_IF_TARGET(NV_IS_HOST,
(result = cudaDeviceSynchronize();),
(CUB_TMP_DEVICE_SYNC_IMPL));

#undef CUB_TMP_DEVICE_SYNC_IMPL

return result;
}
Expand Down
71 changes: 36 additions & 35 deletions cub/device/dispatch/dispatch_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,22 @@

#pragma once

#include <stdio.h>
#include <iterator>
#include <limits>

#include "../../agent/agent_histogram.cuh"
#include "../../util_debug.cuh"
#include "../../util_device.cuh"
#include "../../util_math.cuh"
#include "../../thread/thread_search.cuh"
#include "../../grid/grid_queue.cuh"
#include "../../config.cuh"
#include <cub/agent/agent_histogram.cuh>
#include <cub/util_debug.cuh>
#include <cub/util_device.cuh>
#include <cub/util_math.cuh>
#include <cub/thread/thread_search.cuh>
#include <cub/grid/grid_queue.cuh>
#include <cub/config.cuh>

#include <thrust/system/cuda/detail/core/triple_chevron_launch.h>

#include <nv/target>

#include <cstdio>
#include <iterator>
#include <limits>

CUB_NAMESPACE_BEGIN


Expand Down Expand Up @@ -401,32 +403,31 @@ struct DispatchHistogram
int ptx_version,
KernelConfig &histogram_sweep_config)
{
cudaError_t result = cudaErrorNotSupported;
if (CUB_IS_DEVICE_CODE)
{
#if CUB_INCLUDE_DEVICE_CODE
// We're on the device, so initialize the kernel dispatch configurations with the current PTX policy
result = histogram_sweep_config.template Init<PtxHistogramSweepPolicy>();
#endif
}
else
{
#if CUB_INCLUDE_HOST_CODE
// We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version
if (ptx_version >= 500)
{
result = histogram_sweep_config.template Init<typename Policy500::HistogramSweepPolicy>();
}
else
{
result = histogram_sweep_config.template Init<typename Policy350::HistogramSweepPolicy>();
}
#endif
}
return result;
cudaError_t result = cudaErrorNotSupported;
NV_IF_TARGET(
NV_IS_DEVICE,
(
// We're on the device, so initialize the kernel dispatch
// configurations with the current PTX policy
result = histogram_sweep_config.template Init<PtxHistogramSweepPolicy>();
),
( // NV_IS_HOST:
// We're on the host, so lookup and initialize the kernel dispatch
// configurations with the policies that match the device's PTX
// version
if (ptx_version >= 500)
{
result = histogram_sweep_config.template Init<typename Policy500::HistogramSweepPolicy>();
}
else
{
result = histogram_sweep_config.template Init<typename Policy350::HistogramSweepPolicy>();
}
));

return result;
}


/**
* Kernel kernel dispatch configuration
*/
Expand Down
37 changes: 24 additions & 13 deletions cub/device/dispatch/dispatch_radix_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1332,9 +1332,15 @@ struct DispatchRadixSort :
MaxPolicyT, IS_DESCENDING, KeyT, OffsetT>;
if (CubDebug(error = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&histo_blocks_per_sm, histogram_kernel, HISTO_BLOCK_THREADS, 0))) break;
histogram_kernel<<<histo_blocks_per_sm * num_sms, HISTO_BLOCK_THREADS, 0, stream>>>
(d_bins, d_keys.Current(), num_items, begin_bit, end_bit);
if (CubDebug(error = cudaPeekAtLastError())) break;

error = THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(
histo_blocks_per_sm * num_sms, HISTO_BLOCK_THREADS, 0, stream
).doit(histogram_kernel,
d_bins, d_keys.Current(), num_items, begin_bit, end_bit);
if (CubDebug(error))
{
break;
}

// exclusive sums to determine starts
const int SCAN_BLOCK_THREADS = ActivePolicyT::ExclusiveSumPolicy::BLOCK_THREADS;
Expand Down Expand Up @@ -1368,17 +1374,22 @@ struct DispatchRadixSort :
stream))) break;
auto onesweep_kernel = DeviceRadixSortOnesweepKernel<
MaxPolicyT, IS_DESCENDING, KeyT, ValueT, OffsetT, PortionOffsetT>;
onesweep_kernel<<<num_blocks, ONESWEEP_BLOCK_THREADS, 0, stream>>>
(d_lookback, d_ctrs + portion * num_passes + pass,
portion < num_portions - 1 ?
errror = THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(
num_blocks, ONESWEEP_BLOCK_THREADS, 0, stream
).doit(onesweep_kernel,
d_lookback, d_ctrs + portion * num_passes + pass,
portion < num_portions - 1 ?
d_bins + ((portion + 1) * num_passes + pass) * RADIX_DIGITS : NULL,
d_bins + (portion * num_passes + pass) * RADIX_DIGITS,
d_keys.Alternate(),
d_keys.Current() + portion * PORTION_SIZE,
d_values.Alternate(),
d_values.Current() + portion * PORTION_SIZE,
portion_num_items, current_bit, num_bits);
if (CubDebug(error = cudaPeekAtLastError())) break;
d_bins + (portion * num_passes + pass) * RADIX_DIGITS,
d_keys.Alternate(),
d_keys.Current() + portion * PORTION_SIZE,
d_values.Alternate(),
d_values.Current() + portion * PORTION_SIZE,
portion_num_items, current_bit, num_bits);
if (CubDebug(error))
{
break;
}
}

// use the temporary buffers if no overwrite is allowed
Expand Down
51 changes: 22 additions & 29 deletions cub/device/dispatch/dispatch_reduce_by_key.cuh
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

/******************************************************************************
* Copyright (c) 2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
Expand Down Expand Up @@ -34,19 +33,21 @@

#pragma once

#include <stdio.h>
#include <iterator>
#include <cub/agent/agent_reduce_by_key.cuh>
#include <cub/config.cuh>
#include <cub/device/dispatch/dispatch_scan.cuh>
#include <cub/grid/grid_queue.cuh>
#include <cub/thread/thread_operators.cuh>
#include <cub/util_device.cuh>
#include <cub/util_math.cuh>

#include "dispatch_scan.cuh"
#include "../../config.cuh"
#include "../../agent/agent_reduce_by_key.cuh"
#include "../../thread/thread_operators.cuh"
#include "../../grid/grid_queue.cuh"
#include "../../util_device.cuh"
#include "../../util_math.cuh"
#include <nv/target>

#include <thrust/system/cuda/detail/core/triple_chevron_launch.h>

#include <cstdio>
#include <iterator>

CUB_NAMESPACE_BEGIN

/******************************************************************************
Expand Down Expand Up @@ -193,27 +194,19 @@ struct DispatchReduceByKey
template <typename KernelConfig>
CUB_RUNTIME_FUNCTION __forceinline__
static void InitConfigs(
int ptx_version,
int /*ptx_version*/,
KernelConfig &reduce_by_key_config)
{
if (CUB_IS_DEVICE_CODE)
{
#if CUB_INCLUDE_DEVICE_CODE
(void)ptx_version;
// We're on the device, so initialize the kernel dispatch configurations with the current PTX policy
reduce_by_key_config.template Init<PtxReduceByKeyPolicy>();
#endif
}
else
{
#if CUB_INCLUDE_HOST_CODE
// We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version

// (There's only one policy right now)
(void)ptx_version;
reduce_by_key_config.template Init<typename Policy350::ReduceByKeyPolicyT>();
#endif
}
NV_IF_TARGET(NV_IS_DEVICE,
(
// We're on the device, so initialize the kernel dispatch configurations with the current PTX policy
reduce_by_key_config.template Init<PtxReduceByKeyPolicy>();
), (
// We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version

// (There's only one policy right now)
reduce_by_key_config.template Init<typename Policy350::ReduceByKeyPolicyT>();
));
}


Expand Down
Loading

0 comments on commit 7999303

Please sign in to comment.