Skip to content

Commit

Permalink
Use NV_IF_TARGET to select between host/device/sm implementations.
Browse files Browse the repository at this point in the history
  • Loading branch information
alliepiper committed Feb 4, 2022
1 parent 808b8ed commit f5abb6a
Show file tree
Hide file tree
Showing 15 changed files with 514 additions and 553 deletions.
33 changes: 24 additions & 9 deletions cub/agent/agent_sub_warp_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
#include <cub/warp/warp_merge_sort.cuh>
#include <cub/warp/warp_store.cuh>

#include <cub/detail/target.cuh>

#include <thrust/system/cuda/detail/core/util.h>


Expand Down Expand Up @@ -108,6 +110,23 @@ class AgentSubWarpSort
{
template <typename T>
__device__ bool operator()(T lhs, T rhs)
{
return this->impl(lhs, rhs);
}

#if defined(__CUDA_FP16_TYPES_EXIST__)
__device__ bool operator()(__half lhs, __half rhs)
{
// Need to explicitly cast to float for SM <= 52.
NV_IF_TARGET(NV_PROVIDES_SM_52,
(return this->impl(lhs, rhs);),
(return this->impl(__half2float(lhs), __half2float(rhs));));
}
#endif

private:
template <typename T>
__device__ bool impl(T lhs, T rhs)
{
if (IS_DESCENDING)
{
Expand All @@ -118,19 +137,15 @@ class AgentSubWarpSort
return lhs < rhs;
}
}

#if defined(__CUDA_FP16_TYPES_EXIST__) && (CUB_PTX_ARCH < 530)
__device__ bool operator()(__half lhs, __half rhs)
{
return (*this)(__half2float(lhs), __half2float(rhs));
}
#endif
};

#if defined(__CUDA_FP16_TYPES_EXIST__) && (CUB_PTX_ARCH < 530)
#if defined(__CUDA_FP16_TYPES_EXIST__)
__device__ static bool equal(__half lhs, __half rhs)
{
return __half2float(lhs) == __half2float(rhs);
// Need to explicitly cast to float for SM <= 52.
NV_IF_TARGET(NV_PROVIDES_SM_52,
(return lhs == rhs;),
(return __half2float(lhs) == __half2float(rhs);));
}
#endif

Expand Down
32 changes: 16 additions & 16 deletions cub/detail/device_synchronize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
#pragma once

#include <cub/detail/exec_check_disable.cuh>
#include <cub/detail/detect_cuda_runtime.cuh>
#include <cub/detail/target.cuh>
#include <cub/util_arch.cuh>

#include <cub/util_namespace.cuh>

#include <cuda_runtime_api.h>
Expand All @@ -36,31 +39,28 @@ CUB_RUNTIME_FUNCTION inline cudaError_t device_synchronize()
{
cudaError_t result = cudaErrorUnknown;

if (CUB_IS_HOST_CODE)
{
#if CUB_INCLUDE_HOST_CODE
result = cudaDeviceSynchronize();
#endif
}
else
{
// Device code with the CUDA runtime.
#if defined(CUB_INCLUDE_DEVICE_CODE) && defined(CUB_RUNTIME_ENABLED)
#ifdef CUB_RUNTIME_ENABLED

#if defined(__CUDACC__) && \
((__CUDACC_VER_MAJOR__ > 11) || \
((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 6)))
// CUDA >= 11.6
result = __cudaDeviceSynchronizeDeprecationAvoidance();
// CUDA >= 11.6
#define CUB_TMP_DEVICE_SYNC_IMPL \
result = __cudaDeviceSynchronizeDeprecationAvoidance();
#else // CUDA < 11.6
result = cudaDeviceSynchronize();
#define CUB_TMP_DEVICE_SYNC_IMPL result = cudaDeviceSynchronize();
#endif

#else // Device code without the CUDA runtime.
// Device side CUDA API calls are not supported in this configuration.
result = cudaErrorInvalidConfiguration;
// Device side CUDA API calls are not supported in this configuration.
#define CUB_TMP_DEVICE_SYNC_IMPL result = cudaErrorInvalidConfiguration;
#endif
}

NV_IF_TARGET(NV_IS_HOST,
(result = cudaDeviceSynchronize();),
(CUB_TMP_DEVICE_SYNC_IMPL));

#undef CUB_TMP_DEVICE_SYNC_IMPL

return result;
}
Expand Down
48 changes: 24 additions & 24 deletions cub/device/dispatch/dispatch_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include <limits>

#include "../../agent/agent_histogram.cuh"
#include "../../detail/target.cuh"
#include "../../util_debug.cuh"
#include "../../util_device.cuh"
#include "../../util_math.cuh"
Expand Down Expand Up @@ -401,32 +402,31 @@ struct DispatchHistogram
int ptx_version,
KernelConfig &histogram_sweep_config)
{
cudaError_t result = cudaErrorNotSupported;
if (CUB_IS_DEVICE_CODE)
{
#if CUB_INCLUDE_DEVICE_CODE
// We're on the device, so initialize the kernel dispatch configurations with the current PTX policy
result = histogram_sweep_config.template Init<PtxHistogramSweepPolicy>();
#endif
}
else
{
#if CUB_INCLUDE_HOST_CODE
// We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version
if (ptx_version >= 500)
{
result = histogram_sweep_config.template Init<typename Policy500::HistogramSweepPolicy>();
}
else
{
result = histogram_sweep_config.template Init<typename Policy350::HistogramSweepPolicy>();
}
#endif
}
return result;
cudaError_t result = cudaErrorNotSupported;
NV_IF_TARGET(
NV_IS_DEVICE,
(
// We're on the device, so initialize the kernel dispatch
// configurations with the current PTX policy
result = histogram_sweep_config.template Init<PtxHistogramSweepPolicy>();
),
( // NV_IS_HOST:
// We're on the host, so lookup and initialize the kernel dispatch
// configurations with the policies that match the device's PTX
// version
if (ptx_version >= 500)
{
result = histogram_sweep_config.template Init<typename Policy500::HistogramSweepPolicy>();
}
else
{
result = histogram_sweep_config.template Init<typename Policy350::HistogramSweepPolicy>();
}
));

return result;
}


