diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index a776ce2586..b3c4818e70 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -56,7 +56,7 @@ void tiled_brute_force_knn(const raft::device_resources& handle, size_t m, size_t n, size_t d, - int k, + size_t k, ElementType* distances, // size (m, k) IndexType* indices, // size (m, k) raft::distance::DistanceType metric, @@ -79,7 +79,7 @@ void tiled_brute_force_knn(const raft::device_resources& handle, if (max_col_tile_size && (tile_cols > max_col_tile_size)) { tile_cols = max_col_tile_size; } // tile_cols must be at least k items - tile_cols = std::max(tile_cols, static_cast(k)); + tile_cols = std::max(tile_cols, k); // stores pairwise distances for the current tile rmm::device_uvector temp_distances(tile_rows * tile_cols, stream); @@ -90,13 +90,34 @@ void tiled_brute_force_knn(const raft::device_resources& handle, rmm::device_uvector search_norms(0, stream); rmm::device_uvector index_norms(0, stream); if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded) { + metric == raft::distance::DistanceType::L2SqrtExpanded || + metric == raft::distance::DistanceType::CosineExpanded) { search_norms.resize(m, stream); index_norms.resize(n, stream); - raft::linalg::rowNorm( - search_norms.data(), search, d, m, raft::linalg::NormType::L2Norm, true, stream); - raft::linalg::rowNorm( - index_norms.data(), index, d, n, raft::linalg::NormType::L2Norm, true, stream); + // cosine needs the l2norm, where as l2 distances needs the squared norm + if (metric == raft::distance::DistanceType::CosineExpanded) { + raft::linalg::rowNorm(search_norms.data(), + search, + d, + m, + raft::linalg::NormType::L2Norm, + true, + stream, + raft::sqrt_op{}); + raft::linalg::rowNorm(index_norms.data(), + index, + d, + n, + raft::linalg::NormType::L2Norm, + true, + stream, + raft::sqrt_op{}); + } else { + raft::linalg::rowNorm( + search_norms.data(), search, d, m, raft::linalg::NormType::L2Norm, true, stream); + raft::linalg::rowNorm( + index_norms.data(), index, d, n, raft::linalg::NormType::L2Norm, true, stream); + } pairwise_metric = raft::distance::DistanceType::InnerProduct; } @@ -109,20 +130,17 @@ void tiled_brute_force_knn(const raft::device_resources& handle, // in which case the number of columns here is too high in the temp output. // adjust if necessary auto last_col_tile_size = n % tile_cols; - if (last_col_tile_size && (last_col_tile_size < static_cast(k))) { - temp_out_cols -= k - last_col_tile_size; - } + if (last_col_tile_size && (last_col_tile_size < k)) { temp_out_cols -= k - last_col_tile_size; } // if we have less than k items in the index, we should fill out the result // to indicate that we are missing items (and match behaviour in faiss) - if (n < static_cast(k)) { + if (n < k) { raft::matrix::fill(handle, - raft::make_device_matrix_view(distances, m, static_cast(k)), + raft::make_device_matrix_view(distances, m, k), std::numeric_limits::lowest()); if constexpr (std::is_signed_v) { - raft::matrix::fill( - handle, raft::make_device_matrix_view(indices, m, static_cast(k)), IndexType{-1}); + raft::matrix::fill(handle, raft::make_device_matrix_view(indices, m, k), IndexType{-1}); } } @@ -136,7 +154,7 @@ void tiled_brute_force_knn(const raft::device_resources& handle, for (size_t j = 0; j < n; j += tile_cols) { size_t current_centroid_size = std::min(tile_cols, n - j); - size_t current_k = std::min(current_centroid_size, static_cast(k)); + size_t current_k = std::min(current_centroid_size, k); // calculate the top-k elements for the current tile, by calculating the // full pairwise distance for the tile - and then selecting the top-k from that @@ -176,6 +194,21 @@ void tiled_brute_force_knn(const raft::device_resources& handle, val = distance_epilogue(val, row, col); return val; }); + } else if (metric == raft::distance::DistanceType::CosineExpanded) { + auto row_norms = search_norms.data(); + auto col_norms = index_norms.data(); + auto dist = temp_distances.data(); + + raft::linalg::map_offset( + handle, + raft::make_device_vector_view(dist, current_query_size * current_centroid_size), + [=] __device__(IndexType idx) { + IndexType row = i + (idx / current_centroid_size); + IndexType col = j + (idx % current_centroid_size); + auto val = 1.0 - dist[idx] / (row_norms[row] * col_norms[col]); + val = distance_epilogue(val, row, col); + return val; + }); } else { // if we're not l2 distance, and we have a distance epilogue - run it now if constexpr (!std::is_same_v) { @@ -310,18 +343,6 @@ void brute_force_knn_impl( id_ranges = translations; } - // perform preprocessing - std::unique_ptr> query_metric_processor = - create_processor(metric, n, D, k, rowMajorQuery, userStream); - query_metric_processor->preprocess(search_items); - - std::vector>> metric_processors(input.size()); - for (size_t i = 0; i < input.size(); i++) { - metric_processors[i] = - create_processor(metric, sizes[i], D, k, rowMajorQuery, userStream); - metric_processors[i]->preprocess(input[i]); - } - int device; RAFT_CUDA_TRY(cudaGetDevice(&device)); @@ -430,14 +451,6 @@ void brute_force_knn_impl( raft::linalg::transpose(handle, input[i], index, sizes[i], D, stream); } - // cosine/correlation are handled by metric processor, use IP distance - // for brute force knn call. - auto tiled_metric = metric; - if (metric == raft::distance::DistanceType::CosineExpanded || - metric == raft::distance::DistanceType::CorrelationExpanded) { - tiled_metric = raft::distance::DistanceType::InnerProduct; - } - tiled_brute_force_knn(stream_pool_handle, search, index, @@ -447,7 +460,7 @@ void brute_force_knn_impl( k, out_d_ptr, out_i_ptr, - tiled_metric, + metric, metricArg, 0, 0, @@ -470,12 +483,6 @@ void brute_force_knn_impl( knn_merge_parts(out_D, out_I, res_D, res_I, n, input.size(), k, userStream, trans.data()); } - query_metric_processor->revert(search_items); - query_metric_processor->postprocess(out_D); - for (size_t i = 0; i < input.size(); i++) { - metric_processors[i]->revert(input[i]); - } - if (translations == nullptr) delete id_ranges; }; diff --git a/python/pylibraft/pylibraft/test/test_brute_force.py b/python/pylibraft/pylibraft/test/test_brute_force.py index f349be892d..0bd5e6eaaf 100644 --- a/python/pylibraft/pylibraft/test/test_brute_force.py +++ b/python/pylibraft/pylibraft/test/test_brute_force.py @@ -90,9 +90,6 @@ def test_knn( expected_indices = argsort[i] gpu_dists = actual_distances[i] - if metric == "correlation" or metric == "cosine": - gpu_dists = gpu_dists[::-1] - cpu_ordered = pw_dists[i, expected_indices] np.testing.assert_allclose( cpu_ordered[:k], gpu_dists, atol=1e-4, rtol=1e-4