Skip to content

Commit

Permalink
Use NV_IF_TARGET to select between host/device/sm implementations.
Browse files Browse the repository at this point in the history
  • Loading branch information
alliepiper committed Apr 15, 2021
1 parent efed9f9 commit cfa20c9
Show file tree
Hide file tree
Showing 14 changed files with 435 additions and 433 deletions.
48 changes: 24 additions & 24 deletions cub/device/dispatch/dispatch_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include <limits>

#include "../../agent/agent_histogram.cuh"
#include "../../detail/target.cuh"
#include "../../util_debug.cuh"
#include "../../util_device.cuh"
#include "../../util_math.cuh"
Expand Down Expand Up @@ -403,32 +404,31 @@ struct DispatchHistogram
int ptx_version,
KernelConfig &histogram_sweep_config)
{
cudaError_t result = cudaErrorNotSupported;
if (CUB_IS_DEVICE_CODE)
{
#if CUB_INCLUDE_DEVICE_CODE
// We're on the device, so initialize the kernel dispatch configurations with the current PTX policy
result = histogram_sweep_config.template Init<PtxHistogramSweepPolicy>();
#endif
}
else
{
#if CUB_INCLUDE_HOST_CODE
// We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version
if (ptx_version >= 500)
{
result = histogram_sweep_config.template Init<typename Policy500::HistogramSweepPolicy>();
}
else
{
result = histogram_sweep_config.template Init<typename Policy350::HistogramSweepPolicy>();
}
#endif
}
return result;
cudaError_t result = cudaErrorNotSupported;
NV_IF_TARGET(
NV_IS_DEVICE,
(
// We're on the device, so initialize the kernel dispatch
// configurations with the current PTX policy
result = histogram_sweep_config.template Init<PtxHistogramSweepPolicy>();
),
( // NV_IS_HOST:
// We're on the host, so lookup and initialize the kernel dispatch
// configurations with the policies that match the device's PTX
// version
if (ptx_version >= 500)
{
result = histogram_sweep_config.template Init<typename Policy500::HistogramSweepPolicy>();
}
else
{
result = histogram_sweep_config.template Init<typename Policy350::HistogramSweepPolicy>();
}
));

return result;
}


/**
* Kernel kernel dispatch configuration
*/
Expand Down
33 changes: 14 additions & 19 deletions cub/device/dispatch/dispatch_reduce_by_key.cuh
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

