Skip to content

Commit

Permalink
fix issue in fusedL2knn which happens when rows are multiple of 256 (#…
Browse files Browse the repository at this point in the history
…604)

This PR fixes issue - #568 and rapidsai/cuml#4624
-- fix issue in fusedL2knn which happens when rows are multiple of 256.
-- make index value to be size_t to avoid int overflow though this doesn't hamper these issues but it may for higher input sizes. 
-- also add some additional test cases in fusedL2knn test.

Authors:
  - Mahesh Doijade (https://github.com/mdoijade)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #604
  • Loading branch information
mdoijade authored Mar 31, 2022
1 parent 4a3dfb9 commit 36329c1
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
10 changes: 5 additions & 5 deletions cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ DI void storeWarpQGmem(myWarpSelect** heapArr,
for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) {
const auto idx = j * warpSize + lid;
if (idx < numOfNN) {
out_dists[gmemRowId * numOfNN + idx] = heapArr[i]->warpK[j];
out_inds[gmemRowId * numOfNN + idx] = (IdxT)heapArr[i]->warpV[j];
out_dists[std::size_t(gmemRowId) * numOfNN + idx] = heapArr[i]->warpK[j];
out_inds[std::size_t(gmemRowId) * numOfNN + idx] = (IdxT)heapArr[i]->warpV[j];
}
}
}
Expand All @@ -130,8 +130,8 @@ DI void loadPrevTopKsGmemWarpQ(myWarpSelect** heapArr,
for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) {
const auto idx = j * warpSize + lid;
if (idx < numOfNN) {
heapArr[i]->warpK[j] = out_dists[gmemRowId * numOfNN + idx];
heapArr[i]->warpV[j] = (uint32_t)out_inds[gmemRowId * numOfNN + idx];
heapArr[i]->warpK[j] = out_dists[std::size_t(gmemRowId) * numOfNN + idx];
heapArr[i]->warpV[j] = (uint32_t)out_inds[std::size_t(gmemRowId) * numOfNN + idx];
}
}
static constexpr auto kLaneWarpKTop = myWarpSelect::kNumWarpQRegisters - 1;
Expand Down Expand Up @@ -490,7 +490,7 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x
}
}

if (((gridStrideX + Policy::Nblk * gridDim.x) > n) && gridDim.x == 1) {
if (((gridStrideX + Policy::Nblk * gridDim.x) >= n) && gridDim.x == 1) {
// This is last iteration of grid stride X
loadAllWarpQShmem<Policy, Pair>(heapArr, &shDumpKV[0], m, numOfNN);
storeWarpQGmem<Policy, Pair>(heapArr, out_dists, out_inds, m, numOfNN, starty);
Expand Down
5 changes: 4 additions & 1 deletion cpp/test/spatial/fused_l2_knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -212,20 +212,23 @@ class FusedL2KNNTest : public ::testing::TestWithParam<FusedL2KNNInputs> {

const std::vector<FusedL2KNNInputs> inputs = {
{100, 1000, 16, 10, raft::distance::DistanceType::L2Expanded},
{256, 256, 30, 10, raft::distance::DistanceType::L2Expanded},
{1000, 10000, 16, 10, raft::distance::DistanceType::L2Expanded},
{100, 1000, 16, 50, raft::distance::DistanceType::L2Expanded},
{20, 10000, 16, 10, raft::distance::DistanceType::L2Expanded},
{1000, 10000, 16, 50, raft::distance::DistanceType::L2Expanded},
{1000, 10000, 32, 50, raft::distance::DistanceType::L2Expanded},
{10000, 40000, 32, 30, raft::distance::DistanceType::L2Expanded},
{131072, 131072, 8, 60, raft::distance::DistanceType::L2Expanded},
// L2 unexpanded
{100, 1000, 16, 10, raft::distance::DistanceType::L2Unexpanded},
{1000, 10000, 16, 10, raft::distance::DistanceType::L2Unexpanded},
{100, 1000, 16, 50, raft::distance::DistanceType::L2Unexpanded},
{20, 10000, 16, 50, raft::distance::DistanceType::L2Unexpanded},
{1000, 10000, 16, 50, raft::distance::DistanceType::L2Unexpanded},
{1000, 10000, 32, 50, raft::distance::DistanceType::L2Unexpanded},
{10000, 40000, 32, 30, raft::distance::DistanceType::L2Unexpanded}};
{10000, 40000, 32, 30, raft::distance::DistanceType::L2Unexpanded},
{131072, 131072, 8, 60, raft::distance::DistanceType::L2Unexpanded}};

typedef FusedL2KNNTest<float> FusedL2KNNTestF;
TEST_P(FusedL2KNNTestF, FusedBruteForce) { this->testBruteForce(); }
Expand Down

0 comments on commit 36329c1

Please sign in to comment.