diff --git a/.gitignore b/.gitignore index 80709dbb96..c2528d2cd0 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,7 @@ log .DS_Store dask-worker-space/ *.egg-info/ +*.bin ## scikit-build _skbuild diff --git a/README.md b/README.md index 773d98e23a..a178d90008 100755 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ RAFT relies heavily on RMM which eases the burden of configuring different alloc ### Multi-dimensional Arrays -The APIs in RAFT currently accept raw pointers to device memory and we are in the process of simplifying the APIs with the [mdspan](https://arxiv.org/abs/2010.06474) multi-dimensional array view for representing data in higher dimensions similar to the `ndarray` in the Numpy Python library. RAFT also contains the corresponding owning `mdarray` structure, which simplifies the allocation and management of multi-dimensional data in both host and device (GPU) memory. +The APIs in RAFT accept the [mdspan](https://arxiv.org/abs/2010.06474) multi-dimensional array view for representing data in higher dimensions similar to the `ndarray` in the Numpy Python library. RAFT also contains the corresponding owning `mdarray` structure, which simplifies the allocation and management of multi-dimensional data in both host and device (GPU) memory. The `mdarray` forms a convenience layer over RMM and can be constructed in RAFT using a number of different helper functions: diff --git a/ci/checks/copyright.py b/ci/checks/copyright.py index 43a4a186f8..a44314a6ce 100644 --- a/ci/checks/copyright.py +++ b/ci/checks/copyright.py @@ -37,7 +37,7 @@ re.compile(r"setup[.]cfg$"), re.compile(r"meta[.]yaml$") ] -ExemptFiles = ["cpp/include/raft/spatial/knn/detail/faiss_select/"] +ExemptFiles = ["cpp/include/raft/neighbors/detail/faiss_select/"] # this will break starting at year 10000, which is probably OK :) CheckSimple = re.compile( diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 83390ea881..2999045a0c 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -223,6 +223,10 @@ target_link_libraries( ) target_compile_features(raft INTERFACE cxx_std_17 $) +target_compile_options( + raft INTERFACE $<$:--expt-extended-lambda + --expt-relaxed-constexpr> +) # Endian detection include(TestBigEndian) @@ -364,6 +368,17 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/matrix/specializations/detail/select_k_float_int64_t.cu src/distance/matrix/specializations/detail/select_k_half_uint32_t.cu src/distance/matrix/specializations/detail/select_k_half_int64_t.cu + src/distance/neighbors/ivf_flat_search.cu + src/distance/neighbors/ivf_flat_build.cu + src/distance/neighbors/specializations/ivfflat_build_float_int64_t.cu + src/distance/neighbors/specializations/ivfflat_build_int8_t_int64_t.cu + src/distance/neighbors/specializations/ivfflat_build_uint8_t_int64_t.cu + src/distance/neighbors/specializations/ivfflat_extend_float_int64_t.cu + src/distance/neighbors/specializations/ivfflat_extend_int8_t_int64_t.cu + src/distance/neighbors/specializations/ivfflat_extend_uint8_t_int64_t.cu + src/distance/neighbors/specializations/ivfflat_search_float_int64_t.cu + src/distance/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu + src/distance/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu src/distance/neighbors/ivfpq_build.cu src/distance/neighbors/ivfpq_deserialize.cu src/distance/neighbors/ivfpq_serialize.cu diff --git a/cpp/bench/cluster/kmeans.cu b/cpp/bench/cluster/kmeans.cu index 7079e33f25..f593ec090d 100644 --- a/cpp/bench/cluster/kmeans.cu +++ b/cpp/bench/cluster/kmeans.cu @@ -38,7 +38,13 @@ inline auto operator<<(std::ostream& os, const KMeansBenchParams& p) -> std::ost template struct KMeans : public BlobsFixture { - KMeans(const KMeansBenchParams& p) : BlobsFixture(p.data, p.blobs), params(p) {} + KMeans(const KMeansBenchParams& p) + : BlobsFixture(p.data, p.blobs), + params(p), + centroids(this->handle), + labels(this->handle) + { + } void run_benchmark(::benchmark::State& state) override { diff --git a/cpp/bench/cluster/kmeans_balanced.cu b/cpp/bench/cluster/kmeans_balanced.cu index 1cfdfbe49a..8dda155a59 100644 --- a/cpp/bench/cluster/kmeans_balanced.cu +++ b/cpp/bench/cluster/kmeans_balanced.cu @@ -32,7 +32,7 @@ struct KMeansBalancedBenchParams { template struct KMeansBalanced : public fixture { - KMeansBalanced(const KMeansBalancedBenchParams& p) : params(p) {} + KMeansBalanced(const KMeansBalancedBenchParams& p) : params(p), X(handle), centroids(handle) {} void run_benchmark(::benchmark::State& state) override { diff --git a/cpp/bench/common/benchmark.hpp b/cpp/bench/common/benchmark.hpp index b8babf0582..4b6e1ba286 100644 --- a/cpp/bench/common/benchmark.hpp +++ b/cpp/bench/common/benchmark.hpp @@ -172,7 +172,10 @@ struct BlobsParams { template class BlobsFixture : public fixture { public: - BlobsFixture(const DatasetParams dp, const BlobsParams bp) : data_params(dp), blobs_params(bp) {} + BlobsFixture(const DatasetParams dp, const BlobsParams bp) + : data_params(dp), blobs_params(bp), X(this->handle) + { + } virtual void run_benchmark(::benchmark::State& state) = 0; diff --git a/cpp/bench/distance/fused_l2_nn.cu b/cpp/bench/distance/fused_l2_nn.cu index 48473b2846..7531784707 100644 --- a/cpp/bench/distance/fused_l2_nn.cu +++ b/cpp/bench/distance/fused_l2_nn.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,7 +36,16 @@ inline auto operator<<(std::ostream& os, const fusedl2nn_inputs& p) -> std::ostr template struct fusedl2nn : public fixture { - fusedl2nn(const fusedl2nn_inputs& p) : params(p) {} + fusedl2nn(const fusedl2nn_inputs& p) + : params(p), + workspace(this->handle), + x(this->handle), + y(this->handle), + x_norm(this->handle), + y_norm(this->handle), + out(this->handle) + { + } void allocate_data(const ::benchmark::State& state) override { diff --git a/cpp/bench/matrix/argmin.cu b/cpp/bench/matrix/argmin.cu index 3869f0c5e1..929eed48c4 100644 --- a/cpp/bench/matrix/argmin.cu +++ b/cpp/bench/matrix/argmin.cu @@ -30,7 +30,7 @@ struct ArgminParams { template struct Argmin : public fixture { - Argmin(const ArgminParams& p) : params(p) {} + Argmin(const ArgminParams& p) : params(p), matrix(this->handle), indices(this->handle) {} void allocate_data(const ::benchmark::State& state) override { diff --git a/cpp/bench/matrix/gather.cu b/cpp/bench/matrix/gather.cu index c5d80744cd..213e2aa55f 100644 --- a/cpp/bench/matrix/gather.cu +++ b/cpp/bench/matrix/gather.cu @@ -37,7 +37,10 @@ inline auto operator<<(std::ostream& os, const GatherParams& p) -> std::os template struct Gather : public fixture { - Gather(const GatherParams& p) : params(p) {} + Gather(const GatherParams& p) + : params(p), matrix(this->handle), map(this->handle), out(this->handle), stencil(this->handle) + { + } void allocate_data(const ::benchmark::State& state) override { diff --git a/cpp/bench/neighbors/knn.cuh b/cpp/bench/neighbors/knn.cuh index 259d39d8f7..fe8c2c10d8 100644 --- a/cpp/bench/neighbors/knn.cuh +++ b/cpp/bench/neighbors/knn.cuh @@ -178,7 +178,6 @@ struct ivf_pq_knn { { index_params.n_lists = 4096; index_params.metric = raft::distance::DistanceType::L2Expanded; - auto data_view = raft::make_device_matrix_view(data, ps.n_samples, ps.n_dims); index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view)); } @@ -189,13 +188,12 @@ struct ivf_pq_knn { IdxT* out_idxs) { search_params.n_probes = 20; - auto queries_view = raft::make_device_matrix_view(search_items, ps.n_queries, ps.n_dims); auto idxs_view = raft::make_device_matrix_view(out_idxs, ps.n_queries, ps.k); auto dists_view = raft::make_device_matrix_view(out_dists, ps.n_queries, ps.k); raft::neighbors::ivf_pq::search( - handle, search_params, *index, queries_view, ps.k, idxs_view, dists_view); + handle, search_params, *index, queries_view, idxs_view, dists_view); } }; diff --git a/cpp/include/raft/core/coo_matrix.hpp b/cpp/include/raft/core/coo_matrix.hpp new file mode 100644 index 0000000000..efab8a1601 --- /dev/null +++ b/cpp/include/raft/core/coo_matrix.hpp @@ -0,0 +1,283 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace raft { + +template +class coordinate_structure_t : public sparse_structure { + public: + coordinate_structure_t(RowType n_rows, ColType n_cols, NZType nnz) + : sparse_structure(n_rows, n_cols, nnz){}; + + /** + * Return span containing underlying rows array + * @return span containing underlying rows array + */ + virtual span get_rows() = 0; + + /** + * Return span containing underlying cols array + * @return span containing underlying cols array + */ + virtual span get_cols() = 0; +}; + +/** + * A non-owning view into a coordinate structure + * + * The structure representation does not have a value/weight + * component so that its const-ness can be varied from it. + * + * @tparam RowType + * @tparam ColType + */ +template +class coordinate_structure_view + : public coordinate_structure_t { + public: + static constexpr SparsityType sparsity_type = PRESERVING; + using view_type = coordinate_structure_view; + using row_type = typename sparse_structure::row_type; + using col_type = typename sparse_structure::col_type; + using nnz_type = typename sparse_structure::nnz_type; + + coordinate_structure_view(span rows, + span cols, + row_type n_rows, + col_type n_cols) + : coordinate_structure_t(n_rows, n_cols, rows.size()), + rows_{rows}, + cols_{cols} + { + } + + /** + * Create a view from this view. Note that this is for interface compatibility + * @return + */ + view_type view() { return view_type(rows_, cols_, this->get_n_rows(), this->get_n_cols()); } + + /** + * Return span containing underlying rows array + * @return span containing underlying rows array + */ + span get_rows() override { return rows_; } + + /** + * Return span containing underlying cols array + * @return span containing underlying cols array + */ + span get_cols() override { return cols_; } + + protected: + raft::span rows_; + raft::span cols_; +}; + +/** + * Represents a sparse coordinate structure (or edge list) + * which can be used to model a COO matrix. + * + * The structure representation does not have a value/weight + * component so that its const-ness can be varied from it. + * + * @tparam RowType + * @tparam ColType + * @tparam ContainerPolicy + */ +template + typename ContainerPolicy> +class coordinate_structure : public coordinate_structure_t { + public: + static constexpr SparsityType sparsity_type = OWNING; + using sparse_structure_type = coordinate_structure_t; + using row_type = typename sparse_structure_type::row_type; + using col_type = typename sparse_structure_type::col_type; + using nnz_type = typename sparse_structure_type::nnz_type; + using view_type = coordinate_structure_view; + using row_container_policy_type = ContainerPolicy; + using col_container_policy_type = ContainerPolicy; + using row_container_type = typename row_container_policy_type::container_type; + using col_container_type = typename col_container_policy_type::container_type; + + coordinate_structure( + raft::resources const& handle, + row_type n_rows, + col_type n_cols, + nnz_type nnz = 0) noexcept(std::is_nothrow_default_constructible_v) + : coordinate_structure_t(n_rows, n_cols, nnz), + cp_rows_{}, + cp_cols_{}, + c_rows_{cp_rows_.create(handle, 0)}, + c_cols_{cp_cols_.create(handle, 0)} {}; + + coordinate_structure(coordinate_structure const&) noexcept( + std::is_nothrow_copy_constructible_v) = default; + coordinate_structure(coordinate_structure&&) noexcept( + std::is_nothrow_move_constructible::value) = default; + + constexpr auto operator=(coordinate_structure const&) noexcept( + std::is_nothrow_copy_assignable::value) -> coordinate_structure& = default; + constexpr auto operator=(coordinate_structure&&) noexcept( + std::is_nothrow_move_assignable::value) -> coordinate_structure& = default; + + ~coordinate_structure() noexcept(std::is_nothrow_destructible::value) = + default; + + /** + * Return a view of the coordinate structure. Structural views are sparsity-preserving + * so while the structural elements can be updated in a non-const view, the sparsity + * itself (number of nonzeros) cannot be changed. + * @return coordinate structure view + */ + view_type view() + { + if (this->get_nnz() == 0) { + RAFT_LOG_WARN( + "Cannot create coordinate_structure.view() because it has not been initialized " + "(sparsity is 0)"); + } + auto row_span = raft::span(c_rows_.data(), this->get_nnz()); + auto col_span = raft::span(c_cols_.data(), this->get_nnz()); + return view_type(row_span, col_span, this->get_n_rows(), this->get_n_cols()); + } + + /** + * Return span containing underlying rows array + * @return span containing underlying rows array + */ + span get_rows() override + { + return raft::span(c_rows_.data(), this->get_n_rows()); + } + + /** + * Return span containing underlying cols array + * @return span containing underlying cols array + */ + span get_cols() override + { + return raft::span(c_cols_.data(), this->get_n_cols()); + } + + /** + * Change the sparsity of the current compressed structure. This will + * resize the underlying data arrays. + * @param nnz new sparsity + */ + void initialize_sparsity(nnz_type nnz) + { + sparse_structure_type::initialize_sparsity(nnz); + c_rows_.resize(nnz); + c_cols_.resize(nnz); + } + + protected: + row_container_policy_type cp_rows_; + col_container_policy_type cp_cols_; + row_container_type c_rows_; + col_container_type c_cols_; +}; + +template +class coo_matrix_view + : public sparse_matrix_view, + is_device> { + public: + coo_matrix_view(raft::span element_span, + coordinate_structure_view structure_view) + : sparse_matrix_view, + is_device>(element_span, structure_view) + { + } +}; + +template + typename ContainerPolicy, + SparsityType sparsity_type = SparsityType::OWNING, + typename structure_type = std::conditional_t< + sparsity_type == SparsityType::OWNING, + coordinate_structure, + coordinate_structure_view>> +class coo_matrix + : public sparse_matrix, + is_device, + ContainerPolicy> { + public: + using element_type = ElementType; + using structure_view_type = typename structure_type::view_type; + using container_type = typename ContainerPolicy::container_type; + using sparse_matrix_type = + sparse_matrix, + is_device, + ContainerPolicy>; + static constexpr auto get_sparsity_type() { return sparsity_type; } + template > + coo_matrix(raft::resources const& handle, + RowType n_rows, + ColType n_cols, + NZType nnz = 0) noexcept(std::is_nothrow_default_constructible_v) + : sparse_matrix_type(handle, n_rows, n_cols, nnz){}; + + // Constructor that owns the data but not the structure + template > + coo_matrix(raft::resources const& handle, std::shared_ptr structure) noexcept( + std::is_nothrow_default_constructible_v) + : sparse_matrix_type(handle, structure){}; + /** + * Return a view of the structure underlying this matrix + * @return + */ + structure_view_type structure_view() { return this->structure_.get()->view(); } + + /** + * Initialize the sparsity on this instance if it was not known upon construction + * Please note this will resize the underlying memory buffers + * @param nnz new sparsity to initialize. + */ + template > + void initialize_sparsity(NZType nnz) + { + sparse_matrix_type::initialize_sparsity(nnz); + this->structure_.get()->initialize_sparsity(nnz); + } +}; +} // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/csr_matrix.hpp b/cpp/include/raft/core/csr_matrix.hpp new file mode 100644 index 0000000000..fac656b3f9 --- /dev/null +++ b/cpp/include/raft/core/csr_matrix.hpp @@ -0,0 +1,296 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace raft { + +template +class compressed_structure_t : public sparse_structure { + public: + /** + * Constructor when sparsity is already known + * @param n_rows total number of rows in matrix + * @param n_cols total number of columns in matrix + * @param nnz sparsity of matrix + */ + compressed_structure_t(IndptrType n_rows, IndicesType n_cols, NZType nnz) + : sparse_structure(n_rows, n_cols, nnz){}; + + /** + * Return span containing underlying indptr array + * @return span containing underlying indptr array + */ + virtual span get_indptr() = 0; + + /** + * Return span containing underlying indices array + * @return span containing underlying indices array + */ + virtual span get_indices() = 0; +}; + +/** + * A non-owning view into a compressed sparse structure + * + * The structure representation does not have a value/weight + * component so that its const-ness can be varied from it. + * + * @tparam IndptrType + * @tparam IndicesType + */ +template +class compressed_structure_view + : public compressed_structure_t { + public: + using sparse_structure_type = compressed_structure_t; + using view_type = compressed_structure_view; + using indptr_type = typename sparse_structure_type::row_type; + using indices_type = typename sparse_structure_type::col_type; + using nnz_type = typename sparse_structure_type::nnz_type; + + compressed_structure_view(span indptr, + span indices, + indices_type n_cols) + : sparse_structure_type(indptr.size() - 1, n_cols, indices.size()), + indptr_(indptr), + indices_(indices) + { + } + + /** + * Return span containing underlying indptr array + * @return span containing underlying indptr array + */ + span get_indptr() override { return indptr_; } + + /** + * Return span containing underlying indices array + * @return span containing underlying indices array + */ + span get_indices() override { return indices_; } + + /** + * Create a view from this view. Note that this is for interface compatibility + * @return + */ + view_type view() { return view_type(indptr_, indices_, this->get_n_cols()); } + + protected: + raft::span indptr_; + raft::span indices_; +}; + +/** + * Represents a sparse compressed structure (or adjacency list) + * which can be used to model both a CSR and CSC matrix. + * + * The structure representation does not have a value/weight + * component so that its const-ness can be varied from it. + * + * @tparam IndptrType + * @tparam IndicesType + * @tparam ContainerPolicy + */ +template + typename ContainerPolicy> +class compressed_structure + : public compressed_structure_t { + public: + using sparse_structure_type = compressed_structure_t; + using indptr_type = typename sparse_structure_type::row_type; + using indices_type = typename sparse_structure_type::col_type; + using nnz_type = typename sparse_structure_type::nnz_type; + using view_type = compressed_structure_view; + using indptr_container_policy_type = ContainerPolicy; + using indices_container_policy_type = ContainerPolicy; + using indptr_container_type = typename indptr_container_policy_type::container_type; + using indices_container_type = typename indices_container_policy_type::container_type; + + constexpr compressed_structure( + raft::resources const& handle, + IndptrType n_rows, + IndicesType n_cols, + NZType nnz = 0) noexcept(std::is_nothrow_default_constructible_v) + : sparse_structure_type{n_rows, n_cols, nnz}, + cp_indptr_{}, + cp_indices_{}, + c_indptr_{cp_indptr_.create(handle, n_rows + 1)}, + c_indices_{cp_indices_.create(handle, nnz)} {}; + + compressed_structure(compressed_structure const&) noexcept( + std::is_nothrow_copy_constructible_v) = default; + compressed_structure(compressed_structure&&) noexcept( + std::is_nothrow_move_constructible::value) = default; + + constexpr auto operator=(compressed_structure const&) noexcept( + std::is_nothrow_copy_assignable::value) + -> compressed_structure& = default; + constexpr auto operator =(compressed_structure&&) noexcept( + std::is_nothrow_move_assignable::value) + -> compressed_structure& = default; + + /** + * Return span containing underlying indptr array + * @return span containing underlying indptr array + */ + span get_indptr() override + { + return raft::span(c_indptr_.data(), this->get_n_rows() + 1); + } + + /** + * Return span containing underlying indices array + * @return span containing underlying indices array + */ + span get_indices() override + { + if (this->get_nnz() == 0) { + RAFT_LOG_WARN("Indices requested for structure that has uninitialized sparsity."); + } + return raft::span(c_indices_.data(), this->get_nnz()); + } + + ~compressed_structure() noexcept(std::is_nothrow_destructible::value) = + default; + + /** + * Return a view of the compressed structure. Structural views are sparsity-preserving + * so while the structural elements can be updated in a non-const view, the sparsity + * itself (number of nonzeros) cannot be changed. + * @return compressed structure view + */ + view_type view() + { + if (this->get_nnz() == 0) { + RAFT_LOG_WARN( + "Cannot create compressed_structure.view() because it has not been initialized (sparsity " + "is 0)"); + } + auto indptr_span = raft::span(c_indptr_.data(), this->get_n_rows() + 1); + auto indices_span = raft::span(c_indices_.data(), this->get_nnz()); + return view_type(indptr_span, indices_span, this->get_n_cols()); + } + + /** + * Change the sparsity of the current compressed structure. This will + * resize the underlying data arrays. + * @param nnz new sparsity + */ + void initialize_sparsity(NZType nnz) override + { + sparse_structure_type::initialize_sparsity(nnz); + c_indptr_.resize(this->get_n_rows() + 1); + c_indices_.resize(nnz); + } + + protected: + indptr_container_policy_type cp_indptr_; + indices_container_policy_type cp_indices_; + indptr_container_type c_indptr_; + indices_container_type c_indices_; +}; +template +class csr_matrix_view + : public sparse_matrix_view, + is_device> { + public: + csr_matrix_view( + raft::span element_span, + compressed_structure_view structure_view) + : sparse_matrix_view, + is_device>(element_span, structure_view){}; +}; + +template + typename ContainerPolicy, + SparsityType sparsity_type = SparsityType::OWNING, + typename structure_type = std::conditional_t< + sparsity_type == SparsityType::OWNING, + compressed_structure, + compressed_structure_view>> +class csr_matrix + : public sparse_matrix, + is_device, + ContainerPolicy> { + public: + using element_type = ElementType; + using structure_view_type = typename structure_type::view_type; + static constexpr auto get_sparsity_type() { return sparsity_type; } + using sparse_matrix_type = + sparse_matrix, + is_device, + ContainerPolicy>; + using container_type = typename ContainerPolicy::container_type; + + template > + csr_matrix(raft::resources const& handle, + IndptrType n_rows, + IndicesType n_cols, + NZType nnz = 0) noexcept(std::is_nothrow_default_constructible_v) + : sparse_matrix_type(handle, n_rows, n_cols, nnz){}; + + // Constructor that owns the data but not the structure + + template > + csr_matrix(raft::resources const& handle, std::shared_ptr structure) noexcept( + std::is_nothrow_default_constructible_v) + : sparse_matrix_type(handle, structure){}; + + /** + * Initialize the sparsity on this instance if it was not known upon construction + * Please note this will resize the underlying memory buffers + * @param nnz new sparsity to initialize. + */ + template > + void initialize_sparsity(NZType nnz) + { + sparse_matrix_type::initialize_sparsity(nnz); + this->structure_.get()->initialize_sparsity(nnz); + } + + /** + * Return a view of the structure underlying this matrix + * @return + */ + structure_view_type structure_view() { return this->structure_.get()->view(); } +}; +} // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/detail/device_mdarray.hpp b/cpp/include/raft/core/device_container_policy.hpp similarity index 90% rename from cpp/include/raft/core/detail/device_mdarray.hpp rename to cpp/include/raft/core/device_container_policy.hpp index 31dfaba70a..eef981e56f 100644 --- a/cpp/include/raft/core/detail/device_mdarray.hpp +++ b/cpp/include/raft/core/device_container_policy.hpp @@ -22,19 +22,20 @@ */ #pragma once #include -#include #include #include // dynamic_extent #include +#include +#include #include #include #include #include -namespace raft::detail { +namespace raft { /** * @brief A simplified version of thrust::device_reference with support for CUDA stream. */ @@ -137,6 +138,8 @@ class device_uvector { return device_reference{thrust::device_ptr{data_.data() + i}, data_.stream()}; } + void resize(size_type size) { data_.resize(size, data_.stream()); } + [[nodiscard]] auto data() noexcept -> pointer { return data_.data(); } [[nodiscard]] auto data() const noexcept -> const_pointer { return data_.data(); } }; @@ -146,9 +149,6 @@ class device_uvector { */ template class device_uvector_policy { - rmm::cuda_stream_view stream_; - rmm::mr::device_memory_resource* mr_; - public: using element_type = ElementType; using container_type = device_uvector; @@ -162,19 +162,12 @@ class device_uvector_policy { using const_accessor_policy = std::experimental::default_accessor; public: - auto create(size_t n) -> container_type + auto create(raft::resources const& res, size_t n) -> container_type { - return mr_ ? container_type(n, stream_, mr_) : container_type(n, stream_); + return container_type(n, resource::get_cuda_stream(res), resource::get_workspace_resource(res)); } - device_uvector_policy() = delete; - explicit device_uvector_policy( - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = - nullptr) noexcept(std::is_nothrow_copy_constructible_v) - : stream_{stream}, mr_(mr) - { - } + device_uvector_policy() = default; [[nodiscard]] constexpr auto access(container_type& c, size_t n) const noexcept -> reference { @@ -190,4 +183,4 @@ class device_uvector_policy { [[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; } }; -} // namespace raft::detail +} // namespace raft diff --git a/cpp/include/raft/core/device_coo_matrix.hpp b/cpp/include/raft/core/device_coo_matrix.hpp new file mode 100644 index 0000000000..b1e9ca30fc --- /dev/null +++ b/cpp/include/raft/core/device_coo_matrix.hpp @@ -0,0 +1,387 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include + +namespace raft { + +template typename ContainerPolicy = device_uvector_policy, + SparsityType sparsity_type = SparsityType::OWNING> +using device_coo_matrix = + coo_matrix; + +/** + * Specialization for a coo matrix view which uses device memory + */ +template +using device_coo_matrix_view = coo_matrix_view; + +/** + * Specialization for a sparsity-owning coo matrix which uses device memory + */ +template typename ContainerPolicy = device_uvector_policy> +using device_sparsity_owning_coo_matrix = + coo_matrix; + +template typename ContainerPolicy = device_uvector_policy> +using device_sparsity_preserving_coo_matrix = coo_matrix; + +/** + * Specialization for a sparsity-owning coordinate structure which uses device memory + */ +template typename ContainerPolicy = device_uvector_policy> +using device_coordinate_structure = + coordinate_structure; + +/** + * Specialization for a sparsity-preserving coordinate structure view which uses device memory + */ +template +using device_coordinate_structure_view = coordinate_structure_view; + +template +struct is_device_coo_matrix : std::false_type { +}; + +template + typename ContainerPolicy, + SparsityType sparsity_type> +struct is_device_coo_matrix< + device_coo_matrix> + : std::true_type { +}; + +template +constexpr bool is_device_coo_matrix_v = is_device_coo_matrix::value; + +template +constexpr bool is_device_coo_sparsity_owning_v = + is_device_coo_matrix::value and T::get_sparsity_type() == OWNING; + +template +constexpr bool is_device_coo_sparsity_preserving_v = + is_device_coo_matrix::value and T::get_sparsity_type() == PRESERVING; + +/** + * Create a sparsity-owning sparse matrix in the coordinate format. sparsity-owning means that + * all of the underlying vectors (data, indptr, indices) are owned by the coo_matrix instance. If + * not known up front, the sparsity can be ignored in this factory function and `resize()` invoked + * on the instance once the sparsity is known. + * + * @code{.cpp} + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * + * raft::device_resources handle; + * coo_matrix = raft::make_device_coo_matrix(handle, n_rows, n_cols); + * ... + * // compute expected sparsity + * ... + * int nnz = 5000; + * coo_matrix.initialize_sparsity(nnz); + * @endcode + * + * @tparam ElementType + * @tparam RowType + * @tparam ColType + * @tparam NZType + * @param[in] handle a raft handle for managing expensive device resources + * @param[in] n_rows total number of rows in the matrix + * @param[in] n_cols total number of columns in the matrix + * @param[in] nnz number of non-zeros in the matrix if known [optional] + * @return a sparsity-owning sparse matrix in coordinate (coo) format + */ +template +auto make_device_coo_matrix(raft::resources const& handle, + RowType n_rows, + ColType n_cols, + NZType nnz = 0) +{ + return device_sparsity_owning_coo_matrix( + handle, n_rows, n_cols, nnz); +} + +/** + * Create a sparsity-preserving sparse matrix in the coordinate format. sparsity-preserving means + * that a view of the coo sparsity is supplied, allowing the values in the sparsity to change but + * not the sparsity itself. The csr_matrix instance does not own the sparsity, the sparsity must + * be known up front, and cannot be resized later. + * + * @code{.cpp} + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * + * raft::device_resources handle; + * coo_structure = raft::make_device_coordinate_structure(handle, n_rows, n_cols); + * ... + * // compute expected sparsity + * ... + * coo_structure.initialize_sparsity(nnz); + * coo_matrix = raft::make_device_coo_matrix(handle, coo_structure.view()); + * @endcode + * + * @tparam ElementType + * @tparam RowType + * @tparam ColType + * @tparam NZType + * @param[in] handle raft handle for managing expensive device resources + * @param[in] structure_ a sparsity-preserving coordinate structural view + * @return a sparsity-preserving sparse matrix in coordinate (coo) format + */ +template +auto make_device_coo_matrix(raft::resources const& handle, + device_coordinate_structure_view structure_) +{ + return device_sparsity_preserving_coo_matrix( + handle, + std::make_shared>(structure_)); +} + +/** + * Create a non-owning sparse matrix view in the coordinate format. This is sparsity-preserving, + * meaning that the underlying sparsity is known and cannot be changed. Use the sparsity-owning + * coo_matrix if sparsity needs to be mutable. + * + * @code{.cpp} + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * int nnz = 5000; + * + * // The following pointer is assumed to reference device memory for a size of nnz + * float* d_elm_ptr = ...; + * + * raft::device_resources handle; + * coo_structure = raft::make_device_coordinate_structure(handle, n_rows, n_cols, nnz); + * coo_matrix_view = raft::make_device_coo_matrix_view(handle, d_elm_ptr, coo_structure.view()); + * @endcode + * + * @tparam ElementType + * @tparam RowType + * @tparam ColType + * @tparam NZType + * @param[in] ptr a pointer to array of nonzero matrix elements on device (size nnz) + * @param[in] structure_ a sparsity-preserving coordinate structural view + * @return a sparsity-preserving sparse matrix in coordinate (coo) format + */ +template +auto make_device_coo_matrix_view( + ElementType* ptr, device_coordinate_structure_view structure_) +{ + return device_coo_matrix_view( + raft::device_span(ptr, structure_.get_nnz()), + std::make_shared>(structure_)); +} + +/** + * Create a non-owning sparse matrix view in the coordinate format. This is sparsity-preserving, + * meaning that the underlying sparsity is known and cannot be changed. Use the sparsity-owning + * coo_matrix if sparsity needs to be mutable. + * + * @code{.cpp} + * #include + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * int nnz = 5000; + * + * // The following span is assumed to be of size nnz + * raft::device_span d_elm_ptr; + * + * raft::device_resources handle; + * coo_structure = raft::make_device_coordinate_structure(handle, n_rows, n_cols, nnz); + * coo_matrix_view = raft::make_device_coo_matrix_view(handle, d_elm_ptr, coo_structure.view()); + * @endcode + * + * @tparam ElementType + * @tparam RowType + * @tparam ColType + * @tparam NZType + * @param[in] elements a device span containing nonzero matrix elements (size nnz) + * @param[in] structure_ a sparsity-preserving coordinate structural view + * @return + */ +template +auto make_device_coo_matrix_view( + raft::device_span elements, + device_coordinate_structure_view structure_) +{ + RAFT_EXPECTS(elements.size() == structure_.get_nnz(), + "Size of elements must be equal to the nnz from the structure"); + return device_coo_matrix_view( + elements, + std::make_shared>(structure_)); +} + +/** + * Create a sparsity-owning coordinate structure object. If not known up front, this object can be + * resized() once the sparsity (number of non-zeros) is known, postponing the allocation of the + * underlying data arrays. + * + * @code{.cpp} + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * int nnz = 5000; + * + * raft::device_resources handle; + * coo_structure = raft::make_device_coordinate_structure(handle, n_rows, n_cols, nnz); + * * ... + * // compute expected sparsity + * ... + * coo_structure.initialize_sparsity(nnz); + * @endcode + * + * @tparam RowType + * @tparam ColType + * @tparam NZType + * @param[in] handle raft handle for managing expensive resources on device + * @param[in] n_rows total number of rows + * @param[in] n_cols total number of cols + * @param[in] nnz number of non-zeros + * @return a sparsity-owning coordinate structure instance + */ +template +auto make_device_coordinate_structure(raft::resources const& handle, + RowType n_rows, + ColType n_cols, + NZType nnz = 0) +{ + return device_coordinate_structure(handle, n_rows, n_cols, nnz); +} + +/** + * Create a non-owning sparsity-preserved coordinate structure view. Sparsity-preserving means that + * the underlying sparsity is known and cannot be changed. Use the sparsity-owning version if the + * sparsity is not known up front. + * + * @code{.cpp} + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * int nnz = 5000; + * + * // The following pointers are assumed to reference device memory of size nnz + * int *rows = ...; + * int *cols = ...; + * + * raft::device_resources handle; + * coo_structure = raft::make_device_coordinate_structure_view(handle, rows, cols, n_rows, n_cols, + * nnz); + * @endcode + * + * @tparam RowType + * @tparam ColType + * @tparam NZType + * @param[in] rows pointer to row indices array on device (size nnz) + * @param[in] cols pointer to column indices array on device (size nnz) + * @param[in] n_rows total number of rows + * @param[in] n_cols total number of columns + * @param[in] nnz number of non-zeros + * @return a sparsity-preserving coordinate structural view + */ +template +auto make_device_coo_structure_view( + RowType* rows, ColType* cols, RowType n_rows, ColType n_cols, NZType nnz) +{ + return device_coordinate_structure_view( + raft::device_span(rows, nnz), raft::device_span(cols, nnz), n_rows, n_cols); +} + +/** + * Create a non-owning sparsity-preserved coordinate structure view. Sparsity-preserving means that + * the underlying sparsity is known and cannot be changed. Use the sparsity-owning version if the + * sparsity is not known up front. + * + * @code{.cpp} + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * int nnz = 5000; + * + * // The following device spans are assumed to be of size nnz + * raft::device_span rows; + * raft::device_span cols; + * + * raft::device_resources handle; + * coo_structure = raft::make_device_coordinate_structure_view(handle, rows, cols, n_rows, n_cols); + * @endcode + * + * @tparam RowType + * @tparam ColType + * @tparam NZType + * @param[in] rows a device span containing row indices (size nnz) + * @param[in] cols a device span containing column indices (size nnz) + * @param[in] n_rows total number of rows + * @param[in] n_cols total number of columns + * @return a sparsity-preserving coordinate structural view + */ +template +auto make_device_coo_structure_view(raft::device_span rows, + raft::device_span cols, + RowType n_rows, + ColType n_cols) +{ + return device_coordinate_structure_view(rows, cols, n_rows, n_cols); +} + +}; // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/device_csr_matrix.hpp b/cpp/include/raft/core/device_csr_matrix.hpp new file mode 100644 index 0000000000..59cabacf6d --- /dev/null +++ b/cpp/include/raft/core/device_csr_matrix.hpp @@ -0,0 +1,418 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace raft { + +template typename ContainerPolicy = device_uvector_policy, + SparsityType sparsity_type = SparsityType::OWNING> +using device_csr_matrix = + csr_matrix; + +/** + * Specialization for a sparsity-owning csr matrix which uses device memory + */ +template typename ContainerPolicy = device_uvector_policy> +using device_sparsity_owning_csr_matrix = + csr_matrix; + +template +struct is_device_csr_matrix : std::false_type { +}; + +template + typename ContainerPolicy, + SparsityType sparsity_type> +struct is_device_csr_matrix< + device_csr_matrix> + : std::true_type { +}; + +template +constexpr bool is_device_csr_matrix_v = is_device_csr_matrix::value; + +template +constexpr bool is_device_csr_sparsity_owning_v = + is_device_csr_matrix::value and T::get_sparsity_type() == OWNING; + +template +constexpr bool is_device_csr_sparsity_preserving_v = + is_device_csr_matrix::value and T::get_sparsity_type() == PRESERVING; + +/** + * Specialization for a csr matrix view which uses device memory + */ +template +using device_csr_matrix_view = csr_matrix_view; + +/** + * Specialization for a sparsity-preserving csr matrix which uses device memory + */ +template typename ContainerPolicy = device_uvector_policy> +using device_sparsity_preserving_csr_matrix = csr_matrix; + +/** + * Specialization for a csr matrix view which uses device memory + */ +template +using device_csr_matrix_view = csr_matrix_view; + +/** + * Specialization for a sparsity-owning compressed structure which uses device memory + */ +template typename ContainerPolicy = device_uvector_policy> +using device_compressed_structure = + compressed_structure; + +/** + * Specialization for a sparsity-preserving compressed structure view which uses device memory + */ +template +using device_compressed_structure_view = + compressed_structure_view; + +/** + * Create a sparsity-owning sparse matrix in the compressed-sparse row format. sparsity-owning + * means that all of the underlying vectors (data, indptr, indices) are owned by the csr_matrix + * instance. If not known up front, the sparsity can be ignored in this factory function and + * `resize()` invoked on the instance once the sparsity is known. + * + * @code{.cpp} + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * + * raft::device_resources handle; + * csr_matrix = raft::make_device_csr_matrix(handle, n_rows, n_cols); + * ... + * // compute expected sparsity + * ... + * int nnz = 5000; + * csr_matrix.initialize_sparsity(nnz); + * @endcode + + * @tparam ElementType + * @tparam IndptrType + * @tparam IndicesType + * @tparam NZType + * @param[in] handle a raft handle for managing expensive device resources + * @param[in] n_rows total number of rows in the matrix + * @param[in] n_cols total number of columns in the matrix + * @param[in] nnz number of non-zeros in the matrix if known [optional] + * @return a sparsity-owning sparse matrix in compressed (csr) format + */ +template +auto make_device_csr_matrix(raft::device_resources const& handle, + IndptrType n_rows, + IndicesType n_cols, + NZType nnz = 0) +{ + return device_sparsity_owning_csr_matrix( + handle, n_rows, n_cols, nnz); +} + +/** + * Create a sparsity-preserving sparse matrix in the compressed-sparse row format. + * sparsity-preserving means that a view of the csr sparsity is supplied, allowing the values in + * the sparsity to change but not the sparsity itself. The csr_matrix instance does not own the + * sparsity, the sparsity must be known up front, and cannot be resized later. + * + * @code{.cpp} + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * + * raft::device_resources handle; + * coo_structure = raft::make_device_compressed_structure(handle, n_rows, n_cols); + * ... + * // compute expected sparsity + * ... + * csr_structure.initialize_sparsity(nnz); + * csr_matrix = raft::make_device_csr_matrix(handle, csr_structure.view()); + * @endcode + * + * @tparam ElementType + * @tparam IndptrType + * @tparam IndicesType + * @tparam NZType + * @param[in] handle raft handle for managing expensive device resources + * @param[in] structure_ a sparsity-preserving compressed structural view + * @return a sparsity-preserving sparse matrix in compressed (csr) format + */ +template +auto make_device_csr_matrix( + raft::device_resources const& handle, + device_compressed_structure_view structure_) +{ + return device_sparsity_preserving_csr_matrix( + handle, + std::make_shared>( + structure_)); +} + +/** + * Create a non-owning sparse matrix view in the coordinate format. This is sparsity-preserving, + * meaning that the underlying sparsity is known and cannot be changed. Use the sparsity-owning + * coo_matrix if sparsity needs to be mutable. + * + * @code{.cpp} + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * int nnz = 5000; + * + * // The following pointer is assumed to reference device memory for a size of nnz + * float* d_elm_ptr = ...; + * + * raft::device_resources handle; + * csr_structure = raft::make_device_compressed_structure(handle, n_rows, n_cols, nnz); + * csr_matrix_view = raft::make_device_csr_matrix_view(handle, d_elm_ptr, csr_structure.view()); + * @endcode + * + * @tparam ElementType + * @tparam IndptrType + * @tparam IndicesType + * @tparam NZType + * @param[in] ptr a pointer to array of nonzero matrix elements on device (size nnz) + * @param[in] structure_ a sparsity-preserving compressed sparse structural view + * @return a sparsity-preserving csr matrix view + */ +template +auto make_device_csr_matrix_view( + ElementType* ptr, device_compressed_structure_view structure_) +{ + return device_csr_matrix_view( + raft::device_span(ptr, structure_.get_nnz()), std::make_shared(structure_)); +} + +/** + * Create a non-owning sparse matrix view in the compressed-sparse row format. This is + * sparsity-preserving, meaning that the underlying sparsity is known and cannot be changed. Use the + * sparsity-owning coo_matrix if sparsity needs to be mutable. + * + * @code{.cpp} + * #include + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * int nnz = 5000; + * + * // The following span is assumed to be of size nnz + * raft::device_span d_elm_ptr; + * + * raft::device_resources handle; + * csr_structure = raft::make_device_compressed_structure(handle, n_rows, n_cols, nnz); + * csr_matrix_view = raft::make_device_csr_matrix_view(handle, d_elm_ptr, csr_structure.view()); + * @endcode + * + * @tparam ElementType + * @tparam IndptrType + * @tparam IndicesType + * @tparam NZType + * @param[in] elements device span containing array of matrix elements (size nnz) + * @param[in] structure_ a sparsity-preserving structural view + * @return a sparsity-preserving csr matrix view + */ +template +auto make_device_csr_matrix_view( + raft::device_span elements, + device_compressed_structure_view structure_) +{ + RAFT_EXPECTS(elements.size() == structure_.get_nnz(), + "Size of elements must be equal to the nnz from the structure"); + return device_csr_matrix_view( + elements, std::make_shared(structure_)); +} + +/** + * Create a sparsity-owning compressed structure. This is not sparsity-preserving, meaning that + * the underlying sparsity does not need to be known upon construction. When not known up front, + * the allocation of the underlying indices array is delayed until `resize(nnz)` is invoked. + * + * @code{.cpp} + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * int nnz = 5000; + * + * raft::device_resources handle; + * csr_structure = raft::make_device_compressed_structure(handle, n_rows, n_cols, nnz); + * ... + * // compute expected sparsity + * ... + * csr_structure.initialize_sparsity(nnz); + * @endcode + * + * @tparam IndptrType + * @tparam IndicesType + * @tparam NZType + * @param[in] handle raft handle for managing expensive device resources + * @param[in] n_rows total number of rows + * @param[in] n_cols total number of cols + * @param[in] nnz total number of nonzeros, if known + * @return a sparsity-owning compressed structure instance + */ +template +auto make_device_compressed_structure(raft::device_resources const& handle, + IndptrType n_rows, + IndicesType n_cols, + NZType nnz = 0) +{ + return device_compressed_structure(handle, n_rows, n_cols, nnz); +} + +/** + * Create a non-owning sparsity-preserved compressed structure view. Sparsity-preserving means that + * the underlying sparsity is known and cannot be changed. Use the sparsity-owning version if the + * sparsity is not known up front. + * + * @code{.cpp} + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * int nnz = 5000; + * + * // The following pointer is assumed to reference device memory of size n_rows+1 + * int *indptr = ...; + * + * // The following pointer is assumed to reference device memory of size nnz + * int *indices = ...; + * + * raft::device_resources handle; + * csr_structure = raft::make_device_compressed_structure_view(handle, indptr, indices, n_rows, + * n_cols, nnz); + * @endcode * + * + * @tparam ElementType + * @tparam IndptrType + * @tparam IndicesType + * @tparam NZType + * @param[in] indptr structural indptr (size n_rows+1) + * @param[in] indices structural indices (size nnz) + * @param[in] n_rows total number of rows + * @param[in] n_cols total number of columns + * @param[in] nnz number of non-zeros + * @return a sparsity-preserving compressed structural view + */ +template +auto make_device_csr_structure_view( + IndptrType* indptr, IndicesType* indices, IndptrType n_rows, IndicesType n_cols, NZType nnz) +{ + return device_compressed_structure_view( + raft::device_span(indptr, n_rows + 1), + raft::device_span(indices, nnz), + n_cols); +} + +/** + * Create a non-owning sparsity-preserved compressed structure view. Sparsity-preserving means that + * the underlying sparsity is known and cannot be changed. Use the sparsity-owning version if the + * sparsity is not known up front. + * + * @code{.cpp} + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * int nnz = 5000; + * + * // The following device spans is assumed to be of size n_rows+1 + * raft::device_span indptr; + * + * // The following device span is assumed to be of size nnz + * raft::device_span indices; + * + * raft::device_resources handle; + * csr_structure = raft::make_device_compressed_structure_view(handle, indptr, indices, n_rows, + * n_cols); + * @endcode + * + * @tparam IndptrType + * @tparam IndicesType + * @tparam NZType + * @param[in] indptr structural indptr (size n_rows+1) + * @param[in] indices structural indices (size nnz) + * @param[in] n_cols total number of columns + * @return a sparsity-preserving compressed structural view + * + */ +template +auto make_device_csr_structure_view(raft::device_span indptr, + raft::device_span indices, + IndicesType n_cols) +{ + return device_compressed_structure_view(indptr, indices, n_cols); +} + +}; // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/device_mdarray.hpp b/cpp/include/raft/core/device_mdarray.hpp index 03cb09eecb..2c0cb56910 100644 --- a/cpp/include/raft/core/device_mdarray.hpp +++ b/cpp/include/raft/core/device_mdarray.hpp @@ -17,9 +17,10 @@ #pragma once #include -#include +#include #include #include +#include namespace raft { @@ -33,7 +34,7 @@ namespace raft { template > + typename ContainerPolicy = device_uvector_policy> using device_mdarray = mdarray>; @@ -80,14 +81,14 @@ template -auto make_device_mdarray(raft::device_resources const& handle, extents exts) +auto make_device_mdarray(raft::resources const& handle, extents exts) { using mdarray_t = device_mdarray; typename mdarray_t::mapping_type layout{exts}; - typename mdarray_t::container_policy_type policy{handle.get_stream()}; + typename mdarray_t::container_policy_type policy{}; - return mdarray_t{layout, policy}; + return mdarray_t{handle, layout, policy}; } /** @@ -104,16 +105,16 @@ template -auto make_device_mdarray(raft::device_resources const& handle, +auto make_device_mdarray(raft::resources const& handle, rmm::mr::device_memory_resource* mr, extents exts) { using mdarray_t = device_mdarray; typename mdarray_t::mapping_type layout{exts}; - typename mdarray_t::container_policy_type policy{handle.get_stream(), mr}; + typename mdarray_t::container_policy_type policy{}; - return mdarray_t{layout, policy}; + return mdarray_t{handle, layout, policy}; } /** @@ -130,10 +131,10 @@ auto make_device_mdarray(raft::device_resources const& handle, template -auto make_device_matrix(raft::device_resources const& handle, IndexType n_rows, IndexType n_cols) +auto make_device_matrix(raft::resources const& handle, IndexType n_rows, IndexType n_cols) { return make_device_mdarray( - handle.get_stream(), make_extents(n_rows, n_cols)); + handle, make_extents(n_rows, n_cols)); } /** @@ -146,12 +147,12 @@ auto make_device_matrix(raft::device_resources const& handle, IndexType n_rows, * @return raft::device_scalar */ template -auto make_device_scalar(raft::device_resources const& handle, ElementType const& v) +auto make_device_scalar(raft::resources const& handle, ElementType const& v) { scalar_extent extents; using policy_t = typename device_scalar::container_policy_type; - policy_t policy{handle.get_stream()}; - auto scalar = device_scalar{extents, policy}; + policy_t policy{}; + auto scalar = device_scalar{handle, extents, policy}; scalar(0) = v; return scalar; } @@ -168,9 +169,9 @@ auto make_device_scalar(raft::device_resources const& handle, ElementType const& template -auto make_device_vector(raft::device_resources const& handle, IndexType n) +auto make_device_vector(raft::resources const& handle, IndexType n) { - return make_device_mdarray(handle.get_stream(), + return make_device_mdarray(handle, make_extents(n)); } diff --git a/cpp/include/raft/core/device_resources.hpp b/cpp/include/raft/core/device_resources.hpp index 68c56dc9b6..df6b39a368 100644 --- a/cpp/include/raft/core/device_resources.hpp +++ b/cpp/include/raft/core/device_resources.hpp @@ -69,7 +69,6 @@ class device_resources : public resources { } device_resources(const device_resources& handle) : resources{handle} {} - device_resources(device_resources&&) = delete; device_resources& operator=(device_resources&&) = delete; diff --git a/cpp/include/raft/core/detail/host_mdarray.hpp b/cpp/include/raft/core/host_container_policy.hpp similarity index 86% rename from cpp/include/raft/core/detail/host_mdarray.hpp rename to cpp/include/raft/core/host_container_policy.hpp index 74bd55e78c..3b3538ea20 100644 --- a/cpp/include/raft/core/detail/host_mdarray.hpp +++ b/cpp/include/raft/core/host_container_policy.hpp @@ -6,7 +6,7 @@ */ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,9 +22,10 @@ */ #pragma once #include +#include #include -namespace raft::detail { +namespace raft { /** * @brief A container policy for host mdarray. @@ -43,15 +44,10 @@ class host_vector_policy { using const_accessor_policy = std::experimental::default_accessor; public: - auto create(size_t n) -> container_type { return container_type(n); } + auto create(raft::resources const&, size_t n) -> container_type { return container_type(n); } constexpr host_vector_policy() noexcept(std::is_nothrow_default_constructible_v) = default; - explicit constexpr host_vector_policy(rmm::cuda_stream_view) noexcept( - std::is_nothrow_default_constructible_v) - : host_vector_policy() - { - } [[nodiscard]] constexpr auto access(container_type& c, size_t n) const noexcept -> reference { @@ -66,4 +62,4 @@ class host_vector_policy { [[nodiscard]] auto make_accessor_policy() noexcept { return accessor_policy{}; } [[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; } }; -} // namespace raft::detail +} // namespace raft diff --git a/cpp/include/raft/core/host_coo_matrix.hpp b/cpp/include/raft/core/host_coo_matrix.hpp new file mode 100644 index 0000000000..45ec278a7d --- /dev/null +++ b/cpp/include/raft/core/host_coo_matrix.hpp @@ -0,0 +1,382 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace raft { + +template typename ContainerPolicy = host_vector_policy, + SparsityType sparsity_type = SparsityType::OWNING> +using host_coo_matrix = + coo_matrix; + +/** + * Specialization for a coo matrix view which uses host memory + */ +template +using host_coo_matrix_view = coo_matrix_view; + +/** + * Specialization for a sparsity-owning coo matrix which uses host memory + */ +template typename ContainerPolicy = host_vector_policy> +using host_sparsity_owning_coo_matrix = + coo_matrix; + +template typename ContainerPolicy = host_vector_policy> +using host_sparsity_preserving_coo_matrix = coo_matrix; + +/** + * Specialization for a sparsity-owning coordinate structure which uses host memory + */ +template typename ContainerPolicy = host_vector_policy> +using host_coordinate_structure = + coordinate_structure; + +/** + * Specialization for a sparsity-preserving coordinate structure view which uses host memory + */ +template +using host_coordinate_structure_view = coordinate_structure_view; + +template +struct is_host_coo_matrix : std::false_type { +}; + +template + typename ContainerPolicy, + SparsityType sparsity_type> +struct is_host_coo_matrix< + host_coo_matrix> + : std::true_type { +}; + +template +constexpr bool is_host_coo_matrix_v = is_host_coo_matrix::value; + +template +constexpr bool is_host_coo_sparsity_owning_v = + is_host_coo_matrix::value and T::get_sparsity_type() == OWNING; + +template +constexpr bool is_host_coo_sparsity_preserving_v = + is_host_coo_matrix::value and T::get_sparsity_type() == PRESERVING; + +/** + * Create a sparsity-owning sparse matrix in the coordinate format. sparsity-owning means that + * all of the underlying vectors (data, indptr, indices) are owned by the coo_matrix instance. If + * not known up front, the sparsity can be ignored in this factory function and `resize()` invoked + * on the instance once the sparsity is known. + * + * @code{.cpp} + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * + * raft::resources handle; + * coo_matrix = raft::make_host_coo_matrix(handle, n_rows, n_cols); + * ... + * // compute expected sparsity + * ... + * int nnz = 5000; + * coo_matrix.initialize_sparsity(nnz); + * @endcode + * + * @tparam ElementType + * @tparam RowType + * @tparam ColType + * @tparam NZType + * @param[in] handle raft handle for managing expensive resources + * @param[in] n_rows total number of rows in the matrix + * @param[in] n_cols total number of columns in the matrix + * @param[in] nnz number of non-zeros in the matrix if known [optional] + * @return a sparsity-owning sparse matrix in coordinate (coo) format + */ +template +auto make_host_coo_matrix(raft::resources const& handle, + RowType n_rows, + ColType n_cols, + NZType nnz = 0) +{ + return host_sparsity_owning_coo_matrix( + handle, n_rows, n_cols, nnz); +} + +/** + * Create a sparsity-preserving sparse matrix in the coordinate format. sparsity-preserving means + * that a view of the coo sparsity is supplied, allowing the values in the sparsity to change but + * not the sparsity itself. The coo_matrix instance does not own the sparsity, the sparsity must + * be known up front, and cannot be resized later. + * + * @code{.cpp} + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * + * raft::resources handle; + * coo_structure = raft::make_host_coordinate_structure(handle, n_rows, n_cols); + * ... + * // compute expected sparsity + * ... + * coo_structure.initialize_sparsity(nnz); + * coo_matrix = raft::make_host_coo_matrix(handle, coo_structure.view()); + * @endcode + * + * @tparam ElementType + * @tparam RowType + * @tparam ColType + * @tparam NZType + * @param[in] handle raft handle for managing expensive resources + * @param[in] structure_ a sparsity-preserving coordinate structural view + * @return a sparsity-preserving sparse matrix in coordinate (coo) format + */ +template +auto make_host_coo_matrix(raft::resources const& handle, + host_coordinate_structure_view structure_) +{ + return host_sparsity_preserving_coo_matrix( + handle, std::make_shared>(structure_)); +} + +/** + * Create a non-owning sparse matrix view in the coordinate format. This is sparsity-preserving, + * meaning that the underlying sparsity is known and cannot be changed. Use the sparsity-owning + * coo_matrix if sparsity needs to be mutable. + * + * @code{.cpp} + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * int nnz = 5000; + * + * // The following pointer is assumed to reference host-accessible memory for a size of nnz + * float* h_elm_ptr = ...; + * + * raft::resources handle; + * coo_structure = raft::make_host_coordinate_structure(handle, n_rows, n_cols, nnz); + * coo_matrix_view = raft::make_host_coo_matrix_view(handle, h_elm_ptr, coo_structure.view()); + * @endcode + * + * @tparam ElementType + * @tparam RowType + * @tparam ColType + * @tparam NZType + * @param[in] ptr a pointer to array of nonzero matrix elements on host (size nnz) + * @param[in] structure_ a sparsity-preserving coordinate structural view + * @return a sparsity-preserving sparse matrix in coordinate (coo) format + */ +template +auto make_host_coo_matrix_view(ElementType* ptr, + host_coordinate_structure_view structure_) +{ + return host_coo_matrix_view( + raft::host_span(ptr, structure_.get_nnz()), std::make_shared(structure_)); +} + +/** + * Create a non-owning sparse matrix view in the coordinate format. This is sparsity-preserving, + * meaning that the underlying sparsity is known and cannot be changed. Use the sparsity-owning + * coo_matrix if sparsity needs to be mutable. + * + * @code{.cpp} + * #include + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * int nnz = 5000; + * + * // The following span is assumed to be of size nnz + * raft::host_span h_elm_ptr; + * + * raft::resources handle; + * coo_structure = raft::make_host_coordinate_structure(handle, n_rows, n_cols, nnz); + * coo_matrix_view = raft::make_host_coo_matrix_view(handle, h_elm_ptr, coo_structure.view()); + * @endcode + * + * @tparam ElementType + * @tparam RowType + * @tparam ColType + * @tparam NZType + * @param[in] elements a host span containing nonzero matrix elements (size nnz) + * @param[in] structure_ a sparsity-preserving coordinate structural view + * @return + */ +template +auto make_host_coo_matrix_view(raft::host_span elements, + host_coordinate_structure_view structure_) +{ + RAFT_EXPECTS(elements.size() == structure_.get_nnz(), + "Size of elements must be equal to the nnz from the structure"); + return host_coo_matrix_view(elements, + std::make_shared(structure_)); +} + +/** + * Create a sparsity-owning coordinate structure object. If not known up front, this object can be + * resized() once the sparsity (number of non-zeros) is known, postponing the allocation of the + * underlying data arrays. + * + * @code{.cpp} + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * int nnz = 5000; + * + * raft::resources handle; + * coo_structure = raft::make_host_coordinate_structure(handle, n_rows, n_cols, nnz); + * * ... + * // compute expected sparsity + * ... + * coo_structure.initialize_sparsity(nnz); + * @endcode + * + * @tparam RowType + * @tparam ColType + * @tparam NZType + * @param[in] handle raft handle for managing expensive resources on host + * @param[in] n_rows total number of rows + * @param[in] n_cols total number of cols + * @param[in] nnz number of non-zeros + * @return a sparsity-owning coordinate structure instance + */ +template +auto make_host_coordinate_structure(raft::resources const& handle, + RowType n_rows, + ColType n_cols, + NZType nnz = 0) +{ + return host_coordinate_structure(handle, n_rows, n_cols, nnz); +} + +/** + * Create a non-owning sparsity-preserved coordinate structure view. Sparsity-preserving means that + * the underlying sparsity is known and cannot be changed. Use the sparsity-owning version if the + * sparsity is not known up front. + * + * @code{.cpp} + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * int nnz = 5000; + * + * // The following pointers are assumed to reference host-accessible memory of size nnz + * int *rows = ...; + * int *cols = ...; + * + * raft::resources handle; + * coo_structure = raft::make_host_coordinate_structure_view(handle, rows, cols, n_rows, n_cols, + * nnz); + * @endcode + * + * @tparam RowType + * @tparam ColType + * @tparam NZType + * @param[in] rows pointer to row indices array on host (size nnz) + * @param[in] cols pointer to column indices array on host (size nnz) + * @param[in] n_rows total number of rows + * @param[in] n_cols total number of columns + * @param[in] nnz number of non-zeros + * @return a sparsity-preserving coordinate structural view + */ +template +auto make_host_coo_structure_view( + RowType* rows, ColType* cols, RowType n_rows, ColType n_cols, NZType nnz) +{ + return host_coordinate_structure_view( + raft::host_span(rows, nnz), raft::host_span(cols, nnz), n_rows, n_cols); +} + +/** + * Create a non-owning sparsity-preserved coordinate structure view. Sparsity-preserving means that + * the underlying sparsity is known and cannot be changed. Use the sparsity-owning version if the + * sparsity is not known up front. + * + * @code{.cpp} + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * int nnz = 5000; + * + * // The following host spans are assumed to be of size nnz + * raft::host_span rows; + * raft::host_span cols; + * + * raft::resources handle; + * coo_structure = raft::make_host_coordinate_structure_view(handle, rows, cols, n_rows, n_cols); + * @endcode + * + * @tparam RowType + * @tparam ColType + * @tparam NZType + * @param[in] rows a host span containing row indices (size nnz) + * @param[in] cols a host span containing column indices (size nnz) + * @param[in] n_rows total number of rows + * @param[in] n_cols total number of columns + * @return a sparsity-preserving coordinate structural view + */ +template +auto make_host_coo_structure_view(raft::host_span rows, + raft::host_span cols, + RowType n_rows, + ColType n_cols) +{ + return host_coordinate_structure_view(rows, cols, n_rows, n_cols); +} + +}; // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/host_csr_matrix.hpp b/cpp/include/raft/core/host_csr_matrix.hpp new file mode 100644 index 0000000000..437f60814e --- /dev/null +++ b/cpp/include/raft/core/host_csr_matrix.hpp @@ -0,0 +1,418 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace raft { + +template typename ContainerPolicy = host_vector_policy, + SparsityType sparsity_type = SparsityType::OWNING> +using host_csr_matrix = + csr_matrix; + +/** + * Specialization for a sparsity-owning csr matrix which uses host memory + */ +template typename ContainerPolicy = host_vector_policy> +using host_sparsity_owning_csr_matrix = + csr_matrix; + +template +struct is_host_csr_matrix : std::false_type { +}; + +template + typename ContainerPolicy, + SparsityType sparsity_type> +struct is_host_csr_matrix< + host_csr_matrix> + : std::true_type { +}; + +template +constexpr bool is_host_csr_matrix_v = is_host_csr_matrix::value; + +template +constexpr bool is_host_csr_sparsity_owning_v = + is_host_csr_matrix::value and T::get_sparsity_type() == OWNING; + +template +constexpr bool is_host_csr_sparsity_preserving_v = + is_host_csr_matrix::value and T::get_sparsity_type() == PRESERVING; + +/** + * Specialization for a csr matrix view which uses host memory + */ +template +using host_csr_matrix_view = csr_matrix_view; + +/** + * Specialization for a sparsity-preserving csr matrix which uses host memory + */ +template typename ContainerPolicy = host_vector_policy> +using host_sparsity_preserving_csr_matrix = csr_matrix; + +/** + * Specialization for a csr matrix view which uses host memory + */ +template +using host_csr_matrix_view = csr_matrix_view; + +/** + * Specialization for a sparsity-owning compressed structure which uses host memory + */ +template typename ContainerPolicy = host_vector_policy> +using host_compressed_structure = + compressed_structure; + +/** + * Specialization for a sparsity-preserving compressed structure view which uses host memory + */ +template +using host_compressed_structure_view = + compressed_structure_view; + +/** + * Create a sparsity-owning sparse matrix in the compressed-sparse row format. sparsity-owning + * means that all of the underlying vectors (data, indptr, indices) are owned by the csr_matrix + * instance. If not known up front, the sparsity can be ignored in this factory function and + * `resize()` invoked on the instance once the sparsity is known. + * + * @code{.cpp} + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * + * raft::resources handle; + * csr_matrix = raft::make_host_csr_matrix(handle, n_rows, n_cols); + * ... + * // compute expected sparsity + * ... + * int nnz = 5000; + * csr_matrix.initialize_sparsity(nnz); + * @endcode + * + * @tparam ElementType + * @tparam IndptrType + * @tparam IndicesType + * @tparam NZType + * @param[in] handle raft handle for managing expensive resources + * @param[in] n_rows total number of rows in the matrix + * @param[in] n_cols total number of columns in the matrix + * @param[in] nnz number of non-zeros in the matrix if known [optional] + * @return a sparsity-owning sparse matrix in compressed (csr) format + */ +template +auto make_host_csr_matrix(raft::resources const& handle, + IndptrType n_rows, + IndicesType n_cols, + NZType nnz = 0) +{ + return host_sparsity_owning_csr_matrix( + handle, n_rows, n_cols, nnz); +} + +/** + * Create a sparsity-preserving sparse matrix in the compressed-sparse row format. + * sparsity-preserving means that a view of the csr sparsity is supplied, allowing the values in + * the sparsity to change but not the sparsity itself. The csr_matrix instance does not own the + * sparsity, the sparsity must be known up front, and cannot be resized later. + * + * @code{.cpp} + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * + * raft::resources handle; + * coo_structure = raft::make_host_compressed_structure(handle, n_rows, n_cols); + * ... + * // compute expected sparsity + * ... + * csr_structure.initialize_sparsity(nnz); + * csr_matrix = raft::make_host_csr_matrix(handle, csr_structure.view()); + * @endcode + + * + * @tparam ElementType + * @tparam IndptrType + * @tparam IndicesType + * @tparam NZType + * @param[in] handle raft handle for managing expensive resources + * @param[in] structure_ a sparsity-preserving compressed structural view + * @return a sparsity-preserving sparse matrix in compressed (csr) format + */ +template +auto make_host_csr_matrix( + raft::resources const& handle, + host_compressed_structure_view structure_) +{ + return host_sparsity_preserving_csr_matrix( + handle, + std::make_shared>(structure_)); +} + +/** + * Create a non-owning sparse matrix view in the coordinate format. This is sparsity-preserving, + * meaning that the underlying sparsity is known and cannot be changed. Use the sparsity-owning + * coo_matrix if sparsity needs to be mutable. + * + * @code{.cpp} + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * int nnz = 5000; + * + * // The following pointer is assumed to reference device memory for a size of nnz + * float* h_elm_ptr = ...; + * + * raft::resources handle; + * csr_structure = raft::make_host_compressed_structure(handle, n_rows, n_cols, nnz); + * csr_matrix_view = raft::make_host_csr_matrix_view(handle, h_elm_ptr, csr_structure.view()); + * @endcode + * + * @tparam ElementType + * @tparam IndptrType + * @tparam IndicesType + * @tparam NZType + * @param[in] ptr a pointer to array of nonzero matrix elements on host (size nnz) + * @param[in] structure_ a sparsity-preserving compressed sparse structural view + * @return a sparsity-preserving csr matrix view + */ +template +auto make_host_csr_matrix_view( + ElementType* ptr, host_compressed_structure_view structure_) +{ + return host_csr_matrix_view( + raft::host_span(ptr, structure_.get_nnz()), std::make_shared(structure_)); +} + +/** + * Create a non-owning sparse matrix view in the compressed-sparse row format. This is + * sparsity-preserving, meaning that the underlying sparsity is known and cannot be changed. Use the + * sparsity-owning coo_matrix if sparsity needs to be mutable. + * + * @code{.cpp} + * #include + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * int nnz = 5000; + * + * // The following span is assumed to be of size nnz + * raft::host_span h_elm_ptr; + * + * raft::resources handle; + * csr_structure = raft::make_host_compressed_structure(handle, n_rows, n_cols, nnz); + * csr_matrix_view = raft::make_host_csr_matrix_view(handle, h_elm_ptr, csr_structure.view()); + * @endcode + * + * @tparam ElementType + * @tparam IndptrType + * @tparam IndicesType + * @tparam NZType + * @param[in] elements host span containing array of matrix elements (size nnz) + * @param[in] structure_ a sparsity-preserving structural view + * @return a sparsity-preserving csr matrix view + */ +template +auto make_host_csr_matrix_view( + raft::host_span elements, + host_compressed_structure_view structure_) +{ + RAFT_EXPECTS(elements.size() == structure_.get_nnz(), + "Size of elements must be equal to the nnz from the structure"); + return host_csr_matrix_view( + elements, std::make_shared(structure_)); +} + +/** + * Create a sparsity-owning compressed structure. This is not sparsity-preserving, meaning that + * the underlying sparsity does not need to be known upon construction. When not known up front, + * the allocation of the underlying indices array is delayed until `resize(nnz)` is invoked. + * + * @code{.cpp} + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * int nnz = 5000; + * + * raft::resources handle; + * csr_structure = raft::make_host_compressed_structure(handle, n_rows, n_cols, nnz); + * ... + * // compute expected sparsity + * ... + * csr_structure.initialize_sparsity(nnz); + * @endcode * + * + * @tparam IndptrType + * @tparam IndicesType + * @tparam NZType + * @param[in] handle raft handle for managing expensive resources + * @param[in] n_rows total number of rows + * @param[in] n_cols total number of cols + * @param[in] nnz total number of nonzeros, if known + * @return a sparsity-owning compressed structure instance + */ +template +auto make_host_compressed_structure(raft::resources const& handle, + IndptrType n_rows, + IndicesType n_cols, + NZType nnz = 0) +{ + return host_compressed_structure(handle, n_rows, n_cols, nnz); +} + +/** + * Create a non-owning sparsity-preserved compressed structure view. Sparsity-preserving means that + * the underlying sparsity is known and cannot be changed. Use the sparsity-owning version if the + * sparsity is not known up front. + * + * @code{.cpp} + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * int nnz = 5000; + * + * // The following pointer is assumed to reference host-accessible memory of size n_rows+1 + * int *indptr = ...; + * + * // The following pointer is assumed to reference host-accessible memory of size nnz + * int *indices = ...; + * + * raft::resources handle; + * csr_structure = raft::make_host_compressed_structure_view(handle, indptr, indices, n_rows, + * n_cols, nnz); + * @endcode + * + * + * @tparam ElementType + * @tparam IndptrType + * @tparam IndicesType + * @tparam NZType + * @param[in] indptr structural indptr (size n_rows+1) + * @param[in] indices structural indices (size nnz) + * @param[in] n_rows total number of rows + * @param[in] n_cols total number of columns + * @param[in] nnz number of non-zeros + * @return a sparsity-preserving compressed structural view + */ +template +auto make_host_csr_structure_view( + IndptrType* indptr, IndicesType* indices, IndptrType n_rows, IndicesType n_cols, NZType nnz) +{ + return host_compressed_structure_view( + raft::host_span(indptr, n_rows + 1), + raft::host_span(indices, nnz), + n_cols); +} + +/** + * Create a non-owning sparsity-preserved compressed structure view. Sparsity-preserving means that + * the underlying sparsity is known and cannot be changed. Use the sparsity-owning version if the + * sparsity is not known up front. + * + * @code{.cpp} + * #include + * #include + * + * int n_rows = 100000; + * int n_cols = 10000; + * int nnz = 5000; + * + * // The following host span is assumed to be of size n_rows+1 + * raft::host_span indptr; + * + * // The following host span is assumed to be of size nnz + * raft::host_span indices; + * + * raft::resources handle; + * csr_structure = raft::make_host_compressed_structure_view(handle, indptr, indices, n_rows, + * n_cols); + * @endcode + * + * @tparam IndptrType + * @tparam IndicesType + * @tparam NZType + * @param[in] indptr structural indptr (size n_rows+1) + * @param[in] indices structural indices (size nnz) + * @param[in] n_cols total number of columns + * @return a sparsity-preserving compressed structural view + * + */ +template +auto make_host_csr_structure_view(raft::host_span indptr, + raft::host_span indices, + IndicesType n_cols) +{ + return host_compressed_structure_view(indptr, indices, n_cols); +} + +}; // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/host_mdarray.hpp b/cpp/include/raft/core/host_mdarray.hpp index 20cb5c1446..9ba29e38d4 100644 --- a/cpp/include/raft/core/host_mdarray.hpp +++ b/cpp/include/raft/core/host_mdarray.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,8 +18,9 @@ #include #include +#include -#include +#include #include namespace raft { @@ -33,7 +34,7 @@ namespace raft { template > + typename ContainerPolicy = host_vector_policy> using host_mdarray = mdarray>; /** @@ -66,12 +67,38 @@ template using host_matrix = host_mdarray, LayoutPolicy>; +/** + * @brief Create a host mdarray. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] res raft handle for managing expensive resources + * @param[in] exts dimensionality of the array (series of integers) + * @return raft::host_mdarray + */ +template +auto make_host_mdarray(raft::resources& res, extents exts) +{ + using mdarray_t = host_mdarray; + + typename mdarray_t::mapping_type layout{exts}; + typename mdarray_t::container_policy_type policy; + + return mdarray_t{res, layout, policy}; +} + /** * @brief Create a host mdarray. * @tparam ElementType the data type of the matrix elements * @tparam IndexType the index type of the extents * @tparam LayoutPolicy policy for strides and layout ordering * @param exts dimensionality of the array (series of integers) + * Note: This function is deprecated and will be removed in a future version. Please use version + * that accepts raft::resources. + * * @return raft::host_mdarray */ template exts) typename mdarray_t::mapping_type layout{exts}; typename mdarray_t::container_policy_type policy; - return mdarray_t{layout, policy}; + raft::resources res; + return mdarray_t{res, layout, policy}; +} + +/** + * @brief Create a 2-dim c-contiguous host mdarray. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] res raft handle for managing expensive resources + * @param[in] n_rows number or rows in matrix + * @param[in] n_cols number of columns in matrix + * @return raft::host_matrix + */ +template +auto make_host_matrix(raft::resources& res, IndexType n_rows, IndexType n_cols) +{ + return make_host_mdarray( + res, make_extents(n_rows, n_cols)); } /** @@ -95,6 +142,9 @@ auto make_host_mdarray(extents exts) * @tparam LayoutPolicy policy for strides and layout ordering * @param[in] n_rows number or rows in matrix * @param[in] n_cols number of columns in matrix + * Note: This function is deprecated and will be removed in a future version. Please use version + * that accepts raft::resources. + * * @return raft::host_matrix */ template +auto make_host_scalar(raft::resources& res, ElementType const& v) +{ + // FIXME(jiamingy): We can optimize this by using std::array as container policy, which + // requires some more compile time dispatching. This is enabled in the ref impl but + // hasn't been ported here yet. + scalar_extent extents; + using policy_t = typename host_scalar::container_policy_type; + policy_t policy; + auto scalar = host_scalar{res, extents, policy}; + scalar(0) = v; + return scalar; +} + +/** + * @brief Create a host scalar from v. + * + * @tparam ElementType the data type of the scalar element + * @tparam IndexType the index type of the extents + * @param[in] v scalar type to wrap + * Note: This function is deprecated and will be removed in a future version. Please use version + * that accepts raft::resources. + * + * @return raft::host_scalar + */ +template auto make_host_scalar(ElementType const& v) { // FIXME(jiamingy): We can optimize this by using std::array as container policy, which @@ -123,17 +199,38 @@ auto make_host_scalar(ElementType const& v) scalar_extent extents; using policy_t = typename host_scalar::container_policy_type; policy_t policy; - auto scalar = host_scalar{extents, policy}; + raft::resources handle; + auto scalar = host_scalar{handle, extents, policy}; scalar(0) = v; return scalar; } +/** + * @brief Create a 1-dim host mdarray. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] res raft handle for managing expensive resources + * @param[in] n number of elements in vector + * @return raft::host_vector + */ +template +auto make_host_vector(raft::resources& res, IndexType n) +{ + return make_host_mdarray(res, make_extents(n)); +} + /** * @brief Create a 1-dim host mdarray. * @tparam ElementType the data type of the vector elements * @tparam IndexType the index type of the extents * @tparam LayoutPolicy policy for strides and layout ordering * @param[in] n number of elements in vector + * + * Note: This function is deprecated and will be removed in a future version. Please use version + * that accepts raft::resources. * @return raft::host_vector */ template ; public: - constexpr mdarray() noexcept(std::is_nothrow_default_constructible_v) - : cp_{rmm::cuda_stream_default}, c_{cp_.create(0)} {}; + constexpr mdarray(raft::resources const& handle) noexcept( + std::is_nothrow_default_constructible_v) + : cp_{}, c_{cp_.create(handle, 0)} {}; constexpr mdarray(mdarray const&) noexcept(std::is_nothrow_copy_constructible_v) = default; constexpr mdarray(mdarray&&) noexcept(std::is_nothrow_move_constructible::value) = @@ -203,12 +204,16 @@ class mdarray * @brief The only constructor that can create storage, this is to make sure CUDA stream is being * used. */ - RAFT_MDARRAY_CTOR_CONSTEXPR mdarray(mapping_type const& m, container_policy_type const& cp) - : cp_(cp), map_(m), c_(cp_.create(map_.required_span_size())) + RAFT_MDARRAY_CTOR_CONSTEXPR mdarray(raft::resources const& handle, + mapping_type const& m, + container_policy_type const& cp) + : cp_(cp), map_(m), c_(cp_.create(handle, map_.required_span_size())) { } - RAFT_MDARRAY_CTOR_CONSTEXPR mdarray(mapping_type const& m, container_policy_type& cp) - : cp_(cp), map_(m), c_(cp_.create(map_.required_span_size())) + RAFT_MDARRAY_CTOR_CONSTEXPR mdarray(raft::resources const& handle, + mapping_type const& m, + container_policy_type& cp) + : cp_(cp), map_(m), c_(cp_.create(handle, map_.required_span_size())) { } diff --git a/cpp/include/raft/core/resource/cublas_handle.hpp b/cpp/include/raft/core/resource/cublas_handle.hpp index 710fcc7e60..c8d8ee4c02 100644 --- a/cpp/include/raft/core/resource/cublas_handle.hpp +++ b/cpp/include/raft/core/resource/cublas_handle.hpp @@ -71,7 +71,9 @@ inline cublasHandle_t get_cublas_handle(resources const& res) cudaStream_t stream = get_cuda_stream(res); res.add_resource_factory(std::make_shared(stream)); } - return *res.get_resource(resource_type::CUBLAS_HANDLE); + auto ret = *res.get_resource(resource_type::CUBLAS_HANDLE); + RAFT_CUBLAS_TRY(cublasSetStream(ret, get_cuda_stream(res))); + return ret; }; /** diff --git a/cpp/include/raft/core/sparse_types.hpp b/cpp/include/raft/core/sparse_types.hpp new file mode 100644 index 0000000000..207cc944d2 --- /dev/null +++ b/cpp/include/raft/core/sparse_types.hpp @@ -0,0 +1,216 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace raft { + +enum SparsityType { OWNING, PRESERVING }; + +/** + * Maintains metadata about the structure and sparsity of a sparse matrix. + * @tparam RowType + * @tparam ColType + * @tparam NZType + * @tparam is_device + */ +template +class sparse_structure { + public: + using row_type = RowType; + using col_type = ColType; + using nnz_type = NZType; + + /** + * Constructor when sparsity is already known + * @param n_rows total number of rows in matrix + * @param n_cols total number of columns in matrix + * @param nnz sparsity of matrix + */ + sparse_structure(row_type n_rows, col_type n_cols, nnz_type nnz) + : n_rows_(n_rows), n_cols_(n_cols), nnz_(nnz){}; + + /** + * Constructor when sparsity is not yet known + * @param n_rows total number of rows in matrix + * @param n_cols total number of columns in matrix + */ + sparse_structure(row_type n_rows, col_type n_cols) : n_rows_(n_rows), n_cols_(n_cols), nnz_(0) {} + + /** + * Return the sparsity of the matrix (this will be 0 when sparsity is not yet known) + * @return sparsity of matrix + */ + nnz_type get_nnz() { return nnz_; } + + /** + * Return the total number of rows in the matrix + * @return total number of rows in the matriz + */ + row_type get_n_rows() { return n_rows_; } + + /** + * Return the total number of columns in the matrix + * @return total number of columns + */ + col_type get_n_cols() { return n_cols_; } + + /** + * Initialize the matrix sparsity when it was not known + * upon construction. + * @param nnz + */ + virtual void initialize_sparsity(nnz_type nnz) { nnz_ = nnz; } + + protected: + row_type n_rows_; + col_type n_cols_; + nnz_type nnz_; +}; + +/** + * A non-owning view of a sparse matrix, which includes a + * structure component coupled with its elements/weights + * + * @tparam ElementType + * @tparam sparse_structure + */ +template +class sparse_matrix_view { + public: + using element_type = ElementType; + using structure_view_type = typename StructureType::view_type; + + sparse_matrix_view(raft::span element_span, + structure_view_type structure_view) + : element_span_(element_span), structure_view_(structure_view) + { + // FIXME: Validate structure sizes match span size. + } + + /** + * Return a view of the structure underlying this matrix + * @return + */ + structure_view_type get_structure() { return structure_view_; } + + /** + * Return a span of the nonzero elements of the matrix + * @return span of the nonzero elements of the matrix + */ + span get_elements() { return element_span_; } + + protected: + raft::span element_span_; + structure_view_type structure_view_; +}; + +/** + * TODO: Need to support the following types of configurations: + * 1. solid: immutable_sparse_matrix_view + * - This is an immutable view type, nothing can change. + * 2. liquid: sparse_matrix + * - sparse_matrix owning container w/ StructureType=immutable view? + * 3. gas: sparse_matrix + * - sparse_matrix owning container w/ StructureType owning container? + */ + +/** + * An owning container for a sparse matrix, which includes a + * structure component coupled with its elements/weights + * @tparam ElementType + * @tparam sparse_structure + * @tparam ContainerPolicy + */ +template + typename ContainerPolicy> +class sparse_matrix { + public: + using view_type = ViewType; + using element_type = typename view_type::element_type; + using structure_type = StructureType; + using row_type = typename structure_type::row_type; + using col_type = typename structure_type::col_type; + using nnz_type = typename structure_type::nnz_type; + + using structure_view_type = typename structure_type::view_type; + using container_policy_type = ContainerPolicy; + using container_type = typename container_policy_type::container_type; + + sparse_matrix(raft::resources const& handle, + row_type n_rows, + col_type n_cols, + nnz_type nnz = 0) noexcept(std::is_nothrow_default_constructible_v) + : structure_{std::make_shared(handle, n_rows, n_cols, nnz)}, + cp_{}, + c_elements_{cp_.create(handle, 0)} {}; + + // Constructor that owns the data but not the structure + sparse_matrix(raft::resources const& handle, std::shared_ptr structure) noexcept( + std::is_nothrow_default_constructible_v) + : structure_{structure}, cp_{}, c_elements_{cp_.create(handle, structure.get()->get_nnz())} {}; + + constexpr sparse_matrix(sparse_matrix const&) noexcept( + std::is_nothrow_copy_constructible_v) = default; + constexpr sparse_matrix(sparse_matrix&&) noexcept( + std::is_nothrow_move_constructible::value) = default; + + constexpr auto operator=(sparse_matrix const&) noexcept( + std::is_nothrow_copy_assignable::value) -> sparse_matrix& = default; + constexpr auto operator=(sparse_matrix&&) noexcept( + std::is_nothrow_move_assignable::value) -> sparse_matrix& = default; + + ~sparse_matrix() noexcept(std::is_nothrow_destructible::value) = default; + + void initialize_sparsity(nnz_type nnz) { c_elements_.resize(nnz); }; + + raft::span get_elements() + { + return raft::span(c_elements_.data(), structure_view().get_nnz()); + } + + /** + * Return a view of the structure underlying this matrix + * @return + */ + virtual structure_view_type structure_view() = 0; + + /** + * Return a sparsity-preserving view of this sparse matrix + * @return view of this sparse matrix + */ + view_type view() + { + auto struct_view = structure_view(); + auto element_span = + raft::span(c_elements_.data(), struct_view.get_nnz()); + return view_type(element_span, struct_view); + } + + protected: + std::shared_ptr structure_; + container_policy_type cp_; + container_type c_elements_; +}; +} // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/temporary_device_buffer.hpp b/cpp/include/raft/core/temporary_device_buffer.hpp index 5e6ae84eb5..194471c5de 100644 --- a/cpp/include/raft/core/temporary_device_buffer.hpp +++ b/cpp/include/raft/core/temporary_device_buffer.hpp @@ -43,7 +43,7 @@ namespace raft { template typename ContainerPolicy = detail::device_uvector_policy> + template typename ContainerPolicy = device_uvector_policy> class temporary_device_buffer { using view_type = device_mdspan; using index_type = typename Extents::index_type; @@ -89,9 +89,9 @@ class temporary_device_buffer { { if (device_id_ == -1) { typename owning_device_buffer::mapping_type layout{extents_}; - typename owning_device_buffer::container_policy_type policy{handle.get_stream()}; + typename owning_device_buffer::container_policy_type policy{}; - owning_device_buffer device_data{layout, policy}; + owning_device_buffer device_data{handle, layout, policy}; raft::copy(device_data.data_handle(), data, length_, handle.get_stream()); data_ = data_store{std::in_place_index<1>, std::move(device_data)}; } else { @@ -167,7 +167,7 @@ class temporary_device_buffer { template typename ContainerPolicy = detail::device_uvector_policy, + template typename ContainerPolicy = device_uvector_policy, size_t... Extents> auto make_temporary_device_buffer(raft::device_resources const& handle, ElementType* data, @@ -209,7 +209,7 @@ auto make_temporary_device_buffer(raft::device_resources const& handle, template typename ContainerPolicy = detail::device_uvector_policy, + template typename ContainerPolicy = device_uvector_policy, size_t... Extents> auto make_readonly_temporary_device_buffer(raft::device_resources const& handle, ElementType* data, @@ -252,7 +252,7 @@ auto make_readonly_temporary_device_buffer(raft::device_resources const& handle, template typename ContainerPolicy = detail::device_uvector_policy, + template typename ContainerPolicy = device_uvector_policy, size_t... Extents, typename = std::enable_if_t>> auto make_writeback_temporary_device_buffer(raft::device_resources const& handle, diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 7887eb96be..f469250b45 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -27,24 +27,12 @@ #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - +#include #include #include #include +#include #include #include @@ -126,7 +114,7 @@ void distance_impl(raft::resources const& handle, const DataT* y_norm = nullptr; cudaStream_t stream = raft::resource::get_cuda_stream(handle); - distance_matrix_dispatch( + pairwise_matrix_dispatch( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } @@ -205,7 +193,7 @@ void distance_impl(raft::resources const& handle, using OpT = ops::correlation_distance_op; OpT corr_op(is_row_major, sq_norm_col_vec, sq_norm_row_vec, m, n, k); - distance_matrix_dispatch( + pairwise_matrix_dispatch( corr_op, m, n, k, x, y, norm_col_vec, norm_row_vec, out, fin_op, stream, is_row_major); } @@ -248,34 +236,9 @@ void distance_impl(raft::resources const& handle, norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); } - // On CUDA 12: - // - always execute normal kernel - // - // On CUDA 11 and below: - // - execute CUTLASS-based kernel on SM_80 and above - // - execute normal kernel otherwise. - - if constexpr (__CUDACC_VER_MAJOR__ == 12) { - // Always execute legacy kernels on CUDA 12 - ops::cosine_distance_op distance_op{}; - distance_matrix_dispatch( - distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); - } else { - const auto deviceVersion = getComputeCapability(); - if (deviceVersion.first >= 8) { - // If device is SM_80 or later, use CUTLASS-based kernel. - using Op = ops::cosine_cutlass_op; - Op distance_op{}; - - distance_matrix_cutlass_dispatch( - distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); - } else { - // Else use "legacy" cosine kernel - ops::cosine_distance_op distance_op{}; - distance_matrix_dispatch( - distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); - } - } + ops::cosine_distance_op distance_op{}; + pairwise_matrix_dispatch( + distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } template @@ -300,7 +263,7 @@ void distance_impl(raft::resources const& handle, cudaStream_t stream = raft::resource::get_cuda_stream(handle); - distance_matrix_dispatch( + pairwise_matrix_dispatch( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } @@ -362,7 +325,7 @@ void distance_impl(raft::resources const& handle, const DataT* x_norm = nullptr; const DataT* y_norm = nullptr; - distance_matrix_dispatch( + pairwise_matrix_dispatch( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); // Finally revert sqrt of x and y @@ -394,7 +357,7 @@ void distance_impl(raft::resources const& handle, cudaStream_t stream = raft::resource::get_cuda_stream(handle); - distance_matrix_dispatch( + pairwise_matrix_dispatch( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } @@ -438,7 +401,7 @@ void distance_impl(raft::resources const& handle, const DataT* x_norm = nullptr; const DataT* y_norm = nullptr; - distance_matrix_dispatch( + pairwise_matrix_dispatch( kl_divergence, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); if (x != y) { @@ -469,7 +432,7 @@ void distance_impl(raft::resources const& handle, const DataT* y_norm = nullptr; cudaStream_t stream = raft::resource::get_cuda_stream(handle); - distance_matrix_dispatch( + pairwise_matrix_dispatch( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } @@ -514,34 +477,9 @@ void distance_impl_l2_expanded( // NOTE: different name norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); } - // On CUDA 12: - // - always execute normal kernel - // - // On CUDA 11 and below: - // - execute CUTLASS-based kernel on SM_80 and above - // - execute normal kernel otherwise. - - if constexpr (__CUDACC_VER_MAJOR__ == 12) { - // Always execute legacy kernels on CUDA 12 - ops::l2_exp_distance_op l2_op(perform_sqrt); - distance_matrix_dispatch( - l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); - } else { - const auto deviceVersion = getComputeCapability(); - if (deviceVersion.first >= 8) { - // If device is SM_80 or later, use CUTLASS-based kernel. - using L2Op = ops::l2_exp_cutlass_op; - L2Op l2_op(perform_sqrt); - - distance_matrix_cutlass_dispatch( - l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); - } else { - // Else use "legacy" L2 - ops::l2_exp_distance_op l2_op(perform_sqrt); - distance_matrix_dispatch( - l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); - } - } + ops::l2_exp_distance_op distance_op{perform_sqrt}; + pairwise_matrix_dispatch( + distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } template @@ -610,7 +548,7 @@ void distance_impl(raft::resources const& handle, cudaStream_t stream = raft::resource::get_cuda_stream(handle); - distance_matrix_dispatch( + pairwise_matrix_dispatch( l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } @@ -638,7 +576,7 @@ void distance_impl(raft::resources const& handle, cudaStream_t stream = raft::resource::get_cuda_stream(handle); - distance_matrix_dispatch( + pairwise_matrix_dispatch( l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } @@ -664,7 +602,7 @@ void distance_impl(raft::resources const& handle, cudaStream_t stream = raft::resource::get_cuda_stream(handle); - distance_matrix_dispatch( + pairwise_matrix_dispatch( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } @@ -690,7 +628,7 @@ void distance_impl(raft::resources const& handle, cudaStream_t stream = raft::resource::get_cuda_stream(handle); - distance_matrix_dispatch( + pairwise_matrix_dispatch( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } @@ -716,7 +654,7 @@ void distance_impl(raft::resources const& handle, cudaStream_t stream = raft::resource::get_cuda_stream(handle); - distance_matrix_dispatch( + pairwise_matrix_dispatch( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } diff --git a/cpp/include/raft/distance/detail/distance_ops/all_ops.cuh b/cpp/include/raft/distance/detail/distance_ops/all_ops.cuh new file mode 100644 index 0000000000..3e8f4e86fb --- /dev/null +++ b/cpp/include/raft/distance/detail/distance_ops/all_ops.cuh @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// Defines a named requirement "has_cutlass_op" +#include + +// The distance operations: +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include diff --git a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh index 45bea08a95..930294ce31 100644 --- a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh @@ -27,8 +27,12 @@ namespace raft::distance::detail::ops { * * c_ij = sum_k |x_ik - y_kj| / ( |x_ik| + |y_kj| ) */ -template +template struct canberra_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + // Load norms of input data static constexpr bool use_norms = false; // Whether the core function requires so many instructions that it makes sense diff --git a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh index 3832104280..289b69070a 100644 --- a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh @@ -28,8 +28,12 @@ namespace raft::distance::detail::ops { * / * (|| x - mean(x) ||_2 || y - mean(y) ||_2) */ -template +template struct correlation_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + const DataT* x2n; const DataT* y2n; IdxT m; diff --git a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh index c3f3b75e62..7c37c27b4e 100644 --- a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh @@ -20,6 +20,17 @@ namespace raft::distance::detail::ops { +// Epilogue operator for CUTLASS based kernel +template +struct cosine_cutlass_op { + __device__ cosine_cutlass_op() noexcept {} + __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept + { + return static_cast(1.0) - (AccT)(accVal / (aNorm * bNorm)); + } + __device__ AccT operator()(DataT aData) const noexcept { return aData; } +}; + /** * @brief the expanded cosine distance matrix calculation * @@ -27,8 +38,12 @@ namespace raft::distance::detail::ops { * * d(x, y) = 1 - (x â‹… y) / ( ||x||_2 ||y||_2) */ -template +template struct cosine_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + // Load norms of input data static constexpr bool use_norms = true; // Whether the core function requires so many instructions that it makes sense @@ -60,16 +75,8 @@ struct cosine_distance_op { } } } -}; -template -struct cosine_cutlass_op { - __device__ cosine_cutlass_op() noexcept {} - __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept - { - return static_cast(1.0) - (AccT)(accVal / (aNorm * bNorm)); - } - __device__ AccT operator()(DataT aData) const noexcept { return aData; } + cosine_cutlass_op get_cutlass_op() { return cosine_cutlass_op(); } }; } // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh b/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh new file mode 100644 index 0000000000..d3eb90467b --- /dev/null +++ b/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::distance::detail::ops { + +// This file defines the named requirement "has_cutlass_op" that can be used to +// determine if a distance operation has a CUTLASS op that can be used to pass +// to CUTLASS. Examples of distance operations that satisfy this requirement are +// cosine_distance_op and l2_exp_distance_op. + +// Primary template handles types that do not support CUTLASS. +// This pattern is described in: +// https://en.cppreference.com/w/cpp/types/void_t +template +struct has_cutlass_op : std::false_type { +}; + +// Specialization recognizes types that do support CUTLASS +template +struct has_cutlass_op> : std::true_type { +}; + +} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh index 98acf11560..1cfdcfdc73 100644 --- a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh @@ -26,8 +26,12 @@ namespace raft::distance::detail::ops { * * c_ij = sum_k (x_ik != y_kj) / k */ -template +template struct hamming_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + IdxT k; hamming_distance_op(IdxT k_) noexcept : k(k_) {} diff --git a/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh b/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh index c5e2b84ac2..c4aecc7a6f 100644 --- a/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh @@ -27,8 +27,12 @@ namespace raft::distance::detail::ops { * c_ij = sqrt(1 - sum_k sqrt(x_ik * y_kj)) * */ -template +template struct hellinger_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + // Load norms of input data static constexpr bool use_norms = false; // Whether the core function requires so many instructions that it makes sense diff --git a/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh b/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh index df5aadcf3b..41eeb9dd83 100644 --- a/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh @@ -29,8 +29,12 @@ namespace raft::distance::detail::ops { * c_ij = sqrt(0.5 * sum( -x_i * (log(0.5 * (x_i + y_i)) - log(x_i)) * + (-y_i * (log(0.5 * (x_i + y_i)) - log(y_i))))) */ -template +template struct jensen_shannon_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + // Load norms of input data static constexpr bool use_norms = false; // Whether the core function requires so many instructions that it makes sense diff --git a/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh b/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh index 526927243f..d046b62c30 100644 --- a/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh @@ -26,8 +26,12 @@ namespace raft::distance::detail::ops { * * c_ij = 0.5 * sum(x * log (x / y)); */ -template +template struct kl_divergence_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + const bool is_row_major; const bool x_equal_y; diff --git a/cpp/include/raft/distance/detail/distance_ops/l1.cuh b/cpp/include/raft/distance/detail/distance_ops/l1.cuh index b02971bac7..8ec4000827 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l1.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l1.cuh @@ -26,8 +26,12 @@ namespace raft::distance::detail::ops { * * c_ij = sum_k abs(x_ik - y_kj) */ -template +template struct l1_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + // Do not load norms of data, the computation of L1 distance does not use them. static constexpr bool use_norms = false; // Whether the core function requires so many instructions that it makes sense diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh index fb00f8d66a..2a7af53813 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh @@ -20,6 +20,26 @@ namespace raft::distance::detail::ops { +// Epilogue operator for CUTLASS based kernel +template +struct l2_exp_cutlass_op { + bool sqrt; + + __device__ l2_exp_cutlass_op() noexcept : sqrt(false) {} + __device__ l2_exp_cutlass_op(bool isSqrt) noexcept : sqrt(isSqrt) {} + __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept + { + AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; + // outVal could be negative due to numerical instability, especially when + // calculating self distance. + // clamp to 0 to avoid potential NaN in sqrt + outVal = outVal * (outVal > DataT(0.0)); + return sqrt ? raft::sqrt(outVal) : outVal; + } + + __device__ AccT operator()(DataT aData) const noexcept { return aData; } +}; + /** * @brief the expanded euclidean distance matrix calculation * @@ -28,8 +48,12 @@ namespace raft::distance::detail::ops { * c_ij = - 2 sum_k x_ik * y_kj + ||x_i.||_2 + ||y_.j||_2 * */ -template +template struct l2_exp_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + bool sqrt; l2_exp_distance_op(bool sqrt_) noexcept : sqrt(sqrt_) {} @@ -62,6 +86,8 @@ struct l2_exp_distance_op { #pragma unroll for (int j = 0; j < Policy::AccColsPerTh; ++j) { DataT val = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; + // val could be negative due to numerical instability, especially when + // calculating self distance. Clamp to 0 to avoid potential NaN in sqrt acc[i][j] = val * (val > DataT(0.0)); } } @@ -75,26 +101,8 @@ struct l2_exp_distance_op { } } } -}; - -// Epilogue operator for CUTLASS based kernel -template -struct l2_exp_cutlass_op { - bool sqrt; - - __device__ l2_exp_cutlass_op() noexcept : sqrt(false) {} - __device__ l2_exp_cutlass_op(bool isSqrt) noexcept : sqrt(isSqrt) {} - __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept - { - AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; - // outVal could be negative due to numerical instability, especially when - // calculating self distance. - // clamp to 0 to avoid potential NaN in sqrt - outVal = outVal * (outVal > DataT(0.0)); - return sqrt ? raft::sqrt(outVal) : outVal; - } - __device__ AccT operator()(DataT aData) const noexcept { return aData; } + l2_exp_cutlass_op get_cutlass_op() { return l2_exp_cutlass_op(sqrt); } }; } // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh index e03eb0a97e..f0ea591eaf 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh @@ -27,8 +27,12 @@ namespace raft::distance::detail::ops { * * c_ij = optional_sqrt ( sum_k (x_ik - y_kj)^2 ) */ -template +template struct l2_unexp_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + bool sqrt; l2_unexp_distance_op(bool sqrt_) noexcept : sqrt(sqrt_) {} diff --git a/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh b/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh index caa1379133..fb21fb1a21 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh @@ -27,8 +27,12 @@ namespace raft::distance::detail::ops { * * c_ij = max_k | x_ik - y_kj | */ -template +template struct l_inf_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + // Load norms of input data static constexpr bool use_norms = false; // Whether the core function requires so many instructions that it makes sense diff --git a/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh b/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh index a4a090d058..71dfd51a6e 100644 --- a/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh @@ -26,8 +26,12 @@ namespace raft::distance::detail::ops { * * c_ij = (sum_k |x_ik - y_jk|^p)^(1/p) */ -template +template struct lp_unexp_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + DataT p; lp_unexp_distance_op(DataT p_) noexcept : p(p_) {} diff --git a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh index 7acd858e49..ea09e4d1db 100644 --- a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh @@ -27,8 +27,12 @@ namespace raft::distance::detail::ops { * * c_ij = (k - (sum_k x_ik * y_kj)) / k */ -template +template struct russel_rao_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + IdxT k; const float one_over_k; diff --git a/cpp/include/raft/distance/detail/distance_ops/template.cuh b/cpp/include/raft/distance/detail/distance_ops/template.cuh index b0f40123aa..6998f3cad4 100644 --- a/cpp/include/raft/distance/detail/distance_ops/template.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/template.cuh @@ -24,8 +24,12 @@ namespace raft::distance::detail::ops { // // Fill in the TODO items. -template +template struct template_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + TODO member; template_distance_op(TODO member_) noexcept : member(member_) {} diff --git a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh index 2ab5c69b0d..c5fdd28117 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh @@ -26,6 +26,7 @@ #endif #include +#include #include #include @@ -36,6 +37,8 @@ #include #include +#include + #include "./pairwise_distance_epilogue_elementwise.h" #include "./pairwise_distance_gemm.h" @@ -59,26 +62,29 @@ template -void cutlassDistanceKernel(const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - FinalLambda fin_op, - DistanceFn dist_op, - cudaStream_t stream) +typename std::enable_if::value>::type cutlassDistanceKernel( + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + OutT* dOutput, + FinalLambda fin_op, + OpT distance_op, + cudaStream_t stream) { static_assert(!(std::is_same::value), "OutType bool is not supported use uint8_t instead"); + auto dist_op = distance_op.get_cutlass_op(); + using DistanceFn = decltype(dist_op); using EpilogueOutputOp = cutlass::epilogue::thread::PairwiseDistanceEpilogueElementwise -#include -#include +#include +#include +#include +#include +#include #include -#include +#include +#include namespace raft::distance::detail { -/** - * @brief: Computes minimal common alignment of the rows in a 2D array in bytes - * - * The 2D matrix `x` is assumed to be row-major. This function computes the - * minimal alignment in bytes of the first elements of each row. - * Output can be 16, 8, 4, 2, 1. - * - * @param x Base pointer of row-major input matrix - * @param stride Stride in number of element between consecutive rows. - */ -template -size_t alignment_of_2d_array(const DataT* x, size_t stride) -{ - auto base = reinterpret_cast(x); - size_t stride_bytes = sizeof(DataT) * stride; - - for (int align = 16; align >= 0; align /= 2) { - bool base_aligned = base % align == 0; - bool stride_aligned = stride_bytes % align == 0; - if (base_aligned && stride_aligned) { return align; } - } - return 1; -} - -template -using vec_len_constant = std::integral_constant; - -/** - * @brief: Converts run-time arguments to compile-time arguments - * - * Converts run-time arguments row_major and vec_len to compile-time arguments - * and dispatches a lambda f with these compile-time arguments. - * - * This is equivalent to copying and pasting the lambda function `f` in each of - * the switch case statements. - * - * @tparam F Type of lambda f. - * @param row_major Boolean indicating whether input arrays have row-major layout. - * @param vec_len Integer value 1, 2, or 4 specifying the Veclen template parameter of - * the KernelPolicy. - * @param f Lambda that takes two std::integral_constant parameters representing - * row_major and vec_len. - */ -template -void dispatch(bool row_major, int vec_len, F&& f) -{ - if (row_major) { - switch (vec_len) { - case 4: f(std::bool_constant(), vec_len_constant<4>()); break; - case 2: f(std::bool_constant(), vec_len_constant<2>()); break; - default: f(std::bool_constant(), vec_len_constant<1>()); break; - } - } else { - switch (vec_len) { - case 4: f(std::bool_constant(), vec_len_constant<4>()); break; - case 2: f(std::bool_constant(), vec_len_constant<2>()); break; - default: f(std::bool_constant(), vec_len_constant<1>()); break; - } - } -} - template -void distance_matrix_dispatch(OpT distance_op, +void pairwise_matrix_dispatch(OpT distance_op, IdxT m, IdxT n, IdxT k, @@ -104,114 +45,51 @@ void distance_matrix_dispatch(OpT distance_op, cudaStream_t stream, bool is_row_major) { - // Determine leading dimensions and, if column-major, flip order of passing x - // and y. - IdxT ldx, ldy, ld_out; - if (is_row_major) { - ldx = k, ldy = k, ld_out = n; + // Create kernel parameter struct. Flip x and y if column major. + IdxT ldx = is_row_major ? k : m; + IdxT ldy = is_row_major ? k : n; + IdxT ld_out = is_row_major ? n : m; + + pairwise_matrix_params params{ + m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major}; + + if (!params.is_row_major) { params.flip_x_and_y(); } + + // On CUDA 12: + // - always execute normal kernel + // + // On CUDA 11 and below: + // - execute CUTLASS-based kernel on SM_80 and above + // - execute normal kernel below SM_80 + + constexpr bool is_ctk_12 = __CUDACC_VER_MAJOR__ == 12; + constexpr bool cutlass_op_unavailable = !ops::has_cutlass_op(); + + if constexpr (is_ctk_12 || cutlass_op_unavailable) { + // Always execute legacy kernels on CUDA 12 + auto any_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_future()); + pairwise_matrix_sm60_dispatch(distance_op, params, any_range, stream); } else { - // Flip x, y, and m, n. - std::swap(x, y); - std::swap(x_norm, y_norm); - std::swap(m, n); - ldx = m, ldy = n, ld_out = n; - } - - size_t align_x = alignment_of_2d_array(x, ldx); - size_t align_y = alignment_of_2d_array(y, ldy); - size_t byte_alignment = min(align_x, align_y); - - // Since alignment is in bytes, it could be smaller than sizeof(DataT). - // Handle this (unlikely) case here. - RAFT_EXPECTS(sizeof(DataT) <= byte_alignment, - "Input matrix must be aligned to size of elements."); - - // Compute number of elements that can be loaded in one instruction - // without causing misalignent errors. - int vec_len_aligned; - if (byte_alignment % sizeof(DataT) == 0) { - // In the future, we might support `int8_t` input. In that case, - // byte_alignment / sizeof(DataT) might exceed 4. We maximize at 4 here, to - // prevent adding more cases in dispatch (which are expensive to compile). - vec_len_aligned = std::min(4, int(byte_alignment / sizeof(DataT))); - } else { - vec_len_aligned = 1; - } - - dispatch(is_row_major, vec_len_aligned, [&](auto row_major, auto vec_len_aligned) { - // row_major and vec_len are std::integral_constants of type bool and int - // respectively. - - // To keep compile times in check, we only specialize on veclen > 1 when - // the inner loop is relatively cheap (< 5 flops). - constexpr int vec_len_op = distance_op.expensive_inner_loop ? 1 : vec_len_aligned(); - - // Prevent double, vec_len=4 combination (this is not supported) - constexpr int vec_len = std::min(vec_len_op, static_cast(16 / sizeof(DataT))); - - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - typedef typename std::conditional::type Policy; - - return pairwise_matrix( - distance_op, fin_op, x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, stream); - }); -} - -template -void distance_matrix_cutlass_dispatch(opT cutlass_op, - IdxT m, - IdxT n, - IdxT k, - const DataT* x, - const DataT* y, - const DataT* x_norm, - const DataT* y_norm, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) -{ - // Determine leading dimensions and possibly flip order of passing x and y if - // column_major. - IdxT ldx, ldy, ld_out; - if (is_row_major) { - ldx = k, ldy = k, ld_out = n; - } else { - std::swap(x, y); - std::swap(x_norm, y_norm); - std::swap(m, n); - ldx = m, ldy = n, ld_out = n; + auto cutlass_range = raft::arch::SM_range(raft::arch::SM_80(), raft::arch::SM_future()); + auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80()); + + // Get pointer to SM60 kernel to determine the runtime architecture of the + // current system. Other methods to determine the architecture (that do not + // require a pointer) can be error prone. See: + // https://github.com/NVIDIA/cub/issues/545 + auto sm60_wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, legacy_range); + void* kernel_ptr = reinterpret_cast(sm60_wrapper.kernel_ptr); + auto runtime_arch = raft::arch::kernel_runtime_arch(kernel_ptr); + + if (cutlass_range.contains(runtime_arch)) { + // If device is SM_80 or later, use CUTLASS-based kernel. + pairwise_matrix_sm80_dispatch(distance_op, params, cutlass_range, stream); + } else { + // Reuse kernel wrapper that we obtained above. This avoids performing the + // dispatch twice. + sm60_wrapper.launch(distance_op, params, stream); + } } - - size_t align_x = alignment_of_2d_array(x, ldx); - size_t align_y = alignment_of_2d_array(y, ldy); - size_t byte_alignment = min(align_x, align_y); - - // Since alignment is in bytes, it could be smaller than sizeof(DataT). - // Handle this (unlikely) case here. - RAFT_EXPECTS(sizeof(DataT) <= byte_alignment, - "Input matrix must be aligned to size of elements."); - - // Compute number of elements that can be loaded in one instruction - // without causing misalignent errors. - int vec_len_aligned = (byte_alignment % sizeof(DataT) == 0) ? byte_alignment / sizeof(DataT) : 1; - - dispatch(is_row_major, vec_len_aligned, [&](auto row_major, auto vec_len_aligned) { - // row_major and vec_len are std::integral_constants of type bool and int - // respectively. - - // Prevent double, vec_len=4 combination (this is not supported) - constexpr int vec_len = std::min(vec_len_aligned(), static_cast(16 / sizeof(DataT))); - - cutlassDistanceKernel( - x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, fin_op, cutlass_op, stream); - }); } }; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh new file mode 100644 index 0000000000..c1e4c08af4 --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "kernel_sm60.cuh" +#include +#include + +namespace raft::distance::detail { + +/** + * @brief: Computes minimal common alignment of the rows in a 2D array in bytes + * + * The 2D matrix `x` is assumed to be row-major. This function computes the + * minimal alignment in bytes of the first elements of each row. + * Output can be 16, 8, 4, 2, 1. + * + * @param x Base pointer of row-major input matrix + * @param stride Stride in number of element between consecutive rows. + */ +template +size_t alignment_of_2d_array(const DataT* x, size_t stride) +{ + auto base = reinterpret_cast(x); + size_t stride_bytes = sizeof(DataT) * stride; + + for (int align = 16; align >= 0; align /= 2) { + bool base_aligned = base % align == 0; + bool stride_aligned = stride_bytes % align == 0; + if (base_aligned && stride_aligned) { return align; } + } + return 1; +} + +/** + * @brief: Computes the vec_len parameter kernel policy parameter + * + * @param params Kernel parameters + */ +template +int determine_vec_len(pairwise_matrix_params params) +{ + size_t align_x = alignment_of_2d_array(params.x, params.ldx); + size_t align_y = alignment_of_2d_array(params.y, params.ldy); + size_t byte_alignment = min(align_x, align_y); + + // Since alignment is in bytes, it could be smaller than sizeof(DataT). + // Handle this (unlikely) case here. + RAFT_EXPECTS(sizeof(DataT) <= byte_alignment, + "Input matrix must be aligned to size of elements."); + + // Compute number of elements that can be loaded in one instruction + // without causing misalignent errors. + int vec_len_aligned = (byte_alignment % sizeof(DataT) == 0) ? byte_alignment / sizeof(DataT) : 1; + + // In the future, pairwise_matrix might support `int8_t` input. In that case, + // byte_alignment / sizeof(DataT) might exceed 4. We maximize at 4 here, to + // prevent adding more cases in dispatch_layout below (which are expensive to + // compile). + vec_len_aligned = std::min(vec_len_aligned, 4); + + return vec_len_aligned; +} + +template +using vec_len_constant = std::integral_constant; + +/** + * @brief: Converts run-time arguments to compile-time arguments + * + * Converts run-time arguments row_major and vec_len to compile-time arguments + * and dispatches a lambda f with these compile-time arguments. + * + * This is equivalent to copying and pasting the lambda function `f` in each of + * the switch case statements. + * + * @tparam F Type of lambda f. + * @param row_major Boolean indicating whether input arrays have row-major layout. + * @param vec_len Integer value 1, 2, or 4 specifying the Veclen template parameter of + * the KernelPolicy. + * @param f Lambda that takes two std::integral_constant parameters representing + * row_major and vec_len. + */ +template +auto dispatch_layout(bool row_major, int vec_len, F&& f) +{ + if (row_major) { + switch (vec_len) { + case 4: return f(std::bool_constant(), vec_len_constant<4>()); + case 2: return f(std::bool_constant(), vec_len_constant<2>()); + default: return f(std::bool_constant(), vec_len_constant<1>()); + } + } else { + switch (vec_len) { + case 4: return f(std::bool_constant(), vec_len_constant<4>()); + case 2: return f(std::bool_constant(), vec_len_constant<2>()); + default: return f(std::bool_constant(), vec_len_constant<1>()); + } + } +} + +}; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh new file mode 100644 index 0000000000..6e284007ea --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace raft::distance::detail { + +template +pairwise_matrix_sm60_wrapper pairwise_matrix_sm60_get_wrapper( + OpT distance_op, + pairwise_matrix_params params, + SM_compat_t sm_compat_range) +{ + int vec_len = determine_vec_len(params); + + return dispatch_layout(params.is_row_major, vec_len, [&](auto row_major, auto vec_len_aligned) { + // row_major and vec_len are std::integral_constants of type bool and int + // respectively. + + // To keep compile times in check, we only specialize on veclen > 1 when + // the inner loop is relatively cheap (< 5 flops). + constexpr int vec_len_op = distance_op.expensive_inner_loop ? 1 : vec_len_aligned(); + + // Prevent double, vec_len=4 combination (this is not supported) + constexpr int vec_len = std::min(vec_len_op, static_cast(16 / sizeof(DataT))); + + typedef typename raft::linalg::Policy4x4::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + typedef typename std::conditional::type Policy; + + auto wrapper = + make_pairwise_matrix_sm60_wrapper(distance_op, params, sm_compat_range); + + return wrapper; + }); +} + +template +void pairwise_matrix_sm60_dispatch(OpT distance_op, + pairwise_matrix_params params, + SM_compat_t sm_compat_range, + cudaStream_t stream) +{ + auto wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, sm_compat_range); + + wrapper.launch(distance_op, params, stream); +} + +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm80.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm80.cuh new file mode 100644 index 0000000000..ec2d522c25 --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm80.cuh @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include // std::min +#include +#include + +namespace raft::distance::detail { + +template +void pairwise_matrix_sm80_dispatch(OpT distance_op, + pairwise_matrix_params params, + SM_compat_t sm_compat_range, + cudaStream_t stream) +{ + int vec_len = determine_vec_len(params); + + dispatch_layout(params.is_row_major, vec_len, [&](auto row_major, auto vec_len_aligned) { + // row_major and vec_len are std::integral_constants of type bool and int + // respectively. + + // Prevent double, vec_len=4 combination (this is not supported) + constexpr int vec_len = std::min(vec_len_aligned(), static_cast(16 / sizeof(DataT))); + + using AccT = typename OpT::AccT; + cutlassDistanceKernel(params.x, + params.y, + params.x_norm, + params.y_norm, + params.m, + params.n, + params.k, + params.ldx, + params.ldy, + params.ld_out, + params.out, + params.fin_op, + distance_op, + stream); + }); +} + +}; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh index 7c1052d726..6e3ab7b26b 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -18,33 +18,33 @@ #include #include #include +#include +#include namespace raft::distance::detail { template -__global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel(const DataT* x, - const DataT* y, - const DataT* _xn, - const DataT* _yn, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - opT distance_op, - FinOpT fin_op) +__global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel( + OpT distance_op, pairwise_matrix_params params) { + // Early exit to minimize the size of the kernel when it is not supposed to be compiled. + constexpr SM_compat_t sm_compat_range{}; + if constexpr (!sm_compat_range.contains(raft::arch::SM_compute_arch())) { + assert(false); + return; + } + extern __shared__ char smem[]; + using AccT = typename OpT::AccT; + // Wrap operator back into lambdas. This is temporary and should be removed. // See: https://github.com/rapidsai/raft/issues/1323 auto core_op = [distance_op] __device__(AccT & acc, DataT & x, DataT & y) { @@ -74,50 +74,39 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel(co Policy, decltype(core_op), decltype(epilog_op), - decltype(fin_op), + decltype(params.fin_op), decltype(row_epilog_op), row_major, write_out> - obj(x, - y, - m, - n, - k, - lda, - ldb, - ldd, - _xn, - _yn, - dOutput, + obj(params.x, + params.y, + params.m, + params.n, + params.k, + params.ldx, + params.ldy, + params.ld_out, + params.x_norm, + params.y_norm, + params.out, smem, core_op, epilog_op, - fin_op, + params.fin_op, row_epilog_op); obj.run(); } template void pairwise_matrix(OpT distance_op, - FinOpT fin_op, - const DataT* x, - const DataT* y, - const DataT* _xn, - const DataT* _yn, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, + pairwise_matrix_params params, cudaStream_t stream) { dim3 blk(Policy::Nthreads); @@ -125,12 +114,83 @@ void pairwise_matrix(OpT distance_op, // https://en.cppreference.com/w/cpp/language/dependent_name) size_t smem_size = distance_op.template shared_mem_size(); // Obtain function pointer to kernel - auto kernel = pairwise_matrix_kernel; - dim3 grid = launchConfigGenerator(m, n, smem_size, kernel); + auto kernel = + pairwise_matrix_kernel; + dim3 grid = launchConfigGenerator(params.m, params.n, smem_size, kernel); - kernel<<>>( - x, y, _xn, _yn, m, n, k, lda, ldb, ldd, dOutput, distance_op, fin_op); + kernel<<>>(distance_op, params); RAFT_CUDA_TRY(cudaGetLastError()); } +// The type of a pointer to the pairwise matrix kernel. The following template +// arguments are type-erased: +// +// - The kernel policy +// - row_major +// - SM_compat_t +template +using pairwise_matrix_kernel_t = void (*)(OpT, pairwise_matrix_params); + +// A wrapper for the pairwise matrix kernel launch. Includes kernel launch +// parameters. +template +struct pairwise_matrix_sm60_wrapper { + dim3 grid; + dim3 block; + int smem_size; + pairwise_matrix_kernel_t kernel_ptr; + + void launch(OpT distance_op, + pairwise_matrix_params params, + cudaStream_t stream) + { + kernel_ptr<<>>(distance_op, params); + RAFT_CUDA_TRY(cudaGetLastError()); + } +}; + +/** @brief: Create kernel launch wrapper for pairwise matrix kernel + * + * This can be used to type-erase the kernel execution policy, row_major, and SM + * compatibility range. + * + * @tparam Policy: Kernel execution policy + * @tparam row_major: Indicates whether input matrices are row major + * @tparam OpT: Type of distance operation + * @tparam IdxT: Index type + * @tparam DataT: Data type + * @tparam OutT: Output data type + * @tparam FinOpT: Final operation type + * @tparam SM_compat_t: Type of the SM architecture compatibility + * + * @param distance_op: Distance operation + * @param params: Parameters + * @param sm_compat_range: Which SM architectures to compile for. + */ +template +pairwise_matrix_sm60_wrapper make_pairwise_matrix_sm60_wrapper( + OpT distance_op, + pairwise_matrix_params params, + SM_compat_t sm_compat_range) +{ + dim3 block(Policy::Nthreads); + // Use .template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + int smem_size = distance_op.template shared_mem_size(); + // Obtain function pointer to kernel + auto kernel = + pairwise_matrix_kernel; + dim3 grid = launchConfigGenerator(params.m, params.n, smem_size, kernel); + + return pairwise_matrix_sm60_wrapper{ + grid, block, smem_size, kernel}; +} + }; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/params.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/params.cuh new file mode 100644 index 0000000000..005b95afe9 --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_matrix/params.cuh @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +namespace raft::distance::detail { + +template +struct pairwise_matrix_params { + IdxT m; + IdxT n; + IdxT k; + IdxT ldx; + IdxT ldy; + IdxT ld_out; + const DataT* x; + const DataT* y; + const DataT* x_norm; + const DataT* y_norm; + OutT* out; + FinOpT fin_op; + bool is_row_major; + + /// @brief: Flips the x and y input and corresponding sizes + void flip_x_and_y() + { + // Flip m, n; ldx, ldy; x, y; x_norm, y_norm. + std::swap(m, n); + std::swap(ldx, ldy); + std::swap(x, y); + std::swap(x_norm, y_norm); + } +}; + +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/distance_types.hpp b/cpp/include/raft/distance/distance_types.hpp index f5ed68af4a..4060147f1d 100644 --- a/cpp/include/raft/distance/distance_types.hpp +++ b/cpp/include/raft/distance/distance_types.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -66,6 +66,26 @@ enum DistanceType : unsigned short { Precomputed = 100 }; +/** + * Whether minimal distance corresponds to similar elements (using the given metric). + */ +inline bool is_min_close(DistanceType metric) +{ + bool select_min; + switch (metric) { + case DistanceType::InnerProduct: + case DistanceType::CosineExpanded: + case DistanceType::CorrelationExpanded: + // Similarity metrics have the opposite meaning, i.e. nearest neighbors are those with larger + // similarity (See the same logic at cpp/include/raft/sparse/spatial/detail/knn.cuh:362 + // {perform_k_selection}) + select_min = false; + break; + default: select_min = true; + } + return select_min; +} + namespace kernels { enum KernelType { LINEAR, POLYNOMIAL, RBF, TANH }; diff --git a/cpp/include/raft/linalg/detail/transpose.cuh b/cpp/include/raft/linalg/detail/transpose.cuh index 9e7b236fed..05588bda9c 100644 --- a/cpp/include/raft/linalg/detail/transpose.cuh +++ b/cpp/include/raft/linalg/detail/transpose.cuh @@ -37,6 +37,7 @@ void transpose(raft::device_resources const& handle, cudaStream_t stream) { cublasHandle_t cublas_h = handle.get_cublas_handle(); + RAFT_CUBLAS_TRY(cublasSetStream(cublas_h, stream)); int out_n_rows = n_cols; int out_n_cols = n_rows; @@ -90,6 +91,7 @@ void transpose_row_major_impl( auto out_n_cols = in.extent(0); T constexpr kOne = 1; T constexpr kZero = 0; + CUBLAS_TRY(cublasgeam(handle.get_cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, @@ -116,6 +118,7 @@ void transpose_col_major_impl( auto out_n_cols = in.extent(0); T constexpr kOne = 1; T constexpr kZero = 0; + CUBLAS_TRY(cublasgeam(handle.get_cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, diff --git a/cpp/include/raft/linalg/gemm.cuh b/cpp/include/raft/linalg/gemm.cuh index a336f844bf..7dfaa18911 100644 --- a/cpp/include/raft/linalg/gemm.cuh +++ b/cpp/include/raft/linalg/gemm.cuh @@ -19,9 +19,9 @@ #pragma once #include "detail/gemm.hpp" - #include #include +#include #include #include #include diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index b4d7a37f33..643a63d9db 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -32,6 +32,7 @@ #include #include +#include namespace raft::matrix::detail::select::radix { diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index ac9d14ce17..4891cc5f8d 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -18,9 +18,8 @@ #include #include +#include #include -#include -#include namespace raft::neighbors::brute_force { @@ -96,15 +95,15 @@ inline void knn_merge_parts( "Number of columns in output indices and distances matrices must be equal to k"); auto n_parts = in_keys.extent(0) / n_samples; - spatial::knn::detail::knn_merge_parts(in_keys.data_handle(), - in_values.data_handle(), - out_keys.data_handle(), - out_values.data_handle(), - n_samples, - n_parts, - in_keys.extent(1), - handle.get_stream(), - translations.value_or(nullptr)); + detail::knn_merge_parts(in_keys.data_handle(), + in_values.data_handle(), + out_keys.data_handle(), + out_values.data_handle(), + n_samples, + n_parts, + in_keys.extent(1), + handle.get_stream(), + translations.value_or(nullptr)); } /** @@ -181,21 +180,21 @@ void knn(raft::device_resources const& handle, std::vector* trans_arg = global_id_offset.has_value() ? &trans : nullptr; - raft::spatial::knn::detail::brute_force_knn_impl(handle, - inputs, - sizes, - static_cast(index[0].extent(1)), - // TODO: This is unfortunate. Need to fix. - const_cast(search.data_handle()), - static_cast(search.extent(0)), - indices.data_handle(), - distances.data_handle(), - k, - rowMajorIndex, - rowMajorQuery, - trans_arg, - metric, - metric_arg.value_or(2.0f)); + raft::neighbors::detail::brute_force_knn_impl(handle, + inputs, + sizes, + static_cast(index[0].extent(1)), + // TODO: This is unfortunate. Need to fix. + const_cast(search.data_handle()), + static_cast(search.extent(0)), + indices.data_handle(), + distances.data_handle(), + k, + rowMajorIndex, + rowMajorQuery, + trans_arg, + metric, + metric_arg.value_or(2.0f)); } /** diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/Comparators.cuh b/cpp/include/raft/neighbors/detail/faiss_select/Comparators.cuh similarity index 84% rename from cpp/include/raft/spatial/knn/detail/faiss_select/Comparators.cuh rename to cpp/include/raft/neighbors/detail/faiss_select/Comparators.cuh index 173c06af30..1a34d2f68c 100644 --- a/cpp/include/raft/spatial/knn/detail/faiss_select/Comparators.cuh +++ b/cpp/include/raft/neighbors/detail/faiss_select/Comparators.cuh @@ -10,7 +10,7 @@ #include #include -namespace raft::spatial::knn::detail::faiss_select { +namespace raft::neighbors::detail::faiss_select { template struct Comparator { @@ -26,4 +26,4 @@ struct Comparator { __device__ static inline bool gt(half a, half b) { return __hgt(a, b); } }; -} // namespace raft::spatial::knn::detail::faiss_select +} // namespace raft::neighbors::detail::faiss_select diff --git a/cpp/include/raft/neighbors/detail/faiss_select/DistanceUtils.h b/cpp/include/raft/neighbors/detail/faiss_select/DistanceUtils.h new file mode 100644 index 0000000000..cd4a52e5df --- /dev/null +++ b/cpp/include/raft/neighbors/detail/faiss_select/DistanceUtils.h @@ -0,0 +1,52 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file thirdparty/LICENSES/LICENSE.faiss + */ + +#pragma once + +namespace raft::neighbors::detail::faiss_select { +// If the inner size (dim) of the vectors is small, we want a larger query tile +// size, like 1024 +inline void chooseTileSize(size_t numQueries, + size_t numCentroids, + size_t dim, + size_t elementSize, + size_t totalMem, + size_t& tileRows, + size_t& tileCols) +{ + // The matrix multiplication should be large enough to be efficient, but if + // it is too large, we seem to lose efficiency as opposed to + // double-streaming. Each tile size here defines 1/2 of the memory use due + // to double streaming. We ignore available temporary memory, as that is + // adjusted independently by the user and can thus meet these requirements + // (or not). For <= 4 GB GPUs, prefer 512 MB of usage. For <= 8 GB GPUs, + // prefer 768 MB of usage. Otherwise, prefer 1 GB of usage. + size_t targetUsage = 0; + + if (totalMem <= ((size_t)4) * 1024 * 1024 * 1024) { + targetUsage = 512 * 1024 * 1024; + } else if (totalMem <= ((size_t)8) * 1024 * 1024 * 1024) { + targetUsage = 768 * 1024 * 1024; + } else { + targetUsage = 1024 * 1024 * 1024; + } + + targetUsage /= 2 * elementSize; + + // 512 seems to be a batch size sweetspot for float32. + // If we are on float16, increase to 512. + // If the k size (vec dim) of the matrix multiplication is small (<= 32), + // increase to 1024. + size_t preferredTileRows = 512; + if (dim <= 32) { preferredTileRows = 1024; } + + tileRows = std::min(preferredTileRows, numQueries); + + // tileCols is the remainder size + tileCols = std::min(targetUsage / preferredTileRows, numCentroids); +} +} // namespace raft::neighbors::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkBlock.cuh b/cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkBlock.cuh similarity index 97% rename from cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkBlock.cuh rename to cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkBlock.cuh index d923b41ded..79e3f95be0 100644 --- a/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkBlock.cuh +++ b/cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkBlock.cuh @@ -8,10 +8,10 @@ #pragma once #include -#include -#include +#include +#include -namespace raft::spatial::knn::detail::faiss_select { +namespace raft::neighbors::detail::faiss_select { // Merge pairs of lists smaller than blockDim.x (NumThreads) template ::merge(listK, listV); } -} // namespace raft::spatial::knn::detail::faiss_select +} // namespace raft::neighbors::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkUtils.cuh b/cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkUtils.cuh similarity index 79% rename from cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkUtils.cuh rename to cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkUtils.cuh index 2cb01f9199..78f794bff4 100644 --- a/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkUtils.cuh +++ b/cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkUtils.cuh @@ -7,7 +7,7 @@ #pragma once -namespace raft::spatial::knn::detail::faiss_select { +namespace raft::neighbors::detail::faiss_select { template inline __device__ void swap(bool swap, T& x, T& y) @@ -22,4 +22,4 @@ inline __device__ void assign(bool assign, T& x, T y) { x = assign ? y : x; } -} // namespace raft::spatial::knn::detail::faiss_select +} // namespace raft::neighbors::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkWarp.cuh b/cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkWarp.cuh similarity index 98% rename from cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkWarp.cuh rename to cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkWarp.cuh index bce739b2d8..04f7f90aac 100644 --- a/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkWarp.cuh +++ b/cpp/include/raft/neighbors/detail/faiss_select/MergeNetworkWarp.cuh @@ -7,12 +7,12 @@ #pragma once -#include -#include +#include +#include #include -namespace raft::spatial::knn::detail::faiss_select { +namespace raft::neighbors::detail::faiss_select { // // This file contains functions to: @@ -518,4 +518,4 @@ inline __device__ void warpSortAnyRegisters(K k[N], V v[N]) BitonicSortStep::sort(k, v); } -} // namespace raft::spatial::knn::detail::faiss_select +} // namespace raft::neighbors::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/Select.cuh b/cpp/include/raft/neighbors/detail/faiss_select/Select.cuh similarity index 97% rename from cpp/include/raft/spatial/knn/detail/faiss_select/Select.cuh rename to cpp/include/raft/neighbors/detail/faiss_select/Select.cuh index e4faff7a6c..4aa7d68f54 100644 --- a/cpp/include/raft/spatial/knn/detail/faiss_select/Select.cuh +++ b/cpp/include/raft/neighbors/detail/faiss_select/Select.cuh @@ -7,14 +7,14 @@ #pragma once -#include -#include -#include +#include +#include +#include #include #include -namespace raft::spatial::knn::detail::faiss_select { +namespace raft::neighbors::detail::faiss_select { // Specialization for block-wide monotonic merges producing a merge sort // since what we really want is a constexpr loop expansion @@ -552,4 +552,4 @@ struct WarpSelect { V threadV; }; -} // namespace raft::spatial::knn::detail::faiss_select +} // namespace raft::neighbors::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/StaticUtils.h b/cpp/include/raft/neighbors/detail/faiss_select/StaticUtils.h similarity index 91% rename from cpp/include/raft/spatial/knn/detail/faiss_select/StaticUtils.h rename to cpp/include/raft/neighbors/detail/faiss_select/StaticUtils.h index bac051b68c..5a25c7a321 100644 --- a/cpp/include/raft/spatial/knn/detail/faiss_select/StaticUtils.h +++ b/cpp/include/raft/neighbors/detail/faiss_select/StaticUtils.h @@ -15,7 +15,7 @@ #define __device__ #endif -namespace raft::spatial::knn::detail::faiss_select::utils { +namespace raft::neighbors::detail::faiss_select::utils { template constexpr __host__ __device__ bool isPowerOf2(T v) @@ -45,4 +45,4 @@ static_assert(nextHighestPowerOf2(1536000000u) == 2147483648u, "nextHighestPower static_assert(nextHighestPowerOf2((size_t)2147483648ULL) == (size_t)4294967296ULL, "nextHighestPowerOf2"); -} // namespace raft::spatial::knn::detail::faiss_select::utils +} // namespace raft::neighbors::detail::faiss_select::utils diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/key_value_block_select.cuh b/cpp/include/raft/neighbors/detail/faiss_select/key_value_block_select.cuh similarity index 96% rename from cpp/include/raft/spatial/knn/detail/faiss_select/key_value_block_select.cuh rename to cpp/include/raft/neighbors/detail/faiss_select/key_value_block_select.cuh index 617a26a243..ff06b7dca4 100644 --- a/cpp/include/raft/spatial/knn/detail/faiss_select/key_value_block_select.cuh +++ b/cpp/include/raft/neighbors/detail/faiss_select/key_value_block_select.cuh @@ -7,14 +7,14 @@ #pragma once -#include -#include +#include +#include // TODO: Need to think further about the impact (and new boundaries created) on the registers // because this will change the max k that can be processed. One solution might be to break // up k into multiple batches for larger k. -namespace raft::spatial::knn::detail::faiss_select { +namespace raft::neighbors::detail::faiss_select { // `Dir` true, produce largest values. // `Dir` false, produce smallest values. @@ -221,4 +221,4 @@ struct KeyValueBlockSelect { int kMinus1; }; -} // namespace raft::spatial::knn::detail::faiss_select +} // namespace raft::neighbors::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh similarity index 54% rename from cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh rename to cpp/include/raft/neighbors/detail/ivf_flat_build.cuh index c417a97531..bf7248b983 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh @@ -16,31 +16,68 @@ #pragma once -#include "../ivf_flat_types.hpp" -#include "ann_utils.cuh" - #include #include #include #include #include #include -#include #include #include #include +#include +#include +#include +#include #include #include #include #include -#include -namespace raft::spatial::knn::ivf_flat::detail { +namespace raft::neighbors::ivf_flat::detail { using namespace raft::spatial::knn::detail; // NOLINT +template +auto clone(const raft::device_resources& res, const index& source) -> index +{ + auto stream = res.get_stream(); + + // Allocate the new index + index target(res, + source.metric(), + source.n_lists(), + source.adaptive_centers(), + source.conservative_memory_allocation(), + source.dim()); + + // Copy the independent parts + copy(target.list_sizes().data_handle(), + source.list_sizes().data_handle(), + source.list_sizes().size(), + stream); + copy(target.centers().data_handle(), + source.centers().data_handle(), + source.centers().size(), + stream); + if (source.center_norms().has_value()) { + target.allocate_center_norms(res); + copy(target.center_norms()->data_handle(), + source.center_norms()->data_handle(), + source.center_norms()->size(), + stream); + } + // Copy shared pointers + target.lists() = source.lists(); + + // Make sure the device pointers point to the new lists + target.recompute_internal_state(res); + + return target; +} + /** * @brief Record the dataset into the index, one source row at a time. * @@ -60,11 +97,11 @@ using namespace raft::spatial::knn::detail; // NOLINT * we use source_vecs[source_ixs[i],:]. In both cases i=0..n_rows-1. * * @param[in] labels device pointer to the cluster ids for each row [n_rows] - * @param[in] list_offsets device pointer to the cluster offsets in the output (index) [n_lists] * @param[in] source_vecs device pointer to the input data [n_rows, dim] * @param[in] source_ixs device pointer to the input indices [n_rows] - * @param[out] list_data device pointer to the output [index_size, dim] - * @param[out] list_index device pointer to the source ids corr. to the output [index_size] + * @param[out] list_data_ptrs device pointer to the index data of size [n_lists][index_size, dim] + * @param[out] list_index_ptrs device pointer to the source ids corr. to the output [n_lists] + * [index_size] * @param[out] list_sizes_ptr device pointer to the cluster sizes [n_lists]; * it's used as an atomic counter, and must be initialized with zeros. * @param n_rows source length @@ -74,11 +111,10 @@ using namespace raft::spatial::knn::detail; // NOLINT */ template __global__ void build_index_kernel(const LabelT* labels, - const IdxT* list_offsets, const T* source_vecs, const IdxT* source_ixs, - T* list_data, - IdxT* list_index, + T** list_data_ptrs, + IdxT** list_index_ptrs, uint32_t* list_sizes_ptr, IdxT n_rows, uint32_t dim, @@ -89,10 +125,11 @@ __global__ void build_index_kernel(const LabelT* labels, auto list_id = labels[i]; auto inlist_id = atomicAdd(list_sizes_ptr + list_id, 1); - auto list_offset = list_offsets[list_id]; + auto* list_index = list_index_ptrs[list_id]; + auto* list_data = list_data_ptrs[list_id]; // Record the source vector id in the index - list_index[list_offset + inlist_id] = source_ixs == nullptr ? i : source_ixs[i]; + list_index[inlist_id] = source_ixs == nullptr ? i : source_ixs[i]; // The data is written in interleaved groups of `index::kGroupSize` vectors using interleaved_group = Pow2; @@ -100,7 +137,7 @@ __global__ void build_index_kernel(const LabelT* labels, auto ingroup_id = interleaved_group::mod(inlist_id) * veclen; // Point to the location of the interleaved group of vectors - list_data += (list_offset + group_offset) * dim; + list_data += group_offset * dim; // Point to the source vector if constexpr (gather_src) { @@ -117,58 +154,53 @@ __global__ void build_index_kernel(const LabelT* labels, } } -/** See raft::spatial::knn::ivf_flat::extend docs */ +/** See raft::neighbors::ivf_flat::extend docs */ template -inline auto extend(raft::device_resources const& handle, - const index& orig_index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -> index +void extend(raft::device_resources const& handle, + index* index, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) { using LabelT = uint32_t; + RAFT_EXPECTS(index != nullptr, "index cannot be empty."); auto stream = handle.get_stream(); - auto n_lists = orig_index.n_lists(); - auto dim = orig_index.dim(); + auto n_lists = index->n_lists(); + auto dim = index->dim(); + list_spec list_device_spec{index->dim(), + index->conservative_memory_allocation()}; common::nvtx::range fun_scope( "ivf_flat::extend(%zu, %u)", size_t(n_rows), dim); - RAFT_EXPECTS(new_indices != nullptr || orig_index.size() == 0, + RAFT_EXPECTS(new_indices != nullptr || index->size() == 0, "You must pass data indices when the index is non-empty."); - rmm::device_uvector new_labels(n_rows, stream); + auto new_labels = raft::make_device_vector(handle, n_rows); raft::cluster::kmeans_balanced_params kmeans_params; - kmeans_params.metric = orig_index.metric(); - auto new_vectors_view = raft::make_device_matrix_view(new_vectors, n_rows, dim); - auto orig_centroids_view = raft::make_device_matrix_view( - orig_index.centers().data_handle(), n_lists, dim); - auto labels_view = raft::make_device_vector_view(new_labels.data(), n_rows); + kmeans_params.metric = index->metric(); + auto new_vectors_view = raft::make_device_matrix_view(new_vectors, n_rows, dim); + auto orig_centroids_view = + raft::make_device_matrix_view(index->centers().data_handle(), n_lists, dim); raft::cluster::kmeans_balanced::predict(handle, kmeans_params, new_vectors_view, orig_centroids_view, - labels_view, + new_labels.view(), utils::mapping{}); - index ext_index( - handle, orig_index.metric(), n_lists, orig_index.adaptive_centers(), dim); - - auto list_sizes_ptr = ext_index.list_sizes().data_handle(); - auto list_offsets_ptr = ext_index.list_offsets().data_handle(); - auto centers_ptr = ext_index.centers().data_handle(); + auto* list_sizes_ptr = index->list_sizes().data_handle(); + auto old_list_sizes_dev = raft::make_device_vector(handle, n_lists); + copy(old_list_sizes_dev.data_handle(), list_sizes_ptr, n_lists, stream); // Calculate the centers and sizes on the new data, starting from the original values - raft::copy(centers_ptr, orig_index.centers().data_handle(), ext_index.centers().size(), stream); - - if (ext_index.adaptive_centers()) { - raft::copy( - list_sizes_ptr, orig_index.list_sizes().data_handle(), ext_index.list_sizes().size(), stream); - auto centroids_view = raft::make_device_matrix_view(centers_ptr, n_lists, dim); + if (index->adaptive_centers()) { + auto centroids_view = raft::make_device_matrix_view( + index->centers().data_handle(), index->centers().extent(0), index->centers().extent(1)); auto list_sizes_view = raft::make_device_vector_view, IdxT>( list_sizes_ptr, n_lists); - auto const_labels_view = - raft::make_device_vector_view(new_labels.data(), n_rows); + auto const_labels_view = make_const_mdspan(new_labels.view()); raft::cluster::kmeans_balanced::helpers::calc_centers_and_sizes(handle, new_vectors_view, const_labels_view, @@ -180,90 +212,89 @@ inline auto extend(raft::device_resources const& handle, raft::stats::histogram(raft::stats::HistTypeAuto, reinterpret_cast(list_sizes_ptr), IdxT(n_lists), - new_labels.data(), + new_labels.data_handle(), n_rows, 1, stream); raft::linalg::add( - list_sizes_ptr, list_sizes_ptr, orig_index.list_sizes().data_handle(), n_lists, stream); + list_sizes_ptr, list_sizes_ptr, old_list_sizes_dev.data_handle(), n_lists, stream); } - // Calculate new offsets - IdxT index_size = 0; - update_device(list_offsets_ptr, &index_size, 1, stream); - thrust::inclusive_scan( - rmm::exec_policy(stream), - list_sizes_ptr, - list_sizes_ptr + n_lists, - list_offsets_ptr + 1, - [] __device__(IdxT s, uint32_t l) { return s + Pow2::roundUp(l); }); - update_host(&index_size, list_offsets_ptr + n_lists, 1, stream); - handle.sync_stream(stream); - - ext_index.allocate(handle, index_size); - - // Populate index with the old data - if (orig_index.size() > 0) { - utils::block_copy(orig_index.list_offsets().data_handle(), - list_offsets_ptr, - IdxT(n_lists), - orig_index.data().data_handle(), - ext_index.data().data_handle(), - IdxT(dim), - stream); - - utils::block_copy(orig_index.list_offsets().data_handle(), - list_offsets_ptr, - IdxT(n_lists), - orig_index.indices().data_handle(), - ext_index.indices().data_handle(), - IdxT(1), - stream); + // Calculate and allocate new list data + std::vector new_list_sizes(n_lists); + std::vector old_list_sizes(n_lists); + { + copy(old_list_sizes.data(), old_list_sizes_dev.data_handle(), n_lists, stream); + copy(new_list_sizes.data(), list_sizes_ptr, n_lists, stream); + handle.sync_stream(); + auto& lists = index->lists(); + for (uint32_t label = 0; label < n_lists; label++) { + ivf::resize_list(handle, + lists[label], + list_device_spec, + new_list_sizes[label], + Pow2::roundUp(old_list_sizes[label])); + } } - + // Update the pointers and the sizes + index->recompute_internal_state(handle); // Copy the old sizes, so we can start from the current state of the index; // we'll rebuild the `list_sizes_ptr` in the following kernel, using it as an atomic counter. - raft::copy( - list_sizes_ptr, orig_index.list_sizes().data_handle(), ext_index.list_sizes().size(), stream); + raft::copy(list_sizes_ptr, old_list_sizes_dev.data_handle(), n_lists, stream); + // Kernel to insert the new vectors const dim3 block_dim(256); const dim3 grid_dim(raft::ceildiv(n_rows, block_dim.x)); - build_index_kernel<<>>(new_labels.data(), - list_offsets_ptr, + build_index_kernel<<>>(new_labels.data_handle(), new_vectors, new_indices, - ext_index.data().data_handle(), - ext_index.indices().data_handle(), + index->data_ptrs().data_handle(), + index->inds_ptrs().data_handle(), list_sizes_ptr, n_rows, dim, - ext_index.veclen()); + index->veclen()); RAFT_CUDA_TRY(cudaPeekAtLastError()); // Precompute the centers vector norms for L2Expanded distance - if (ext_index.center_norms().has_value()) { - if (!ext_index.adaptive_centers() && orig_index.center_norms().has_value()) { - raft::copy(ext_index.center_norms()->data_handle(), - orig_index.center_norms()->data_handle(), - orig_index.center_norms()->size(), - stream); - } else { - raft::linalg::rowNorm(ext_index.center_norms()->data_handle(), - ext_index.centers().data_handle(), + if (!index->center_norms().has_value()) { + index->allocate_center_norms(handle); + if (index->center_norms().has_value()) { + raft::linalg::rowNorm(index->center_norms()->data_handle(), + index->centers().data_handle(), dim, n_lists, raft::linalg::L2Norm, true, stream); - RAFT_LOG_TRACE_VEC(ext_index.center_norms()->data_handle(), std::min(dim, 20)); + RAFT_LOG_TRACE_VEC(index->center_norms()->data_handle(), std::min(dim, 20)); } + } else if (index->center_norms().has_value() && index->adaptive_centers()) { + raft::linalg::rowNorm(index->center_norms()->data_handle(), + index->centers().data_handle(), + dim, + n_lists, + raft::linalg::L2Norm, + true, + stream); + RAFT_LOG_TRACE_VEC(index->center_norms()->data_handle(), std::min(dim, 20)); } +} - // assemble the index +/** See raft::neighbors::ivf_flat::extend docs */ +template +auto extend(raft::device_resources const& handle, + const index& orig_index, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) -> index +{ + auto ext_index = clone(handle, orig_index); + detail::extend(handle, &ext_index, new_vectors, new_indices, n_rows); return ext_index; } -/** See raft::spatial::knn::ivf_flat::build docs */ +/** See raft::neighbors::ivf_flat::build docs */ template inline auto build(raft::device_resources const& handle, const index_params& params, @@ -280,7 +311,8 @@ inline auto build(raft::device_resources const& handle, index index(handle, params, dim); utils::memzero(index.list_sizes().data_handle(), index.list_sizes().size(), stream); - utils::memzero(index.list_offsets().data_handle(), index.list_offsets().size(), stream); + utils::memzero(index.data_ptrs().data_handle(), index.data_ptrs().size(), stream); + utils::memzero(index.inds_ptrs().data_handle(), index.inds_ptrs().size(), stream); // Train the kmeans clustering { @@ -310,10 +342,9 @@ inline auto build(raft::device_resources const& handle, // add the data if necessary if (params.add_data_on_build) { - return detail::extend(handle, index, dataset, nullptr, n_rows); - } else { - return index; + detail::extend(handle, &index, dataset, nullptr, n_rows); } + return index; } /** @@ -356,20 +387,17 @@ inline void fill_refinement_index(raft::device_resources const& handle, new_labels_view, raft::compose_op(raft::cast_op(), raft::div_const_op(n_candidates))); - auto list_sizes_ptr = refinement_index->list_sizes().data_handle(); - auto list_offsets_ptr = refinement_index->list_offsets().data_handle(); + auto list_sizes_ptr = refinement_index->list_sizes().data_handle(); // We do not fill centers and center norms, since we will not run coarse search. - // Calculate new offsets - uint32_t n_roundup = Pow2::roundUp(n_candidates); - auto list_offsets_view = raft::make_device_vector_view( - list_offsets_ptr, refinement_index->list_offsets().size()); - linalg::map_offset(handle, - list_offsets_view, - raft::compose_op(raft::cast_op(), raft::mul_const_op(n_roundup))); - - IdxT index_size = n_roundup * n_lists; - refinement_index->allocate(handle, index_size); + // Allocate new memory + auto& lists = refinement_index->lists(); + list_spec list_device_spec{refinement_index->dim(), false}; + for (uint32_t label = 0; label < n_lists; label++) { + ivf::resize_list(handle, lists[label], list_device_spec, n_candidates, uint32_t(0)); + } + // Update the pointers and the sizes + refinement_index->recompute_internal_state(handle); RAFT_CUDA_TRY(cudaMemsetAsync(list_sizes_ptr, 0, n_lists * sizeof(uint32_t), stream)); @@ -377,121 +405,14 @@ inline void fill_refinement_index(raft::device_resources const& handle, const dim3 grid_dim(raft::ceildiv(n_queries * n_candidates, block_dim.x)); build_index_kernel <<>>(new_labels.data(), - list_offsets_ptr, dataset, candidate_idx, - refinement_index->data().data_handle(), - refinement_index->indices().data_handle(), + refinement_index->data_ptrs().data_handle(), + refinement_index->inds_ptrs().data_handle(), list_sizes_ptr, n_queries * n_candidates, refinement_index->dim(), refinement_index->veclen()); RAFT_CUDA_TRY(cudaPeekAtLastError()); } - -// Serialization version 2 -// No backward compatibility yet; that is, can't add additional fields without breaking -// backward compatibility. -// TODO(hcho3) Implement next-gen serializer for IVF that allows for expansion in a backward -// compatible fashion. -constexpr int serialization_version = 2; - -static_assert(sizeof(index) == 408, - "The size of the index struct has changed since the last update; " - "paste in the new size and consider updating the save/load logic"); - -/** - * Save the index to file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @param[in] handle the raft handle - * @param[in] filename the file name for saving the index - * @param[in] index_ IVF-Flat index - * - */ -template -void serialize(raft::device_resources const& handle, - const std::string& filename, - const index& index_) -{ - std::ofstream of(filename, std::ios::out | std::ios::binary); - if (!of) { RAFT_FAIL("Cannot open %s", filename.c_str()); } - - RAFT_LOG_DEBUG( - "Saving IVF-PQ index, size %zu, dim %u", static_cast(index_.size()), index_.dim()); - serialize_scalar(handle, of, serialization_version); - serialize_scalar(handle, of, index_.size()); - serialize_scalar(handle, of, index_.dim()); - serialize_scalar(handle, of, index_.n_lists()); - serialize_scalar(handle, of, index_.metric()); - serialize_scalar(handle, of, index_.veclen()); - serialize_scalar(handle, of, index_.adaptive_centers()); - serialize_mdspan(handle, of, index_.data()); - serialize_mdspan(handle, of, index_.indices()); - serialize_mdspan(handle, of, index_.list_sizes()); - serialize_mdspan(handle, of, index_.list_offsets()); - serialize_mdspan(handle, of, index_.centers()); - if (index_.center_norms()) { - bool has_norms = true; - serialize_scalar(handle, of, has_norms); - serialize_mdspan(handle, of, *index_.center_norms()); - } else { - bool has_norms = false; - serialize_scalar(handle, of, has_norms); - } - of.close(); - if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } -} - -/** Load an index from file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @param[in] handle the raft handle - * @param[in] filename the name of the file that stores the index - * @param[in] index_ IVF-Flat index - * - */ -template -auto deserialize(raft::device_resources const& handle, const std::string& filename) - -> index -{ - std::ifstream infile(filename, std::ios::in | std::ios::binary); - - if (!infile) { RAFT_FAIL("Cannot open %s", filename.c_str()); } - - auto ver = deserialize_scalar(handle, infile); - if (ver != serialization_version) { - RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); - } - auto n_rows = deserialize_scalar(handle, infile); - auto dim = deserialize_scalar(handle, infile); - auto n_lists = deserialize_scalar(handle, infile); - auto metric = deserialize_scalar(handle, infile); - auto veclen = deserialize_scalar(handle, infile); - bool adaptive_centers = deserialize_scalar(handle, infile); - - index index_ = - raft::spatial::knn::ivf_flat::index(handle, metric, n_lists, adaptive_centers, dim); - - index_.allocate(handle, n_rows); - auto data = index_.data(); - deserialize_mdspan(handle, infile, data); - deserialize_mdspan(handle, infile, index_.indices()); - deserialize_mdspan(handle, infile, index_.list_sizes()); - deserialize_mdspan(handle, infile, index_.list_offsets()); - deserialize_mdspan(handle, infile, index_.centers()); - bool has_norms = deserialize_scalar(handle, infile); - if (has_norms) { - if (!index_.center_norms()) { - RAFT_FAIL("Error inconsistent center norms"); - } else { - auto center_norms = *index_.center_norms(); - deserialize_mdspan(handle, infile, center_norms); - } - } - infile.close(); - return index_; -} -} // namespace raft::spatial::knn::ivf_flat::detail +} // namespace raft::neighbors::ivf_flat::detail diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh similarity index 97% rename from cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh rename to cpp/include/raft/neighbors/detail/ivf_flat_search.cuh index 7f70d4b8a5..f657070df4 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh @@ -16,9 +16,6 @@ #pragma once -#include "../ivf_flat_types.hpp" -#include "ann_utils.cuh" - #include #include #include @@ -30,6 +27,8 @@ #include #include #include +#include +#include #include #include #include @@ -39,7 +38,7 @@ #include #include -namespace raft::spatial::knn::ivf_flat::detail { +namespace raft::neighbors::ivf_flat::detail { using namespace raft::spatial::knn::detail; // NOLINT @@ -673,10 +672,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock) const uint32_t query_smem_elems, const T* query, const uint32_t* coarse_index, - const IdxT* list_indices, - const T* list_data, + const IdxT* const* list_indices_ptrs, + const T* const* list_data_ptrs, const uint32_t* list_sizes, - const IdxT* list_offsets, const uint32_t n_probes, const uint32_t k, const uint32_t dim, @@ -722,8 +720,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) // Every CUDA block scans one cluster at a time. for (int probe_id = blockIdx.x; probe_id < n_probes; probe_id += gridDim.x) { - const uint32_t list_id = coarse_index[probe_id]; // The id of cluster(list) - const size_t list_offset = list_offsets[list_id]; + const uint32_t list_id = coarse_index[probe_id]; // The id of cluster(list) // The number of vectors in each cluster(list); [nlist] const uint32_t list_length = list_sizes[list_id]; @@ -741,7 +738,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) group_id += kNumWarps) { AccT dist = 0; // This is where this warp begins reading data (start position of an interleaved group) - const T* data = list_data + (list_offset + group_id * kIndexGroupSize) * dim; + const T* data = list_data_ptrs[list_id] + (group_id * kIndexGroupSize) * dim; // This is the vector a given lane/thread handles const uint32_t vec_id = group_id * WarpSize + lane_id; @@ -778,7 +775,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) // Enqueue one element per thread const float val = valid ? static_cast(dist) : block_sort_t::queue_t::kDummy; - const size_t idx = valid ? static_cast(list_indices[list_offset + vec_id]) : 0; + const size_t idx = valid ? static_cast(list_indices_ptrs[list_id][vec_id]) : 0; queue.add(val, idx); } } @@ -819,7 +816,7 @@ template void launch_kernel(Lambda lambda, PostLambda post_process, - const ivf_flat::index& index, + const index& index, const T* queries, const uint32_t* coarse_index, const uint32_t num_queries, @@ -869,10 +866,9 @@ void launch_kernel(Lambda lambda, query_smem_elems, queries, coarse_index, - index.indices().data_handle(), - index.data().data_handle(), + index.inds_ptrs().data_handle(), + index.data_ptrs().data_handle(), index.list_sizes().data_handle(), - index.list_offsets().data_handle(), n_probes, k, index.dim(), @@ -1056,7 +1052,7 @@ struct select_interleaved_scan_kernel { * @param stream */ template -void ivfflat_interleaved_scan(const ivf_flat::index& index, +void ivfflat_interleaved_scan(const index& index, const T* queries, const uint32_t* coarse_query_results, const uint32_t n_queries, @@ -1248,27 +1244,7 @@ void search_impl(raft::device_resources const& handle, } } -/** - * Whether minimal distance corresponds to similar elements (using the given metric). - */ -inline bool is_min_close(distance::DistanceType metric) -{ - bool select_min; - switch (metric) { - case raft::distance::DistanceType::InnerProduct: - case raft::distance::DistanceType::CosineExpanded: - case raft::distance::DistanceType::CorrelationExpanded: - // Similarity metrics have the opposite meaning, i.e. nearest neighbors are those with larger - // similarity (See the same logic at cpp/include/raft/sparse/spatial/detail/knn.cuh:362 - // {perform_k_selection}) - select_min = false; - break; - default: select_min = true; - } - return select_min; -} - -/** See raft::spatial::knn::ivf_flat::search docs */ +/** See raft::neighbors::ivf_flat::search docs */ template inline void search(raft::device_resources const& handle, const search_params& params, @@ -1299,10 +1275,10 @@ inline void search(raft::device_resources const& handle, n_queries, k, n_probes, - is_min_close(index.metric()), + raft::distance::is_min_close(index.metric()), neighbors, distances, mr); } -} // namespace raft::spatial::knn::ivf_flat::detail +} // namespace raft::neighbors::ivf_flat::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh new file mode 100644 index 0000000000..1bb7f97123 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh @@ -0,0 +1,176 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace raft::neighbors::ivf_flat::detail { + +// Serialization version 3 +// No backward compatibility yet; that is, can't add additional fields without breaking +// backward compatibility. +// TODO(hcho3) Implement next-gen serializer for IVF that allows for expansion in a backward +// compatible fashion. +constexpr int serialization_version = 3; + +// NB: we wrap this check in a struct, so that the updated RealSize is easy to see in the error +// message. +template +struct check_index_layout { + static_assert(RealSize == ExpectedSize, + "The size of the index struct has changed since the last update; " + "paste in the new size and consider updating the serialization logic"); +}; + +template struct check_index_layout), 296>; + +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index_ IVF-Flat index + * + */ +template +void serialize(raft::device_resources const& handle, std::ostream& os, const index& index_) +{ + RAFT_LOG_DEBUG( + "Saving IVF-Flat index, size %zu, dim %u", static_cast(index_.size()), index_.dim()); + + serialize_scalar(handle, os, serialization_version); + serialize_scalar(handle, os, index_.size()); + serialize_scalar(handle, os, index_.dim()); + serialize_scalar(handle, os, index_.n_lists()); + serialize_scalar(handle, os, index_.metric()); + serialize_scalar(handle, os, index_.adaptive_centers()); + serialize_scalar(handle, os, index_.conservative_memory_allocation()); + serialize_mdspan(handle, os, index_.centers()); + if (index_.center_norms()) { + bool has_norms = true; + serialize_scalar(handle, os, has_norms); + serialize_mdspan(handle, os, *index_.center_norms()); + } else { + bool has_norms = false; + serialize_scalar(handle, os, has_norms); + } + auto sizes_host = make_host_vector(index_.list_sizes().extent(0)); + copy(sizes_host.data_handle(), + index_.list_sizes().data_handle(), + sizes_host.size(), + handle.get_stream()); + handle.sync_stream(); + serialize_mdspan(handle, os, sizes_host.view()); + + list_spec list_store_spec{index_.dim(), true}; + for (uint32_t label = 0; label < index_.n_lists(); label++) { + ivf::serialize_list(handle, + os, + index_.lists()[label], + list_store_spec, + Pow2::roundUp(sizes_host(label))); + } + handle.sync_stream(); +} + +template +void serialize(raft::device_resources const& handle, + const std::string& filename, + const index& index_) +{ + std::ofstream of(filename, std::ios::out | std::ios::binary); + if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + + detail::serialize(handle, of, index_); + + of.close(); + if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } +} + +/** Load an index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * @param[in] index_ IVF-Flat index + * + */ +template +auto deserialize(raft::device_resources const& handle, std::istream& is) -> index +{ + auto ver = deserialize_scalar(handle, is); + if (ver != serialization_version) { + RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); + } + auto n_rows = deserialize_scalar(handle, is); + auto dim = deserialize_scalar(handle, is); + auto n_lists = deserialize_scalar(handle, is); + auto metric = deserialize_scalar(handle, is); + bool adaptive_centers = deserialize_scalar(handle, is); + bool cma = deserialize_scalar(handle, is); + + index index_ = index(handle, metric, n_lists, adaptive_centers, cma, dim); + + deserialize_mdspan(handle, is, index_.centers()); + bool has_norms = deserialize_scalar(handle, is); + if (has_norms) { + index_.allocate_center_norms(handle); + if (!index_.center_norms()) { + RAFT_FAIL("Error inconsistent center norms"); + } else { + auto center_norms = index_.center_norms().value(); + deserialize_mdspan(handle, is, center_norms); + } + } + deserialize_mdspan(handle, is, index_.list_sizes()); + + list_spec list_device_spec{index_.dim(), cma}; + list_spec list_store_spec{index_.dim(), true}; + for (uint32_t label = 0; label < index_.n_lists(); label++) { + ivf::deserialize_list(handle, is, index_.lists()[label], list_store_spec, list_device_spec); + } + handle.sync_stream(); + + index_.recompute_internal_state(handle); + + return index_; +} + +template +auto deserialize(raft::device_resources const& handle, const std::string& filename) + -> index +{ + std::ifstream is(filename, std::ios::in | std::ios::binary); + + if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + + auto index = detail::deserialize(handle, is); + + is.close(); + + return index; +} +} // namespace raft::neighbors::ivf_flat::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh index bf3014568a..1a563d213e 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh @@ -664,7 +664,7 @@ __launch_bounds__(BlockSize) __global__ void process_and_fill_codes_kernel( const uint32_t pq_len = pq_centers.extent(1); const uint32_t pq_dim = new_vectors.extent(1) / pq_len; - auto pq_extents = list_spec{PqBits, pq_dim, true}.make_list_extents(out_ix + 1); + auto pq_extents = list_spec{PqBits, pq_dim, true}.make_list_extents(out_ix + 1); auto pq_extents_vectorized = make_extents(pq_extents.extent(0), pq_extents.extent(1), pq_extents.extent(2)); auto pq_dataset = make_mdspan( @@ -899,14 +899,15 @@ void extend(raft::device_resources const& handle, &managed_memory_upstream, 1024 * 1024); // The spec defines how the clusters look like - auto spec = list_spec{index->pq_bits(), index->pq_dim(), index->conservative_memory_allocation()}; + auto spec = list_spec{ + index->pq_bits(), index->pq_dim(), index->conservative_memory_allocation()}; // Try to allocate an index with the same parameters and the projected new size // (which can be slightly larger than index->size() + n_rows, due to padding). // If this fails, the index would be too big to fit in the device anyway. std::optional> placeholder_list( std::in_place_t{}, handle, - list_spec(spec), + list_spec{spec}, n_rows + (kIndexGroupSize - 1) * std::min(n_clusters, n_rows)); // Available device memory diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh index 0701b0feb5..7d70ab9fbe 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh @@ -45,7 +45,9 @@ struct check_index_layout { "The size of the index struct has changed since the last update; " "paste in the new size and consider updating the serialization logic"); }; -template struct check_index_layout), 536>; + +// TODO: Recompute this and come back to it. +template struct check_index_layout), 448>; /** * Write the index to an output stream @@ -89,10 +91,9 @@ void serialize(raft::device_resources const& handle_, std::ostream& os, const in handle_.get_stream()); handle_.sync_stream(); serialize_mdspan(handle_, os, sizes_host.view()); - auto list_store_spec = list_spec{index.pq_bits(), index.pq_dim(), true}; + auto list_store_spec = list_spec{index.pq_bits(), index.pq_dim(), true}; for (uint32_t label = 0; label < index.n_lists(); label++) { - ivf::serialize_list( - handle_, os, index.lists()[label], list_store_spec, sizes_host(label)); + ivf::serialize_list(handle_, os, index.lists()[label], list_store_spec, sizes_host(label)); } } @@ -162,11 +163,10 @@ auto deserialize(raft::device_resources const& handle_, std::istream& is) -> ind deserialize_mdspan(handle_, is, index.centers_rot()); deserialize_mdspan(handle_, is, index.rotation_matrix()); deserialize_mdspan(handle_, is, index.list_sizes()); - auto list_device_spec = list_spec{pq_bits, pq_dim, cma}; - auto list_store_spec = list_spec{pq_bits, pq_dim, true}; + auto list_device_spec = list_spec{pq_bits, pq_dim, cma}; + auto list_store_spec = list_spec{pq_bits, pq_dim, true}; for (auto& list : index.lists()) { - ivf::deserialize_list( - handle_, is, list, list_store_spec, list_device_spec); + ivf::deserialize_list(handle_, is, list, list_store_spec, list_device_spec); } handle_.sync_stream(); diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh new file mode 100644 index 0000000000..875fc3b37c --- /dev/null +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -0,0 +1,455 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::neighbors::detail { +using namespace raft::spatial::knn::detail; +using namespace raft::spatial::knn; + +/** + * Calculates brute force knn, using a fixed memory budget + * by tiling over both the rows and columns of pairwise_distances + */ +template +void tiled_brute_force_knn(const raft::device_resources& handle, + const ElementType* search, // size (m ,d) + const ElementType* index, // size (n ,d) + size_t m, + size_t n, + size_t d, + int k, + ElementType* distances, // size (m, k) + IndexType* indices, // size (m, k) + raft::distance::DistanceType metric, + float metric_arg = 0.0, + size_t max_row_tile_size = 0, + size_t max_col_tile_size = 0) +{ + // Figure out the number of rows/cols to tile for + size_t tile_rows = 0; + size_t tile_cols = 0; + auto stream = handle.get_stream(); + auto device_memory = handle.get_workspace_resource(); + auto total_mem = device_memory->get_mem_info(stream).second; + faiss_select::chooseTileSize(m, n, d, sizeof(ElementType), total_mem, tile_rows, tile_cols); + + // for unittesting, its convenient to be able to put a max size on the tiles + // so we can test the tiling logic without having to use huge inputs. + if (max_row_tile_size && (tile_rows > max_row_tile_size)) { tile_rows = max_row_tile_size; } + if (max_col_tile_size && (tile_cols > max_col_tile_size)) { tile_cols = max_col_tile_size; } + + // tile_cols must be at least k items + tile_cols = std::max(tile_cols, static_cast(k)); + + // stores pairwise distances for the current tile + rmm::device_uvector temp_distances(tile_rows * tile_cols, stream); + + // calculate norms for L2 expanded distances - this lets us avoid calculating + // norms repeatedly per-tile, and just do once for the entire input + auto pairwise_metric = metric; + rmm::device_uvector search_norms(0, stream); + rmm::device_uvector index_norms(0, stream); + if (metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded) { + search_norms.resize(m, stream); + index_norms.resize(n, stream); + raft::linalg::rowNorm( + search_norms.data(), search, d, m, raft::linalg::NormType::L2Norm, true, stream); + raft::linalg::rowNorm( + index_norms.data(), index, d, n, raft::linalg::NormType::L2Norm, true, stream); + pairwise_metric = raft::distance::DistanceType::InnerProduct; + } + + // if we're tiling over columns, we need additional buffers for temporary output + // distances/indices + size_t num_col_tiles = raft::ceildiv(n, tile_cols); + size_t temp_out_cols = k * num_col_tiles; + + // the final column tile could have less than 'k' items in it + // in which case the number of columns here is too high in the temp output. + // adjust if necessary + auto last_col_tile_size = n % tile_cols; + if (last_col_tile_size && (last_col_tile_size < static_cast(k))) { + temp_out_cols -= k - last_col_tile_size; + } + + // if we have less than k items in the index, we should fill out the result + // to indicate that we are missing items (and match behaviour in faiss) + if (n < static_cast(k)) { + raft::matrix::fill(handle, + raft::make_device_matrix_view(distances, m, static_cast(k)), + std::numeric_limits::lowest()); + + if constexpr (std::is_signed_v) { + raft::matrix::fill( + handle, raft::make_device_matrix_view(indices, m, static_cast(k)), IndexType{-1}); + } + } + + rmm::device_uvector temp_out_distances(tile_rows * temp_out_cols, stream); + rmm::device_uvector temp_out_indices(tile_rows * temp_out_cols, stream); + + bool select_min = raft::distance::is_min_close(metric); + + for (size_t i = 0; i < m; i += tile_rows) { + size_t current_query_size = std::min(tile_rows, m - i); + + for (size_t j = 0; j < n; j += tile_cols) { + size_t current_centroid_size = std::min(tile_cols, n - j); + size_t current_k = std::min(current_centroid_size, static_cast(k)); + + // calculate the top-k elements for the current tile, by calculating the + // full pairwise distance for the tile - and then selecting the top-k from that + // note: we're using a int32 IndexType here on purpose in order to + // use the pairwise_distance specializations. Since the tile size will ensure + // that the total memory is < 1GB per tile, this will not cause any issues + distance::pairwise_distance(handle, + search + i * d, + index + j * d, + temp_distances.data(), + current_query_size, + current_centroid_size, + d, + pairwise_metric, + true, + metric_arg); + if (metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded) { + auto row_norms = search_norms.data() + i; + auto col_norms = index_norms.data() + j; + auto dist = temp_distances.data(); + + raft::linalg::map_offset( + handle, + raft::make_device_vector_view(dist, current_query_size * current_centroid_size), + [=] __device__(IndexType i) { + IndexType row = i / current_centroid_size, col = i % current_centroid_size; + + auto val = row_norms[row] + col_norms[col] - 2.0 * dist[i]; + + // due to numerical instability (especially around self-distance) + // the distances here could be slightly negative, which will + // cause NaN values in the subsequent sqrt. Clamp to 0 + val = val * (val >= 0.0001); + if (metric == raft::distance::DistanceType::L2SqrtExpanded) { val = sqrt(val); } + return val; + }); + } + + select_k(temp_distances.data(), + nullptr, + current_query_size, + current_centroid_size, + distances + i * k, + indices + i * k, + select_min, + current_k, + stream); + + // if we're tiling over columns, we need to do a couple things to fix up + // the output of select_k + // 1. The column id's in the output are relative to the tile, so we need + // to adjust the column ids by adding the column the tile starts at (j) + // 2. select_k writes out output in a row-major format, which means we + // can't just concat the output of all the tiles and do a select_k on the + // concatenation. + // Fix both of these problems in a single pass here + if (tile_cols != n) { + const ElementType* in_distances = distances + i * k; + const IndexType* in_indices = indices + i * k; + ElementType* out_distances = temp_out_distances.data(); + IndexType* out_indices = temp_out_indices.data(); + + auto count = thrust::make_counting_iterator(0); + thrust::for_each(handle.get_thrust_policy(), + count, + count + current_query_size * current_k, + [=] __device__(IndexType i) { + IndexType row = i / current_k, col = i % current_k; + IndexType out_index = row * temp_out_cols + j * k / tile_cols + col; + + out_distances[out_index] = in_distances[i]; + out_indices[out_index] = in_indices[i] + j; + }); + } + } + + if (tile_cols != n) { + // select the actual top-k items here from the temporary output + select_k(temp_out_distances.data(), + temp_out_indices.data(), + current_query_size, + temp_out_cols, + distances + i * k, + indices + i * k, + select_min, + k, + stream); + } + } +} + +/** + * Search the kNN for the k-nearest neighbors of a set of query vectors + * @param[in] input vector of device device memory array pointers to search + * @param[in] sizes vector of memory sizes for each device array pointer in input + * @param[in] D number of cols in input and search_items + * @param[in] search_items set of vectors to query for neighbors + * @param[in] n number of items in search_items + * @param[out] res_I pointer to device memory for returning k nearest indices + * @param[out] res_D pointer to device memory for returning k nearest distances + * @param[in] k number of neighbors to query + * @param[in] userStream the main cuda stream to use + * @param[in] internalStreams optional when n_params > 0, the index partitions can be + * queried in parallel using these streams. Note that n_int_streams also + * has to be > 0 for these to be used and their cardinality does not need + * to correspond to n_parts. + * @param[in] n_int_streams size of internalStreams. When this is <= 0, only the + * user stream will be used. + * @param[in] rowMajorIndex are the index arrays in row-major layout? + * @param[in] rowMajorQuery are the query array in row-major layout? + * @param[in] translations translation ids for indices when index rows represent + * non-contiguous partitions + * @param[in] metric corresponds to the raft::distance::DistanceType enum (default is L2Expanded) + * @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm + */ +template +void brute_force_knn_impl( + raft::device_resources const& handle, + std::vector& input, + std::vector& sizes, + IntType D, + value_t* search_items, + IntType n, + IdxType* res_I, + value_t* res_D, + IntType k, + bool rowMajorIndex = true, + bool rowMajorQuery = true, + std::vector* translations = nullptr, + raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, + float metricArg = 0) +{ + auto userStream = handle.get_stream(); + + ASSERT(input.size() == sizes.size(), "input and sizes vectors should be the same size"); + + std::vector* id_ranges; + if (translations == nullptr) { + // If we don't have explicit translations + // for offsets of the indices, build them + // from the local partitions + id_ranges = new std::vector(); + IdxType total_n = 0; + for (size_t i = 0; i < input.size(); i++) { + id_ranges->push_back(total_n); + total_n += sizes[i]; + } + } else { + // otherwise, use the given translations + id_ranges = translations; + } + + // perform preprocessing + std::unique_ptr> query_metric_processor = + create_processor(metric, n, D, k, rowMajorQuery, userStream); + query_metric_processor->preprocess(search_items); + + std::vector>> metric_processors(input.size()); + for (size_t i = 0; i < input.size(); i++) { + metric_processors[i] = + create_processor(metric, sizes[i], D, k, rowMajorQuery, userStream); + metric_processors[i]->preprocess(input[i]); + } + + int device; + RAFT_CUDA_TRY(cudaGetDevice(&device)); + + rmm::device_uvector trans(id_ranges->size(), userStream); + raft::update_device(trans.data(), id_ranges->data(), id_ranges->size(), userStream); + + rmm::device_uvector all_D(0, userStream); + rmm::device_uvector all_I(0, userStream); + + value_t* out_D = res_D; + IdxType* out_I = res_I; + + if (input.size() > 1) { + all_D.resize(input.size() * k * n, userStream); + all_I.resize(input.size() * k * n, userStream); + + out_D = all_D.data(); + out_I = all_I.data(); + } + + // currently we don't support col_major inside tiled_brute_force_knn, because + // of limitattions of the pairwise_distance API: + // 1) paiwise_distance takes a single 'isRowMajor' parameter - and we have + // multiple options here (like rowMajorQuery/rowMajorIndex) + // 2) because of tiling, we need to be able to set a custom stride in the PW + // api, which isn't supported + // Instead, transpose the input matrices if they are passed as col-major. + auto search = search_items; + rmm::device_uvector search_row_major(0, userStream); + if (!rowMajorQuery) { + search_row_major.resize(n * D, userStream); + raft::linalg::transpose(handle, search, search_row_major.data(), n, D, userStream); + search = search_row_major.data(); + } + + // transpose into a temporary buffer if necessary + rmm::device_uvector index_row_major(0, userStream); + if (!rowMajorIndex) { + size_t total_size = 0; + for (auto size : sizes) { + total_size += size; + } + index_row_major.resize(total_size * D, userStream); + } + + // Make other streams from pool wait on main stream + handle.wait_stream_pool_on_stream(); + + size_t total_rows_processed = 0; + for (size_t i = 0; i < input.size(); i++) { + value_t* out_d_ptr = out_D + (i * k * n); + IdxType* out_i_ptr = out_I + (i * k * n); + + auto stream = handle.get_next_usable_stream(i); + + if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true && + (metric == raft::distance::DistanceType::L2Unexpanded || + metric == raft::distance::DistanceType::L2SqrtUnexpanded || + metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded)) { + fusedL2Knn(D, + out_i_ptr, + out_d_ptr, + input[i], + search_items, + sizes[i], + n, + k, + rowMajorIndex, + rowMajorQuery, + stream, + metric); + + // Perform necessary post-processing + if (metric == raft::distance::DistanceType::L2SqrtExpanded || + metric == raft::distance::DistanceType::L2SqrtUnexpanded || + metric == raft::distance::DistanceType::LpUnexpanded) { + float p = 0.5; // standard l2 + if (metric == raft::distance::DistanceType::LpUnexpanded) p = 1.0 / metricArg; + raft::linalg::unaryOp( + res_D, + res_D, + n * k, + [p] __device__(float input) { return powf(fabsf(input), p); }, + stream); + } + } else { + switch (metric) { + case raft::distance::DistanceType::Haversine: + ASSERT(D == 2, + "Haversine distance requires 2 dimensions " + "(latitude / longitude)."); + + haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream); + break; + default: + // Create a new handle with the current stream from the stream pool + raft::device_resources stream_pool_handle(handle); + raft::resource::set_cuda_stream(stream_pool_handle, stream); + + auto index = input[i]; + if (!rowMajorIndex) { + index = index_row_major.data() + total_rows_processed * D; + total_rows_processed += sizes[i]; + raft::linalg::transpose(handle, input[i], index, sizes[i], D, stream); + } + + // cosine/correlation are handled by metric processor, use IP distance + // for brute force knn call. + auto tiled_metric = metric; + if (metric == raft::distance::DistanceType::CosineExpanded || + metric == raft::distance::DistanceType::CorrelationExpanded) { + tiled_metric = raft::distance::DistanceType::InnerProduct; + } + + tiled_brute_force_knn(stream_pool_handle, + search, + index, + n, + sizes[i], + D, + k, + out_d_ptr, + out_i_ptr, + tiled_metric, + metricArg); + break; + } + } + + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + + // Sync internal streams if used. We don't need to + // sync the user stream because we'll already have + // fully serial execution. + handle.sync_stream_pool(); + + if (input.size() > 1 || translations != nullptr) { + // This is necessary for proper index translations. If there are + // no translations or partitions to combine, it can be skipped. + knn_merge_parts(out_D, out_I, res_D, res_I, n, input.size(), k, userStream, trans.data()); + } + + query_metric_processor->revert(search_items); + query_metric_processor->postprocess(out_D); + for (size_t i = 0; i < input.size(); i++) { + metric_processors[i]->revert(input[i]); + } + + if (translations == nullptr) delete id_ranges; +}; + +} // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh b/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh new file mode 100644 index 0000000000..e2b5c41fb0 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include +#include + +namespace raft::neighbors::detail { + +template +__global__ void knn_merge_parts_kernel(value_t* inK, + value_idx* inV, + value_t* outK, + value_idx* outV, + size_t n_samples, + int n_parts, + value_t initK, + value_idx initV, + int k, + value_idx* translations) +{ + constexpr int kNumWarps = tpb / WarpSize; + + __shared__ value_t smemK[kNumWarps * warp_q]; + __shared__ value_idx smemV[kNumWarps * warp_q]; + + /** + * Uses shared memory + */ + faiss_select:: + BlockSelect, warp_q, thread_q, tpb> + heap(initK, initV, smemK, smemV, k); + + // Grid is exactly sized to rows available + int row = blockIdx.x; + int total_k = k * n_parts; + + int i = threadIdx.x; + + // Get starting pointers for cols in current thread + int part = i / k; + size_t row_idx = (row * k) + (part * n_samples * k); + + int col = i % k; + + value_t* inKStart = inK + (row_idx + col); + value_idx* inVStart = inV + (row_idx + col); + + int limit = Pow2::roundDown(total_k); + value_idx translation = 0; + + for (; i < limit; i += tpb) { + translation = translations[part]; + heap.add(*inKStart, (*inVStart) + translation); + + part = (i + tpb) / k; + row_idx = (row * k) + (part * n_samples * k); + + col = (i + tpb) % k; + + inKStart = inK + (row_idx + col); + inVStart = inV + (row_idx + col); + } + + // Handle last remainder fraction of a warp of elements + if (i < total_k) { + translation = translations[part]; + heap.addThreadQ(*inKStart, (*inVStart) + translation); + } + + heap.reduce(); + + for (int i = threadIdx.x; i < k; i += tpb) { + outK[row * k + i] = smemK[i]; + outV[row * k + i] = smemV[i]; + } +} + +template +inline void knn_merge_parts_impl(value_t* inK, + value_idx* inV, + value_t* outK, + value_idx* outV, + size_t n_samples, + int n_parts, + int k, + cudaStream_t stream, + value_idx* translations) +{ + auto grid = dim3(n_samples); + + constexpr int n_threads = (warp_q <= 1024) ? 128 : 64; + auto block = dim3(n_threads); + + auto kInit = std::numeric_limits::max(); + auto vInit = -1; + knn_merge_parts_kernel + <<>>( + inK, inV, outK, outV, n_samples, n_parts, kInit, vInit, k, translations); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +/** + * @brief Merge knn distances and index matrix, which have been partitioned + * by row, into a single matrix with only the k-nearest neighbors. + * + * @param inK partitioned knn distance matrix + * @param inV partitioned knn index matrix + * @param outK merged knn distance matrix + * @param outV merged knn index matrix + * @param n_samples number of samples per partition + * @param n_parts number of partitions + * @param k number of neighbors per partition (also number of merged neighbors) + * @param stream CUDA stream to use + * @param translations mapping of index offsets for each partition + */ +template +inline void knn_merge_parts(value_t* inK, + value_idx* inV, + value_t* outK, + value_idx* outV, + size_t n_samples, + int n_parts, + int k, + cudaStream_t stream, + value_idx* translations) +{ + if (k == 1) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 32) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 64) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 128) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 256) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 512) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 1024) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); +} +} // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh index b264643584..f244d5875c 100644 --- a/cpp/include/raft/neighbors/detail/refine.cuh +++ b/cpp/include/raft/neighbors/detail/refine.cuh @@ -20,9 +20,9 @@ #include #include #include +#include +#include #include -#include -#include #include #include @@ -108,17 +108,17 @@ void refine_device(raft::device_resources const& handle, handle.get_thrust_policy(), fake_coarse_idx.data(), fake_coarse_idx.data() + n_queries); raft::neighbors::ivf_flat::index refinement_index( - handle, metric, n_queries, false, dim); + handle, metric, n_queries, false, true, dim); - raft::spatial::knn::ivf_flat::detail::fill_refinement_index(handle, - &refinement_index, - dataset.data_handle(), - neighbor_candidates.data_handle(), - n_queries, - n_candidates); + raft::neighbors::ivf_flat::detail::fill_refinement_index(handle, + &refinement_index, + dataset.data_handle(), + neighbor_candidates.data_handle(), + n_queries, + n_candidates); uint32_t grid_dim_x = 1; - raft::spatial::knn::ivf_flat::detail::ivfflat_interleaved_scan< + raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< data_t, typename raft::spatial::knn::detail::utils::config::value_t, idx_t>(refinement_index, @@ -128,7 +128,7 @@ void refine_device(raft::device_resources const& handle, refinement_index.metric(), 1, k, - raft::spatial::knn::ivf_flat::detail::is_min_close(metric), + raft::distance::is_min_close(metric), indices.data_handle(), distances.data_handle(), grid_dim_x, diff --git a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh b/cpp/include/raft/neighbors/detail/selection_faiss.cuh similarity index 95% rename from cpp/include/raft/spatial/knn/detail/selection_faiss.cuh rename to cpp/include/raft/neighbors/detail/selection_faiss.cuh index 5264f5d12e..5df42e94b9 100644 --- a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh +++ b/cpp/include/raft/neighbors/detail/selection_faiss.cuh @@ -16,16 +16,12 @@ #pragma once -#include #include #include -#include +#include -namespace raft { -namespace spatial { -namespace knn { -namespace detail { +namespace raft::neighbors::detail { template constexpr int kFaissMaxK() @@ -170,8 +166,4 @@ inline void select_k(const key_t* inK, else ASSERT(k <= max_k, "Current max k is %d (requested %d)", max_k, k); } - -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft +}; // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/ivf_flat.cuh b/cpp/include/raft/neighbors/ivf_flat.cuh index f18611b9f1..c573676504 100644 --- a/cpp/include/raft/neighbors/ivf_flat.cuh +++ b/cpp/include/raft/neighbors/ivf_flat.cuh @@ -16,9 +16,10 @@ #pragma once +#include +#include +#include #include -#include -#include #include @@ -67,7 +68,7 @@ auto build(raft::device_resources const& handle, IdxT n_rows, uint32_t dim) -> index { - return raft::spatial::knn::ivf_flat::detail::build(handle, params, dataset, n_rows, dim); + return raft::neighbors::ivf_flat::detail::build(handle, params, dataset, n_rows, dim); } /** @@ -99,7 +100,6 @@ auto build(raft::device_resources const& handle, * @tparam value_t data element type * @tparam idx_t type of the indices in the source dataset * @tparam int_t precision / type of integral arguments - * @tparam matrix_idx_t matrix indexing type * * @param[in] handle * @param[in] params configure the index building @@ -109,14 +109,61 @@ auto build(raft::device_resources const& handle, */ template auto build(raft::device_resources const& handle, + const index_params& params, + raft::device_matrix_view dataset) + -> index +{ + return raft::neighbors::ivf_flat::detail::build(handle, + params, + dataset.data_handle(), + static_cast(dataset.extent(0)), + static_cast(dataset.extent(1))); +} + +/** + * @brief Build the index from the dataset for efficient search. + * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * // use default index parameters + * ivf_flat::index_params index_params; + * // create and fill the index from a [N, D] dataset + * ivf_flat::index index; + * ivf_flat::build(handle, dataset, index_params, index); + * // use default search parameters + * ivf_flat::search_params search_params; + * // search K nearest neighbours for each of the N queries + * ivf_flat::search(handle, index, queries, out_inds, out_dists, search_params, k); + * @endcode + * + * @tparam value_t data element type + * @tparam idx_t type of the indices in the source dataset + * @tparam int_t precision / type of integral arguments + * @tparam matrix_idx_t matrix indexing type + * + * @param[in] handle + * @param[in] params configure the index building + * @param[in] dataset raft::device_matrix_view to a row-major matrix [n_rows, dim] + * @param[out] idx reference to ivf_flat::index + * + */ +template +void build(raft::device_resources const& handle, + const index_params& params, raft::device_matrix_view dataset, - const index_params& params) -> index + raft::neighbors::ivf_flat::index& idx) { - return raft::spatial::knn::ivf_flat::detail::build(handle, - params, - dataset.data_handle(), - static_cast(dataset.extent(0)), - static_cast(dataset.extent(1))); + idx = raft::neighbors::ivf_flat::detail::build(handle, + params, + dataset.data_handle(), + static_cast(dataset.extent(0)), + static_cast(dataset.extent(1))); } /** @} */ @@ -160,7 +207,7 @@ auto extend(raft::device_resources const& handle, const IdxT* new_indices, IdxT n_rows) -> index { - return raft::spatial::knn::ivf_flat::detail::extend( + return raft::neighbors::ivf_flat::detail::extend( handle, orig_index, new_vectors, new_indices, n_rows); } @@ -190,24 +237,21 @@ auto extend(raft::device_resources const& handle, * * @tparam value_t data element type * @tparam idx_t type of the indices in the source dataset - * @tparam int_t precision / type of integral arguments - * @tparam matrix_idx_t matrix indexing type * * @param[in] handle - * @param[in] orig_index original index - * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices a device pointer to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices optional raft::device_matrix_view to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` * here to imply a continuous range `[0...n_rows)`. + * @param[in] orig_index original index * * @return the constructed extended ivf-flat index */ template auto extend(raft::device_resources const& handle, - const index& orig_index, raft::device_matrix_view new_vectors, - std::optional> new_indices = std::nullopt) - -> index + std::optional> new_indices, + const index& orig_index) -> index { return extend( handle, @@ -252,7 +296,7 @@ void extend(raft::device_resources const& handle, const IdxT* new_indices, IdxT n_rows) { - *index = extend(handle, *index, new_vectors, new_indices, n_rows); + raft::neighbors::ivf_flat::detail::extend(handle, index, new_vectors, new_indices, n_rows); } /** @@ -272,32 +316,31 @@ void extend(raft::device_resources const& handle, * // train the index from a [N, D] dataset * auto index_empty = ivf_flat::build(handle, dataset, index_params, dataset); * // fill the index with the data - * ivf_flat::extend(handle, index_empty, dataset); + * std::optional> no_op = std::nullopt; + * ivf_flat::extend(handle, dataset, no_opt, &index_empty); * @endcode * * @tparam value_t data element type * @tparam idx_t type of the indices in the source dataset - * @tparam int_t precision / type of integral arguments - * @tparam matrix_idx_t matrix indexing type * * @param[in] handle - * @param[inout] index - * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices a device pointer to a vector of indices [n_rows]. + * @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices optional raft::device_matrix_view to a vector of indices [n_rows]. * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` * here to imply a continuous range `[0...n_rows)`. + * @param[inout] index pointer to index, to be overwritten in-place */ template void extend(raft::device_resources const& handle, - index* index, raft::device_matrix_view new_vectors, - std::optional> new_indices = std::nullopt) + std::optional> new_indices, + index* index) { - *index = extend(handle, - *index, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - static_cast(new_vectors.extent(0))); + extend(handle, + index, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + static_cast(new_vectors.extent(0))); } /** @} */ @@ -355,7 +398,7 @@ void search(raft::device_resources const& handle, float* distances, rmm::mr::device_memory_resource* mr = nullptr) { - return raft::spatial::knn::ivf_flat::detail::search( + return raft::neighbors::ivf_flat::detail::search( handle, params, index, queries, n_queries, k, neighbors, distances, mr); } @@ -388,33 +431,29 @@ void search(raft::device_resources const& handle, * @tparam value_t data element type * @tparam idx_t type of the indices * @tparam int_t precision / type of integral arguments - * @tparam matrix_idx_t matrix indexing type * * @param[in] handle + * @param[in] params configure the search * @param[in] index ivf-flat constructed index * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset * [n_queries, k] * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param[in] params configure the search - * @param[in] k the number of neighbors to find for each query. */ -template +template void search(raft::device_resources const& handle, + const search_params& params, const index& index, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - const search_params& params, - int_t k) + raft::device_matrix_view distances) { RAFT_EXPECTS( queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), "Number of rows in output neighbors and distances matrices must equal the number of queries."); - RAFT_EXPECTS( - neighbors.extent(1) == distances.extent(1) && neighbors.extent(1) == static_cast(k), - "Number of columns in output neighbors and distances matrices must equal k"); + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), + "Number of columns in output neighbors and distances matrices must be equal"); RAFT_EXPECTS(queries.extent(1) == index.dim(), "Number of query dimensions should equal number of dimensions in the index."); @@ -424,7 +463,7 @@ void search(raft::device_resources const& handle, index, queries.data_handle(), static_cast(queries.extent(0)), - static_cast(k), + static_cast(neighbors.extent(1)), neighbors.data_handle(), distances.data_handle(), nullptr); diff --git a/cpp/include/raft/neighbors/ivf_flat_serialize.cuh b/cpp/include/raft/neighbors/ivf_flat_serialize.cuh new file mode 100644 index 0000000000..77fce13e61 --- /dev/null +++ b/cpp/include/raft/neighbors/ivf_flat_serialize.cuh @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "detail/ivf_flat_serialize.cuh" + +namespace raft::neighbors::ivf_flat { + +/** + * \defgroup ivf_flat_serialize IVF-Flat Serialize + * @{ + */ + +/** + * Write the index to an output stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // create an output stream + * std::ostream os(std::cout.rdbuf()); + * // create an index with `auto index = ivf_flat::build(...);` + * raft::serialize(handle, os, index); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] os output stream + * @param[in] index IVF-Flat index + * + */ +template +void serialize(raft::device_resources const& handle, std::ostream& os, const index& index) +{ + detail::serialize(handle, os, index); +} + +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an index with `auto index = ivf_flat::build(...);` + * raft::serialize(handle, filename, index); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index IVF-Flat index + * + */ +template +void serialize(raft::device_resources const& handle, + const std::string& filename, + const index& index) +{ + detail::serialize(handle, filename, index); +} + +/** + * Load index from input stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // create an input stream + * std::istream is(std::cin.rdbuf()); + * using T = float; // data element type + * using IdxT = int; // type of the index + * auto index = raft::deserialize(handle, is); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] is input stream + * + * @return raft::neighbors::ivf_flat::index + */ +template +index deserialize(raft::device_resources const& handle, std::istream& is) +{ + return detail::deserialize(handle, is); +} + +/** + * Load index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * using T = float; // data element type + * using IdxT = int; // type of the index + * auto index = raft::deserialize(handle, filename); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * + * @return raft::neighbors::ivf_flat::index + */ +template +index deserialize(raft::device_resources const& handle, const std::string& filename) +{ + return detail::deserialize(handle, filename); +} + +/**@}*/ + +} // namespace raft::neighbors::ivf_flat diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp index d234822a23..2a6aa12847 100644 --- a/cpp/include/raft/neighbors/ivf_flat_types.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_types.hpp @@ -19,11 +19,17 @@ #include "ann_types.hpp" #include +#include #include +#include +#include #include +#include #include +#include #include +#include #include namespace raft::neighbors::ivf_flat { @@ -55,6 +61,16 @@ struct index_params : ann::index_params { * `index.centers()` "drift" together with the changing distribution of the newly added data. */ bool adaptive_centers = false; + /** + * By default, the algorithm allocates more space than necessary for individual clusters + * (`list_data`). This allows to amortize the cost of memory allocation and reduce the number of + * data copies during repeated calls to `extend` (extending the database). + * + * The alternative is the conservative allocation behavior; when enabled, the algorithm always + * allocates the minimum amount of memory required to store the given number of records. Set this + * flag to `true` if you prefer to use as little GPU memory for the database as possible. + */ + bool conservative_memory_allocation = false; }; struct search_params : ann::search_params { @@ -65,6 +81,40 @@ struct search_params : ann::search_params { static_assert(std::is_aggregate_v); static_assert(std::is_aggregate_v); +template +struct list_spec { + using value_type = ValueT; + using list_extents = matrix_extent; + using index_type = IdxT; + + SizeT align_max; + SizeT align_min; + uint32_t dim; + + constexpr list_spec(uint32_t dim, bool conservative_memory_allocation) + : dim(dim), + align_min(kIndexGroupSize), + align_max(conservative_memory_allocation ? kIndexGroupSize : 1024) + { + } + + // Allow casting between different size-types (for safer size and offset calculations) + template + constexpr explicit list_spec(const list_spec& other_spec) + : dim{other_spec.dim}, align_min{other_spec.align_min}, align_max{other_spec.align_max} + { + } + + /** Determine the extents of an array enough to hold a given amount of data. */ + constexpr auto make_list_extents(SizeT n_rows) const -> list_extents + { + return make_extents(n_rows, dim); + } +}; + +template +using list_data = ivf::list; + /** * @brief IVF-flat index. * @@ -118,59 +168,27 @@ struct index : ann::index { * x[16, 4], x[16, 5], x[17, 4], x[17, 5], ... x[30, 4], x[30, 5], - , - , * */ - inline auto data() noexcept -> device_mdspan, row_major> - { - return data_.view(); - } - [[nodiscard]] inline auto data() const noexcept - -> device_mdspan, row_major> - { - return data_.view(); - } - - /** Inverted list indices: ids of items in the source data [size] */ - inline auto indices() noexcept -> device_mdspan, row_major> - { - return indices_.view(); - } - [[nodiscard]] inline auto indices() const noexcept - -> device_mdspan, row_major> - { - return indices_.view(); - } - - /** Sizes of the lists (clusters) [n_lists] */ - inline auto list_sizes() noexcept -> device_mdspan, row_major> + /** Sizes of the lists (clusters) [n_lists] + * NB: This may differ from the actual list size if the shared lists have been extended by another + * index + */ + inline auto list_sizes() noexcept -> device_vector_view { return list_sizes_.view(); } [[nodiscard]] inline auto list_sizes() const noexcept - -> device_mdspan, row_major> + -> device_vector_view { return list_sizes_.view(); } - /** - * Offsets into the lists [n_lists + 1]. - * The last value contains the total length of the index. - */ - inline auto list_offsets() noexcept -> device_mdspan, row_major> - { - return list_offsets_.view(); - } - [[nodiscard]] inline auto list_offsets() const noexcept - -> device_mdspan, row_major> - { - return list_offsets_.view(); - } - /** k-means cluster centers corresponding to the lists [n_lists, dim] */ - inline auto centers() noexcept -> device_mdspan, row_major> + inline auto centers() noexcept -> device_matrix_view { return centers_.view(); } [[nodiscard]] inline auto centers() const noexcept - -> device_mdspan, row_major> + -> device_matrix_view { return centers_.view(); } @@ -181,39 +199,33 @@ struct index : ann::index { * NB: this may be empty if the index is empty or if the metric does not require the center norms * calculation. */ - inline auto center_norms() noexcept - -> std::optional, row_major>> + inline auto center_norms() noexcept -> std::optional> { if (center_norms_.has_value()) { - return std::make_optional, row_major>>( - center_norms_->view()); + return std::make_optional>(center_norms_->view()); } else { return std::nullopt; } } [[nodiscard]] inline auto center_norms() const noexcept - -> std::optional, row_major>> + -> std::optional> { if (center_norms_.has_value()) { - return std::make_optional, row_major>>( - center_norms_->view()); + return std::make_optional>(center_norms_->view()); } else { return std::nullopt; } } /** Total length of the index. */ - [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { return indices_.extent(0); } + [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { return total_size_; } /** Dimensionality of the data. */ [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t { return centers_.extent(1); } /** Number of clusters/inverted lists. */ - [[nodiscard]] constexpr inline auto n_lists() const noexcept -> uint32_t - { - return centers_.extent(0); - } + [[nodiscard]] constexpr inline auto n_lists() const noexcept -> uint32_t { return lists_.size(); } // Don't allow copying the index for performance reasons (try avoiding copying data) index(const index&) = delete; @@ -223,51 +235,111 @@ struct index : ann::index { ~index() = default; /** Construct an empty index. It needs to be trained and then populated. */ - index(raft::device_resources const& handle, + index(raft::device_resources const& res, raft::distance::DistanceType metric, uint32_t n_lists, bool adaptive_centers, + bool conservative_memory_allocation, uint32_t dim) : ann::index(), veclen_(calculate_veclen(dim)), metric_(metric), adaptive_centers_(adaptive_centers), - data_(make_device_mdarray(handle, make_extents(0, dim))), - indices_(make_device_mdarray(handle, make_extents(0))), - list_sizes_(make_device_mdarray(handle, make_extents(n_lists))), - list_offsets_(make_device_mdarray(handle, make_extents(n_lists + 1))), - centers_(make_device_mdarray(handle, make_extents(n_lists, dim))), - center_norms_(std::nullopt) + conservative_memory_allocation_{conservative_memory_allocation}, + centers_(make_device_matrix(res, n_lists, dim)), + center_norms_(std::nullopt), + lists_{n_lists}, + list_sizes_{make_device_vector(res, n_lists)}, + data_ptrs_{make_device_vector(res, n_lists)}, + inds_ptrs_{make_device_vector(res, n_lists)}, + total_size_{0} { check_consistency(); } /** Construct an empty index. It needs to be trained and then populated. */ - index(raft::device_resources const& handle, const index_params& params, uint32_t dim) - : index(handle, params.metric, params.n_lists, params.adaptive_centers, dim) + index(raft::device_resources const& res, const index_params& params, uint32_t dim) + : index(res, + params.metric, + params.n_lists, + params.adaptive_centers, + params.conservative_memory_allocation, + dim) { } + /** Pointers to the inverted lists (clusters) data [n_lists]. */ + inline auto data_ptrs() noexcept -> device_vector_view { return data_ptrs_.view(); } + [[nodiscard]] inline auto data_ptrs() const noexcept -> device_vector_view + { + return data_ptrs_.view(); + } + + /** Pointers to the inverted lists (clusters) indices [n_lists]. */ + inline auto inds_ptrs() noexcept -> device_vector_view + { + return inds_ptrs_.view(); + } + [[nodiscard]] inline auto inds_ptrs() const noexcept -> device_vector_view + { + return inds_ptrs_.view(); + } /** - * Replace the content of the index with new uninitialized mdarrays to hold the indicated amount - * of data. + * Whether to use convervative memory allocation when extending the list (cluster) data + * (see index_params.conservative_memory_allocation). */ - void allocate(raft::device_resources const& handle, IdxT index_size) + [[nodiscard]] constexpr inline auto conservative_memory_allocation() const noexcept -> bool { - data_ = make_device_mdarray(handle, make_extents(index_size, dim())); - indices_ = make_device_mdarray(handle, make_extents(index_size)); + return conservative_memory_allocation_; + } + /** + * Update the state of the dependent index members. + */ + void recompute_internal_state(raft::device_resources const& res) + { + auto stream = res.get_stream(); + + // Actualize the list pointers + auto this_lists = lists(); + auto this_data_ptrs = data_ptrs(); + auto this_inds_ptrs = inds_ptrs(); + IdxT recompute_total_size = 0; + for (uint32_t label = 0; label < this_lists.size(); label++) { + auto& list = this_lists[label]; + const auto data_ptr = list ? list->data.data_handle() : nullptr; + const auto inds_ptr = list ? list->indices.data_handle() : nullptr; + const auto list_size = list ? IdxT(list->size) : 0; + copy(&this_data_ptrs(label), &data_ptr, 1, stream); + copy(&this_inds_ptrs(label), &inds_ptr, 1, stream); + recompute_total_size += list_size; + } + total_size_ = recompute_total_size; + check_consistency(); + } + + void allocate_center_norms(raft::device_resources const& res) + { switch (metric_) { case raft::distance::DistanceType::L2Expanded: case raft::distance::DistanceType::L2SqrtExpanded: case raft::distance::DistanceType::L2Unexpanded: case raft::distance::DistanceType::L2SqrtUnexpanded: - center_norms_ = make_device_mdarray(handle, make_extents(n_lists())); + center_norms_ = make_device_vector(res, n_lists()); break; default: center_norms_ = std::nullopt; } + } - check_consistency(); + /** Lists' data and indices. */ + inline auto lists() noexcept -> std::vector>>& + { + return lists_; + } + [[nodiscard]] inline auto lists() const noexcept + -> const std::vector>>& + { + return lists_; } private: @@ -278,26 +350,29 @@ struct index : ann::index { uint32_t veclen_; raft::distance::DistanceType metric_; bool adaptive_centers_; - device_mdarray, row_major> data_; - device_mdarray, row_major> indices_; - device_mdarray, row_major> list_sizes_; - device_mdarray, row_major> list_offsets_; - device_mdarray, row_major> centers_; - std::optional, row_major>> center_norms_; + bool conservative_memory_allocation_; + std::vector>> lists_; + device_vector list_sizes_; + device_matrix centers_; + std::optional> center_norms_; + + // Computed members + device_vector data_ptrs_; + device_vector inds_ptrs_; + IdxT total_size_; /** Throw an error if the index content is inconsistent. */ void check_consistency() { + auto n_lists = lists_.size(); RAFT_EXPECTS(dim() % veclen_ == 0, "dimensionality is not a multiple of the veclen"); - RAFT_EXPECTS(data_.extent(0) == indices_.extent(0), "inconsistent index size"); - RAFT_EXPECTS(data_.extent(1) == IdxT(centers_.extent(1)), "inconsistent data dimensionality"); - RAFT_EXPECTS( // - (centers_.extent(0) == list_sizes_.extent(0)) && // - (centers_.extent(0) + 1 == list_offsets_.extent(0)) && // + RAFT_EXPECTS(list_sizes_.extent(0) == n_lists, "inconsistent list size"); + RAFT_EXPECTS(data_ptrs_.extent(0) == n_lists, "inconsistent list size"); + RAFT_EXPECTS(inds_ptrs_.extent(0) == n_lists, "inconsistent list size"); + RAFT_EXPECTS( // + (centers_.extent(0) == list_sizes_.extent(0)) && // (!center_norms_.has_value() || centers_.extent(0) == center_norms_->extent(0)), "inconsistent number of lists (clusters)"); - RAFT_EXPECTS(reinterpret_cast(data_.data_handle()) % (veclen_ * sizeof(T)) == 0, - "The data storage pointer is not aligned to the vector length"); } static auto calculate_veclen(uint32_t dim) -> uint32_t diff --git a/cpp/include/raft/neighbors/ivf_list.hpp b/cpp/include/raft/neighbors/ivf_list.hpp index 4644143057..a0ba001f77 100644 --- a/cpp/include/raft/neighbors/ivf_list.hpp +++ b/cpp/include/raft/neighbors/ivf_list.hpp @@ -35,11 +35,13 @@ namespace raft::neighbors::ivf { /** The data for a single IVF list. */ -template