From 1417a2eece0a90552b455064e1a5be96c19c2647 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 1 Feb 2024 02:01:36 -0800 Subject: [PATCH 01/12] Add fused cosine 1-NN kernel and unify the fused distance 1-NN kernels fix doc issue in fused_distance_nn runtime API --- cpp/CMakeLists.txt | 4 +- cpp/bench/prims/CMakeLists.txt | 12 +- .../distance/detail/fused_distance_nn.cuh | 89 ++++ .../custom_epilogue_with_broadcast.h | 3 +- .../detail/fused_distance_nn/cutlass_base.cuh | 19 +- .../epilogue_elementwise.cuh | 10 +- .../fused_distance_nn/fused_cosine_nn.cuh | 135 ++++++ .../fused_distance_nn/helper_structs.cuh | 145 ++++++ .../fused_distance_nn/persistent_gemm.h | 7 +- .../predicated_tile_iterator_reduced_vec.h | 101 +++-- .../detail/fused_distance_nn/simt_kernel.cuh | 186 ++++++++ .../raft/distance/detail/fused_l2_nn.cuh | 262 +---------- .../raft/distance/fused_distance_nn-ext.cuh | 91 ++++ .../raft/distance/fused_distance_nn-inl.cuh | 325 ++++++++++++++ .../raft/distance/fused_distance_nn.cuh | 24 + ...pers.cuh => fused_distance_nn_helpers.cuh} | 5 +- cpp/include/raft/distance/fused_l2_nn-ext.cuh | 12 +- cpp/include/raft/distance/fused_l2_nn-inl.cuh | 4 +- .../distance/fused_distance_nn.hpp | 74 ++++ cpp/src/distance/fused_distance_nn.cu | 60 +++ .../distance/fused_distance_min_arg.cu | 137 ++++++ cpp/test/CMakeLists.txt | 1 + cpp/test/distance/fused_cosine_nn.cu | 416 ++++++++++++++++++ 23 files changed, 1789 insertions(+), 333 deletions(-) create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn.cuh create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh create mode 100644 cpp/include/raft/distance/fused_distance_nn-ext.cuh create mode 100644 cpp/include/raft/distance/fused_distance_nn-inl.cuh create mode 100755 cpp/include/raft/distance/fused_distance_nn.cuh rename cpp/include/raft/distance/{fused_l2_nn_helpers.cuh => fused_distance_nn_helpers.cuh} (89%) create mode 100644 cpp/include/raft_runtime/distance/fused_distance_nn.hpp create mode 100644 cpp/src/distance/fused_distance_nn.cu create mode 100644 cpp/src/raft_runtime/distance/fused_distance_min_arg.cu create mode 100644 cpp/test/distance/fused_cosine_nn.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 650bc1a059..61ebbe9978 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -259,7 +259,7 @@ endif() if(RAFT_NVTX) # This enables NVTX within the project with no option to disable it downstream. - target_link_libraries(raft INTERFACE CUDA::nvToolsExt) + target_link_libraries(raft INTERFACE CUDA::nvtx3) target_compile_definitions(raft INTERFACE NVTX_ENABLED) else() # Allow enable NVTX downstream if not set here. This creates a new option at build/install time, @@ -327,6 +327,7 @@ if(RAFT_COMPILE_LIBRARY) src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu src/distance/distance.cu src/distance/fused_l2_nn.cu + src/distance/fused_distance_nn.cu src/linalg/detail/coalesced_reduction.cu src/matrix/detail/select_k_double_int64_t.cu src/matrix/detail/select_k_double_uint32_t.cu @@ -425,6 +426,7 @@ if(RAFT_COMPILE_LIBRARY) src/raft_runtime/cluster/update_centroids.cuh src/raft_runtime/cluster/update_centroids_double.cu src/raft_runtime/cluster/update_centroids_float.cu + src/raft_runtime/distance/fused_distance_min_arg.cu src/raft_runtime/distance/fused_l2_min_arg.cu src/raft_runtime/distance/pairwise_distance.cu src/raft_runtime/matrix/select_k_float_int64_t.cu diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 3a2431cd34..d031431946 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -122,16 +122,8 @@ if(BUILD_PRIMS_BENCH) ) ConfigureBench( - NAME - MATRIX_BENCH - PATH - bench/prims/matrix/argmin.cu - bench/prims/matrix/gather.cu - bench/prims/matrix/select_k.cu - bench/prims/matrix/main.cpp - OPTIONAL - LIB - EXPLICIT_INSTANTIATE_ONLY + NAME MATRIX_BENCH PATH bench/prims/matrix/argmin.cu bench/prims/matrix/gather.cu + bench/prims/matrix/select_k.cu bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureBench( diff --git a/cpp/include/raft/distance/detail/fused_distance_nn.cuh b/cpp/include/raft/distance/detail/fused_distance_nn.cuh new file mode 100644 index 0000000000..94f199275d --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn.cuh @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // size_t +#include // std::numeric_limits +#include // raft::KeyValuePair +#include // raft::identity_op +#include // ops::l2_exp_distance_op +#include +#include +#include +#include +#include // PairwiseDistances +#include +#include // Policy +#include // raft::util::arch::SM_* +#include // raft::ceildiv, raft::shfl + +namespace raft { +namespace distance { + +namespace detail { + +template +void fusedDistanceNNImpl(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + int* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer, + bool isRowMajor, + raft::distance::DistanceType metric, + float metric_arg, + cudaStream_t stream) +{ + // The kernel policy is determined by fusedDistanceNN. + typedef Policy P; + + dim3 blk(P::Nthreads); + auto nblks = raft::ceildiv(m, P::Nthreads); + constexpr auto maxVal = std::numeric_limits::max(); + typedef KeyValuePair KVPair; + + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + if (initOutBuffer) { + initKernel + <<>>(min, m, maxVal, redOp); + RAFT_CUDA_TRY(cudaGetLastError()); + } + + switch (metric) { + case DistanceType::CosineExpanded: + fusedCosineNN( + min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, stream); + break; + default: assert("only cosine metric is supported with fusedDistanceNN\n"); break; + } +} + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h index f659ed256d..ac20578083 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h @@ -30,7 +30,7 @@ **************************************************************************************************/ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -615,6 +615,7 @@ class EpilogueWithBroadcastCustom : public EpilogueBase #include +#include #include // FusedDistanceNNEpilogueElementwise #include // FusedDistanceNNGemm #include // getMultiProcessorCount @@ -46,6 +47,14 @@ namespace raft { namespace distance { namespace detail { +template +RAFT_KERNEL initBinMutexKernel(cuda::binary_semaphore* mut, IdxT m) +{ + auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; + + if (tid < m) { mut[tid].release(); } +} + template ; constexpr int batch_count = 1; + rmm::device_uvector> bin_mutex(m, stream); + + int blks_ = (m / 256) + 1; + + initBinMutexKernel<<>>(bin_mutex.data(), m); + typename EpilogueOutputOp::Params epilog_op_param( - dist_op, cg_reduce_op, redOp, pairRedOp, mutexes); + dist_op, cg_reduce_op, redOp, pairRedOp, mutexes, bin_mutex.data()); // Number of pipelines you want to use constexpr int NumStages = 3; diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh index a21f3d60e0..d65d2df4a4 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh @@ -29,7 +29,7 @@ * **************************************************************************************************/ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -62,6 +62,7 @@ #include #include +#include #include ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -122,6 +123,7 @@ class FusedDistanceNNEpilogueElementwise { KVPReduceOpT_ pair_redop_; ReduceOpT_ red_op_; int* mutexes_; + cuda::binary_semaphore* bin_mutex_; using CGReduceT = CGReduceOp_; // // Methods @@ -131,12 +133,14 @@ class FusedDistanceNNEpilogueElementwise { CGReduceOp cg_reduce_op, ReduceOpT_ red_op, KVPReduceOpT_ pair_redop, - int* mutexes) + int* mutexes, + cuda::binary_semaphore* bin_mutex) : cg_reduce_op(cg_reduce_op), dist_op_(dist_op), pair_redop_(pair_redop), red_op_(red_op), - mutexes_(mutexes) + mutexes_(mutexes), + bin_mutex_(bin_mutex) { } diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh new file mode 100644 index 0000000000..e86db734a5 --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // size_t +#include // std::numeric_limits +#include // raft::KeyValuePair +#include // raft::identity_op +#include // ops::l2_exp_distance_op +#include +#include +#include +#include // PairwiseDistances +#include // Policy +#include // raft::util::arch::SM_* +#include // raft::ceildiv, raft::shfl + +namespace raft { +namespace distance { + +namespace detail { + +template +void fusedCosineNN(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + int* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + cudaStream_t stream) +{ + // The kernel policy is determined by fusedL2NN. + typedef Policy P; + + dim3 blk(P::Nthreads); + constexpr auto maxVal = std::numeric_limits::max(); + typedef KeyValuePair KVPair; + + namespace arch = raft::util::arch; + using AccT = DataT; + ops::cosine_distance_op distance_op{}; + + raft::identity_op fin_op{}; + + auto kernel = fusedDistanceNNkernel; + + // Get pointer to fp32 SIMT kernel to determine the runtime architecture of the + // current system. Other methods to determine the architecture (that do not + // require a pointer) can be error prone. See: + // https://github.com/NVIDIA/cub/issues/545 + void* kernel_ptr = reinterpret_cast(kernel); + auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); + auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); + + if (cutlass_range.contains(runtime_arch)) { + // If device is SM_80 or later, use CUTLASS-based kernel. + using cosineOp = raft::distance::detail::ops::cosine_cutlass_op; + using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; + kvp_cg_min_reduce_op_ cg_reduce_op; + cosineOp cosine_dist_op; + + IdxT lda, ldb, ldd; + lda = k, ldb = k, ldd = n; + + cutlassFusedDistanceNN(x, + y, + xn, + yn, + m, + n, + k, + lda, + ldb, + ldd, + min, + workspace, + cg_reduce_op, + cosine_dist_op, + redOp, + pairRedOp, + stream); + } else { + // If device less than SM_80, use fp32 SIMT kernel. + constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); + dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); + + kernel<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); + RAFT_CUDA_TRY(cudaGetLastError()); + } +} + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh new file mode 100644 index 0000000000..e88ea9cfc8 --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh @@ -0,0 +1,145 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // size_t +#include // std::numeric_limits +#include // raft::KeyValuePair +#include // raft::identity_op +#include // ops::l2_exp_distance_op +#include +#include +#include // PairwiseDistances +#include // Policy +#include // raft::util::arch::SM_* +#include // raft::ceildiv, raft::shfl +#include + +namespace raft { +namespace distance { + +namespace detail { + +template +struct KVPMinReduceImpl { + typedef raft::KeyValuePair KVP; + DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + +}; // KVPMinReduce + +template +struct MinAndDistanceReduceOpImpl { + typedef typename raft::KeyValuePair KVP; + + DI void operator()(LabelT rid, KVP* out, const KVP& other) const + { + if (other.value < out->value) { + out->key = other.key; + out->value = other.value; + } + } + DI void operator()(LabelT rid, volatile KVP* out, const KVP& other) const + { + if (other.value < out->value) { + out->key = other.key; + out->value = other.value; + } + } + + DI void operator()(LabelT rid, DataT* out, const KVP& other) const + { + if (other.value < *out) { *out = other.value; } + } + + DI void operator()(LabelT rid, volatile DataT* out, const KVP& other) const + { + if (other.value < *out) { *out = other.value; } + } + + DI void operator()(LabelT rid, DataT* out, const DataT& other) const + { + if (other < *out) { *out = other; } + } + + DI void operator()(LabelT rid, volatile DataT* out, const DataT& other) const + { + if (other < *out) { *out = other; } + } + + DI void init(DataT* out, DataT maxVal) const { *out = maxVal; } + DI void init(KVP* out, DataT maxVal) const + { + out->value = maxVal; + out->key = 0xfffffff0; + } + + DI void init_key(DataT& out, LabelT idx) const { return; } + DI void init_key(KVP& out, LabelT idx) const { out.key = idx; } + + DI DataT get_value(KVP& out) const { return out.value; } + DI DataT get_value(DataT& out) const { return out; } +}; + +template +struct MinReduceOpImpl { + typedef typename raft::KeyValuePair KVP; + DI void operator()(LabelT rid, DataT* out, const KVP& other) + { + if (other.value < *out) { *out = other.value; } + } + + DI void init(DataT* out, DataT maxVal) { *out = maxVal; } +}; + +template +RAFT_KERNEL initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) +{ + auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; + if (tid < m) { redOp.init(min + tid, maxVal); } +} + +template +void initialize(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp, cudaStream_t stream) +{ + auto blks = raft::ceildiv(m, 256); + initKernel<<>>(min, m, maxVal, redOp); +} + +// cg::reduce functor for FusedDistanceNN used in its cutlass version +// to output the min distance value & key(loc id). +// This is used in fused_distance_nn/predicated_tile_iterator_reduced_vec.h +// store_with_byte_offset() passed to cg::reduce() & select_reduce. +template +struct kvp_cg_min_reduce_op { + typedef typename raft::KeyValuePair KVP; + + __host__ __device__ kvp_cg_min_reduce_op() noexcept {}; + + using AccTypeT = AccType; + using IndexT = Index; + // functor signature. + __host__ __device__ KVP operator()(KVP a, KVP b) const { return a.value < b.value ? a : b; } + + __host__ __device__ AccType operator()(AccType a, AccType b) const { return min(a, b); } + + __host__ __device__ bool isAmin(AccType a, AccType b) const { return a < b ? true : false; } +}; + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h index 3a8d6c8655..a04fe36b79 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h @@ -29,7 +29,7 @@ * **************************************************************************************************/ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -181,8 +181,7 @@ struct FusedDistanceNNPersistent { /// Default ctor CUTLASS_HOST_DEVICE Arguments() - : // problem_count(0), - threadblock_count(0), + : threadblock_count(0), ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), @@ -206,6 +205,7 @@ struct FusedDistanceNNPersistent { void const* ptr_B, void const* ptr_C, void* ptr_Vector, + // volatile void* ptr_Tensor, void* ptr_Tensor, typename LayoutA::Stride::Index lda, typename LayoutB::Stride::Index ldb, @@ -236,7 +236,6 @@ struct FusedDistanceNNPersistent { /// Parameters structure struct Params { - // typename ProblemVisitor::Params problem_visitor; temp_problem_visitor problem_visitor; int threadblock_count; diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h index dc224c5c96..4591fa7855 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h @@ -29,7 +29,7 @@ * **************************************************************************************************/ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -322,9 +322,10 @@ class PredicatedTileIteratorReducedVec { Params params_; /// Byte-level pointer - uint8_t* byte_pointer_; + // uint8_t* byte_pointer_; /// Byte-level pointer first tile offset of this threadblock. - uint8_t* first_tile_byte_pointer_; + volatile uint8_t* first_tile_byte_pointer_; + // uint8_t* first_tile_byte_pointer_; /// Array of boolean values to contain steady-state predicates Mask mask_; @@ -349,6 +350,8 @@ class PredicatedTileIteratorReducedVec { /// Scatter indices int const* indices_; + const int do_gmem_reduction_; + // // Static asserts about internal strides // @@ -359,7 +362,6 @@ class PredicatedTileIteratorReducedVec { protected: SharedStorage& shared_storage_; - const bool& do_gmem_reduction_; private: // @@ -373,10 +375,10 @@ class PredicatedTileIteratorReducedVec { CUTLASS_DEVICE PredicatedTileIteratorReducedVec(SharedStorage& shared_storage, Params const& params, - Element* pointer, + volatile Element* pointer, TensorCoord extent, int thread_idx, - const bool& do_gmem_reduction, + const bool do_gmem_reduction, TensorCoord threadblock_offset = TensorCoord(), int const* indices = nullptr) : params_(params), @@ -408,6 +410,7 @@ class PredicatedTileIteratorReducedVec { EpilogueOpParams const& user_params = params_.user_param; shared_storage_.initSmem(user_params); } + __syncthreads(); // Null pointer performs no accesses if (!pointer) { mask_.clear(); } @@ -415,65 +418,61 @@ class PredicatedTileIteratorReducedVec { if (ScatterD && !indices) { mask_.clear(); } // Initialize pointer - first_tile_byte_pointer_ = reinterpret_cast(pointer) + + first_tile_byte_pointer_ = reinterpret_cast(pointer) + LongIndex(block_offset.row()) * LongIndex(params_.stride); - if (ScatterD) { - byte_pointer_ = reinterpret_cast(pointer) + - LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; - } - + // first_tile_byte_pointer_ = reinterpret_cast(pointer) + + // LongIndex(block_offset.row()) * LongIndex(params_.stride); // Initialize internal state counter state_[0] = state_[1] = state_[2] = 0; } - /// Destructor - CUTLASS_DEVICE - ~PredicatedTileIteratorReducedVec() + CUTLASS_DEVICE void dumpToGmem() { + if (block_start_row_first_tile_ >= extent_row_) { return; } + if (do_gmem_reduction_) { EpilogueOpParams const& user_params = params_.user_param; - auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); - Element* shared_elem_arr = shared_storage_.data(); const uint32_t mutex_id = (block_start_row_first_tile_ / total_rows); - bool useGmemMutex = (gridDim.x != ((extent_row_ - 1 + total_rows) / total_rows)); - // If this is not optimal grid size perform mutex based gmem reduce. - if (useGmemMutex) { - // single lock per block for multiple rows - if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { - // acquire mutex lock. - unsigned int ns = 8; - while (atomicCAS(user_params.mutexes_ + mutex_id, 0, 1) == 1) { - __nanosleep(ns); - if (ns < 256) { ns *= 2; } - } + const bool useGmemMutex = (gridDim.x != ((extent_row_ - 1 + total_rows) / total_rows)); + int row = threadIdx.x; + Element* shared_elem_arr = shared_storage_.data(); + Element row_local_min; + if (row < total_rows) { row_local_min = shared_elem_arr[row]; } + + // single lock per block for multiple rows + if (useGmemMutex && threadIdx.x == 0) { user_params.bin_mutex_[mutex_id].acquire(); } + __syncthreads(); + + if (row < total_rows) { + volatile Element* gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); + + if ((block_start_row_first_tile_ + row) < extent_row_) { + user_params.red_op_(block_start_row_first_tile_ + row, (gmem_ptr + row), row_local_min); } } __syncthreads(); - for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { - if (block_start_row_first_tile_ + row < extent_row_) { - user_params.red_op_( - block_start_row_first_tile_ + row, &gmem_ptr[row], shared_elem_arr[row]); - } - } + __threadfence(); - if (useGmemMutex) { - __threadfence(); - __syncthreads(); - if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { - // release mutex lock. - atomicExch(user_params.mutexes_ + mutex_id, 0); - } + if (useGmemMutex && (threadIdx.x == 0)) { + // release mutex lock. + user_params.bin_mutex_[mutex_id].release(); } + shared_storage_.initSmem(user_params); + __syncthreads(); } } + /// Destructor + CUTLASS_DEVICE + ~PredicatedTileIteratorReducedVec() {} + /// Adds a pointer offset in units of Element CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset) { - byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + // byte_pointer_ += pointer_offset * sizeof_bits::value / 8; } /// Performs reduction and Stores a reduced output to memory @@ -514,9 +513,6 @@ class PredicatedTileIteratorReducedVec { user_params.red_op_.init(&red_val, maxVal); if (row_guard) { - const int iter_row = (row_id % total_rows); - const auto prev_red_val = user_params.red_op_.get_value(shared_elem_arr[iter_row]); - CUTLASS_PRAGMA_UNROLL for (int column = 0; column < ThreadMap::Iterations::kColumn * kElementsPerAccess; ++column) { @@ -535,6 +531,10 @@ class PredicatedTileIteratorReducedVec { user_params.red_op_(row_id, &red_val, this_val); } } + } + const int iter_row = (row_id % total_rows); + const auto prev_red_val = user_params.red_op_.get_value(shared_elem_arr[iter_row]); + if (row_guard) { // select_reduce doesn't need to use `red_op_` as at the warp level we use cg_reduce_op, // this satisfies the requirement of mst/single linkage of checking colors buffer. select_reduce red_obj( @@ -543,6 +543,7 @@ class PredicatedTileIteratorReducedVec { } } } + __syncthreads(); } /// Stores a fragment to memory @@ -573,15 +574,14 @@ class PredicatedTileIteratorReducedVec { PredicatedTileIteratorReducedVec& operator++() { ++state_[0]; - - if (!ScatterD) { byte_pointer_ += params_.advance_row; } + // if (!ScatterD) { byte_pointer_ += params_.advance_row; } thread_start_row_ += ThreadMap::Shape::kRow; if (state_[0] == ThreadMap::Count::kRow) { state_[0] = 0; ++state_[1]; - byte_pointer_ += params_.advance_group; + // byte_pointer_ += params_.advance_group; thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; @@ -589,18 +589,17 @@ class PredicatedTileIteratorReducedVec { if (state_[1] == ThreadMap::Count::kGroup) { state_[1] = 0; ++state_[2]; - byte_pointer_ += params_.advance_cluster; + // byte_pointer_ += params_.advance_cluster; thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; if (state_[2] == ThreadMap::Count::kCluster) { state_[2] = 0; - byte_pointer_ += params_.advance_tile; + // byte_pointer_ += params_.advance_tile; } } } - return *this; } diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh new file mode 100644 index 0000000000..f5e4c725d6 --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh @@ -0,0 +1,186 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // size_t +#include // std::numeric_limits +#include // raft::KeyValuePair +#include // ops::l2_exp_distance_op +#include // PairwiseDistances +#include // Policy + +namespace raft { +namespace distance { +namespace detail { + +// TODO: specialize this function for MinAndDistanceReduceOp +// with atomicCAS of 64 bit which will eliminate mutex and shfls +template +DI void updateReducedVal( + int* mutex, OutT* min, KVPair* val, ReduceOpT red_op, IdxT m, IdxT gridStrideY) +{ + const auto lid = threadIdx.x % raft::WarpSize; + const auto accrowid = threadIdx.x / P::AccThCols; + + // Update each output row in order within a warp. This will resolve hang + // issues with pre-Volta architectures +#pragma unroll + for (int j = 0; j < (raft::WarpSize / P::AccThCols); j++) { + if (lid == j * P::AccThCols) { +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto rid = gridStrideY + accrowid + i * P::AccThRows; + if (rid < m) { + auto value = val[i]; + while (atomicCAS(mutex + rid, 0, 1) == 1) + ; + __threadfence(); + red_op(rid, min + rid, value); + __threadfence(); + atomicCAS(mutex + rid, 1, 0); + } + } + } + } +} + +template +__launch_bounds__(P::Nthreads, 2) RAFT_KERNEL fusedDistanceNNkernel(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + DataT maxVal, + int* mutex, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + OpT distance_op, + FinalLambda fin_op) +{ +// compile only if below non-ampere arch. +#if __CUDA_ARCH__ < 800 + extern __shared__ char smem[]; + + typedef KeyValuePair KVPair; + KVPair val[P::AccRowsPerTh]; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {0, maxVal}; + } + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [n, pairRedOp, &val, maxVal] __device__( + DataT acc[P::AccRowsPerTh][P::AccColsPerTh], + DataT * regxn, + DataT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) { + KVPReduceOpT pairRed_op(pairRedOp); + + // intra thread reduce + const auto acccolid = threadIdx.x % P::AccThCols; + const auto accrowid = threadIdx.x / P::AccThCols; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto tmpkey = acccolid + j * P::AccThCols + gridStrideX; + KVPair tmp = {tmpkey, acc[i][j]}; + if (tmpkey < n) { + val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); + } + } + } + }; + + auto rowEpilog_lambda = + [m, mutex, min, pairRedOp, redOp, &val, maxVal] __device__(IdxT gridStrideY) { + KVPReduceOpT pairRed_op(pairRedOp); + ReduceOpT red_op(redOp); + + const auto accrowid = threadIdx.x / P::AccThCols; + const auto lid = raft::laneId(); + + // reduce +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = P::AccThCols / 2; j > 0; j >>= 1) { + // Actually, the srcLane (lid +j) should be (lid +j) % P:AccThCols, + // but the shfl op applies the modulo internally. + auto tmpkey = raft::shfl(val[i].key, lid + j, P::AccThCols); + auto tmpvalue = raft::shfl(val[i].value, lid + j, P::AccThCols); + KVPair tmp = {tmpkey, tmpvalue}; + val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); + } + } + + updateReducedVal(mutex, min, val, red_op, m, gridStrideY); + + // reset the val array. +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {0, maxVal}; + } + }; + + IdxT lda = k, ldb = k, ldd = n; + constexpr bool row_major = true; + constexpr bool write_out = false; + PairwiseDistances + obj(x, + y, + m, + n, + k, + lda, + ldb, + ldd, + xn, + yn, + nullptr, // Output pointer + smem, + distance_op, + epilog_lambda, + fin_op, + rowEpilog_lambda); + obj.run(); +#endif +} + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 2468dcd740..75275d40b3 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,8 @@ #include // raft::identity_op #include // ops::l2_exp_distance_op #include +#include +#include #include // PairwiseDistances #include // Policy #include // raft::util::arch::SM_* @@ -32,248 +34,6 @@ namespace distance { namespace detail { -template -struct KVPMinReduceImpl { - typedef raft::KeyValuePair KVP; - DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } - DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } - -}; // KVPMinReduce - -template -struct MinAndDistanceReduceOpImpl { - typedef typename raft::KeyValuePair KVP; - DI void operator()(LabelT rid, KVP* out, const KVP& other) const - { - if (other.value < out->value) { - out->key = other.key; - out->value = other.value; - } - } - - DI void operator()(LabelT rid, DataT* out, const KVP& other) const - { - if (other.value < *out) { *out = other.value; } - } - - DI void operator()(LabelT rid, DataT* out, const DataT& other) const - { - if (other < *out) { *out = other; } - } - - DI void init(DataT* out, DataT maxVal) const { *out = maxVal; } - DI void init(KVP* out, DataT maxVal) const { out->value = maxVal; } - - DI void init_key(DataT& out, LabelT idx) const { return; } - DI void init_key(KVP& out, LabelT idx) const { out.key = idx; } - - DI DataT get_value(KVP& out) const - { - return out.value; - ; - } - DI DataT get_value(DataT& out) const { return out; } -}; - -template -struct MinReduceOpImpl { - typedef typename raft::KeyValuePair KVP; - DI void operator()(LabelT rid, DataT* out, const KVP& other) - { - if (other.value < *out) { *out = other.value; } - } - - DI void init(DataT* out, DataT maxVal) { *out = maxVal; } -}; - -template -RAFT_KERNEL initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) -{ - auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; - if (tid < m) { redOp.init(min + tid, maxVal); } -} - -template -void initialize(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp, cudaStream_t stream) -{ - auto blks = raft::ceildiv(m, 256); - initKernel<<>>(min, m, maxVal, redOp); -} - -// TODO: specialize this function for MinAndDistanceReduceOp -// with atomicCAS of 64 bit which will eliminate mutex and shfls -template -DI void updateReducedVal( - int* mutex, OutT* min, KVPair* val, ReduceOpT red_op, IdxT m, IdxT gridStrideY) -{ - const auto lid = threadIdx.x % raft::WarpSize; - const auto accrowid = threadIdx.x / P::AccThCols; - - // Update each output row in order within a warp. This will resolve hang - // issues with pre-Volta architectures -#pragma unroll - for (int j = 0; j < (raft::WarpSize / P::AccThCols); j++) { - if (lid == j * P::AccThCols) { -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto rid = gridStrideY + accrowid + i * P::AccThRows; - if (rid < m) { - auto value = val[i]; - while (atomicCAS(mutex + rid, 0, 1) == 1) - ; - __threadfence(); - red_op(rid, min + rid, value); - __threadfence(); - atomicCAS(mutex + rid, 1, 0); - } - } - } - } -} - -template -__launch_bounds__(P::Nthreads, 2) RAFT_KERNEL fusedL2NNkernel(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - DataT maxVal, - int* mutex, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - OpT distance_op, - FinalLambda fin_op) -{ -// compile only if below non-ampere arch. -#if __CUDA_ARCH__ < 800 - extern __shared__ char smem[]; - - typedef KeyValuePair KVPair; - KVPair val[P::AccRowsPerTh]; -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {0, maxVal}; - } - - // epilogue operation lambda for final value calculation - auto epilog_lambda = [n, pairRedOp, &val, maxVal] __device__( - DataT acc[P::AccRowsPerTh][P::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - KVPReduceOpT pairRed_op(pairRedOp); - - // intra thread reduce - const auto acccolid = threadIdx.x % P::AccThCols; - const auto accrowid = threadIdx.x / P::AccThCols; -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto tmpkey = acccolid + j * P::AccThCols + gridStrideX; - KVPair tmp = {tmpkey, acc[i][j]}; - if (tmpkey < n) { - val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); - } - } - } - }; - - auto rowEpilog_lambda = - [m, mutex, min, pairRedOp, redOp, &val, maxVal] __device__(IdxT gridStrideY) { - KVPReduceOpT pairRed_op(pairRedOp); - ReduceOpT red_op(redOp); - - const auto accrowid = threadIdx.x / P::AccThCols; - const auto lid = raft::laneId(); - - // reduce -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = P::AccThCols / 2; j > 0; j >>= 1) { - // Actually, the srcLane (lid +j) should be (lid +j) % P:AccThCols, - // but the shfl op applies the modulo internally. - auto tmpkey = raft::shfl(val[i].key, lid + j, P::AccThCols); - auto tmpvalue = raft::shfl(val[i].value, lid + j, P::AccThCols); - KVPair tmp = {tmpkey, tmpvalue}; - val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); - } - } - - updateReducedVal(mutex, min, val, red_op, m, gridStrideY); - - // reset the val array. -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {0, maxVal}; - } - }; - - IdxT lda = k, ldb = k, ldd = n; - constexpr bool row_major = true; - constexpr bool write_out = false; - PairwiseDistances - obj(x, - y, - m, - n, - k, - lda, - ldb, - ldd, - xn, - yn, - nullptr, // Output pointer - smem, - distance_op, - epilog_lambda, - fin_op, - rowEpilog_lambda); - obj.run(); -#endif -} - -// cg::reduce functor for FusedDistanceNN used in its cutlass version -// to output the min distance value & key(loc id). -// This is used in fused_distance_nn/predicated_tile_iterator_reduced_vec.h -// store_with_byte_offset() passed to cg::reduce() & select_reduce. -template -struct kvp_cg_min_reduce_op { - typedef typename raft::KeyValuePair KVP; - - __host__ __device__ kvp_cg_min_reduce_op() noexcept {}; - - using AccTypeT = AccType; - using IndexT = Index; - // functor signature. - __host__ __device__ KVP operator()(KVP a, KVP b) const { return a.value < b.value ? a : b; } - - __host__ __device__ AccType operator()(AccType a, AccType b) const { return min(a, b); } - - __host__ __device__ bool isAmin(AccType a, AccType b) const { return a < b ? true : false; } -}; - template ; + auto kernel = fusedDistanceNNkernel; // Get pointer to fp32 SIMT kernel to determine the best compute architecture // out of all for which the kernel was compiled for that matches closely diff --git a/cpp/include/raft/distance/fused_distance_nn-ext.cuh b/cpp/include/raft/distance/fused_distance_nn-ext.cuh new file mode 100644 index 0000000000..9dd236a3bd --- /dev/null +++ b/cpp/include/raft/distance/fused_distance_nn-ext.cuh @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // int64_t +#include // raft::KeyValuePair +#include // raft::resources +#include // include initialize and reduce operations +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft { +namespace distance { + +template +void fusedDistanceNNMinReduce(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + bool sqrt, + bool initOutBuffer, + bool isRowMajor, + raft::distance::DistanceType metric, + float metric_arg, + cudaStream_t stream) RAFT_EXPLICIT; + +} // namespace distance +} // namespace raft + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_distance_fusedDistanceNNMinReduce(DataT, OutT, IdxT) \ + extern template void raft::distance::fusedDistanceNNMinReduce( \ + OutT * min, \ + const DataT* x, \ + const DataT* y, \ + const DataT* xn, \ + const DataT* yn, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + bool sqrt, \ + bool initOutBuffer, \ + bool isRowMajor, \ + raft::distance::DistanceType metric, \ + float metric_arg, \ + cudaStream_t stream) + +instantiate_raft_distance_fusedDistanceNNMinReduce(double, double, int); +instantiate_raft_distance_fusedDistanceNNMinReduce(double, double, int64_t); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int64_t); + +// We can't have comma's in the macro expansion, so we use the COMMA macro: +#define COMMA , + +instantiate_raft_distance_fusedDistanceNNMinReduce(double, + raft::KeyValuePair, + int); +instantiate_raft_distance_fusedDistanceNNMinReduce(double, + raft::KeyValuePair, + int64_t); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, raft::KeyValuePair, int); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, + raft::KeyValuePair, + int64_t); + +#undef COMMA + +#undef instantiate_raft_distance_fusedDistanceNNMinReduce diff --git a/cpp/include/raft/distance/fused_distance_nn-inl.cuh b/cpp/include/raft/distance/fused_distance_nn-inl.cuh new file mode 100644 index 0000000000..342bde828d --- /dev/null +++ b/cpp/include/raft/distance/fused_distance_nn-inl.cuh @@ -0,0 +1,325 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __FUSED_DISTANCE_NN_H +#define __FUSED_DISTANCE_NN_H + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace distance { + +/** + * \ingroup fused_l2_nn + * @{ + */ +/** + * @brief Fused L2 distance and 1-nearest-neighbor computation in a single call. + * + * The benefits of such a call are 2-fold: 1) eliminate the need for an + * intermediate buffer to store the output of gemm 2) reduce the memory read + * traffic on this intermediate buffer, otherwise needed during the reduction + * phase for 1-NN. + * + * @tparam DataT data type + * @tparam OutT output type to either store 1-NN indices and their minimum + * distances or store only the min distances. Accordingly, one + * has to pass an appropriate `ReduceOpT` + * @tparam IdxT indexing arithmetic type + * @tparam ReduceOpT A struct to perform the final needed reduction operation + * and also to initialize the output array elements with the + * appropriate initial value needed for reduction. + * @tparam KVPReduceOpT A struct providing functions for key-value pair comparison. + * + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). + * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) + * @param[in] redOp reduction operator in the epilogue + * @param[in] pairRedOp reduction operation on key value pairs + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + * @param[in] initOutBuffer whether to initialize the output buffer before the + * main kernel launch + * @param[in] isRowMajor whether the input/output is row or column major. + * @param[in] metric Distance metric to be used (supports L2, cosine) + * @param[in] metric_arg power argument for distances like Minkowski (not supported for now) + * @param[in] stream cuda stream + */ +template +void fusedDistanceNN(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer, + bool isRowMajor, + raft::distance::DistanceType metric, + float metric_arg, + cudaStream_t stream) +{ + ASSERT(isRowMajor, "fusedDistanceNN only supports row major inputs"); + // When k is smaller than 32, the Policy4x4 results in redundant calculations + // as it uses tiles that have k=32. Therefore, use a "skinny" policy instead + // that uses tiles with a smaller value of k. + bool is_skinny = k < 32; + + size_t bytes = sizeof(DataT) * k; + auto px = reinterpret_cast(x); + auto py = reinterpret_cast(y); + if (16 % sizeof(DataT) == 0 && bytes % 16 == 0 && px % 16 == 0 && py % 16 == 0) { + if (is_skinny) { + detail::fusedDistanceNNImpl< + DataT, + OutT, + IdxT, + typename linalg::Policy4x4Skinny::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } else { + detail::fusedDistanceNNImpl::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } + } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0 && px % 8 == 0 && py % 8 == 0) { + if (is_skinny) { + detail::fusedDistanceNNImpl< + DataT, + OutT, + IdxT, + typename linalg::Policy4x4Skinny::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } else { + detail::fusedDistanceNNImpl::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } + } else { + if (is_skinny) { + detail::fusedDistanceNNImpl::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } else { + detail::fusedDistanceNNImpl::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } + } +} + +/** + * @brief Wrapper around fusedDistanceNN with minimum reduction operators. + * + * fusedDistanceNN cannot be compiled in the distance library due to the lambda + * operators, so this wrapper covers the most common case (minimum). + * + * @tparam DataT data type + * @tparam OutT output type to either store 1-NN indices and their minimum + * distances (e.g. raft::KeyValuePair) or store only the min + * distances. + * @tparam IdxT indexing arithmetic type + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). + * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + * @param[in] initOutBuffer whether to initialize the output buffer before the + * main kernel launch + * @param[in] isRowMajor whether the input/output is row or column major. + * @param[in] metric Distance metric to be used (supports L2, cosine) + * @param[in] metric_arg power argument for distances like Minkowski (not supported for now) + * @param[in] stream cuda stream + */ +template +void fusedDistanceNNMinReduce(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + bool sqrt, + bool initOutBuffer, + bool isRowMajor, + raft::distance::DistanceType metric, + float metric_arg, + cudaStream_t stream) +{ + MinAndDistanceReduceOp redOp; + KVPMinReduce pairRedOp; + + fusedDistanceNN(min, + x, + y, + xn, + yn, + m, + n, + k, + workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); +} + +/** @} */ + +} // namespace distance +} // namespace raft + +#endif diff --git a/cpp/include/raft/distance/fused_distance_nn.cuh b/cpp/include/raft/distance/fused_distance_nn.cuh new file mode 100755 index 0000000000..0c22df72f1 --- /dev/null +++ b/cpp/include/raft/distance/fused_distance_nn.cuh @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "fused_distance_nn-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "fused_distance_nn-ext.cuh" +#endif diff --git a/cpp/include/raft/distance/fused_l2_nn_helpers.cuh b/cpp/include/raft/distance/fused_distance_nn_helpers.cuh similarity index 89% rename from cpp/include/raft/distance/fused_l2_nn_helpers.cuh rename to cpp/include/raft/distance/fused_distance_nn_helpers.cuh index 996f696ef6..e70d098d09 100644 --- a/cpp/include/raft/distance/fused_l2_nn_helpers.cuh +++ b/cpp/include/raft/distance/fused_distance_nn_helpers.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,8 @@ #pragma once #include -#include +// #include +#include namespace raft::distance { diff --git a/cpp/include/raft/distance/fused_l2_nn-ext.cuh b/cpp/include/raft/distance/fused_l2_nn-ext.cuh index c99c1eb015..66e9960f1d 100644 --- a/cpp/include/raft/distance/fused_l2_nn-ext.cuh +++ b/cpp/include/raft/distance/fused_l2_nn-ext.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,11 +16,11 @@ #pragma once -#include // int64_t -#include // raft::KeyValuePair -#include // raft::resources -#include // include initialize and reduce operations -#include // RAFT_EXPLICIT +#include // int64_t +#include // raft::KeyValuePair +#include // raft::resources +#include // include initialize and reduce operations +#include // RAFT_EXPLICIT #ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY diff --git a/cpp/include/raft/distance/fused_l2_nn-inl.cuh b/cpp/include/raft/distance/fused_l2_nn-inl.cuh index 17373e3bcc..4cb6b367a5 100644 --- a/cpp/include/raft/distance/fused_l2_nn-inl.cuh +++ b/cpp/include/raft/distance/fused_l2_nn-inl.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/cpp/include/raft_runtime/distance/fused_distance_nn.hpp b/cpp/include/raft_runtime/distance/fused_distance_nn.hpp new file mode 100644 index 0000000000..6580cfa639 --- /dev/null +++ b/cpp/include/raft_runtime/distance/fused_distance_nn.hpp @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::runtime::distance { + +/** + * @defgroup fused_distance_nn_min_arg_runtime Fused Distance 1NN Runtime API + * @{ + */ + +/** + * @brief Wrapper around fusedDistanceNN with minimum reduction operators. + * + * fusedDistanceNN cannot be compiled in the distance library due to the lambda + * operators, so this wrapper covers the most common case (minimum). + * + * @param[in] handle raft handle + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + * @param[in] metric Distance metric to be used (supports L2, cosine) + * @param[in] isRowMajor whether the input/output is row or column major. + * @param[in] metric_arg power argument for distances like Minkowski (not supported for now) + */ +void fused_distance_nn_min_arg(raft::resources const& handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg); + +void fused_distance_nn_min_arg(raft::resources const& handle, + int* min, + const double* x, + const double* y, + int m, + int n, + int k, + bool sqrt, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg); + +/** @} */ // end group fused_distance_nn_min_arg_runtime + +} // end namespace raft::runtime::distance diff --git a/cpp/src/distance/fused_distance_nn.cu b/cpp/src/distance/fused_distance_nn.cu new file mode 100644 index 0000000000..c3d1301e29 --- /dev/null +++ b/cpp/src/distance/fused_distance_nn.cu @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // int64_t +#include // raft::KeyValuePair +#include + +#define instantiate_raft_distance_fusedDistanceNNMinReduce(DataT, OutT, IdxT) \ + template void raft::distance::fusedDistanceNNMinReduce( \ + OutT * min, \ + const DataT* x, \ + const DataT* y, \ + const DataT* xn, \ + const DataT* yn, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + bool sqrt, \ + bool initOutBuffer, \ + bool isRowMajor, \ + raft::distance::DistanceType metric, \ + float metric_arg, \ + cudaStream_t stream) + +instantiate_raft_distance_fusedDistanceNNMinReduce(double, double, int); +instantiate_raft_distance_fusedDistanceNNMinReduce(double, double, int64_t); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int64_t); + +// We can't have comma's in the macro expansion, so we use the COMMA macro: +#define COMMA , + +instantiate_raft_distance_fusedDistanceNNMinReduce(double, + raft::KeyValuePair, + int); +instantiate_raft_distance_fusedDistanceNNMinReduce(double, + raft::KeyValuePair, + int64_t); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, raft::KeyValuePair, int); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, + raft::KeyValuePair, + int64_t); + +#undef COMMA + +#undef instantiate_raft_distance_fusedDistanceNNMinReduce diff --git a/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu b/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu new file mode 100644 index 0000000000..90d00d9f6b --- /dev/null +++ b/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu @@ -0,0 +1,137 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::runtime::distance { + +template +struct KeyValueIndexOp { + __host__ __device__ __forceinline__ IndexT + operator()(const raft::KeyValuePair& a) const + { + return a.key; + } +}; + +template +void compute_fused_cosine_nn_min_arg(raft::resources const& handle, + idx_t* min, + const value_t* x, + const value_t* y, + idx_t m, + idx_t n, + idx_t k, + bool sqrt) +{ + rmm::device_uvector workspace(m, resource::get_cuda_stream(handle)); + auto kvp = raft::make_device_vector>(handle, m); + + rmm::device_uvector x_norms(m, resource::get_cuda_stream(handle)); + rmm::device_uvector y_norms(n, resource::get_cuda_stream(handle)); + constexpr bool is_row_major = true; + raft::linalg::rowNorm(x_norms.data(), + x, + k, + m, + raft::linalg::L2Norm, + is_row_major, + resource::get_cuda_stream(handle), + raft::sqrt_op{}); + raft::linalg::rowNorm(y_norms.data(), + y, + k, + n, + raft::linalg::L2Norm, + is_row_major, + resource::get_cuda_stream(handle), + raft::sqrt_op{}); + + raft::distance::fusedDistanceNNMinReduce(kvp.data_handle(), + x, + y, + x_norms.data(), + y_norms.data(), + m, + n, + k, + (void*)workspace.data(), + sqrt, + true, + is_row_major, + raft::distance::DistanceType::CosineExpanded, + 0.0f, + resource::get_cuda_stream(handle)); + + KeyValueIndexOp conversion_op; + thrust::transform(resource::get_thrust_policy(handle), + kvp.data_handle(), + kvp.data_handle() + m, + min, + conversion_op); + resource::sync_stream(handle); +} + +void fused_distance_nn_min_arg(raft::resources const& handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg) +{ + switch (metric) { + case raft::distance::DistanceType::CosineExpanded: + compute_fused_cosine_nn_min_arg(handle, min, x, y, m, n, k, sqrt); + break; + default: assert("only cosine metric is supported with fusedDistanceNN\n"); break; + } +} + +void fused_distance_nn_min_arg(raft::resources const& handle, + int* min, + const double* x, + const double* y, + int m, + int n, + int k, + bool sqrt, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg) +{ + switch (metric) { + case raft::distance::DistanceType::CosineExpanded: + compute_fused_cosine_nn_min_arg(handle, min, x, y, m, n, k, sqrt); + break; + default: assert("only cosine metric is supported with fusedDistanceNN\n"); break; + } +} + +} // end namespace raft::runtime::distance diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index fe29409d9b..2a1384e96e 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -170,6 +170,7 @@ if(BUILD_TESTS) test/distance/masked_nn.cu test/distance/masked_nn_compress_to_bits.cu test/distance/fused_l2_nn.cu + test/distance/fused_cosine_nn.cu test/distance/gram.cu LIB EXPLICIT_INSTANTIATE_ONLY diff --git a/cpp/test/distance/fused_cosine_nn.cu b/cpp/test/distance/fused_cosine_nn.cu new file mode 100644 index 0000000000..5a89e71608 --- /dev/null +++ b/cpp/test/distance/fused_cosine_nn.cu @@ -0,0 +1,416 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace distance { + +template +struct RaftKVPMinReduce { + typedef raft::KeyValuePair KVP; + + DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + + DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + +}; // KVPMinReduce + +template +__global__ void naiveCosKernel(raft::KeyValuePair* min, + DataT* x, + DataT* y, + int m, + int n, + int k, + int* workspace, + DataT maxVal) +{ + int midx = threadIdx.y + blockIdx.y * blockDim.y; + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + DataT acc_a = DataT(0); + DataT acc_b = DataT(0); + DataT acc_ab = DataT(0); + // if (midx >= m || nidx >= n) { return; } + + for (int i = 0; i < k; ++i) { + int xidx = i + midx * k; + int yidx = i + nidx * k; + auto a = x[xidx]; + auto b = y[yidx]; + acc_a += a * a; + acc_b += b * b; + acc_ab += a * b; + } + + // Use 1.0 - (cosine similarity) to calc the distance + DataT acc = maxVal; + if (midx < m || nidx < n) { acc = (DataT)1.0 - acc_ab / (raft::sqrt(acc_a) * raft::sqrt(acc_b)); } + + ReduceOpT redOp; + typedef cub::WarpReduce> WarpReduce; + __shared__ typename WarpReduce::TempStorage temp[NWARPS]; + int warpId = threadIdx.x / raft::WarpSize; + raft::KeyValuePair tmp; + tmp.key = nidx; + tmp.value = midx >= m || nidx >= n ? maxVal : acc; + tmp = WarpReduce(temp[warpId]).Reduce(tmp, RaftKVPMinReduce()); + if (threadIdx.x % raft::WarpSize == 0 && midx < m) { + while (atomicCAS(workspace + midx, 0, 1) == 1) + ; + __threadfence(); + redOp(midx, min + midx, tmp); + __threadfence(); + atomicCAS(workspace + midx, 1, 0); + } +} + +template +void naive(raft::KeyValuePair* min, + DataT* x, + DataT* y, + int m, + int n, + int k, + int* workspace, + cudaStream_t stream) +{ + static const dim3 TPB(32, 16, 1); + dim3 nblks(raft::ceildiv(n, (int)TPB.x), raft::ceildiv(m, (int)TPB.y), 1); + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + auto blks = raft::ceildiv(m, 256); + MinAndDistanceReduceOp op; + detail::initKernel, int> + <<>>(min, m, std::numeric_limits::max(), op); + RAFT_CUDA_TRY(cudaGetLastError()); + naiveCosKernel, 16> + <<>>(min, x, y, m, n, k, workspace, std::numeric_limits::max()); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +template +struct Inputs { + DataT tolerance; + int m, n, k; + unsigned long long int seed; + + friend std::ostream& operator<<(std::ostream& os, const Inputs& p) + { + return os << "m: " << p.m + << ", " + "n: " + << p.n + << ", " + "k: " + << p.k + << ", " + "seed: " + << p.seed + << ", " + "tol: " + << p.tolerance; + } +}; + +template +class FusedCosineNNTest : public ::testing::TestWithParam> { + public: + FusedCosineNNTest() + : params(::testing::TestWithParam>::GetParam()), + stream(resource::get_cuda_stream(handle)), + x(params.m * params.k, stream), + y(params.n * params.k, stream), + xn(params.m, stream), + yn(params.n, stream), + min(params.m, stream), + min_ref(params.m, stream), + workspace(params.m * sizeof(int), stream) + { + } + + protected: + void SetUp() override + { + raft::random::RngState r(params.seed); + int m = params.m; + int n = params.n; + int k = params.k; + uniform(handle, r, x.data(), m * k, DataT(-1.0), DataT(1.0)); + uniform(handle, r, y.data(), n * k, DataT(-1.0), DataT(1.0)); + generateGoldenResult(); + raft::linalg::rowNorm( + xn.data(), x.data(), k, m, raft::linalg::L2Norm, true, stream, raft::sqrt_op{}); + raft::linalg::rowNorm( + yn.data(), y.data(), k, n, raft::linalg::L2Norm, true, stream, raft::sqrt_op{}); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + protected: + raft::resources handle; + cudaStream_t stream; + Inputs params; + rmm::device_uvector x; + rmm::device_uvector y; + rmm::device_uvector xn; + rmm::device_uvector yn; + rmm::device_uvector> min; + rmm::device_uvector> min_ref; + rmm::device_uvector workspace; + + virtual void generateGoldenResult() + { + int m = params.m; + int n = params.n; + int k = params.k; + naive(min_ref.data(), x.data(), y.data(), m, n, k, (int*)workspace.data(), stream); + } + + void runTest(raft::KeyValuePair* out) + { + int m = params.m; + int n = params.n; + int k = params.k; + raft::distance::DistanceType metric = raft::distance::DistanceType::CosineExpanded; + constexpr bool init_out_buffer = true; + fusedDistanceNNMinReduce, int>(out, + x.data(), + y.data(), + xn.data(), + yn.data(), + m, + n, + k, + (void*)workspace.data(), + false, + init_out_buffer, + true, + metric, + 0.0f, + stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } +}; + +template +struct CompareApproxAbsKVP { + typedef typename raft::KeyValuePair KVP; + CompareApproxAbsKVP(T eps_) : eps(eps_) {} + bool operator()(const KVP& a, const KVP& b) const + { + T diff = std::abs(std::abs(a.value) - std::abs(b.value)); + T m = std::max(std::abs(a.value), std::abs(b.value)); + T ratio = m >= eps ? diff / m : diff; + return (ratio <= eps); + } + + private: + T eps; +}; + +template +struct CompareExactKVP { + typedef typename raft::KeyValuePair KVP; + bool operator()(const KVP& a, const KVP& b) const + { + if (a.value != b.value) return false; + return true; + } +}; + +template +::testing::AssertionResult devArrMatch(const raft::KeyValuePair* expected, + const raft::KeyValuePair* actual, + size_t size, + L eq_compare, + cudaStream_t stream = 0) +{ + typedef typename raft::KeyValuePair KVP; + std::shared_ptr exp_h(new KVP[size]); + std::shared_ptr act_h(new KVP[size]); + raft::update_host(exp_h.get(), expected, size, stream); + raft::update_host(act_h.get(), actual, size, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (size_t i(0); i < size; ++i) { + auto exp = exp_h.get()[i]; + auto act = act_h.get()[i]; + if (!eq_compare(exp, act)) { + return ::testing::AssertionFailure() + << "actual=" << act.key << "," << act.value << " != expected=" << exp.key << "," + << exp.value << " @" << i; + } + } + return ::testing::AssertionSuccess(); +} + +const std::vector> inputsf = { + {0.001f, 32, 32, 32, 1234ULL}, + {0.001f, 32, 64, 32, 1234ULL}, + {0.001f, 64, 32, 32, 1234ULL}, + {0.001f, 64, 64, 32, 1234ULL}, + {0.001f, 128, 32, 32, 1234ULL}, + {0.001f, 128, 64, 32, 1234ULL}, + {0.001f, 128, 128, 64, 1234ULL}, + {0.001f, 64, 128, 128, 1234ULL}, + + {0.001f, 32, 32, 34, 1234ULL}, + {0.001f, 32, 64, 34, 1234ULL}, + {0.001f, 64, 32, 34, 1234ULL}, + {0.001f, 64, 64, 34, 1234ULL}, + {0.001f, 128, 32, 34, 1234ULL}, + {0.001f, 128, 64, 34, 1234ULL}, + {0.001f, 128, 128, 66, 1234ULL}, + {0.001f, 64, 128, 130, 1234ULL}, + + {0.001f, 32, 32, 33, 1234ULL}, + {0.001f, 32, 64, 33, 1234ULL}, + {0.001f, 64, 32, 33, 1234ULL}, + {0.001f, 64, 64, 33, 1234ULL}, + {0.001f, 128, 32, 33, 1234ULL}, + {0.001f, 128, 64, 33, 1234ULL}, + {0.001f, 128, 128, 65, 1234ULL}, + {0.001f, 64, 128, 129, 1234ULL}, + {0.006f, 1805, 134, 2, 1234ULL}, + {0.006f, 8192, 1024, 64, 1234ULL}, + {0.006f, 8192, 1025, 64, 1234ULL}, + + // Repeat with smaller values of k + {0.006f, 32, 32, 1, 1234ULL}, + {0.001f, 32, 64, 2, 1234ULL}, + {0.001f, 64, 32, 3, 1234ULL}, + {0.001f, 64, 64, 4, 1234ULL}, + {0.001f, 128, 32, 5, 1234ULL}, + {0.001f, 128, 64, 6, 1234ULL}, + {0.001f, 128, 128, 7, 1234ULL}, + {0.001f, 64, 128, 8, 1234ULL}, + + {0.001f, 32, 32, 9, 1234ULL}, + {0.001f, 32, 64, 10, 1234ULL}, + {0.001f, 64, 32, 11, 1234ULL}, + {0.001f, 64, 64, 12, 1234ULL}, + {0.001f, 128, 32, 13, 1234ULL}, + {0.001f, 128, 64, 14, 1234ULL}, + {0.001f, 128, 128, 15, 1234ULL}, + {0.001f, 64, 128, 16, 1234ULL}, + + {0.001f, 32, 32, 17, 1234ULL}, + {0.001f, 32, 64, 18, 1234ULL}, + {0.001f, 64, 32, 19, 1234ULL}, + {0.001f, 64, 64, 20, 1234ULL}, + {0.001f, 128, 32, 21, 1234ULL}, + {0.001f, 128, 64, 22, 1234ULL}, + {0.001f, 128, 128, 23, 1234ULL}, + {0.00001, 64, 128, 24, 1234ULL}, + {0.001f, 1805, 134, 25, 1234ULL}, + {0.006f, 8192, 1024, 25, 1234ULL}, + {0.006f, 8192, 1024, 66, 1234ULL}, +}; +typedef FusedCosineNNTest FusedCosineNNTestF; +TEST_P(FusedCosineNNTestF, Result) +{ + runTest(min.data()); + ASSERT_TRUE(devArrMatch( + min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(FusedCosineNNTests, FusedCosineNNTestF, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.00001, 32, 32, 32, 1234ULL}, {0.00001, 32, 64, 32, 1234ULL}, + {0.00001, 64, 32, 32, 1234ULL}, {0.00001, 64, 64, 32, 1234ULL}, + {0.00001, 128, 32, 32, 1234ULL}, {0.00001, 128, 64, 32, 1234ULL}, + {0.00001, 128, 128, 64, 1234ULL}, {0.00001, 64, 128, 128, 1234ULL}, + + {0.00001, 32, 32, 34, 1234ULL}, {0.00001, 32, 64, 34, 1234ULL}, + {0.00001, 64, 32, 34, 1234ULL}, {0.00001, 64, 64, 34, 1234ULL}, + {0.00001, 128, 32, 34, 1234ULL}, {0.00001, 128, 64, 34, 1234ULL}, + {0.00001, 128, 128, 66, 1234ULL}, {0.00001, 64, 128, 130, 1234ULL}, + + {0.00001, 32, 32, 33, 1234ULL}, {0.00001, 32, 64, 33, 1234ULL}, + {0.00001, 64, 32, 33, 1234ULL}, {0.00001, 64, 64, 33, 1234ULL}, + {0.00001, 128, 32, 33, 1234ULL}, {0.00001, 128, 64, 33, 1234ULL}, + {0.00001, 128, 128, 65, 1234ULL}, {0.00001, 64, 128, 129, 1234ULL}, + + {0.00001, 1805, 134, 2, 1234ULL}, {0.00001, 8192, 1024, 25, 1234ULL}, +}; +typedef FusedCosineNNTest FusedCosineNNTestD; +TEST_P(FusedCosineNNTestD, Result) +{ + runTest(min.data()); + ASSERT_TRUE(devArrMatch( + min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(FusedCosineNNTests, FusedCosineNNTestD, ::testing::ValuesIn(inputsd)); + +/// This is to test output determinism of the prim +template +class FusedCosineNNDetTest : public FusedCosineNNTest { + public: + FusedCosineNNDetTest() : stream(resource::get_cuda_stream(handle)), min1(0, stream) {} + + void SetUp() override + { + FusedCosineNNTest::SetUp(); + int m = this->params.m; + min1.resize(m, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + void TearDown() override { FusedCosineNNTest::TearDown(); } + + protected: + raft::resources handle; + cudaStream_t stream; + + rmm::device_uvector> min1; + + static const int NumRepeats = 3; + + void generateGoldenResult() override {} +}; + +typedef FusedCosineNNDetTest FusedCosineNNDetTestF; +TEST_P(FusedCosineNNDetTestF, Result) +{ + runTest(min.data()); // assumed to be golden + for (int i = 0; i < NumRepeats; ++i) { + runTest(min1.data()); + ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); + cudaMemsetAsync(min1.data(), 0, sizeof(*min.data()) * params.m, stream); + } +} +INSTANTIATE_TEST_CASE_P(FusedCosineNNDetTests, FusedCosineNNDetTestF, ::testing::ValuesIn(inputsf)); + +typedef FusedCosineNNDetTest FusedCosineNNDetTestD; +TEST_P(FusedCosineNNDetTestD, Result) +{ + runTest(min.data()); // assumed to be golden + for (int i = 0; i < NumRepeats; ++i) { + runTest(min1.data()); + ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); + } +} +INSTANTIATE_TEST_CASE_P(FusedCosineNNDetTests, FusedCosineNNDetTestD, ::testing::ValuesIn(inputsd)); + +} // end namespace distance +} // end namespace raft From 538440897bae7c967c842da37476daaf534993f6 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 16 Feb 2024 07:15:54 -0800 Subject: [PATCH 02/12] remove double datatype API, code cleanup and other review comments --- .../fused_distance_nn/persistent_gemm.h | 1 - .../predicated_tile_iterator_reduced_vec.h | 16 --------------- .../raft/distance/fused_distance_nn-ext.cuh | 10 +--------- .../raft/distance/fused_distance_nn-inl.cuh | 2 +- .../raft/distance/fused_distance_nn.cuh | 2 +- .../distance/fused_distance_nn_helpers.cuh | 1 - .../distance/fused_distance_nn.hpp | 12 ----------- cpp/src/distance/fused_distance_nn.cu | 8 -------- .../distance/fused_distance_min_arg.cu | 20 ------------------- cpp/test/distance/fused_cosine_nn.cu | 2 ++ 10 files changed, 5 insertions(+), 69 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h index a04fe36b79..223af7eb58 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h @@ -205,7 +205,6 @@ struct FusedDistanceNNPersistent { void const* ptr_B, void const* ptr_C, void* ptr_Vector, - // volatile void* ptr_Tensor, void* ptr_Tensor, typename LayoutA::Stride::Index lda, typename LayoutB::Stride::Index ldb, diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h index 4591fa7855..81e7819223 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h @@ -321,11 +321,8 @@ class PredicatedTileIteratorReducedVec { /// Parameters structure containing reference and precomputed state. Params params_; - /// Byte-level pointer - // uint8_t* byte_pointer_; /// Byte-level pointer first tile offset of this threadblock. volatile uint8_t* first_tile_byte_pointer_; - // uint8_t* first_tile_byte_pointer_; /// Array of boolean values to contain steady-state predicates Mask mask_; @@ -421,8 +418,6 @@ class PredicatedTileIteratorReducedVec { first_tile_byte_pointer_ = reinterpret_cast(pointer) + LongIndex(block_offset.row()) * LongIndex(params_.stride); - // first_tile_byte_pointer_ = reinterpret_cast(pointer) + - // LongIndex(block_offset.row()) * LongIndex(params_.stride); // Initialize internal state counter state_[0] = state_[1] = state_[2] = 0; } @@ -468,13 +463,6 @@ class PredicatedTileIteratorReducedVec { CUTLASS_DEVICE ~PredicatedTileIteratorReducedVec() {} - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) - { - // byte_pointer_ += pointer_offset * sizeof_bits::value / 8; - } - /// Performs reduction and Stores a reduced output to memory CUTLASS_DEVICE void store_with_byte_offset(Fragment& frag, int64_t byte_offset) const @@ -574,14 +562,12 @@ class PredicatedTileIteratorReducedVec { PredicatedTileIteratorReducedVec& operator++() { ++state_[0]; - // if (!ScatterD) { byte_pointer_ += params_.advance_row; } thread_start_row_ += ThreadMap::Shape::kRow; if (state_[0] == ThreadMap::Count::kRow) { state_[0] = 0; ++state_[1]; - // byte_pointer_ += params_.advance_group; thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; @@ -589,14 +575,12 @@ class PredicatedTileIteratorReducedVec { if (state_[1] == ThreadMap::Count::kGroup) { state_[1] = 0; ++state_[2]; - // byte_pointer_ += params_.advance_cluster; thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; if (state_[2] == ThreadMap::Count::kCluster) { state_[2] = 0; - // byte_pointer_ += params_.advance_tile; } } } diff --git a/cpp/include/raft/distance/fused_distance_nn-ext.cuh b/cpp/include/raft/distance/fused_distance_nn-ext.cuh index 9dd236a3bd..0b9096423a 100644 --- a/cpp/include/raft/distance/fused_distance_nn-ext.cuh +++ b/cpp/include/raft/distance/fused_distance_nn-ext.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -67,20 +67,12 @@ void fusedDistanceNNMinReduce(OutT* min, float metric_arg, \ cudaStream_t stream) -instantiate_raft_distance_fusedDistanceNNMinReduce(double, double, int); -instantiate_raft_distance_fusedDistanceNNMinReduce(double, double, int64_t); instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int); instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int64_t); // We can't have comma's in the macro expansion, so we use the COMMA macro: #define COMMA , -instantiate_raft_distance_fusedDistanceNNMinReduce(double, - raft::KeyValuePair, - int); -instantiate_raft_distance_fusedDistanceNNMinReduce(double, - raft::KeyValuePair, - int64_t); instantiate_raft_distance_fusedDistanceNNMinReduce(float, raft::KeyValuePair, int); instantiate_raft_distance_fusedDistanceNNMinReduce(float, raft::KeyValuePair, diff --git a/cpp/include/raft/distance/fused_distance_nn-inl.cuh b/cpp/include/raft/distance/fused_distance_nn-inl.cuh index 342bde828d..5ec4b8c5cf 100644 --- a/cpp/include/raft/distance/fused_distance_nn-inl.cuh +++ b/cpp/include/raft/distance/fused_distance_nn-inl.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/distance/fused_distance_nn.cuh b/cpp/include/raft/distance/fused_distance_nn.cuh index 0c22df72f1..04c42e49a1 100755 --- a/cpp/include/raft/distance/fused_distance_nn.cuh +++ b/cpp/include/raft/distance/fused_distance_nn.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/distance/fused_distance_nn_helpers.cuh b/cpp/include/raft/distance/fused_distance_nn_helpers.cuh index e70d098d09..3a570c681c 100644 --- a/cpp/include/raft/distance/fused_distance_nn_helpers.cuh +++ b/cpp/include/raft/distance/fused_distance_nn_helpers.cuh @@ -17,7 +17,6 @@ #pragma once #include -// #include #include namespace raft::distance { diff --git a/cpp/include/raft_runtime/distance/fused_distance_nn.hpp b/cpp/include/raft_runtime/distance/fused_distance_nn.hpp index 6580cfa639..09d8d401e4 100644 --- a/cpp/include/raft_runtime/distance/fused_distance_nn.hpp +++ b/cpp/include/raft_runtime/distance/fused_distance_nn.hpp @@ -57,18 +57,6 @@ void fused_distance_nn_min_arg(raft::resources const& handle, bool isRowMajor, float metric_arg); -void fused_distance_nn_min_arg(raft::resources const& handle, - int* min, - const double* x, - const double* y, - int m, - int n, - int k, - bool sqrt, - raft::distance::DistanceType metric, - bool isRowMajor, - float metric_arg); - /** @} */ // end group fused_distance_nn_min_arg_runtime } // end namespace raft::runtime::distance diff --git a/cpp/src/distance/fused_distance_nn.cu b/cpp/src/distance/fused_distance_nn.cu index c3d1301e29..fc8a6cb26d 100644 --- a/cpp/src/distance/fused_distance_nn.cu +++ b/cpp/src/distance/fused_distance_nn.cu @@ -36,20 +36,12 @@ float metric_arg, \ cudaStream_t stream) -instantiate_raft_distance_fusedDistanceNNMinReduce(double, double, int); -instantiate_raft_distance_fusedDistanceNNMinReduce(double, double, int64_t); instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int); instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int64_t); // We can't have comma's in the macro expansion, so we use the COMMA macro: #define COMMA , -instantiate_raft_distance_fusedDistanceNNMinReduce(double, - raft::KeyValuePair, - int); -instantiate_raft_distance_fusedDistanceNNMinReduce(double, - raft::KeyValuePair, - int64_t); instantiate_raft_distance_fusedDistanceNNMinReduce(float, raft::KeyValuePair, int); instantiate_raft_distance_fusedDistanceNNMinReduce(float, raft::KeyValuePair, diff --git a/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu b/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu index 90d00d9f6b..1899b1616f 100644 --- a/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu +++ b/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu @@ -114,24 +114,4 @@ void fused_distance_nn_min_arg(raft::resources const& handle, } } -void fused_distance_nn_min_arg(raft::resources const& handle, - int* min, - const double* x, - const double* y, - int m, - int n, - int k, - bool sqrt, - raft::distance::DistanceType metric, - bool isRowMajor, - float metric_arg) -{ - switch (metric) { - case raft::distance::DistanceType::CosineExpanded: - compute_fused_cosine_nn_min_arg(handle, min, x, y, m, n, k, sqrt); - break; - default: assert("only cosine metric is supported with fusedDistanceNN\n"); break; - } -} - } // end namespace raft::runtime::distance diff --git a/cpp/test/distance/fused_cosine_nn.cu b/cpp/test/distance/fused_cosine_nn.cu index 5a89e71608..e87692094f 100644 --- a/cpp/test/distance/fused_cosine_nn.cu +++ b/cpp/test/distance/fused_cosine_nn.cu @@ -14,6 +14,8 @@ * limitations under the License. */ +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Search with filter instantiation + #include "../test_utils.cuh" #include #include From 0ab7a842fea505be07d78f817f549a1ff6a17553 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 16 Feb 2024 07:35:49 -0800 Subject: [PATCH 03/12] fix formatting issues --- .../fused_distance_nn/predicated_tile_iterator_reduced_vec.h | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h index 81e7819223..5ceb0dbaf5 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h @@ -579,9 +579,7 @@ class PredicatedTileIteratorReducedVec { thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; - if (state_[2] == ThreadMap::Count::kCluster) { - state_[2] = 0; - } + if (state_[2] == ThreadMap::Count::kCluster) { state_[2] = 0; } } } return *this; From e9090d6ebd1159c8c146452708ca3b2d9a5e6e31 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 22 Feb 2024 10:15:23 -0800 Subject: [PATCH 04/12] unify fusedl2nn with fuseddistanceNN, add deprecation warning for fused_l2_nn_min_arg, support only float for fused_distance_nn --- .../distance/detail/fused_distance_nn.cuh | 11 +- .../{ => fused_distance_nn}/fused_l2_nn.cuh | 0 cpp/include/raft/distance/fused_l2_nn-inl.cuh | 2 +- .../raft_runtime/distance/fused_l2_nn.hpp | 36 ++--- .../distance/fused_distance_min_arg.cu | 75 +-------- .../distance/fused_distance_min_arg.hpp | 143 ++++++++++++++++++ .../raft_runtime/distance/fused_l2_min_arg.cu | 88 +++-------- cpp/test/distance/fused_l2_nn.cu | 3 +- 8 files changed, 199 insertions(+), 159 deletions(-) rename cpp/include/raft/distance/detail/{ => fused_distance_nn}/fused_l2_nn.cuh (100%) create mode 100644 cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp diff --git a/cpp/include/raft/distance/detail/fused_distance_nn.cuh b/cpp/include/raft/distance/detail/fused_distance_nn.cuh index 94f199275d..f679c3584b 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn.cuh @@ -23,6 +23,7 @@ #include // ops::l2_exp_distance_op #include #include +#include #include #include #include // PairwiseDistances @@ -76,11 +77,17 @@ void fusedDistanceNNImpl(OutT* min, } switch (metric) { - case DistanceType::CosineExpanded: + case raft::distance::DistanceType::CosineExpanded: fusedCosineNN( min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, stream); break; - default: assert("only cosine metric is supported with fusedDistanceNN\n"); break; + case raft::distance::DistanceType::L2SqrtExpanded: + case raft::distance::DistanceType::L2Expanded: + // initOutBuffer is take care by fusedDistanceNNImpl() so we set it false to fusedL2NNImpl. + fusedL2NNImpl( + min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, false, stream); + break; + default: assert("only cosine/l2 metric is supported with fusedDistanceNN\n"); break; } } diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/fused_l2_nn.cuh similarity index 100% rename from cpp/include/raft/distance/detail/fused_l2_nn.cuh rename to cpp/include/raft/distance/detail/fused_distance_nn/fused_l2_nn.cuh diff --git a/cpp/include/raft/distance/fused_l2_nn-inl.cuh b/cpp/include/raft/distance/fused_l2_nn-inl.cuh index 4cb6b367a5..10af89e051 100644 --- a/cpp/include/raft/distance/fused_l2_nn-inl.cuh +++ b/cpp/include/raft/distance/fused_l2_nn-inl.cuh @@ -22,7 +22,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/cpp/include/raft_runtime/distance/fused_l2_nn.hpp b/cpp/include/raft_runtime/distance/fused_l2_nn.hpp index 6154e03f4c..e46b3c5271 100644 --- a/cpp/include/raft_runtime/distance/fused_l2_nn.hpp +++ b/cpp/include/raft_runtime/distance/fused_l2_nn.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,23 +42,25 @@ namespace raft::runtime::distance { * @param[in] k gemm k * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt */ -void fused_l2_nn_min_arg(raft::resources const& handle, - int* min, - const float* x, - const float* y, - int m, - int n, - int k, - bool sqrt); +[[deprecated("use fused_distance_nn_min_arg instead")]] void fused_l2_nn_min_arg( + raft::resources const& handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt); -void fused_l2_nn_min_arg(raft::resources const& handle, - int* min, - const double* x, - const double* y, - int m, - int n, - int k, - bool sqrt); +[[deprecated("use fused_distance_nn_min_arg instead")]] void fused_l2_nn_min_arg( + raft::resources const& handle, + int* min, + const double* x, + const double* y, + int m, + int n, + int k, + bool sqrt); /** @} */ // end group fused_l2_nn_min_arg_runtime diff --git a/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu b/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu index 1899b1616f..6c6caf8687 100644 --- a/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu +++ b/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu @@ -14,86 +14,19 @@ * limitations under the License. */ +#include "fused_distance_min_arg.hpp" #include #include #include #include #include #include -#include #include #include #include namespace raft::runtime::distance { -template -struct KeyValueIndexOp { - __host__ __device__ __forceinline__ IndexT - operator()(const raft::KeyValuePair& a) const - { - return a.key; - } -}; - -template -void compute_fused_cosine_nn_min_arg(raft::resources const& handle, - idx_t* min, - const value_t* x, - const value_t* y, - idx_t m, - idx_t n, - idx_t k, - bool sqrt) -{ - rmm::device_uvector workspace(m, resource::get_cuda_stream(handle)); - auto kvp = raft::make_device_vector>(handle, m); - - rmm::device_uvector x_norms(m, resource::get_cuda_stream(handle)); - rmm::device_uvector y_norms(n, resource::get_cuda_stream(handle)); - constexpr bool is_row_major = true; - raft::linalg::rowNorm(x_norms.data(), - x, - k, - m, - raft::linalg::L2Norm, - is_row_major, - resource::get_cuda_stream(handle), - raft::sqrt_op{}); - raft::linalg::rowNorm(y_norms.data(), - y, - k, - n, - raft::linalg::L2Norm, - is_row_major, - resource::get_cuda_stream(handle), - raft::sqrt_op{}); - - raft::distance::fusedDistanceNNMinReduce(kvp.data_handle(), - x, - y, - x_norms.data(), - y_norms.data(), - m, - n, - k, - (void*)workspace.data(), - sqrt, - true, - is_row_major, - raft::distance::DistanceType::CosineExpanded, - 0.0f, - resource::get_cuda_stream(handle)); - - KeyValueIndexOp conversion_op; - thrust::transform(resource::get_thrust_policy(handle), - kvp.data_handle(), - kvp.data_handle() + m, - min, - conversion_op); - resource::sync_stream(handle); -} - void fused_distance_nn_min_arg(raft::resources const& handle, int* min, const float* x, @@ -110,7 +43,11 @@ void fused_distance_nn_min_arg(raft::resources const& handle, case raft::distance::DistanceType::CosineExpanded: compute_fused_cosine_nn_min_arg(handle, min, x, y, m, n, k, sqrt); break; - default: assert("only cosine metric is supported with fusedDistanceNN\n"); break; + case raft::distance::DistanceType::L2Expanded: + case raft::distance::DistanceType::L2SqrtExpanded: + compute_fused_l2_nn_min_arg(handle, min, x, y, m, n, k, sqrt); + break; + default: assert("only Cosine/L2 metric is supported with fusedDistanceNN\n"); break; } } diff --git a/cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp b/cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp new file mode 100644 index 0000000000..d348e7755b --- /dev/null +++ b/cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp @@ -0,0 +1,143 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::runtime::distance { + +template +struct KeyValueIndexOp { + __host__ __device__ __forceinline__ IndexT + operator()(const raft::KeyValuePair& a) const + { + return a.key; + } +}; + +template +void compute_fused_l2_nn_min_arg(raft::resources const& handle, + idx_t* min, + const value_t* x, + const value_t* y, + idx_t m, + idx_t n, + idx_t k, + bool sqrt) +{ + rmm::device_uvector workspace(m, resource::get_cuda_stream(handle)); + auto kvp = raft::make_device_vector>(handle, m); + constexpr bool is_row_major = true; + + rmm::device_uvector x_norms(m, resource::get_cuda_stream(handle)); + rmm::device_uvector y_norms(n, resource::get_cuda_stream(handle)); + raft::linalg::rowNorm( + x_norms.data(), x, k, m, raft::linalg::L2Norm, is_row_major, resource::get_cuda_stream(handle)); + raft::linalg::rowNorm( + y_norms.data(), y, k, n, raft::linalg::L2Norm, is_row_major, resource::get_cuda_stream(handle)); + + raft::distance::fusedL2NNMinReduce(kvp.data_handle(), + x, + y, + x_norms.data(), + y_norms.data(), + m, + n, + k, + (void*)workspace.data(), + sqrt, + true, + resource::get_cuda_stream(handle)); + + KeyValueIndexOp conversion_op; + thrust::transform(resource::get_thrust_policy(handle), + kvp.data_handle(), + kvp.data_handle() + m, + min, + conversion_op); + resource::sync_stream(handle); +} + +template +void compute_fused_cosine_nn_min_arg(raft::resources const& handle, + idx_t* min, + const value_t* x, + const value_t* y, + idx_t m, + idx_t n, + idx_t k, + bool sqrt) +{ + rmm::device_uvector workspace(m, resource::get_cuda_stream(handle)); + auto kvp = raft::make_device_vector>(handle, m); + + rmm::device_uvector x_norms(m, resource::get_cuda_stream(handle)); + rmm::device_uvector y_norms(n, resource::get_cuda_stream(handle)); + constexpr bool is_row_major = true; + raft::linalg::rowNorm(x_norms.data(), + x, + k, + m, + raft::linalg::L2Norm, + is_row_major, + resource::get_cuda_stream(handle), + raft::sqrt_op{}); + raft::linalg::rowNorm(y_norms.data(), + y, + k, + n, + raft::linalg::L2Norm, + is_row_major, + resource::get_cuda_stream(handle), + raft::sqrt_op{}); + + raft::distance::fusedDistanceNNMinReduce(kvp.data_handle(), + x, + y, + x_norms.data(), + y_norms.data(), + m, + n, + k, + (void*)workspace.data(), + sqrt, + true, + is_row_major, + raft::distance::DistanceType::CosineExpanded, + 0.0f, + resource::get_cuda_stream(handle)); + + KeyValueIndexOp conversion_op; + thrust::transform(resource::get_thrust_policy(handle), + kvp.data_handle(), + kvp.data_handle() + m, + min, + conversion_op); + resource::sync_stream(handle); +} + +} // end namespace raft::runtime::distance diff --git a/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu b/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu index d8949a645b..82a225aca4 100644 --- a/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu +++ b/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "fused_distance_min_arg.hpp" #include #include #include @@ -27,77 +28,28 @@ namespace raft::runtime::distance { -template -struct KeyValueIndexOp { - __host__ __device__ __forceinline__ IndexT - operator()(const raft::KeyValuePair& a) const - { - return a.key; - } -}; - -template -void compute_fused_l2_nn_min_arg(raft::resources const& handle, - idx_t* min, - const value_t* x, - const value_t* y, - idx_t m, - idx_t n, - idx_t k, - bool sqrt) -{ - rmm::device_uvector workspace(m, resource::get_cuda_stream(handle)); - auto kvp = raft::make_device_vector>(handle, m); - - rmm::device_uvector x_norms(m, resource::get_cuda_stream(handle)); - rmm::device_uvector y_norms(n, resource::get_cuda_stream(handle)); - raft::linalg::rowNorm( - x_norms.data(), x, k, m, raft::linalg::L2Norm, true, resource::get_cuda_stream(handle)); - raft::linalg::rowNorm( - y_norms.data(), y, k, n, raft::linalg::L2Norm, true, resource::get_cuda_stream(handle)); - - raft::distance::fusedL2NNMinReduce(kvp.data_handle(), - x, - y, - x_norms.data(), - y_norms.data(), - m, - n, - k, - (void*)workspace.data(), - sqrt, - true, - resource::get_cuda_stream(handle)); - - KeyValueIndexOp conversion_op; - thrust::transform(resource::get_thrust_policy(handle), - kvp.data_handle(), - kvp.data_handle() + m, - min, - conversion_op); - resource::sync_stream(handle); -} - -void fused_l2_nn_min_arg(raft::resources const& handle, - int* min, - const float* x, - const float* y, - int m, - int n, - int k, - bool sqrt) +[[deprecated("use fused_distance_nn_min_arg instead")]] void fused_l2_nn_min_arg( + raft::resources const& handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt) { compute_fused_l2_nn_min_arg(handle, min, x, y, m, n, k, sqrt); } -void fused_l2_nn_min_arg(raft::resources const& handle, - int* min, - const double* x, - const double* y, - int m, - int n, - int k, - bool sqrt) +[[deprecated("use fused_distance_nn_min_arg instead")]] void fused_l2_nn_min_arg( + raft::resources const& handle, + int* min, + const double* x, + const double* y, + int m, + int n, + int k, + bool sqrt) { compute_fused_l2_nn_min_arg(handle, min, x, y, m, n, k, sqrt); } diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index 565895565f..2a99acd0e4 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,6 @@ #include #include #include -#include #include #include #include From 5a48625d2b6a89ba8d716ee2c827132aaebad8b4 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 23 Feb 2024 01:01:23 -0800 Subject: [PATCH 05/12] expose fused_distance_nn in pylibraft and add unit test for it with all supported distance metrics --- .../pylibraft/distance/CMakeLists.txt | 4 +- .../pylibraft/pylibraft/distance/__init__.py | 9 +- .../pylibraft/distance/fused_distance_nn.pyx | 218 ++++++++++++++++++ .../test/test_fused_distance_argmin.py | 69 ++++++ 4 files changed, 296 insertions(+), 4 deletions(-) create mode 100755 python/pylibraft/pylibraft/distance/fused_distance_nn.pyx create mode 100755 python/pylibraft/pylibraft/test/test_fused_distance_argmin.py diff --git a/python/pylibraft/pylibraft/distance/CMakeLists.txt b/python/pylibraft/pylibraft/distance/CMakeLists.txt index 14f0cc441a..2530e07a98 100644 --- a/python/pylibraft/pylibraft/distance/CMakeLists.txt +++ b/python/pylibraft/pylibraft/distance/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -13,7 +13,7 @@ # ============================================================================= # Set the list of Cython files to build -set(cython_sources pairwise_distance.pyx fused_l2_nn.pyx) +set(cython_sources pairwise_distance.pyx fused_l2_nn.pyx fused_distance_nn.pyx) set(linked_libraries raft::raft raft::compiled) # Build all of the Cython targets diff --git a/python/pylibraft/pylibraft/distance/__init__.py b/python/pylibraft/pylibraft/distance/__init__.py index f059b5f3dd..d16ab30b2f 100644 --- a/python/pylibraft/pylibraft/distance/__init__.py +++ b/python/pylibraft/pylibraft/distance/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,12 @@ # limitations under the License. # +from .fused_distance_nn import fused_distance_nn_argmin from .fused_l2_nn import fused_l2_nn_argmin from .pairwise_distance import DISTANCE_TYPES, distance as pairwise_distance -__all__ = ["fused_l2_nn_argmin", "pairwise_distance"] +__all__ = [ + "fused_distance_nn_argmin", + "fused_l2_nn_argmin", + "pairwise_distance", +] diff --git a/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx b/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx new file mode 100755 index 0000000000..f98b9e710a --- /dev/null +++ b/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx @@ -0,0 +1,218 @@ +# +# Copyright (c) 2022-2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +import numpy as np + +from cython.operator cimport dereference as deref +from libc.stdint cimport uintptr_t +from libcpp cimport bool + +from .distance_type cimport DistanceType + +from pylibraft.common import ( + Handle, + auto_convert_output, + cai_wrapper, + device_ndarray, +) +from pylibraft.common.handle import auto_sync_handle + +from pylibraft.common.handle cimport device_resources + + +cdef extern from "raft_runtime/distance/fused_distance_nn.hpp" \ + namespace "raft::runtime::distance" nogil: + + void fused_distance_nn_min_arg( + const device_resources &handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt, + DistanceType metric, + bool isRowMajor, + float metric_arg) except + + + +DISTANCE_TYPES = { + "l2": DistanceType.L2SqrtExpanded, + "sqeuclidean": DistanceType.L2Expanded, + "euclidean": DistanceType.L2SqrtExpanded, + "l1": DistanceType.L1, + "cityblock": DistanceType.L1, + "inner_product": DistanceType.InnerProduct, + "chebyshev": DistanceType.Linf, + "canberra": DistanceType.Canberra, + "cosine": DistanceType.CosineExpanded, + "lp": DistanceType.LpUnexpanded, + "correlation": DistanceType.CorrelationExpanded, + "jaccard": DistanceType.JaccardExpanded, + "hellinger": DistanceType.HellingerExpanded, + "braycurtis": DistanceType.BrayCurtis, + "jensenshannon": DistanceType.JensenShannon, + "hamming": DistanceType.HammingUnexpanded, + "kl_divergence": DistanceType.KLDivergence, + "minkowski": DistanceType.LpUnexpanded, + "russellrao": DistanceType.RusselRaoExpanded, + "dice": DistanceType.DiceExpanded, +} + +SUPPORTED_DISTANCES = ["euclidean", "l2", "cosine", "sqeuclidean"] + + +@auto_sync_handle +@auto_convert_output +def fused_distance_nn_argmin(X, Y, out=None, sqrt=True, metric="euclidean", + handle=None): + """ + Compute the 1-nearest neighbors between X and Y using the L2 distance + + Parameters + ---------- + + X : CUDA array interface compliant matrix shape (m, k) + Y : CUDA array interface compliant matrix shape (n, k) + out : Writable CUDA array interface matrix shape (m, 1) + metric : string denoting the metric type (default="euclidean") + + {handle_docstring} + + Examples + -------- + To compute the 1-nearest neighbors argmin: + + >>> import cupy as cp + >>> from pylibraft.common import Handle + >>> from pylibraft.distance import fused_distance_nn_argmin + >>> n_samples = 5000 + >>> n_clusters = 5 + >>> n_features = 50 + >>> in1 = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> in2 = cp.random.random_sample((n_clusters, n_features), + ... dtype=cp.float32) + >>> # A single RAFT handle can optionally be reused across + >>> # pylibraft functions. + >>> handle = Handle() + + >>> output = fused_distance_nn_argmin(in1, in2, handle=handle) + + >>> # pylibraft functions are often asynchronous so the + >>> # handle needs to be explicitly synchronized + >>> handle.sync() + + The output can also be computed in-place on a preallocated + array: + + >>> import cupy as cp + >>> from pylibraft.common import Handle + >>> from pylibraft.distance import fused_distance_nn_argmin + >>> n_samples = 5000 + >>> n_clusters = 5 + >>> n_features = 50 + >>> in1 = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> in2 = cp.random.random_sample((n_clusters, n_features), + ... dtype=cp.float32) + >>> output = cp.empty((n_samples, 1), dtype=cp.int32) + >>> # A single RAFT handle can optionally be reused across + >>> # pylibraft functions. + >>> handle = Handle() + + >>> fused_distance_nn_argmin(in1, in2, out=output, handle=handle) + array(...) + + >>> # pylibraft functions are often asynchronous so the + >>> # handle needs to be explicitly synchronized + >>> handle.sync() + """ + + x_cai = cai_wrapper(X) + y_cai = cai_wrapper(Y) + + x_dt = x_cai.dtype + y_dt = y_cai.dtype + + m = x_cai.shape[0] + n = y_cai.shape[0] + + if out is None: + output = device_ndarray.empty((m,), dtype="int32") + else: + output = out + + output_cai = cai_wrapper(output) + + x_k = x_cai.shape[1] + y_k = y_cai.shape[1] + + if x_k != y_k: + raise ValueError("Inputs must have same number of columns. " + "a=%s, b=%s" % (x_k, y_k)) + + if metric not in SUPPORTED_DISTANCES: + raise ValueError("metric %s is not supported" % metric) + + cdef DistanceType distance_type = DISTANCE_TYPES[metric] + + x_ptr = x_cai.data + y_ptr = y_cai.data + + d_ptr = output_cai.data + + handle = handle if handle is not None else Handle() + cdef device_resources *h = handle.getHandle() + + d_dt = output_cai.dtype + + x_c_contiguous = x_cai.c_contiguous + y_c_contiguous = y_cai.c_contiguous + + if x_c_contiguous != y_c_contiguous: + raise ValueError("Inputs must have matching strides") + + if not x_c_contiguous: + raise ValueError("Inputs must be C contiguous") + + if x_dt != y_dt: + raise ValueError("Inputs must have the same dtypes") + if d_dt != np.int32: + raise ValueError("Output array must be int32") + # unused arg for now. + metric_arg = 0.0 + if x_dt == np.float32: + fused_distance_nn_min_arg(deref(h), + d_ptr, + x_ptr, + y_ptr, + m, + n, + x_k, + sqrt, + distance_type, + x_c_contiguous, + metric_arg) + else: + raise ValueError("dtype %s not supported" % x_dt) + + return output diff --git a/python/pylibraft/pylibraft/test/test_fused_distance_argmin.py b/python/pylibraft/pylibraft/test/test_fused_distance_argmin.py new file mode 100755 index 0000000000..26aa3f5ab2 --- /dev/null +++ b/python/pylibraft/pylibraft/test/test_fused_distance_argmin.py @@ -0,0 +1,69 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pytest +from scipy.spatial.distance import cdist + +from pylibraft.common import DeviceResources, device_ndarray +from pylibraft.distance import fused_distance_nn_argmin + + +@pytest.mark.parametrize("inplace", [True, False]) +@pytest.mark.parametrize("n_rows", [10, 100]) +@pytest.mark.parametrize("n_clusters", [50, 100]) +@pytest.mark.parametrize("n_cols", [128, 31]) +@pytest.mark.parametrize("dtype", [np.float32]) +@pytest.mark.parametrize( + "metric", + [ + "euclidean", + "cosine", + "sqeuclidean", + ], +) +def test_fused_distance_nn_minarg( + n_rows, n_cols, n_clusters, dtype, inplace, metric +): + input1 = np.random.random_sample((n_rows, n_cols)) + input1 = np.asarray(input1, order="C").astype(dtype) + + input2 = np.random.random_sample((n_clusters, n_cols)) + input2 = np.asarray(input2, order="C").astype(dtype) + + output = np.zeros((n_rows), dtype="int32") + expected = cdist(input1, input2, metric) + + expected = expected.argmin(axis=1) + + input1_device = device_ndarray(input1) + input2_device = device_ndarray(input2) + output_device = device_ndarray(output) if inplace else None + + is_sqrt = True if metric == "sqeuclidean" else False + handle = DeviceResources() + ret_output = fused_distance_nn_argmin( + input1_device, + input2_device, + output_device, + is_sqrt, + metric, + handle=handle, + ) + handle.sync() + output_device = ret_output if not inplace else output_device + actual = output_device.copy_to_host() + + assert np.allclose(expected, actual, rtol=1e-4) From 63261c6e82d8e38bcdd81577fd7c94f84b317355 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 23 Feb 2024 01:45:19 -0800 Subject: [PATCH 06/12] correct the description for fused distance nn arg min pylibraft API --- python/pylibraft/pylibraft/distance/fused_distance_nn.pyx | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx b/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx index f98b9e710a..e19c96a896 100755 --- a/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx +++ b/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx @@ -85,7 +85,10 @@ SUPPORTED_DISTANCES = ["euclidean", "l2", "cosine", "sqeuclidean"] def fused_distance_nn_argmin(X, Y, out=None, sqrt=True, metric="euclidean", handle=None): """ - Compute the 1-nearest neighbors between X and Y using the L2 distance + Compute the 1-nearest neighbors between X and Y using the distance metrics + + Valid values for metric: + ["euclidean", "l2", "cosine", "sqeuclidean"] Parameters ---------- From f53e43917feed3c4dbff8605e250813fa388ef2d Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 23 Feb 2024 04:50:13 -0800 Subject: [PATCH 07/12] fix the fused_l2_nn header name in masked_nn.cuh --- cpp/include/raft/distance/detail/masked_nn.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh index 4de9f4764a..4ff83dce89 100644 --- a/cpp/include/raft/distance/detail/masked_nn.cuh +++ b/cpp/include/raft/distance/detail/masked_nn.cuh @@ -22,7 +22,7 @@ #include #include -#include +#include #include #include #include From 958a2b3495326b725511bfbf7d824a880a93906b Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 23 Feb 2024 05:17:08 -0800 Subject: [PATCH 08/12] fix copyright year in masked_nn.cuh --- cpp/include/raft/distance/detail/masked_nn.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh index 4ff83dce89..7bbb1ae789 100644 --- a/cpp/include/raft/distance/detail/masked_nn.cuh +++ b/cpp/include/raft/distance/detail/masked_nn.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From 7a0a6db2cdc8164918e231fb13bfcae7fab278cd Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 28 Feb 2024 02:13:32 -0800 Subject: [PATCH 09/12] fix copyright year for newly added source files --- cpp/include/raft/distance/detail/fused_distance_nn.cuh | 2 +- .../raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh | 2 +- cpp/include/raft_runtime/distance/fused_distance_nn.hpp | 2 +- cpp/src/distance/fused_distance_nn.cu | 2 +- cpp/src/raft_runtime/distance/fused_distance_min_arg.cu | 2 +- cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp | 2 +- cpp/test/distance/fused_cosine_nn.cu | 2 +- python/pylibraft/pylibraft/distance/fused_distance_nn.pyx | 2 +- python/pylibraft/pylibraft/test/test_fused_distance_argmin.py | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_distance_nn.cuh b/cpp/include/raft/distance/detail/fused_distance_nn.cuh index f679c3584b..181cc71fc1 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh index e86db734a5..c6d4fe18a5 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft_runtime/distance/fused_distance_nn.hpp b/cpp/include/raft_runtime/distance/fused_distance_nn.hpp index 09d8d401e4..7c309d6fc7 100644 --- a/cpp/include/raft_runtime/distance/fused_distance_nn.hpp +++ b/cpp/include/raft_runtime/distance/fused_distance_nn.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/src/distance/fused_distance_nn.cu b/cpp/src/distance/fused_distance_nn.cu index fc8a6cb26d..0e4514ac86 100644 --- a/cpp/src/distance/fused_distance_nn.cu +++ b/cpp/src/distance/fused_distance_nn.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu b/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu index 6c6caf8687..5f2171d92c 100644 --- a/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu +++ b/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp b/cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp index d348e7755b..c9fb202e6c 100644 --- a/cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp +++ b/cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/distance/fused_cosine_nn.cu b/cpp/test/distance/fused_cosine_nn.cu index e87692094f..4e545e9aea 100644 --- a/cpp/test/distance/fused_cosine_nn.cu +++ b/cpp/test/distance/fused_cosine_nn.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx b/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx index e19c96a896..256b632c81 100755 --- a/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx +++ b/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2022-2024, NVIDIA CORPORATION. +# Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/pylibraft/pylibraft/test/test_fused_distance_argmin.py b/python/pylibraft/pylibraft/test/test_fused_distance_argmin.py index 26aa3f5ab2..6736128242 100755 --- a/python/pylibraft/pylibraft/test/test_fused_distance_argmin.py +++ b/python/pylibraft/pylibraft/test/test_fused_distance_argmin.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION. +# Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From f4974db2d95e6fcafd2537974f2b2f95c903cb1c Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 29 Feb 2024 06:40:30 -0800 Subject: [PATCH 10/12] fix fused_distance_nn.pyx file permission to be rw instead of rwx --- python/pylibraft/pylibraft/distance/fused_distance_nn.pyx | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100755 => 100644 python/pylibraft/pylibraft/distance/fused_distance_nn.pyx diff --git a/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx b/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx old mode 100755 new mode 100644 From 757145957ffa8070decd93823e86e63105edbd59 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 7 Mar 2024 00:54:34 -0800 Subject: [PATCH 11/12] fix clang formatting issues --- .../raft/distance/detail/fused_distance_nn.cuh | 5 +++-- .../detail/fused_distance_nn/cutlass_base.cuh | 14 ++------------ .../fused_distance_nn/epilogue_elementwise.cuh | 5 ++--- .../detail/fused_distance_nn/fused_cosine_nn.cuh | 5 +++-- .../detail/fused_distance_nn/fused_l2_nn.cuh | 5 +++-- .../detail/fused_distance_nn/helper_structs.cuh | 5 +++-- .../detail/fused_distance_nn/simt_kernel.cuh | 5 +++-- .../raft/distance/fused_distance_nn-ext.cuh | 3 ++- .../raft/distance/fused_distance_nn-inl.cuh | 7 +++++-- cpp/include/raft/distance/fused_l2_nn-ext.cuh | 6 +++--- cpp/src/distance/fused_distance_nn.cu | 3 ++- .../distance/fused_distance_min_arg.cu | 2 ++ .../distance/fused_distance_min_arg.hpp | 1 + cpp/src/raft_runtime/distance/fused_l2_min_arg.cu | 1 + cpp/test/distance/fused_cosine_nn.cu | 4 +++- 15 files changed, 38 insertions(+), 33 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_distance_nn.cuh b/cpp/include/raft/distance/detail/fused_distance_nn.cuh index 181cc71fc1..4fbfdc8755 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn.cuh @@ -16,8 +16,6 @@ #pragma once -#include // size_t -#include // std::numeric_limits #include // raft::KeyValuePair #include // raft::identity_op #include // ops::l2_exp_distance_op @@ -32,6 +30,9 @@ #include // raft::util::arch::SM_* #include // raft::ceildiv, raft::shfl +#include // size_t +#include // std::numeric_limits + namespace raft { namespace distance { diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh index 3f359c2f54..b2fc5e0cc7 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh @@ -26,18 +26,6 @@ #define cutlass raft_cutlass #endif -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include #include // FusedDistanceNNEpilogueElementwise #include // FusedDistanceNNGemm #include // getMultiProcessorCount @@ -45,6 +33,8 @@ #include +#include + #include #include #include diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh index 993cf2f086..e69b2486df 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh @@ -56,6 +56,8 @@ #pragma once +#include + #include #include #include @@ -63,9 +65,6 @@ #include #include -#include -#include - ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh index c6d4fe18a5..f29c8b4d4c 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh @@ -16,8 +16,6 @@ #pragma once -#include // size_t -#include // std::numeric_limits #include // raft::KeyValuePair #include // raft::identity_op #include // ops::l2_exp_distance_op @@ -29,6 +27,9 @@ #include // raft::util::arch::SM_* #include // raft::ceildiv, raft::shfl +#include // size_t +#include // std::numeric_limits + namespace raft { namespace distance { diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/fused_l2_nn.cuh index 75275d40b3..65475e73c7 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn/fused_l2_nn.cuh @@ -16,8 +16,6 @@ #pragma once -#include // size_t -#include // std::numeric_limits #include // raft::KeyValuePair #include // raft::identity_op #include // ops::l2_exp_distance_op @@ -29,6 +27,9 @@ #include // raft::util::arch::SM_* #include // raft::ceildiv, raft::shfl +#include // size_t +#include // std::numeric_limits + namespace raft { namespace distance { diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh index e88ea9cfc8..e056c5d397 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh @@ -16,8 +16,6 @@ #pragma once -#include // size_t -#include // std::numeric_limits #include // raft::KeyValuePair #include // raft::identity_op #include // ops::l2_exp_distance_op @@ -29,6 +27,9 @@ #include // raft::ceildiv, raft::shfl #include +#include // size_t +#include // std::numeric_limits + namespace raft { namespace distance { diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh index f5e4c725d6..7417fd5dac 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh @@ -16,13 +16,14 @@ #pragma once -#include // size_t -#include // std::numeric_limits #include // raft::KeyValuePair #include // ops::l2_exp_distance_op #include // PairwiseDistances #include // Policy +#include // size_t +#include // std::numeric_limits + namespace raft { namespace distance { namespace detail { diff --git a/cpp/include/raft/distance/fused_distance_nn-ext.cuh b/cpp/include/raft/distance/fused_distance_nn-ext.cuh index 0b9096423a..263bbcea81 100644 --- a/cpp/include/raft/distance/fused_distance_nn-ext.cuh +++ b/cpp/include/raft/distance/fused_distance_nn-ext.cuh @@ -16,12 +16,13 @@ #pragma once -#include // int64_t #include // raft::KeyValuePair #include // raft::resources #include // include initialize and reduce operations #include // RAFT_EXPLICIT +#include // int64_t + #ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY namespace raft { diff --git a/cpp/include/raft/distance/fused_distance_nn-inl.cuh b/cpp/include/raft/distance/fused_distance_nn-inl.cuh index 5ec4b8c5cf..ffe86a1c04 100644 --- a/cpp/include/raft/distance/fused_distance_nn-inl.cuh +++ b/cpp/include/raft/distance/fused_distance_nn-inl.cuh @@ -19,14 +19,17 @@ #pragma once -#include -#include #include #include #include #include #include + +#include + #include + +#include #include namespace raft { diff --git a/cpp/include/raft/distance/fused_l2_nn-ext.cuh b/cpp/include/raft/distance/fused_l2_nn-ext.cuh index ffc6cccab5..d0ac83cd51 100644 --- a/cpp/include/raft/distance/fused_l2_nn-ext.cuh +++ b/cpp/include/raft/distance/fused_l2_nn-ext.cuh @@ -16,10 +16,10 @@ #pragma once -#include // raft::KeyValuePair -#include // raft::resources +#include // raft::KeyValuePair +#include // raft::resources #include // include initialize and reduce operations -#include // RAFT_EXPLICIT +#include // RAFT_EXPLICIT #include // int64_t diff --git a/cpp/src/distance/fused_distance_nn.cu b/cpp/src/distance/fused_distance_nn.cu index 0e4514ac86..dc722d929c 100644 --- a/cpp/src/distance/fused_distance_nn.cu +++ b/cpp/src/distance/fused_distance_nn.cu @@ -14,10 +14,11 @@ * limitations under the License. */ -#include // int64_t #include // raft::KeyValuePair #include +#include // int64_t + #define instantiate_raft_distance_fusedDistanceNNMinReduce(DataT, OutT, IdxT) \ template void raft::distance::fusedDistanceNNMinReduce( \ OutT * min, \ diff --git a/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu b/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu index 5f2171d92c..dfdff4e94b 100644 --- a/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu +++ b/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu @@ -15,6 +15,7 @@ */ #include "fused_distance_min_arg.hpp" + #include #include #include @@ -22,6 +23,7 @@ #include #include #include + #include #include diff --git a/cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp b/cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp index c9fb202e6c..6452752a79 100644 --- a/cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp +++ b/cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp @@ -25,6 +25,7 @@ #include #include #include + #include #include diff --git a/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu b/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu index f421b2fe39..870757dca1 100644 --- a/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu +++ b/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu @@ -15,6 +15,7 @@ */ #include "fused_distance_min_arg.hpp" + #include #include #include diff --git a/cpp/test/distance/fused_cosine_nn.cu b/cpp/test/distance/fused_cosine_nn.cu index 4e545e9aea..d4d632e1dc 100644 --- a/cpp/test/distance/fused_cosine_nn.cu +++ b/cpp/test/distance/fused_cosine_nn.cu @@ -17,7 +17,7 @@ #undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Search with filter instantiation #include "../test_utils.cuh" -#include + #include #include #include @@ -27,6 +27,8 @@ #include #include +#include + namespace raft { namespace distance { From 9f63c564f5c7d0feb9d41fb21ef0b281147ba352 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade <36705640+mdoijade@users.noreply.github.com> Date: Mon, 18 Mar 2024 23:34:58 +0530 Subject: [PATCH 12/12] Update python/pylibraft/pylibraft/distance/fused_distance_nn.pyx Co-authored-by: Ben Frederickson --- .../pylibraft/distance/fused_distance_nn.pyx | 23 +------------------ 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx b/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx index 256b632c81..0e9fa4b366 100644 --- a/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx +++ b/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx @@ -54,28 +54,7 @@ cdef extern from "raft_runtime/distance/fused_distance_nn.hpp" \ float metric_arg) except + -DISTANCE_TYPES = { - "l2": DistanceType.L2SqrtExpanded, - "sqeuclidean": DistanceType.L2Expanded, - "euclidean": DistanceType.L2SqrtExpanded, - "l1": DistanceType.L1, - "cityblock": DistanceType.L1, - "inner_product": DistanceType.InnerProduct, - "chebyshev": DistanceType.Linf, - "canberra": DistanceType.Canberra, - "cosine": DistanceType.CosineExpanded, - "lp": DistanceType.LpUnexpanded, - "correlation": DistanceType.CorrelationExpanded, - "jaccard": DistanceType.JaccardExpanded, - "hellinger": DistanceType.HellingerExpanded, - "braycurtis": DistanceType.BrayCurtis, - "jensenshannon": DistanceType.JensenShannon, - "hamming": DistanceType.HammingUnexpanded, - "kl_divergence": DistanceType.KLDivergence, - "minkowski": DistanceType.LpUnexpanded, - "russellrao": DistanceType.RusselRaoExpanded, - "dice": DistanceType.DiceExpanded, -} +from pylibraft.distance.pairwise_distance import DISTANCE_TYPES SUPPORTED_DISTANCES = ["euclidean", "l2", "cosine", "sqeuclidean"]