From e8f1862e36072ff867f59ed3e38e8dcb7bb02fd3 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade <36705640+mdoijade@users.noreply.github.com> Date: Wed, 2 Jun 2021 20:18:28 +0530 Subject: [PATCH] Add Grid stride pairwise dist and fused L2 NN kernels (#232) This PR addresses issues mentioned in https://github.com/rapidsai/raft/issues/221 -- Adds grid stride based fusedL2NN kernel, this gives approx 1.85x speed up over previous version of this kernel. -- Adds support in pairwise dist base class to work for any input size by adding support for grid stride based work distribution. Authors: - Mahesh Doijade (https://github.com/mdoijade) Approvers: - Thejaswi. N. S (https://github.com/teju85) - Divye Gala (https://github.com/divyegala) - Alex Fender (https://github.com/afender) URL: https://github.com/rapidsai/raft/pull/232 --- cpp/include/raft/distance/cosine.cuh | 35 +- cpp/include/raft/distance/euclidean.cuh | 77 ++-- cpp/include/raft/distance/fused_l2_nn.cuh | 366 +++++++----------- cpp/include/raft/distance/l1.cuh | 36 +- .../raft/distance/pairwise_distance_base.cuh | 163 ++++++-- cpp/include/raft/linalg/contractions.cuh | 8 +- cpp/test/distance/fused_l2_nn.cu | 2 - 7 files changed, 359 insertions(+), 328 deletions(-) diff --git a/cpp/include/raft/distance/cosine.cuh b/cpp/include/raft/distance/cosine.cuh index 5a212ce64c..ed9bd28b7f 100644 --- a/cpp/include/raft/distance/cosine.cuh +++ b/cpp/include/raft/distance/cosine.cuh @@ -61,8 +61,6 @@ void cosineImpl(const DataT *x, const DataT *y, const DataT *xn, typedef typename std::conditional::type KPolicy; - dim3 grid(raft::ceildiv(m, KPolicy::Mblk), - raft::ceildiv(n, KPolicy::Nblk)); dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda @@ -73,7 +71,8 @@ void cosineImpl(const DataT *x, const DataT *y, const DataT *xn, // epilogue operation lambda for final value calculation auto epilog_lambda = [] __device__( AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, DataT * regyn) { + DataT * regxn, DataT * regyn, IdxT gridStrideX, + IdxT gridStrideY) { #pragma unroll for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll @@ -83,20 +82,26 @@ void cosineImpl(const DataT *x, const DataT *y, const DataT *xn, } }; + constexpr size_t shmemSize = + KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)); if (isRowMajor) { - pairwiseDistanceMatKernel - <<>>(x, y, xn, yn, m, n, k, lda, - ldb, ldd, dOutput, core_lambda, - epilog_lambda, fin_op); + auto cosineRowMajor = + pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, shmemSize, cosineRowMajor); + cosineRowMajor<<>>( + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, + fin_op); } else { - pairwiseDistanceMatKernel - <<>>(x, y, xn, yn, m, n, k, lda, - ldb, ldd, dOutput, core_lambda, - epilog_lambda, fin_op); + auto cosineColMajor = + pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, shmemSize, cosineColMajor); + cosineColMajor<<>>( + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, + fin_op); } CUDA_CHECK(cudaGetLastError()); diff --git a/cpp/include/raft/distance/euclidean.cuh b/cpp/include/raft/distance/euclidean.cuh index f3f946ad7b..484da0e5bf 100644 --- a/cpp/include/raft/distance/euclidean.cuh +++ b/cpp/include/raft/distance/euclidean.cuh @@ -60,8 +60,6 @@ void euclideanExpImpl(const DataT *x, const DataT *y, const DataT *xn, typedef typename std::conditional::type KPolicy; - dim3 grid(raft::ceildiv(m, KPolicy::Mblk), - raft::ceildiv(n, KPolicy::Nblk)); dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda @@ -72,7 +70,8 @@ void euclideanExpImpl(const DataT *x, const DataT *y, const DataT *xn, // epilogue operation lambda for final value calculation auto epilog_lambda = [sqrt] __device__( AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, DataT * regyn) { + DataT * regxn, DataT * regyn, IdxT gridStrideX, + IdxT gridStrideY) { #pragma unroll for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll @@ -91,20 +90,29 @@ void euclideanExpImpl(const DataT *x, const DataT *y, const DataT *xn, } }; + constexpr size_t shmemSize = + KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)); if (isRowMajor) { - pairwiseDistanceMatKernel - <<>>(x, y, xn, yn, m, n, k, lda, - ldb, ldd, dOutput, core_lambda, - epilog_lambda, fin_op); + auto euclideanExpRowMajor = + pairwiseDistanceMatKernel; + dim3 grid = + launchConfigGenerator(m, n, shmemSize, euclideanExpRowMajor); + + euclideanExpRowMajor<<>>( + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, + fin_op); } else { - pairwiseDistanceMatKernel - <<>>(x, y, xn, yn, m, n, k, lda, - ldb, ldd, dOutput, core_lambda, - epilog_lambda, fin_op); + auto euclideanExpColMajor = + pairwiseDistanceMatKernel; + dim3 grid = + launchConfigGenerator(m, n, shmemSize, euclideanExpColMajor); + euclideanExpColMajor<<>>( + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, + fin_op); } CUDA_CHECK(cudaGetLastError()); @@ -229,8 +237,7 @@ void euclideanUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, typedef typename std::conditional::type KPolicy; - dim3 grid(raft::ceildiv(m, KPolicy::Mblk), - raft::ceildiv(n, KPolicy::Nblk)); + dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda @@ -242,7 +249,8 @@ void euclideanUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, // epilogue operation lambda for final value calculation auto epilog_lambda = [sqrt] __device__( AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, DataT * regyn) { + DataT * regxn, DataT * regyn, IdxT gridStrideX, + IdxT gridStrideY) { if (sqrt) { #pragma unroll for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { @@ -255,19 +263,28 @@ void euclideanUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, }; if (isRowMajor) { - pairwiseDistanceMatKernel - <<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, - epilog_lambda, fin_op); + auto euclideanUnExpRowMajor = + pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + euclideanUnExpRowMajor); + + euclideanUnExpRowMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); + } else { - pairwiseDistanceMatKernel - <<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, - epilog_lambda, fin_op); + auto euclideanUnExpColMajor = + pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + euclideanUnExpColMajor); + + euclideanUnExpColMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); } CUDA_CHECK(cudaGetLastError()); diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.cuh index 000d856841..b96a536e38 100644 --- a/cpp/include/raft/distance/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/fused_l2_nn.cuh @@ -20,6 +20,7 @@ #include #include #include +#include #include namespace raft { @@ -68,117 +69,81 @@ struct MinReduceOp { DI void init(DataT* out, DataT maxVal) { *out = maxVal; } }; -template > -struct FusedL2NN : public BaseClass { - private: - typedef Policy P; - - const DataT* xn; - const DataT* yn; - OutT* min; - int* mutex; - - DataT *sxNorm, *syNorm; - cub::KeyValuePair* sRed; - - DataT maxVal; - - DataT acc[P::AccRowsPerTh][P::AccColsPerTh]; - - ReduceOpT redOp; - KVPReduceOpT pairRedOp; - -#if (ENABLE_MEMCPY_ASYNC == 1) - DataT zeros[P::Veclen]; - nvcuda::experimental::pipeline pipe; -#endif - - static const DataT Two = (DataT)2.0; - static constexpr size_t SizeAndAlign = P::Veclen * sizeof(DataT); - - public: - DI FusedL2NN(OutT* _min, const DataT* _x, const DataT* _y, const DataT* _xn, - const DataT* _yn, IdxT _m, IdxT _n, IdxT _k, char* _smem, - DataT _mv, int* _mut, ReduceOpT op, KVPReduceOpT pair_op) - : BaseClass(_x, _y, _m, _n, _k, _smem), - xn(_xn), - yn(_yn), - min(_min), - mutex(_mut), - sxNorm((DataT*)_smem), - syNorm(&(sxNorm[P::Mblk])), - sRed((cub::KeyValuePair*)_smem), - maxVal(_mv), - redOp(op), - pairRedOp(pair_op) { -#if (ENABLE_MEMCPY_ASYNC == 1) -#pragma unroll - for (int i = 0; i < P::Veclen; ++i) { - zeros[i] = BaseClass::Zero; - } -#endif - } - - DI void run() { - prolog(); - loop(); - __syncthreads(); // so that we can safely reuse smem - epilog(); +template +__global__ void initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) { + auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; + if (tid < m) { + redOp.init(min + tid, maxVal); } +} - private: - DI void prolog() { - this->ldgXY(0); +// TODO: specialize this function for MinAndDistanceReduceOp +// with atomicCAS of 64 bit which will eliminate mutex and shfls +template +DI void updateReducedVal(int* mutex, OutT* min, KVPair* val, ReduceOpT red_op, + IdxT m, IdxT gridStrideY) { + const auto lid = threadIdx.x % raft::WarpSize; + const auto accrowid = threadIdx.x / P::AccThCols; + + // for now have first lane from each warp update a unique output row. This + // will resolve hang issues with pre-Volta architectures #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { + for (int j = 0; j < (raft::WarpSize / P::AccThCols); j++) { + if (lid == 0) { #pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = BaseClass::Zero; + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto rid = gridStrideY + accrowid + j + i * P::AccThRows; + if (rid < m) { + auto value = val[i]; + while (atomicCAS(mutex + rid, 0, 1) == 1) + ; + __threadfence(); + red_op(rid, min + rid, value); + __threadfence(); + atomicCAS(mutex + rid, 1, 0); + } } } - this->stsXY(); - __syncthreads(); - this->pageWr ^= 1; - } - - DI void loop() { - for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(kidx); - accumulate(); // on the previous k-block - this->stsXY(); - __syncthreads(); - this->pageWr ^= 1; - this->pageRd ^= 1; + if (j < (raft::WarpSize / P::AccThCols) - 1) { +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto tmpkey = raft::shfl(val[i].key, (j + 1) * P::AccThCols); + auto tmpvalue = raft::shfl(val[i].value, (j + 1) * P::AccThCols); + val[i] = {tmpkey, tmpvalue}; + } } - accumulate(); // last iteration } +} - DI void epilog() { - for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { - auto idx = blockIdx.x * P::Mblk + i; - sxNorm[i] = idx < this->m ? xn[idx] : maxVal; - } - for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { - auto idx = blockIdx.y * P::Nblk + i; - syNorm[i] = idx < this->n ? yn[idx] : maxVal; - } - __syncthreads(); - DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - regxn[i] = sxNorm[i * P::AccThRows + this->accrowid]; - } +template +__global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel( + OutT* min, const DataT* x, const DataT* y, const DataT* xn, const DataT* yn, + IdxT m, IdxT n, IdxT k, DataT maxVal, int* mutex, ReduceOpT redOp, + KVPReduceOpT pairRedOp, CoreLambda core_op, FinalLambda fin_op) { + extern __shared__ char smem[]; + + typedef cub::KeyValuePair KVPair; + KVPair val[P::AccRowsPerTh]; #pragma unroll - for (int i = 0; i < P::AccColsPerTh; ++i) { - regyn[i] = syNorm[i * P::AccThCols + this->acccolid]; - } + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {-1, maxVal}; + } + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [n, pairRedOp, &val, maxVal] __device__( + DataT acc[P::AccRowsPerTh][P::AccColsPerTh], + DataT * regxn, DataT * regyn, IdxT gridStrideX, + IdxT gridStrideY) { + KVPReduceOpT pairRed_op(pairRedOp); + #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - Two * acc[i][j]; + acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; } } if (Sqrt) { @@ -190,175 +155,112 @@ struct FusedL2NN : public BaseClass { } } } - // reduce - cub::KeyValuePair val[P::AccRowsPerTh]; - auto lid = raft::laneId(); + + // intra thread reduce + const auto acccolid = threadIdx.x % P::AccThCols; + const auto accrowid = threadIdx.x / P::AccThCols; #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {-1, maxVal}; #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { - auto tmpkey = this->acccolid + j * P::AccThCols + blockIdx.y * P::Nblk; - cub::KeyValuePair tmp = {tmpkey, acc[i][j]}; - if (tmpkey < this->n) + auto tmpkey = acccolid + j * P::AccThCols + gridStrideX; + KVPair tmp = {tmpkey, acc[i][j]}; + if (tmpkey < n) { val[i] = - pairRedOp(this->accrowid + i * P::AccThRows + blockIdx.x * P::Mblk, - tmp, val[i]); + pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); + } } - __syncthreads(); + } + }; + + auto rowEpilog_lambda = [m, mutex, min, pairRedOp, redOp, &val, + maxVal] __device__(IdxT gridStrideY) { + KVPReduceOpT pairRed_op(pairRedOp); + ReduceOpT red_op(redOp); + + const auto accrowid = threadIdx.x / P::AccThCols; + const auto lid = raft::laneId(); + + // reduce +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = P::AccThCols / 2; j > 0; j >>= 1) { auto tmpkey = raft::shfl(val[i].key, lid + j); auto tmpvalue = raft::shfl(val[i].value, lid + j); - cub::KeyValuePair tmp = {tmpkey, tmpvalue}; + KVPair tmp = {tmpkey, tmpvalue}; val[i] = - pairRedOp(this->accrowid + i * P::AccThRows + blockIdx.x * P::Mblk, - tmp, val[i]); - } - } - if (lid % P::AccThCols == 0) { -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - sRed[i * P::AccThCols + this->accrowid] = val[i]; + pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); } } - __syncthreads(); - updateResults(); - } - /* - * todo: From Volta onwards see if "coalesced" atomicCAS approach as - * written below helps improve perf - * ``` - * auto tid = threadIdx.x; - * auto rid = IdxT(blockIdx.x) * P::Mblk + tid; - * if (rid < m) { - * auto val = sRed[i]; - * while (atomicCAS(mutex + rid, 0, 1) == 1) - * ; - * __threadfence(); - * redOp(rid, min + rid, val); - * __threadfence(); - * atomicCAS(mutex + rid, 1, 0); - * } - * ``` - */ - DI void updateResults() { - // for now have first lane from each warp update a unique output row. This - // will resolve hang issues with pre-Volta architectures - auto nWarps = blockDim.x / raft::WarpSize; - auto lid = raft::laneId(); - auto ridx = IdxT(blockIdx.x) * P::Mblk; - if (lid == 0) { - for (int i = threadIdx.x / raft::WarpSize; i < P::Mblk; i += nWarps) { - auto rid = ridx + i; - if (rid < this->m) { - auto val = sRed[i]; - while (atomicCAS(mutex + rid, 0, 1) == 1) - ; - __threadfence(); - redOp(rid, min + rid, val); - __threadfence(); - atomicCAS(mutex + rid, 1, 0); - } - } - } - } + updateReducedVal(mutex, min, val, red_op, + m, gridStrideY); - DI void accumulate() { + // reset the val array. #pragma unroll - for (int ki = 0; ki < P::Kblk; ki += P::Veclen) { - this->ldsXY(ki); -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { -#pragma unroll - for (int v = 0; v < P::Veclen; ++v) { - acc[i][j] += this->regx[i][v] * this->regy[j][v]; - } - } - } - } - } - -#if (ENABLE_MEMCPY_ASYNC == 1) - DI void ldgXY(IdxT kidx) { - auto koffset = kidx + this->scolid; - auto offset = - this->pageWr * P::SmemPage + this->srowid * P::SmemStride + this->scolid; - auto* saddrx = this->sx + offset; - for (int i = 0; i < P::LdgPerThX; ++i) { - auto* sax = saddrx + i * P::LdgRowsX * P::SmemStride; - auto* gax = this->x + i * P::LdgRowsX * this->k + koffset; - auto inside = - koffset < this->k && (this->xrowid + i * P::LdgRowsX) < this->m; - __pipeline_memcpy_async(sax, inside ? gax : nullptr, SizeAndAlign, - inside ? 0 : SizeAndAlign); - } - auto* saddry = this->sy + offset; - for (int i = 0; i < P::LdgPerThY; ++i) { - auto* say = saddry + i * P::LdgRowsY * P::SmemStride; - auto* gay = this->y + i * P::LdgRowsY * this->k + koffset; - auto inside = - koffset < this->k && (this->yrowid + i * P::LdgRowsY) < this->n; - __pipeline_memcpy_async(say, inside ? gay : nullptr, SizeAndAlign, - inside ? 0 : SizeAndAlign); + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {-1, maxVal}; } - pipe.commit(); - } - - DI void stsXY() { pipe.wait_prior<0>(); } -#endif // ENABLE_MEMCPY_ASYNC -}; // struct FusedL2NN - -template -__global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2NNkernel( - OutT* min, const DataT* x, const DataT* y, const DataT* xn, const DataT* yn, - IdxT m, IdxT n, IdxT k, DataT maxVal, int* mutex, ReduceOpT redOp, - KVPReduceOpT pairRedOp) { - extern __shared__ char smem[]; - FusedL2NN obj( - min, x, y, xn, yn, m, n, k, smem, maxVal, mutex, redOp, pairRedOp); + }; + + IdxT lda = k, ldb = k, ldd = n; + PairwiseDistances + obj(x, y, m, n, k, lda, ldb, ldd, xn, yn, nullptr, smem, core_op, + epilog_lambda, fin_op, rowEpilog_lambda); obj.run(); } -template -__global__ void initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) { - auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; - if (tid < m) { - redOp.init(min + tid, maxVal); - } -} - template void fusedL2NNImpl(OutT* min, const DataT* x, const DataT* y, const DataT* xn, const DataT* yn, IdxT m, IdxT n, IdxT k, int* workspace, ReduceOpT redOp, KVPReduceOpT pairRedOp, bool sqrt, bool initOutBuffer, cudaStream_t stream) { - typedef typename linalg::Policy4x4::Policy Policy; - dim3 grid(raft::ceildiv(m, Policy::Mblk), - raft::ceildiv(n, Policy::Nblk)); - dim3 blk(Policy::Nthreads); - auto nblks = raft::ceildiv(m, Policy::Nthreads); - auto maxVal = std::numeric_limits::max(); + typedef typename linalg::Policy4x4::Policy P; + + dim3 blk(P::Nthreads); + auto nblks = raft::ceildiv(m, P::Nthreads); + constexpr auto maxVal = std::numeric_limits::max(); + typedef cub::KeyValuePair KVPair; + + // Accumulation operation lambda + auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { + acc += x * y; + }; + CUDA_CHECK(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); if (initOutBuffer) { initKernel - <<>>(min, m, maxVal, redOp); + <<>>(min, m, maxVal, redOp); CUDA_CHECK(cudaGetLastError()); } + + auto fin_op = [] __device__(DataT d_val, int g_d_idx) { return d_val; }; + + constexpr size_t shmemSize = + P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); if (sqrt) { - fusedL2NNkernel - <<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp); + auto fusedL2NNSqrt = + fusedL2NNkernel; + dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NNSqrt); + + fusedL2NNSqrt<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, + core_lambda, fin_op); } else { - fusedL2NNkernel - <<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp); + auto fusedL2NN = + fusedL2NNkernel; + dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NN); + fusedL2NN<<>>(min, x, y, xn, yn, m, n, k, + maxVal, workspace, redOp, + pairRedOp, core_lambda, fin_op); } + CUDA_CHECK(cudaGetLastError()); } diff --git a/cpp/include/raft/distance/l1.cuh b/cpp/include/raft/distance/l1.cuh index ce4fbb33e3..6ab084f041 100644 --- a/cpp/include/raft/distance/l1.cuh +++ b/cpp/include/raft/distance/l1.cuh @@ -53,8 +53,6 @@ static void l1Impl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, typedef typename std::conditional::type KPolicy; - dim3 grid(raft::ceildiv(m, KPolicy::Mblk), - raft::ceildiv(n, KPolicy::Nblk)); dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda @@ -66,22 +64,30 @@ static void l1Impl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, // epilogue operation lambda for final value calculation auto epilog_lambda = [] __device__( AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, DataT * regyn) { return; }; + DataT * regxn, DataT * regyn, IdxT gridStrideX, + IdxT gridStrideY) { return; }; if (isRowMajor) { - pairwiseDistanceMatKernel - <<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, - epilog_lambda, fin_op); + auto l1RowMajor = + pairwiseDistanceMatKernel; + dim3 grid = + launchConfigGenerator(m, n, KPolicy::SmemSize, l1RowMajor); + + l1RowMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); } else { - pairwiseDistanceMatKernel - <<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, - epilog_lambda, fin_op); + auto l1ColMajor = + pairwiseDistanceMatKernel; + dim3 grid = + launchConfigGenerator(m, n, KPolicy::SmemSize, l1ColMajor); + l1ColMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); } CUDA_CHECK(cudaGetLastError()); diff --git a/cpp/include/raft/distance/pairwise_distance_base.cuh b/cpp/include/raft/distance/pairwise_distance_base.cuh index 4e1605b887..503397bac9 100644 --- a/cpp/include/raft/distance/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/pairwise_distance_base.cuh @@ -14,8 +14,11 @@ * limitations under the License. */ #pragma once +#include +#include #include #include +#include namespace raft { namespace distance { @@ -53,9 +56,12 @@ namespace distance { * @param epilog_op the epilog operation lambda * @param fin_op the final gemm epilogue lambda */ + template > struct PairwiseDistances : public BaseClass { @@ -63,13 +69,13 @@ struct PairwiseDistances : public BaseClass { typedef Policy P; const DataT* xn; const DataT* yn; - DataT* sxNorm; - DataT* syNorm; + const DataT* const yBase; OutT* dOutput; char* smem; CoreLambda core_op; EpilogueLambda epilog_op; FinalLambda fin_op; + rowEpilogueLambda rowEpilog_op; AccT acc[P::AccRowsPerTh][P::AccColsPerTh]; @@ -79,34 +85,95 @@ struct PairwiseDistances : public BaseClass { IdxT _k, IdxT _lda, IdxT _ldb, IdxT _ldd, const DataT* _xn, const DataT* _yn, OutT* _dOutput, char* _smem, CoreLambda _core_op, - EpilogueLambda _epilog_op, FinalLambda _fin_op) + EpilogueLambda _epilog_op, FinalLambda _fin_op, + rowEpilogueLambda _rowEpilog_op) : BaseClass(_x, _y, _m, _n, _k, _lda, _ldb, _ldd, _smem), - sxNorm((DataT*)_smem), - syNorm(&(sxNorm[P::Mblk])), xn(_xn), yn(_yn), + yBase(_y), dOutput(_dOutput), smem(_smem), core_op(_core_op), epilog_op(_epilog_op), - fin_op(_fin_op) {} + fin_op(_fin_op), + rowEpilog_op(_rowEpilog_op) {} DI void run() { - prolog(); - loop(); - epilog(); + for (auto gridStrideY = blockIdx.y * P::Mblk; gridStrideY < this->m; + gridStrideY += P::Mblk * gridDim.y) { + for (auto gridStrideX = blockIdx.x * P::Nblk; gridStrideX < this->n; + gridStrideX += P::Nblk * gridDim.x) { + prolog(gridStrideX, gridStrideY); + loop(); + epilog(gridStrideX, gridStrideY); + } + rowEpilog_op(gridStrideY); + } } private: - DI void prolog() { + DI void updateIndicesY() { + const auto stride = P::Nblk * gridDim.x; + if (isRowMajor) { + this->y += stride * this->ldb; + } else { + this->y += stride; + } + this->yrowid += stride; + this->pageWr = 0; + this->pageRd = 0; + } + + DI void updateIndicesXY() { + const auto stride = P::Mblk * gridDim.y; + if (isRowMajor) { + this->x += stride * this->lda; + this->yrowid = IdxT(blockIdx.x) * P::Nblk + this->srowid; + this->y = yBase + this->yrowid * this->ldb; + } else { + this->x += stride; + this->yrowid = IdxT(blockIdx.x) * P::Nblk; + this->y = yBase + this->yrowid + this->srowid * this->ldb; + } + this->xrowid += stride; + this->pageWr = 0; + this->pageRd = 0; + } + + DI void prolog(IdxT gridStrideX, IdxT gridStrideY) { + if (gridStrideX > blockIdx.x * P::Nblk) { + updateIndicesY(); + } else if (gridStrideY > blockIdx.y * P::Mblk) { + updateIndicesXY(); + } + + typedef TxN_t VecType; + VecType zeros; + zeros.fill(BaseClass::Zero); +#pragma unroll + for (int j = 0; j < P::LdgPerThX; ++j) { + zeros.store(&this->ldgDataX[j][0], 0); + } +#pragma unroll + for (int j = 0; j < P::LdgPerThY; ++j) { + zeros.store(&this->ldgDataY[j][0], 0); + } + this->ldgXY(0); + #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { + zeros.store(&this->regx[i][0], 0); #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { acc[i][j] = BaseClass::Zero; } } +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + zeros.store(&this->regy[j][0], 0); + } + this->stsXY(); __syncthreads(); this->pageWr ^= 1; @@ -141,19 +208,24 @@ struct PairwiseDistances : public BaseClass { } } - DI void epilog() { + DI void epilog(IdxT gridStrideX, IdxT gridStrideY) { if (useNorms) { - __syncthreads(); // so that we can safely reuse smem + DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); + DataT* syNorm = (&sxNorm[P::Mblk]); // Load x & y norms required by this threadblock in shmem buffer - for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { - auto idx = blockIdx.x * P::Mblk + i; - sxNorm[i] = idx < this->m ? xn[idx] : 0; + if (gridStrideX == blockIdx.x * P::Nblk) { + for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { + auto idx = gridStrideY + i; + sxNorm[i] = idx < this->m ? xn[idx] : 0; + } } + for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { - auto idx = blockIdx.y * P::Nblk + i; + auto idx = gridStrideX + i; syNorm[i] = idx < this->n ? yn[idx] : 0; } + __syncthreads(); DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; @@ -166,21 +238,24 @@ struct PairwiseDistances : public BaseClass { regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; } - epilog_op(acc, regxn, regyn); + epilog_op(acc, regxn, regyn, gridStrideX, gridStrideY); } else { - epilog_op(acc, nullptr, nullptr); + epilog_op(acc, nullptr, nullptr, gridStrideX, gridStrideY); } - IdxT startx = blockIdx.x * P::Mblk + this->accrowid; - IdxT starty = blockIdx.y * P::Nblk + this->acccolid; + if (writeOut) { + IdxT starty = gridStrideY + this->accrowid; + IdxT startx = gridStrideX + this->acccolid; + #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto rowId = startx + i * P::AccThRows; + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto rowId = starty + i * P::AccThRows; #pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto colId = starty + j * P::AccThCols; - if (rowId < this->m && colId < this->n) { - dOutput[rowId * this->n + colId] = fin_op(acc[i][j], 0); + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto colId = startx + j * P::AccThCols; + if (rowId < this->m && colId < this->n) { + dOutput[rowId * this->n + colId] = fin_op(acc[i][j], 0); + } } } } @@ -217,9 +292,11 @@ struct PairwiseDistances : public BaseClass { * @param epilog_op the epilogue lambda * @param fin_op the final gemm epilogue lambda */ + template + typename EpilogueLambda, typename FinalLambda, bool isRowMajor = true, + bool writeOut = true> __global__ __launch_bounds__( Policy::Nthreads, 2) void pairwiseDistanceMatKernel(const DataT* x, const DataT* y, @@ -229,13 +306,39 @@ __global__ __launch_bounds__( EpilogueLambda epilog_op, FinalLambda fin_op) { extern __shared__ char smem[]; + auto rowEpilog = [] __device__(IdxT starty) { return; }; PairwiseDistances + EpilogueLambda, FinalLambda, decltype(rowEpilog), + isRowMajor, writeOut> obj(x, y, m, n, k, lda, ldb, ldd, _xn, _yn, dOutput, smem, core_op, - epilog_op, fin_op); + epilog_op, fin_op, rowEpilog); obj.run(); } +template +dim3 launchConfigGenerator(IdxT m, IdxT n, size_t sMemSize, T func) { + const auto numSMs = raft::getMultiProcessorCount(); + int numBlocksPerSm = 0; + dim3 grid; + + CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &numBlocksPerSm, func, P::Nthreads, sMemSize)); + int minGridSize = numSMs * numBlocksPerSm; + int yChunks = raft::ceildiv(m, P::Mblk); + int xChunks = raft::ceildiv(n, P::Nblk); + grid.y = yChunks > minGridSize ? minGridSize : yChunks; + grid.x = (minGridSize - grid.y) <= 0 ? 1 : xChunks; + if (grid.x != 1) { + int i = 1; + while (grid.y * i < minGridSize) { + i++; + } + grid.x = i >= xChunks ? xChunks : i; + } + + return grid; +} + }; // namespace distance }; // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/linalg/contractions.cuh b/cpp/include/raft/linalg/contractions.cuh index 86d608ea87..c590abb142 100644 --- a/cpp/include/raft/linalg/contractions.cuh +++ b/cpp/include/raft/linalg/contractions.cuh @@ -293,13 +293,13 @@ struct Contractions_NT { pageWr(0), pageRd(0) { if (isRowMajor) { - xrowid = IdxT(blockIdx.x) * P::Mblk + srowid; - yrowid = IdxT(blockIdx.y) * P::Nblk + srowid; + xrowid = IdxT(blockIdx.y) * P::Mblk + srowid; + yrowid = IdxT(blockIdx.x) * P::Nblk + srowid; x = _x + xrowid * lda; y = _y + yrowid * ldb; } else { - xrowid = IdxT(blockIdx.x) * P::Mblk; - yrowid = IdxT(blockIdx.y) * P::Nblk; + xrowid = IdxT(blockIdx.y) * P::Mblk; + yrowid = IdxT(blockIdx.x) * P::Nblk; x = _x + xrowid + srowid * lda; y = _y + yrowid + srowid * ldb; } diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index d4e39a0b5e..4573a070b6 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -164,7 +164,6 @@ struct CompareApproxAbsKVP { typedef typename cub::KeyValuePair KVP; CompareApproxAbsKVP(T eps_) : eps(eps_) {} bool operator()(const KVP &a, const KVP &b) const { - if (a.key != b.key) return false; T diff = raft::abs(raft::abs(a.value) - raft::abs(b.value)); T m = std::max(raft::abs(a.value), raft::abs(b.value)); T ratio = m >= eps ? diff / m : diff; @@ -179,7 +178,6 @@ template struct CompareExactKVP { typedef typename cub::KeyValuePair KVP; bool operator()(const KVP &a, const KVP &b) const { - if (a.key != b.key) return false; if (a.value != b.value) return false; return true; }