Skip to content

Commit

Permalink
Fix sparse KNN for large batches (#1640)
Browse files Browse the repository at this point in the history
Answers #1187

Authors:
  - Victor Lafargue (https://github.com/viclafargue)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #1640
  • Loading branch information
viclafargue authored Jul 26, 2023
1 parent f99a418 commit f25907b
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 17 deletions.
2 changes: 1 addition & 1 deletion cpp/bench/ann/src/faiss/faiss_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ std::unique_ptr<raft::bench::ann::ANN<T>> create_algo(const std::string& algo,
// stop compiler warning; not all algorithms support multi-GPU so it may not be used
(void)dev_list;

raft::bench::ann::Metric metric = parse_metric(distance);
std::unique_ptr<raft::bench::ann::ANN<T>> ann;

if constexpr (std::is_same_v<T, float>) {
raft::bench::ann::Metric metric = parse_metric(distance);
if (algo == "faiss_gpu_ivf_flat") {
ann = make_algo<T, raft::bench::ann::FaissGpuIVFFlat>(metric, dim, conf, dev_list);
} else if (algo == "faiss_gpu_ivf_pq") {
Expand Down
5 changes: 3 additions & 2 deletions cpp/include/raft/sparse/detail/utils.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -90,7 +90,8 @@ __global__ void iota_fill_block_kernel(value_idx* indices, value_idx ncols)
int tid = threadIdx.x;

for (int i = tid; i < ncols; i += blockDim.x) {
indices[row * ncols + i] = i;
uint64_t idx = (uint64_t)row * (uint64_t)ncols;
indices[idx + i] = i;
}
}

Expand Down
12 changes: 4 additions & 8 deletions cpp/include/raft/sparse/distance/detail/coo_spmv.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,8 @@ inline void balanced_coo_pairwise_generalized_spmv(
strategy_t strategy,
int chunk_size = 500000)
{
RAFT_CUDA_TRY(cudaMemsetAsync(out_dists,
0,
sizeof(value_t) * config_.a_nrows * config_.b_nrows,
resource::get_cuda_stream(config_.handle)));
uint64_t n = (uint64_t)sizeof(value_t) * (uint64_t)config_.a_nrows * (uint64_t)config_.b_nrows;
RAFT_CUDA_TRY(cudaMemsetAsync(out_dists, 0, n, resource::get_cuda_stream(config_.handle)));

strategy.dispatch(out_dists, coo_rows_b, product_func, accum_func, write_func, chunk_size);
};
Expand Down Expand Up @@ -112,10 +110,8 @@ inline void balanced_coo_pairwise_generalized_spmv(
write_f write_func,
int chunk_size = 500000)
{
RAFT_CUDA_TRY(cudaMemsetAsync(out_dists,
0,
sizeof(value_t) * config_.a_nrows * config_.b_nrows,
resource::get_cuda_stream(config_.handle)));
uint64_t n = (uint64_t)sizeof(value_t) * (uint64_t)config_.a_nrows * (uint64_t)config_.b_nrows;
RAFT_CUDA_TRY(cudaMemsetAsync(out_dists, 0, n, resource::get_cuda_stream(config_.handle)));

int max_cols = max_cols_per_block<value_idx, value_t>();

Expand Down
16 changes: 11 additions & 5 deletions cpp/include/raft/sparse/distance/detail/lp_distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,13 @@ class l2_sqrt_unexpanded_distances_t : public l2_unexpanded_distances_t<value_id
void compute(value_t* out_dists)
{
l2_unexpanded_distances_t<value_idx, value_t>::compute(out_dists);

uint64_t n = (uint64_t)this->config_->a_nrows * (uint64_t)this->config_->b_nrows;
// Sqrt Post-processing
raft::linalg::unaryOp<value_t>(
out_dists,
out_dists,
this->config_->a_nrows * this->config_->b_nrows,
n,
[] __device__(value_t input) {
int neg = input < 0 ? -1 : 1;
return raft::sqrt(abs(input) * neg);
Expand Down Expand Up @@ -203,10 +205,11 @@ class lp_unexpanded_distances_t : public distances_t<value_t> {
raft::add_op(),
raft::atomic_add_op());

uint64_t n = (uint64_t)this->config_->a_nrows * (uint64_t)this->config_->b_nrows;
value_t one_over_p = value_t{1} / p;
raft::linalg::unaryOp<value_t>(out_dists,
out_dists,
config_->a_nrows * config_->b_nrows,
n,
raft::pow_const_op<value_t>(one_over_p),
resource::get_cuda_stream(config_->handle));
}
Expand All @@ -229,10 +232,11 @@ class hamming_unexpanded_distances_t : public distances_t<value_t> {
unexpanded_lp_distances<value_idx, value_t>(
out_dists, config_, raft::notequal_op(), raft::add_op(), raft::atomic_add_op());

uint64_t n = (uint64_t)config_->a_nrows * (uint64_t)config_->b_nrows;
value_t n_cols = 1.0 / config_->a_ncols;
raft::linalg::unaryOp<value_t>(out_dists,
out_dists,
config_->a_nrows * config_->b_nrows,
n,
raft::mul_const_op<value_t>(n_cols),
resource::get_cuda_stream(config_->handle));
}
Expand Down Expand Up @@ -271,10 +275,11 @@ class jensen_shannon_unexpanded_distances_t : public distances_t<value_t> {
raft::add_op(),
raft::atomic_add_op());

uint64_t n = (uint64_t)this->config_->a_nrows * (uint64_t)this->config_->b_nrows;
raft::linalg::unaryOp<value_t>(
out_dists,
out_dists,
config_->a_nrows * config_->b_nrows,
n,
[=] __device__(value_t input) { return raft::sqrt(0.5 * input); },
resource::get_cuda_stream(config_->handle));
}
Expand Down Expand Up @@ -311,9 +316,10 @@ class kl_divergence_unexpanded_distances_t : public distances_t<value_t> {
raft::add_op(),
raft::atomic_add_op());

uint64_t n = (uint64_t)this->config_->a_nrows * (uint64_t)this->config_->b_nrows;
raft::linalg::unaryOp<value_t>(out_dists,
out_dists,
config_->a_nrows * config_->b_nrows,
n,
raft::mul_const_op<value_t>(0.5),
resource::get_cuda_stream(config_->handle));
}
Expand Down
3 changes: 2 additions & 1 deletion cpp/include/raft/sparse/neighbors/detail/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ class sparse_knn_t {
/**
* Compute distances
*/
size_t dense_size = idx_batcher.batch_rows() * query_batcher.batch_rows();
uint64_t dense_size =
(uint64_t)idx_batcher.batch_rows() * (uint64_t)query_batcher.batch_rows();
rmm::device_uvector<value_t> batch_dists(dense_size, resource::get_cuda_stream(handle));

RAFT_CUDA_TRY(cudaMemset(batch_dists.data(), 0, batch_dists.size() * sizeof(value_t)));
Expand Down

0 comments on commit f25907b

Please sign in to comment.