Skip to content

Commit

Permalink
Merge pull request #171 from viclafargue/update-bf-knn
Browse files Browse the repository at this point in the history
[REVIEW] Update KNN
  • Loading branch information
cjnolet authored Mar 17, 2021
2 parents 2ef0a51 + bba8495 commit 7091ae3
Show file tree
Hide file tree
Showing 6 changed files with 358 additions and 66 deletions.
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ if(BUILD_RAFT_TESTS)
test/sparse/sort.cu
test/sparse/symmetrize.cu
test/spatial/knn.cu
test/spatial/haversine.cu
test/stats/mean.cu
test/stats/mean_center.cu
test/stats/stddev.cu
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/sparse/selection/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include <raft/sparse/csr.cuh>
#include <raft/sparse/distance/distance.cuh>
#include <raft/sparse/selection/selection.cuh>
#include <raft/spatial/knn/detail/brute_force_knn.hpp>
#include <raft/spatial/knn/detail/brute_force_knn.cuh>
#include <raft/spatial/knn/knn.hpp>

#include <raft/cudart_utils.h>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <raft/handle.hpp>
#include <set>

#include "haversine_distance.cuh"
#include "processing.hpp"

namespace raft {
Expand Down Expand Up @@ -166,24 +167,37 @@ inline void knn_merge_parts(value_t *inK, value_idx *inV, value_t *outK,
inK, inV, outK, outV, n_samples, n_parts, k, stream, translations);
}

