From 4b8d827ccfe2c85f601882d4ca55efd4bc5d5ff9 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Tue, 22 Nov 2022 19:14:41 +0100 Subject: [PATCH] Solve issue of nan when computing sqrt of a distance of a point to itself --- cpp/include/raft/distance/detail/fused_l2_nn.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 1385d0aa09..e8c2648c2e 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -174,7 +174,8 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = raft::mySqrt(acc[i][j]); + auto acc_ij = acc[i][j]; + acc[i][j] = acc_ij > DataT{0} ? raft::mySqrt(acc_ij) : DataT{0}; } } }