Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initializing memory in RBC #509

Merged
merged 5 commits into from
Feb 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion cpp/include/raft/spatial/knn/detail/ball_cover.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -76,6 +76,9 @@ void sample_landmarks(const raft::handle_t& handle,
thrust::fill(
handle.get_thrust_policy(), R_1nn_ones.data(), R_1nn_ones.data() + R_1nn_ones.size(), 1.0);

thrust::fill(
handle.get_thrust_policy(), R_indices.data(), R_indices.data() + R_indices.size(), 0.0);

/**
* 1. Randomly sample sqrt(n) points from X
*/
Expand Down Expand Up @@ -234,6 +237,16 @@ void perform_rbc_query(const raft::handle_t& handle,
float weight = 1.0,
bool perform_post_filtering = true)
{
// initialize output inds and dists
thrust::fill(handle.get_thrust_policy(),
inds,
inds + (k * n_query_pts),
std::numeric_limits<value_idx>::max());
thrust::fill(handle.get_thrust_policy(),
dists,
dists + (k * n_query_pts),
std::numeric_limits<value_t>::max());

// Compute nearest k for each neighborhood in each closest R
rbc_low_dim_pass_one(handle,
index,
Expand Down Expand Up @@ -289,6 +302,16 @@ void rbc_build_index(const raft::handle_t& handle,
rmm::device_uvector<value_idx> R_knn_inds(index.m, handle.get_stream());
rmm::device_uvector<value_t> R_knn_dists(index.m, handle.get_stream());

// Initialize the uvectors
thrust::fill(handle.get_thrust_policy(),
R_knn_inds.begin(),
R_knn_inds.end(),
std::numeric_limits<value_idx>::max());
thrust::fill(handle.get_thrust_policy(),
R_knn_dists.begin(),
R_knn_dists.end(),
std::numeric_limits<value_t>::max());

/**
* 1. Randomly sample sqrt(n) points from X
*/
Expand Down Expand Up @@ -340,6 +363,16 @@ void rbc_all_knn_query(const raft::handle_t& handle,
rmm::device_uvector<value_idx> R_knn_inds(k * index.m, handle.get_stream());
rmm::device_uvector<value_t> R_knn_dists(k * index.m, handle.get_stream());

// Initialize the uvectors
thrust::fill(handle.get_thrust_policy(),
R_knn_inds.begin(),
R_knn_inds.end(),
std::numeric_limits<value_idx>::max());
thrust::fill(handle.get_thrust_policy(),
R_knn_dists.begin(),
R_knn_dists.end(),
std::numeric_limits<value_t>::max());

// For debugging / verification. Remove before releasing
rmm::device_uvector<value_int> dists_counter(index.m, handle.get_stream());
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
rmm::device_uvector<value_int> post_dists_counter(index.m, handle.get_stream());
Expand Down Expand Up @@ -396,6 +429,16 @@ void rbc_knn_query(const raft::handle_t& handle,
rmm::device_uvector<value_idx> R_knn_inds(k * index.m, handle.get_stream());
rmm::device_uvector<value_t> R_knn_dists(k * index.m, handle.get_stream());

// Initialize the uvectors
thrust::fill(handle.get_thrust_policy(),
R_knn_inds.begin(),
R_knn_inds.end(),
std::numeric_limits<value_idx>::max());
thrust::fill(handle.get_thrust_policy(),
R_knn_dists.begin(),
R_knn_dists.end(),
std::numeric_limits<value_t>::max());

k_closest_landmarks(handle, index, query, n_query_pts, k, R_knn_inds.data(), R_knn_dists.data());

// For debugging / verification. Remove before releasing
Expand Down
25 changes: 15 additions & 10 deletions cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,14 @@ __global__ void compute_final_dists_registers(const value_t* X_index,
for (; i < limit; i += tpb) {
value_idx cur_candidate_ind = R_1nn_inds[R_start_offset + i];
value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i];
value_t z = heap.warpKTopRDist == 0.00 ? 0.0
: (abs(heap.warpKTop - heap.warpKTopRDist) *

value_t z = heap.warpKTopRDist == 0.00 ? 0.0
: (abs(heap.warpKTop - heap.warpKTopRDist) *
abs(heap.warpKTopRDist - cur_candidate_dist) -
heap.warpKTop * cur_candidate_dist) /
heap.warpKTopRDist;
z = isnan(z) ? 0.0 : z;
z = isnan(z) || isinf(z) ? 0.0 : z;

// If lower bound on distance could possibly be in
// the closest k neighbors, compute it and add to k-select
value_t dist = std::numeric_limits<value_t>::max();
Expand Down Expand Up @@ -261,7 +263,8 @@ __global__ void compute_final_dists_registers(const value_t* X_index,
heap.warpKTop * cur_candidate_dist) /
heap.warpKTopRDist;

z = isnan(z) ? 0.0 : z;
z = isnan(z) || isinf(z) ? 0.0 : z;

// If lower bound on distance could possibly be in
// the closest k neighbors, compute it and add to k-select
value_t dist = std::numeric_limits<value_t>::max();
Expand Down Expand Up @@ -361,8 +364,7 @@ __global__ void block_rbc_kernel_registers(const value_t* X_index,
shared_memV,
k);

value_t min_R_dist = R_knn_dists[blockIdx.x * k + (k - 1)];

value_t min_R_dist = R_knn_dists[blockIdx.x * k + (k - 1)];
value_int n_dists_computed = 0;

/**
Expand Down Expand Up @@ -409,9 +411,10 @@ __global__ void block_rbc_kernel_registers(const value_t* X_index,
heap.warpKTop * cur_candidate_dist) /
heap.warpKTopRDist;

z = isnan(z) ? 0.0 : z;
z = isnan(z) || isinf(z) ? 0.0 : z;
value_t dist = std::numeric_limits<value_t>::max();
if (i < k || z <= heap.warpKTop) {

if (z <= heap.warpKTop) {
const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind);
value_t local_y_ptr[col_q];
for (value_int j = 0; j < n_cols; ++j) {
Expand All @@ -433,9 +436,10 @@ __global__ void block_rbc_kernel_registers(const value_t* X_index,
heap.warpKTop * cur_candidate_dist) /
heap.warpKTopRDist;

z = isnan(z) ? 0.0 : z;
z = isnan(z) || isinf(z) ? 0.0 : z;
value_t dist = std::numeric_limits<value_t>::max();
if (i < k || z <= heap.warpKTop) {

if (z <= heap.warpKTop) {
const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind);
value_t local_y_ptr[col_q];
for (value_int j = 0; j < n_cols; ++j) {
Expand Down Expand Up @@ -610,6 +614,7 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle,
const value_int bitset_size = ceil(index.n_landmarks / 32.0);

rmm::device_uvector<std::uint32_t> bitset(bitset_size * index.m, handle.get_stream());
thrust::fill(handle.get_thrust_policy(), bitset.data(), bitset.data() + bitset.size(), 0);

perform_post_filter_registers<value_idx, value_t, value_int, 128>
<<<n_query_rows, 128, bitset_size * sizeof(std::uint32_t), handle.get_stream()>>>(
Expand Down