Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bugs in NN-Descent #3

Merged
merged 2 commits into from
Aug 22, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 50 additions & 39 deletions cpp/include/raft/neighbors/detail/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <limits>
#include <thrust/execution_policy.h>
#include <thrust/fill.h>
#include <queue>

#include "../nn_descent_types.hpp"

Expand Down Expand Up @@ -372,7 +373,7 @@ __device__ __forceinline__ void load_vec(Data_t *vec_buffer, const Data_t *d_vec
const int lane_id) {
if constexpr (std::is_same_v<Data_t, float> or std::is_same_v<Data_t, uint8_t> or std::is_same_v<int8_t, float>) {
constexpr int num_load_elems_per_warp = WARP_SIZE;
for (int step = 0; step < div_up(load_dims, num_load_elems_per_warp); step++) {
for (int step = 0; step < div_up(padding_dims, num_load_elems_per_warp); step++) {
int idx = step * num_load_elems_per_warp + lane_id;
if (idx < load_dims) {
vec_buffer[idx] = d_vec[idx];
Expand All @@ -381,12 +382,12 @@ __device__ __forceinline__ void load_vec(Data_t *vec_buffer, const Data_t *d_vec
}
}
}
if constexpr (std::is_same<Data_t, __half>::value) {
if ((size_t)vec_buffer % sizeof(float2) == 0 && load_dims % 4 == 0 &&
padding_dims % 4 == 0) {
if constexpr (std::is_same_v<Data_t, __half>) {
if ((size_t)d_vec % sizeof(float2) == 0 && (size_t)vec_buffer % sizeof(float2) == 0 &&
load_dims % 4 == 0 && padding_dims % 4 == 0) {
constexpr int num_load_elems_per_warp = WARP_SIZE * 4;
#pragma unroll
for (int step = 0; step < div_up(load_dims, num_load_elems_per_warp); step++) {
for (int step = 0; step < div_up(padding_dims, num_load_elems_per_warp); step++) {
int idx_in_vec = step * num_load_elems_per_warp + lane_id * 4;
if (idx_in_vec + 4 <= load_dims) {
*(float2 *)(vec_buffer + idx_in_vec) = *(float2 *)(d_vec + idx_in_vec);
Expand All @@ -396,7 +397,7 @@ __device__ __forceinline__ void load_vec(Data_t *vec_buffer, const Data_t *d_vec
}
} else {
constexpr int num_load_elems_per_warp = WARP_SIZE;
for (int step = 0; step < div_up(load_dims, num_load_elems_per_warp); step++) {
for (int step = 0; step < div_up(padding_dims, num_load_elems_per_warp); step++) {
int idx = step * num_load_elems_per_warp + lane_id;
if (idx < load_dims) {
vec_buffer[idx] = d_vec[idx];
Expand All @@ -408,15 +409,15 @@ __device__ __forceinline__ void load_vec(Data_t *vec_buffer, const Data_t *d_vec
}
}

template <typename Data_t, typename Index_t>
__global__ void preprocess_data_kernel(const Data_t *input_data, __half *output_data, Index_t nrow,
int dim, DistData_t *l2_norms) {
template <typename Data_t>
__global__ void preprocess_data_kernel(const Data_t* input_data, __half* output_data, int dim,
DistData_t* l2_norms, size_t list_offset = 0) {
extern __shared__ char buffer[];
__shared__ float l2_norm;
Data_t *s_vec = (Data_t *)buffer;
size_t list_id = blockIdx.x;
size_t list_id = list_offset + blockIdx.x;

load_vec(s_vec, input_data + list_id * dim, dim, dim, threadIdx.x % WARP_SIZE);
load_vec(s_vec, input_data + blockIdx.x * dim, dim, dim, threadIdx.x % WARP_SIZE);
if (threadIdx.x == 0) {
l2_norm = 0;
}
Expand All @@ -443,9 +444,10 @@ __global__ void preprocess_data_kernel(const Data_t *input_data, __half *output_
int idx = step * WARP_SIZE + threadIdx.x;
if (idx < dim) {
if (l2_norms == nullptr) {
output_data[list_id * dim + idx] = (float)input_data[list_id * dim + idx] / sqrt(l2_norm);
output_data[list_id * dim + idx] =
(float)input_data[(size_t)blockIdx.x * dim + idx] / sqrt(l2_norm);
} else {
output_data[list_id * dim + idx] = input_data[list_id * dim + idx];
output_data[list_id * dim + idx] = input_data[(size_t)blockIdx.x * dim + idx];
if (idx == 0) {
l2_norms[list_id] = l2_norm;
}
Expand Down Expand Up @@ -475,8 +477,7 @@ __global__ void add_rev_edges_kernel(const Index_t *graph, Index_t *rev_graph, i

template <typename Index_t, typename ID_t = InternalID_t<Index_t>>
__device__ void insert_to_global_graph(ResultItem<Index_t> elem, size_t list_id, ID_t *graph,
DistData_t *dists, int node_degree, int *locks,
bool new_new = true) {
DistData_t *dists, int node_degree, int *locks) {
int tx = threadIdx.x;
int lane_id = tx % WARP_SIZE;
size_t global_idx_base = list_id * node_degree;
Expand Down Expand Up @@ -760,8 +761,7 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 4)
if (idx_in_list >= list_new_size) continue;
auto min_elem = get_min_item(s_list[idx_in_list], idx_in_list, new_neighbors, s_distances);
if (min_elem.id() < gridDim.x) {
insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks,
true);
insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks);
}
}

Expand Down Expand Up @@ -851,8 +851,7 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 4)
}

if (min_elem.id() < gridDim.x) {
insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks,
false);
insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks);
}
}
}
Expand Down Expand Up @@ -945,18 +944,13 @@ void GnndGraph<Index_t>::init_random_graph() {

#pragma omp parallel for
for (size_t i = 0; i < nrow; i++) {
for (size_t j = 0; j < NUM_SAMPLES; j++) {
size_t idx = i * NUM_SAMPLES + j;
for (size_t j = 0; j < node_degree; j++) {
size_t idx = i * node_degree + j;
Index_t id = rand_seq[idx % nrow];
if ((size_t)id == i) {
id = rand_seq[(idx + NUM_SAMPLES) % nrow];
id = rand_seq[(idx + node_degree) % nrow];
}
h_graph[i * node_degree + j].id_with_flag() = id;
}
for (size_t j = NUM_SAMPLES; j < node_degree; j++) {
h_graph[i * node_degree + j].id_with_flag() = std::numeric_limits<Index_t>::max();
}
for (size_t j = 0; j < node_degree; j++) {
h_dists[i * node_degree + j] = std::numeric_limits<DistData_t>::max();
}
}
Expand Down Expand Up @@ -1113,7 +1107,7 @@ void GNND<Data_t, Index_t>::add_reverse_edges(Index_t* graph_ptr, Index_t* h_rev
list_sizes);
RAFT_CUDA_TRY(cudaMemcpyAsync(h_rev_graph_ptr, d_rev_graph_ptr,
sizeof(*h_rev_graph_ptr) * nrow_ * NUM_SAMPLES,
cudaMemcpyDeviceToHost, stream));
cudaMemcpyDefault, stream));
}

template <typename Data_t, typename Index_t>
Expand All @@ -1136,14 +1130,31 @@ void GNND<Data_t, Index_t>::build(Data_t* data, const Index_t nrow, Index_t* out
cudaPointerAttributes data_ptr_attr;
RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data));
if (data_ptr_attr.type == cudaMemoryTypeUnregistered) {
RAFT_CUDA_TRY(cudaHostRegister(const_cast<std::remove_const_t<Data_t>*>(data), sizeof(Data_t) * nrow * build_config_.dataset_dim,
cudaHostRegisterDefault));
typename std::remove_const<Data_t>::type* input_data;
size_t batch_size = 100000;
RAFT_CUDA_TRY(cudaMallocAsync(&input_data,
sizeof(Data_t) * batch_size * build_config_.dataset_dim, stream));
for (size_t step = 0; step < div_up(nrow_, batch_size); step++) {
size_t list_offset = step * batch_size;
size_t num_lists =
step != div_up(nrow_, batch_size) - 1 ? batch_size : nrow_ - list_offset;
RAFT_CUDA_TRY(cudaMemcpyAsync(
input_data, data + list_offset * build_config_.dataset_dim,
sizeof(Data_t) * num_lists * build_config_.dataset_dim, cudaMemcpyDefault, stream));
preprocess_data_kernel<<<num_lists, WARP_SIZE,
sizeof(Data_t) * div_up(build_config_.dataset_dim, WARP_SIZE) *
WARP_SIZE,
stream>>>(input_data, d_data_, build_config_.dataset_dim,
l2_norms_, list_offset);
}
RAFT_CUDA_TRY(cudaFreeAsync(input_data, stream));
} else {
preprocess_data_kernel<<<
nrow_, WARP_SIZE,
sizeof(Data_t) * div_up(build_config_.dataset_dim, WARP_SIZE) * WARP_SIZE, stream>>>(
data, d_data_, build_config_.dataset_dim, l2_norms_);
}

preprocess_data_kernel<<<
nrow_, WARP_SIZE, sizeof(Data_t) * div_up(build_config_.dataset_dim, WARP_SIZE) * WARP_SIZE,
stream>>>(data, d_data_, nrow_, build_config_.dataset_dim, l2_norms_);

thrust::fill(thrust::device.on(stream), (Index_t*)graph_buffer_,
(Index_t*)graph_buffer_ + (size_t)nrow_ * DEGREE_ON_DEVICE,
std::numeric_limits<Index_t>::max());
Expand All @@ -1168,13 +1179,13 @@ void GNND<Data_t, Index_t>::build(Data_t* data, const Index_t nrow, Index_t* out
for (size_t it = 0; it < build_config_.max_iterations; it++) {
RAFT_CUDA_TRY(cudaMemcpyAsync(d_list_sizes_new_, graph_.h_list_sizes_new,
sizeof(*d_list_sizes_new_) * nrow_,
cudaMemcpyHostToDevice, stream));
cudaMemcpyDefault, stream));
RAFT_CUDA_TRY(cudaMemcpyAsync(h_graph_old_, graph_.h_graph_old,
sizeof(*h_graph_old_) * nrow_ * NUM_SAMPLES,
cudaMemcpyHostToHost, stream));
cudaMemcpyDefault, stream));
RAFT_CUDA_TRY(cudaMemcpyAsync(d_list_sizes_old_, graph_.h_list_sizes_old,
sizeof(*d_list_sizes_old_) * nrow_,
cudaMemcpyHostToDevice, stream));
cudaMemcpyDefault, stream));
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));

std::thread update_and_sample_thread(update_and_sample, it);
Expand All @@ -1201,11 +1212,11 @@ void GNND<Data_t, Index_t>::build(Data_t* data, const Index_t nrow, Index_t* out

RAFT_CUDA_TRY(cudaMemcpyAsync(graph_host_buffer_, graph_buffer_,
sizeof(*graph_buffer_) * nrow_ * DEGREE_ON_DEVICE,
cudaMemcpyDeviceToHost, stream));
cudaMemcpyDefault, stream));
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
RAFT_CUDA_TRY(cudaMemcpyAsync(dists_host_buffer_, dists_buffer_,
sizeof(*dists_buffer_) * nrow_ * DEGREE_ON_DEVICE,
cudaMemcpyDeviceToHost, stream));
cudaMemcpyDefault, stream));
graph_.sample_graph_new(graph_host_buffer_, DEGREE_ON_DEVICE);
}

Expand Down