/**
* Kernel kernel dispatch configuration
*/
Expand Down
33 changes: 14 additions & 19 deletions cub/device/dispatch/dispatch_reduce_by_key.cuh
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

/******************************************************************************
* Copyright (c) 2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
Expand Down Expand Up @@ -45,6 +44,8 @@
#include "../../util_device.cuh"
#include "../../util_math.cuh"

#include "../../detail/target.cuh"

#include <thrust/system/cuda/detail/core/triple_chevron_launch.h>

CUB_NAMESPACE_BEGIN
Expand Down Expand Up @@ -196,24 +197,18 @@ struct DispatchReduceByKey
int ptx_version,
KernelConfig &reduce_by_key_config)
{
if (CUB_IS_DEVICE_CODE)
{
#if CUB_INCLUDE_DEVICE_CODE
(void)ptx_version;
// We're on the device, so initialize the kernel dispatch configurations with the current PTX policy
reduce_by_key_config.template Init<PtxReduceByKeyPolicy>();
#endif
}
else
{
#if CUB_INCLUDE_HOST_CODE
// We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version

// (There's only one policy right now)
(void)ptx_version;
reduce_by_key_config.template Init<typename Policy350::ReduceByKeyPolicyT>();
#endif
}
NV_IF_TARGET(NV_IS_DEVICE,
(
(void)ptx_version;
// We're on the device, so initialize the kernel dispatch configurations with the current PTX policy
reduce_by_key_config.template Init<PtxReduceByKeyPolicy>();
), (
// We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version

// (There's only one policy right now)
(void)ptx_version;
reduce_by_key_config.template Init<typename Policy350::ReduceByKeyPolicyT>();
));
}


Expand Down
29 changes: 12 additions & 17 deletions cub/device/dispatch/dispatch_rle.cuh
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

/******************************************************************************
* Copyright (c) 2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
Expand Down Expand Up @@ -39,6 +38,7 @@

#include "dispatch_scan.cuh"
#include "../../config.cuh"
#include "../../detail/target.cuh"
#include "../../agent/agent_rle.cuh"
#include "../../thread/thread_operators.cuh"
#include "../../grid/grid_queue.cuh"
Expand Down Expand Up @@ -184,22 +184,17 @@ struct DeviceRleDispatch
int ptx_version,
KernelConfig& device_rle_config)
{
if (CUB_IS_DEVICE_CODE) {
#if CUB_INCLUDE_DEVICE_CODE
// We're on the device, so initialize the kernel dispatch configurations with the current PTX policy
device_rle_config.template Init<PtxRleSweepPolicy>();
#endif
}
else
{
#if CUB_INCLUDE_HOST_CODE
// We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version

// (There's only one policy right now)
(void)ptx_version;
device_rle_config.template Init<typename Policy350::RleSweepPolicy>();
#endif
}
NV_IF_TARGET(NV_IS_DEVICE,
(
// We're on the device, so initialize the kernel dispatch configurations with the current PTX policy
device_rle_config.template Init<PtxRleSweepPolicy>();
), (
// We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version

// (There's only one policy right now)
(void)ptx_version;
device_rle_config.template Init<typename Policy350::RleSweepPolicy>();
));
}


Expand Down
31 changes: 13 additions & 18 deletions cub/device/dispatch/dispatch_select_if.cuh
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

/******************************************************************************
* Copyright (c) 2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
Expand Down Expand Up @@ -39,6 +38,7 @@

#include "dispatch_scan.cuh"
#include "../../config.cuh"
#include "../../detail/target.cuh"
#include "../../agent/agent_select_if.cuh"
#include "../../thread/thread_operators.cuh"
#include "../../grid/grid_queue.cuh"
Expand Down Expand Up @@ -190,23 +190,18 @@ struct DispatchSelectIf
int ptx_version,
KernelConfig &select_if_config)
{
if (CUB_IS_DEVICE_CODE) {
#if CUB_INCLUDE_DEVICE_CODE
(void)ptx_version;
// We're on the device, so initialize the kernel dispatch configurations with the current PTX policy
select_if_config.template Init<PtxSelectIfPolicyT>();
#endif
}
else
{
#if CUB_INCLUDE_HOST_CODE
// We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version

// (There's only one policy right now)
(void)ptx_version;
select_if_config.template Init<typename Policy350::SelectIfPolicyT>();
#endif
}
NV_IF_TARGET(NV_IS_DEVICE,
(
(void)ptx_version;
// We're on the device, so initialize the kernel dispatch configurations with the current PTX policy
select_if_config.template Init<PtxSelectIfPolicyT>();
), (
// We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version

// (There's only one policy right now)
(void)ptx_version;
select_if_config.template Init<typename Policy350::SelectIfPolicyT>();
));
}


Expand Down
Loading

0 comments on commit f5abb6a

Please sign in to comment.