diff --git a/cpp/include/cuml/manifold/common.hpp b/cpp/include/cuml/manifold/common.hpp index 3346f9127e..99dd0e97d7 100644 --- a/cpp/include/cuml/manifold/common.hpp +++ b/cpp/include/cuml/manifold/common.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ #pragma once +#include + namespace ML { // Dense input uses int64_t until FAISS is updated @@ -55,8 +57,8 @@ struct knn_graph { template struct manifold_inputs_t { T* y; - int n; - int d; + uint64_t n; + uint64_t d; manifold_inputs_t(T* y_, int n_, int d_) : y(y_), n(n_), d(d_) {} diff --git a/cpp/src/tsne/tsne_runner.cuh b/cpp/src/tsne/tsne_runner.cuh index cdbfdd2674..b019f4e10d 100644 --- a/cpp/src/tsne/tsne_runner.cuh +++ b/cpp/src/tsne/tsne_runner.cuh @@ -37,6 +37,7 @@ #include #include +#include namespace ML { @@ -167,7 +168,7 @@ class TSNE_runner { { distance_and_perplexity(); - const auto NNZ = COO_Matrix.nnz; + const auto NNZ = static_cast(COO_Matrix.nnz); auto* VAL = COO_Matrix.vals(); const auto* COL = COO_Matrix.cols(); const auto* ROW = COO_Matrix.rows(); diff --git a/cpp/src/umap/fuzzy_simpl_set/naive.cuh b/cpp/src/umap/fuzzy_simpl_set/naive.cuh index 41e54f1f63..a6b82b84ab 100644 --- a/cpp/src/umap/fuzzy_simpl_set/naive.cuh +++ b/cpp/src/umap/fuzzy_simpl_set/naive.cuh @@ -79,7 +79,7 @@ static const float MIN_K_DIST_SCALE = 1e-3; * Descriptions adapted from: https://github.com/lmcinnes/umap/blob/master/umap/umap_.py * */ -template +template CUML_KERNEL void smooth_knn_dist_kernel(const value_t* knn_dists, int n, float mean_dist, @@ -92,7 +92,8 @@ CUML_KERNEL void smooth_knn_dist_kernel(const value_t* knn_dists, { // row-based matrix 1 thread per row int row = (blockIdx.x * TPB_X) + threadIdx.x; - int i = row * n_neighbors; // each thread processes one row of the dist matrix + nnz_t i = + static_cast(row) * n_neighbors; // each thread processes one row of the dist matrix if (row < n) { float target = __log2f(n_neighbors) * bandwidth; @@ -190,7 +191,7 @@ CUML_KERNEL void smooth_knn_dist_kernel(const value_t* knn_dists, * * Descriptions adapted from: https://github.com/lmcinnes/umap/blob/master/umap/umap_.py */ -template +template CUML_KERNEL void compute_membership_strength_kernel( const value_idx* knn_indices, const float* knn_dists, // nn outputs @@ -199,14 +200,14 @@ CUML_KERNEL void compute_membership_strength_kernel( value_t* vals, int* rows, int* cols, // result coo - int n, - int n_neighbors) + int n_neighbors, + nnz_t to_process) { // model params // row-based matrix is best - int idx = (blockIdx.x * TPB_X) + threadIdx.x; + nnz_t idx = (blockIdx.x * TPB_X) + threadIdx.x; - if (idx < n * n_neighbors) { + if (idx < to_process) { int row = idx / n_neighbors; // one neighbor per thread double cur_rho = rhos[row]; @@ -237,8 +238,8 @@ CUML_KERNEL void compute_membership_strength_kernel( /* * Sets up and runs the knn dist smoothing */ -template -void smooth_knn_dist(int n, +template +void smooth_knn_dist(nnz_t n, const value_idx* knn_indices, const float* knn_dists, value_t* rhos, @@ -253,7 +254,8 @@ void smooth_knn_dist(int n, rmm::device_uvector dist_means_dev(n_neighbors, stream); - raft::stats::mean(dist_means_dev.data(), knn_dists, 1, n_neighbors * n, false, false, stream); + raft::stats::mean( + dist_means_dev.data(), knn_dists, nnz_t{1}, n * n_neighbors, false, false, stream); RAFT_CUDA_TRY(cudaPeekAtLastError()); value_t mean_dist = 0.0; @@ -263,7 +265,7 @@ void smooth_knn_dist(int n, /** * Smooth kNN distances to be continuous */ - smooth_knn_dist_kernel<<>>( + smooth_knn_dist_kernel<<>>( knn_dists, n, mean_dist, sigmas, rhos, n_neighbors, local_connectivity); RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -284,8 +286,8 @@ void smooth_knn_dist(int n, * @param params UMAPParams config object * @param stream cuda stream to use for device operations */ -template -void launcher(int n, +template +void launcher(nnz_t n, const value_idx* knn_indices, const value_t* knn_dists, int n_neighbors, @@ -301,15 +303,15 @@ void launcher(int n, RAFT_CUDA_TRY(cudaMemsetAsync(sigmas.data(), 0, n * sizeof(value_t), stream)); RAFT_CUDA_TRY(cudaMemsetAsync(rhos.data(), 0, n * sizeof(value_t), stream)); - smooth_knn_dist(n, - knn_indices, - knn_dists, - rhos.data(), - sigmas.data(), - params, - n_neighbors, - params->local_connectivity, - stream); + smooth_knn_dist(n, + knn_indices, + knn_dists, + rhos.data(), + sigmas.data(), + params, + n_neighbors, + params->local_connectivity, + stream); raft::sparse::COO in(stream, n * n_neighbors, n, n); @@ -328,18 +330,20 @@ void launcher(int n, * Compute graph of membership strengths */ - dim3 grid_elm(raft::ceildiv(n * n_neighbors, TPB_X), 1, 1); + nnz_t to_process = static_cast(in.n_rows) * n_neighbors; + dim3 grid_elm(raft::ceildiv(to_process, TPB_X), 1, 1); dim3 blk_elm(TPB_X, 1, 1); - compute_membership_strength_kernel<<>>(knn_indices, - knn_dists, - sigmas.data(), - rhos.data(), - in.vals(), - in.rows(), - in.cols(), - in.n_rows, - n_neighbors); + compute_membership_strength_kernel + <<>>(knn_indices, + knn_dists, + sigmas.data(), + rhos.data(), + in.vals(), + in.rows(), + in.cols(), + n_neighbors, + to_process); RAFT_CUDA_TRY(cudaPeekAtLastError()); if (ML::default_logger().should_log(ML::level_enum::debug)) { diff --git a/cpp/src/umap/fuzzy_simpl_set/runner.cuh b/cpp/src/umap/fuzzy_simpl_set/runner.cuh index 6cfd3cd58d..03ee9f59b7 100644 --- a/cpp/src/umap/fuzzy_simpl_set/runner.cuh +++ b/cpp/src/umap/fuzzy_simpl_set/runner.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * Copyright (c) 2019-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ using namespace ML; * @param stream cuda stream * @param algorithm algo type to choose */ -template +template void run(int n, const value_idx* knn_indices, const T* knn_dists, @@ -50,7 +50,7 @@ void run(int n, { switch (algorithm) { case 0: - Naive::launcher( + Naive::launcher( n, knn_indices, knn_dists, n_neighbors, coo, params, stream); break; } diff --git a/cpp/src/umap/init_embed/random_algo.cuh b/cpp/src/umap/init_embed/random_algo.cuh index 217531e548..31d6f0cb82 100644 --- a/cpp/src/umap/init_embed/random_algo.cuh +++ b/cpp/src/umap/init_embed/random_algo.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * Copyright (c) 2019-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,14 +20,16 @@ #include +#include + namespace UMAPAlgo { namespace InitEmbed { namespace RandomInit { using namespace ML; -template -void launcher(int n, int d, UMAPParams* params, T* embedding, cudaStream_t stream) +template +void launcher(nnz_t n, int d, UMAPParams* params, T* embedding, cudaStream_t stream) { uint64_t seed = params->random_state; diff --git a/cpp/src/umap/init_embed/runner.cuh b/cpp/src/umap/init_embed/runner.cuh index f0f9961ef4..907a632790 100644 --- a/cpp/src/umap/init_embed/runner.cuh +++ b/cpp/src/umap/init_embed/runner.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * Copyright (c) 2019-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ namespace InitEmbed { using namespace ML; -template +template void run(const raft::handle_t& handle, int n, int d, @@ -43,9 +43,9 @@ void run(const raft::handle_t& handle, /** * Initial algo uses FAISS indices */ - case 0: RandomInit::launcher(n, d, params, embedding, handle.get_stream()); break; + case 0: RandomInit::launcher(n, d, params, embedding, handle.get_stream()); break; - case 1: SpectralInit::launcher(handle, n, d, coo, params, embedding); break; + case 1: SpectralInit::launcher(handle, n, d, coo, params, embedding); break; } } } // namespace InitEmbed diff --git a/cpp/src/umap/init_embed/spectral_algo.cuh b/cpp/src/umap/init_embed/spectral_algo.cuh index cac6e8dcc0..70308af7cf 100644 --- a/cpp/src/umap/init_embed/spectral_algo.cuh +++ b/cpp/src/umap/init_embed/spectral_algo.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * Copyright (c) 2019-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,6 +29,8 @@ #include #include +#include + #include namespace UMAPAlgo { @@ -42,9 +44,9 @@ using namespace ML; /** * Performs a spectral layout initialization */ -template +template void launcher(const raft::handle_t& handle, - int n, + nnz_t n, int d, raft::sparse::COO* coo, UMAPParams* params, @@ -52,7 +54,8 @@ void launcher(const raft::handle_t& handle, { cudaStream_t stream = handle.get_stream(); - ASSERT(n > params->n_components, "Spectral layout requires n_samples > n_components"); + ASSERT(n > static_cast(params->n_components), + "Spectral layout requires n_samples > n_components"); rmm::device_uvector tmp_storage(n * params->n_components, stream); diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index f21729a695..9518a1240c 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -35,6 +35,7 @@ #include #include +#include #include @@ -138,9 +139,10 @@ inline void launcher(const raft::handle_t& handle, target[i * n_neighbors + j] = source[i * graph_degree + j]; } } - raft::copy(handle, - raft::make_device_matrix_view(out.knn_indices, inputsA.n, n_neighbors), - temp_indices_h.view()); + raft::copy( + handle, + raft::make_device_matrix_view(out.knn_indices, inputsA.n, static_cast(n_neighbors)), + temp_indices_h.view()); // `graph.distances()` is a device array (n x graph_degree). // Slice and copy to the output device array `out.knn_dists` (n x n_neighbors). @@ -152,7 +154,7 @@ inline void launcher(const raft::handle_t& handle, raft::matrix::slice( handle, raft::make_const_mdspan(graph.distances().value()), - raft::make_device_matrix_view(out.knn_dists, inputsA.n, n_neighbors), + raft::make_device_matrix_view(out.knn_dists, inputsA.n, static_cast(n_neighbors)), coords); } } diff --git a/cpp/src/umap/runner.cuh b/cpp/src/umap/runner.cuh index 01aa6f62c7..6ca1e697cc 100644 --- a/cpp/src/umap/runner.cuh +++ b/cpp/src/umap/runner.cuh @@ -91,7 +91,7 @@ inline void find_ab(UMAPParams* params, cudaStream_t stream) Optimize::find_params_ab(params, stream); } -template +template void _get_graph(const raft::handle_t& handle, const umap_inputs& inputs, UMAPParams* params, @@ -136,7 +136,7 @@ void _get_graph(const raft::handle_t& handle, raft::common::nvtx::push_range("umap::simplicial_set"); raft::sparse::COO fss_graph(stream); - FuzzySimplSet::run( + FuzzySimplSet::run( inputs.n, knn_graph.knn_indices, knn_graph.knn_dists, k, &fss_graph, params, stream); CUML_LOG_DEBUG("Done. Calling remove zeros"); @@ -148,7 +148,7 @@ void _get_graph(const raft::handle_t& handle, raft::common::nvtx::pop_range(); } -template +template void _get_graph_supervised(const raft::handle_t& handle, const umap_inputs& inputs, UMAPParams* params, @@ -200,13 +200,13 @@ void _get_graph_supervised(const raft::handle_t& handle, * Run Fuzzy simplicial set */ // int nnz = n*k*2; - FuzzySimplSet::run(inputs.n, - knn_graph.knn_indices, - knn_graph.knn_dists, - params->n_neighbors, - &fss_graph_tmp, - params, - stream); + FuzzySimplSet::run(inputs.n, + knn_graph.knn_indices, + knn_graph.knn_dists, + params->n_neighbors, + &fss_graph_tmp, + params, + stream); RAFT_CUDA_TRY(cudaPeekAtLastError()); raft::sparse::op::coo_remove_zeros(&fss_graph_tmp, &fss_graph, stream); @@ -219,7 +219,7 @@ void _get_graph_supervised(const raft::handle_t& handle, */ if (params->target_metric == ML::UMAPParams::MetricType::CATEGORICAL) { CUML_LOG_DEBUG("Performing categorical intersection"); - Supervised::perform_categorical_intersection( + Supervised::perform_categorical_intersection( inputs.y, &fss_graph, &ci_graph, params, stream); /** @@ -227,7 +227,7 @@ void _get_graph_supervised(const raft::handle_t& handle, */ } else { CUML_LOG_DEBUG("Performing general intersection"); - Supervised::perform_general_intersection( + Supervised::perform_general_intersection( handle, inputs.y, &fss_graph, &ci_graph, params, stream); } @@ -239,7 +239,7 @@ void _get_graph_supervised(const raft::handle_t& handle, raft::common::nvtx::pop_range(); } -template +template void _refine(const raft::handle_t& handle, const umap_inputs& inputs, UMAPParams* params, @@ -252,10 +252,10 @@ void _refine(const raft::handle_t& handle, /** * Run simplicial set embedding to approximate low-dimensional representation */ - SimplSetEmbed::run(inputs.n, inputs.d, graph, params, embeddings, stream); + SimplSetEmbed::run(inputs.n, inputs.d, graph, params, embeddings, stream); } -template +template void _init_and_refine(const raft::handle_t& handle, const umap_inputs& inputs, UMAPParams* params, @@ -266,13 +266,14 @@ void _init_and_refine(const raft::handle_t& handle, ML::default_logger().set_level(params->verbosity); // Initialize embeddings - InitEmbed::run(handle, inputs.n, inputs.d, graph, params, embeddings, stream, params->init); + InitEmbed::run( + handle, inputs.n, inputs.d, graph, params, embeddings, stream, params->init); // Run simplicial set embedding - SimplSetEmbed::run(inputs.n, inputs.d, graph, params, embeddings, stream); + SimplSetEmbed::run(inputs.n, inputs.d, graph, params, embeddings, stream); } -template +template void _fit(const raft::handle_t& handle, const umap_inputs& inputs, UMAPParams* params, @@ -284,13 +285,15 @@ void _fit(const raft::handle_t& handle, cudaStream_t stream = handle.get_stream(); ML::default_logger().set_level(params->verbosity); - UMAPAlgo::_get_graph(handle, inputs, params, graph); + UMAPAlgo::_get_graph( + handle, inputs, params, graph); /** * Run initialization method */ raft::common::nvtx::push_range("umap::embedding"); - InitEmbed::run(handle, inputs.n, inputs.d, graph, params, embeddings, stream, params->init); + InitEmbed::run( + handle, inputs.n, inputs.d, graph, params, embeddings, stream, params->init); if (params->callback) { params->callback->setup(inputs.n, params->n_components); @@ -300,7 +303,7 @@ void _fit(const raft::handle_t& handle, /** * Run simplicial set embedding to approximate low-dimensional representation */ - SimplSetEmbed::run(inputs.n, inputs.d, graph, params, embeddings, stream); + SimplSetEmbed::run(inputs.n, inputs.d, graph, params, embeddings, stream); raft::common::nvtx::pop_range(); if (params->callback) params->callback->on_train_end(embeddings); @@ -308,7 +311,7 @@ void _fit(const raft::handle_t& handle, RAFT_CUDA_TRY(cudaPeekAtLastError()); } -template +template void _fit_supervised(const raft::handle_t& handle, const umap_inputs& inputs, UMAPParams* params, @@ -320,14 +323,15 @@ void _fit_supervised(const raft::handle_t& handle, cudaStream_t stream = handle.get_stream(); ML::default_logger().set_level(params->verbosity); - UMAPAlgo::_get_graph_supervised( + UMAPAlgo::_get_graph_supervised( handle, inputs, params, graph); /** * Initialize embeddings */ raft::common::nvtx::push_range("umap::supervised::fit"); - InitEmbed::run(handle, inputs.n, inputs.d, graph, params, embeddings, stream, params->init); + InitEmbed::run( + handle, inputs.n, inputs.d, graph, params, embeddings, stream, params->init); if (params->callback) { params->callback->setup(inputs.n, params->n_components); @@ -337,7 +341,7 @@ void _fit_supervised(const raft::handle_t& handle, /** * Run simplicial set embedding to approximate low-dimensional representation */ - SimplSetEmbed::run(inputs.n, inputs.d, graph, params, embeddings, stream); + SimplSetEmbed::run(inputs.n, inputs.d, graph, params, embeddings, stream); raft::common::nvtx::pop_range(); if (params->callback) params->callback->on_train_end(embeddings); @@ -348,7 +352,7 @@ void _fit_supervised(const raft::handle_t& handle, /** * */ -template +template void _transform(const raft::handle_t& handle, const umap_inputs& inputs, umap_inputs& orig_x_inputs, @@ -410,22 +414,22 @@ void _transform(const raft::handle_t& handle, dim3 grid_n(raft::ceildiv(inputs.n, TPB_X), 1, 1); dim3 blk(TPB_X, 1, 1); - FuzzySimplSetImpl::smooth_knn_dist(inputs.n, - knn_graph.knn_indices, - knn_graph.knn_dists, - rhos.data(), - sigmas.data(), - params, - params->n_neighbors, - adjusted_local_connectivity, - stream); + FuzzySimplSetImpl::smooth_knn_dist(inputs.n, + knn_graph.knn_indices, + knn_graph.knn_dists, + rhos.data(), + sigmas.data(), + params, + params->n_neighbors, + adjusted_local_connectivity, + stream); raft::common::nvtx::pop_range(); /** * Compute graph of membership strengths */ - int nnz = inputs.n * params->n_neighbors; + nnz_t nnz = static_cast(inputs.n) * params->n_neighbors; dim3 grid_nnz(raft::ceildiv(nnz, TPB_X), 1, 1); @@ -437,7 +441,8 @@ void _transform(const raft::handle_t& handle, raft::sparse::COO graph_coo(stream, nnz, inputs.n, inputs.n); - FuzzySimplSetImpl::compute_membership_strength_kernel + nnz_t to_process = static_cast(graph_coo.n_rows) * params->n_neighbors; + FuzzySimplSetImpl::compute_membership_strength_kernel <<>>(knn_graph.knn_indices, knn_graph.knn_dists, sigmas.data(), @@ -445,15 +450,13 @@ void _transform(const raft::handle_t& handle, graph_coo.vals(), graph_coo.rows(), graph_coo.cols(), - graph_coo.n_rows, - params->n_neighbors); + params->n_neighbors, + to_process); RAFT_CUDA_TRY(cudaPeekAtLastError()); - rmm::device_uvector row_ind(inputs.n, stream); - rmm::device_uvector ia(inputs.n, stream); + rmm::device_uvector row_ind(inputs.n, stream); raft::sparse::convert::sorted_coo_to_csr(&graph_coo, row_ind.data(), stream); - raft::sparse::linalg::coo_degree(&graph_coo, ia.data(), stream); rmm::device_uvector vals_normed(graph_coo.nnz, stream); RAFT_CUDA_TRY(cudaMemsetAsync(vals_normed.data(), 0, graph_coo.nnz * sizeof(value_t), stream)); @@ -471,9 +474,6 @@ void _transform(const raft::handle_t& handle, params->n_components, transformed, params->n_neighbors); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - - RAFT_CUDA_TRY(cudaMemsetAsync(ia.data(), 0.0, ia.size() * sizeof(int), stream)); RAFT_CUDA_TRY(cudaPeekAtLastError()); @@ -533,18 +533,18 @@ void _transform(const raft::handle_t& handle, auto initial_alpha = params->initial_alpha / 4.0; - SimplSetEmbedImpl::optimize_layout(transformed, - inputs.n, - embedding, - embedding_n, - comp_coo.rows(), - comp_coo.cols(), - comp_coo.nnz, - epochs_per_sample.data(), - params->repulsion_strength, - params, - n_epochs, - stream); + SimplSetEmbedImpl::optimize_layout(transformed, + inputs.n, + embedding, + embedding_n, + comp_coo.rows(), + comp_coo.cols(), + comp_coo.nnz, + epochs_per_sample.data(), + params->repulsion_strength, + params, + n_epochs, + stream); raft::common::nvtx::pop_range(); if (params->callback) params->callback->on_train_end(transformed); diff --git a/cpp/src/umap/simpl_set_embed/algo.cuh b/cpp/src/umap/simpl_set_embed/algo.cuh index 6be8b0235b..b5bccdac65 100644 --- a/cpp/src/umap/simpl_set_embed/algo.cuh +++ b/cpp/src/umap/simpl_set_embed/algo.cuh @@ -18,8 +18,6 @@ #include "optimize_batch_kernel.cuh" -#include - #include #include @@ -38,6 +36,7 @@ #include #include +#include #include #include @@ -58,8 +57,9 @@ using namespace ML; * @param result: an array of number of epochs per sample, one for each 1-simplex * @param stream cuda stream */ -template -void make_epochs_per_sample(T* weights, int weights_n, int n_epochs, T* result, cudaStream_t stream) +template +void make_epochs_per_sample( + T* weights, idx_t weights_n, int n_epochs, T* result, cudaStream_t stream) { thrust::device_ptr d_weights = thrust::device_pointer_cast(weights); T weights_max = @@ -99,7 +99,7 @@ void optimization_iteration_finalization( /** * Update the embeddings and clear the buffers when using deterministic algorithm. */ -template +template void apply_embedding_updates(T* head_embedding, T* head_buffer, int head_n, @@ -111,27 +111,26 @@ void apply_embedding_updates(T* head_embedding, rmm::cuda_stream_view stream) { ASSERT(params->deterministic, "Only used when deterministic is set to true."); + nnz_t n_components = params->n_components; if (move_other) { - auto n_components = params->n_components; - thrust::for_each( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0u), - thrust::make_counting_iterator(0u) + std::max(head_n, tail_n) * params->n_components, - [=] __device__(uint32_t i) { - if (i < head_n * n_components) { - head_embedding[i] += head_buffer[i]; - head_buffer[i] = 0.0f; - } - if (i < tail_n * n_components) { - tail_embedding[i] += tail_buffer[i]; - tail_buffer[i] = 0.0f; - } - }); + thrust::for_each(rmm::exec_policy(stream), + thrust::make_counting_iterator(0u), + thrust::make_counting_iterator(0u) + std::max(head_n, tail_n) * n_components, + [=] __device__(uint32_t i) { + if (i < head_n * n_components) { + head_embedding[i] += head_buffer[i]; + head_buffer[i] = 0.0f; + } + if (i < tail_n * n_components) { + tail_embedding[i] += tail_buffer[i]; + tail_buffer[i] = 0.0f; + } + }); } else { // No need to update reference embedding thrust::for_each(rmm::exec_policy(stream), thrust::make_counting_iterator(0u), - thrust::make_counting_iterator(0u) + head_n * params->n_components, + thrust::make_counting_iterator(0u) + head_n * n_components, [=] __device__(uint32_t i) { head_embedding[i] += head_buffer[i]; head_buffer[i] = 0.0f; @@ -168,9 +167,9 @@ T create_rounding_factor(T max_abs, int n) return std::ldexp(static_cast(1.0), exp); } -template +template T create_gradient_rounding_factor( - const int* head, int nnz, int n_samples, T alpha, rmm::cuda_stream_view stream) + const int* head, nnz_t nnz, int n_samples, T alpha, rmm::cuda_stream_view stream) { rmm::device_uvector buffer(n_samples, stream); // calculate the maximum number of edges connected to 1 vertex. @@ -195,14 +194,14 @@ T create_gradient_rounding_factor( * positive weights (neighbors in the 1-skeleton) and repelling * negative weights (non-neighbors in the 1-skeleton). */ -template +template void optimize_layout(T* head_embedding, int head_n, T* tail_embedding, int tail_n, const int* head, const int* tail, - int nnz, + nnz_t nnz, T* epochs_per_sample, float gamma, UMAPParams* params, @@ -233,12 +232,13 @@ void optimize_layout(T* head_embedding, T* d_head_buffer = head_embedding; T* d_tail_buffer = tail_embedding; if (params->deterministic) { - head_buffer.resize(head_n * params->n_components, stream_view); + nnz_t n_components = params->n_components; + head_buffer.resize(head_n * n_components, stream_view); RAFT_CUDA_TRY( cudaMemsetAsync(head_buffer.data(), '\0', sizeof(T) * head_buffer.size(), stream)); // No need for tail if it's not being written. if (move_other) { - tail_buffer.resize(tail_n * params->n_components, stream_view); + tail_buffer.resize(tail_n * n_components, stream_view); RAFT_CUDA_TRY( cudaMemsetAsync(tail_buffer.data(), '\0', sizeof(T) * tail_buffer.size(), stream)); } @@ -250,42 +250,40 @@ void optimize_layout(T* head_embedding, dim3 blk(TPB_X, 1, 1); uint64_t seed = params->random_state; - T rounding = create_gradient_rounding_factor(head, nnz, head_n, alpha, stream_view); + T rounding = create_gradient_rounding_factor(head, nnz, head_n, alpha, stream_view); - MLCommon::FastIntDiv tail_n_fast(tail_n); for (int n = 0; n < n_epochs; n++) { - call_optimize_batch_kernel(head_embedding, - d_head_buffer, - head_n, - tail_embedding, - d_tail_buffer, - tail_n_fast, - head, - tail, - nnz, - epochs_per_sample, - epoch_of_next_negative_sample.data(), - epoch_of_next_sample.data(), - alpha, - gamma, - seed, - move_other, - params, - n, - grid, - blk, - stream, - rounding); + call_optimize_batch_kernel(head_embedding, + d_head_buffer, + tail_embedding, + d_tail_buffer, + tail_n, + head, + tail, + nnz, + epochs_per_sample, + epoch_of_next_negative_sample.data(), + epoch_of_next_sample.data(), + alpha, + gamma, + seed, + move_other, + params, + n, + grid, + blk, + stream, + rounding); if (params->deterministic) { - apply_embedding_updates(head_embedding, - d_head_buffer, - head_n, - tail_embedding, - d_tail_buffer, - tail_n, - params, - move_other, - stream_view); + apply_embedding_updates(head_embedding, + d_head_buffer, + head_n, + tail_embedding, + d_tail_buffer, + tail_n, + params, + move_other, + stream_view); } RAFT_CUDA_TRY(cudaGetLastError()); optimization_iteration_finalization(params, head_embedding, alpha, n, n_epochs, seed); @@ -297,11 +295,11 @@ void optimize_layout(T* head_embedding, * the fuzzy set cross entropy between the embeddings * and their 1-skeletons. */ -template +template void launcher( int m, int n, raft::sparse::COO* in, UMAPParams* params, T* embedding, cudaStream_t stream) { - int nnz = in->nnz; + nnz_t nnz = in->nnz; /** * Find vals.max() @@ -347,18 +345,18 @@ void launcher( CUML_LOG_DEBUG(ss.str().c_str()); } - optimize_layout(embedding, - m, - embedding, - m, - out.rows(), - out.cols(), - out.nnz, - epochs_per_sample.data(), - params->repulsion_strength, - params, - n_epochs, - stream); + optimize_layout(embedding, + m, + embedding, + m, + out.rows(), + out.cols(), + out.nnz, + epochs_per_sample.data(), + params->repulsion_strength, + params, + n_epochs, + stream); RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh b/cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh index 5fd34d2f3b..bbc0e9eb61 100644 --- a/cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh +++ b/cpp/src/umap/simpl_set_embed/optimize_batch_kernel.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * Copyright (c) 2021-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,8 @@ #include #include +#include + #include namespace UMAPAlgo { @@ -96,16 +98,15 @@ DI T truncate_gradient(T const rounding_factor, T const x) return (rounding_factor + x) - rounding_factor; } -template +template CUML_KERNEL void optimize_batch_kernel_reg(T const* head_embedding, T* head_buffer, - int head_n, T const* tail_embedding, T* tail_buffer, - const MLCommon::FastIntDiv tail_n, + MLCommon::FastIntDiv tail_n, const int* head, const int* tail, - int nnz, + nnz_t nnz, T const* epochs_per_sample, T* epoch_of_next_negative_sample, T* epoch_of_next_sample, @@ -118,7 +119,7 @@ CUML_KERNEL void optimize_batch_kernel_reg(T const* head_embedding, T nsr_inv, T rounding) { - int row = (blockIdx.x * TPB_X) + threadIdx.x; + nnz_t row = (blockIdx.x * TPB_X) + threadIdx.x; if (row >= nnz) return; auto _epoch_of_next_sample = epoch_of_next_sample[row]; if (_epoch_of_next_sample > epoch) return; @@ -127,8 +128,8 @@ CUML_KERNEL void optimize_batch_kernel_reg(T const* head_embedding, /** * Positive sample stage (attractive forces) */ - int j = head[row]; - int k = tail[row]; + nnz_t j = head[row]; + nnz_t k = tail[row]; T const* current = head_embedding + (j * n_components); T const* other = tail_embedding + (k * n_components); @@ -170,11 +171,11 @@ CUML_KERNEL void optimize_batch_kernel_reg(T const* head_embedding, /** * Negative sampling stage */ - raft::random::detail::PhiloxGenerator gen((uint64_t)seed, (uint64_t)row, 0); + raft::random::detail::PhiloxGenerator gen((uint64_t)seed, (nnz_t)row, 0); for (int p = 0; p < n_neg_samples; p++) { int r; gen.next(r); - int t = r % tail_n; + nnz_t t = r % tail_n; T const* negative_sample = tail_embedding + (t * n_components); T negative_sample_reg[n_components]; for (int i = 0; i < n_components; ++i) { @@ -210,16 +211,15 @@ CUML_KERNEL void optimize_batch_kernel_reg(T const* head_embedding, _epoch_of_next_negative_sample + n_neg_samples * epochs_per_negative_sample; } -template +template CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, T* head_buffer, - int head_n, T const* tail_embedding, T* tail_buffer, - const MLCommon::FastIntDiv tail_n, + MLCommon::FastIntDiv tail_n, const int* head, const int* tail, - int nnz, + nnz_t nnz, T const* epochs_per_sample, T* epoch_of_next_negative_sample, T* epoch_of_next_sample, @@ -233,7 +233,7 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, T rounding) { extern __shared__ T embedding_shared_mem_updates[]; - int row = (blockIdx.x * TPB_X) + threadIdx.x; + nnz_t row = (blockIdx.x * TPB_X) + threadIdx.x; if (row >= nnz) return; auto _epoch_of_next_sample = epoch_of_next_sample[row]; if (_epoch_of_next_sample > epoch) return; @@ -242,17 +242,19 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, /** * Positive sample stage (attractive forces) */ - int j = head[row]; - int k = tail[row]; - T const* current = head_embedding + (j * params.n_components); - T const* other = tail_embedding + (k * params.n_components); + nnz_t n_components = params.n_components; + + nnz_t j = head[row]; + nnz_t k = tail[row]; + T const* current = head_embedding + (j * n_components); + T const* other = tail_embedding + (k * n_components); - T* cur_write = head_buffer + (j * params.n_components); - T* oth_write = tail_buffer + (k * params.n_components); + T* cur_write = head_buffer + (j * n_components); + T* oth_write = tail_buffer + (k * n_components); T* current_buffer{nullptr}; if (use_shared_mem) { current_buffer = (T*)embedding_shared_mem_updates + threadIdx.x; } - auto dist_squared = rdist(current, other, params.n_components); + auto dist_squared = rdist(current, other, n_components); // Attractive force between the two vertices, since they // are connected by an edge in the 1-skeleton. auto attractive_grad_coeff = T(0.0); @@ -264,7 +266,7 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, * (update `other` embedding only if we are * performing unsupervised training). */ - for (int d = 0; d < params.n_components; d++) { + for (int d = 0; d < n_components; d++) { auto grad_d = clip(attractive_grad_coeff * (current[d] - other[d]), T(-4.0), T(4.0)); grad_d *= alpha; if (use_shared_mem) { @@ -279,7 +281,7 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, // storing gradients for negative samples back to global memory if (use_shared_mem && move_other) { __syncthreads(); - for (int d = 0; d < params.n_components; d++) { + for (int d = 0; d < n_components; d++) { auto grad = current_buffer[d * TPB_X]; raft::myAtomicAdd((T*)oth_write + d, truncate_gradient(rounding, -grad)); } @@ -291,13 +293,13 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, /** * Negative sampling stage */ - raft::random::detail::PhiloxGenerator gen((uint64_t)seed, (uint64_t)row, 0); + raft::random::detail::PhiloxGenerator gen((uint64_t)seed, (nnz_t)row, 0); for (int p = 0; p < n_neg_samples; p++) { int r; gen.next(r); - int t = r % tail_n; - T const* negative_sample = tail_embedding + (t * params.n_components); - dist_squared = rdist(current, negative_sample, params.n_components); + nnz_t t = r % tail_n; + T const* negative_sample = tail_embedding + (t * n_components); + dist_squared = rdist(current, negative_sample, n_components); // repulsive force between two vertices auto repulsive_grad_coeff = T(0.0); if (dist_squared > T(0.0)) { @@ -309,7 +311,7 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, * (which has been negatively sampled) by updating * their 'weights' to push them farther in Euclidean space. */ - for (int d = 0; d < params.n_components; d++) { + for (int d = 0; d < n_components; d++) { auto grad_d = T(0.0); if (repulsive_grad_coeff > T(0.0)) grad_d = clip(repulsive_grad_coeff * (current[d] - negative_sample[d]), T(-4.0), T(4.0)); @@ -327,7 +329,7 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, // storing gradients for positive samples back to global memory if (use_shared_mem) { __syncthreads(); - for (int d = 0; d < params.n_components; d++) { + for (int d = 0; d < n_components; d++) { raft::myAtomicAdd((T*)cur_write + d, truncate_gradient(rounding, current_buffer[d * TPB_X])); } @@ -348,16 +350,15 @@ CUML_KERNEL void optimize_batch_kernel(T const* head_embedding, * @param rounding: Floating rounding factor used to truncate the gradient update for * deterministic result. */ -template +template void call_optimize_batch_kernel(T const* head_embedding, T* head_buffer, - int head_n, T const* tail_embedding, T* tail_buffer, - const MLCommon::FastIntDiv& tail_n, + int tail_n, const int* head, const int* tail, - int nnz, + nnz_t nnz, T const* epochs_per_sample, T* epoch_of_next_negative_sample, T* epoch_of_next_sample, @@ -378,32 +379,31 @@ void call_optimize_batch_kernel(T const* head_embedding, T nsr_inv = T(1.0) / params->negative_sample_rate; if (params->n_components == 2) { // multicore implementation with registers - optimize_batch_kernel_reg<<>>(head_embedding, - head_buffer, - head_n, - tail_embedding, - tail_buffer, - tail_n, - head, - tail, - nnz, - epochs_per_sample, - epoch_of_next_negative_sample, - epoch_of_next_sample, - alpha, - n, - gamma, - seed, - move_other, - *params, - nsr_inv, - rounding); + optimize_batch_kernel_reg + <<>>(head_embedding, + head_buffer, + tail_embedding, + tail_buffer, + tail_n, + head, + tail, + nnz, + epochs_per_sample, + epoch_of_next_negative_sample, + epoch_of_next_sample, + alpha, + n, + gamma, + seed, + move_other, + *params, + nsr_inv, + rounding); } else if (use_shared_mem) { // multicore implementation with shared memory - optimize_batch_kernel + optimize_batch_kernel <<>>(head_embedding, head_buffer, - head_n, tail_embedding, tail_buffer, tail_n, @@ -423,26 +423,26 @@ void call_optimize_batch_kernel(T const* head_embedding, rounding); } else { // multicore implementation without shared memory - optimize_batch_kernel<<>>(head_embedding, - head_buffer, - head_n, - tail_embedding, - tail_buffer, - tail_n, - head, - tail, - nnz, - epochs_per_sample, - epoch_of_next_negative_sample, - epoch_of_next_sample, - alpha, - n, - gamma, - seed, - move_other, - *params, - nsr_inv, - rounding); + optimize_batch_kernel + <<>>(head_embedding, + head_buffer, + tail_embedding, + tail_buffer, + tail_n, + head, + tail, + nnz, + epochs_per_sample, + epoch_of_next_negative_sample, + epoch_of_next_sample, + alpha, + n, + gamma, + seed, + move_other, + *params, + nsr_inv, + rounding); } } } // namespace Algo diff --git a/cpp/src/umap/simpl_set_embed/runner.cuh b/cpp/src/umap/simpl_set_embed/runner.cuh index dabedd4bdd..a2a6d08e27 100644 --- a/cpp/src/umap/simpl_set_embed/runner.cuh +++ b/cpp/src/umap/simpl_set_embed/runner.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * Copyright (c) 2019-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ namespace SimplSetEmbed { using namespace ML; -template +template void run(int m, int n, raft::sparse::COO* coo, @@ -38,7 +38,7 @@ void run(int m, int algorithm = 0) { switch (algorithm) { - case 0: SimplSetEmbed::Algo::launcher(m, n, coo, params, embedding, stream); + case 0: SimplSetEmbed::Algo::launcher(m, n, coo, params, embedding, stream); } } } // namespace SimplSetEmbed diff --git a/cpp/src/umap/supervised.cuh b/cpp/src/umap/supervised.cuh index 1a9739f280..8bb9bc56cb 100644 --- a/cpp/src/umap/supervised.cuh +++ b/cpp/src/umap/supervised.cuh @@ -101,7 +101,7 @@ void reset_local_connectivity(raft::sparse::COO* in_coo, * and this will update the fuzzy simplicial set to respect that label * data. */ -template +template void categorical_simplicial_set_intersection(raft::sparse::COO* graph_coo, value_t* target, cudaStream_t stream, @@ -119,7 +119,7 @@ void categorical_simplicial_set_intersection(raft::sparse::COO* graph_c far_dist); } -template +template CUML_KERNEL void sset_intersection_kernel(int* row_ind1, int* cols1, value_t* vals1, @@ -177,7 +177,7 @@ CUML_KERNEL void sset_intersection_kernel(int* row_ind1, * Computes the CSR column index pointer and values * for the general simplicial set intersecftion. */ -template +template void general_simplicial_set_intersection(int* row1_ind, raft::sparse::COO* in1, int* row2_ind, @@ -236,28 +236,28 @@ void general_simplicial_set_intersection(int* row1_ind, dim3 grid(raft::ceildiv(in1->nnz, TPB_X), 1, 1); dim3 blk(TPB_X, 1, 1); - sset_intersection_kernel<<>>(row1_ind, - in1->cols(), - in1->vals(), - in1->nnz, - row2_ind, - in2->cols(), - in2->vals(), - in2->nnz, - result_ind.data(), - result->cols(), - result->vals(), - result->nnz, - left_min, - right_min, - in1->n_rows, - weight); + sset_intersection_kernel<<>>(row1_ind, + in1->cols(), + in1->vals(), + in1->nnz, + row2_ind, + in2->cols(), + in2->vals(), + in2->nnz, + result_ind.data(), + result->cols(), + result->vals(), + result->nnz, + left_min, + right_min, + in1->n_rows, + weight); RAFT_CUDA_TRY(cudaGetLastError()); dim3 grid_n(raft::ceildiv(result->nnz, TPB_X), 1, 1); } -template +template void perform_categorical_intersection(T* y, raft::sparse::COO* rgraph_coo, raft::sparse::COO* final_coo, @@ -267,7 +267,7 @@ void perform_categorical_intersection(T* y, float far_dist = 1.0e12; // target weight if (params->target_weight < 1.0) far_dist = 2.5 * (1.0 / (1.0 - params->target_weight)); - categorical_simplicial_set_intersection(rgraph_coo, y, stream, far_dist); + categorical_simplicial_set_intersection(rgraph_coo, y, stream, far_dist); raft::sparse::COO comp_coo(stream); raft::sparse::op::coo_remove_zeros(rgraph_coo, &comp_coo, stream); @@ -277,7 +277,7 @@ void perform_categorical_intersection(T* y, RAFT_CUDA_TRY(cudaPeekAtLastError()); } -template +template void perform_general_intersection(const raft::handle_t& handle, value_t* y, raft::sparse::COO* rgraph_coo, @@ -317,13 +317,13 @@ void perform_general_intersection(const raft::handle_t& handle, */ raft::sparse::COO ygraph_coo(stream); - FuzzySimplSet::run(rgraph_coo->n_rows, - y_knn_indices.data(), - y_knn_dists.data(), - params->target_n_neighbors, - &ygraph_coo, - params, - stream); + FuzzySimplSet::run(rgraph_coo->n_rows, + y_knn_indices.data(), + y_knn_dists.data(), + params->target_n_neighbors, + &ygraph_coo, + params, + stream); RAFT_CUDA_TRY(cudaPeekAtLastError()); if (ML::default_logger().should_log(ML::level_enum::debug)) { @@ -349,13 +349,13 @@ void perform_general_intersection(const raft::handle_t& handle, raft::sparse::convert::sorted_coo_to_csr(rgraph_coo, xrow_ind.data(), stream); raft::sparse::COO result_coo(stream); - general_simplicial_set_intersection(xrow_ind.data(), - rgraph_coo, - yrow_ind.data(), - &cygraph_coo, - &result_coo, - params->target_weight, - stream); + general_simplicial_set_intersection(xrow_ind.data(), + rgraph_coo, + yrow_ind.data(), + &cygraph_coo, + &result_coo, + params->target_weight, + stream); /** * Remove zeros diff --git a/cpp/src/umap/umap.cu b/cpp/src/umap/umap.cu index 899051f8de..9b06db078a 100644 --- a/cpp/src/umap/umap.cu +++ b/cpp/src/umap/umap.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * Copyright (c) 2021-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,8 @@ #include #include +#include + #include namespace ML { @@ -56,23 +58,28 @@ std::unique_ptr> get_graph( UMAPAlgo::_get_graph_supervised, + uint64_t, TPB_X>(handle, inputs, params, graph.get()); } else { UMAPAlgo::_get_graph, + uint64_t, TPB_X>(handle, inputs, params, graph.get()); } return graph; } else { manifold_dense_inputs_t inputs(X, y, n, d); if (y != nullptr) { + UMAPAlgo::_get_graph_supervised, + uint64_t, + TPB_X>(handle, inputs, params, graph.get()); + } else { UMAPAlgo:: - _get_graph_supervised, TPB_X>( + _get_graph, uint64_t, TPB_X>( handle, inputs, params, graph.get()); - } else { - UMAPAlgo::_get_graph, TPB_X>( - handle, inputs, params, graph.get()); } return graph; } @@ -88,7 +95,7 @@ void refine(const raft::handle_t& handle, { CUML_LOG_DEBUG("Calling UMAP::refine() with precomputed KNN"); manifold_dense_inputs_t inputs(X, nullptr, n, d); - UMAPAlgo::_refine, TPB_X>( + UMAPAlgo::_refine, uint64_t, TPB_X>( handle, inputs, params, graph, embeddings); } @@ -102,8 +109,9 @@ void init_and_refine(const raft::handle_t& handle, { CUML_LOG_DEBUG("Calling UMAP::init_and_refine() with precomputed KNN"); manifold_dense_inputs_t inputs(X, nullptr, n, d); - UMAPAlgo::_init_and_refine, TPB_X>( - handle, inputs, params, graph, embeddings); + UMAPAlgo:: + _init_and_refine, uint64_t, TPB_X>( + handle, inputs, params, graph, embeddings); } void fit(const raft::handle_t& handle, @@ -126,21 +134,26 @@ void fit(const raft::handle_t& handle, UMAPAlgo::_fit_supervised, + uint64_t, TPB_X>(handle, inputs, params, embeddings, graph); } else { UMAPAlgo::_fit, + uint64_t, TPB_X>(handle, inputs, params, embeddings, graph); } } else { manifold_dense_inputs_t inputs(X, y, n, d); if (y != nullptr) { - UMAPAlgo::_fit_supervised, TPB_X>( - handle, inputs, params, embeddings, graph); + UMAPAlgo::_fit_supervised, + uint64_t, + TPB_X>(handle, inputs, params, embeddings, graph); } else { - UMAPAlgo::_fit, TPB_X>( + UMAPAlgo::_fit, uint64_t, TPB_X>( handle, inputs, params, embeddings, graph); } } @@ -167,11 +180,13 @@ void fit_sparse(const raft::handle_t& handle, UMAPAlgo::_fit_supervised, + uint64_t, TPB_X>(handle, inputs, params, embeddings, graph); } else { UMAPAlgo::_fit, + uint64_t, TPB_X>(handle, inputs, params, embeddings, graph); } } else { @@ -180,11 +195,13 @@ void fit_sparse(const raft::handle_t& handle, UMAPAlgo::_fit_supervised, + uint64_t, TPB_X>(handle, inputs, params, embeddings, graph); } else { UMAPAlgo::_fit, + uint64_t, TPB_X>(handle, inputs, params, embeddings, graph); } } @@ -205,7 +222,7 @@ void transform(const raft::handle_t& handle, "build algo nn_descent not supported for transform()"); manifold_dense_inputs_t inputs(X, nullptr, n, d); manifold_dense_inputs_t orig_inputs(orig_X, nullptr, orig_n, d); - UMAPAlgo::_transform, TPB_X>( + UMAPAlgo::_transform, uint64_t, TPB_X>( handle, inputs, orig_inputs, embedding, embedding_n, params, transformed); } @@ -233,8 +250,9 @@ void transform_sparse(const raft::handle_t& handle, manifold_sparse_inputs_t orig_x_inputs( orig_x_indptr, orig_x_indices, orig_x_data, nullptr, orig_nnz, orig_n, d); - UMAPAlgo::_transform, TPB_X>( - handle, inputs, orig_x_inputs, embedding, embedding_n, params, transformed); + UMAPAlgo:: + _transform, uint64_t, TPB_X>( + handle, inputs, orig_x_inputs, embedding, embedding_n, params, transformed); } } // namespace UMAP