diff --git a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh index d154e5f92a..c20582df72 100644 --- a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh @@ -294,67 +294,66 @@ void brute_force_knn_impl( cudaStream_t stream = raft::select_stream(userStream, internalStreams, n_int_streams, i); // // TODO: Enable this once we figure out why it's causing pytest failures in cuml. - // if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true && - // (metric == raft::distance::DistanceType::L2Unexpanded || - // metric == raft::distance::DistanceType::L2SqrtUnexpanded //|| - // // metric == raft::distance::DistanceType::L2Expanded || - // // metric == raft::distance::DistanceType::L2SqrtExpanded) - // )) { - // fusedL2Knn(D, - // out_i_ptr, - // out_d_ptr, - // input[i], - // search_items, - // sizes[i], - // n, - // k, - // rowMajorIndex, - // rowMajorQuery, - // stream, - // metric); - // } else { - switch (metric) { - case raft::distance::DistanceType::Haversine: - - ASSERT(D == 2, - "Haversine distance requires 2 dimensions " - "(latitude / longitude)."); - - haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream); - break; - default: - faiss::MetricType m = build_faiss_metric(metric); - - faiss::gpu::StandardGpuResources gpu_res; - - gpu_res.noTempMemory(); - gpu_res.setDefaultStream(device, stream); - - faiss::gpu::GpuDistanceParams args; - args.metric = m; - args.metricArg = metricArg; - args.k = k; - args.dims = D; - args.vectors = input[i]; - args.vectorsRowMajor = rowMajorIndex; - args.numVectors = sizes[i]; - args.queries = search_items; - args.queriesRowMajor = rowMajorQuery; - args.numQueries = n; - args.outDistances = out_d_ptr; - args.outIndices = out_i_ptr; - - /** - * @todo: Until FAISS supports pluggable allocation strategies, - * we will not reap the benefits of the pool allocator for - * avoiding device-wide synchronizations from cudaMalloc/cudaFree - */ - bfKnn(&gpu_res, args); + if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true && + (metric == raft::distance::DistanceType::L2Unexpanded || + metric == raft::distance::DistanceType::L2SqrtUnexpanded || + metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded)) { + fusedL2Knn(D, + out_i_ptr, + out_d_ptr, + input[i], + search_items, + sizes[i], + n, + k, + rowMajorIndex, + rowMajorQuery, + stream, + metric); + } else { + switch (metric) { + case raft::distance::DistanceType::Haversine: + + ASSERT(D == 2, + "Haversine distance requires 2 dimensions " + "(latitude / longitude)."); + + haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream); + break; + default: + faiss::MetricType m = build_faiss_metric(metric); + + faiss::gpu::StandardGpuResources gpu_res; + + gpu_res.noTempMemory(); + gpu_res.setDefaultStream(device, stream); + + faiss::gpu::GpuDistanceParams args; + args.metric = m; + args.metricArg = metricArg; + args.k = k; + args.dims = D; + args.vectors = input[i]; + args.vectorsRowMajor = rowMajorIndex; + args.numVectors = sizes[i]; + args.queries = search_items; + args.queriesRowMajor = rowMajorQuery; + args.numQueries = n; + args.outDistances = out_d_ptr; + args.outIndices = out_i_ptr; + + /** + * @todo: Until FAISS supports pluggable allocation strategies, + * we will not reap the benefits of the pool allocator for + * avoiding device-wide synchronizations from cudaMalloc/cudaFree + */ + bfKnn(&gpu_res, args); + } } - } - CUDA_CHECK(cudaPeekAtLastError()); - // } + CUDA_CHECK(cudaPeekAtLastError()); + } // Sync internal streams if used. We don't need to // sync the user stream because we'll already have @@ -379,7 +378,11 @@ void brute_force_knn_impl( float p = 0.5; // standard l2 if (metric == raft::distance::DistanceType::LpUnexpanded) p = 1.0 / metricArg; raft::linalg::unaryOp( - res_D, res_D, n * k, [p] __device__(float input) { return powf(input, p); }, userStream); + res_D, + res_D, + n * k, + [p] __device__(float input) { return powf(fabsf(input), p); }, + userStream); } query_metric_processor->revert(search_items);