/******************************************************************************
* Copyright (c) 2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
Expand Down Expand Up @@ -45,6 +44,8 @@
#include "../../util_device.cuh"
#include "../../util_math.cuh"

#include "../../detail/target.cuh"

#include <thrust/system/cuda/detail/core/triple_chevron_launch.h>

/// Optional outer namespace(s)
Expand Down Expand Up @@ -203,24 +204,18 @@ struct DispatchReduceByKey
int ptx_version,
KernelConfig &reduce_by_key_config)
{
if (CUB_IS_DEVICE_CODE)
{
#if CUB_INCLUDE_DEVICE_CODE
(void)ptx_version;
// We're on the device, so initialize the kernel dispatch configurations with the current PTX policy
reduce_by_key_config.template Init<PtxReduceByKeyPolicy>();
#endif
}
else
{
#if CUB_INCLUDE_HOST_CODE
// We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version

// (There's only one policy right now)
(void)ptx_version;
reduce_by_key_config.template Init<typename Policy350::ReduceByKeyPolicyT>();
#endif
}
NV_IF_TARGET(NV_IS_DEVICE,
(
(void)ptx_version;
// We're on the device, so initialize the kernel dispatch configurations with the current PTX policy
reduce_by_key_config.template Init<PtxReduceByKeyPolicy>();
), (
// We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version

// (There's only one policy right now)
(void)ptx_version;
reduce_by_key_config.template Init<typename Policy350::ReduceByKeyPolicyT>();
));
}


Expand Down
29 changes: 12 additions & 17 deletions cub/device/dispatch/dispatch_rle.cuh
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

/******************************************************************************
* Copyright (c) 2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
Expand Down Expand Up @@ -39,6 +38,7 @@

#include "dispatch_scan.cuh"
#include "../../config.cuh"
#include "../../detail/target.cuh"
#include "../../agent/agent_rle.cuh"
#include "../../thread/thread_operators.cuh"
#include "../../grid/grid_queue.cuh"
Expand Down Expand Up @@ -189,22 +189,17 @@ struct DeviceRleDispatch
int ptx_version,
KernelConfig& device_rle_config)
{
if (CUB_IS_DEVICE_CODE) {
#if CUB_INCLUDE_DEVICE_CODE
// We're on the device, so initialize the kernel dispatch configurations with the current PTX policy
device_rle_config.template Init<PtxRleSweepPolicy>();
#endif
}
else
{
#if CUB_INCLUDE_HOST_CODE
// We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version

// (There's only one policy right now)
(void)ptx_version;
device_rle_config.template Init<typename Policy350::RleSweepPolicy>();
#endif
}
NV_IF_TARGET(NV_IS_DEVICE,
(
// We're on the device, so initialize the kernel dispatch configurations with the current PTX policy
device_rle_config.template Init<PtxRleSweepPolicy>();
), (
// We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version

// (There's only one policy right now)
(void)ptx_version;
device_rle_config.template Init<typename Policy350::RleSweepPolicy>();
));
}


Expand Down
31 changes: 13 additions & 18 deletions cub/device/dispatch/dispatch_select_if.cuh
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

/******************************************************************************
* Copyright (c) 2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
Expand Down Expand Up @@ -39,6 +38,7 @@

#include "dispatch_scan.cuh"
#include "../../config.cuh"
#include "../../detail/target.cuh"
#include "../../agent/agent_select_if.cuh"
#include "../../thread/thread_operators.cuh"
#include "../../grid/grid_queue.cuh"
Expand Down Expand Up @@ -194,23 +194,18 @@ struct DispatchSelectIf
int ptx_version,
KernelConfig &select_if_config)
{
if (CUB_IS_DEVICE_CODE) {
#if CUB_INCLUDE_DEVICE_CODE
(void)ptx_version;
// We're on the device, so initialize the kernel dispatch configurations with the current PTX policy
select_if_config.template Init<PtxSelectIfPolicyT>();
#endif
}
else
{
#if CUB_INCLUDE_HOST_CODE
// We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version

// (There's only one policy right now)
(void)ptx_version;
select_if_config.template Init<typename Policy350::SelectIfPolicyT>();
#endif
}
NV_IF_TARGET(NV_IS_DEVICE,
(
(void)ptx_version;
// We're on the device, so initialize the kernel dispatch configurations with the current PTX policy
select_if_config.template Init<PtxSelectIfPolicyT>();
), (
// We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version

// (There's only one policy right now)
(void)ptx_version;
select_if_config.template Init<typename Policy350::SelectIfPolicyT>();
));
}


Expand Down
93 changes: 45 additions & 48 deletions cub/grid/grid_queue.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#pragma once

#include "../config.cuh"
#include "../detail/target.cuh"
#include "../util_debug.cuh"

/// Optional outer namespace(s)
Expand Down Expand Up @@ -124,21 +125,20 @@ public:
cudaStream_t stream = 0)
{
cudaError_t result = cudaErrorUnknown;
if (CUB_IS_DEVICE_CODE) {
#if CUB_INCLUDE_DEVICE_CODE
(void)stream;
d_counters[FILL] = fill_size;
d_counters[DRAIN] = 0;
result = cudaSuccess;
#endif
} else {
#if CUB_INCLUDE_HOST_CODE
OffsetT counters[2];
counters[FILL] = fill_size;
counters[DRAIN] = 0;
result = CubDebug(cudaMemcpyAsync(d_counters, counters, sizeof(OffsetT) * 2, cudaMemcpyHostToDevice, stream));
#endif
}

NV_IF_TARGET(NV_IS_DEVICE,
(
(void)stream;
d_counters[FILL] = fill_size;
d_counters[DRAIN] = 0;
result = cudaSuccess;
), (
OffsetT counters[2];
counters[FILL] = fill_size;
counters[DRAIN] = 0;
result = CubDebug(cudaMemcpyAsync(d_counters, counters, sizeof(OffsetT) * 2, cudaMemcpyHostToDevice, stream));
));

return result;
}

Expand All @@ -147,17 +147,16 @@ public:
__host__ __device__ __forceinline__ cudaError_t ResetDrain(cudaStream_t stream = 0)
{
cudaError_t result = cudaErrorUnknown;
if (CUB_IS_DEVICE_CODE) {
#if CUB_INCLUDE_DEVICE_CODE
(void)stream;
d_counters[DRAIN] = 0;
result = cudaSuccess;
#endif
} else {
#if CUB_INCLUDE_HOST_CODE
result = CubDebug(cudaMemsetAsync(d_counters + DRAIN, 0, sizeof(OffsetT), stream));
#endif
}

NV_IF_TARGET(NV_IS_DEVICE,
(
(void)stream;
d_counters[DRAIN] = 0;
result = cudaSuccess;
), (
result = CubDebug(cudaMemsetAsync(d_counters + DRAIN, 0, sizeof(OffsetT), stream));
));

return result;
}

Expand All @@ -166,17 +165,16 @@ public:
__host__ __device__ __forceinline__ cudaError_t ResetFill(cudaStream_t stream = 0)
{
cudaError_t result = cudaErrorUnknown;
if (CUB_IS_DEVICE_CODE) {
#if CUB_INCLUDE_DEVICE_CODE
(void)stream;
d_counters[FILL] = 0;
result = cudaSuccess;
#endif
} else {
#if CUB_INCLUDE_HOST_CODE
result = CubDebug(cudaMemsetAsync(d_counters + FILL, 0, sizeof(OffsetT), stream));
#endif
}

NV_IF_TARGET(NV_IS_DEVICE,
(
(void)stream;
d_counters[FILL] = 0;
result = cudaSuccess;
), (
result = CubDebug(cudaMemsetAsync(d_counters + FILL, 0, sizeof(OffsetT), stream));
));

return result;
}

Expand All @@ -187,17 +185,16 @@ public:
cudaStream_t stream = 0)
{
cudaError_t result = cudaErrorUnknown;
if (CUB_IS_DEVICE_CODE) {
#if CUB_INCLUDE_DEVICE_CODE
(void)stream;
fill_size = d_counters[FILL];
result = cudaSuccess;
#endif
} else {
#if CUB_INCLUDE_HOST_CODE
result = CubDebug(cudaMemcpyAsync(&fill_size, d_counters + FILL, sizeof(OffsetT), cudaMemcpyDeviceToHost, stream));
#endif
}

NV_IF_TARGET(NV_IS_DEVICE,
(
(void)stream;
fill_size = d_counters[FILL];
result = cudaSuccess;
), (
result = CubDebug(cudaMemcpyAsync(&fill_size, d_counters + FILL, sizeof(OffsetT), cudaMemcpyDeviceToHost, stream));
));

return result;
}

Expand Down
Loading

0 comments on commit cfa20c9

Please sign in to comment.