From c58d00a4b99b049a2d57ab0dce2c0d2da75f84cc Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 27 Jan 2023 17:23:41 +0100 Subject: [PATCH] Move contractions tiling logic outside of Contractions_NT (#837) The main functionality of Contractions_NT involves loading tiles of data into shared memory to enable fast GEMM-like kernels. In practice, this requires keeping track of tiles of data (2D submatrices of a bigger matrix) and distributing the data in the tiles over shared memory and registers of thread in a thread block. Currently, Contractions_NT performs indexing logic for both: 1. The distribution of data in a tile over registers and shared memory; 2. Looping over tiles of data in a 2D matrix. In this PR, we move functionality 2 out of Contractions_NT. Moving over the tiles of data and keeping track of the grid stride loop is now the responsibility of the calling code. Splitting these responsibilities is helpful when non-trivial tiling logic is required, as in the upcoming sparseL2NN functionality. **Note**: This PR also cleans up one unfortunate wart in the current implementation. Depending on which of the two overloaded constructors was called, the tiling logic was transposed leading to extremely difficult to track down bugs. Authors: - Allard Hendriksen (https://github.com/ahendriksen) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/837 --- cpp/CMakeLists.txt | 15 +- cpp/bench/distance/distance_common.cuh | 10 +- cpp/bench/neighbors/knn.cuh | 3 + cpp/bench/neighbors/refine.cu | 23 +- .../detail/pairwise_distance_base.cuh | 207 ++++++++---------- .../raft/linalg/detail/contractions.cuh | 51 ++--- .../raft/neighbors/specializations.cuh | 7 +- .../raft/neighbors/specializations/refine.cuh | 51 +++++ .../knn/detail/epsilon_neighborhood.cuh | 12 +- cpp/src/distance/neighbors/refine.cu | 52 ----- .../neighbors/refine_d_uint64_t_float.cu | 34 +++ .../neighbors/refine_d_uint64_t_int8_t.cu | 34 +++ .../neighbors/refine_d_uint64_t_uint8_t.cu | 34 +++ .../neighbors/refine_h_uint64_t_float.cu | 34 +++ .../neighbors/refine_h_uint64_t_int8_t.cu | 33 +++ .../neighbors/refine_h_uint64_t_uint8_t.cu | 34 +++ .../refine_d_uint64_t_float.cu | 30 +++ .../refine_d_uint64_t_int8_t.cu | 30 +++ .../refine_d_uint64_t_uint8_t.cu | 30 +++ .../refine_h_uint64_t_float.cu | 30 +++ .../refine_h_uint64_t_int8_t.cu | 29 +++ .../refine_h_uint64_t_uint8_t.cu | 30 +++ cpp/test/neighbors/refine.cu | 22 +- 23 files changed, 605 insertions(+), 230 deletions(-) create mode 100644 cpp/include/raft/neighbors/specializations/refine.cuh delete mode 100644 cpp/src/distance/neighbors/refine.cu create mode 100644 cpp/src/distance/neighbors/refine_d_uint64_t_float.cu create mode 100644 cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu create mode 100644 cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu create mode 100644 cpp/src/distance/neighbors/refine_h_uint64_t_float.cu create mode 100644 cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu create mode 100644 cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu create mode 100644 cpp/src/distance/neighbors/specializations/refine_d_uint64_t_float.cu create mode 100644 cpp/src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu create mode 100644 cpp/src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu create mode 100644 cpp/src/distance/neighbors/specializations/refine_h_uint64_t_float.cu create mode 100644 cpp/src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu create mode 100644 cpp/src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index c6850b290f..a6341f6dda 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -284,7 +284,18 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/cluster/update_centroids_double.cu src/distance/cluster/cluster_cost_float.cu src/distance/cluster/cluster_cost_double.cu - src/distance/neighbors/refine.cu + src/distance/neighbors/refine_d_uint64_t_float.cu + src/distance/neighbors/refine_d_uint64_t_int8_t.cu + src/distance/neighbors/refine_d_uint64_t_uint8_t.cu + src/distance/neighbors/refine_h_uint64_t_float.cu + src/distance/neighbors/refine_h_uint64_t_int8_t.cu + src/distance/neighbors/refine_h_uint64_t_uint8_t.cu + src/distance/neighbors/specializations/refine_d_uint64_t_float.cu + src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu + src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu + src/distance/neighbors/specializations/refine_h_uint64_t_float.cu + src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu + src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu src/distance/neighbors/ivfpq_search.cu src/distance/cluster/kmeans_fit_float.cu src/distance/cluster/kmeans_fit_double.cu diff --git a/cpp/bench/distance/distance_common.cuh b/cpp/bench/distance/distance_common.cuh index 73faacce37..1be00ec0c7 100644 --- a/cpp/bench/distance/distance_common.cuh +++ b/cpp/bench/distance/distance_common.cuh @@ -24,14 +24,14 @@ namespace raft::bench::distance { -struct distance_inputs { +struct distance_params { int m, n, k; bool isRowMajor; -}; // struct distance_inputs +}; // struct distance_params template struct distance : public fixture { - distance(const distance_inputs& p) + distance(const distance_params& p) : params(p), x(p.m * p.k, stream), y(p.n * p.k, stream), @@ -63,13 +63,13 @@ struct distance : public fixture { } private: - distance_inputs params; + distance_params params; rmm::device_uvector x, y, out; rmm::device_uvector workspace; size_t worksize; }; // struct Distance -const std::vector dist_input_vecs{ +const std::vector dist_input_vecs{ {32, 16384, 16384, true}, {64, 16384, 16384, true}, {128, 16384, 16384, true}, {256, 16384, 16384, true}, {512, 16384, 16384, true}, {1024, 16384, 16384, true}, {16384, 32, 16384, true}, {16384, 64, 16384, true}, {16384, 128, 16384, true}, diff --git a/cpp/bench/neighbors/knn.cuh b/cpp/bench/neighbors/knn.cuh index 60eb8c257d..eec1cba99e 100644 --- a/cpp/bench/neighbors/knn.cuh +++ b/cpp/bench/neighbors/knn.cuh @@ -32,6 +32,9 @@ #include #if defined RAFT_DISTANCE_COMPILED #include +#include +#else +#pragma message("NN / Distance specializations are not enabled; expect very long building times.") #endif #endif diff --git a/cpp/bench/neighbors/refine.cu b/cpp/bench/neighbors/refine.cu index 3349b8b6ae..f32af3a57e 100644 --- a/cpp/bench/neighbors/refine.cu +++ b/cpp/bench/neighbors/refine.cu @@ -27,6 +27,7 @@ #if defined RAFT_DISTANCE_COMPILED #include +#include #endif #if defined RAFT_NN_COMPILED @@ -52,7 +53,7 @@ inline auto operator<<(std::ostream& os, const RefineInputs& p) -> std::os return os; } -RefineInputs p; +RefineInputs p; template class RefineAnn : public fixture { @@ -98,24 +99,24 @@ class RefineAnn : public fixture { RefineHelper data; }; -std::vector> getInputs() +std::vector> getInputs() { - std::vector> out; + std::vector> out; raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded; for (bool host_data : {true, false}) { - for (int64_t n_queries : {1000, 10000}) { - for (int64_t dim : {128, 512}) { - out.push_back(RefineInputs{n_queries, 2000000, dim, 32, 128, metric, host_data}); - out.push_back(RefineInputs{n_queries, 2000000, dim, 10, 40, metric, host_data}); + for (uint64_t n_queries : {1000, 10000}) { + for (uint64_t dim : {128, 512}) { + out.push_back(RefineInputs{n_queries, 2000000, dim, 32, 128, metric, host_data}); + out.push_back(RefineInputs{n_queries, 2000000, dim, 10, 40, metric, host_data}); } } } return out; } -using refine_float_int64 = RefineAnn; -RAFT_BENCH_REGISTER(refine_float_int64, "", getInputs()); +using refine_float_uint64 = RefineAnn; +RAFT_BENCH_REGISTER(refine_float_uint64, "", getInputs()); -using refine_uint8_int64 = RefineAnn; -RAFT_BENCH_REGISTER(refine_uint8_int64, "", getInputs()); +using refine_uint8_uint64 = RefineAnn; +RAFT_BENCH_REGISTER(refine_uint8_uint64, "", getInputs()); } // namespace raft::bench::neighbors diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index 69bb83d29a..d849b23999 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -59,6 +59,7 @@ namespace detail { * @param core_op the core accumulation operation lambda * @param epilog_op the epilog operation lambda * @param fin_op the final gemm epilogue lambda + * @param rowEpilog_op epilog lambda that executes when a full row has been processed */ template m; - gridStrideY += P::Mblk * gridDim.y) { - for (auto gridStrideX = blockIdx.x * P::Nblk; gridStrideX < this->n; - gridStrideX += P::Nblk * gridDim.x) { - prolog(gridStrideX, gridStrideY); - loop(); - epilog(gridStrideX, gridStrideY); + for (auto tile_idx_m = grid_offset_m; tile_idx_m < this->m; tile_idx_m += grid_stride_m) { + this->ldgXY(tile_idx_m, grid_offset_n, 0); + for (auto tile_idx_n = grid_offset_n; tile_idx_n < this->n; tile_idx_n += grid_stride_n) { + // Prolog: + reset_accumulator(); + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + + // Main loop: + for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { + this->ldgXY(tile_idx_m, tile_idx_n, kidx); + // Process all data in shared memory (previous k-block) and + // accumulate in registers. + accumulate(); + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + this->switch_read_buffer(); + } + accumulate(); // last iteration + // The pre-condition for the loop over tile_idx_n is that write_buffer + // and read_buffer point to the same buffer. This flips read_buffer back + // so that it satisfies the pre-condition of this loop. + this->switch_read_buffer(); + + // Epilog: + if (useNorms) { + DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; + load_norms(tile_idx_m, tile_idx_n, regxn, regyn); + // Overlap ldg with epilog computation + ldgNextGridStride(tile_idx_m, tile_idx_n); + epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); + } else { + // Overlap ldg with epilog computation + ldgNextGridStride(tile_idx_m, tile_idx_n); + epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); + } + if (writeOut) { store_output(tile_idx_m, tile_idx_n); } } - rowEpilog_op(gridStrideY); + rowEpilog_op(tile_idx_m); } } private: - DI void updateIndicesY() - { - const auto stride = P::Nblk * gridDim.x; - if (isRowMajor) { - this->y += stride * this->ldb; - } else { - this->y += stride; - } - this->yrowid += stride; - } - - DI void updateIndicesXY() - { - const auto stride = P::Mblk * gridDim.y; - if (isRowMajor) { - this->x += stride * this->lda; - this->yrowid = IdxT(blockIdx.x) * P::Nblk + this->srowid; - this->y = yBase + this->yrowid * this->ldb; - } else { - this->x += stride; - this->yrowid = IdxT(blockIdx.x) * P::Nblk; - this->y = yBase + this->yrowid + this->srowid * this->ldb; - } - this->xrowid += stride; - } - - DI void ldgNextGridStride(IdxT gridStrideX, IdxT gridStrideY) + DI void ldgNextGridStride(IdxT tile_idx_m, IdxT tile_idx_n) { // Fetch next grid stride ldg if within range - if ((gridStrideX + gridDim.x * P::Nblk) < this->n) { - updateIndicesY(); - this->ldgXY(0); - } else if ((gridStrideY + gridDim.y * P::Mblk) < this->m) { - updateIndicesXY(); - this->ldgXY(0); + const auto next_tile_tile_idx_n = tile_idx_n + grid_stride_n; + const auto next_tile_tile_idx_m = tile_idx_m + grid_stride_m; + if ((next_tile_tile_idx_n) < this->n) { + this->ldgXY(tile_idx_m, next_tile_tile_idx_n, 0); + } else if ((next_tile_tile_idx_m) < this->m) { + this->ldgXY(next_tile_tile_idx_m, grid_offset_n, 0); } } - DI void prolog(IdxT gridStrideX, IdxT gridStrideY) + DI void reset_accumulator() { - if (gridStrideX == blockIdx.x * P::Nblk) { this->ldgXY(0); } - + // Reset accumulator registers to zero. #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll @@ -184,28 +199,6 @@ struct PairwiseDistances : public BaseClass { acc[i][j] = BaseClass::Zero; } } - - this->stsXY(); - __syncthreads(); - this->pageWr ^= 1; - } - - DI void loop() - { - for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(kidx); - accumulate(); // on the previous k-block - this->stsXY(); - __syncthreads(); - this->pageWr ^= 1; - this->pageRd ^= 1; - } - accumulate(); // last iteration - // This is needed for making sure next grid stride of - // non-norm based metrics uses previously accumulated buffer so - // it doesn't make shmem dirty until previous iteration - // is complete. - this->pageRd ^= 1; } DI void accumulate() @@ -226,60 +219,52 @@ struct PairwiseDistances : public BaseClass { } } - DI void epilog(IdxT gridStrideX, IdxT gridStrideY) + DI void load_norms(IdxT tile_idx_m, + IdxT tile_idx_n, + DataT (®xn)[P::AccRowsPerTh], + DataT (®yn)[P::AccColsPerTh]) { - if (useNorms) { - DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); - DataT* syNorm = (&sxNorm[P::Mblk]); - - // Load x & y norms required by this threadblock in shmem buffer - if (gridStrideX == blockIdx.x * P::Nblk) { - for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { - auto idx = gridStrideY + i; - sxNorm[i] = idx < this->m ? xn[idx] : 0; - } - } - - for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { - auto idx = gridStrideX + i; - syNorm[i] = idx < this->n ? yn[idx] : 0; + DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); + DataT* syNorm = (&sxNorm[P::Mblk]); + + // Load x & y norms required by this threadblock in shmem buffer + if (tile_idx_n == blockIdx.x * P::Nblk) { + for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { + auto idx = tile_idx_m + i; + sxNorm[i] = idx < this->m ? xn[idx] : 0; } + } - __syncthreads(); + for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { + auto idx = tile_idx_n + i; + syNorm[i] = idx < this->n ? yn[idx] : 0; + } + __syncthreads(); - DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; - } + for (int i = 0; i < P::AccRowsPerTh; ++i) { + regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; + } #pragma unroll - for (int i = 0; i < P::AccColsPerTh; ++i) { - regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; - } - - // Overlap ldg with epilog computation - ldgNextGridStride(gridStrideX, gridStrideY); - epilog_op(acc, regxn, regyn, gridStrideX, gridStrideY); - } else { - // Overlap ldg with epilog computation - ldgNextGridStride(gridStrideX, gridStrideY); - epilog_op(acc, nullptr, nullptr, gridStrideX, gridStrideY); + for (int i = 0; i < P::AccColsPerTh; ++i) { + regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; } + } - if (writeOut) { - IdxT starty = gridStrideY + this->accrowid; - IdxT startx = gridStrideX + this->acccolid; + DI void store_output(IdxT tile_idx_m, IdxT tile_idx_n) + { + IdxT starty = tile_idx_m + this->accrowid; + IdxT startx = tile_idx_n + this->acccolid; #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto rowId = starty + i * P::AccThRows; + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto rowId = starty + i * P::AccThRows; #pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto colId = startx + j * P::AccThCols; - if (rowId < this->m && colId < this->n) { - // Promote to 64 bit index for final write, as output array can be > 2^31 - dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0); - } + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto colId = startx + j * P::AccThCols; + if (rowId < this->m && colId < this->n) { + // Promote to 64 bit index for final write, as output array can be > 2^31 + dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0); } } } diff --git a/cpp/include/raft/linalg/detail/contractions.cuh b/cpp/include/raft/linalg/detail/contractions.cuh index 5d83f88e71..f2d71117f7 100644 --- a/cpp/include/raft/linalg/detail/contractions.cuh +++ b/cpp/include/raft/linalg/detail/contractions.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,14 +40,10 @@ struct Contractions_NT { /** leading dimension in Output D */ IdxT ldd; - /** current thread's global mem row id for X data */ - IdxT xrowid; - /** current thread's global mem row id for Y data */ - IdxT yrowid; /** global memory pointer to X matrix */ - const DataT* x; + const DataT* x_base; /** global memory pointer to Y matrix */ - const DataT* y; + const DataT* y_base; /** current thread's smem row id */ int srowid; @@ -94,10 +90,8 @@ struct Contractions_NT { k(_k), lda(_k), ldb(_k), - xrowid(IdxT(blockIdx.x) * P::Mblk + threadIdx.x / P::LdgThRow), - yrowid(IdxT(blockIdx.y) * P::Nblk + threadIdx.x / P::LdgThRow), - x(_x + xrowid * lda), - y(_y + yrowid * ldb), + x_base(_x), + y_base(_y), srowid(threadIdx.x / P::LdgThRow), scolid((threadIdx.x % P::LdgThRow) * P::Veclen), accrowid(threadIdx.x / P::AccThCols), @@ -133,6 +127,8 @@ struct Contractions_NT { lda(_lda), ldb(_ldb), ldd(_ldd), + x_base(_x), + y_base(_y), srowid(threadIdx.x / P::LdgThRow), scolid((threadIdx.x % P::LdgThRow) * P::Veclen), accrowid(threadIdx.x / P::AccThCols), @@ -142,17 +138,6 @@ struct Contractions_NT { pageWr(0), pageRd(0) { - if (isRowMajor) { - xrowid = IdxT(blockIdx.y) * P::Mblk + srowid; - yrowid = IdxT(blockIdx.x) * P::Nblk + srowid; - x = _x + xrowid * lda; - y = _y + yrowid * ldb; - } else { - xrowid = IdxT(blockIdx.y) * P::Mblk; - yrowid = IdxT(blockIdx.x) * P::Nblk; - x = _x + xrowid + srowid * lda; - y = _y + yrowid + srowid * ldb; - } } protected: @@ -160,10 +145,10 @@ struct Contractions_NT { * @brief Load current block of X/Y from global memory to registers * @param[in] kidx current start index of k to be loaded */ - DI void ldgXY(IdxT kidx) + DI void ldgXY(IdxT tile_idx_m, IdxT tile_idx_n, IdxT kidx) { - ldgX(kidx); - ldgY(kidx); + ldgX(tile_idx_m, kidx); + ldgY(tile_idx_n, kidx); } /** @@ -186,9 +171,16 @@ struct Contractions_NT { ldsY(kidx, sy + pageRd * P::SmemPage); } + DI void switch_read_buffer() { this->pageRd ^= 1; } + + DI void switch_write_buffer() { this->pageWr ^= 1; } + private: - DI void ldgX(IdxT kidx) + DI void ldgX(IdxT tile_idx_m, IdxT kidx) { + IdxT xrowid = isRowMajor ? tile_idx_m + srowid : tile_idx_m; + auto x = isRowMajor ? x_base + xrowid * lda : x_base + xrowid + srowid * lda; + if (isRowMajor) { auto numRows = m; auto koffset = kidx + scolid; @@ -220,8 +212,11 @@ struct Contractions_NT { } } - DI void ldgY(IdxT kidx) + DI void ldgY(IdxT tile_idx_n, IdxT kidx) { + IdxT yrowid = isRowMajor ? tile_idx_n + srowid : tile_idx_n; + auto y = isRowMajor ? y_base + yrowid * ldb : y_base + yrowid + srowid * ldb; + if (isRowMajor) { auto numRows = n; auto koffset = kidx + scolid; @@ -315,4 +310,4 @@ struct Contractions_NT { } // namespace detail } // namespace linalg -} // namespace raft \ No newline at end of file +} // namespace raft diff --git a/cpp/include/raft/neighbors/specializations.cuh b/cpp/include/raft/neighbors/specializations.cuh index 0511bbbf6c..d17467c8a7 100644 --- a/cpp/include/raft/neighbors/specializations.cuh +++ b/cpp/include/raft/neighbors/specializations.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,9 +20,8 @@ #pragma once #include +#include #include #include - -#include - +#include #endif diff --git a/cpp/include/raft/neighbors/specializations/refine.cuh b/cpp/include/raft/neighbors/specializations/refine.cuh new file mode 100644 index 0000000000..71e83a26f3 --- /dev/null +++ b/cpp/include/raft/neighbors/specializations/refine.cuh @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::neighbors { + +#ifdef RAFT_INST +#undef RAFT_INST +#endif + +#define RAFT_INST(T, IdxT) \ + extern template void refine( \ + raft::device_resources const& handle, \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbor_candidates, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + distance::DistanceType metric); \ + \ + extern template void refine( \ + raft::device_resources const& handle, \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbor_candidates, \ + raft::host_matrix_view indices, \ + raft::host_matrix_view distances, \ + distance::DistanceType metric); + +RAFT_INST(float, uint64_t); +RAFT_INST(uint8_t, uint64_t); +RAFT_INST(int8_t, uint64_t); + +#undef RAFT_INST +} // namespace raft::neighbors diff --git a/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh b/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh index 19862d743d..7616083796 100644 --- a/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh +++ b/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -64,7 +64,7 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass { private: DI void prolog() { - this->ldgXY(0); + this->ldgXY(IdxT(blockIdx.x) * P::Mblk, IdxT(blockIdx.y) * P::Nblk, 0); #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll @@ -74,18 +74,18 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass { } this->stsXY(); __syncthreads(); - this->pageWr ^= 1; + this->switch_write_buffer(); } DI void loop() { for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(kidx); + this->ldgXY(IdxT(blockIdx.x) * P::Mblk, IdxT(blockIdx.y) * P::Nblk, kidx); accumulate(); // on the previous k-block this->stsXY(); __syncthreads(); - this->pageWr ^= 1; - this->pageRd ^= 1; + this->switch_write_buffer(); + this->switch_read_buffer(); } accumulate(); // last iteration } diff --git a/cpp/src/distance/neighbors/refine.cu b/cpp/src/distance/neighbors/refine.cu deleted file mode 100644 index 83e3383cba..0000000000 --- a/cpp/src/distance/neighbors/refine.cu +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::runtime::neighbors { - -#define RAFT_INST_REFINE(IDX_T, DATA_T) \ - void refine(raft::device_resources const& handle, \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbor_candidates, \ - raft::device_matrix_view indices, \ - raft::device_matrix_view distances, \ - distance::DistanceType metric) \ - { \ - raft::neighbors::detail::refine_device( \ - handle, dataset, queries, neighbor_candidates, indices, distances, metric); \ - } \ - \ - void refine(raft::device_resources const& handle, \ - raft::host_matrix_view dataset, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbor_candidates, \ - raft::host_matrix_view indices, \ - raft::host_matrix_view distances, \ - distance::DistanceType metric) \ - { \ - raft::neighbors::detail::refine_host( \ - dataset, queries, neighbor_candidates, indices, distances, metric); \ - } - -RAFT_INST_REFINE(uint64_t, float); -RAFT_INST_REFINE(uint64_t, uint8_t); -RAFT_INST_REFINE(uint64_t, int8_t); - -#undef RAFT_INST_REFINE - -} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu new file mode 100644 index 0000000000..d7b460180a --- /dev/null +++ b/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::runtime::neighbors { + +void refine(raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu new file mode 100644 index 0000000000..3db07f0cdb --- /dev/null +++ b/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::runtime::neighbors { + +void refine(raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu new file mode 100644 index 0000000000..2ce43d5800 --- /dev/null +++ b/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::runtime::neighbors { + +void refine(raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu new file mode 100644 index 0000000000..2a2dcff3bf --- /dev/null +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::runtime::neighbors { + +void refine(raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu new file mode 100644 index 0000000000..d7c60b62a5 --- /dev/null +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::runtime::neighbors { +void refine(raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu new file mode 100644 index 0000000000..e9c4345e97 --- /dev/null +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::runtime::neighbors { + +void refine(raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_float.cu b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_float.cu new file mode 100644 index 0000000000..6bb1985d94 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_float.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { + +template void refine( + raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu new file mode 100644 index 0000000000..7e70ee5e29 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { + +template void refine( + raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu new file mode 100644 index 0000000000..53de106ef9 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { + +template void refine( + raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_float.cu b/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_float.cu new file mode 100644 index 0000000000..b473924741 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_float.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { + +template void refine( + raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu new file mode 100644 index 0000000000..c8b0e4c1c2 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { +template void refine( + raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu new file mode 100644 index 0000000000..b9e0f58ef6 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { + +template void refine( + raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/test/neighbors/refine.cu b/cpp/test/neighbors/refine.cu index 98933046b9..a78f5cfe5c 100644 --- a/cpp/test/neighbors/refine.cu +++ b/cpp/test/neighbors/refine.cu @@ -31,7 +31,7 @@ #include -#if defined RAFT_NN_COMPILED +#if defined RAFT_DISTANCE_COMPILED #include #endif @@ -107,26 +107,26 @@ class RefineTest : public ::testing::TestWithParam> { RefineHelper data; }; -const std::vector> inputs = - raft::util::itertools::product>( - {137}, - {1000}, - {16}, - {1, 10, 33}, - {33}, +const std::vector> inputs = + raft::util::itertools::product>( + {static_cast(137)}, + {static_cast(1000)}, + {static_cast(16)}, + {static_cast(1), static_cast(10), static_cast(33)}, + {static_cast(33)}, {raft::distance::DistanceType::L2Expanded, raft::distance::DistanceType::InnerProduct}, {false, true}); -typedef RefineTest RefineTestF; +typedef RefineTest RefineTestF; TEST_P(RefineTestF, AnnRefine) { this->testRefine(); } INSTANTIATE_TEST_CASE_P(RefineTest, RefineTestF, ::testing::ValuesIn(inputs)); -typedef RefineTest RefineTestF_uint8; +typedef RefineTest RefineTestF_uint8; TEST_P(RefineTestF_uint8, AnnRefine) { this->testRefine(); } INSTANTIATE_TEST_CASE_P(RefineTest, RefineTestF_uint8, ::testing::ValuesIn(inputs)); -typedef RefineTest RefineTestF_int8; +typedef RefineTest RefineTestF_int8; TEST_P(RefineTestF_int8, AnnRefine) { this->testRefine(); } INSTANTIATE_TEST_CASE_P(RefineTest, RefineTestF_int8, ::testing::ValuesIn(inputs)); } // namespace raft::neighbors