diff --git a/cpp/include/raft/distance/pairwise_distance_base.cuh b/cpp/include/raft/distance/pairwise_distance_base.cuh index d5a434f2fa..43abc9eb65 100644 --- a/cpp/include/raft/distance/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/pairwise_distance_base.cuh @@ -136,39 +136,29 @@ struct PairwiseDistances : public BaseClass { this->xrowid += stride; } - DI void prolog(IdxT gridStrideX, IdxT gridStrideY) { - if (gridStrideX > blockIdx.x * P::Nblk) { + DI void ldgNextGridStride(IdxT gridStrideX, IdxT gridStrideY) { + // Fetch next grid stride ldg if within range + if ((gridStrideX + gridDim.x * P::Nblk) < this->n) { updateIndicesY(); - } else if (gridStrideY > blockIdx.y * P::Mblk) { + this->ldgXY(0); + } else if ((gridStrideY + gridDim.y * P::Mblk) < this->m) { updateIndicesXY(); + this->ldgXY(0); } + } - typedef TxN_t VecType; - VecType zeros; - zeros.fill(BaseClass::Zero); -#pragma unroll - for (int j = 0; j < P::LdgPerThX; ++j) { - zeros.store(&this->ldgDataX[j][0], 0); - } -#pragma unroll - for (int j = 0; j < P::LdgPerThY; ++j) { - zeros.store(&this->ldgDataY[j][0], 0); + DI void prolog(IdxT gridStrideX, IdxT gridStrideY) { + if (gridStrideX == blockIdx.x * P::Nblk) { + this->ldgXY(0); } - this->ldgXY(0); - #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { - zeros.store(&this->regx[i][0], 0); #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { acc[i][j] = BaseClass::Zero; } } -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - zeros.store(&this->regy[j][0], 0); - } this->stsXY(); __syncthreads(); @@ -239,8 +229,12 @@ struct PairwiseDistances : public BaseClass { regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; } + // Overlap ldg with epilog computation + ldgNextGridStride(gridStrideX, gridStrideY); epilog_op(acc, regxn, regyn, gridStrideX, gridStrideY); } else { + // Overlap ldg with epilog computation + ldgNextGridStride(gridStrideX, gridStrideY); epilog_op(acc, nullptr, nullptr, gridStrideX, gridStrideY); }