From a3b8587978db1556c27d41d4d2a30ba2ac75e91e Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 2 Sep 2022 22:21:22 +0200 Subject: [PATCH 01/15] contractions: Concentrate tile index calculations The calculation of the tile indices are now performed in ldgXY(). This will make it possible to remove all state related to the tile index out of the class in the next commit. Note that the calculation of the tile index can depend on which overloaded constructor is called(!) --- .../detail/pairwise_distance_base.cuh | 27 ++---- .../raft/linalg/detail/contractions.cuh | 84 +++++++++++++------ .../knn/detail/epsilon_neighborhood.cuh | 10 +-- 3 files changed, 72 insertions(+), 49 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index 69bb83d29a..f6e66d068e 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -138,27 +138,14 @@ struct PairwiseDistances : public BaseClass { DI void updateIndicesY() { const auto stride = P::Nblk * gridDim.x; - if (isRowMajor) { - this->y += stride * this->ldb; - } else { - this->y += stride; - } - this->yrowid += stride; + this->increment_grid_idx_n(stride); } DI void updateIndicesXY() { const auto stride = P::Mblk * gridDim.y; - if (isRowMajor) { - this->x += stride * this->lda; - this->yrowid = IdxT(blockIdx.x) * P::Nblk + this->srowid; - this->y = yBase + this->yrowid * this->ldb; - } else { - this->x += stride; - this->yrowid = IdxT(blockIdx.x) * P::Nblk; - this->y = yBase + this->yrowid + this->srowid * this->ldb; - } - this->xrowid += stride; + this->increment_grid_idx_m(stride); + this->reset_grid_idx_n(); } DI void ldgNextGridStride(IdxT gridStrideX, IdxT gridStrideY) @@ -187,7 +174,7 @@ struct PairwiseDistances : public BaseClass { this->stsXY(); __syncthreads(); - this->pageWr ^= 1; + this->switch_write_buffer(); } DI void loop() @@ -197,15 +184,15 @@ struct PairwiseDistances : public BaseClass { accumulate(); // on the previous k-block this->stsXY(); __syncthreads(); - this->pageWr ^= 1; - this->pageRd ^= 1; + this->switch_write_buffer(); + this->switch_read_buffer(); } accumulate(); // last iteration // This is needed for making sure next grid stride of // non-norm based metrics uses previously accumulated buffer so // it doesn't make shmem dirty until previous iteration // is complete. - this->pageRd ^= 1; + this->switch_read_buffer(); } DI void accumulate() diff --git a/cpp/include/raft/linalg/detail/contractions.cuh b/cpp/include/raft/linalg/detail/contractions.cuh index 5d83f88e71..a6efdec49e 100644 --- a/cpp/include/raft/linalg/detail/contractions.cuh +++ b/cpp/include/raft/linalg/detail/contractions.cuh @@ -40,14 +40,15 @@ struct Contractions_NT { /** leading dimension in Output D */ IdxT ldd; - /** current thread's global mem row id for X data */ - IdxT xrowid; - /** current thread's global mem row id for Y data */ - IdxT yrowid; /** global memory pointer to X matrix */ - const DataT* x; + const DataT* x_base; /** global memory pointer to Y matrix */ - const DataT* y; + const DataT* y_base; + + /** Support variables to provide backward compatibility **/ + IdxT grid_idx_m = 0; + IdxT grid_idx_n = 0; + bool first_constructor_called; /** current thread's smem row id */ int srowid; @@ -94,10 +95,8 @@ struct Contractions_NT { k(_k), lda(_k), ldb(_k), - xrowid(IdxT(blockIdx.x) * P::Mblk + threadIdx.x / P::LdgThRow), - yrowid(IdxT(blockIdx.y) * P::Nblk + threadIdx.x / P::LdgThRow), - x(_x + xrowid * lda), - y(_y + yrowid * ldb), + x_base(_x), + y_base(_y), srowid(threadIdx.x / P::LdgThRow), scolid((threadIdx.x % P::LdgThRow) * P::Veclen), accrowid(threadIdx.x / P::AccThCols), @@ -105,7 +104,8 @@ struct Contractions_NT { sx((DataT*)_smem), sy(&(sx[P::SmemPageX])), pageWr(0), - pageRd(0) + pageRd(0), + first_constructor_called(true) { } @@ -133,6 +133,8 @@ struct Contractions_NT { lda(_lda), ldb(_ldb), ldd(_ldd), + x_base(_x), + y_base(_y), srowid(threadIdx.x / P::LdgThRow), scolid((threadIdx.x % P::LdgThRow) * P::Veclen), accrowid(threadIdx.x / P::AccThCols), @@ -140,19 +142,9 @@ struct Contractions_NT { sx((DataT*)_smem), sy(&(sx[P::SmemPageX])), pageWr(0), - pageRd(0) + pageRd(0), + first_constructor_called(false) { - if (isRowMajor) { - xrowid = IdxT(blockIdx.y) * P::Mblk + srowid; - yrowid = IdxT(blockIdx.x) * P::Nblk + srowid; - x = _x + xrowid * lda; - y = _y + yrowid * ldb; - } else { - xrowid = IdxT(blockIdx.y) * P::Mblk; - yrowid = IdxT(blockIdx.x) * P::Nblk; - x = _x + xrowid + srowid * lda; - y = _y + yrowid + srowid * ldb; - } } protected: @@ -166,6 +158,12 @@ struct Contractions_NT { ldgY(kidx); } + DI void ldgXY(IdxT tile_idx_m, IdxT tile_idx_n, IdxT kidx) + { + ldgX(tile_idx_m, kidx); + ldgY(tile_idx_n, kidx); + } + /** * @brief Store current block of X/Y from registers to smem * @param[in] kidx current start index of k to be loaded @@ -186,9 +184,35 @@ struct Contractions_NT { ldsY(kidx, sy + pageRd * P::SmemPage); } + DI void increment_grid_idx_m(IdxT by) { grid_idx_m += by; } + + DI void increment_grid_idx_n(IdxT by) { grid_idx_n += by; } + + DI void reset_grid_idx_n() { grid_idx_n = 0; } + + DI void switch_read_buffer() { this->pageRd ^= 1; } + + DI void switch_write_buffer() { this->pageWr ^= 1; } + private: DI void ldgX(IdxT kidx) { + // Backward compatible way to determine the tile index. This depends on + // whether the first or the second constructor was called. The first + // constructor is called in epsilon_neighborhood.cuh and the second + // constructor is called in pairwise_distance_base.cuh. + if (first_constructor_called) { + ldgX(IdxT(blockIdx.x) * P::Mblk, kidx); + } else { + ldgX(grid_idx_m + IdxT(blockIdx.y) * P::Mblk, kidx); + } + } + + DI void ldgX(IdxT tile_idx_m, IdxT kidx) + { + IdxT xrowid = isRowMajor ? tile_idx_m + srowid : tile_idx_m; + auto x = isRowMajor ? x_base + xrowid * lda : x_base + xrowid + srowid * lda; + if (isRowMajor) { auto numRows = m; auto koffset = kidx + scolid; @@ -222,6 +246,18 @@ struct Contractions_NT { DI void ldgY(IdxT kidx) { + if (first_constructor_called) { + ldgY(IdxT(blockIdx.y) * P::Nblk, kidx); + } else { + ldgY(grid_idx_n + IdxT(blockIdx.x) * P::Nblk, kidx); + } + } + + DI void ldgY(IdxT tile_idx_n, IdxT kidx) + { + IdxT yrowid = isRowMajor ? tile_idx_n + srowid : tile_idx_n; + auto y = isRowMajor ? y_base + yrowid * ldb : y_base + yrowid + srowid * ldb; + if (isRowMajor) { auto numRows = n; auto koffset = kidx + scolid; @@ -315,4 +351,4 @@ struct Contractions_NT { } // namespace detail } // namespace linalg -} // namespace raft \ No newline at end of file +} // namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh b/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh index 19862d743d..cd0e005921 100644 --- a/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh +++ b/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh @@ -64,7 +64,7 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass { private: DI void prolog() { - this->ldgXY(0); + this->ldgXY(IdxT(blockIdx.x) * P::Mblk, IdxT(blockIdx.y) * P::Nblk, 0); #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll @@ -74,18 +74,18 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass { } this->stsXY(); __syncthreads(); - this->pageWr ^= 1; + this->switch_write_buffer(); } DI void loop() { for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(kidx); + this->ldgXY(IdxT(blockIdx.x) * P::Mblk, IdxT(blockIdx.y) * P::Nblk, kidx); accumulate(); // on the previous k-block this->stsXY(); __syncthreads(); - this->pageWr ^= 1; - this->pageRd ^= 1; + this->switch_write_buffer(); + this->switch_read_buffer(); } accumulate(); // last iteration } From 99e65a5d93c8fbca1afea76fbcda019346d5d1df Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 2 Sep 2022 22:51:49 +0200 Subject: [PATCH 02/15] pairwise_distance_base: Remove all ldgXY(0) calls This commit moves all grid and tile indexing logic into the caller. Contractions_NT is now only responsible for *intra*-tile indexing. Due to the complexity of the epilog function, the ldgNextGridStride function is not yet called from within the main loop. That is the next goal so that we have all the grid and tile indexing localized in the loop. --- .../detail/pairwise_distance_base.cuh | 121 ++++++++++-------- .../raft/linalg/detail/contractions.cuh | 45 +------ 2 files changed, 67 insertions(+), 99 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index f6e66d068e..fefb964f3d 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -87,6 +87,12 @@ struct PairwiseDistances : public BaseClass { FinalLambda fin_op; rowEpilogueLambda rowEpilog_op; + + const IdxT grid_stride_m; + const IdxT grid_stride_n; + const IdxT grid_offset_m; + const IdxT grid_offset_n; + AccT acc[P::AccRowsPerTh][P::AccColsPerTh]; public: @@ -116,53 +122,63 @@ struct PairwiseDistances : public BaseClass { core_op(_core_op), epilog_op(_epilog_op), fin_op(_fin_op), - rowEpilog_op(_rowEpilog_op) + rowEpilog_op(_rowEpilog_op), + grid_stride_m(P::Nblk * gridDim.y), + grid_stride_n(P::Mblk * gridDim.x), + grid_offset_m(P::Mblk * blockIdx.y), + grid_offset_n(P::Nblk * blockIdx.x) { } DI void run() { - for (auto gridStrideY = blockIdx.y * P::Mblk; gridStrideY < this->m; - gridStrideY += P::Mblk * gridDim.y) { - for (auto gridStrideX = blockIdx.x * P::Nblk; gridStrideX < this->n; - gridStrideX += P::Nblk * gridDim.x) { - prolog(gridStrideX, gridStrideY); - loop(); - epilog(gridStrideX, gridStrideY); + for (auto tile_idx_m = grid_offset_m; tile_idx_m < this->m; tile_idx_m += grid_stride_m) { + this->ldgXY(tile_idx_m, grid_offset_n, 0); + for (auto tile_idx_n = grid_offset_n; tile_idx_n < this->n; tile_idx_n += grid_stride_n) { + reset_accumulator(); + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + + for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { + this->ldgXY(tile_idx_m, tile_idx_n, kidx); + // Process all data in shared memory (previous k-block) and + // accumulate in registers. + accumulate(); + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + this->switch_read_buffer(); + } + accumulate(); // last iteration + // This is needed for making sure next grid stride of + // non-norm based metrics uses previously accumulated buffer so + // it doesn't make shmem dirty until previous iteration + // is complete. + this->switch_read_buffer(); + + epilog(tile_idx_n, tile_idx_m); } - rowEpilog_op(gridStrideY); + rowEpilog_op(tile_idx_m); } } private: - DI void updateIndicesY() - { - const auto stride = P::Nblk * gridDim.x; - this->increment_grid_idx_n(stride); - } - - DI void updateIndicesXY() - { - const auto stride = P::Mblk * gridDim.y; - this->increment_grid_idx_m(stride); - this->reset_grid_idx_n(); - } - - DI void ldgNextGridStride(IdxT gridStrideX, IdxT gridStrideY) + DI void ldgNextGridStride(IdxT tile_idx_n, IdxT tile_idx_m) { // Fetch next grid stride ldg if within range - if ((gridStrideX + gridDim.x * P::Nblk) < this->n) { - updateIndicesY(); - this->ldgXY(0); - } else if ((gridStrideY + gridDim.y * P::Mblk) < this->m) { - updateIndicesXY(); - this->ldgXY(0); + const auto next_tile_tile_idx_n = tile_idx_n + grid_stride_n; + const auto next_tile_tile_idx_m = tile_idx_m + grid_stride_m; + if ((next_tile_tile_idx_n) < this->n) { + this->ldgXY(tile_idx_m, next_tile_tile_idx_n, 0); + } else if ((next_tile_tile_idx_m) < this->m) { + this->ldgXY(next_tile_tile_idx_m, grid_offset_n, 0); } } - DI void prolog(IdxT gridStrideX, IdxT gridStrideY) + DI void prolog(IdxT tile_idx_n, IdxT tile_idx_m) { - if (gridStrideX == blockIdx.x * P::Nblk) { this->ldgXY(0); } + if (tile_idx_n == blockIdx.x * P::Nblk) { this->ldgXY(0); } #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { @@ -177,22 +193,15 @@ struct PairwiseDistances : public BaseClass { this->switch_write_buffer(); } - DI void loop() - { - for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(kidx); - accumulate(); // on the previous k-block - this->stsXY(); - __syncthreads(); - this->switch_write_buffer(); - this->switch_read_buffer(); + DI void reset_accumulator() { + // Reset accumulator registers to zero. +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + acc[i][j] = BaseClass::Zero; + } } - accumulate(); // last iteration - // This is needed for making sure next grid stride of - // non-norm based metrics uses previously accumulated buffer so - // it doesn't make shmem dirty until previous iteration - // is complete. - this->switch_read_buffer(); } DI void accumulate() @@ -213,22 +222,22 @@ struct PairwiseDistances : public BaseClass { } } - DI void epilog(IdxT gridStrideX, IdxT gridStrideY) + DI void epilog(IdxT tile_idx_n, IdxT tile_idx_m) { if (useNorms) { DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); DataT* syNorm = (&sxNorm[P::Mblk]); // Load x & y norms required by this threadblock in shmem buffer - if (gridStrideX == blockIdx.x * P::Nblk) { + if (tile_idx_n == blockIdx.x * P::Nblk) { for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { - auto idx = gridStrideY + i; + auto idx = tile_idx_m + i; sxNorm[i] = idx < this->m ? xn[idx] : 0; } } for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { - auto idx = gridStrideX + i; + auto idx = tile_idx_n + i; syNorm[i] = idx < this->n ? yn[idx] : 0; } @@ -245,17 +254,17 @@ struct PairwiseDistances : public BaseClass { } // Overlap ldg with epilog computation - ldgNextGridStride(gridStrideX, gridStrideY); - epilog_op(acc, regxn, regyn, gridStrideX, gridStrideY); + ldgNextGridStride(tile_idx_n, tile_idx_m); + epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); } else { // Overlap ldg with epilog computation - ldgNextGridStride(gridStrideX, gridStrideY); - epilog_op(acc, nullptr, nullptr, gridStrideX, gridStrideY); + ldgNextGridStride(tile_idx_n, tile_idx_m); + epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); } if (writeOut) { - IdxT starty = gridStrideY + this->accrowid; - IdxT startx = gridStrideX + this->acccolid; + IdxT starty = tile_idx_m + this->accrowid; + IdxT startx = tile_idx_n + this->acccolid; #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { diff --git a/cpp/include/raft/linalg/detail/contractions.cuh b/cpp/include/raft/linalg/detail/contractions.cuh index a6efdec49e..6d7a8e2292 100644 --- a/cpp/include/raft/linalg/detail/contractions.cuh +++ b/cpp/include/raft/linalg/detail/contractions.cuh @@ -45,11 +45,6 @@ struct Contractions_NT { /** global memory pointer to Y matrix */ const DataT* y_base; - /** Support variables to provide backward compatibility **/ - IdxT grid_idx_m = 0; - IdxT grid_idx_n = 0; - bool first_constructor_called; - /** current thread's smem row id */ int srowid; /** current thread's smem column id */ @@ -104,8 +99,7 @@ struct Contractions_NT { sx((DataT*)_smem), sy(&(sx[P::SmemPageX])), pageWr(0), - pageRd(0), - first_constructor_called(true) + pageRd(0) { } @@ -142,8 +136,7 @@ struct Contractions_NT { sx((DataT*)_smem), sy(&(sx[P::SmemPageX])), pageWr(0), - pageRd(0), - first_constructor_called(false) + pageRd(0) { } @@ -152,12 +145,6 @@ struct Contractions_NT { * @brief Load current block of X/Y from global memory to registers * @param[in] kidx current start index of k to be loaded */ - DI void ldgXY(IdxT kidx) - { - ldgX(kidx); - ldgY(kidx); - } - DI void ldgXY(IdxT tile_idx_m, IdxT tile_idx_n, IdxT kidx) { ldgX(tile_idx_m, kidx); @@ -184,30 +171,11 @@ struct Contractions_NT { ldsY(kidx, sy + pageRd * P::SmemPage); } - DI void increment_grid_idx_m(IdxT by) { grid_idx_m += by; } - - DI void increment_grid_idx_n(IdxT by) { grid_idx_n += by; } - - DI void reset_grid_idx_n() { grid_idx_n = 0; } - DI void switch_read_buffer() { this->pageRd ^= 1; } DI void switch_write_buffer() { this->pageWr ^= 1; } private: - DI void ldgX(IdxT kidx) - { - // Backward compatible way to determine the tile index. This depends on - // whether the first or the second constructor was called. The first - // constructor is called in epsilon_neighborhood.cuh and the second - // constructor is called in pairwise_distance_base.cuh. - if (first_constructor_called) { - ldgX(IdxT(blockIdx.x) * P::Mblk, kidx); - } else { - ldgX(grid_idx_m + IdxT(blockIdx.y) * P::Mblk, kidx); - } - } - DI void ldgX(IdxT tile_idx_m, IdxT kidx) { IdxT xrowid = isRowMajor ? tile_idx_m + srowid : tile_idx_m; @@ -244,15 +212,6 @@ struct Contractions_NT { } } - DI void ldgY(IdxT kidx) - { - if (first_constructor_called) { - ldgY(IdxT(blockIdx.y) * P::Nblk, kidx); - } else { - ldgY(grid_idx_n + IdxT(blockIdx.x) * P::Nblk, kidx); - } - } - DI void ldgY(IdxT tile_idx_n, IdxT kidx) { IdxT yrowid = isRowMajor ? tile_idx_n + srowid : tile_idx_n; From e6d5078aa126f5612525a548a063610136f98293 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 2 Sep 2022 23:40:32 +0200 Subject: [PATCH 03/15] pairwise_distance_base: Move all logic into run loop This commit removes the epilog function and moves its functionality into the run loop. The next step might be to see if the ldgNextGridStride() method has to be called the current location, or if performance is the same if its called at the start of the loop. --- .../detail/pairwise_distance_base.cuh | 128 ++++++++---------- 1 file changed, 57 insertions(+), 71 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index fefb964f3d..a2dffad808 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -87,7 +87,6 @@ struct PairwiseDistances : public BaseClass { FinalLambda fin_op; rowEpilogueLambda rowEpilog_op; - const IdxT grid_stride_m; const IdxT grid_stride_n; const IdxT grid_offset_m; @@ -141,14 +140,14 @@ struct PairwiseDistances : public BaseClass { this->switch_write_buffer(); for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(tile_idx_m, tile_idx_n, kidx); - // Process all data in shared memory (previous k-block) and - // accumulate in registers. - accumulate(); - this->stsXY(); - __syncthreads(); - this->switch_write_buffer(); - this->switch_read_buffer(); + this->ldgXY(tile_idx_m, tile_idx_n, kidx); + // Process all data in shared memory (previous k-block) and + // accumulate in registers. + accumulate(); + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + this->switch_read_buffer(); } accumulate(); // last iteration // This is needed for making sure next grid stride of @@ -157,14 +156,25 @@ struct PairwiseDistances : public BaseClass { // is complete. this->switch_read_buffer(); - epilog(tile_idx_n, tile_idx_m); + if (useNorms) { + DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; + load_norms(tile_idx_m, tile_idx_n, regxn, regyn); + // Overlap ldg with epilog computation + ldgNextGridStride(tile_idx_m, tile_idx_n); + epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); + } else { + // Overlap ldg with epilog computation + ldgNextGridStride(tile_idx_m, tile_idx_n); + epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); + } + if (writeOut) { store_output(tile_idx_m, tile_idx_n); } } rowEpilog_op(tile_idx_m); } } private: - DI void ldgNextGridStride(IdxT tile_idx_n, IdxT tile_idx_m) + DI void ldgNextGridStride(IdxT tile_idx_m, IdxT tile_idx_n) { // Fetch next grid stride ldg if within range const auto next_tile_tile_idx_n = tile_idx_n + grid_stride_n; @@ -176,24 +186,8 @@ struct PairwiseDistances : public BaseClass { } } - DI void prolog(IdxT tile_idx_n, IdxT tile_idx_m) + DI void reset_accumulator() { - if (tile_idx_n == blockIdx.x * P::Nblk) { this->ldgXY(0); } - -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = BaseClass::Zero; - } - } - - this->stsXY(); - __syncthreads(); - this->switch_write_buffer(); - } - - DI void reset_accumulator() { // Reset accumulator registers to zero. #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { @@ -222,60 +216,52 @@ struct PairwiseDistances : public BaseClass { } } - DI void epilog(IdxT tile_idx_n, IdxT tile_idx_m) + DI void load_norms(IdxT tile_idx_m, + IdxT tile_idx_n, + DataT (®xn)[P::AccRowsPerTh], + DataT (®yn)[P::AccColsPerTh]) { - if (useNorms) { - DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); - DataT* syNorm = (&sxNorm[P::Mblk]); - - // Load x & y norms required by this threadblock in shmem buffer - if (tile_idx_n == blockIdx.x * P::Nblk) { - for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { - auto idx = tile_idx_m + i; - sxNorm[i] = idx < this->m ? xn[idx] : 0; - } - } - - for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { - auto idx = tile_idx_n + i; - syNorm[i] = idx < this->n ? yn[idx] : 0; + DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); + DataT* syNorm = (&sxNorm[P::Mblk]); + + // Load x & y norms required by this threadblock in shmem buffer + if (tile_idx_n == blockIdx.x * P::Nblk) { + for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { + auto idx = tile_idx_m + i; + sxNorm[i] = idx < this->m ? xn[idx] : 0; } + } - __syncthreads(); + for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { + auto idx = tile_idx_n + i; + syNorm[i] = idx < this->n ? yn[idx] : 0; + } + __syncthreads(); - DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; - } + for (int i = 0; i < P::AccRowsPerTh; ++i) { + regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; + } #pragma unroll - for (int i = 0; i < P::AccColsPerTh; ++i) { - regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; - } - - // Overlap ldg with epilog computation - ldgNextGridStride(tile_idx_n, tile_idx_m); - epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); - } else { - // Overlap ldg with epilog computation - ldgNextGridStride(tile_idx_n, tile_idx_m); - epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); + for (int i = 0; i < P::AccColsPerTh; ++i) { + regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; } + } - if (writeOut) { - IdxT starty = tile_idx_m + this->accrowid; - IdxT startx = tile_idx_n + this->acccolid; + DI void store_output(IdxT tile_idx_m, IdxT tile_idx_n) + { + IdxT starty = tile_idx_m + this->accrowid; + IdxT startx = tile_idx_n + this->acccolid; #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto rowId = starty + i * P::AccThRows; + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto rowId = starty + i * P::AccThRows; #pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto colId = startx + j * P::AccThCols; - if (rowId < this->m && colId < this->n) { - // Promote to 64 bit index for final write, as output array can be > 2^31 - dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0); - } + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto colId = startx + j * P::AccThCols; + if (rowId < this->m && colId < this->n) { + // Promote to 64 bit index for final write, as output array can be > 2^31 + dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0); } } } From 995d2ae5a550060c27e3426570a7f9e8e7addc01 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 5 Oct 2022 16:17:56 +0200 Subject: [PATCH 04/15] pairwise_distance_base: Fix typo This results in subtle issues with non-square KernelPolicy, as found in fusedL2KNN. --- cpp/include/raft/distance/detail/pairwise_distance_base.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index a2dffad808..b28c3a3de4 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -122,8 +122,8 @@ struct PairwiseDistances : public BaseClass { epilog_op(_epilog_op), fin_op(_fin_op), rowEpilog_op(_rowEpilog_op), - grid_stride_m(P::Nblk * gridDim.y), - grid_stride_n(P::Mblk * gridDim.x), + grid_stride_m(P::Mblk * gridDim.y), + grid_stride_n(P::Nblk * gridDim.x), grid_offset_m(P::Mblk * blockIdx.y), grid_offset_n(P::Nblk * blockIdx.x) { From e6976c53ab559befef9019123ae33379bb54733e Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 23 Jan 2023 15:40:04 +0100 Subject: [PATCH 05/15] Implement reviewer feedback --- .../raft/distance/detail/pairwise_distance_base.cuh | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index b28c3a3de4..d849b23999 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -59,6 +59,7 @@ namespace detail { * @param core_op the core accumulation operation lambda * @param epilog_op the epilog operation lambda * @param fin_op the final gemm epilogue lambda + * @param rowEpilog_op epilog lambda that executes when a full row has been processed */ template m; tile_idx_m += grid_stride_m) { this->ldgXY(tile_idx_m, grid_offset_n, 0); for (auto tile_idx_n = grid_offset_n; tile_idx_n < this->n; tile_idx_n += grid_stride_n) { + // Prolog: reset_accumulator(); this->stsXY(); __syncthreads(); this->switch_write_buffer(); + // Main loop: for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { this->ldgXY(tile_idx_m, tile_idx_n, kidx); // Process all data in shared memory (previous k-block) and @@ -150,12 +153,12 @@ struct PairwiseDistances : public BaseClass { this->switch_read_buffer(); } accumulate(); // last iteration - // This is needed for making sure next grid stride of - // non-norm based metrics uses previously accumulated buffer so - // it doesn't make shmem dirty until previous iteration - // is complete. + // The pre-condition for the loop over tile_idx_n is that write_buffer + // and read_buffer point to the same buffer. This flips read_buffer back + // so that it satisfies the pre-condition of this loop. this->switch_read_buffer(); + // Epilog: if (useNorms) { DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; load_norms(tile_idx_m, tile_idx_n, regxn, regyn); From e52b0f94afc1f1b6f4071eee2ac96b349030554d Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 26 Jan 2023 10:10:11 -0500 Subject: [PATCH 06/15] Forcing sccache reinit. --- build.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/build.sh b/build.sh index b47e1ed862..f34c032204 100755 --- a/build.sh +++ b/build.sh @@ -75,6 +75,8 @@ COMPILE_DIST_LIBRARY=OFF ENABLE_NN_DEPENDENCIES=OFF INSTALL_TARGET=install +SCCACHE_RECACHE=1 + TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" BENCH_TARGETS="CLUSTER_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH" ENABLE_thrust_DEPENDENCY=ON From 85c6294d6f3adc10ef623332ee4ca7d6509afa42 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 26 Jan 2023 11:19:43 -0500 Subject: [PATCH 07/15] Breaking specializations for refine into individual files --- cpp/CMakeLists.txt | 9 +++- .../raft/linalg/detail/contractions.cuh | 2 +- .../knn/detail/epsilon_neighborhood.cuh | 2 +- cpp/src/distance/neighbors/refine.cu | 52 ------------------- .../neighbors/refine_d_uint64_t_float.cu | 33 ++++++++++++ .../neighbors/refine_d_uint64_t_int8_t.cu | 33 ++++++++++++ .../neighbors/refine_d_uint64_t_uint8_t.cu | 33 ++++++++++++ .../neighbors/refine_h_uint64_t_float.cu | 33 ++++++++++++ .../neighbors/refine_h_uint64_t_int8_t.cu | 33 ++++++++++++ .../neighbors/refine_h_uint64_t_uint8_t.cu | 33 ++++++++++++ 10 files changed, 207 insertions(+), 56 deletions(-) delete mode 100644 cpp/src/distance/neighbors/refine.cu create mode 100644 cpp/src/distance/neighbors/refine_d_uint64_t_float.cu create mode 100644 cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu create mode 100644 cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu create mode 100644 cpp/src/distance/neighbors/refine_h_uint64_t_float.cu create mode 100644 cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu create mode 100644 cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index c6850b290f..a45c5b0cc8 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -284,7 +284,12 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/cluster/update_centroids_double.cu src/distance/cluster/cluster_cost_float.cu src/distance/cluster/cluster_cost_double.cu - src/distance/neighbors/refine.cu + src/distance/neighbors/refine_d_uint64_t_float.cu + src/distance/neighbors/refine_d_uint64_t_int8_t.cu + src/distance/neighbors/refine_d_uint64_t_uint8_t.cu + src/distance/neighbors/refine_h_uint64_t_float.cu + src/distance/neighbors/refine_h_uint64_t_int8_t.cu + src/distance/neighbors/refine_h_uint64_t_uint8_t.cu src/distance/neighbors/ivfpq_search.cu src/distance/cluster/kmeans_fit_float.cu src/distance/cluster/kmeans_fit_double.cu diff --git a/cpp/include/raft/linalg/detail/contractions.cuh b/cpp/include/raft/linalg/detail/contractions.cuh index 6d7a8e2292..f2d71117f7 100644 --- a/cpp/include/raft/linalg/detail/contractions.cuh +++ b/cpp/include/raft/linalg/detail/contractions.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh b/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh index cd0e005921..7616083796 100644 --- a/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh +++ b/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/src/distance/neighbors/refine.cu b/cpp/src/distance/neighbors/refine.cu deleted file mode 100644 index 83e3383cba..0000000000 --- a/cpp/src/distance/neighbors/refine.cu +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::runtime::neighbors { - -#define RAFT_INST_REFINE(IDX_T, DATA_T) \ - void refine(raft::device_resources const& handle, \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbor_candidates, \ - raft::device_matrix_view indices, \ - raft::device_matrix_view distances, \ - distance::DistanceType metric) \ - { \ - raft::neighbors::detail::refine_device( \ - handle, dataset, queries, neighbor_candidates, indices, distances, metric); \ - } \ - \ - void refine(raft::device_resources const& handle, \ - raft::host_matrix_view dataset, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbor_candidates, \ - raft::host_matrix_view indices, \ - raft::host_matrix_view distances, \ - distance::DistanceType metric) \ - { \ - raft::neighbors::detail::refine_host( \ - dataset, queries, neighbor_candidates, indices, distances, metric); \ - } - -RAFT_INST_REFINE(uint64_t, float); -RAFT_INST_REFINE(uint64_t, uint8_t); -RAFT_INST_REFINE(uint64_t, int8_t); - -#undef RAFT_INST_REFINE - -} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu new file mode 100644 index 0000000000..32819099d9 --- /dev/null +++ b/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::runtime::neighbors { + +void refine(raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu new file mode 100644 index 0000000000..ff45e74ba1 --- /dev/null +++ b/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::runtime::neighbors { + +void refine(raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu new file mode 100644 index 0000000000..0a1590194b --- /dev/null +++ b/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::runtime::neighbors { + +void refine(raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu new file mode 100644 index 0000000000..2d734ac5bf --- /dev/null +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::runtime::neighbors { + +void refine(raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu new file mode 100644 index 0000000000..9749499298 --- /dev/null +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::runtime::neighbors { + +void refine(raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu new file mode 100644 index 0000000000..dbc2100635 --- /dev/null +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::runtime::neighbors { + +void refine(raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors From 0fad8425730ef3d4a28b4d945952833dd0732809 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 26 Jan 2023 12:45:44 -0500 Subject: [PATCH 08/15] Checking in --- cpp/bench/neighbors/refine.cu | 11 ++-- .../raft/neighbors/specializations.cuh | 4 +- .../raft/neighbors/specializations/refine.cuh | 51 +++++++++++++++++++ .../neighbors/refine_d_uint64_t_float.cu | 1 + .../neighbors/refine_d_uint64_t_int8_t.cu | 1 + .../neighbors/refine_d_uint64_t_uint8_t.cu | 1 + .../neighbors/refine_h_uint64_t_float.cu | 1 + .../neighbors/refine_h_uint64_t_int8_t.cu | 2 +- .../neighbors/refine_h_uint64_t_uint8_t.cu | 1 + .../refine_d_uint64_t_float.cu | 30 +++++++++++ .../refine_d_uint64_t_int8_t.cu | 30 +++++++++++ .../refine_d_uint64_t_uint8_t.cu | 30 +++++++++++ .../refine_h_uint64_t_float.cu | 30 +++++++++++ .../refine_h_uint64_t_int8_t.cu | 29 +++++++++++ .../refine_h_uint64_t_uint8_t.cu | 30 +++++++++++ cpp/test/neighbors/refine.cu | 12 ++--- 16 files changed, 250 insertions(+), 14 deletions(-) create mode 100644 cpp/include/raft/neighbors/specializations/refine.cuh create mode 100644 cpp/src/distance/neighbors/specializations/refine_d_uint64_t_float.cu create mode 100644 cpp/src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu create mode 100644 cpp/src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu create mode 100644 cpp/src/distance/neighbors/specializations/refine_h_uint64_t_float.cu create mode 100644 cpp/src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu create mode 100644 cpp/src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu diff --git a/cpp/bench/neighbors/refine.cu b/cpp/bench/neighbors/refine.cu index 3349b8b6ae..16b115cab4 100644 --- a/cpp/bench/neighbors/refine.cu +++ b/cpp/bench/neighbors/refine.cu @@ -27,6 +27,7 @@ #if defined RAFT_DISTANCE_COMPILED #include +#include #endif #if defined RAFT_NN_COMPILED @@ -52,7 +53,7 @@ inline auto operator<<(std::ostream& os, const RefineInputs& p) -> std::os return os; } -RefineInputs p; +RefineInputs p; template class RefineAnn : public fixture { @@ -113,9 +114,9 @@ std::vector> getInputs() return out; } -using refine_float_int64 = RefineAnn; -RAFT_BENCH_REGISTER(refine_float_int64, "", getInputs()); +using refine_float_uint64 = RefineAnn; +RAFT_BENCH_REGISTER(refine_float_uint64, "", getInputs()); -using refine_uint8_int64 = RefineAnn; -RAFT_BENCH_REGISTER(refine_uint8_int64, "", getInputs()); +using refine_uint8_uint64 = RefineAnn; +RAFT_BENCH_REGISTER(refine_uint8_uint64, "", getInputs()); } // namespace raft::bench::neighbors diff --git a/cpp/include/raft/neighbors/specializations.cuh b/cpp/include/raft/neighbors/specializations.cuh index 0511bbbf6c..77c49b70e6 100644 --- a/cpp/include/raft/neighbors/specializations.cuh +++ b/cpp/include/raft/neighbors/specializations.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ #include #include #include +#include #include - #endif diff --git a/cpp/include/raft/neighbors/specializations/refine.cuh b/cpp/include/raft/neighbors/specializations/refine.cuh new file mode 100644 index 0000000000..71e83a26f3 --- /dev/null +++ b/cpp/include/raft/neighbors/specializations/refine.cuh @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::neighbors { + +#ifdef RAFT_INST +#undef RAFT_INST +#endif + +#define RAFT_INST(T, IdxT) \ + extern template void refine( \ + raft::device_resources const& handle, \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbor_candidates, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + distance::DistanceType metric); \ + \ + extern template void refine( \ + raft::device_resources const& handle, \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbor_candidates, \ + raft::host_matrix_view indices, \ + raft::host_matrix_view distances, \ + distance::DistanceType metric); + +RAFT_INST(float, uint64_t); +RAFT_INST(uint8_t, uint64_t); +RAFT_INST(int8_t, uint64_t); + +#undef RAFT_INST +} // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu index 32819099d9..75fe526b07 100644 --- a/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu +++ b/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu @@ -15,6 +15,7 @@ */ #include +#include namespace raft::runtime::neighbors { diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu index ff45e74ba1..aaf05ca3cb 100644 --- a/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu +++ b/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu @@ -15,6 +15,7 @@ */ #include +#include namespace raft::runtime::neighbors { diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu index 0a1590194b..574ed7cf29 100644 --- a/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu +++ b/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu @@ -15,6 +15,7 @@ */ #include +#include namespace raft::runtime::neighbors { diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu index 2d734ac5bf..d03c082329 100644 --- a/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu @@ -15,6 +15,7 @@ */ #include +#include namespace raft::runtime::neighbors { diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu index 9749499298..01982ada95 100644 --- a/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu @@ -15,9 +15,9 @@ */ #include +#include namespace raft::runtime::neighbors { - void refine(raft::device_resources const& handle, raft::host_matrix_view dataset, raft::host_matrix_view queries, diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu index dbc2100635..08a9ff410e 100644 --- a/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu @@ -15,6 +15,7 @@ */ #include +#include namespace raft::runtime::neighbors { diff --git a/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_float.cu b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_float.cu new file mode 100644 index 0000000000..6bb1985d94 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_float.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { + +template void refine( + raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu new file mode 100644 index 0000000000..7e70ee5e29 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { + +template void refine( + raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu new file mode 100644 index 0000000000..53de106ef9 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { + +template void refine( + raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_float.cu b/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_float.cu new file mode 100644 index 0000000000..b473924741 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_float.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { + +template void refine( + raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu new file mode 100644 index 0000000000..c8b0e4c1c2 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { +template void refine( + raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu new file mode 100644 index 0000000000..b9e0f58ef6 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { + +template void refine( + raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/test/neighbors/refine.cu b/cpp/test/neighbors/refine.cu index 98933046b9..6f9e8210be 100644 --- a/cpp/test/neighbors/refine.cu +++ b/cpp/test/neighbors/refine.cu @@ -31,7 +31,7 @@ #include -#if defined RAFT_NN_COMPILED +#if defined RAFT_DISTANCE_COMPILED #include #endif @@ -107,8 +107,8 @@ class RefineTest : public ::testing::TestWithParam> { RefineHelper data; }; -const std::vector> inputs = - raft::util::itertools::product>( +const std::vector> inputs = + raft::util::itertools::product>( {137}, {1000}, {16}, @@ -117,16 +117,16 @@ const std::vector> inputs = {raft::distance::DistanceType::L2Expanded, raft::distance::DistanceType::InnerProduct}, {false, true}); -typedef RefineTest RefineTestF; +typedef RefineTest RefineTestF; TEST_P(RefineTestF, AnnRefine) { this->testRefine(); } INSTANTIATE_TEST_CASE_P(RefineTest, RefineTestF, ::testing::ValuesIn(inputs)); -typedef RefineTest RefineTestF_uint8; +typedef RefineTest RefineTestF_uint8; TEST_P(RefineTestF_uint8, AnnRefine) { this->testRefine(); } INSTANTIATE_TEST_CASE_P(RefineTest, RefineTestF_uint8, ::testing::ValuesIn(inputs)); -typedef RefineTest RefineTestF_int8; +typedef RefineTest RefineTestF_int8; TEST_P(RefineTestF_int8, AnnRefine) { this->testRefine(); } INSTANTIATE_TEST_CASE_P(RefineTest, RefineTestF_int8, ::testing::ValuesIn(inputs)); } // namespace raft::neighbors From f7788af34ee8e17efd0e0c408b42b959b024c72b Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 26 Jan 2023 15:10:53 -0500 Subject: [PATCH 09/15] Including just the refine specialization --- cpp/include/raft/neighbors/specializations.cuh | 3 +-- cpp/src/distance/neighbors/refine_d_uint64_t_float.cu | 2 +- cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu | 2 +- cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu | 2 +- cpp/src/distance/neighbors/refine_h_uint64_t_float.cu | 2 +- cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu | 2 +- cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu | 2 +- 7 files changed, 7 insertions(+), 8 deletions(-) diff --git a/cpp/include/raft/neighbors/specializations.cuh b/cpp/include/raft/neighbors/specializations.cuh index 77c49b70e6..d17467c8a7 100644 --- a/cpp/include/raft/neighbors/specializations.cuh +++ b/cpp/include/raft/neighbors/specializations.cuh @@ -20,9 +20,8 @@ #pragma once #include +#include #include #include #include - -#include #endif diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu index 75fe526b07..d7b460180a 100644 --- a/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu +++ b/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu @@ -15,7 +15,7 @@ */ #include -#include +#include namespace raft::runtime::neighbors { diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu index aaf05ca3cb..3db07f0cdb 100644 --- a/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu +++ b/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu @@ -15,7 +15,7 @@ */ #include -#include +#include namespace raft::runtime::neighbors { diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu index 574ed7cf29..2ce43d5800 100644 --- a/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu +++ b/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu @@ -15,7 +15,7 @@ */ #include -#include +#include namespace raft::runtime::neighbors { diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu index d03c082329..2a2dcff3bf 100644 --- a/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu @@ -15,7 +15,7 @@ */ #include -#include +#include namespace raft::runtime::neighbors { diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu index 01982ada95..d7c60b62a5 100644 --- a/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu @@ -15,7 +15,7 @@ */ #include -#include +#include namespace raft::runtime::neighbors { void refine(raft::device_resources const& handle, diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu index 08a9ff410e..e9c4345e97 100644 --- a/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu @@ -15,7 +15,7 @@ */ #include -#include +#include namespace raft::runtime::neighbors { From 9e7b7298cf444ccbfec4df61d7b40b796c83227b Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 26 Jan 2023 15:30:46 -0500 Subject: [PATCH 10/15] Proper import of speicalizations --- cpp/bench/neighbors/knn.cuh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cpp/bench/neighbors/knn.cuh b/cpp/bench/neighbors/knn.cuh index 60eb8c257d..eec1cba99e 100644 --- a/cpp/bench/neighbors/knn.cuh +++ b/cpp/bench/neighbors/knn.cuh @@ -32,6 +32,9 @@ #include #if defined RAFT_DISTANCE_COMPILED #include +#include +#else +#pragma message("NN / Distance specializations are not enabled; expect very long building times.") #endif #endif From 060e62cd9363d8d8831c6bd619a8a303442bb596 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 26 Jan 2023 15:40:24 -0500 Subject: [PATCH 11/15] Remove SCCACHE_RECACHE from build.sh --- build.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/build.sh b/build.sh index f34c032204..b47e1ed862 100755 --- a/build.sh +++ b/build.sh @@ -75,8 +75,6 @@ COMPILE_DIST_LIBRARY=OFF ENABLE_NN_DEPENDENCIES=OFF INSTALL_TARGET=install -SCCACHE_RECACHE=1 - TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" BENCH_TARGETS="CLUSTER_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH" ENABLE_thrust_DEPENDENCY=ON From 2370c18c5bb09d3184fe9c0e9a349e7469112794 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 26 Jan 2023 17:12:14 -0500 Subject: [PATCH 12/15] Fixing build errro --- cpp/test/neighbors/refine.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cpp/test/neighbors/refine.cu b/cpp/test/neighbors/refine.cu index 6f9e8210be..a78f5cfe5c 100644 --- a/cpp/test/neighbors/refine.cu +++ b/cpp/test/neighbors/refine.cu @@ -109,11 +109,11 @@ class RefineTest : public ::testing::TestWithParam> { const std::vector> inputs = raft::util::itertools::product>( - {137}, - {1000}, - {16}, - {1, 10, 33}, - {33}, + {static_cast(137)}, + {static_cast(1000)}, + {static_cast(16)}, + {static_cast(1), static_cast(10), static_cast(33)}, + {static_cast(33)}, {raft::distance::DistanceType::L2Expanded, raft::distance::DistanceType::InnerProduct}, {false, true}); From d0a5ea49e27223183bd70455c400f666df2e01dc Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 26 Jan 2023 18:06:04 -0500 Subject: [PATCH 13/15] Fixing remaining compile errors --- cpp/bench/neighbors/refine.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/bench/neighbors/refine.cu b/cpp/bench/neighbors/refine.cu index 16b115cab4..f32af3a57e 100644 --- a/cpp/bench/neighbors/refine.cu +++ b/cpp/bench/neighbors/refine.cu @@ -99,15 +99,15 @@ class RefineAnn : public fixture { RefineHelper data; }; -std::vector> getInputs() +std::vector> getInputs() { - std::vector> out; + std::vector> out; raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded; for (bool host_data : {true, false}) { - for (int64_t n_queries : {1000, 10000}) { - for (int64_t dim : {128, 512}) { - out.push_back(RefineInputs{n_queries, 2000000, dim, 32, 128, metric, host_data}); - out.push_back(RefineInputs{n_queries, 2000000, dim, 10, 40, metric, host_data}); + for (uint64_t n_queries : {1000, 10000}) { + for (uint64_t dim : {128, 512}) { + out.push_back(RefineInputs{n_queries, 2000000, dim, 32, 128, metric, host_data}); + out.push_back(RefineInputs{n_queries, 2000000, dim, 10, 40, metric, host_data}); } } } From 62917236b59e221dfecc32e2628497ef843aaf08 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 26 Jan 2023 18:29:23 -0500 Subject: [PATCH 14/15] Adding specializations to cmakelists --- cpp/CMakeLists.txt | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index a45c5b0cc8..a6341f6dda 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -290,6 +290,12 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/neighbors/refine_h_uint64_t_float.cu src/distance/neighbors/refine_h_uint64_t_int8_t.cu src/distance/neighbors/refine_h_uint64_t_uint8_t.cu + src/distance/neighbors/specializations/refine_d_uint64_t_float.cu + src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu + src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu + src/distance/neighbors/specializations/refine_h_uint64_t_float.cu + src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu + src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu src/distance/neighbors/ivfpq_search.cu src/distance/cluster/kmeans_fit_float.cu src/distance/cluster/kmeans_fit_double.cu From e3ea7edd7d7f2e558a175a0d94d00ad0dbe81f24 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 27 Jan 2023 12:08:53 +0100 Subject: [PATCH 15/15] Rename distance_inputs to distance_params Force a rerun of CI. --- cpp/bench/distance/distance_common.cuh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cpp/bench/distance/distance_common.cuh b/cpp/bench/distance/distance_common.cuh index 73faacce37..1be00ec0c7 100644 --- a/cpp/bench/distance/distance_common.cuh +++ b/cpp/bench/distance/distance_common.cuh @@ -24,14 +24,14 @@ namespace raft::bench::distance { -struct distance_inputs { +struct distance_params { int m, n, k; bool isRowMajor; -}; // struct distance_inputs +}; // struct distance_params template struct distance : public fixture { - distance(const distance_inputs& p) + distance(const distance_params& p) : params(p), x(p.m * p.k, stream), y(p.n * p.k, stream), @@ -63,13 +63,13 @@ struct distance : public fixture { } private: - distance_inputs params; + distance_params params; rmm::device_uvector x, y, out; rmm::device_uvector workspace; size_t worksize; }; // struct Distance -const std::vector dist_input_vecs{ +const std::vector dist_input_vecs{ {32, 16384, 16384, true}, {64, 16384, 16384, true}, {128, 16384, 16384, true}, {256, 16384, 16384, true}, {512, 16384, 16384, true}, {1024, 16384, 16384, true}, {16384, 32, 16384, true}, {16384, 64, 16384, true}, {16384, 128, 16384, true},