Skip to content

Commit

Permalink
[caffe2] fix no matching function min/max Clang errors (pytorch#33563)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#33563

When NVCC or Clang are driving CUDA compilation many math functions are declared by default, with a small difference: Clang marks them as `__device__` only, while NVCC uses both `__host__` and `__device__`. This makes every un-elaborated `min` or `max` function call from a `__host__` function generate a syntax error when Clang is used.

Fix the errors by using `std::min` and `std::max` from `<algorithm>`, since C++14 they are `constexpr` and can be used in the `__device__` code [1].

1. https://llvm.org/docs/CompileCudaWithLLVM.html#algorithm

Test Plan:
```lang=bash
buck build mode/opt -c fbcode.cuda_use_clang=true //fblearner/flow/projects/dper:workflow
buck build mode/opt //fblearner/flow/projects/dper:workflow
```
Execute tests on devgpu:
```
buck test mode/dev-nosan -j 8 //caffe2/caffe2/python/operator_test/... //caffe2/test:cuda
```

Reviewed By: ngimel

Differential Revision: D20005795

fbshipit-source-id: 98a3f35e8a96c15d3ad3d2066396591f5cca1696
  • Loading branch information
Igor Sugak authored and facebook-github-bot committed Feb 28, 2020
1 parent c6d3012 commit 5dde8cd
Show file tree
Hide file tree
Showing 11 changed files with 51 additions and 21 deletions.
19 changes: 14 additions & 5 deletions aten/src/THC/generic/THCTensorMath.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#define THC_GENERIC_FILE "THC/generic/THCTensorMath.cu"
#else

#include <algorithm>

#include "ATen/cuda/CUDAContext.h"
#include <ATen/MemoryOverlap.h>

Expand Down Expand Up @@ -149,12 +151,16 @@ void THCTensor_(diag)(THCState *state, THCTensor *self_, THCTensor *src_, int64_
int64_t stride1 = THCTensor_(stride)(state, src_, 1);
int64_t size0 = THCTensor_(size)(state, src_, 0);
int64_t size1 = THCTensor_(size)(state, src_, 1);
int64_t size = (k > 0) ? min((int64_t)size0, (int64_t)size1 - k) : min((int64_t)size0 + k, (int64_t)size1);
int64_t size = (k > 0) ? std::min((int64_t)size0, (int64_t)size1 - k)
: std::min((int64_t)size0 + k, (int64_t)size1);
THCTensor_(resize1d)(state, self_, size);
if (size > 0) {
int64_t strideSelf = THCTensor_(stride)(state, self_, 0);
const dim3 threads(min((int64_t)at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, (int64_t)size));
dim3 grid(min((int64_t)1024, (int64_t)THCCeilDiv(size, (int64_t)threads.x)));
const dim3 threads(std::min(
(int64_t)at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
(int64_t)size));
dim3 grid(std::min(
(int64_t)1024, (int64_t)THCCeilDiv(size, (int64_t)threads.x)));
int64_t start = (k >= 0 ? k * stride1 : -k * stride0);
THCTensor_copyFromDiagonal<scalar_t><<<grid, threads, 0, c10::cuda::getCurrentCUDAStream()>>>
(THCTensor_(data)(state, self_), THCTensor_(data)(state, src_), start, size, stride0 + stride1, strideSelf);
Expand All @@ -168,8 +174,11 @@ void THCTensor_(diag)(THCState *state, THCTensor *self_, THCTensor *src_, int64_
if (size > 0) {
int64_t stride0 = THCTensor_(stride)(state, self_, 0);
int64_t stride1 = THCTensor_(stride)(state, self_, 1);
const dim3 threads(min((int64_t)at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, (int64_t)size));
dim3 grid(min((int64_t)1024, (int64_t)THCCeilDiv(size, (ptrdiff_t)threads.x)));
const dim3 threads(std::min(
(int64_t)at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
(int64_t)size));
dim3 grid(std::min(
(int64_t)1024, (int64_t)THCCeilDiv(size, (ptrdiff_t)threads.x)));
ptrdiff_t start = (k >= 0 ? k * stride1 : -k * stride0);
THCTensor_copyToDiagonal<scalar_t><<<grid, threads, 0, c10::cuda::getCurrentCUDAStream()>>>
(THCTensor_(data)(state, self_), THCTensor_(data)(state, src_), start, totalElements, stride0 + stride1, strideSrc);
Expand Down
10 changes: 7 additions & 3 deletions aten/src/THC/generic/THCTensorMathScan.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <algorithm>

#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "THC/generic/THCTensorMathScan.cu"
#else
Expand Down Expand Up @@ -41,9 +43,11 @@ __host__ void THCTensor_(scanOuterDim)(THCState *state, THCTensor *tgt,
num_irows *= THCTensor_(sizeLegacyNoScalars)(state, src, dim);
}

dim3 threads(min(512, num_irows));
dim3 threads(std::min(512u, num_irows));
unsigned maxGridDim = 1024;
dim3 grid(min(maxGridDim, num_orows), min(maxGridDim, THCCeilDiv(num_irows, threads.x)));
dim3 grid(
std::min(maxGridDim, num_orows),
std::min(maxGridDim, THCCeilDiv(num_irows, threads.x)));

THCTensor_kernel_scanOuterDim<scalar_t><<<grid, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
THCTensor_(data)(state, tgt), THCTensor_(data)(state, src),
Expand All @@ -66,7 +70,7 @@ __host__ void THCTensor_(scanInnermostDim)(THCState *state, THCTensor *tgt,
unsigned row_size = THCTensor_(sizeLegacyNoScalars)(state, src, ndim - 1);

dim3 threads(16, 32);
dim3 grid(min(1024, THCCeilDiv(num_rows, threads.y)));
dim3 grid(std::min(1024u, THCCeilDiv(num_rows, threads.y)));

THCTensor_kernel_scanInnermostDim<scalar_t, 16, 32><<<grid, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
THCTensor_(data)(state, tgt), THCTensor_(data)(state, src), num_rows, row_size, init, binary_op);
Expand Down
7 changes: 5 additions & 2 deletions aten/src/THCUNN/RReLU.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#include <algorithm>
#include <utility>

#include <THCUNN/THCUNN.h>
#include <TH/THHalf.h>
#include <THC/THCNumerics.cuh>
Expand All @@ -7,12 +10,12 @@
#include <curand.h>
#include <curand_kernel.h>
#include <curand_philox4x32_x.h>
#include <utility>

// copied from cutorch/lib/THC/THCTensorRandom.cu
#define MAX_NUM_BLOCKS 64
#define BLOCK_SIZE 256
#define NUM_BLOCKS(n) min((int)THCCeilDiv(n, (ptrdiff_t) BLOCK_SIZE), MAX_NUM_BLOCKS)
#define NUM_BLOCKS(n) \
(std::min((int)THCCeilDiv(n, (ptrdiff_t)BLOCK_SIZE), MAX_NUM_BLOCKS))

template<typename T>
inline T __device__ curand_uniform_type(curandStatePhilox4_32_10_t *state);
Expand Down
2 changes: 1 addition & 1 deletion caffe2/core/context_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ void TrackMemoryAlloc(size_t nbytes) {
int this_gpu = CaffeCudaGetDevice();
g_total_by_gpu_map[this_gpu] += nbytes;
g_max_by_gpu_map[this_gpu] =
max(g_max_by_gpu_map[this_gpu], g_total_by_gpu_map[this_gpu]);
std::max(g_max_by_gpu_map[this_gpu], g_total_by_gpu_map[this_gpu]);
g_total_mem += nbytes;
if (g_total_mem - g_last_rep >
FLAGS_caffe2_gpu_memory_report_interval_mb * 1024 * 1024) {
Expand Down
4 changes: 3 additions & 1 deletion caffe2/operators/boolean_mask_ops.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <algorithm>

#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/boolean_mask_ops.h"

Expand Down Expand Up @@ -91,7 +93,7 @@ class BooleanMaskOp<CUDAContext> final : public Operator<CUDAContext> {

if (numOfOutput > 0) {
BooleanMaskCopyKernel<<<
min(numOfOutput, static_cast<int64_t>(CAFFE_MAXIMUM_NUM_BLOCKS)),
std::min(numOfOutput, static_cast<int64_t>(CAFFE_MAXIMUM_NUM_BLOCKS)),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
Expand Down
6 changes: 4 additions & 2 deletions caffe2/operators/boolean_unmask_ops.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <algorithm>

#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/boolean_unmask_ops.h"

Expand Down Expand Up @@ -87,15 +89,15 @@ class BooleanUnmaskOp<CUDAContext> final : public Operator<CUDAContext> {
auto* indicesData = indices_.mutable_data<int>();

ComputeIndicesKernel<<<
min(maskSize, CAFFE_MAXIMUM_NUM_BLOCKS),
std::min(maskSize, CAFFE_MAXIMUM_NUM_BLOCKS),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
numMasks, maskSize, indicesData, masks_.data<bool*>());
auto* valueSizesData = valueSizes_.mutable_data<int>();
FillValuesKernel<<<
min(numMasks, CAFFE_MAXIMUM_NUM_BLOCKS),
std::min(numMasks, CAFFE_MAXIMUM_NUM_BLOCKS),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
Expand Down
8 changes: 5 additions & 3 deletions caffe2/operators/normalize_ops.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <algorithm>

#include <cub/block/block_reduce.cuh>

#include "caffe2/core/context_gpu.h"
Expand Down Expand Up @@ -89,7 +91,7 @@ void NormalizeOp<float, CUDAContext>::DoNormalize(
const int n,
const int sf) {
NormalizeKernel<<<
min(n, CAFFE_MAXIMUM_NUM_BLOCKS),
std::min(n, CAFFE_MAXIMUM_NUM_BLOCKS),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(m, n, sf, xData, yData, kEps_);
Expand All @@ -108,7 +110,7 @@ bool NormalizeGradientOp<float, CUDAContext>::RunOnDevice() {
int M = X.numel() / N;
const int SF = X.size_from_dim(canonical_axis + 1);
NormalizeGradientKernel<<<
min(M, CAFFE_MAXIMUM_NUM_BLOCKS),
std::min(M, CAFFE_MAXIMUM_NUM_BLOCKS),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
Expand Down Expand Up @@ -165,7 +167,7 @@ void NormalizeL1Op<float, CUDAContext>::DoNormalize(
const int n,
const int sf) {
NormalizeL1Kernel<<<
min(n, CAFFE_MAXIMUM_NUM_BLOCKS),
std::min(n, CAFFE_MAXIMUM_NUM_BLOCKS),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(m, n, sf, xData, yData);
Expand Down
4 changes: 3 additions & 1 deletion caffe2/operators/scale_blobs_op.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <algorithm>

#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/scale_blobs_op.h"

Expand Down Expand Up @@ -47,7 +49,7 @@ bool ScaleBlobsOp<CUDAContext>::DoRunWithType() {
for (int i = 0; i < numBlobs; ++i) {
hostBlobSizesData[i] = Input(i).numel();
totalSize += hostBlobSizesData[i];
maxSize = max(maxSize, hostBlobSizesData[i]);
maxSize = std::max(maxSize, hostBlobSizesData[i]);
hostInputsData[i] = Input(i).template data<T>();
hostOutputsData[i] = Output(i)->template mutable_data<T>();
}
Expand Down
4 changes: 3 additions & 1 deletion caffe2/operators/segment_reduction_op_gpu.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <algorithm>

#include <cub/block/block_reduce.cuh>
#include <cub/device/device_reduce.cuh>
#include <cub/device/device_scan.cuh>
Expand Down Expand Up @@ -1177,7 +1179,7 @@ class SortedSegmentRangeMeanOp : public Operator<Context> {
K,
context_.cuda_stream());
sorted_segment_mean_kernel<T, SIndex, LOGEXP>
<<<min(K, CAFFE_MAXIMUM_NUM_BLOCKS),
<<<std::min(K, CAFFE_MAXIMUM_NUM_BLOCKS),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
Expand Down
4 changes: 3 additions & 1 deletion caffe2/operators/sequence_ops.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <algorithm>

#include <cub/cub.cuh>

#include "caffe2/core/context_gpu.h"
Expand Down Expand Up @@ -350,7 +352,7 @@ void GatherPaddingOp<CUDAContext>::GatherPadding(
&lengths_prefix_sum_,
&context_);
gather_padding_kernel<T>
<<<min(block_size, CAFFE_MAXIMUM_NUM_BLOCKS),
<<<std::min(block_size, CAFFE_MAXIMUM_NUM_BLOCKS),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
Expand Down
4 changes: 3 additions & 1 deletion caffe2/sgd/adagrad_op_gpu.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <algorithm>

#include <cub/block/block_reduce.cuh>
#include "caffe2/core/common_gpu.h"
#include "caffe2/core/context_gpu.h"
Expand Down Expand Up @@ -203,7 +205,7 @@ bool RowWiseSparseAdagradOp<float, CUDAContext>::DoRunWithType() {
// each thread block will handle multiple rows of the input and output
RowWiseSparseAdagradKernel<<<
min(GRAD_M, CAFFE_MAXIMUM_NUM_BLOCKS),
std::min(GRAD_M, CAFFE_MAXIMUM_NUM_BLOCKS),
num_threads,
0,
context_.cuda_stream()>>>(
Expand Down

0 comments on commit 5dde8cd

Please sign in to comment.