From 58dea2c23dd86adaee3b7bcd42cf5c24125f019f Mon Sep 17 00:00:00 2001 From: viclafargue Date: Wed, 22 Jan 2025 14:27:57 +0100 Subject: [PATCH 01/11] Fix UMAP issues with large inputs --- cpp/include/cuml/common/callback.hpp | 4 +- cpp/include/cuml/manifold/common.hpp | 6 +-- cpp/src/tsne/tsne_runner.cuh | 2 +- cpp/src/umap/fuzzy_simpl_set/naive.cuh | 34 ++++++++------- cpp/src/umap/init_embed/spectral_algo.cuh | 4 +- cpp/src/umap/knn_graph/algo.cuh | 7 ++-- cpp/src/umap/runner.cuh | 8 ++-- cpp/src/umap/simpl_set_embed/algo.cuh | 30 ++++++------- .../simpl_set_embed/optimize_batch_kernel.cuh | 42 +++++++++---------- cpp/src/umap/simpl_set_embed/runner.cuh | 4 +- cpp/src/umap/supervised.cuh | 6 +-- 11 files changed, 76 insertions(+), 71 deletions(-) diff --git a/cpp/include/cuml/common/callback.hpp b/cpp/include/cuml/common/callback.hpp index c2b99e1f6b..2f32811137 100644 --- a/cpp/include/cuml/common/callback.hpp +++ b/cpp/include/cuml/common/callback.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,7 +42,7 @@ class GraphBasedDimRedCallback : public Callback { protected: int n; - int n_components; + uint64_t n_components; bool isFloat; }; diff --git a/cpp/include/cuml/manifold/common.hpp b/cpp/include/cuml/manifold/common.hpp index 3346f9127e..12232d15ff 100644 --- a/cpp/include/cuml/manifold/common.hpp +++ b/cpp/include/cuml/manifold/common.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -55,8 +55,8 @@ struct knn_graph { template struct manifold_inputs_t { T* y; - int n; - int d; + uint64_t n; + uint64_t d; manifold_inputs_t(T* y_, int n_, int d_) : y(y_), n(n_), d(d_) {} diff --git a/cpp/src/tsne/tsne_runner.cuh b/cpp/src/tsne/tsne_runner.cuh index cdbfdd2674..9485ae0fd6 100644 --- a/cpp/src/tsne/tsne_runner.cuh +++ b/cpp/src/tsne/tsne_runner.cuh @@ -167,7 +167,7 @@ class TSNE_runner { { distance_and_perplexity(); - const auto NNZ = COO_Matrix.nnz; + const auto NNZ = (value_idx)COO_Matrix.nnz; auto* VAL = COO_Matrix.vals(); const auto* COL = COO_Matrix.cols(); const auto* ROW = COO_Matrix.rows(); diff --git a/cpp/src/umap/fuzzy_simpl_set/naive.cuh b/cpp/src/umap/fuzzy_simpl_set/naive.cuh index 41e54f1f63..b791d3d7f0 100644 --- a/cpp/src/umap/fuzzy_simpl_set/naive.cuh +++ b/cpp/src/umap/fuzzy_simpl_set/naive.cuh @@ -91,8 +91,8 @@ CUML_KERNEL void smooth_knn_dist_kernel(const value_t* knn_dists, float bandwidth = 1.0) { // row-based matrix 1 thread per row - int row = (blockIdx.x * TPB_X) + threadIdx.x; - int i = row * n_neighbors; // each thread processes one row of the dist matrix + int row = (blockIdx.x * TPB_X) + threadIdx.x; + uint64_t i = (uint64_t)row * n_neighbors; // each thread processes one row of the dist matrix if (row < n) { float target = __log2f(n_neighbors) * bandwidth; @@ -190,7 +190,7 @@ CUML_KERNEL void smooth_knn_dist_kernel(const value_t* knn_dists, * * Descriptions adapted from: https://github.com/lmcinnes/umap/blob/master/umap/umap_.py */ -template +template CUML_KERNEL void compute_membership_strength_kernel( const value_idx* knn_indices, const float* knn_dists, // nn outputs @@ -199,14 +199,14 @@ CUML_KERNEL void compute_membership_strength_kernel( value_t* vals, int* rows, int* cols, // result coo - int n, - int n_neighbors) + int n_neighbors, + uint64_t to_process) { // model params // row-based matrix is best - int idx = (blockIdx.x * TPB_X) + threadIdx.x; + uint64_t idx = (blockIdx.x * TPB_X) + threadIdx.x; - if (idx < n * n_neighbors) { + if (idx < to_process) { int row = idx / n_neighbors; // one neighbor per thread double cur_rho = rhos[row]; @@ -237,8 +237,8 @@ CUML_KERNEL void compute_membership_strength_kernel( /* * Sets up and runs the knn dist smoothing */ -template -void smooth_knn_dist(int n, +template +void smooth_knn_dist(uint64_t n, const value_idx* knn_indices, const float* knn_dists, value_t* rhos, @@ -253,7 +253,8 @@ void smooth_knn_dist(int n, rmm::device_uvector dist_means_dev(n_neighbors, stream); - raft::stats::mean(dist_means_dev.data(), knn_dists, 1, n_neighbors * n, false, false, stream); + raft::stats::mean( + dist_means_dev.data(), knn_dists, (uint64_t)1, n * n_neighbors, false, false, stream); RAFT_CUDA_TRY(cudaPeekAtLastError()); value_t mean_dist = 0.0; @@ -284,8 +285,8 @@ void smooth_knn_dist(int n, * @param params UMAPParams config object * @param stream cuda stream to use for device operations */ -template -void launcher(int n, +template +void launcher(uint64_t n, const value_idx* knn_indices, const value_t* knn_dists, int n_neighbors, @@ -328,7 +329,8 @@ void launcher(int n, * Compute graph of membership strengths */ - dim3 grid_elm(raft::ceildiv(n * n_neighbors, TPB_X), 1, 1); + uint64_t to_process = n * n_neighbors; + dim3 grid_elm(raft::ceildiv(to_process, TPB_X), 1, 1); dim3 blk_elm(TPB_X, 1, 1); compute_membership_strength_kernel<<>>(knn_indices, @@ -338,8 +340,8 @@ void launcher(int n, in.vals(), in.rows(), in.cols(), - in.n_rows, - n_neighbors); + n_neighbors, + to_process); RAFT_CUDA_TRY(cudaPeekAtLastError()); if (ML::default_logger().should_log(ML::level_enum::debug)) { @@ -365,7 +367,7 @@ void launcher(int n, }, stream); - raft::sparse::op::coo_sort(out, stream); + // raft::sparse::op::coo_sort(out, stream); } } // namespace Naive } // namespace FuzzySimplSet diff --git a/cpp/src/umap/init_embed/spectral_algo.cuh b/cpp/src/umap/init_embed/spectral_algo.cuh index cac6e8dcc0..814ef48af7 100644 --- a/cpp/src/umap/init_embed/spectral_algo.cuh +++ b/cpp/src/umap/init_embed/spectral_algo.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * Copyright (c) 2019-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,7 +44,7 @@ using namespace ML; */ template void launcher(const raft::handle_t& handle, - int n, + uint64_t n, int d, raft::sparse::COO* coo, UMAPParams* params, diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index 6617d72c00..b75ca1a103 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * Copyright (c) 2019-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -126,11 +126,12 @@ inline void launcher(const raft::handle_t& handle, RAFT_EXPECTS(graph.distances().has_value(), "return_distances for nn descent should be set to true to be used for UMAP"); - auto out_knn_dists_view = raft::make_device_matrix_view(out.knn_dists, inputsA.n, n_neighbors); + auto out_knn_dists_view = + raft::make_device_matrix_view(out.knn_dists, inputsA.n, (uint64_t)n_neighbors); raft::matrix::slice( handle, raft::make_const_mdspan(graph.distances().value()), out_knn_dists_view, coords); auto out_knn_indices_view = - raft::make_device_matrix_view(out.knn_indices, inputsA.n, n_neighbors); + raft::make_device_matrix_view(out.knn_indices, inputsA.n, (uint64_t)n_neighbors); raft::matrix::slice( handle, raft::make_const_mdspan(indices_d.view()), out_knn_indices_view, coords); } diff --git a/cpp/src/umap/runner.cuh b/cpp/src/umap/runner.cuh index 01aa6f62c7..2a69421194 100644 --- a/cpp/src/umap/runner.cuh +++ b/cpp/src/umap/runner.cuh @@ -348,7 +348,7 @@ void _fit_supervised(const raft::handle_t& handle, /** * */ -template +template void _transform(const raft::handle_t& handle, const umap_inputs& inputs, umap_inputs& orig_x_inputs, @@ -425,7 +425,7 @@ void _transform(const raft::handle_t& handle, * Compute graph of membership strengths */ - int nnz = inputs.n * params->n_neighbors; + uint64_t nnz = inputs.n * params->n_neighbors; dim3 grid_nnz(raft::ceildiv(nnz, TPB_X), 1, 1); @@ -449,14 +449,14 @@ void _transform(const raft::handle_t& handle, params->n_neighbors); RAFT_CUDA_TRY(cudaPeekAtLastError()); - rmm::device_uvector row_ind(inputs.n, stream); + rmm::device_uvector row_ind(inputs.n, stream); rmm::device_uvector ia(inputs.n, stream); raft::sparse::convert::sorted_coo_to_csr(&graph_coo, row_ind.data(), stream); raft::sparse::linalg::coo_degree(&graph_coo, ia.data(), stream); rmm::device_uvector vals_normed(graph_coo.nnz, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(vals_normed.data(), 0, graph_coo.nnz * sizeof(value_t), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(vals_normed.data(), 0, nnz * sizeof(value_t), stream)); CUML_LOG_DEBUG("Performing L1 normalization"); diff --git a/cpp/src/umap/simpl_set_embed/algo.cuh b/cpp/src/umap/simpl_set_embed/algo.cuh index 6be8b0235b..6d08874411 100644 --- a/cpp/src/umap/simpl_set_embed/algo.cuh +++ b/cpp/src/umap/simpl_set_embed/algo.cuh @@ -18,8 +18,6 @@ #include "optimize_batch_kernel.cuh" -#include - #include #include @@ -59,7 +57,8 @@ using namespace ML; * @param stream cuda stream */ template -void make_epochs_per_sample(T* weights, int weights_n, int n_epochs, T* result, cudaStream_t stream) +void make_epochs_per_sample( + T* weights, uint64_t weights_n, int n_epochs, T* result, cudaStream_t stream) { thrust::device_ptr d_weights = thrust::device_pointer_cast(weights); T weights_max = @@ -102,10 +101,10 @@ void optimization_iteration_finalization( template void apply_embedding_updates(T* head_embedding, T* head_buffer, - int head_n, + uint64_t head_n, T* tail_embedding, T* tail_buffer, - int tail_n, + uint64_t tail_n, UMAPParams* params, bool move_other, rmm::cuda_stream_view stream) @@ -195,14 +194,14 @@ T create_gradient_rounding_factor( * positive weights (neighbors in the 1-skeleton) and repelling * negative weights (non-neighbors in the 1-skeleton). */ -template +template void optimize_layout(T* head_embedding, - int head_n, + uint64_t head_n, T* tail_embedding, - int tail_n, + uint64_t tail_n, const int* head, const int* tail, - int nnz, + uint64_t nnz, T* epochs_per_sample, float gamma, UMAPParams* params, @@ -252,14 +251,13 @@ void optimize_layout(T* head_embedding, T rounding = create_gradient_rounding_factor(head, nnz, head_n, alpha, stream_view); - MLCommon::FastIntDiv tail_n_fast(tail_n); for (int n = 0; n < n_epochs; n++) { call_optimize_batch_kernel(head_embedding, d_head_buffer, head_n, tail_embedding, d_tail_buffer, - tail_n_fast, + tail_n, head, tail, nnz, @@ -298,10 +296,14 @@ void optimize_layout(T* head_embedding, * and their 1-skeletons. */ template -void launcher( - int m, int n, raft::sparse::COO* in, UMAPParams* params, T* embedding, cudaStream_t stream) +void launcher(uint64_t m, + int n, + raft::sparse::COO* in, + UMAPParams* params, + T* embedding, + cudaStream_t stream) { - int nnz = in->nnz; + uint64_t nnz = in->nnz; /** * Find vals.max() diff --git a/cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh b/cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh index 5fd34d2f3b..f36656f92d 100644 --- a/cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh +++ b/cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * Copyright (c) 2021-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -96,16 +96,16 @@ DI T truncate_gradient(T const rounding_factor, T const x) return (rounding_factor + x) - rounding_factor; } -template +template CUML_KERNEL void optimize_batch_kernel_reg(T const* head_embedding, T* head_buffer, int head_n, T const* tail_embedding, T* tail_buffer, - const MLCommon::FastIntDiv tail_n, + uint64_t tail_n, const int* head, const int* tail, - int nnz, + uint64_t nnz, T const* epochs_per_sample, T* epoch_of_next_negative_sample, T* epoch_of_next_sample, @@ -118,7 +118,7 @@ CUML_KERNEL void optimize_batch_kernel_reg(T const* head_embedding, T nsr_inv, T rounding) { - int row = (blockIdx.x * TPB_X) + threadIdx.x; + uint64_t row = (blockIdx.x * TPB_X) + threadIdx.x; if (row >= nnz) return; auto _epoch_of_next_sample = epoch_of_next_sample[row]; if (_epoch_of_next_sample > epoch) return; @@ -127,8 +127,8 @@ CUML_KERNEL void optimize_batch_kernel_reg(T const* head_embedding, /** * Positive sample stage (attractive forces) */ - int j = head[row]; - int k = tail[row]; + uint64_t j = head[row]; + uint64_t k = tail[row]; T const* current = head_embedding + (j * n_components); T const* other = tail_embedding + (k * n_components); @@ -172,9 +172,9 @@ CUML_KERNEL void optimize_batch_kernel_reg(T const* head_embedding, */ raft::random::detail::PhiloxGenerator gen((uint64_t)seed, (uint64_t)row, 0); for (int p = 0; p < n_neg_samples; p++) { - int r; + uint64_t r; gen.next(r); - int t = r % tail_n; + uint64_t t = r % tail_n; T const* negative_sample = tail_embedding + (t * n_components); T negative_sample_reg[n_components]; for (int i = 0; i < n_components; ++i) { @@ -210,16 +210,16 @@ CUML_KERNEL void optimize_batch_kernel_reg(T const* head_embedding, _epoch_of_next_negative_sample + n_neg_samples * epochs_per_negative_sample; } -template +template CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, T* head_buffer, - int head_n, + uint64_t head_n, T const* tail_embedding, T* tail_buffer, - const MLCommon::FastIntDiv tail_n, + uint64_t tail_n, const int* head, const int* tail, - int nnz, + uint64_t nnz, T const* epochs_per_sample, T* epoch_of_next_negative_sample, T* epoch_of_next_sample, @@ -233,7 +233,7 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, T rounding) { extern __shared__ T embedding_shared_mem_updates[]; - int row = (blockIdx.x * TPB_X) + threadIdx.x; + uint64_t row = (blockIdx.x * TPB_X) + threadIdx.x; if (row >= nnz) return; auto _epoch_of_next_sample = epoch_of_next_sample[row]; if (_epoch_of_next_sample > epoch) return; @@ -242,8 +242,8 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, /** * Positive sample stage (attractive forces) */ - int j = head[row]; - int k = tail[row]; + uint64_t j = head[row]; + uint64_t k = tail[row]; T const* current = head_embedding + (j * params.n_components); T const* other = tail_embedding + (k * params.n_components); @@ -293,9 +293,9 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, */ raft::random::detail::PhiloxGenerator gen((uint64_t)seed, (uint64_t)row, 0); for (int p = 0; p < n_neg_samples; p++) { - int r; + uint64_t r; gen.next(r); - int t = r % tail_n; + uint64_t t = r % tail_n; T const* negative_sample = tail_embedding + (t * params.n_components); dist_squared = rdist(current, negative_sample, params.n_components); // repulsive force between two vertices @@ -348,16 +348,16 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, * @param rounding: Floating rounding factor used to truncate the gradient update for * deterministic result. */ -template +template void call_optimize_batch_kernel(T const* head_embedding, T* head_buffer, int head_n, T const* tail_embedding, T* tail_buffer, - const MLCommon::FastIntDiv& tail_n, + const uint64_t tail_n, const int* head, const int* tail, - int nnz, + uint64_t nnz, T const* epochs_per_sample, T* epoch_of_next_negative_sample, T* epoch_of_next_sample, diff --git a/cpp/src/umap/simpl_set_embed/runner.cuh b/cpp/src/umap/simpl_set_embed/runner.cuh index dabedd4bdd..c7f64bbe80 100644 --- a/cpp/src/umap/simpl_set_embed/runner.cuh +++ b/cpp/src/umap/simpl_set_embed/runner.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * Copyright (c) 2019-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ namespace SimplSetEmbed { using namespace ML; template -void run(int m, +void run(uint64_t m, int n, raft::sparse::COO* coo, UMAPParams* params, diff --git a/cpp/src/umap/supervised.cuh b/cpp/src/umap/supervised.cuh index 1a9739f280..fad8652033 100644 --- a/cpp/src/umap/supervised.cuh +++ b/cpp/src/umap/supervised.cuh @@ -101,7 +101,7 @@ void reset_local_connectivity(raft::sparse::COO* in_coo, * and this will update the fuzzy simplicial set to respect that label * data. */ -template +template void categorical_simplicial_set_intersection(raft::sparse::COO* graph_coo, value_t* target, cudaStream_t stream, @@ -119,7 +119,7 @@ void categorical_simplicial_set_intersection(raft::sparse::COO* graph_c far_dist); } -template +template CUML_KERNEL void sset_intersection_kernel(int* row_ind1, int* cols1, value_t* vals1, @@ -177,7 +177,7 @@ CUML_KERNEL void sset_intersection_kernel(int* row_ind1, * Computes the CSR column index pointer and values * for the general simplicial set intersecftion. */ -template +template void general_simplicial_set_intersection(int* row1_ind, raft::sparse::COO* in1, int* row2_ind, From ea9a4766ac6df7fcafa2988570f71a2ecd2410bf Mon Sep 17 00:00:00 2001 From: viclafargue Date: Wed, 22 Jan 2025 16:28:36 +0100 Subject: [PATCH 02/11] Re-enable coo_sort before removing zeroes --- cpp/src/umap/fuzzy_simpl_set/naive.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/umap/fuzzy_simpl_set/naive.cuh b/cpp/src/umap/fuzzy_simpl_set/naive.cuh index b791d3d7f0..29fdcfc1a8 100644 --- a/cpp/src/umap/fuzzy_simpl_set/naive.cuh +++ b/cpp/src/umap/fuzzy_simpl_set/naive.cuh @@ -367,7 +367,7 @@ void launcher(uint64_t n, }, stream); - // raft::sparse::op::coo_sort(out, stream); + raft::sparse::op::coo_sort(out, stream); } } // namespace Naive } // namespace FuzzySimplSet From c8db94bb3bc377a3a0270ab7b5a8f49bcb61e8e5 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Thu, 23 Jan 2025 16:39:59 +0000 Subject: [PATCH 03/11] updates --- cpp/include/cuml/common/callback.hpp | 4 +-- cpp/src/umap/init_embed/random_algo.cuh | 4 +-- cpp/src/umap/runner.cuh | 9 ++--- cpp/src/umap/simpl_set_embed/algo.cuh | 36 +++++++++---------- .../simpl_set_embed/optimize_batch_kernel.cuh | 30 ++++++++-------- 5 files changed, 40 insertions(+), 43 deletions(-) diff --git a/cpp/include/cuml/common/callback.hpp b/cpp/include/cuml/common/callback.hpp index 2f32811137..c2b99e1f6b 100644 --- a/cpp/include/cuml/common/callback.hpp +++ b/cpp/include/cuml/common/callback.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2025, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,7 +42,7 @@ class GraphBasedDimRedCallback : public Callback { protected: int n; - uint64_t n_components; + int n_components; bool isFloat; }; diff --git a/cpp/src/umap/init_embed/random_algo.cuh b/cpp/src/umap/init_embed/random_algo.cuh index 217531e548..76ad23e784 100644 --- a/cpp/src/umap/init_embed/random_algo.cuh +++ b/cpp/src/umap/init_embed/random_algo.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * Copyright (c) 2019-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,7 +27,7 @@ namespace RandomInit { using namespace ML; template -void launcher(int n, int d, UMAPParams* params, T* embedding, cudaStream_t stream) +void launcher(uint64_t n, int d, UMAPParams* params, T* embedding, cudaStream_t stream) { uint64_t seed = params->random_state; diff --git a/cpp/src/umap/runner.cuh b/cpp/src/umap/runner.cuh index 2a69421194..500a9658fa 100644 --- a/cpp/src/umap/runner.cuh +++ b/cpp/src/umap/runner.cuh @@ -425,7 +425,7 @@ void _transform(const raft::handle_t& handle, * Compute graph of membership strengths */ - uint64_t nnz = inputs.n * params->n_neighbors; + uint64_t nnz = (uint64_t)inputs.n * params->n_neighbors; dim3 grid_nnz(raft::ceildiv(nnz, TPB_X), 1, 1); @@ -450,13 +450,11 @@ void _transform(const raft::handle_t& handle, RAFT_CUDA_TRY(cudaPeekAtLastError()); rmm::device_uvector row_ind(inputs.n, stream); - rmm::device_uvector ia(inputs.n, stream); raft::sparse::convert::sorted_coo_to_csr(&graph_coo, row_ind.data(), stream); - raft::sparse::linalg::coo_degree(&graph_coo, ia.data(), stream); rmm::device_uvector vals_normed(graph_coo.nnz, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(vals_normed.data(), 0, nnz * sizeof(value_t), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(vals_normed.data(), 0, graph_coo.nnz * sizeof(value_t), stream)); CUML_LOG_DEBUG("Performing L1 normalization"); @@ -471,9 +469,6 @@ void _transform(const raft::handle_t& handle, params->n_components, transformed, params->n_neighbors); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - - RAFT_CUDA_TRY(cudaMemsetAsync(ia.data(), 0.0, ia.size() * sizeof(int), stream)); RAFT_CUDA_TRY(cudaPeekAtLastError()); diff --git a/cpp/src/umap/simpl_set_embed/algo.cuh b/cpp/src/umap/simpl_set_embed/algo.cuh index 6d08874411..fd97d4f2ac 100644 --- a/cpp/src/umap/simpl_set_embed/algo.cuh +++ b/cpp/src/umap/simpl_set_embed/algo.cuh @@ -110,27 +110,26 @@ void apply_embedding_updates(T* head_embedding, rmm::cuda_stream_view stream) { ASSERT(params->deterministic, "Only used when deterministic is set to true."); + uint64_t n_components = params->n_components; if (move_other) { - auto n_components = params->n_components; - thrust::for_each( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0u), - thrust::make_counting_iterator(0u) + std::max(head_n, tail_n) * params->n_components, - [=] __device__(uint32_t i) { - if (i < head_n * n_components) { - head_embedding[i] += head_buffer[i]; - head_buffer[i] = 0.0f; - } - if (i < tail_n * n_components) { - tail_embedding[i] += tail_buffer[i]; - tail_buffer[i] = 0.0f; - } - }); + thrust::for_each(rmm::exec_policy(stream), + thrust::make_counting_iterator(0u), + thrust::make_counting_iterator(0u) + std::max(head_n, tail_n) * n_components, + [=] __device__(uint32_t i) { + if (i < head_n * n_components) { + head_embedding[i] += head_buffer[i]; + head_buffer[i] = 0.0f; + } + if (i < tail_n * n_components) { + tail_embedding[i] += tail_buffer[i]; + tail_buffer[i] = 0.0f; + } + }); } else { // No need to update reference embedding thrust::for_each(rmm::exec_policy(stream), thrust::make_counting_iterator(0u), - thrust::make_counting_iterator(0u) + head_n * params->n_components, + thrust::make_counting_iterator(0u) + head_n * n_components, [=] __device__(uint32_t i) { head_embedding[i] += head_buffer[i]; head_buffer[i] = 0.0f; @@ -232,12 +231,13 @@ void optimize_layout(T* head_embedding, T* d_head_buffer = head_embedding; T* d_tail_buffer = tail_embedding; if (params->deterministic) { - head_buffer.resize(head_n * params->n_components, stream_view); + uint64_t n_components = params->n_components; + head_buffer.resize(head_n * n_components, stream_view); RAFT_CUDA_TRY( cudaMemsetAsync(head_buffer.data(), '\0', sizeof(T) * head_buffer.size(), stream)); // No need for tail if it's not being written. if (move_other) { - tail_buffer.resize(tail_n * params->n_components, stream_view); + tail_buffer.resize(tail_n * n_components, stream_view); RAFT_CUDA_TRY( cudaMemsetAsync(tail_buffer.data(), '\0', sizeof(T) * tail_buffer.size(), stream)); } diff --git a/cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh b/cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh index f36656f92d..3abf51c785 100644 --- a/cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh +++ b/cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh @@ -99,7 +99,7 @@ DI T truncate_gradient(T const rounding_factor, T const x) template CUML_KERNEL void optimize_batch_kernel_reg(T const* head_embedding, T* head_buffer, - int head_n, + uint64_t head_n, T const* tail_embedding, T* tail_buffer, uint64_t tail_n, @@ -242,17 +242,19 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, /** * Positive sample stage (attractive forces) */ + uint64_t n_components = params.n_components; + uint64_t j = head[row]; uint64_t k = tail[row]; - T const* current = head_embedding + (j * params.n_components); - T const* other = tail_embedding + (k * params.n_components); + T const* current = head_embedding + (j * n_components); + T const* other = tail_embedding + (k * n_components); - T* cur_write = head_buffer + (j * params.n_components); - T* oth_write = tail_buffer + (k * params.n_components); + T* cur_write = head_buffer + (j * n_components); + T* oth_write = tail_buffer + (k * n_components); T* current_buffer{nullptr}; if (use_shared_mem) { current_buffer = (T*)embedding_shared_mem_updates + threadIdx.x; } - auto dist_squared = rdist(current, other, params.n_components); + auto dist_squared = rdist(current, other, n_components); // Attractive force between the two vertices, since they // are connected by an edge in the 1-skeleton. auto attractive_grad_coeff = T(0.0); @@ -264,7 +266,7 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, * (update `other` embedding only if we are * performing unsupervised training). */ - for (int d = 0; d < params.n_components; d++) { + for (int d = 0; d < n_components; d++) { auto grad_d = clip(attractive_grad_coeff * (current[d] - other[d]), T(-4.0), T(4.0)); grad_d *= alpha; if (use_shared_mem) { @@ -279,7 +281,7 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, // storing gradients for negative samples back to global memory if (use_shared_mem && move_other) { __syncthreads(); - for (int d = 0; d < params.n_components; d++) { + for (int d = 0; d < n_components; d++) { auto grad = current_buffer[d * TPB_X]; raft::myAtomicAdd((T*)oth_write + d, truncate_gradient(rounding, -grad)); } @@ -296,8 +298,8 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, uint64_t r; gen.next(r); uint64_t t = r % tail_n; - T const* negative_sample = tail_embedding + (t * params.n_components); - dist_squared = rdist(current, negative_sample, params.n_components); + T const* negative_sample = tail_embedding + (t * n_components); + dist_squared = rdist(current, negative_sample, n_components); // repulsive force between two vertices auto repulsive_grad_coeff = T(0.0); if (dist_squared > T(0.0)) { @@ -309,7 +311,7 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, * (which has been negatively sampled) by updating * their 'weights' to push them farther in Euclidean space. */ - for (int d = 0; d < params.n_components; d++) { + for (int d = 0; d < n_components; d++) { auto grad_d = T(0.0); if (repulsive_grad_coeff > T(0.0)) grad_d = clip(repulsive_grad_coeff * (current[d] - negative_sample[d]), T(-4.0), T(4.0)); @@ -327,7 +329,7 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, // storing gradients for positive samples back to global memory if (use_shared_mem) { __syncthreads(); - for (int d = 0; d < params.n_components; d++) { + for (int d = 0; d < n_components; d++) { raft::myAtomicAdd((T*)cur_write + d, truncate_gradient(rounding, current_buffer[d * TPB_X])); } @@ -351,10 +353,10 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, template void call_optimize_batch_kernel(T const* head_embedding, T* head_buffer, - int head_n, + uint64_t head_n, T const* tail_embedding, T* tail_buffer, - const uint64_t tail_n, + uint64_t tail_n, const int* head, const int* tail, uint64_t nnz, From fb2668136c4b7140caca604848473bfdbf602536 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Thu, 23 Jan 2025 19:05:50 +0000 Subject: [PATCH 04/11] fix issue --- cpp/src/umap/fuzzy_simpl_set/naive.cuh | 4 ++-- cpp/src/umap/runner.cuh | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/cpp/src/umap/fuzzy_simpl_set/naive.cuh b/cpp/src/umap/fuzzy_simpl_set/naive.cuh index 29fdcfc1a8..25b4a4fa19 100644 --- a/cpp/src/umap/fuzzy_simpl_set/naive.cuh +++ b/cpp/src/umap/fuzzy_simpl_set/naive.cuh @@ -329,10 +329,10 @@ void launcher(uint64_t n, * Compute graph of membership strengths */ - uint64_t to_process = n * n_neighbors; - dim3 grid_elm(raft::ceildiv(to_process, TPB_X), 1, 1); + dim3 grid_elm(raft::ceildiv(n * n_neighbors, TPB_X), 1, 1); dim3 blk_elm(TPB_X, 1, 1); + uint64_t to_process = (uint64_t)in.n_rows * n_neighbors; compute_membership_strength_kernel<<>>(knn_indices, knn_dists, sigmas.data(), diff --git a/cpp/src/umap/runner.cuh b/cpp/src/umap/runner.cuh index 500a9658fa..cf6c89da57 100644 --- a/cpp/src/umap/runner.cuh +++ b/cpp/src/umap/runner.cuh @@ -437,6 +437,7 @@ void _transform(const raft::handle_t& handle, raft::sparse::COO graph_coo(stream, nnz, inputs.n, inputs.n); + uint64_t to_process = (uint64_t)graph_coo.n_rows * params->n_neighbors; FuzzySimplSetImpl::compute_membership_strength_kernel <<>>(knn_graph.knn_indices, knn_graph.knn_dists, @@ -445,8 +446,8 @@ void _transform(const raft::handle_t& handle, graph_coo.vals(), graph_coo.rows(), graph_coo.cols(), - graph_coo.n_rows, - params->n_neighbors); + params->n_neighbors, + to_process); RAFT_CUDA_TRY(cudaPeekAtLastError()); rmm::device_uvector row_ind(inputs.n, stream); From 038f31e12d3f8ad2fe469b8c06603da0bb8a63f1 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 24 Jan 2025 11:31:19 +0100 Subject: [PATCH 05/11] answering review --- cpp/include/cuml/manifold/common.hpp | 2 ++ cpp/src/tsne/tsne_runner.cuh | 3 ++- cpp/src/umap/fuzzy_simpl_set/naive.cuh | 9 +++++---- cpp/src/umap/init_embed/random_algo.cuh | 2 ++ cpp/src/umap/init_embed/spectral_algo.cuh | 2 ++ cpp/src/umap/knn_graph/algo.cuh | 5 +++-- cpp/src/umap/runner.cuh | 6 ++++-- cpp/src/umap/simpl_set_embed/algo.cuh | 1 + cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh | 2 ++ cpp/src/umap/simpl_set_embed/runner.cuh | 2 ++ cpp/src/umap/supervised.cuh | 2 ++ 11 files changed, 27 insertions(+), 9 deletions(-) diff --git a/cpp/include/cuml/manifold/common.hpp b/cpp/include/cuml/manifold/common.hpp index 12232d15ff..99dd0e97d7 100644 --- a/cpp/include/cuml/manifold/common.hpp +++ b/cpp/include/cuml/manifold/common.hpp @@ -16,6 +16,8 @@ #pragma once +#include + namespace ML { // Dense input uses int64_t until FAISS is updated diff --git a/cpp/src/tsne/tsne_runner.cuh b/cpp/src/tsne/tsne_runner.cuh index 9485ae0fd6..b1b58561be 100644 --- a/cpp/src/tsne/tsne_runner.cuh +++ b/cpp/src/tsne/tsne_runner.cuh @@ -37,6 +37,7 @@ #include #include +#include namespace ML { @@ -167,7 +168,7 @@ class TSNE_runner { { distance_and_perplexity(); - const auto NNZ = (value_idx)COO_Matrix.nnz; + const auto NNZ = value_idx{COO_Matrix.nnz}; auto* VAL = COO_Matrix.vals(); const auto* COL = COO_Matrix.cols(); const auto* ROW = COO_Matrix.rows(); diff --git a/cpp/src/umap/fuzzy_simpl_set/naive.cuh b/cpp/src/umap/fuzzy_simpl_set/naive.cuh index 25b4a4fa19..f1a15514fb 100644 --- a/cpp/src/umap/fuzzy_simpl_set/naive.cuh +++ b/cpp/src/umap/fuzzy_simpl_set/naive.cuh @@ -30,6 +30,7 @@ #include +#include #include #include @@ -92,7 +93,7 @@ CUML_KERNEL void smooth_knn_dist_kernel(const value_t* knn_dists, { // row-based matrix 1 thread per row int row = (blockIdx.x * TPB_X) + threadIdx.x; - uint64_t i = (uint64_t)row * n_neighbors; // each thread processes one row of the dist matrix + uint64_t i = uint64_t{row} * n_neighbors; // each thread processes one row of the dist matrix if (row < n) { float target = __log2f(n_neighbors) * bandwidth; @@ -254,7 +255,7 @@ void smooth_knn_dist(uint64_t n, rmm::device_uvector dist_means_dev(n_neighbors, stream); raft::stats::mean( - dist_means_dev.data(), knn_dists, (uint64_t)1, n * n_neighbors, false, false, stream); + dist_means_dev.data(), knn_dists, uint64_t{1}, n * n_neighbors, false, false, stream); RAFT_CUDA_TRY(cudaPeekAtLastError()); value_t mean_dist = 0.0; @@ -329,10 +330,10 @@ void launcher(uint64_t n, * Compute graph of membership strengths */ - dim3 grid_elm(raft::ceildiv(n * n_neighbors, TPB_X), 1, 1); + uint64_t to_process = {in.n_rows} * n_neighbors; + dim3 grid_elm(raft::ceildiv(to_process, TPB_X), 1, 1); dim3 blk_elm(TPB_X, 1, 1); - uint64_t to_process = (uint64_t)in.n_rows * n_neighbors; compute_membership_strength_kernel<<>>(knn_indices, knn_dists, sigmas.data(), diff --git a/cpp/src/umap/init_embed/random_algo.cuh b/cpp/src/umap/init_embed/random_algo.cuh index 76ad23e784..d8afbfe42f 100644 --- a/cpp/src/umap/init_embed/random_algo.cuh +++ b/cpp/src/umap/init_embed/random_algo.cuh @@ -20,6 +20,8 @@ #include +#include + namespace UMAPAlgo { namespace InitEmbed { namespace RandomInit { diff --git a/cpp/src/umap/init_embed/spectral_algo.cuh b/cpp/src/umap/init_embed/spectral_algo.cuh index 814ef48af7..27e1350497 100644 --- a/cpp/src/umap/init_embed/spectral_algo.cuh +++ b/cpp/src/umap/init_embed/spectral_algo.cuh @@ -29,6 +29,8 @@ #include #include +#include + #include namespace UMAPAlgo { diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index b75ca1a103..1be525d3e6 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -35,6 +35,7 @@ #include #include +#include #include @@ -127,11 +128,11 @@ inline void launcher(const raft::handle_t& handle, RAFT_EXPECTS(graph.distances().has_value(), "return_distances for nn descent should be set to true to be used for UMAP"); auto out_knn_dists_view = - raft::make_device_matrix_view(out.knn_dists, inputsA.n, (uint64_t)n_neighbors); + raft::make_device_matrix_view(out.knn_dists, inputsA.n, uint64_t{n_neighbors}); raft::matrix::slice( handle, raft::make_const_mdspan(graph.distances().value()), out_knn_dists_view, coords); auto out_knn_indices_view = - raft::make_device_matrix_view(out.knn_indices, inputsA.n, (uint64_t)n_neighbors); + raft::make_device_matrix_view(out.knn_indices, inputsA.n, uint64_t{n_neighbors}); raft::matrix::slice( handle, raft::make_const_mdspan(indices_d.view()), out_knn_indices_view, coords); } diff --git a/cpp/src/umap/runner.cuh b/cpp/src/umap/runner.cuh index cf6c89da57..fdfa798cdf 100644 --- a/cpp/src/umap/runner.cuh +++ b/cpp/src/umap/runner.cuh @@ -47,6 +47,8 @@ #include #include +#include + #include namespace UMAPAlgo { @@ -425,7 +427,7 @@ void _transform(const raft::handle_t& handle, * Compute graph of membership strengths */ - uint64_t nnz = (uint64_t)inputs.n * params->n_neighbors; + uint64_t nnz = uint64_t{inputs.n} * params->n_neighbors; dim3 grid_nnz(raft::ceildiv(nnz, TPB_X), 1, 1); @@ -437,7 +439,7 @@ void _transform(const raft::handle_t& handle, raft::sparse::COO graph_coo(stream, nnz, inputs.n, inputs.n); - uint64_t to_process = (uint64_t)graph_coo.n_rows * params->n_neighbors; + uint64_t to_process = uint64_t{graph_coo.n_rows} * params->n_neighbors; FuzzySimplSetImpl::compute_membership_strength_kernel <<>>(knn_graph.knn_indices, knn_graph.knn_dists, diff --git a/cpp/src/umap/simpl_set_embed/algo.cuh b/cpp/src/umap/simpl_set_embed/algo.cuh index fd97d4f2ac..66e432b103 100644 --- a/cpp/src/umap/simpl_set_embed/algo.cuh +++ b/cpp/src/umap/simpl_set_embed/algo.cuh @@ -36,6 +36,7 @@ #include #include +#include #include #include diff --git a/cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh b/cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh index 3abf51c785..69a5b53a61 100644 --- a/cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh +++ b/cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh @@ -24,6 +24,8 @@ #include #include +#include + #include namespace UMAPAlgo { diff --git a/cpp/src/umap/simpl_set_embed/runner.cuh b/cpp/src/umap/simpl_set_embed/runner.cuh index c7f64bbe80..c64bb98113 100644 --- a/cpp/src/umap/simpl_set_embed/runner.cuh +++ b/cpp/src/umap/simpl_set_embed/runner.cuh @@ -22,6 +22,8 @@ #include +#include + namespace UMAPAlgo { namespace SimplSetEmbed { diff --git a/cpp/src/umap/supervised.cuh b/cpp/src/umap/supervised.cuh index fad8652033..a1afdc1d63 100644 --- a/cpp/src/umap/supervised.cuh +++ b/cpp/src/umap/supervised.cuh @@ -46,6 +46,8 @@ #include #include +#include + namespace UMAPAlgo { namespace Supervised { From fb70daf4a55ec14d46461818da36692a04ffac22 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 24 Jan 2025 16:30:57 +0100 Subject: [PATCH 06/11] fix small issue --- cpp/src/umap/simpl_set_embed/algo.cuh | 19 +++++++------------ .../simpl_set_embed/optimize_batch_kernel.cuh | 12 +++--------- cpp/src/umap/simpl_set_embed/runner.cuh | 2 +- 3 files changed, 11 insertions(+), 22 deletions(-) diff --git a/cpp/src/umap/simpl_set_embed/algo.cuh b/cpp/src/umap/simpl_set_embed/algo.cuh index 66e432b103..b5429b78af 100644 --- a/cpp/src/umap/simpl_set_embed/algo.cuh +++ b/cpp/src/umap/simpl_set_embed/algo.cuh @@ -102,10 +102,10 @@ void optimization_iteration_finalization( template void apply_embedding_updates(T* head_embedding, T* head_buffer, - uint64_t head_n, + int head_n, T* tail_embedding, T* tail_buffer, - uint64_t tail_n, + int tail_n, UMAPParams* params, bool move_other, rmm::cuda_stream_view stream) @@ -169,7 +169,7 @@ T create_rounding_factor(T max_abs, int n) template T create_gradient_rounding_factor( - const int* head, int nnz, int n_samples, T alpha, rmm::cuda_stream_view stream) + const int* head, uint64_t nnz, int n_samples, T alpha, rmm::cuda_stream_view stream) { rmm::device_uvector buffer(n_samples, stream); // calculate the maximum number of edges connected to 1 vertex. @@ -196,9 +196,9 @@ T create_gradient_rounding_factor( */ template void optimize_layout(T* head_embedding, - uint64_t head_n, + int head_n, T* tail_embedding, - uint64_t tail_n, + int tail_n, const int* head, const int* tail, uint64_t nnz, @@ -255,7 +255,6 @@ void optimize_layout(T* head_embedding, for (int n = 0; n < n_epochs; n++) { call_optimize_batch_kernel(head_embedding, d_head_buffer, - head_n, tail_embedding, d_tail_buffer, tail_n, @@ -297,12 +296,8 @@ void optimize_layout(T* head_embedding, * and their 1-skeletons. */ template -void launcher(uint64_t m, - int n, - raft::sparse::COO* in, - UMAPParams* params, - T* embedding, - cudaStream_t stream) +void launcher( + int m, int n, raft::sparse::COO* in, UMAPParams* params, T* embedding, cudaStream_t stream) { uint64_t nnz = in->nnz; diff --git a/cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh b/cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh index 69a5b53a61..a8c2c127f9 100644 --- a/cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh +++ b/cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh @@ -101,10 +101,9 @@ DI T truncate_gradient(T const rounding_factor, T const x) template CUML_KERNEL void optimize_batch_kernel_reg(T const* head_embedding, T* head_buffer, - uint64_t head_n, T const* tail_embedding, T* tail_buffer, - uint64_t tail_n, + int tail_n, const int* head, const int* tail, uint64_t nnz, @@ -215,10 +214,9 @@ CUML_KERNEL void optimize_batch_kernel_reg(T const* head_embedding, template CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, T* head_buffer, - uint64_t head_n, T const* tail_embedding, T* tail_buffer, - uint64_t tail_n, + int tail_n, const int* head, const int* tail, uint64_t nnz, @@ -355,10 +353,9 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, template void call_optimize_batch_kernel(T const* head_embedding, T* head_buffer, - uint64_t head_n, T const* tail_embedding, T* tail_buffer, - uint64_t tail_n, + int tail_n, const int* head, const int* tail, uint64_t nnz, @@ -384,7 +381,6 @@ void call_optimize_batch_kernel(T const* head_embedding, // multicore implementation with registers optimize_batch_kernel_reg<<>>(head_embedding, head_buffer, - head_n, tail_embedding, tail_buffer, tail_n, @@ -407,7 +403,6 @@ void call_optimize_batch_kernel(T const* head_embedding, optimize_batch_kernel <<>>(head_embedding, head_buffer, - head_n, tail_embedding, tail_buffer, tail_n, @@ -429,7 +424,6 @@ void call_optimize_batch_kernel(T const* head_embedding, // multicore implementation without shared memory optimize_batch_kernel<<>>(head_embedding, head_buffer, - head_n, tail_embedding, tail_buffer, tail_n, diff --git a/cpp/src/umap/simpl_set_embed/runner.cuh b/cpp/src/umap/simpl_set_embed/runner.cuh index c64bb98113..6c0041fb5c 100644 --- a/cpp/src/umap/simpl_set_embed/runner.cuh +++ b/cpp/src/umap/simpl_set_embed/runner.cuh @@ -31,7 +31,7 @@ namespace SimplSetEmbed { using namespace ML; template -void run(uint64_t m, +void run(int m, int n, raft::sparse::COO* coo, UMAPParams* params, From ca863926006df4408539f65c428c8f2315dd0548 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 24 Jan 2025 19:32:51 +0100 Subject: [PATCH 07/11] typos --- cpp/src/umap/fuzzy_simpl_set/naive.cuh | 2 +- cpp/src/umap/knn_graph/algo.cuh | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/src/umap/fuzzy_simpl_set/naive.cuh b/cpp/src/umap/fuzzy_simpl_set/naive.cuh index f1a15514fb..ba75152868 100644 --- a/cpp/src/umap/fuzzy_simpl_set/naive.cuh +++ b/cpp/src/umap/fuzzy_simpl_set/naive.cuh @@ -330,7 +330,7 @@ void launcher(uint64_t n, * Compute graph of membership strengths */ - uint64_t to_process = {in.n_rows} * n_neighbors; + uint64_t to_process = (uint64_t)in.n_rows * n_neighbors; dim3 grid_elm(raft::ceildiv(to_process, TPB_X), 1, 1); dim3 blk_elm(TPB_X, 1, 1); diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index 1be525d3e6..551367e685 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -128,11 +128,11 @@ inline void launcher(const raft::handle_t& handle, RAFT_EXPECTS(graph.distances().has_value(), "return_distances for nn descent should be set to true to be used for UMAP"); auto out_knn_dists_view = - raft::make_device_matrix_view(out.knn_dists, inputsA.n, uint64_t{n_neighbors}); + raft::make_device_matrix_view(out.knn_dists, inputsA.n, (uint64_t)n_neighbors); raft::matrix::slice( handle, raft::make_const_mdspan(graph.distances().value()), out_knn_dists_view, coords); auto out_knn_indices_view = - raft::make_device_matrix_view(out.knn_indices, inputsA.n, uint64_t{n_neighbors}); + raft::make_device_matrix_view(out.knn_indices, inputsA.n, (uint64_t)n_neighbors); raft::matrix::slice( handle, raft::make_const_mdspan(indices_d.view()), out_knn_indices_view, coords); } From 72a83ab1d8283d60e3dd64d006e2660a473124a7 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 24 Jan 2025 19:53:15 +0100 Subject: [PATCH 08/11] typos --- cpp/src/tsne/tsne_runner.cuh | 2 +- cpp/src/umap/fuzzy_simpl_set/naive.cuh | 7 ++++--- cpp/src/umap/knn_graph/algo.cuh | 4 ++-- cpp/src/umap/runner.cuh | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/cpp/src/tsne/tsne_runner.cuh b/cpp/src/tsne/tsne_runner.cuh index b1b58561be..b019f4e10d 100644 --- a/cpp/src/tsne/tsne_runner.cuh +++ b/cpp/src/tsne/tsne_runner.cuh @@ -168,7 +168,7 @@ class TSNE_runner { { distance_and_perplexity(); - const auto NNZ = value_idx{COO_Matrix.nnz}; + const auto NNZ = static_cast(COO_Matrix.nnz); auto* VAL = COO_Matrix.vals(); const auto* COL = COO_Matrix.cols(); const auto* ROW = COO_Matrix.rows(); diff --git a/cpp/src/umap/fuzzy_simpl_set/naive.cuh b/cpp/src/umap/fuzzy_simpl_set/naive.cuh index ba75152868..0c3bab81f3 100644 --- a/cpp/src/umap/fuzzy_simpl_set/naive.cuh +++ b/cpp/src/umap/fuzzy_simpl_set/naive.cuh @@ -92,8 +92,9 @@ CUML_KERNEL void smooth_knn_dist_kernel(const value_t* knn_dists, float bandwidth = 1.0) { // row-based matrix 1 thread per row - int row = (blockIdx.x * TPB_X) + threadIdx.x; - uint64_t i = uint64_t{row} * n_neighbors; // each thread processes one row of the dist matrix + int row = (blockIdx.x * TPB_X) + threadIdx.x; + uint64_t i = + static_cast(row) * n_neighbors; // each thread processes one row of the dist matrix if (row < n) { float target = __log2f(n_neighbors) * bandwidth; @@ -330,7 +331,7 @@ void launcher(uint64_t n, * Compute graph of membership strengths */ - uint64_t to_process = (uint64_t)in.n_rows * n_neighbors; + uint64_t to_process = static_cast(in.n_rows) * n_neighbors; dim3 grid_elm(raft::ceildiv(to_process, TPB_X), 1, 1); dim3 blk_elm(TPB_X, 1, 1); diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index 551367e685..7a342ee31f 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -128,11 +128,11 @@ inline void launcher(const raft::handle_t& handle, RAFT_EXPECTS(graph.distances().has_value(), "return_distances for nn descent should be set to true to be used for UMAP"); auto out_knn_dists_view = - raft::make_device_matrix_view(out.knn_dists, inputsA.n, (uint64_t)n_neighbors); + raft::make_device_matrix_view(out.knn_dists, inputsA.n, static_cast(n_neighbors)); raft::matrix::slice( handle, raft::make_const_mdspan(graph.distances().value()), out_knn_dists_view, coords); auto out_knn_indices_view = - raft::make_device_matrix_view(out.knn_indices, inputsA.n, (uint64_t)n_neighbors); + raft::make_device_matrix_view(out.knn_indices, inputsA.n, static_cast(n_neighbors)); raft::matrix::slice( handle, raft::make_const_mdspan(indices_d.view()), out_knn_indices_view, coords); } diff --git a/cpp/src/umap/runner.cuh b/cpp/src/umap/runner.cuh index fdfa798cdf..3439851b26 100644 --- a/cpp/src/umap/runner.cuh +++ b/cpp/src/umap/runner.cuh @@ -439,7 +439,7 @@ void _transform(const raft::handle_t& handle, raft::sparse::COO graph_coo(stream, nnz, inputs.n, inputs.n); - uint64_t to_process = uint64_t{graph_coo.n_rows} * params->n_neighbors; + uint64_t to_process = static_cast(graph_coo.n_rows) * params->n_neighbors; FuzzySimplSetImpl::compute_membership_strength_kernel <<>>(knn_graph.knn_indices, knn_graph.knn_dists, From a7134dd7e391d65ef76b6c274c309a70299cc73c Mon Sep 17 00:00:00 2001 From: viclafargue Date: Mon, 27 Jan 2025 16:12:49 +0000 Subject: [PATCH 09/11] compilation fix --- cpp/src/umap/init_embed/spectral_algo.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/src/umap/init_embed/spectral_algo.cuh b/cpp/src/umap/init_embed/spectral_algo.cuh index 27e1350497..288d2373fe 100644 --- a/cpp/src/umap/init_embed/spectral_algo.cuh +++ b/cpp/src/umap/init_embed/spectral_algo.cuh @@ -54,7 +54,8 @@ void launcher(const raft::handle_t& handle, { cudaStream_t stream = handle.get_stream(); - ASSERT(n > params->n_components, "Spectral layout requires n_samples > n_components"); + ASSERT(n > static_cast(params->n_components), + "Spectral layout requires n_samples > n_components"); rmm::device_uvector tmp_storage(n * params->n_components, stream); From 9abe8425285b9fc5306c971b5e57549ab4820f39 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Wed, 29 Jan 2025 19:38:49 +0000 Subject: [PATCH 10/11] changes so far --- cpp/src/tsne/tsne_runner.cuh | 2 +- cpp/src/umap/init_embed/spectral_algo.cuh | 2 +- cpp/src/umap/runner.cuh | 19 ++++++++++++------- cpp/src/umap/simpl_set_embed/algo.cuh | 12 ++++++------ cpp/src/umap/supervised.cuh | 6 +++--- 5 files changed, 23 insertions(+), 18 deletions(-) diff --git a/cpp/src/tsne/tsne_runner.cuh b/cpp/src/tsne/tsne_runner.cuh index b019f4e10d..c1e63252f8 100644 --- a/cpp/src/tsne/tsne_runner.cuh +++ b/cpp/src/tsne/tsne_runner.cuh @@ -168,7 +168,7 @@ class TSNE_runner { { distance_and_perplexity(); - const auto NNZ = static_cast(COO_Matrix.nnz); + const auto NNZ = COO_Matrix.nnz; auto* VAL = COO_Matrix.vals(); const auto* COL = COO_Matrix.cols(); const auto* ROW = COO_Matrix.rows(); diff --git a/cpp/src/umap/init_embed/spectral_algo.cuh b/cpp/src/umap/init_embed/spectral_algo.cuh index 288d2373fe..86986dad90 100644 --- a/cpp/src/umap/init_embed/spectral_algo.cuh +++ b/cpp/src/umap/init_embed/spectral_algo.cuh @@ -65,7 +65,7 @@ void launcher(const raft::handle_t& handle, coo->rows(), coo->cols(), coo->vals(), - coo->nnz, + coo->safe_nnz, n, params->n_components, tmp_storage.data(), diff --git a/cpp/src/umap/runner.cuh b/cpp/src/umap/runner.cuh index 3439851b26..845a0633b6 100644 --- a/cpp/src/umap/runner.cuh +++ b/cpp/src/umap/runner.cuh @@ -456,13 +456,18 @@ void _transform(const raft::handle_t& handle, raft::sparse::convert::sorted_coo_to_csr(&graph_coo, row_ind.data(), stream); - rmm::device_uvector vals_normed(graph_coo.nnz, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(vals_normed.data(), 0, graph_coo.nnz * sizeof(value_t), stream)); + rmm::device_uvector vals_normed(graph_coo.safe_nnz, stream); + RAFT_CUDA_TRY( + cudaMemsetAsync(vals_normed.data(), 0, graph_coo.safe_nnz * sizeof(value_t), stream)); CUML_LOG_DEBUG("Performing L1 normalization"); - raft::sparse::linalg::csr_row_normalize_l1( - row_ind.data(), graph_coo.vals(), graph_coo.nnz, graph_coo.n_rows, vals_normed.data(), stream); + raft::sparse::linalg::csr_row_normalize_l1(row_ind.data(), + graph_coo.vals(), + graph_coo.safe_nnz, + graph_coo.n_rows, + vals_normed.data(), + stream); init_transform<<>>(graph_coo.cols(), vals_normed.data(), @@ -497,7 +502,7 @@ void _transform(const raft::handle_t& handle, raft::linalg::unaryOp( graph_coo.vals(), graph_coo.vals(), - graph_coo.nnz, + graph_coo.safe_nnz, [=] __device__(value_t input) { if (input < (max / float(n_epochs))) return 0.0f; @@ -520,7 +525,7 @@ void _transform(const raft::handle_t& handle, rmm::device_uvector epochs_per_sample(nnz, stream); SimplSetEmbedImpl::make_epochs_per_sample( - comp_coo.vals(), comp_coo.nnz, n_epochs, epochs_per_sample.data(), stream); + comp_coo.vals(), comp_coo.safe_nnz, n_epochs, epochs_per_sample.data(), stream); CUML_LOG_DEBUG("Performing optimization"); @@ -537,7 +542,7 @@ void _transform(const raft::handle_t& handle, embedding_n, comp_coo.rows(), comp_coo.cols(), - comp_coo.nnz, + comp_coo.safe_nnz, epochs_per_sample.data(), params->repulsion_strength, params, diff --git a/cpp/src/umap/simpl_set_embed/algo.cuh b/cpp/src/umap/simpl_set_embed/algo.cuh index b5429b78af..5090ed9d8e 100644 --- a/cpp/src/umap/simpl_set_embed/algo.cuh +++ b/cpp/src/umap/simpl_set_embed/algo.cuh @@ -299,7 +299,7 @@ template void launcher( int m, int n, raft::sparse::COO* in, UMAPParams* params, T* embedding, cudaStream_t stream) { - uint64_t nnz = in->nnz; + uint64_t nnz = in->safe_nnz; /** * Find vals.max() @@ -334,14 +334,14 @@ void launcher( raft::sparse::COO out(stream); raft::sparse::op::coo_remove_zeros(in, &out, stream); - rmm::device_uvector epochs_per_sample(out.nnz, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(epochs_per_sample.data(), 0, out.nnz * sizeof(T), stream)); + rmm::device_uvector epochs_per_sample(out.safe_nnz, stream); + RAFT_CUDA_TRY(cudaMemsetAsync(epochs_per_sample.data(), 0, out.safe_nnz * sizeof(T), stream)); - make_epochs_per_sample(out.vals(), out.nnz, n_epochs, epochs_per_sample.data(), stream); + make_epochs_per_sample(out.vals(), out.safe_nnz, n_epochs, epochs_per_sample.data(), stream); if (ML::default_logger().should_log(ML::level_enum::debug)) { std::stringstream ss; - ss << raft::arr2Str(epochs_per_sample.data(), out.nnz, "epochs_per_sample", stream); + ss << raft::arr2Str(epochs_per_sample.data(), out.safe_nnz, "epochs_per_sample", stream); CUML_LOG_DEBUG(ss.str().c_str()); } @@ -351,7 +351,7 @@ void launcher( m, out.rows(), out.cols(), - out.nnz, + out.safe_nnz, epochs_per_sample.data(), params->repulsion_strength, params, diff --git a/cpp/src/umap/supervised.cuh b/cpp/src/umap/supervised.cuh index a1afdc1d63..e6087d79ac 100644 --- a/cpp/src/umap/supervised.cuh +++ b/cpp/src/umap/supervised.cuh @@ -103,7 +103,7 @@ void reset_local_connectivity(raft::sparse::COO* in_coo, * and this will update the fuzzy simplicial set to respect that label * data. */ -template +template void categorical_simplicial_set_intersection(raft::sparse::COO* graph_coo, value_t* target, cudaStream_t stream, @@ -121,7 +121,7 @@ void categorical_simplicial_set_intersection(raft::sparse::COO* graph_c far_dist); } -template +template CUML_KERNEL void sset_intersection_kernel(int* row_ind1, int* cols1, value_t* vals1, @@ -179,7 +179,7 @@ CUML_KERNEL void sset_intersection_kernel(int* row_ind1, * Computes the CSR column index pointer and values * for the general simplicial set intersecftion. */ -template +template void general_simplicial_set_intersection(int* row1_ind, raft::sparse::COO* in1, int* row2_ind, From 6ff906b31c175ca0ad4304d8b81676428e1f1ab2 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Thu, 30 Jan 2025 14:53:18 +0000 Subject: [PATCH 11/11] completing change --- cpp/src/tsne/tsne_runner.cuh | 2 +- cpp/src/umap/init_embed/spectral_algo.cuh | 2 +- cpp/src/umap/runner.cuh | 19 +++++++------------ cpp/src/umap/simpl_set_embed/algo.cuh | 12 ++++++------ cpp/src/umap/supervised.cuh | 6 +++--- 5 files changed, 18 insertions(+), 23 deletions(-) diff --git a/cpp/src/tsne/tsne_runner.cuh b/cpp/src/tsne/tsne_runner.cuh index c1e63252f8..8102d443ad 100644 --- a/cpp/src/tsne/tsne_runner.cuh +++ b/cpp/src/tsne/tsne_runner.cuh @@ -168,7 +168,7 @@ class TSNE_runner { { distance_and_perplexity(); - const auto NNZ = COO_Matrix.nnz; + const auto NNZ = (value_idx)COO_Matrix.nnz; auto* VAL = COO_Matrix.vals(); const auto* COL = COO_Matrix.cols(); const auto* ROW = COO_Matrix.rows(); diff --git a/cpp/src/umap/init_embed/spectral_algo.cuh b/cpp/src/umap/init_embed/spectral_algo.cuh index 86986dad90..288d2373fe 100644 --- a/cpp/src/umap/init_embed/spectral_algo.cuh +++ b/cpp/src/umap/init_embed/spectral_algo.cuh @@ -65,7 +65,7 @@ void launcher(const raft::handle_t& handle, coo->rows(), coo->cols(), coo->vals(), - coo->safe_nnz, + coo->nnz, n, params->n_components, tmp_storage.data(), diff --git a/cpp/src/umap/runner.cuh b/cpp/src/umap/runner.cuh index 845a0633b6..3439851b26 100644 --- a/cpp/src/umap/runner.cuh +++ b/cpp/src/umap/runner.cuh @@ -456,18 +456,13 @@ void _transform(const raft::handle_t& handle, raft::sparse::convert::sorted_coo_to_csr(&graph_coo, row_ind.data(), stream); - rmm::device_uvector vals_normed(graph_coo.safe_nnz, stream); - RAFT_CUDA_TRY( - cudaMemsetAsync(vals_normed.data(), 0, graph_coo.safe_nnz * sizeof(value_t), stream)); + rmm::device_uvector vals_normed(graph_coo.nnz, stream); + RAFT_CUDA_TRY(cudaMemsetAsync(vals_normed.data(), 0, graph_coo.nnz * sizeof(value_t), stream)); CUML_LOG_DEBUG("Performing L1 normalization"); - raft::sparse::linalg::csr_row_normalize_l1(row_ind.data(), - graph_coo.vals(), - graph_coo.safe_nnz, - graph_coo.n_rows, - vals_normed.data(), - stream); + raft::sparse::linalg::csr_row_normalize_l1( + row_ind.data(), graph_coo.vals(), graph_coo.nnz, graph_coo.n_rows, vals_normed.data(), stream); init_transform<<>>(graph_coo.cols(), vals_normed.data(), @@ -502,7 +497,7 @@ void _transform(const raft::handle_t& handle, raft::linalg::unaryOp( graph_coo.vals(), graph_coo.vals(), - graph_coo.safe_nnz, + graph_coo.nnz, [=] __device__(value_t input) { if (input < (max / float(n_epochs))) return 0.0f; @@ -525,7 +520,7 @@ void _transform(const raft::handle_t& handle, rmm::device_uvector epochs_per_sample(nnz, stream); SimplSetEmbedImpl::make_epochs_per_sample( - comp_coo.vals(), comp_coo.safe_nnz, n_epochs, epochs_per_sample.data(), stream); + comp_coo.vals(), comp_coo.nnz, n_epochs, epochs_per_sample.data(), stream); CUML_LOG_DEBUG("Performing optimization"); @@ -542,7 +537,7 @@ void _transform(const raft::handle_t& handle, embedding_n, comp_coo.rows(), comp_coo.cols(), - comp_coo.safe_nnz, + comp_coo.nnz, epochs_per_sample.data(), params->repulsion_strength, params, diff --git a/cpp/src/umap/simpl_set_embed/algo.cuh b/cpp/src/umap/simpl_set_embed/algo.cuh index 5090ed9d8e..b5429b78af 100644 --- a/cpp/src/umap/simpl_set_embed/algo.cuh +++ b/cpp/src/umap/simpl_set_embed/algo.cuh @@ -299,7 +299,7 @@ template void launcher( int m, int n, raft::sparse::COO* in, UMAPParams* params, T* embedding, cudaStream_t stream) { - uint64_t nnz = in->safe_nnz; + uint64_t nnz = in->nnz; /** * Find vals.max() @@ -334,14 +334,14 @@ void launcher( raft::sparse::COO out(stream); raft::sparse::op::coo_remove_zeros(in, &out, stream); - rmm::device_uvector epochs_per_sample(out.safe_nnz, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(epochs_per_sample.data(), 0, out.safe_nnz * sizeof(T), stream)); + rmm::device_uvector epochs_per_sample(out.nnz, stream); + RAFT_CUDA_TRY(cudaMemsetAsync(epochs_per_sample.data(), 0, out.nnz * sizeof(T), stream)); - make_epochs_per_sample(out.vals(), out.safe_nnz, n_epochs, epochs_per_sample.data(), stream); + make_epochs_per_sample(out.vals(), out.nnz, n_epochs, epochs_per_sample.data(), stream); if (ML::default_logger().should_log(ML::level_enum::debug)) { std::stringstream ss; - ss << raft::arr2Str(epochs_per_sample.data(), out.safe_nnz, "epochs_per_sample", stream); + ss << raft::arr2Str(epochs_per_sample.data(), out.nnz, "epochs_per_sample", stream); CUML_LOG_DEBUG(ss.str().c_str()); } @@ -351,7 +351,7 @@ void launcher( m, out.rows(), out.cols(), - out.safe_nnz, + out.nnz, epochs_per_sample.data(), params->repulsion_strength, params, diff --git a/cpp/src/umap/supervised.cuh b/cpp/src/umap/supervised.cuh index e6087d79ac..a1afdc1d63 100644 --- a/cpp/src/umap/supervised.cuh +++ b/cpp/src/umap/supervised.cuh @@ -103,7 +103,7 @@ void reset_local_connectivity(raft::sparse::COO* in_coo, * and this will update the fuzzy simplicial set to respect that label * data. */ -template +template void categorical_simplicial_set_intersection(raft::sparse::COO* graph_coo, value_t* target, cudaStream_t stream, @@ -121,7 +121,7 @@ void categorical_simplicial_set_intersection(raft::sparse::COO* graph_c far_dist); } -template +template CUML_KERNEL void sset_intersection_kernel(int* row_ind1, int* cols1, value_t* vals1, @@ -179,7 +179,7 @@ CUML_KERNEL void sset_intersection_kernel(int* row_ind1, * Computes the CSR column index pointer and values * for the general simplicial set intersecftion. */ -template +template void general_simplicial_set_intersection(int* row1_ind, raft::sparse::COO* in1, int* row2_ind,