diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index 69bb83d29a..f6e66d068e 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -138,27 +138,14 @@ struct PairwiseDistances : public BaseClass { DI void updateIndicesY() { const auto stride = P::Nblk * gridDim.x; - if (isRowMajor) { - this->y += stride * this->ldb; - } else { - this->y += stride; - } - this->yrowid += stride; + this->increment_grid_idx_n(stride); } DI void updateIndicesXY() { const auto stride = P::Mblk * gridDim.y; - if (isRowMajor) { - this->x += stride * this->lda; - this->yrowid = IdxT(blockIdx.x) * P::Nblk + this->srowid; - this->y = yBase + this->yrowid * this->ldb; - } else { - this->x += stride; - this->yrowid = IdxT(blockIdx.x) * P::Nblk; - this->y = yBase + this->yrowid + this->srowid * this->ldb; - } - this->xrowid += stride; + this->increment_grid_idx_m(stride); + this->reset_grid_idx_n(); } DI void ldgNextGridStride(IdxT gridStrideX, IdxT gridStrideY) @@ -187,7 +174,7 @@ struct PairwiseDistances : public BaseClass { this->stsXY(); __syncthreads(); - this->pageWr ^= 1; + this->switch_write_buffer(); } DI void loop() @@ -197,15 +184,15 @@ struct PairwiseDistances : public BaseClass { accumulate(); // on the previous k-block this->stsXY(); __syncthreads(); - this->pageWr ^= 1; - this->pageRd ^= 1; + this->switch_write_buffer(); + this->switch_read_buffer(); } accumulate(); // last iteration // This is needed for making sure next grid stride of // non-norm based metrics uses previously accumulated buffer so // it doesn't make shmem dirty until previous iteration // is complete. - this->pageRd ^= 1; + this->switch_read_buffer(); } DI void accumulate() diff --git a/cpp/include/raft/linalg/detail/contractions.cuh b/cpp/include/raft/linalg/detail/contractions.cuh index 5d83f88e71..a6efdec49e 100644 --- a/cpp/include/raft/linalg/detail/contractions.cuh +++ b/cpp/include/raft/linalg/detail/contractions.cuh @@ -40,14 +40,15 @@ struct Contractions_NT { /** leading dimension in Output D */ IdxT ldd; - /** current thread's global mem row id for X data */ - IdxT xrowid; - /** current thread's global mem row id for Y data */ - IdxT yrowid; /** global memory pointer to X matrix */ - const DataT* x; + const DataT* x_base; /** global memory pointer to Y matrix */ - const DataT* y; + const DataT* y_base; + + /** Support variables to provide backward compatibility **/ + IdxT grid_idx_m = 0; + IdxT grid_idx_n = 0; + bool first_constructor_called; /** current thread's smem row id */ int srowid; @@ -94,10 +95,8 @@ struct Contractions_NT { k(_k), lda(_k), ldb(_k), - xrowid(IdxT(blockIdx.x) * P::Mblk + threadIdx.x / P::LdgThRow), - yrowid(IdxT(blockIdx.y) * P::Nblk + threadIdx.x / P::LdgThRow), - x(_x + xrowid * lda), - y(_y + yrowid * ldb), + x_base(_x), + y_base(_y), srowid(threadIdx.x / P::LdgThRow), scolid((threadIdx.x % P::LdgThRow) * P::Veclen), accrowid(threadIdx.x / P::AccThCols), @@ -105,7 +104,8 @@ struct Contractions_NT { sx((DataT*)_smem), sy(&(sx[P::SmemPageX])), pageWr(0), - pageRd(0) + pageRd(0), + first_constructor_called(true) { } @@ -133,6 +133,8 @@ struct Contractions_NT { lda(_lda), ldb(_ldb), ldd(_ldd), + x_base(_x), + y_base(_y), srowid(threadIdx.x / P::LdgThRow), scolid((threadIdx.x % P::LdgThRow) * P::Veclen), accrowid(threadIdx.x / P::AccThCols), @@ -140,19 +142,9 @@ struct Contractions_NT { sx((DataT*)_smem), sy(&(sx[P::SmemPageX])), pageWr(0), - pageRd(0) + pageRd(0), + first_constructor_called(false) { - if (isRowMajor) { - xrowid = IdxT(blockIdx.y) * P::Mblk + srowid; - yrowid = IdxT(blockIdx.x) * P::Nblk + srowid; - x = _x + xrowid * lda; - y = _y + yrowid * ldb; - } else { - xrowid = IdxT(blockIdx.y) * P::Mblk; - yrowid = IdxT(blockIdx.x) * P::Nblk; - x = _x + xrowid + srowid * lda; - y = _y + yrowid + srowid * ldb; - } } protected: @@ -166,6 +158,12 @@ struct Contractions_NT { ldgY(kidx); } + DI void ldgXY(IdxT tile_idx_m, IdxT tile_idx_n, IdxT kidx) + { + ldgX(tile_idx_m, kidx); + ldgY(tile_idx_n, kidx); + } + /** * @brief Store current block of X/Y from registers to smem * @param[in] kidx current start index of k to be loaded @@ -186,9 +184,35 @@ struct Contractions_NT { ldsY(kidx, sy + pageRd * P::SmemPage); } + DI void increment_grid_idx_m(IdxT by) { grid_idx_m += by; } + + DI void increment_grid_idx_n(IdxT by) { grid_idx_n += by; } + + DI void reset_grid_idx_n() { grid_idx_n = 0; } + + DI void switch_read_buffer() { this->pageRd ^= 1; } + + DI void switch_write_buffer() { this->pageWr ^= 1; } + private: DI void ldgX(IdxT kidx) { + // Backward compatible way to determine the tile index. This depends on + // whether the first or the second constructor was called. The first + // constructor is called in epsilon_neighborhood.cuh and the second + // constructor is called in pairwise_distance_base.cuh. + if (first_constructor_called) { + ldgX(IdxT(blockIdx.x) * P::Mblk, kidx); + } else { + ldgX(grid_idx_m + IdxT(blockIdx.y) * P::Mblk, kidx); + } + } + + DI void ldgX(IdxT tile_idx_m, IdxT kidx) + { + IdxT xrowid = isRowMajor ? tile_idx_m + srowid : tile_idx_m; + auto x = isRowMajor ? x_base + xrowid * lda : x_base + xrowid + srowid * lda; + if (isRowMajor) { auto numRows = m; auto koffset = kidx + scolid; @@ -222,6 +246,18 @@ struct Contractions_NT { DI void ldgY(IdxT kidx) { + if (first_constructor_called) { + ldgY(IdxT(blockIdx.y) * P::Nblk, kidx); + } else { + ldgY(grid_idx_n + IdxT(blockIdx.x) * P::Nblk, kidx); + } + } + + DI void ldgY(IdxT tile_idx_n, IdxT kidx) + { + IdxT yrowid = isRowMajor ? tile_idx_n + srowid : tile_idx_n; + auto y = isRowMajor ? y_base + yrowid * ldb : y_base + yrowid + srowid * ldb; + if (isRowMajor) { auto numRows = n; auto koffset = kidx + scolid; @@ -315,4 +351,4 @@ struct Contractions_NT { } // namespace detail } // namespace linalg -} // namespace raft \ No newline at end of file +} // namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh b/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh index 19862d743d..cd0e005921 100644 --- a/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh +++ b/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh @@ -64,7 +64,7 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass { private: DI void prolog() { - this->ldgXY(0); + this->ldgXY(IdxT(blockIdx.x) * P::Mblk, IdxT(blockIdx.y) * P::Nblk, 0); #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll @@ -74,18 +74,18 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass { } this->stsXY(); __syncthreads(); - this->pageWr ^= 1; + this->switch_write_buffer(); } DI void loop() { for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(kidx); + this->ldgXY(IdxT(blockIdx.x) * P::Mblk, IdxT(blockIdx.y) * P::Nblk, kidx); accumulate(); // on the previous k-block this->stsXY(); __syncthreads(); - this->pageWr ^= 1; - this->pageRd ^= 1; + this->switch_write_buffer(); + this->switch_read_buffer(); } accumulate(); // last iteration }