inline faiss::MetricType build_faiss_metric(distance::DistanceType metric) {
inline faiss::MetricType build_faiss_metric(
raft::distance::DistanceType metric) {
switch (metric) {
case distance::DistanceType::L2Unexpanded:
case raft::distance::DistanceType::CosineExpanded:
return faiss::MetricType::METRIC_INNER_PRODUCT;
case raft::distance::DistanceType::CorrelationExpanded:
return faiss::MetricType::METRIC_INNER_PRODUCT;
case raft::distance::DistanceType::L2Expanded:
return faiss::MetricType::METRIC_L2;
case raft::distance::DistanceType::L2Unexpanded:
return faiss::MetricType::METRIC_L2;
case raft::distance::DistanceType::L2SqrtExpanded:
return faiss::MetricType::METRIC_L2;
case distance::DistanceType::L1:
case raft::distance::DistanceType::L2SqrtUnexpanded:
return faiss::MetricType::METRIC_L2;
case raft::distance::DistanceType::L1:
return faiss::MetricType::METRIC_L1;
case distance::DistanceType::Linf:
return faiss::MetricType::METRIC_Linf;
case distance::DistanceType::LpUnexpanded:
case raft::distance::DistanceType::InnerProduct:
return faiss::MetricType::METRIC_INNER_PRODUCT;
case raft::distance::DistanceType::LpUnexpanded:
return faiss::MetricType::METRIC_Lp;
case distance::DistanceType::Canberra:
case raft::distance::DistanceType::Linf:
return faiss::MetricType::METRIC_Linf;
case raft::distance::DistanceType::Canberra:
return faiss::MetricType::METRIC_Canberra;
case distance::DistanceType::BrayCurtis:
case raft::distance::DistanceType::BrayCurtis:
return faiss::MetricType::METRIC_BrayCurtis;
case distance::DistanceType::JensenShannon:
case raft::distance::DistanceType::JensenShannon:
return faiss::MetricType::METRIC_JensenShannon;
default:
return faiss::MetricType::METRIC_INNER_PRODUCT;
THROW("MetricType not supported: %d", metric);
}
}

Expand All @@ -209,33 +223,33 @@ inline faiss::MetricType build_faiss_metric(distance::DistanceType metric) {
* @param[in] rowMajorQuery are the query array in row-major layout?
* @param[in] translations translation ids for indices when index rows represent
* non-contiguous partitions
* @param[in] metric corresponds to the FAISS::metricType enum (default is euclidean)
* @param[in] metric corresponds to the raft::distance::DistanceType enum (default is L2Expanded)
* @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm
* @param[in] expanded_form whether or not lp variants should be reduced w/ lp-root
*/
template <typename IntType = int>
void brute_force_knn_impl(
std::vector<float *> &input, std::vector<int> &sizes, IntType D,
float *search_items, IntType n, int64_t *res_I, float *res_D, IntType k,
std::shared_ptr<raft::mr::device::allocator> allocator,
cudaStream_t userStream, cudaStream_t *internalStreams = nullptr,
int n_int_streams = 0, bool rowMajorIndex = true, bool rowMajorQuery = true,
std::vector<int64_t> *translations = nullptr,
distance::DistanceType metric = distance::DistanceType::L2Unexpanded,
float metricArg = 2.0, bool expanded_form = false) {
void brute_force_knn_impl(std::vector<float *> &input, std::vector<int> &sizes,
IntType D, float *search_items, IntType n,
int64_t *res_I, float *res_D, IntType k,
std::shared_ptr<deviceAllocator> allocator,
cudaStream_t userStream,
cudaStream_t *internalStreams = nullptr,
int n_int_streams = 0, bool rowMajorIndex = true,
bool rowMajorQuery = true,
std::vector<int64_t> *translations = nullptr,
raft::distance::DistanceType metric =
raft::distance::DistanceType::L2Expanded,
float metricArg = 0) {
ASSERT(input.size() == sizes.size(),
"input and sizes vectors should be the same size");

faiss::MetricType m = detail::build_faiss_metric(metric);

std::vector<int64_t> *id_ranges;
if (translations == nullptr) {
// If we don't have explicit translations
// for offsets of the indices, build them
// from the local partitions
id_ranges = new std::vector<int64_t>();
int64_t total_n = 0;
for (size_t i = 0; i < input.size(); i++) {
for (int i = 0; i < input.size(); i++) {
id_ranges->push_back(total_n);
total_n += sizes[i];
}
Expand All @@ -252,7 +266,7 @@ void brute_force_knn_impl(

std::vector<std::unique_ptr<MetricProcessor<float>>> metric_processors(
input.size());
for (size_t i = 0; i < input.size(); i++) {
for (int i = 0; i < input.size(); i++) {
metric_processors[i] = create_processor<float>(
metric, sizes[i], D, k, rowMajorQuery, userStream, allocator);
metric_processors[i]->preprocess(input[i]);
Expand Down Expand Up @@ -283,35 +297,52 @@ void brute_force_knn_impl(
// Sync user stream only if using other streams to parallelize query
if (n_int_streams > 0) CUDA_CHECK(cudaStreamSynchronize(userStream));

for (size_t i = 0; i < input.size(); i++) {
faiss::gpu::StandardGpuResources gpu_res;
for (int i = 0; i < input.size(); i++) {
float *out_d_ptr = out_D + (i * k * n);
int64_t *out_i_ptr = out_I + (i * k * n);

cudaStream_t stream =
raft::select_stream(userStream, internalStreams, n_int_streams, i);

gpu_res.noTempMemory();
gpu_res.setDefaultStream(device, stream);

faiss::gpu::GpuDistanceParams args;
args.metric = m;
args.metricArg = metricArg;
args.k = k;
args.dims = D;
args.vectors = input[i];
args.vectorsRowMajor = rowMajorIndex;
args.numVectors = sizes[i];
args.queries = search_items;
args.queriesRowMajor = rowMajorQuery;
args.numQueries = n;
args.outDistances = out_D + (i * k * n);
args.outIndices = out_I + (i * k * n);

/**
* @todo: Until FAISS supports pluggable allocation strategies,
* we will not reap the benefits of the pool allocator for
* avoiding device-wide synchronizations from cudaMalloc/cudaFree
*/
bfKnn(&gpu_res, args);
switch (metric) {
case raft::distance::DistanceType::Haversine:

ASSERT(D == 2,
"Haversine distance requires 2 dimensions "
"(latitude / longitude).");

haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n,
k, stream);
break;
default:
faiss::MetricType m = build_faiss_metric(metric);

faiss::gpu::StandardGpuResources gpu_res;

gpu_res.noTempMemory();
gpu_res.setDefaultStream(device, stream);

faiss::gpu::GpuDistanceParams args;
args.metric = m;
args.metricArg = metricArg;
args.k = k;
args.dims = D;
args.vectors = input[i];
args.vectorsRowMajor = rowMajorIndex;
args.numVectors = sizes[i];
args.queries = search_items;
args.queriesRowMajor = rowMajorQuery;
args.numQueries = n;
args.outDistances = out_d_ptr;
args.outIndices = out_i_ptr;

/**
* @todo: Until FAISS supports pluggable allocation strategies,
* we will not reap the benefits of the pool allocator for
* avoiding device-wide synchronizations from cudaMalloc/cudaFree
*/
bfKnn(&gpu_res, args);
}

CUDA_CHECK(cudaPeekAtLastError());
}
Expand All @@ -326,32 +357,33 @@ void brute_force_knn_impl(
if (input.size() > 1 || translations != nullptr) {
// This is necessary for proper index translations. If there are
// no translations or partitions to combine, it can be skipped.
detail::knn_merge_parts(out_D, out_I, res_D, res_I, n, input.size(), k,
userStream, trans.data());
knn_merge_parts(out_D, out_I, res_D, res_I, n, input.size(), k, userStream,
trans.data());
}

// Perform necessary post-processing
if ((m == faiss::MetricType::METRIC_L2 ||
m == faiss::MetricType::METRIC_Lp) &&
!expanded_form) {
if (metric == raft::distance::DistanceType::L2SqrtExpanded ||
metric == raft::distance::DistanceType::L2SqrtUnexpanded ||
metric == raft::distance::DistanceType::LpUnexpanded) {
/**
* post-processing
*/
float p = 0.5; // standard l2
if (m == faiss::MetricType::METRIC_Lp) p = 1.0 / metricArg;
if (metric == raft::distance::DistanceType::LpUnexpanded)
p = 1.0 / metricArg;
raft::linalg::unaryOp<float>(
res_D, res_D, n * k,
[p] __device__(float input) { return powf(input, p); }, userStream);
}

query_metric_processor->revert(search_items);
query_metric_processor->postprocess(out_D);
for (size_t i = 0; i < input.size(); i++) {
for (int i = 0; i < input.size(); i++) {
metric_processors[i]->revert(input[i]);
}

if (translations == nullptr) delete id_ranges;
}
};

} // namespace detail
} // namespace knn
Expand Down
140 changes: 140 additions & 0 deletions cpp/include/raft/spatial/knn/detail/haversine_distance.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/cudart_utils.h>
#include <raft/cuda_utils.cuh>

#include <faiss/gpu/GpuDistance.h>
#include <faiss/gpu/GpuResources.h>
#include <faiss/gpu/StandardGpuResources.h>
#include <faiss/utils/Heap.h>
#include <faiss/gpu/utils/Limits.cuh>
#include <faiss/gpu/utils/Select.cuh>

#include <raft/linalg/distance_type.h>
#include <raft/handle.hpp>

namespace raft {
namespace spatial {
namespace knn {
namespace detail {

template <typename value_t>
DI value_t compute_haversine(value_t x1, value_t y1, value_t x2, value_t y2) {
value_t sin_0 = sin(0.5 * (x1 - y1));
value_t sin_1 = sin(0.5 * (x2 - y2));
value_t rdist = sin_0 * sin_0 + cos(x1) * cos(y1) * sin_1 * sin_1;

return 2 * asin(sqrt(rdist));
}

/**
* @tparam value_idx data type of indices
* @tparam value_t data type of values and distances
* @tparam warp_q
* @tparam thread_q
* @tparam tpb
* @param[out] out_inds output indices
* @param[out] out_dists output distances
* @param[in] index index array
* @param[in] query query array
* @param[in] n_index_rows number of rows in index array
* @param[in] k number of closest neighbors to return
*/
template <typename value_idx, typename value_t, int warp_q = 1024,
int thread_q = 8, int tpb = 128>
__global__ void haversine_knn_kernel(value_idx *out_inds, value_t *out_dists,
const value_t *index, const value_t *query,
size_t n_index_rows, int k) {
constexpr int kNumWarps = tpb / faiss::gpu::kWarpSize;

__shared__ value_t smemK[kNumWarps * warp_q];
__shared__ value_idx smemV[kNumWarps * warp_q];

faiss::gpu::BlockSelect<value_t, value_idx, false,
faiss::gpu::Comparator<value_t>, warp_q, thread_q,
tpb>
heap(faiss::gpu::Limits<value_t>::getMax(), -1, smemK, smemV, k);

// Grid is exactly sized to rows available
int limit = faiss::gpu::utils::roundDown(n_index_rows, faiss::gpu::kWarpSize);

const value_t *query_ptr = query + (blockIdx.x * 2);
value_t x1 = query_ptr[0];
value_t x2 = query_ptr[1];

int i = threadIdx.x;

for (; i < limit; i += tpb) {
const value_t *idx_ptr = index + (i * 2);
value_t y1 = idx_ptr[0];
value_t y2 = idx_ptr[1];

value_t dist = compute_haversine(x1, y1, x2, y2);

heap.add(dist, i);
}

// Handle last remainder fraction of a warp of elements
if (i < n_index_rows) {
const value_t *idx_ptr = index + (i * 2);
value_t y1 = idx_ptr[0];
value_t y2 = idx_ptr[1];

value_t dist = compute_haversine(x1, y1, x2, y2);

heap.addThreadQ(dist, i);
}

heap.reduce();

for (int i = threadIdx.x; i < k; i += tpb) {
out_dists[blockIdx.x * k + i] = smemK[i];
out_inds[blockIdx.x * k + i] = smemV[i];
}
}

/**
* Conmpute the k-nearest neighbors using the Haversine
* (great circle arc) distance. Input is assumed to have
* 2 dimensions (latitude, longitude) in radians.
* @tparam value_idx
* @tparam value_t
* @param[out] out_inds output indices array on device (size n_query_rows * k)
* @param[out] out_dists output dists array on device (size n_query_rows * k)
* @param[in] index input index array on device (size n_index_rows * 2)
* @param[in] query input query array on device (size n_query_rows * 2)
* @param[in] n_index_rows number of rows in index array
* @param[in] n_query_rows number of rows in query array
* @param[in] k number of closest neighbors to return
* @param[in] stream stream to order kernel launch
*/
template <typename value_idx, typename value_t>
void haversine_knn(value_idx *out_inds, value_t *out_dists,
const value_t *index, const value_t *query,
size_t n_index_rows, size_t n_query_rows, int k,
cudaStream_t stream) {
haversine_knn_kernel<<<n_query_rows, 128, 0, stream>>>(
out_inds, out_dists, index, query, n_index_rows, k);
}

} // namespace detail
} // namespace knn
} // namespace spatial
} // namespace raft
Loading

0 comments on commit 7091ae3

Please sign in to comment.