diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh index 039ac8854a..51e4037c60 100644 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ b/cpp/include/raft/cluster/detail/kmeans.cuh @@ -15,6 +15,7 @@ */ #pragma once +#include #include #include #include @@ -28,6 +29,7 @@ #include #include +#include #include #include #include @@ -404,8 +406,8 @@ static int chooseNewCentroid(handle_t const& handle, //} RAFT_CHECK_CUDA(stream); - obsIndex = max(obsIndex, 0); - obsIndex = min(obsIndex, n - 1); + obsIndex = std::max(obsIndex, static_cast(0)); + obsIndex = std::min(obsIndex, n - 1); // Record new centroid position RAFT_CUDA_TRY(cudaMemcpyAsync(centroid, @@ -467,7 +469,7 @@ static int initializeCentroids(handle_t const& handle, auto stream = handle.get_stream(); auto thrust_exec_policy = handle.get_thrust_policy(); - constexpr index_type_t grid_lower_bound{65535}; + constexpr unsigned grid_lower_bound{65535}; // ------------------------------------------------------- // Implementation @@ -477,12 +479,12 @@ static int initializeCentroids(handle_t const& handle, dim3 blockDim_warp{WARP_SIZE, 1, BSIZE_DIV_WSIZE}; // CUDA grid dimensions - dim3 gridDim_warp{min((d + WARP_SIZE - 1) / WARP_SIZE, grid_lower_bound), + dim3 gridDim_warp{std::min(ceildiv(d, WARP_SIZE), grid_lower_bound), 1, - min((n + BSIZE_DIV_WSIZE - 1) / BSIZE_DIV_WSIZE, grid_lower_bound)}; + std::min(ceildiv(n, BSIZE_DIV_WSIZE), grid_lower_bound)}; // CUDA grid dimensions - dim3 gridDim_block{min((n + BLOCK_SIZE - 1) / BLOCK_SIZE, grid_lower_bound), 1, 1}; + dim3 gridDim_block{std::min(ceildiv(n, BLOCK_SIZE), grid_lower_bound), 1, 1}; // Assign observation vectors to code 0 RAFT_CUDA_TRY(cudaMemsetAsync(codes, 0, n * sizeof(index_type_t), stream)); @@ -574,10 +576,10 @@ static int assignCentroids(handle_t const& handle, dim3 blockDim{WARP_SIZE, 1, BLOCK_SIZE / WARP_SIZE}; dim3 gridDim; - constexpr index_type_t grid_lower_bound{65535}; - gridDim.x = min((d + WARP_SIZE - 1) / WARP_SIZE, grid_lower_bound); - gridDim.y = min(k, grid_lower_bound); - gridDim.z = min((n + BSIZE_DIV_WSIZE - 1) / BSIZE_DIV_WSIZE, grid_lower_bound); + constexpr unsigned grid_lower_bound{65535}; + gridDim.x = std::min(ceildiv(d, WARP_SIZE), grid_lower_bound); + gridDim.y = std::min(static_cast(k), grid_lower_bound); + gridDim.z = std::min(ceildiv(n, BSIZE_DIV_WSIZE), grid_lower_bound); computeDistances<<>>(n, d, k, obs, centroids, dists); RAFT_CHECK_CUDA(stream); @@ -587,7 +589,7 @@ static int assignCentroids(handle_t const& handle, blockDim.x = BLOCK_SIZE; blockDim.y = 1; blockDim.z = 1; - gridDim.x = min((n + BLOCK_SIZE - 1) / BLOCK_SIZE, grid_lower_bound); + gridDim.x = std::min(ceildiv(n, BLOCK_SIZE), grid_lower_bound); gridDim.y = 1; gridDim.z = 1; minDistances<<>>(n, k, dists, codes, clusterSizes); @@ -644,7 +646,7 @@ static int updateCentroids(handle_t const& handle, const value_type_t one = 1; const value_type_t zero = 0; - constexpr index_type_t grid_lower_bound{65535}; + constexpr unsigned grid_lower_bound{65535}; auto stream = handle.get_stream(); auto cublas_h = handle.get_cublas_handle(); @@ -717,8 +719,8 @@ static int updateCentroids(handle_t const& handle, dim3 blockDim{WARP_SIZE, BLOCK_SIZE / WARP_SIZE, 1}; // CUDA grid dimensions - dim3 gridDim{min((d + WARP_SIZE - 1) / WARP_SIZE, grid_lower_bound), - min((k + BSIZE_DIV_WSIZE - 1) / BSIZE_DIV_WSIZE, grid_lower_bound), + dim3 gridDim{std::min(ceildiv(d, WARP_SIZE), grid_lower_bound), + std::min(ceildiv(k, BSIZE_DIV_WSIZE), grid_lower_bound), 1}; divideCentroids<<>>(d, k, clusterSizes, centroids); @@ -791,7 +793,7 @@ int kmeans(handle_t const& handle, // Current iteration index_type_t iter; - constexpr index_type_t grid_lower_bound{65535}; + constexpr unsigned grid_lower_bound{65535}; // Residual sum of squares at previous iteration value_type_t residualPrev = 0; @@ -818,10 +820,9 @@ int kmeans(handle_t const& handle, dim3 blockDim{WARP_SIZE, 1, BLOCK_SIZE / WARP_SIZE}; - dim3 gridDim{ - min((d + WARP_SIZE - 1) / WARP_SIZE, grid_lower_bound), - 1, - min((n + BLOCK_SIZE / WARP_SIZE - 1) / (BLOCK_SIZE / WARP_SIZE), grid_lower_bound)}; + dim3 gridDim{std::min(ceildiv(d, WARP_SIZE), grid_lower_bound), + 1, + std::min(ceildiv(n, BLOCK_SIZE / WARP_SIZE), grid_lower_bound)}; CUDA_TRY(cudaMemsetAsync(work, 0, n * k * sizeof(value_type_t), stream)); computeDistances<<>>(n, d, 1, obs, centroids, work); @@ -958,7 +959,7 @@ int kmeans(handle_t const& handle, // Allocate memory raft::spectral::matrix::vector_t clusterSizes(handle, k); raft::spectral::matrix::vector_t centroids(handle, d * k); - raft::spectral::matrix::vector_t work(handle, n * max(k, d)); + raft::spectral::matrix::vector_t work(handle, n * std::max(k, d)); raft::spectral::matrix::vector_t work_int(handle, 2 * d * n); // Perform k-means diff --git a/cpp/include/raft/cuda_utils.cuh b/cpp/include/raft/cuda_utils.cuh index 8a66eff242..be995ea824 100644 --- a/cpp/include/raft/cuda_utils.cuh +++ b/cpp/include/raft/cuda_utils.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020, NVIDIA CORPORATION. + * Copyright (c) 2018-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -109,7 +109,7 @@ static const int WarpSize = 32; DI int laneId() { int id; - asm("mov.s32 %0, %laneid;" : "=r"(id)); + asm("mov.s32 %0, %%laneid;" : "=r"(id)); return id; } @@ -228,13 +228,13 @@ DI T myAtomicMax(T* address, T val); DI float myAtomicMin(float* address, float val) { - myAtomicReduce(address, val, fminf); + myAtomicReduce(address, val, fminf); return *address; } DI float myAtomicMax(float* address, float val) { - myAtomicReduce(address, val, fmaxf); + myAtomicReduce(address, val, fmaxf); return *address; } diff --git a/cpp/include/raft/label/detail/classlabels.cuh b/cpp/include/raft/label/detail/classlabels.cuh index c805860759..53657a5dfa 100644 --- a/cpp/include/raft/label/detail/classlabels.cuh +++ b/cpp/include/raft/label/detail/classlabels.cuh @@ -24,6 +24,8 @@ #include #include +#include + namespace raft { namespace label { namespace detail { @@ -56,7 +58,7 @@ int getUniquelabels(rmm::device_uvector& unique, value_t* y, size_t n, NULL, bytes, y, workspace.data(), n, 0, sizeof(value_t) * 8, stream); cub::DeviceSelect::Unique( NULL, bytes2, workspace.data(), workspace.data(), d_num_selected.data(), n, stream); - bytes = max(bytes, bytes2); + bytes = std::max(bytes, bytes2); rmm::device_uvector cub_storage(bytes, stream); // Select Unique classes diff --git a/cpp/include/raft/linalg/detail/contractions.cuh b/cpp/include/raft/linalg/detail/contractions.cuh index 40d0839f60..0261d1967e 100644 --- a/cpp/include/raft/linalg/detail/contractions.cuh +++ b/cpp/include/raft/linalg/detail/contractions.cuh @@ -76,7 +76,7 @@ struct Contractions_NT { /** block of Y data loaded from global mem after `ldgXY()` */ DataT ldgDataY[P::LdgPerThY][P::Veclen]; - static const DataT Zero = (DataT)0; + static constexpr DataT Zero = (DataT)0; public: /** diff --git a/cpp/include/raft/linalg/detail/qr.cuh b/cpp/include/raft/linalg/detail/qr.cuh index a250dd3578..81b1867a82 100644 --- a/cpp/include/raft/linalg/detail/qr.cuh +++ b/cpp/include/raft/linalg/detail/qr.cuh @@ -22,6 +22,8 @@ #include #include +#include + namespace raft { namespace linalg { namespace detail { @@ -37,7 +39,7 @@ void qrGetQ(const raft::handle_t& handle, cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle(); int m = n_rows, n = n_cols; - int k = min(m, n); + int k = std::min(m, n); RAFT_CUDA_TRY(cudaMemcpyAsync(Q, M, sizeof(math_t) * m * n, cudaMemcpyDeviceToDevice, stream)); rmm::device_uvector tau(k, stream); @@ -70,8 +72,8 @@ void qrGetQR(const raft::handle_t& handle, int m = n_rows, n = n_cols; rmm::device_uvector R_full(m * n, stream); - rmm::device_uvector tau(min(m, n), stream); - RAFT_CUDA_TRY(cudaMemsetAsync(tau.data(), 0, sizeof(math_t) * min(m, n), stream)); + rmm::device_uvector tau(std::min(m, n), stream); + RAFT_CUDA_TRY(cudaMemsetAsync(tau.data(), 0, sizeof(math_t) * std::min(m, n), stream)); int R_full_nrows = m, R_full_ncols = n; RAFT_CUDA_TRY( cudaMemcpyAsync(R_full.data(), M, sizeof(math_t) * m * n, cudaMemcpyDeviceToDevice, stream)); @@ -100,12 +102,12 @@ void qrGetQR(const raft::handle_t& handle, int Q_nrows = m, Q_ncols = n; RAFT_CUSOLVER_TRY(cusolverDnorgqr_bufferSize( - cusolverH, Q_nrows, Q_ncols, min(Q_ncols, Q_nrows), Q, Q_nrows, tau.data(), &Lwork)); + cusolverH, Q_nrows, Q_ncols, std::min(Q_ncols, Q_nrows), Q, Q_nrows, tau.data(), &Lwork)); workspace.resize(Lwork, stream); RAFT_CUSOLVER_TRY(cusolverDnorgqr(cusolverH, Q_nrows, Q_ncols, - min(Q_ncols, Q_nrows), + std::min(Q_ncols, Q_nrows), Q, Q_nrows, tau.data(), diff --git a/cpp/include/raft/linalg/detail/reduce_rows_by_key.cuh b/cpp/include/raft/linalg/detail/reduce_rows_by_key.cuh index aa0b1545d3..7550ce2093 100644 --- a/cpp/include/raft/linalg/detail/reduce_rows_by_key.cuh +++ b/cpp/include/raft/linalg/detail/reduce_rows_by_key.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/linalg/detail/rsvd.cuh b/cpp/include/raft/linalg/detail/rsvd.cuh index 88436eda64..3dc22a7e89 100644 --- a/cpp/include/raft/linalg/detail/rsvd.cuh +++ b/cpp/include/raft/linalg/detail/rsvd.cuh @@ -26,6 +26,8 @@ #include #include +#include + namespace raft { namespace linalg { namespace detail { @@ -386,9 +388,9 @@ void rsvdPerc(const raft::handle_t& handle, int max_sweeps, cudaStream_t stream) { - int k = max((int)(min(n_rows, n_cols) * PC_perc), - 1); // Number of singular values to be computed - int p = max((int)(min(n_rows, n_cols) * UpS_perc), 1); // Upsamples + int k = std::max((int)(std::min(n_rows, n_cols) * PC_perc), + 1); // Number of singular values to be computed + int p = std::max((int)(std::min(n_rows, n_cols) * UpS_perc), 1); // Upsamples rsvdFixedRank(handle, M, n_rows, diff --git a/cpp/include/raft/matrix/detail/linewise_op.cuh b/cpp/include/raft/matrix/detail/linewise_op.cuh index 63fa872f9d..81204bfe66 100644 --- a/cpp/include/raft/matrix/detail/linewise_op.cuh +++ b/cpp/include/raft/matrix/detail/linewise_op.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,8 @@ #include #include +#include + namespace raft { namespace matrix { namespace detail { @@ -312,7 +314,7 @@ __global__ void __launch_bounds__(BlockSize) typedef Linewise L; constexpr uint workSize = L::VecElems * BlockSize; uint workOffset = workSize; - __shared__ alignas(sizeof(Type) * L::VecElems) + __shared__ __align__(sizeof(Type) * L::VecElems) Type shm[workSize * ((sizeof...(Vecs)) > 1 ? 2 : 1)]; const IdxType blockOffset = (arrOffset + BlockSize * L::VecElems * blockIdx.x) % rowLen; return L::vectorRows( @@ -422,7 +424,7 @@ void matrixLinewiseVecCols(Type* out, const uint occupy = getOptimalGridSize(); // does not make sense to have more blocks than this const uint maxBlocks = raft::ceildiv(uint(alignedLen), bs.x * VecElems); - const dim3 gs(min(maxBlocks, occupy), 1, 1); + const dim3 gs(std::min(maxBlocks, occupy), 1, 1); // The work arrangement is blocked on the block and warp levels; // see more details at Linewise::vectorCols. // The value below determines how many scalar elements are processed by on thread in total. @@ -482,7 +484,7 @@ void matrixLinewiseVecRows(Type* out, const uint expected_grid_size = rowLen / raft::gcd(block_work_size, uint(rowLen)); // Minimum size of the grid to make the device well occupied const uint occupy = getOptimalGridSize(); - const dim3 gs(min( + const dim3 gs(std::min( // does not make sense to have more blocks than this raft::ceildiv(uint(totalLen), block_work_size), // increase the grid size to be not less than `occupy` while diff --git a/cpp/include/raft/matrix/detail/matrix.cuh b/cpp/include/raft/matrix/detail/matrix.cuh index 6d631b4f4f..f057ba283c 100644 --- a/cpp/include/raft/matrix/detail/matrix.cuh +++ b/cpp/include/raft/matrix/detail/matrix.cuh @@ -220,7 +220,7 @@ template void copyUpperTriangular(m_t* src, m_t* dst, idx_t n_rows, idx_t n_cols, cudaStream_t stream) { idx_t m = n_rows, n = n_cols; - idx_t k = min(m, n); + idx_t k = std::min(m, n); dim3 block(64); dim3 grid((m * n + block.x - 1) / block.x); getUpperTriangular<<>>(src, dst, m, n, k); @@ -246,7 +246,7 @@ template void initializeDiagonalMatrix( m_t* vec, m_t* matrix, idx_t n_rows, idx_t n_cols, cudaStream_t stream) { - idx_t k = min(n_rows, n_cols); + idx_t k = std::min(n_rows, n_cols); dim3 block(64); dim3 grid((k + block.x - 1) / block.x); copyVectorToMatrixDiagonal<<>>(vec, matrix, n_rows, n_cols, k); @@ -285,4 +285,4 @@ m_t getL2Norm(const raft::handle_t& handle, m_t* in, idx_t size, cudaStream_t st } // end namespace detail } // end namespace matrix -} // end namespace raft \ No newline at end of file +} // end namespace raft diff --git a/cpp/include/raft/random/detail/make_blobs.cuh b/cpp/include/raft/random/detail/make_blobs.cuh index fff1ab835b..b79178567b 100644 --- a/cpp/include/raft/random/detail/make_blobs.cuh +++ b/cpp/include/raft/random/detail/make_blobs.cuh @@ -245,4 +245,4 @@ void make_blobs_caller(DataT* out, } // end namespace detail } // end namespace random -} // end namespace raft \ No newline at end of file +} // end namespace raft diff --git a/cpp/include/raft/sparse/distance/detail/l2_distance.cuh b/cpp/include/raft/sparse/distance/detail/l2_distance.cuh index 0624674e81..e6dd396f2d 100644 --- a/cpp/include/raft/sparse/distance/detail/l2_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/l2_distance.cuh @@ -31,6 +31,8 @@ #include +#include + namespace raft { namespace sparse { namespace distance { @@ -411,7 +413,7 @@ class hellinger_expanded_distances_t : public distances_t { void compute(value_t* out_dists) { - rmm::device_uvector coo_rows(max(config_->b_nnz, config_->a_nnz), + rmm::device_uvector coo_rows(std::max(config_->b_nnz, config_->a_nnz), config_->handle.get_stream()); raft::sparse::convert::csr_to_coo(config_->b_indptr, diff --git a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh index de9049ced7..96d51f2e75 100644 --- a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh @@ -32,6 +32,8 @@ #include +#include + namespace raft { namespace sparse { namespace distance { @@ -48,7 +50,7 @@ void unexpanded_lp_distances(value_t* out_dists, accum_f accum_func, write_f write_func) { - rmm::device_uvector coo_rows(max(config_->b_nnz, config_->a_nnz), + rmm::device_uvector coo_rows(std::max(config_->b_nnz, config_->a_nnz), config_->handle.get_stream()); raft::sparse::convert::csr_to_coo(config_->b_indptr, @@ -283,7 +285,7 @@ class kl_divergence_unexpanded_distances_t : public distances_t { void compute(value_t* out_dists) { - rmm::device_uvector coo_rows(max(config_->b_nnz, config_->a_nnz), + rmm::device_uvector coo_rows(std::max(config_->b_nnz, config_->a_nnz), config_->handle.get_stream()); raft::sparse::convert::csr_to_coo(config_->b_indptr, diff --git a/cpp/include/raft/sparse/selection/detail/knn.cuh b/cpp/include/raft/sparse/selection/detail/knn.cuh index 82a689fe00..d263f2409f 100644 --- a/cpp/include/raft/sparse/selection/detail/knn.cuh +++ b/cpp/include/raft/sparse/selection/detail/knn.cuh @@ -31,6 +31,8 @@ #include #include +#include + namespace raft { namespace sparse { namespace selection { @@ -354,7 +356,7 @@ class sparse_knn_t { // in the case where the number of idx rows in the batch is < k, we // want to adjust k. - value_idx n_neighbors = min(k, batch_cols); + value_idx n_neighbors = std::min(static_cast(k), batch_cols); bool ascending = true; if (metric == raft::distance::DistanceType::InnerProduct) ascending = false; diff --git a/cpp/include/raft/sparse/selection/detail/knn_graph.cuh b/cpp/include/raft/sparse/selection/detail/knn_graph.cuh index 6ac96e1324..b222dfd9bd 100644 --- a/cpp/include/raft/sparse/selection/detail/knn_graph.cuh +++ b/cpp/include/raft/sparse/selection/detail/knn_graph.cuh @@ -32,6 +32,7 @@ #include #include +#include #include namespace raft { @@ -59,7 +60,7 @@ value_idx build_k(value_idx n_samples, int c) { // from "kNN-MST-Agglomerative: A fast & scalable graph-based data clustering // approach on GPU" - return min(n_samples, max((value_idx)2, (value_idx)floor(log2(n_samples)) + c)); + return std::min(n_samples, std::max((value_idx)2, (value_idx)floor(log2(n_samples)) + c)); } template diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index 6b5df01a97..e3e33e6642 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -29,7 +29,7 @@ namespace knn { namespace detail { template -DI void loadAllWarpQShmem(myWarpSelect& heapArr, +DI void loadAllWarpQShmem(myWarpSelect** heapArr, Pair* shDumpKV, const IdxT m, const unsigned int numOfNN) @@ -40,7 +40,7 @@ DI void loadAllWarpQShmem(myWarpSelect& heapArr, const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; if (rowId < m) { #pragma unroll - for (int j = 0; j < heapArr[i]->kNumWarpQRegisters; ++j) { + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { const int idx = j * warpSize + lid; if (idx < numOfNN) { Pair KVPair = shDumpKV[rowId * numOfNN + idx]; @@ -53,14 +53,14 @@ DI void loadAllWarpQShmem(myWarpSelect& heapArr, } template -DI void loadWarpQShmem(myWarpSelect& heapArr, +DI void loadWarpQShmem(myWarpSelect* heapArr, Pair* shDumpKV, const int rowId, const unsigned int numOfNN) { const int lid = raft::laneId(); #pragma unroll - for (int j = 0; j < heapArr->kNumWarpQRegisters; ++j) { + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { const int idx = j * warpSize + lid; if (idx < numOfNN) { Pair KVPair = shDumpKV[rowId * numOfNN + idx]; @@ -71,7 +71,7 @@ DI void loadWarpQShmem(myWarpSelect& heapArr, } template -DI void storeWarpQShmem(myWarpSelect& heapArr, +DI void storeWarpQShmem(myWarpSelect* heapArr, Pair* shDumpKV, const IdxT rowId, const unsigned int numOfNN) @@ -79,7 +79,7 @@ DI void storeWarpQShmem(myWarpSelect& heapArr, const int lid = raft::laneId(); #pragma unroll - for (int j = 0; j < heapArr->kNumWarpQRegisters; ++j) { + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { const int idx = j * warpSize + lid; if (idx < numOfNN) { Pair otherKV = Pair(heapArr->warpV[j], heapArr->warpK[j]); @@ -89,7 +89,7 @@ DI void storeWarpQShmem(myWarpSelect& heapArr, } template -DI void storeWarpQGmem(myWarpSelect& heapArr, +DI void storeWarpQGmem(myWarpSelect** heapArr, volatile OutT* out_dists, volatile IdxT* out_inds, const IdxT m, @@ -102,7 +102,7 @@ DI void storeWarpQGmem(myWarpSelect& heapArr, const auto gmemRowId = starty + i * Policy::AccThRows; if (gmemRowId < m) { #pragma unroll - for (int j = 0; j < heapArr[i]->kNumWarpQRegisters; ++j) { + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { const auto idx = j * warpSize + lid; if (idx < numOfNN) { out_dists[gmemRowId * numOfNN + idx] = heapArr[i]->warpK[j]; @@ -114,7 +114,7 @@ DI void storeWarpQGmem(myWarpSelect& heapArr, } template -DI void loadPrevTopKsGmemWarpQ(myWarpSelect& heapArr, +DI void loadPrevTopKsGmemWarpQ(myWarpSelect** heapArr, volatile OutT* out_dists, volatile IdxT* out_inds, const IdxT m, @@ -127,14 +127,14 @@ DI void loadPrevTopKsGmemWarpQ(myWarpSelect& heapArr, const auto gmemRowId = starty + i * Policy::AccThRows; if (gmemRowId < m) { #pragma unroll - for (int j = 0; j < heapArr[i]->kNumWarpQRegisters; ++j) { + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { const auto idx = j * warpSize + lid; if (idx < numOfNN) { heapArr[i]->warpK[j] = out_dists[gmemRowId * numOfNN + idx]; heapArr[i]->warpV[j] = (uint32_t)out_inds[gmemRowId * numOfNN + idx]; } } - auto constexpr kLaneWarpKTop = heapArr[i]->kNumWarpQRegisters - 1; + static constexpr auto kLaneWarpKTop = myWarpSelect::kNumWarpQRegisters - 1; heapArr[i]->warpKTop = raft::shfl(heapArr[i]->warpK[kLaneWarpKTop], heapArr[i]->kLane); } } @@ -261,7 +261,7 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x const auto rowId = starty + i * Policy::AccThRows; if (rowId < m) { #pragma unroll - for (int j = 0; j < heapArr[i]->kNumWarpQRegisters; ++j) { + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { Pair otherKV; otherKV.value = identity; otherKV.key = keyMax; @@ -287,7 +287,7 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x const auto rowId = starty + i * Policy::AccThRows; if (rowId < m) { #pragma unroll - for (int j = 0; j < heapArr[i]->kNumWarpQRegisters; ++j) { + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { Pair otherKV; otherKV.value = identity; otherKV.key = keyMax; @@ -341,7 +341,7 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x }; // epilogue operation lambda for final value calculation - auto epilog_lambda = [numOfNN, m, n, ldd, out_dists, out_inds] __device__( + auto epilog_lambda = [numOfNN, m, n, ldd, out_dists, out_inds, keyMax, identity] __device__( AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], DataT * regxn, DataT * regyn, @@ -448,7 +448,7 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x } const int finalNumVals = raft::shfl(numValsWarpTopK[i], 31); loadWarpQShmem(heapArr[i], &shDumpKV[0], rowId, numOfNN); - updateSortedWarpQkNumWarpQRegisters>( + updateSortedWarpQ( heapArr[i], &allWarpTopKs[0], rowId, finalNumVals); } } diff --git a/cpp/include/raft/spectral/detail/spectral_util.cuh b/cpp/include/raft/spectral/detail/spectral_util.cuh index c7a0f0c5ef..c1796cbbc1 100644 --- a/cpp/include/raft/spectral/detail/spectral_util.cuh +++ b/cpp/include/raft/spectral/detail/spectral_util.cuh @@ -25,6 +25,8 @@ #include #include +#include + namespace raft { namespace spectral { @@ -96,7 +98,7 @@ cudaError_t scale_obs(index_type_t m, index_type_t n, value_type_t* obs) // find next power of 2 p2m = next_pow2(m); // setup launch configuration - unsigned int xsize = max(2, min(p2m, 32)); + unsigned int xsize = std::max(2, std::min(p2m, 32)); dim3 nthreads{xsize, 256 / xsize, 1}; dim3 nblocks{1, (n + nthreads.y - 1) / nthreads.y, 1}; diff --git a/cpp/include/raft/stats/detail/meanvar.cuh b/cpp/include/raft/stats/detail/meanvar.cuh index 7d4c68e364..075e7fe170 100644 --- a/cpp/include/raft/stats/detail/meanvar.cuh +++ b/cpp/include/raft/stats/detail/meanvar.cuh @@ -199,13 +199,14 @@ void meanvar( if (rowMajor) { static_assert(BlockSize >= WarpSize, "Block size must be not smaller than the warp size."); const dim3 bs(WarpSize, BlockSize / WarpSize, 1); - dim3 gs(raft::ceildiv(D, bs.x), raft::ceildiv(N, bs.y), 1); + dim3 gs(raft::ceildiv(D, bs.x), raft::ceildiv(N, bs.y), 1); // Don't create more blocks than necessary to occupy the GPU int occupancy; RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &occupancy, meanvar_kernel_rowmajor, BlockSize, 0)); - gs.y = min(gs.y, raft::ceildiv(occupancy * getMultiProcessorCount(), gs.x)); + gs.y = + std::min(gs.y, raft::ceildiv(occupancy * getMultiProcessorCount(), gs.x)); // Global memory: one mean_var for each column // one lock per all blocks working on the same set of columns diff --git a/cpp/scripts/__clang_cuda_additional_intrinsics.h b/cpp/scripts/__clang_cuda_additional_intrinsics.h new file mode 100644 index 0000000000..8964d210bf --- /dev/null +++ b/cpp/scripts/__clang_cuda_additional_intrinsics.h @@ -0,0 +1,391 @@ +#ifndef __CLANG_CUDA_ADDITIONAL_INTRINSICS_H__ +#define __CLANG_CUDA_ADDITIONAL_INTRINSICS_H__ +#ifndef __CUDA__ +#error "This file is for CUDA compilation only." +#endif + +// for some of these macros, see cuda_fp16.hpp +#if defined(__cplusplus) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 320)) +#if (defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__) +#define __LDG_PTR "l" +#define __LBITS "64" +#else +#define __LDG_PTR "r" +#define __LBITS "32" +#endif // (defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__) + +#define __NOARG + +#define __MAKE_LD(cop, c_typ, int_typ, ptx_typ, inl_typ, mem) \ + __device__ __forceinline__ c_typ __ld ## cop (const c_typ* addr) { \ + int_typ out; \ + asm("ld." #cop "." ptx_typ " %0, [%1];" \ + : "=" inl_typ(out) : __LDG_PTR(addr)mem); \ + return (c_typ)out; \ + } + +#define __MAKE_LD2(cop, c_typ, int_typ, ptx_typ, inl_typ, mem) \ + __device__ __forceinline__ c_typ __ld ## cop (const c_typ* addr) { \ + int_typ out1, out2; \ + asm("ld." #cop ".v2." ptx_typ " {%0, %1}, [%2];" \ + : "=" inl_typ(out1), "=" inl_typ(out2) : __LDG_PTR(addr)mem); \ + c_typ out; \ + out.x = out1; \ + out.y = out2; \ + return out; \ + } + +#define __MAKE_LD4(cop, c_typ, int_typ, ptx_typ, inl_typ, mem) \ + __device__ __forceinline__ c_typ __ld ## cop (const c_typ* addr) { \ + int_typ out1, out2, out3, out4; \ + asm("ld." #cop".v4." ptx_typ " {%0, %1, %2, %3}, [%4];" \ + : "=" inl_typ(out1), "=" inl_typ(out2), \ + "=" inl_typ(out3), "=" inl_typ(out4) : __LDG_PTR(addr)mem); \ + c_typ out; \ + out.x = out1; \ + out.y = out2; \ + out.z = out3; \ + out.w = out4; \ + return out; \ + } + +__MAKE_LD(cg, char, short, "s8", "h", __NOARG) +__MAKE_LD(cg, signed char, short, "s8", "h", __NOARG) +__MAKE_LD(cg, unsigned char, short, "u8", "h", __NOARG) +__MAKE_LD(cg, short, short, "s16", "h", __NOARG) +__MAKE_LD(cg, unsigned short, unsigned short, "u16", "h", __NOARG) +__MAKE_LD(cg, int, int, "s32", "r", __NOARG) +__MAKE_LD(cg, unsigned int, unsigned int, "u32", "r", __NOARG) +__MAKE_LD(cg, long, long, "s" __LBITS, __LDG_PTR, __NOARG) +__MAKE_LD(cg, unsigned long, unsigned long, "u" __LBITS, __LDG_PTR, __NOARG) +__MAKE_LD(cg, long long, long long, "s64", "l", __NOARG) +__MAKE_LD(cg, unsigned long long, unsigned long long, "u64", "l", __NOARG) +__MAKE_LD(cg, float, float, "f32", "f", __NOARG) +__MAKE_LD(cg, double, double, "f64", "d", __NOARG) + +__MAKE_LD2(cg, char2, short, "s8", "h", __NOARG) +__MAKE_LD2(cg, uchar2, short, "u8", "h", __NOARG) +__MAKE_LD2(cg, short2, short, "s16", "h", __NOARG) +__MAKE_LD2(cg, ushort2, unsigned short, "u16", "h", __NOARG) +__MAKE_LD2(cg, int2, int, "s32", "r", __NOARG) +__MAKE_LD2(cg, uint2, unsigned int, "u32", "r", __NOARG) +__MAKE_LD2(cg, longlong2, long long, "s64", "l", __NOARG) +__MAKE_LD2(cg, ulonglong2, unsigned long long, "u64", "l", __NOARG) +__MAKE_LD2(cg, float2, float, "f32", "f", __NOARG) +__MAKE_LD2(cg, double2, double, "f64", "d", __NOARG) + +__MAKE_LD4(cg, char4, short, "s8", "h", __NOARG) +__MAKE_LD4(cg, uchar4, short, "u8", "h", __NOARG) +__MAKE_LD4(cg, short4, short, "s16", "h", __NOARG) +__MAKE_LD4(cg, ushort4, unsigned short, "u16", "h", __NOARG) +__MAKE_LD4(cg, int4, int, "s32", "r", __NOARG) +__MAKE_LD4(cg, uint4, unsigned int, "u32", "r", __NOARG) +__MAKE_LD4(cg, float4, float, "f32", "f", __NOARG) + + +__MAKE_LD(ca, char, short, "s8", "h", __NOARG) +__MAKE_LD(ca, signed char, short, "s8", "h", __NOARG) +__MAKE_LD(ca, unsigned char, short, "u8", "h", __NOARG) +__MAKE_LD(ca, short, short, "s16", "h", __NOARG) +__MAKE_LD(ca, unsigned short, unsigned short, "u16", "h", __NOARG) +__MAKE_LD(ca, int, int, "s32", "r", __NOARG) +__MAKE_LD(ca, unsigned int, unsigned int, "u32", "r", __NOARG) +__MAKE_LD(ca, long, long, "s" __LBITS, __LDG_PTR, __NOARG) +__MAKE_LD(ca, unsigned long, unsigned long, "u" __LBITS, __LDG_PTR, __NOARG) +__MAKE_LD(ca, long long, long long, "s64", "l", __NOARG) +__MAKE_LD(ca, unsigned long long, unsigned long long, "u64", "l", __NOARG) +__MAKE_LD(ca, float, float, "f32", "f", __NOARG) +__MAKE_LD(ca, double, double, "f64", "d", __NOARG) + +__MAKE_LD2(ca, char2, short, "s8", "h", __NOARG) +__MAKE_LD2(ca, uchar2, short, "u8", "h", __NOARG) +__MAKE_LD2(ca, short2, short, "s16", "h", __NOARG) +__MAKE_LD2(ca, ushort2, unsigned short, "u16", "h", __NOARG) +__MAKE_LD2(ca, int2, int, "s32", "r", __NOARG) +__MAKE_LD2(ca, uint2, unsigned int, "u32", "r", __NOARG) +__MAKE_LD2(ca, longlong2, long long, "s64", "l", __NOARG) +__MAKE_LD2(ca, ulonglong2, unsigned long long, "u64", "l", __NOARG) +__MAKE_LD2(ca, float2, float, "f32", "f", __NOARG) +__MAKE_LD2(ca, double2, double, "f64", "d", __NOARG) + +__MAKE_LD4(ca, char4, short, "s8", "h", __NOARG) +__MAKE_LD4(ca, uchar4, short, "u8", "h", __NOARG) +__MAKE_LD4(ca, short4, short, "s16", "h", __NOARG) +__MAKE_LD4(ca, ushort4, unsigned short, "u16", "h", __NOARG) +__MAKE_LD4(ca, int4, int, "s32", "r", __NOARG) +__MAKE_LD4(ca, uint4, unsigned int, "u32", "r", __NOARG) +__MAKE_LD4(ca, float4, float, "f32", "f", __NOARG) + + +__MAKE_LD(cs, char, short, "s8", "h", __NOARG) +__MAKE_LD(cs, signed char, short, "s8", "h", __NOARG) +__MAKE_LD(cs, unsigned char, short, "u8", "h", __NOARG) +__MAKE_LD(cs, short, short, "s16", "h", __NOARG) +__MAKE_LD(cs, unsigned short, unsigned short, "u16", "h", __NOARG) +__MAKE_LD(cs, int, int, "s32", "r", __NOARG) +__MAKE_LD(cs, unsigned int, unsigned int, "u32", "r", __NOARG) +__MAKE_LD(cs, long, long, "s" __LBITS, __LDG_PTR, __NOARG) +__MAKE_LD(cs, unsigned long, unsigned long, "u" __LBITS, __LDG_PTR, __NOARG) +__MAKE_LD(cs, long long, long long, "s64", "l", __NOARG) +__MAKE_LD(cs, unsigned long long, unsigned long long, "u64", "l", __NOARG) +__MAKE_LD(cs, float, float, "f32", "f", __NOARG) +__MAKE_LD(cs, double, double, "f64", "d", __NOARG) + +__MAKE_LD2(cs, char2, short, "s8", "h", __NOARG) +__MAKE_LD2(cs, uchar2, short, "u8", "h", __NOARG) +__MAKE_LD2(cs, short2, short, "s16", "h", __NOARG) +__MAKE_LD2(cs, ushort2, unsigned short, "u16", "h", __NOARG) +__MAKE_LD2(cs, int2, int, "s32", "r", __NOARG) +__MAKE_LD2(cs, uint2, unsigned int, "u32", "r", __NOARG) +__MAKE_LD2(cs, longlong2, long long, "s64", "l", __NOARG) +__MAKE_LD2(cs, ulonglong2, unsigned long long, "u64", "l", __NOARG) +__MAKE_LD2(cs, float2, float, "f32", "f", __NOARG) +__MAKE_LD2(cs, double2, double, "f64", "d", __NOARG) + +__MAKE_LD4(cs, char4, short, "s8", "h", __NOARG) +__MAKE_LD4(cs, uchar4, short, "u8", "h", __NOARG) +__MAKE_LD4(cs, short4, short, "s16", "h", __NOARG) +__MAKE_LD4(cs, ushort4, unsigned short, "u16", "h", __NOARG) +__MAKE_LD4(cs, int4, int, "s32", "r", __NOARG) +__MAKE_LD4(cs, uint4, unsigned int, "u32", "r", __NOARG) +__MAKE_LD4(cs, float4, float, "f32", "f", __NOARG) + + +__MAKE_LD(lu, char, short, "s8", "h", : "memory") +__MAKE_LD(lu, signed char, short, "s8", "h", : "memory") +__MAKE_LD(lu, unsigned char, short, "u8", "h", : "memory") +__MAKE_LD(lu, short, short, "s16", "h", : "memory") +__MAKE_LD(lu, unsigned short, unsigned short, "u16", "h", : "memory") +__MAKE_LD(lu, int, int, "s32", "r", : "memory") +__MAKE_LD(lu, unsigned int, unsigned int, "u32", "r", : "memory") +__MAKE_LD(lu, long, long, "s" __LBITS, __LDG_PTR, : "memory") +__MAKE_LD(lu, unsigned long, unsigned long, "u" __LBITS, __LDG_PTR, : "memory") +__MAKE_LD(lu, long long, long long, "s64", "l", : "memory") +__MAKE_LD(lu, unsigned long long, unsigned long long, "u64", "l", : "memory") +__MAKE_LD(lu, float, float, "f32", "f", : "memory") +__MAKE_LD(lu, double, double, "f64", "d", : "memory") + +__MAKE_LD2(lu, char2, short, "s8", "h", : "memory") +__MAKE_LD2(lu, uchar2, short, "u8", "h", : "memory") +__MAKE_LD2(lu, short2, short, "s16", "h", : "memory") +__MAKE_LD2(lu, ushort2, unsigned short, "u16", "h", : "memory") +__MAKE_LD2(lu, int2, int, "s32", "r", : "memory") +__MAKE_LD2(lu, uint2, unsigned int, "u32", "r", : "memory") +__MAKE_LD2(lu, longlong2, long long, "s64", "l", : "memory") +__MAKE_LD2(lu, ulonglong2, unsigned long long, "u64", "l", : "memory") +__MAKE_LD2(lu, float2, float, "f32", "f", : "memory") +__MAKE_LD2(lu, double2, double, "f64", "d", : "memory") + +__MAKE_LD4(lu, char4, short, "s8", "h", : "memory") +__MAKE_LD4(lu, uchar4, short, "u8", "h", : "memory") +__MAKE_LD4(lu, short4, short, "s16", "h", : "memory") +__MAKE_LD4(lu, ushort4, unsigned short, "u16", "h", : "memory") +__MAKE_LD4(lu, int4, int, "s32", "r", : "memory") +__MAKE_LD4(lu, uint4, unsigned int, "u32", "r", : "memory") +__MAKE_LD4(lu, float4, float, "f32", "f", : "memory") + + +__MAKE_LD(cv, char, short, "s8", "h", : "memory") +__MAKE_LD(cv, signed char, short, "s8", "h", : "memory") +__MAKE_LD(cv, unsigned char, short, "u8", "h", : "memory") +__MAKE_LD(cv, short, short, "s16", "h", : "memory") +__MAKE_LD(cv, unsigned short, unsigned short, "u16", "h", : "memory") +__MAKE_LD(cv, int, int, "s32", "r", : "memory") +__MAKE_LD(cv, unsigned int, unsigned int, "u32", "r", : "memory") +__MAKE_LD(cv, long, long, "s" __LBITS, __LDG_PTR, : "memory") +__MAKE_LD(cv, unsigned long, unsigned long, "u" __LBITS, __LDG_PTR, : "memory") +__MAKE_LD(cv, long long, long long, "s64", "l", : "memory") +__MAKE_LD(cv, unsigned long long, unsigned long long, "u64", "l", : "memory") +__MAKE_LD(cv, float, float, "f32", "f", : "memory") +__MAKE_LD(cv, double, double, "f64", "d", : "memory") + +__MAKE_LD2(cv, char2, short, "s8", "h", : "memory") +__MAKE_LD2(cv, uchar2, short, "u8", "h", : "memory") +__MAKE_LD2(cv, short2, short, "s16", "h", : "memory") +__MAKE_LD2(cv, ushort2, unsigned short, "u16", "h", : "memory") +__MAKE_LD2(cv, int2, int, "s32", "r", : "memory") +__MAKE_LD2(cv, uint2, unsigned int, "u32", "r", : "memory") +__MAKE_LD2(cv, longlong2, long long, "s64", "l", : "memory") +__MAKE_LD2(cv, ulonglong2, unsigned long long, "u64", "l", : "memory") +__MAKE_LD2(cv, float2, float, "f32", "f", : "memory") +__MAKE_LD2(cv, double2, double, "f64", "d", : "memory") + +__MAKE_LD4(cv, char4, short, "s8", "h", : "memory") +__MAKE_LD4(cv, uchar4, short, "u8", "h", : "memory") +__MAKE_LD4(cv, short4, short, "s16", "h", : "memory") +__MAKE_LD4(cv, ushort4, unsigned short, "u16", "h", : "memory") +__MAKE_LD4(cv, int4, int, "s32", "r", : "memory") +__MAKE_LD4(cv, uint4, unsigned int, "u32", "r", : "memory") +__MAKE_LD4(cv, float4, float, "f32", "f", : "memory") + + +#define __MAKE_ST(cop, c_typ, int_typ, ptx_typ, inl_typ) \ + __device__ __forceinline__ void __st ## cop (c_typ* addr, c_typ v) { \ + asm("st." #cop "." ptx_typ " [%0], %1;" \ + :: __LDG_PTR(addr), inl_typ((int_typ)v) : "memory"); \ + } + +#define __MAKE_ST2(cop, c_typ, int_typ, ptx_typ, inl_typ) \ + __device__ __forceinline__ void __st ## cop (c_typ* addr, c_typ v) { \ + int_typ v1 = v.x, v2 = v.y; \ + asm("st." #cop ".v2." ptx_typ " [%0], {%1, %2};" \ + :: __LDG_PTR(addr), inl_typ(v1), inl_typ(v2) : "memory"); \ + } + +#define __MAKE_ST4(cop, c_typ, int_typ, ptx_typ, inl_typ) \ + __device__ __forceinline__ c_typ __st ## cop (c_typ* addr, c_typ v) { \ + int_typ v1 = v.x, v2 = v.y, v3 = v.z, v4 = v.w; \ + asm("st." #cop ".v4." ptx_typ " [%0], {%1, %2, %3, %4};" \ + :: __LDG_PTR(addr), inl_typ(v1), inl_typ(v2), \ + inl_typ(v3), inl_typ(v4) : "memory"); \ + } + +__MAKE_ST(wb, char, short, "s8", "h") +__MAKE_ST(wb, signed char, short, "s8", "h") +__MAKE_ST(wb, unsigned char, short, "u8", "h") +__MAKE_ST(wb, short, short, "s16", "h") +__MAKE_ST(wb, unsigned short, unsigned short, "u16", "h") +__MAKE_ST(wb, int, int, "s32", "r") +__MAKE_ST(wb, unsigned int, unsigned int, "u32", "r") +__MAKE_ST(wb, long, long, "s" __LBITS, __LDG_PTR) +__MAKE_ST(wb, unsigned long, unsigned long, "u" __LBITS, __LDG_PTR) +__MAKE_ST(wb, long long, long long, "s64", "l") +__MAKE_ST(wb, unsigned long long, unsigned long long, "u64", "l") +__MAKE_ST(wb, float, float, "f32", "f") +__MAKE_ST(wb, double, double, "f64", "d") + +__MAKE_ST2(wb, char2, short, "s8", "h") +__MAKE_ST2(wb, uchar2, short, "u8", "h") +__MAKE_ST2(wb, short2, short, "s16", "h") +__MAKE_ST2(wb, ushort2, unsigned short, "u16", "h") +__MAKE_ST2(wb, int2, int, "s32", "r") +__MAKE_ST2(wb, uint2, unsigned int, "u32", "r") +__MAKE_ST2(wb, longlong2, long long, "s64", "l") +__MAKE_ST2(wb, ulonglong2, unsigned long long, "u64", "l") +__MAKE_ST2(wb, float2, float, "f32", "f") +__MAKE_ST2(wb, double2, double, "f64", "d") + +__MAKE_ST4(wb, char4, short, "s8", "h") +__MAKE_ST4(wb, uchar4, short, "u8", "h") +__MAKE_ST4(wb, short4, short, "s16", "h") +__MAKE_ST4(wb, ushort4, unsigned short, "u16", "h") +__MAKE_ST4(wb, int4, int, "s32", "r") +__MAKE_ST4(wb, uint4, unsigned int, "u32", "r") +__MAKE_ST4(wb, float4, float, "f32", "f") + + +__MAKE_ST(cg, char, short, "s8", "h") +__MAKE_ST(cg, signed char, short, "s8", "h") +__MAKE_ST(cg, unsigned char, short, "u8", "h") +__MAKE_ST(cg, short, short, "s16", "h") +__MAKE_ST(cg, unsigned short, unsigned short, "u16", "h") +__MAKE_ST(cg, int, int, "s32", "r") +__MAKE_ST(cg, unsigned int, unsigned int, "u32", "r") +__MAKE_ST(cg, long, long, "s" __LBITS, __LDG_PTR) +__MAKE_ST(cg, unsigned long, unsigned long, "u" __LBITS, __LDG_PTR) +__MAKE_ST(cg, long long, long long, "s64", "l") +__MAKE_ST(cg, unsigned long long, unsigned long long, "u64", "l") +__MAKE_ST(cg, float, float, "f32", "f") +__MAKE_ST(cg, double, double, "f64", "d") + +__MAKE_ST2(cg, char2, short, "s8", "h") +__MAKE_ST2(cg, uchar2, short, "u8", "h") +__MAKE_ST2(cg, short2, short, "s16", "h") +__MAKE_ST2(cg, ushort2, unsigned short, "u16", "h") +__MAKE_ST2(cg, int2, int, "s32", "r") +__MAKE_ST2(cg, uint2, unsigned int, "u32", "r") +__MAKE_ST2(cg, longlong2, long long, "s64", "l") +__MAKE_ST2(cg, ulonglong2, unsigned long long, "u64", "l") +__MAKE_ST2(cg, float2, float, "f32", "f") +__MAKE_ST2(cg, double2, double, "f64", "d") + +__MAKE_ST4(cg, char4, short, "s8", "h") +__MAKE_ST4(cg, uchar4, short, "u8", "h") +__MAKE_ST4(cg, short4, short, "s16", "h") +__MAKE_ST4(cg, ushort4, unsigned short, "u16", "h") +__MAKE_ST4(cg, int4, int, "s32", "r") +__MAKE_ST4(cg, uint4, unsigned int, "u32", "r") +__MAKE_ST4(cg, float4, float, "f32", "f") + + +__MAKE_ST(cs, char, short, "s8", "h") +__MAKE_ST(cs, signed char, short, "s8", "h") +__MAKE_ST(cs, unsigned char, short, "u8", "h") +__MAKE_ST(cs, short, short, "s16", "h") +__MAKE_ST(cs, unsigned short, unsigned short, "u16", "h") +__MAKE_ST(cs, int, int, "s32", "r") +__MAKE_ST(cs, unsigned int, unsigned int, "u32", "r") +__MAKE_ST(cs, long, long, "s" __LBITS, __LDG_PTR) +__MAKE_ST(cs, unsigned long, unsigned long, "u" __LBITS, __LDG_PTR) +__MAKE_ST(cs, long long, long long, "s64", "l") +__MAKE_ST(cs, unsigned long long, unsigned long long, "u64", "l") +__MAKE_ST(cs, float, float, "f32", "f") +__MAKE_ST(cs, double, double, "f64", "d") + +__MAKE_ST2(cs, char2, short, "s8", "h") +__MAKE_ST2(cs, uchar2, short, "u8", "h") +__MAKE_ST2(cs, short2, short, "s16", "h") +__MAKE_ST2(cs, ushort2, unsigned short, "u16", "h") +__MAKE_ST2(cs, int2, int, "s32", "r") +__MAKE_ST2(cs, uint2, unsigned int, "u32", "r") +__MAKE_ST2(cs, longlong2, long long, "s64", "l") +__MAKE_ST2(cs, ulonglong2, unsigned long long, "u64", "l") +__MAKE_ST2(cs, float2, float, "f32", "f") +__MAKE_ST2(cs, double2, double, "f64", "d") + +__MAKE_ST4(cs, char4, short, "s8", "h") +__MAKE_ST4(cs, uchar4, short, "u8", "h") +__MAKE_ST4(cs, short4, short, "s16", "h") +__MAKE_ST4(cs, ushort4, unsigned short, "u16", "h") +__MAKE_ST4(cs, int4, int, "s32", "r") +__MAKE_ST4(cs, uint4, unsigned int, "u32", "r") +__MAKE_ST4(cs, float4, float, "f32", "f") + + +__MAKE_ST(wt, char, short, "s8", "h") +__MAKE_ST(wt, signed char, short, "s8", "h") +__MAKE_ST(wt, unsigned char, short, "u8", "h") +__MAKE_ST(wt, short, short, "s16", "h") +__MAKE_ST(wt, unsigned short, unsigned short, "u16", "h") +__MAKE_ST(wt, int, int, "s32", "r") +__MAKE_ST(wt, unsigned int, unsigned int, "u32", "r") +__MAKE_ST(wt, long, long, "s" __LBITS, __LDG_PTR) +__MAKE_ST(wt, unsigned long, unsigned long, "u" __LBITS, __LDG_PTR) +__MAKE_ST(wt, long long, long long, "s64", "l") +__MAKE_ST(wt, unsigned long long, unsigned long long, "u64", "l") +__MAKE_ST(wt, float, float, "f32", "f") +__MAKE_ST(wt, double, double, "f64", "d") + +__MAKE_ST2(wt, char2, short, "s8", "h") +__MAKE_ST2(wt, uchar2, short, "u8", "h") +__MAKE_ST2(wt, short2, short, "s16", "h") +__MAKE_ST2(wt, ushort2, unsigned short, "u16", "h") +__MAKE_ST2(wt, int2, int, "s32", "r") +__MAKE_ST2(wt, uint2, unsigned int, "u32", "r") +__MAKE_ST2(wt, longlong2, long long, "s64", "l") +__MAKE_ST2(wt, ulonglong2, unsigned long long, "u64", "l") +__MAKE_ST2(wt, float2, float, "f32", "f") +__MAKE_ST2(wt, double2, double, "f64", "d") + +__MAKE_ST4(wt, char4, short, "s8", "h") +__MAKE_ST4(wt, uchar4, short, "u8", "h") +__MAKE_ST4(wt, short4, short, "s16", "h") +__MAKE_ST4(wt, ushort4, unsigned short, "u16", "h") +__MAKE_ST4(wt, int4, int, "s32", "r") +__MAKE_ST4(wt, uint4, unsigned int, "u32", "r") +__MAKE_ST4(wt, float4, float, "f32", "f") + + +#undef __MAKE_ST4 +#undef __MAKE_ST2 +#undef __MAKE_ST +#undef __MAKE_LD4 +#undef __MAKE_LD2 +#undef __MAKE_LD +#undef __NOARG +#undef __LBITS +#undef __LDG_PTR + +#endif // defined(__cplusplus) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 320)) + +#endif // defined(__CLANG_CUDA_ADDITIONAL_INTRINSICS_H__) diff --git a/cpp/scripts/run-clang-compile.py b/cpp/scripts/run-clang-compile.py new file mode 100644 index 0000000000..4edbde84b3 --- /dev/null +++ b/cpp/scripts/run-clang-compile.py @@ -0,0 +1,331 @@ +# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +# IMPORTANT DISCLAIMER: # +# This file is experimental and may not run successfully on the entire repo! # +# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +# + +from __future__ import print_function +import argparse +import glob +import json +import multiprocessing as mp +import os +import re +import shutil +import subprocess + + +CLANG_COMPILER = "clang++" +GPU_ARCH_REGEX = re.compile(r"sm_(\d+)") +SPACES = re.compile(r"\s+") +XCOMPILER_FLAG = re.compile(r"-((Xcompiler)|(-compiler-options))=?") +XPTXAS_FLAG = re.compile(r"-((Xptxas)|(-ptxas-options))=?") +# any options that may have equal signs in nvcc but not in clang +# add those options here if you find any +OPTIONS_NO_EQUAL_SIGN = ['-isystem'] +SEPARATOR = "-" * 8 +END_SEPARATOR = "*" * 64 + + +def parse_args(): + argparser = argparse.ArgumentParser("Runs clang++ on a project instead of nvcc") + argparser.add_argument( + "-cdb", type=str, default="compile_commands.json", + help="Path to cmake-generated compilation database") + argparser.add_argument( + "-ignore", type=str, default=None, + help="Regex used to ignore files from checking") + argparser.add_argument( + "-select", type=str, default=None, + help="Regex used to select files for checking") + argparser.add_argument( + "-j", type=int, default=-1, help="Number of parallel jobs to launch.") + args = argparser.parse_args() + if args.j <= 0: + args.j = mp.cpu_count() + args.ignore_compiled = re.compile(args.ignore) if args.ignore else None + args.select_compiled = re.compile(args.select) if args.select else None + # we don't check clang's version, it should be OK with any clang + # recent enough to handle CUDA >= 11 + if not os.path.exists(args.cdb): + raise Exception("Compilation database '%s' missing" % args.cdb) + return args + + +def list_all_cmds(cdb): + with open(cdb, "r") as fp: + return json.load(fp) + + +def get_gpu_archs(command): + archs = [] + for loc in range(len(command)): + if (command[loc] != "-gencode" and command[loc] != "--generate-code" + and not command[loc].startswith("--generate-code=")): + continue + if command[loc].startswith("--generate-code="): + arch_flag = command[loc][len("--generate-code="):] + else: + arch_flag = command[loc + 1] + match = GPU_ARCH_REGEX.search(arch_flag) + if match is not None: + archs.append("--cuda-gpu-arch=sm_%s" % match.group(1)) + return archs + + +def get_index(arr, item_options): + return set(i for i, s in enumerate(arr) for item in item_options + if s == item) + + +def remove_items(arr, item_options): + for i in sorted(get_index(arr, item_options), reverse=True): + del arr[i] + + +def remove_items_plus_one(arr, item_options): + for i in sorted(get_index(arr, item_options), reverse=True): + if i < len(arr) - 1: + del arr[i + 1] + del arr[i] + idx = set(i for i, s in enumerate(arr) for item in item_options + if s.startswith(item + "=")) + for i in sorted(idx, reverse=True): + del arr[i] + + +def add_cuda_path(command, nvcc): + nvcc_path = shutil.which(nvcc) + if not nvcc_path: + raise Exception("Command %s has invalid compiler %s" % (command, nvcc)) + cuda_root = os.path.dirname(os.path.dirname(nvcc_path)) + # make sure that cuda root has version.txt + if not os.path.isfile(os.path.join(cuda_root, "version.txt")): + raise Exception( + "clang++ expects a `version.txt` file in your CUDA root path with " + "content `CUDA Version ..`") + command.append('--cuda-path=%s' % cuda_root) + + +def get_clang_args(cmd): + command, file = cmd["command"], cmd["file"] + is_cuda = file.endswith(".cu") + command = re.split(SPACES, command) + # get original compiler + cc_orig = command[0] + # compiler is always clang++! + command[0] = "clang++" + # remove compilation and output targets from the original command + remove_items_plus_one(command, ["--compile", "-c"]) + remove_items_plus_one(command, ["--output-file", "-o"]) + if is_cuda: + # replace nvcc's "-gencode ..." with clang's "--cuda-gpu-arch ..." + archs = get_gpu_archs(command) + command.extend(archs) + # provide proper cuda path to clang + add_cuda_path(command, cc_orig) + # remove all kinds of nvcc flags clang doesn't know about + remove_items_plus_one(command, [ + "--generate-code", + "-gencode", + "--x", + "-x", + "--compiler-bindir", + "-ccbin", + "--diag_suppress", + "-diag-suppress", + "--default-stream", + "-default-stream", + ]) + remove_items(command, [ + "-extended-lambda", + "--extended-lambda", + "-expt-extended-lambda", + "--expt-extended-lambda", + "-expt-relaxed-constexpr", + "--expt-relaxed-constexpr", + "--device-debug", + "-G", + "--generate-line-info", + "-lineinfo", + ]) + # "-x cuda" is the right usage in clang + command.extend(["-x", "cuda"]) + # we remove -Xcompiler flags: here we basically have to hope for the + # best that clang++ will accept any flags which nvcc passed to gcc + for i, c in reversed(list(enumerate(command))): + new_c = XCOMPILER_FLAG.sub('', c) + if new_c == c: + continue + command[i:i + 1] = new_c.split(',') + # we also change -Xptxas to -Xcuda-ptxas, always adding space here + for i, c in reversed(list(enumerate(command))): + if XPTXAS_FLAG.search(c): + if not c.endswith("=") and i < len(command) - 1: + del command[i + 1] + command[i] = '-Xcuda-ptxas' + command.insert(i + 1, XPTXAS_FLAG.sub('', c)) + # several options like isystem don't expect `=` + for opt in OPTIONS_NO_EQUAL_SIGN: + opt_eq = opt + '=' + # make sure that we iterate from back to front here for insert + for i, c in reversed(list(enumerate(command))): + if not c.startswith(opt_eq): + continue + x = c.split('=') + # we only care about the first `=` + command[i] = x[0] + command.insert(i + 1, '='.join(x[1:])) + # use extensible whole program, to avoid ptx resolution/linking + command.extend(["-Xcuda-ptxas", "-ewp"]) + # for libcudacxx, we need to allow variadic functions + command.extend(["-Xclang", "-fcuda-allow-variadic-functions"]) + # add some additional CUDA intrinsics + cuda_intrinsics_file = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "__clang_cuda_additional_intrinsics.h") + command.extend(["-include", cuda_intrinsics_file]) + # somehow this option gets onto the commandline, it is unrecognized by clang + remove_items(command, [ + "--forward-unknown-to-host-compiler", + "-forward-unknown-to-host-compiler" + ]) + # do not treat warnings as errors here ! + for i, x in reversed(list(enumerate(command))): + if x.startswith("-Werror"): + del command[i] + # add GCC headers if we can find GCC + gcc_path = shutil.which("gcc") + if gcc_path: + gcc_base = os.path.dirname(os.path.dirname(gcc_path)) + gcc_glob1 = os.path.join(gcc_base, "lib", "gcc", "*", "*", "include") + gcc_glob2 = os.path.join(gcc_base, "lib64", "gcc", "*", "*", "include") + inc_dirs = glob.glob(gcc_glob1) + glob.glob(gcc_glob2) + for d in inc_dirs: + command.extend(["-isystem", d]) + return command + + +def run_clang_command(clang_cmd, cwd): + cmd = " ".join(clang_cmd) + result = subprocess.run(cmd, check=False, shell=True, cwd=cwd, + stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + result.stdout = result.stdout.decode("utf-8").strip() + out = "CMD: " + cmd + "\n" + out += "CWD: " + cwd + "\n" + out += "EXIT-CODE: %d\n" % result.returncode + status = result.returncode == 0 + out += result.stdout + return status, out + + +class LockContext(object): + def __init__(self, lock=None) -> None: + self._lock = lock + + def __enter__(self): + if self._lock: + self._lock.acquire() + return self + + def __exit__(self, _, __, ___): + if self._lock: + self._lock.release() + return False # we don't handle exceptions + + +def print_result(passed, stdout, file): + status_str = "PASSED" if passed else "FAILED" + print("%s File:%s %s %s" % (SEPARATOR, file, status_str, SEPARATOR)) + if not passed and stdout: + print(stdout) + print("%s\n" % END_SEPARATOR) + + +def run_clang(cmd, args): + command = get_clang_args(cmd) + cwd = os.path.dirname(args.cdb) + # compile only and dump output to /dev/null + command.extend(["-c", cmd["file"], "-o", os.devnull]) + status, out = run_clang_command(command, cwd) + # we immediately print the result since this is more interactive for user + with lock: + print_result(status, out, cmd["file"]) + return status + + +# mostly used for debugging purposes +def run_sequential(args, all_files): + # lock must be defined as in `run_parallel` + global lock + lock = LockContext() + results = [] + for cmd in all_files: + # skip files that we don't want to look at + if args.ignore_compiled is not None and \ + re.search(args.ignore_compiled, cmd["file"]) is not None: + continue + if args.select_compiled is not None and \ + re.search(args.select_compiled, cmd["file"]) is None: + continue + results.append(run_clang(cmd, args)) + return all(results) + + +def copy_lock(init_lock): + # this is required to pass locks to pool workers + # see https://stackoverflow.com/questions/25557686/ + # python-sharing-a-lock-between-processes + global lock + lock = init_lock + + +def run_parallel(args, all_files): + init_lock = LockContext(mp.Lock()) + pool = mp.Pool(args.j, initializer=copy_lock, initargs=(init_lock,)) + results = [] + for cmd in all_files: + # skip files that we don't want to look at + if args.ignore_compiled is not None and \ + re.search(args.ignore_compiled, cmd["file"]) is not None: + continue + if args.select_compiled is not None and \ + re.search(args.select_compiled, cmd["file"]) is None: + continue + results.append(pool.apply_async(run_clang, args=(cmd, args))) + results_final = [r.get() for r in results] + pool.close() + pool.join() + return all(results_final) + + +def main(): + args = parse_args() + all_files = list_all_cmds(args.cdb) + # ensure that we use only the real paths + for cmd in all_files: + cmd["file"] = os.path.realpath(os.path.expanduser(cmd["file"])) + if args.j == 1: + status = run_sequential(args, all_files) + else: + status = run_parallel(args, all_files) + if not status: + raise Exception("clang++ failed! Refer to the errors above.") + + +if __name__ == "__main__": + main() diff --git a/cpp/scripts/run-clang-tidy.py b/cpp/scripts/run-clang-tidy.py index 23260d2f4d..ed1a633232 100644 --- a/cpp/scripts/run-clang-tidy.py +++ b/cpp/scripts/run-clang-tidy.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,6 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # +# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +# IMPORTANT DISCLAIMER: # +# This file is experimental and may not run successfully on the entire repo! # +# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +# from __future__ import print_function import sys diff --git a/cpp/test/linalg/rsvd.cu b/cpp/test/linalg/rsvd.cu index da38464bf7..7b0bb7c928 100644 --- a/cpp/test/linalg/rsvd.cu +++ b/cpp/test/linalg/rsvd.cu @@ -23,6 +23,8 @@ #include #include +#include + namespace raft { namespace linalg { @@ -111,8 +113,8 @@ class RsvdTest : public ::testing::TestWithParam> { raft::update_host(A_backup_cpu.data(), A.data(), m * n, stream); if (params.k == 0) { - params.k = max((int)(min(m, n) * params.PC_perc), 1); - params.p = max((int)(min(m, n) * params.UpS_perc), 1); + params.k = std::max((int)(std::min(m, n) * params.PC_perc), 1); + params.p = std::max((int)(std::min(m, n) * params.UpS_perc), 1); } U.resize(m * params.k, stream); diff --git a/cpp/test/random/make_blobs.cu b/cpp/test/random/make_blobs.cu index caad627d49..48e8986947 100644 --- a/cpp/test/random/make_blobs.cu +++ b/cpp/test/random/make_blobs.cu @@ -170,7 +170,6 @@ const std::vector> inputsf_t = { {0.011, 1024, 8, 3, 1.f, false, true, raft::random::GenPC, 1234ULL}, {0.0055, 5003, 32, 5, 1.f, true, false, raft::random::GenPhilox, 1234ULL}, {0.011, 5003, 8, 5, 1.f, true, false, raft::random::GenPhilox, 1234ULL}, - {0.0055, 5003, 32, 5, 1.f, true, false, raft::random::GenPC, 1234ULL}, {0.011, 5003, 8, 5, 1.f, true, false, raft::random::GenPC, 1234ULL}, {0.0055, 5003, 32, 5, 1.f, false, false, raft::random::GenPhilox, 1234ULL}, @@ -230,4 +229,4 @@ TEST_P(MakeBlobsTestD, Result) { check(); } INSTANTIATE_TEST_CASE_P(MakeBlobsTests, MakeBlobsTestD, ::testing::ValuesIn(inputsd_t)); } // end namespace random -} // end namespace raft \ No newline at end of file +} // end namespace raft diff --git a/cpp/test/stats/meanvar.cu b/cpp/test/stats/meanvar.cu index a5bb5b0b0d..b0efe1c7dd 100644 --- a/cpp/test/stats/meanvar.cu +++ b/cpp/test/stats/meanvar.cu @@ -21,6 +21,8 @@ #include #include +#include + namespace raft { namespace stats { @@ -34,7 +36,10 @@ struct MeanVarInputs { T mean_tol() const { return T(N_SIGMAS) * stddev / sqrt(T(rows)); } - T var_tol() const { return T(N_SIGMAS) * stddev * stddev * sqrt(T(2.0) / T(max(1, rows - 1))); } + T var_tol() const + { + return T(N_SIGMAS) * stddev * stddev * sqrt(T(2.0) / T(std::max(1, rows - 1))); + } }; template