diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index 9b69d437f4..7424a5ff81 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -105,8 +105,8 @@ DI void storeWarpQGmem(myWarpSelect** heapArr, for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { const auto idx = j * warpSize + lid; if (idx < numOfNN) { - out_dists[gmemRowId * numOfNN + idx] = heapArr[i]->warpK[j]; - out_inds[gmemRowId * numOfNN + idx] = (IdxT)heapArr[i]->warpV[j]; + out_dists[std::size_t(gmemRowId) * numOfNN + idx] = heapArr[i]->warpK[j]; + out_inds[std::size_t(gmemRowId) * numOfNN + idx] = (IdxT)heapArr[i]->warpV[j]; } } } @@ -130,8 +130,8 @@ DI void loadPrevTopKsGmemWarpQ(myWarpSelect** heapArr, for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { const auto idx = j * warpSize + lid; if (idx < numOfNN) { - heapArr[i]->warpK[j] = out_dists[gmemRowId * numOfNN + idx]; - heapArr[i]->warpV[j] = (uint32_t)out_inds[gmemRowId * numOfNN + idx]; + heapArr[i]->warpK[j] = out_dists[std::size_t(gmemRowId) * numOfNN + idx]; + heapArr[i]->warpV[j] = (uint32_t)out_inds[std::size_t(gmemRowId) * numOfNN + idx]; } } static constexpr auto kLaneWarpKTop = myWarpSelect::kNumWarpQRegisters - 1; @@ -490,7 +490,7 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x } } - if (((gridStrideX + Policy::Nblk * gridDim.x) > n) && gridDim.x == 1) { + if (((gridStrideX + Policy::Nblk * gridDim.x) >= n) && gridDim.x == 1) { // This is last iteration of grid stride X loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); diff --git a/cpp/test/spatial/fused_l2_knn.cu b/cpp/test/spatial/fused_l2_knn.cu index 2ec4e86d1f..70b83fad35 100644 --- a/cpp/test/spatial/fused_l2_knn.cu +++ b/cpp/test/spatial/fused_l2_knn.cu @@ -212,12 +212,14 @@ class FusedL2KNNTest : public ::testing::TestWithParam { const std::vector inputs = { {100, 1000, 16, 10, raft::distance::DistanceType::L2Expanded}, + {256, 256, 30, 10, raft::distance::DistanceType::L2Expanded}, {1000, 10000, 16, 10, raft::distance::DistanceType::L2Expanded}, {100, 1000, 16, 50, raft::distance::DistanceType::L2Expanded}, {20, 10000, 16, 10, raft::distance::DistanceType::L2Expanded}, {1000, 10000, 16, 50, raft::distance::DistanceType::L2Expanded}, {1000, 10000, 32, 50, raft::distance::DistanceType::L2Expanded}, {10000, 40000, 32, 30, raft::distance::DistanceType::L2Expanded}, + {131072, 131072, 8, 60, raft::distance::DistanceType::L2Expanded}, // L2 unexpanded {100, 1000, 16, 10, raft::distance::DistanceType::L2Unexpanded}, {1000, 10000, 16, 10, raft::distance::DistanceType::L2Unexpanded}, @@ -225,7 +227,8 @@ const std::vector inputs = { {20, 10000, 16, 50, raft::distance::DistanceType::L2Unexpanded}, {1000, 10000, 16, 50, raft::distance::DistanceType::L2Unexpanded}, {1000, 10000, 32, 50, raft::distance::DistanceType::L2Unexpanded}, - {10000, 40000, 32, 30, raft::distance::DistanceType::L2Unexpanded}}; + {10000, 40000, 32, 30, raft::distance::DistanceType::L2Unexpanded}, + {131072, 131072, 8, 60, raft::distance::DistanceType::L2Unexpanded}}; typedef FusedL2KNNTest FusedL2KNNTestF; TEST_P(FusedL2KNNTestF, FusedBruteForce) { this->testBruteForce(); }