From 667d873876fb1bb8669a2b9a5d1105b5ad67ca6e Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 14 Mar 2023 14:10:29 -0400 Subject: [PATCH 01/11] Small updates to docs (#1339) RAFT is getting a little more attention and I'm just updating a few things in the docs to make them look more polished. Authors: - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Ben Frederickson (https://github.com/benfred) URL: https://github.com/rapidsai/raft/pull/1339 --- README.md | 2 +- docs/source/index.rst | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 773d98e23a..a178d90008 100755 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ RAFT relies heavily on RMM which eases the burden of configuring different alloc ### Multi-dimensional Arrays -The APIs in RAFT currently accept raw pointers to device memory and we are in the process of simplifying the APIs with the [mdspan](https://arxiv.org/abs/2010.06474) multi-dimensional array view for representing data in higher dimensions similar to the `ndarray` in the Numpy Python library. RAFT also contains the corresponding owning `mdarray` structure, which simplifies the allocation and management of multi-dimensional data in both host and device (GPU) memory. +The APIs in RAFT accept the [mdspan](https://arxiv.org/abs/2010.06474) multi-dimensional array view for representing data in higher dimensions similar to the `ndarray` in the Numpy Python library. RAFT also contains the corresponding owning `mdarray` structure, which simplifies the allocation and management of multi-dimensional data in both host and device (GPU) memory. The `mdarray` forms a convenience layer over RMM and can be constructed in RAFT using a number of different helper functions: diff --git a/docs/source/index.rst b/docs/source/index.rst index 2418c6a767..814899c36b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,5 +1,5 @@ -Welcome to RAFT's documentation! -================================= +RAPIDS RAFT: Reusable Accelerated Functions and Tools +===================================================== RAFT contains fundamental widely-used algorithms and primitives for scientific computing, data science and machine learning. The algorithms are CUDA-accelerated and form building-blocks for rapidly composing analytics. From 03e26b520ea281d9b239938eda30dd5100559e24 Mon Sep 17 00:00:00 2001 From: Micka Date: Wed, 15 Mar 2023 02:31:02 +0100 Subject: [PATCH 02/11] IVF-Flat index splitting (#1271) Refactor of `ivf_flat::index` to split the cluster data in separate buffers - Addressing #1170. - Following #1249 in the index structure implementation. - Adding serialization API with stream and filename overloads - Moving `raft::spatial::knn::ivf_flat` namespace to `raft::neighbors::ivf_flat` Authors: - Micka (https://github.com/lowener) - Corey J. Nolet (https://github.com/cjnolet) - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - AJ Schmidt (https://github.com/ajschmidt8) - Artem M. Chirkin (https://github.com/achirkin) URL: https://github.com/rapidsai/raft/pull/1271 --- .../detail/ivf_flat_build.cuh | 369 +++++++----------- .../detail/ivf_flat_search.cuh | 32 +- .../neighbors/detail/ivf_flat_serialize.cuh | 176 +++++++++ .../raft/neighbors/detail/ivf_pq_build.cuh | 7 +- .../neighbors/detail/ivf_pq_serialize.cuh | 12 +- cpp/include/raft/neighbors/detail/refine.cuh | 22 +- cpp/include/raft/neighbors/ivf_flat.cuh | 39 +- .../raft/neighbors/ivf_flat_serialize.cuh | 156 ++++++++ cpp/include/raft/neighbors/ivf_flat_types.hpp | 236 +++++++---- cpp/include/raft/neighbors/ivf_list.hpp | 88 +++-- cpp/include/raft/neighbors/ivf_list_types.hpp | 38 +- cpp/include/raft/neighbors/ivf_pq_types.hpp | 7 +- cpp/test/neighbors/ann_ivf_flat.cuh | 49 ++- 13 files changed, 797 insertions(+), 434 deletions(-) rename cpp/include/raft/{spatial/knn => neighbors}/detail/ivf_flat_build.cuh (54%) rename cpp/include/raft/{spatial/knn => neighbors}/detail/ivf_flat_search.cuh (98%) create mode 100644 cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh create mode 100644 cpp/include/raft/neighbors/ivf_flat_serialize.cuh diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh similarity index 54% rename from cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh rename to cpp/include/raft/neighbors/detail/ivf_flat_build.cuh index c417a97531..bf7248b983 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh @@ -16,31 +16,68 @@ #pragma once -#include "../ivf_flat_types.hpp" -#include "ann_utils.cuh" - #include #include #include #include #include #include -#include #include #include #include +#include +#include +#include +#include #include #include #include #include -#include -namespace raft::spatial::knn::ivf_flat::detail { +namespace raft::neighbors::ivf_flat::detail { using namespace raft::spatial::knn::detail; // NOLINT +template +auto clone(const raft::device_resources& res, const index& source) -> index +{ + auto stream = res.get_stream(); + + // Allocate the new index + index target(res, + source.metric(), + source.n_lists(), + source.adaptive_centers(), + source.conservative_memory_allocation(), + source.dim()); + + // Copy the independent parts + copy(target.list_sizes().data_handle(), + source.list_sizes().data_handle(), + source.list_sizes().size(), + stream); + copy(target.centers().data_handle(), + source.centers().data_handle(), + source.centers().size(), + stream); + if (source.center_norms().has_value()) { + target.allocate_center_norms(res); + copy(target.center_norms()->data_handle(), + source.center_norms()->data_handle(), + source.center_norms()->size(), + stream); + } + // Copy shared pointers + target.lists() = source.lists(); + + // Make sure the device pointers point to the new lists + target.recompute_internal_state(res); + + return target; +} + /** * @brief Record the dataset into the index, one source row at a time. * @@ -60,11 +97,11 @@ using namespace raft::spatial::knn::detail; // NOLINT * we use source_vecs[source_ixs[i],:]. In both cases i=0..n_rows-1. * * @param[in] labels device pointer to the cluster ids for each row [n_rows] - * @param[in] list_offsets device pointer to the cluster offsets in the output (index) [n_lists] * @param[in] source_vecs device pointer to the input data [n_rows, dim] * @param[in] source_ixs device pointer to the input indices [n_rows] - * @param[out] list_data device pointer to the output [index_size, dim] - * @param[out] list_index device pointer to the source ids corr. to the output [index_size] + * @param[out] list_data_ptrs device pointer to the index data of size [n_lists][index_size, dim] + * @param[out] list_index_ptrs device pointer to the source ids corr. to the output [n_lists] + * [index_size] * @param[out] list_sizes_ptr device pointer to the cluster sizes [n_lists]; * it's used as an atomic counter, and must be initialized with zeros. * @param n_rows source length @@ -74,11 +111,10 @@ using namespace raft::spatial::knn::detail; // NOLINT */ template __global__ void build_index_kernel(const LabelT* labels, - const IdxT* list_offsets, const T* source_vecs, const IdxT* source_ixs, - T* list_data, - IdxT* list_index, + T** list_data_ptrs, + IdxT** list_index_ptrs, uint32_t* list_sizes_ptr, IdxT n_rows, uint32_t dim, @@ -89,10 +125,11 @@ __global__ void build_index_kernel(const LabelT* labels, auto list_id = labels[i]; auto inlist_id = atomicAdd(list_sizes_ptr + list_id, 1); - auto list_offset = list_offsets[list_id]; + auto* list_index = list_index_ptrs[list_id]; + auto* list_data = list_data_ptrs[list_id]; // Record the source vector id in the index - list_index[list_offset + inlist_id] = source_ixs == nullptr ? i : source_ixs[i]; + list_index[inlist_id] = source_ixs == nullptr ? i : source_ixs[i]; // The data is written in interleaved groups of `index::kGroupSize` vectors using interleaved_group = Pow2; @@ -100,7 +137,7 @@ __global__ void build_index_kernel(const LabelT* labels, auto ingroup_id = interleaved_group::mod(inlist_id) * veclen; // Point to the location of the interleaved group of vectors - list_data += (list_offset + group_offset) * dim; + list_data += group_offset * dim; // Point to the source vector if constexpr (gather_src) { @@ -117,58 +154,53 @@ __global__ void build_index_kernel(const LabelT* labels, } } -/** See raft::spatial::knn::ivf_flat::extend docs */ +/** See raft::neighbors::ivf_flat::extend docs */ template -inline auto extend(raft::device_resources const& handle, - const index& orig_index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -> index +void extend(raft::device_resources const& handle, + index* index, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) { using LabelT = uint32_t; + RAFT_EXPECTS(index != nullptr, "index cannot be empty."); auto stream = handle.get_stream(); - auto n_lists = orig_index.n_lists(); - auto dim = orig_index.dim(); + auto n_lists = index->n_lists(); + auto dim = index->dim(); + list_spec list_device_spec{index->dim(), + index->conservative_memory_allocation()}; common::nvtx::range fun_scope( "ivf_flat::extend(%zu, %u)", size_t(n_rows), dim); - RAFT_EXPECTS(new_indices != nullptr || orig_index.size() == 0, + RAFT_EXPECTS(new_indices != nullptr || index->size() == 0, "You must pass data indices when the index is non-empty."); - rmm::device_uvector new_labels(n_rows, stream); + auto new_labels = raft::make_device_vector(handle, n_rows); raft::cluster::kmeans_balanced_params kmeans_params; - kmeans_params.metric = orig_index.metric(); - auto new_vectors_view = raft::make_device_matrix_view(new_vectors, n_rows, dim); - auto orig_centroids_view = raft::make_device_matrix_view( - orig_index.centers().data_handle(), n_lists, dim); - auto labels_view = raft::make_device_vector_view(new_labels.data(), n_rows); + kmeans_params.metric = index->metric(); + auto new_vectors_view = raft::make_device_matrix_view(new_vectors, n_rows, dim); + auto orig_centroids_view = + raft::make_device_matrix_view(index->centers().data_handle(), n_lists, dim); raft::cluster::kmeans_balanced::predict(handle, kmeans_params, new_vectors_view, orig_centroids_view, - labels_view, + new_labels.view(), utils::mapping{}); - index ext_index( - handle, orig_index.metric(), n_lists, orig_index.adaptive_centers(), dim); - - auto list_sizes_ptr = ext_index.list_sizes().data_handle(); - auto list_offsets_ptr = ext_index.list_offsets().data_handle(); - auto centers_ptr = ext_index.centers().data_handle(); + auto* list_sizes_ptr = index->list_sizes().data_handle(); + auto old_list_sizes_dev = raft::make_device_vector(handle, n_lists); + copy(old_list_sizes_dev.data_handle(), list_sizes_ptr, n_lists, stream); // Calculate the centers and sizes on the new data, starting from the original values - raft::copy(centers_ptr, orig_index.centers().data_handle(), ext_index.centers().size(), stream); - - if (ext_index.adaptive_centers()) { - raft::copy( - list_sizes_ptr, orig_index.list_sizes().data_handle(), ext_index.list_sizes().size(), stream); - auto centroids_view = raft::make_device_matrix_view(centers_ptr, n_lists, dim); + if (index->adaptive_centers()) { + auto centroids_view = raft::make_device_matrix_view( + index->centers().data_handle(), index->centers().extent(0), index->centers().extent(1)); auto list_sizes_view = raft::make_device_vector_view, IdxT>( list_sizes_ptr, n_lists); - auto const_labels_view = - raft::make_device_vector_view(new_labels.data(), n_rows); + auto const_labels_view = make_const_mdspan(new_labels.view()); raft::cluster::kmeans_balanced::helpers::calc_centers_and_sizes(handle, new_vectors_view, const_labels_view, @@ -180,90 +212,89 @@ inline auto extend(raft::device_resources const& handle, raft::stats::histogram(raft::stats::HistTypeAuto, reinterpret_cast(list_sizes_ptr), IdxT(n_lists), - new_labels.data(), + new_labels.data_handle(), n_rows, 1, stream); raft::linalg::add( - list_sizes_ptr, list_sizes_ptr, orig_index.list_sizes().data_handle(), n_lists, stream); + list_sizes_ptr, list_sizes_ptr, old_list_sizes_dev.data_handle(), n_lists, stream); } - // Calculate new offsets - IdxT index_size = 0; - update_device(list_offsets_ptr, &index_size, 1, stream); - thrust::inclusive_scan( - rmm::exec_policy(stream), - list_sizes_ptr, - list_sizes_ptr + n_lists, - list_offsets_ptr + 1, - [] __device__(IdxT s, uint32_t l) { return s + Pow2::roundUp(l); }); - update_host(&index_size, list_offsets_ptr + n_lists, 1, stream); - handle.sync_stream(stream); - - ext_index.allocate(handle, index_size); - - // Populate index with the old data - if (orig_index.size() > 0) { - utils::block_copy(orig_index.list_offsets().data_handle(), - list_offsets_ptr, - IdxT(n_lists), - orig_index.data().data_handle(), - ext_index.data().data_handle(), - IdxT(dim), - stream); - - utils::block_copy(orig_index.list_offsets().data_handle(), - list_offsets_ptr, - IdxT(n_lists), - orig_index.indices().data_handle(), - ext_index.indices().data_handle(), - IdxT(1), - stream); + // Calculate and allocate new list data + std::vector new_list_sizes(n_lists); + std::vector old_list_sizes(n_lists); + { + copy(old_list_sizes.data(), old_list_sizes_dev.data_handle(), n_lists, stream); + copy(new_list_sizes.data(), list_sizes_ptr, n_lists, stream); + handle.sync_stream(); + auto& lists = index->lists(); + for (uint32_t label = 0; label < n_lists; label++) { + ivf::resize_list(handle, + lists[label], + list_device_spec, + new_list_sizes[label], + Pow2::roundUp(old_list_sizes[label])); + } } - + // Update the pointers and the sizes + index->recompute_internal_state(handle); // Copy the old sizes, so we can start from the current state of the index; // we'll rebuild the `list_sizes_ptr` in the following kernel, using it as an atomic counter. - raft::copy( - list_sizes_ptr, orig_index.list_sizes().data_handle(), ext_index.list_sizes().size(), stream); + raft::copy(list_sizes_ptr, old_list_sizes_dev.data_handle(), n_lists, stream); + // Kernel to insert the new vectors const dim3 block_dim(256); const dim3 grid_dim(raft::ceildiv(n_rows, block_dim.x)); - build_index_kernel<<>>(new_labels.data(), - list_offsets_ptr, + build_index_kernel<<>>(new_labels.data_handle(), new_vectors, new_indices, - ext_index.data().data_handle(), - ext_index.indices().data_handle(), + index->data_ptrs().data_handle(), + index->inds_ptrs().data_handle(), list_sizes_ptr, n_rows, dim, - ext_index.veclen()); + index->veclen()); RAFT_CUDA_TRY(cudaPeekAtLastError()); // Precompute the centers vector norms for L2Expanded distance - if (ext_index.center_norms().has_value()) { - if (!ext_index.adaptive_centers() && orig_index.center_norms().has_value()) { - raft::copy(ext_index.center_norms()->data_handle(), - orig_index.center_norms()->data_handle(), - orig_index.center_norms()->size(), - stream); - } else { - raft::linalg::rowNorm(ext_index.center_norms()->data_handle(), - ext_index.centers().data_handle(), + if (!index->center_norms().has_value()) { + index->allocate_center_norms(handle); + if (index->center_norms().has_value()) { + raft::linalg::rowNorm(index->center_norms()->data_handle(), + index->centers().data_handle(), dim, n_lists, raft::linalg::L2Norm, true, stream); - RAFT_LOG_TRACE_VEC(ext_index.center_norms()->data_handle(), std::min(dim, 20)); + RAFT_LOG_TRACE_VEC(index->center_norms()->data_handle(), std::min(dim, 20)); } + } else if (index->center_norms().has_value() && index->adaptive_centers()) { + raft::linalg::rowNorm(index->center_norms()->data_handle(), + index->centers().data_handle(), + dim, + n_lists, + raft::linalg::L2Norm, + true, + stream); + RAFT_LOG_TRACE_VEC(index->center_norms()->data_handle(), std::min(dim, 20)); } +} - // assemble the index +/** See raft::neighbors::ivf_flat::extend docs */ +template +auto extend(raft::device_resources const& handle, + const index& orig_index, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) -> index +{ + auto ext_index = clone(handle, orig_index); + detail::extend(handle, &ext_index, new_vectors, new_indices, n_rows); return ext_index; } -/** See raft::spatial::knn::ivf_flat::build docs */ +/** See raft::neighbors::ivf_flat::build docs */ template inline auto build(raft::device_resources const& handle, const index_params& params, @@ -280,7 +311,8 @@ inline auto build(raft::device_resources const& handle, index index(handle, params, dim); utils::memzero(index.list_sizes().data_handle(), index.list_sizes().size(), stream); - utils::memzero(index.list_offsets().data_handle(), index.list_offsets().size(), stream); + utils::memzero(index.data_ptrs().data_handle(), index.data_ptrs().size(), stream); + utils::memzero(index.inds_ptrs().data_handle(), index.inds_ptrs().size(), stream); // Train the kmeans clustering { @@ -310,10 +342,9 @@ inline auto build(raft::device_resources const& handle, // add the data if necessary if (params.add_data_on_build) { - return detail::extend(handle, index, dataset, nullptr, n_rows); - } else { - return index; + detail::extend(handle, &index, dataset, nullptr, n_rows); } + return index; } /** @@ -356,20 +387,17 @@ inline void fill_refinement_index(raft::device_resources const& handle, new_labels_view, raft::compose_op(raft::cast_op(), raft::div_const_op(n_candidates))); - auto list_sizes_ptr = refinement_index->list_sizes().data_handle(); - auto list_offsets_ptr = refinement_index->list_offsets().data_handle(); + auto list_sizes_ptr = refinement_index->list_sizes().data_handle(); // We do not fill centers and center norms, since we will not run coarse search. - // Calculate new offsets - uint32_t n_roundup = Pow2::roundUp(n_candidates); - auto list_offsets_view = raft::make_device_vector_view( - list_offsets_ptr, refinement_index->list_offsets().size()); - linalg::map_offset(handle, - list_offsets_view, - raft::compose_op(raft::cast_op(), raft::mul_const_op(n_roundup))); - - IdxT index_size = n_roundup * n_lists; - refinement_index->allocate(handle, index_size); + // Allocate new memory + auto& lists = refinement_index->lists(); + list_spec list_device_spec{refinement_index->dim(), false}; + for (uint32_t label = 0; label < n_lists; label++) { + ivf::resize_list(handle, lists[label], list_device_spec, n_candidates, uint32_t(0)); + } + // Update the pointers and the sizes + refinement_index->recompute_internal_state(handle); RAFT_CUDA_TRY(cudaMemsetAsync(list_sizes_ptr, 0, n_lists * sizeof(uint32_t), stream)); @@ -377,121 +405,14 @@ inline void fill_refinement_index(raft::device_resources const& handle, const dim3 grid_dim(raft::ceildiv(n_queries * n_candidates, block_dim.x)); build_index_kernel <<>>(new_labels.data(), - list_offsets_ptr, dataset, candidate_idx, - refinement_index->data().data_handle(), - refinement_index->indices().data_handle(), + refinement_index->data_ptrs().data_handle(), + refinement_index->inds_ptrs().data_handle(), list_sizes_ptr, n_queries * n_candidates, refinement_index->dim(), refinement_index->veclen()); RAFT_CUDA_TRY(cudaPeekAtLastError()); } - -// Serialization version 2 -// No backward compatibility yet; that is, can't add additional fields without breaking -// backward compatibility. -// TODO(hcho3) Implement next-gen serializer for IVF that allows for expansion in a backward -// compatible fashion. -constexpr int serialization_version = 2; - -static_assert(sizeof(index) == 408, - "The size of the index struct has changed since the last update; " - "paste in the new size and consider updating the save/load logic"); - -/** - * Save the index to file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @param[in] handle the raft handle - * @param[in] filename the file name for saving the index - * @param[in] index_ IVF-Flat index - * - */ -template -void serialize(raft::device_resources const& handle, - const std::string& filename, - const index& index_) -{ - std::ofstream of(filename, std::ios::out | std::ios::binary); - if (!of) { RAFT_FAIL("Cannot open %s", filename.c_str()); } - - RAFT_LOG_DEBUG( - "Saving IVF-PQ index, size %zu, dim %u", static_cast(index_.size()), index_.dim()); - serialize_scalar(handle, of, serialization_version); - serialize_scalar(handle, of, index_.size()); - serialize_scalar(handle, of, index_.dim()); - serialize_scalar(handle, of, index_.n_lists()); - serialize_scalar(handle, of, index_.metric()); - serialize_scalar(handle, of, index_.veclen()); - serialize_scalar(handle, of, index_.adaptive_centers()); - serialize_mdspan(handle, of, index_.data()); - serialize_mdspan(handle, of, index_.indices()); - serialize_mdspan(handle, of, index_.list_sizes()); - serialize_mdspan(handle, of, index_.list_offsets()); - serialize_mdspan(handle, of, index_.centers()); - if (index_.center_norms()) { - bool has_norms = true; - serialize_scalar(handle, of, has_norms); - serialize_mdspan(handle, of, *index_.center_norms()); - } else { - bool has_norms = false; - serialize_scalar(handle, of, has_norms); - } - of.close(); - if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } -} - -/** Load an index from file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @param[in] handle the raft handle - * @param[in] filename the name of the file that stores the index - * @param[in] index_ IVF-Flat index - * - */ -template -auto deserialize(raft::device_resources const& handle, const std::string& filename) - -> index -{ - std::ifstream infile(filename, std::ios::in | std::ios::binary); - - if (!infile) { RAFT_FAIL("Cannot open %s", filename.c_str()); } - - auto ver = deserialize_scalar(handle, infile); - if (ver != serialization_version) { - RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); - } - auto n_rows = deserialize_scalar(handle, infile); - auto dim = deserialize_scalar(handle, infile); - auto n_lists = deserialize_scalar(handle, infile); - auto metric = deserialize_scalar(handle, infile); - auto veclen = deserialize_scalar(handle, infile); - bool adaptive_centers = deserialize_scalar(handle, infile); - - index index_ = - raft::spatial::knn::ivf_flat::index(handle, metric, n_lists, adaptive_centers, dim); - - index_.allocate(handle, n_rows); - auto data = index_.data(); - deserialize_mdspan(handle, infile, data); - deserialize_mdspan(handle, infile, index_.indices()); - deserialize_mdspan(handle, infile, index_.list_sizes()); - deserialize_mdspan(handle, infile, index_.list_offsets()); - deserialize_mdspan(handle, infile, index_.centers()); - bool has_norms = deserialize_scalar(handle, infile); - if (has_norms) { - if (!index_.center_norms()) { - RAFT_FAIL("Error inconsistent center norms"); - } else { - auto center_norms = *index_.center_norms(); - deserialize_mdspan(handle, infile, center_norms); - } - } - infile.close(); - return index_; -} -} // namespace raft::spatial::knn::ivf_flat::detail +} // namespace raft::neighbors::ivf_flat::detail diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh similarity index 98% rename from cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh rename to cpp/include/raft/neighbors/detail/ivf_flat_search.cuh index 7f70d4b8a5..b2bfd18610 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh @@ -16,9 +16,6 @@ #pragma once -#include "../ivf_flat_types.hpp" -#include "ann_utils.cuh" - #include #include #include @@ -30,6 +27,8 @@ #include #include #include +#include +#include #include #include #include @@ -39,7 +38,7 @@ #include #include -namespace raft::spatial::knn::ivf_flat::detail { +namespace raft::neighbors::ivf_flat::detail { using namespace raft::spatial::knn::detail; // NOLINT @@ -673,10 +672,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock) const uint32_t query_smem_elems, const T* query, const uint32_t* coarse_index, - const IdxT* list_indices, - const T* list_data, + const IdxT* const* list_indices_ptrs, + const T* const* list_data_ptrs, const uint32_t* list_sizes, - const IdxT* list_offsets, const uint32_t n_probes, const uint32_t k, const uint32_t dim, @@ -722,8 +720,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) // Every CUDA block scans one cluster at a time. for (int probe_id = blockIdx.x; probe_id < n_probes; probe_id += gridDim.x) { - const uint32_t list_id = coarse_index[probe_id]; // The id of cluster(list) - const size_t list_offset = list_offsets[list_id]; + const uint32_t list_id = coarse_index[probe_id]; // The id of cluster(list) // The number of vectors in each cluster(list); [nlist] const uint32_t list_length = list_sizes[list_id]; @@ -741,7 +738,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) group_id += kNumWarps) { AccT dist = 0; // This is where this warp begins reading data (start position of an interleaved group) - const T* data = list_data + (list_offset + group_id * kIndexGroupSize) * dim; + const T* data = list_data_ptrs[list_id] + (group_id * kIndexGroupSize) * dim; // This is the vector a given lane/thread handles const uint32_t vec_id = group_id * WarpSize + lane_id; @@ -778,7 +775,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) // Enqueue one element per thread const float val = valid ? static_cast(dist) : block_sort_t::queue_t::kDummy; - const size_t idx = valid ? static_cast(list_indices[list_offset + vec_id]) : 0; + const size_t idx = valid ? static_cast(list_indices_ptrs[list_id][vec_id]) : 0; queue.add(val, idx); } } @@ -819,7 +816,7 @@ template void launch_kernel(Lambda lambda, PostLambda post_process, - const ivf_flat::index& index, + const index& index, const T* queries, const uint32_t* coarse_index, const uint32_t num_queries, @@ -869,10 +866,9 @@ void launch_kernel(Lambda lambda, query_smem_elems, queries, coarse_index, - index.indices().data_handle(), - index.data().data_handle(), + index.inds_ptrs().data_handle(), + index.data_ptrs().data_handle(), index.list_sizes().data_handle(), - index.list_offsets().data_handle(), n_probes, k, index.dim(), @@ -1056,7 +1052,7 @@ struct select_interleaved_scan_kernel { * @param stream */ template -void ivfflat_interleaved_scan(const ivf_flat::index& index, +void ivfflat_interleaved_scan(const index& index, const T* queries, const uint32_t* coarse_query_results, const uint32_t n_queries, @@ -1268,7 +1264,7 @@ inline bool is_min_close(distance::DistanceType metric) return select_min; } -/** See raft::spatial::knn::ivf_flat::search docs */ +/** See raft::neighbors::ivf_flat::search docs */ template inline void search(raft::device_resources const& handle, const search_params& params, @@ -1305,4 +1301,4 @@ inline void search(raft::device_resources const& handle, mr); } -} // namespace raft::spatial::knn::ivf_flat::detail +} // namespace raft::neighbors::ivf_flat::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh new file mode 100644 index 0000000000..dabd06fa89 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh @@ -0,0 +1,176 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace raft::neighbors::ivf_flat::detail { + +// Serialization version 3 +// No backward compatibility yet; that is, can't add additional fields without breaking +// backward compatibility. +// TODO(hcho3) Implement next-gen serializer for IVF that allows for expansion in a backward +// compatible fashion. +constexpr int serialization_version = 3; + +// NB: we wrap this check in a struct, so that the updated RealSize is easy to see in the error +// message. +template +struct check_index_layout { + static_assert(RealSize == ExpectedSize, + "The size of the index struct has changed since the last update; " + "paste in the new size and consider updating the serialization logic"); +}; + +template struct check_index_layout), 368>; + +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index_ IVF-Flat index + * + */ +template +void serialize(raft::device_resources const& handle, std::ostream& os, const index& index_) +{ + RAFT_LOG_DEBUG( + "Saving IVF-Flat index, size %zu, dim %u", static_cast(index_.size()), index_.dim()); + + serialize_scalar(handle, os, serialization_version); + serialize_scalar(handle, os, index_.size()); + serialize_scalar(handle, os, index_.dim()); + serialize_scalar(handle, os, index_.n_lists()); + serialize_scalar(handle, os, index_.metric()); + serialize_scalar(handle, os, index_.adaptive_centers()); + serialize_scalar(handle, os, index_.conservative_memory_allocation()); + serialize_mdspan(handle, os, index_.centers()); + if (index_.center_norms()) { + bool has_norms = true; + serialize_scalar(handle, os, has_norms); + serialize_mdspan(handle, os, *index_.center_norms()); + } else { + bool has_norms = false; + serialize_scalar(handle, os, has_norms); + } + auto sizes_host = make_host_vector(index_.list_sizes().extent(0)); + copy(sizes_host.data_handle(), + index_.list_sizes().data_handle(), + sizes_host.size(), + handle.get_stream()); + handle.sync_stream(); + serialize_mdspan(handle, os, sizes_host.view()); + + list_spec list_store_spec{index_.dim(), true}; + for (uint32_t label = 0; label < index_.n_lists(); label++) { + ivf::serialize_list(handle, + os, + index_.lists()[label], + list_store_spec, + Pow2::roundUp(sizes_host(label))); + } + handle.sync_stream(); +} + +template +void serialize(raft::device_resources const& handle, + const std::string& filename, + const index& index_) +{ + std::ofstream of(filename, std::ios::out | std::ios::binary); + if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + + detail::serialize(handle, of, index_); + + of.close(); + if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } +} + +/** Load an index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * @param[in] index_ IVF-Flat index + * + */ +template +auto deserialize(raft::device_resources const& handle, std::istream& is) -> index +{ + auto ver = deserialize_scalar(handle, is); + if (ver != serialization_version) { + RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); + } + auto n_rows = deserialize_scalar(handle, is); + auto dim = deserialize_scalar(handle, is); + auto n_lists = deserialize_scalar(handle, is); + auto metric = deserialize_scalar(handle, is); + bool adaptive_centers = deserialize_scalar(handle, is); + bool cma = deserialize_scalar(handle, is); + + index index_ = index(handle, metric, n_lists, adaptive_centers, cma, dim); + + deserialize_mdspan(handle, is, index_.centers()); + bool has_norms = deserialize_scalar(handle, is); + if (has_norms) { + index_.allocate_center_norms(handle); + if (!index_.center_norms()) { + RAFT_FAIL("Error inconsistent center norms"); + } else { + auto center_norms = index_.center_norms().value(); + deserialize_mdspan(handle, is, center_norms); + } + } + deserialize_mdspan(handle, is, index_.list_sizes()); + + list_spec list_device_spec{index_.dim(), cma}; + list_spec list_store_spec{index_.dim(), true}; + for (uint32_t label = 0; label < index_.n_lists(); label++) { + ivf::deserialize_list(handle, is, index_.lists()[label], list_store_spec, list_device_spec); + } + handle.sync_stream(); + + index_.recompute_internal_state(handle); + + return index_; +} + +template +auto deserialize(raft::device_resources const& handle, const std::string& filename) + -> index +{ + std::ifstream is(filename, std::ios::in | std::ios::binary); + + if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + + auto index = detail::deserialize(handle, is); + + is.close(); + + return index; +} +} // namespace raft::neighbors::ivf_flat::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh index bf3014568a..1a563d213e 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh @@ -664,7 +664,7 @@ __launch_bounds__(BlockSize) __global__ void process_and_fill_codes_kernel( const uint32_t pq_len = pq_centers.extent(1); const uint32_t pq_dim = new_vectors.extent(1) / pq_len; - auto pq_extents = list_spec{PqBits, pq_dim, true}.make_list_extents(out_ix + 1); + auto pq_extents = list_spec{PqBits, pq_dim, true}.make_list_extents(out_ix + 1); auto pq_extents_vectorized = make_extents(pq_extents.extent(0), pq_extents.extent(1), pq_extents.extent(2)); auto pq_dataset = make_mdspan( @@ -899,14 +899,15 @@ void extend(raft::device_resources const& handle, &managed_memory_upstream, 1024 * 1024); // The spec defines how the clusters look like - auto spec = list_spec{index->pq_bits(), index->pq_dim(), index->conservative_memory_allocation()}; + auto spec = list_spec{ + index->pq_bits(), index->pq_dim(), index->conservative_memory_allocation()}; // Try to allocate an index with the same parameters and the projected new size // (which can be slightly larger than index->size() + n_rows, due to padding). // If this fails, the index would be too big to fit in the device anyway. std::optional> placeholder_list( std::in_place_t{}, handle, - list_spec(spec), + list_spec{spec}, n_rows + (kIndexGroupSize - 1) * std::min(n_clusters, n_rows)); // Available device memory diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh index 0701b0feb5..826ed90db1 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh @@ -89,10 +89,9 @@ void serialize(raft::device_resources const& handle_, std::ostream& os, const in handle_.get_stream()); handle_.sync_stream(); serialize_mdspan(handle_, os, sizes_host.view()); - auto list_store_spec = list_spec{index.pq_bits(), index.pq_dim(), true}; + auto list_store_spec = list_spec{index.pq_bits(), index.pq_dim(), true}; for (uint32_t label = 0; label < index.n_lists(); label++) { - ivf::serialize_list( - handle_, os, index.lists()[label], list_store_spec, sizes_host(label)); + ivf::serialize_list(handle_, os, index.lists()[label], list_store_spec, sizes_host(label)); } } @@ -162,11 +161,10 @@ auto deserialize(raft::device_resources const& handle_, std::istream& is) -> ind deserialize_mdspan(handle_, is, index.centers_rot()); deserialize_mdspan(handle_, is, index.rotation_matrix()); deserialize_mdspan(handle_, is, index.list_sizes()); - auto list_device_spec = list_spec{pq_bits, pq_dim, cma}; - auto list_store_spec = list_spec{pq_bits, pq_dim, true}; + auto list_device_spec = list_spec{pq_bits, pq_dim, cma}; + auto list_store_spec = list_spec{pq_bits, pq_dim, true}; for (auto& list : index.lists()) { - ivf::deserialize_list( - handle_, is, list, list_store_spec, list_device_spec); + ivf::deserialize_list(handle_, is, list, list_store_spec, list_device_spec); } handle_.sync_stream(); diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh index b264643584..b0aebe28b6 100644 --- a/cpp/include/raft/neighbors/detail/refine.cuh +++ b/cpp/include/raft/neighbors/detail/refine.cuh @@ -20,9 +20,9 @@ #include #include #include +#include +#include #include -#include -#include #include #include @@ -108,17 +108,17 @@ void refine_device(raft::device_resources const& handle, handle.get_thrust_policy(), fake_coarse_idx.data(), fake_coarse_idx.data() + n_queries); raft::neighbors::ivf_flat::index refinement_index( - handle, metric, n_queries, false, dim); + handle, metric, n_queries, false, true, dim); - raft::spatial::knn::ivf_flat::detail::fill_refinement_index(handle, - &refinement_index, - dataset.data_handle(), - neighbor_candidates.data_handle(), - n_queries, - n_candidates); + raft::neighbors::ivf_flat::detail::fill_refinement_index(handle, + &refinement_index, + dataset.data_handle(), + neighbor_candidates.data_handle(), + n_queries, + n_candidates); uint32_t grid_dim_x = 1; - raft::spatial::knn::ivf_flat::detail::ivfflat_interleaved_scan< + raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< data_t, typename raft::spatial::knn::detail::utils::config::value_t, idx_t>(refinement_index, @@ -128,7 +128,7 @@ void refine_device(raft::device_resources const& handle, refinement_index.metric(), 1, k, - raft::spatial::knn::ivf_flat::detail::is_min_close(metric), + raft::neighbors::ivf_flat::detail::is_min_close(metric), indices.data_handle(), distances.data_handle(), grid_dim_x, diff --git a/cpp/include/raft/neighbors/ivf_flat.cuh b/cpp/include/raft/neighbors/ivf_flat.cuh index f18611b9f1..f42bfe66c7 100644 --- a/cpp/include/raft/neighbors/ivf_flat.cuh +++ b/cpp/include/raft/neighbors/ivf_flat.cuh @@ -16,9 +16,10 @@ #pragma once +#include +#include +#include #include -#include -#include #include @@ -67,7 +68,7 @@ auto build(raft::device_resources const& handle, IdxT n_rows, uint32_t dim) -> index { - return raft::spatial::knn::ivf_flat::detail::build(handle, params, dataset, n_rows, dim); + return raft::neighbors::ivf_flat::detail::build(handle, params, dataset, n_rows, dim); } /** @@ -99,7 +100,6 @@ auto build(raft::device_resources const& handle, * @tparam value_t data element type * @tparam idx_t type of the indices in the source dataset * @tparam int_t precision / type of integral arguments - * @tparam matrix_idx_t matrix indexing type * * @param[in] handle * @param[in] params configure the index building @@ -112,11 +112,11 @@ auto build(raft::device_resources const& handle, raft::device_matrix_view dataset, const index_params& params) -> index { - return raft::spatial::knn::ivf_flat::detail::build(handle, - params, - dataset.data_handle(), - static_cast(dataset.extent(0)), - static_cast(dataset.extent(1))); + return raft::neighbors::ivf_flat::detail::build(handle, + params, + dataset.data_handle(), + static_cast(dataset.extent(0)), + static_cast(dataset.extent(1))); } /** @} */ @@ -160,7 +160,7 @@ auto extend(raft::device_resources const& handle, const IdxT* new_indices, IdxT n_rows) -> index { - return raft::spatial::knn::ivf_flat::detail::extend( + return raft::neighbors::ivf_flat::detail::extend( handle, orig_index, new_vectors, new_indices, n_rows); } @@ -190,8 +190,6 @@ auto extend(raft::device_resources const& handle, * * @tparam value_t data element type * @tparam idx_t type of the indices in the source dataset - * @tparam int_t precision / type of integral arguments - * @tparam matrix_idx_t matrix indexing type * * @param[in] handle * @param[in] orig_index original index @@ -252,7 +250,7 @@ void extend(raft::device_resources const& handle, const IdxT* new_indices, IdxT n_rows) { - *index = extend(handle, *index, new_vectors, new_indices, n_rows); + raft::neighbors::ivf_flat::detail::extend(handle, index, new_vectors, new_indices, n_rows); } /** @@ -277,8 +275,6 @@ void extend(raft::device_resources const& handle, * * @tparam value_t data element type * @tparam idx_t type of the indices in the source dataset - * @tparam int_t precision / type of integral arguments - * @tparam matrix_idx_t matrix indexing type * * @param[in] handle * @param[inout] index @@ -293,11 +289,11 @@ void extend(raft::device_resources const& handle, raft::device_matrix_view new_vectors, std::optional> new_indices = std::nullopt) { - *index = extend(handle, - *index, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - static_cast(new_vectors.extent(0))); + extend(handle, + index, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + static_cast(new_vectors.extent(0))); } /** @} */ @@ -355,7 +351,7 @@ void search(raft::device_resources const& handle, float* distances, rmm::mr::device_memory_resource* mr = nullptr) { - return raft::spatial::knn::ivf_flat::detail::search( + return raft::neighbors::ivf_flat::detail::search( handle, params, index, queries, n_queries, k, neighbors, distances, mr); } @@ -388,7 +384,6 @@ void search(raft::device_resources const& handle, * @tparam value_t data element type * @tparam idx_t type of the indices * @tparam int_t precision / type of integral arguments - * @tparam matrix_idx_t matrix indexing type * * @param[in] handle * @param[in] index ivf-flat constructed index diff --git a/cpp/include/raft/neighbors/ivf_flat_serialize.cuh b/cpp/include/raft/neighbors/ivf_flat_serialize.cuh new file mode 100644 index 0000000000..d2ec9a39bd --- /dev/null +++ b/cpp/include/raft/neighbors/ivf_flat_serialize.cuh @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "detail/ivf_flat_serialize.cuh" + +namespace raft::neighbors::ivf_flat { + +/** + * \defgroup ivf_flat_serialize IVF-Flat Serialize + * @{ + */ + +/** + * Write the index to an output stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // create an output stream + * std::ostream os(std::cout.rdbuf()); + * // create an index with `auto index = ivf_flat::build(...);` + * raft::serialize(handle, os, index); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] os output stream + * @param[in] index IVF-Flat index + * + * @return raft::neighbors::ivf_flat::index + */ +template +void serialize(raft::device_resources const& handle, std::ostream& os, const index& index) +{ + detail::serialize(handle, os, index); +} + +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an index with `auto index = ivf_flat::build(...);` + * raft::serialize(handle, filename, index); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index IVF-Flat index + * + * @return raft::neighbors::ivf_flat::index + */ +template +void serialize(raft::device_resources const& handle, + const std::string& filename, + const index& index) +{ + detail::serialize(handle, filename, index); +} + +/** + * Load index from input stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // create an input stream + * std::istream is(std::cin.rdbuf()); + * using T = float; // data element type + * using IdxT = int; // type of the index + * auto index = raft::deserialize(handle, is); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] is input stream + * + * @return raft::neighbors::ivf_flat::index + */ +template +index deserialize(raft::device_resources const& handle, std::istream& is) +{ + return detail::deserialize(handle, is); +} + +/** + * Load index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * using T = float; // data element type + * using IdxT = int; // type of the index + * auto index = raft::deserialize(handle, filename); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * + * @return raft::neighbors::ivf_flat::index + */ +template +index deserialize(raft::device_resources const& handle, const std::string& filename) +{ + return detail::deserialize(handle, filename); +} + +/**@}*/ + +} // namespace raft::neighbors::ivf_flat diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp index d234822a23..20bca6e3e6 100644 --- a/cpp/include/raft/neighbors/ivf_flat_types.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_types.hpp @@ -20,10 +20,15 @@ #include #include +#include +#include #include +#include #include +#include #include +#include #include namespace raft::neighbors::ivf_flat { @@ -55,6 +60,16 @@ struct index_params : ann::index_params { * `index.centers()` "drift" together with the changing distribution of the newly added data. */ bool adaptive_centers = false; + /** + * By default, the algorithm allocates more space than necessary for individual clusters + * (`list_data`). This allows to amortize the cost of memory allocation and reduce the number of + * data copies during repeated calls to `extend` (extending the database). + * + * The alternative is the conservative allocation behavior; when enabled, the algorithm always + * allocates the minimum amount of memory required to store the given number of records. Set this + * flag to `true` if you prefer to use as little GPU memory for the database as possible. + */ + bool conservative_memory_allocation = false; }; struct search_params : ann::search_params { @@ -65,6 +80,40 @@ struct search_params : ann::search_params { static_assert(std::is_aggregate_v); static_assert(std::is_aggregate_v); +template +struct list_spec { + using value_type = ValueT; + using list_extents = matrix_extent; + using index_type = IdxT; + + SizeT align_max; + SizeT align_min; + uint32_t dim; + + constexpr list_spec(uint32_t dim, bool conservative_memory_allocation) + : dim(dim), + align_min(kIndexGroupSize), + align_max(conservative_memory_allocation ? kIndexGroupSize : 1024) + { + } + + // Allow casting between different size-types (for safer size and offset calculations) + template + constexpr explicit list_spec(const list_spec& other_spec) + : dim{other_spec.dim}, align_min{other_spec.align_min}, align_max{other_spec.align_max} + { + } + + /** Determine the extents of an array enough to hold a given amount of data. */ + constexpr auto make_list_extents(SizeT n_rows) const -> list_extents + { + return make_extents(n_rows, dim); + } +}; + +template +using list_data = ivf::list; + /** * @brief IVF-flat index. * @@ -118,59 +167,27 @@ struct index : ann::index { * x[16, 4], x[16, 5], x[17, 4], x[17, 5], ... x[30, 4], x[30, 5], - , - , * */ - inline auto data() noexcept -> device_mdspan, row_major> - { - return data_.view(); - } - [[nodiscard]] inline auto data() const noexcept - -> device_mdspan, row_major> - { - return data_.view(); - } - - /** Inverted list indices: ids of items in the source data [size] */ - inline auto indices() noexcept -> device_mdspan, row_major> - { - return indices_.view(); - } - [[nodiscard]] inline auto indices() const noexcept - -> device_mdspan, row_major> - { - return indices_.view(); - } - - /** Sizes of the lists (clusters) [n_lists] */ - inline auto list_sizes() noexcept -> device_mdspan, row_major> + /** Sizes of the lists (clusters) [n_lists] + * NB: This may differ from the actual list size if the shared lists have been extended by another + * index + */ + inline auto list_sizes() noexcept -> device_vector_view { return list_sizes_.view(); } [[nodiscard]] inline auto list_sizes() const noexcept - -> device_mdspan, row_major> + -> device_vector_view { return list_sizes_.view(); } - /** - * Offsets into the lists [n_lists + 1]. - * The last value contains the total length of the index. - */ - inline auto list_offsets() noexcept -> device_mdspan, row_major> - { - return list_offsets_.view(); - } - [[nodiscard]] inline auto list_offsets() const noexcept - -> device_mdspan, row_major> - { - return list_offsets_.view(); - } - /** k-means cluster centers corresponding to the lists [n_lists, dim] */ - inline auto centers() noexcept -> device_mdspan, row_major> + inline auto centers() noexcept -> device_matrix_view { return centers_.view(); } [[nodiscard]] inline auto centers() const noexcept - -> device_mdspan, row_major> + -> device_matrix_view { return centers_.view(); } @@ -181,39 +198,33 @@ struct index : ann::index { * NB: this may be empty if the index is empty or if the metric does not require the center norms * calculation. */ - inline auto center_norms() noexcept - -> std::optional, row_major>> + inline auto center_norms() noexcept -> std::optional> { if (center_norms_.has_value()) { - return std::make_optional, row_major>>( - center_norms_->view()); + return std::make_optional>(center_norms_->view()); } else { return std::nullopt; } } [[nodiscard]] inline auto center_norms() const noexcept - -> std::optional, row_major>> + -> std::optional> { if (center_norms_.has_value()) { - return std::make_optional, row_major>>( - center_norms_->view()); + return std::make_optional>(center_norms_->view()); } else { return std::nullopt; } } /** Total length of the index. */ - [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { return indices_.extent(0); } + [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { return total_size_; } /** Dimensionality of the data. */ [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t { return centers_.extent(1); } /** Number of clusters/inverted lists. */ - [[nodiscard]] constexpr inline auto n_lists() const noexcept -> uint32_t - { - return centers_.extent(0); - } + [[nodiscard]] constexpr inline auto n_lists() const noexcept -> uint32_t { return lists_.size(); } // Don't allow copying the index for performance reasons (try avoiding copying data) index(const index&) = delete; @@ -223,51 +234,111 @@ struct index : ann::index { ~index() = default; /** Construct an empty index. It needs to be trained and then populated. */ - index(raft::device_resources const& handle, + index(raft::device_resources const& res, raft::distance::DistanceType metric, uint32_t n_lists, bool adaptive_centers, + bool conservative_memory_allocation, uint32_t dim) : ann::index(), veclen_(calculate_veclen(dim)), metric_(metric), adaptive_centers_(adaptive_centers), - data_(make_device_mdarray(handle, make_extents(0, dim))), - indices_(make_device_mdarray(handle, make_extents(0))), - list_sizes_(make_device_mdarray(handle, make_extents(n_lists))), - list_offsets_(make_device_mdarray(handle, make_extents(n_lists + 1))), - centers_(make_device_mdarray(handle, make_extents(n_lists, dim))), - center_norms_(std::nullopt) + conservative_memory_allocation_{conservative_memory_allocation}, + centers_(make_device_matrix(res, n_lists, dim)), + center_norms_(std::nullopt), + lists_{n_lists}, + list_sizes_{make_device_vector(res, n_lists)}, + data_ptrs_{make_device_vector(res, n_lists)}, + inds_ptrs_{make_device_vector(res, n_lists)}, + total_size_{0} { check_consistency(); } /** Construct an empty index. It needs to be trained and then populated. */ - index(raft::device_resources const& handle, const index_params& params, uint32_t dim) - : index(handle, params.metric, params.n_lists, params.adaptive_centers, dim) + index(raft::device_resources const& res, const index_params& params, uint32_t dim) + : index(res, + params.metric, + params.n_lists, + params.adaptive_centers, + params.conservative_memory_allocation, + dim) { } + /** Pointers to the inverted lists (clusters) data [n_lists]. */ + inline auto data_ptrs() noexcept -> device_vector_view { return data_ptrs_.view(); } + [[nodiscard]] inline auto data_ptrs() const noexcept -> device_vector_view + { + return data_ptrs_.view(); + } + + /** Pointers to the inverted lists (clusters) indices [n_lists]. */ + inline auto inds_ptrs() noexcept -> device_vector_view + { + return inds_ptrs_.view(); + } + [[nodiscard]] inline auto inds_ptrs() const noexcept -> device_vector_view + { + return inds_ptrs_.view(); + } /** - * Replace the content of the index with new uninitialized mdarrays to hold the indicated amount - * of data. + * Whether to use convervative memory allocation when extending the list (cluster) data + * (see index_params.conservative_memory_allocation). */ - void allocate(raft::device_resources const& handle, IdxT index_size) + [[nodiscard]] constexpr inline auto conservative_memory_allocation() const noexcept -> bool { - data_ = make_device_mdarray(handle, make_extents(index_size, dim())); - indices_ = make_device_mdarray(handle, make_extents(index_size)); + return conservative_memory_allocation_; + } + /** + * Update the state of the dependent index members. + */ + void recompute_internal_state(raft::device_resources const& res) + { + auto stream = res.get_stream(); + + // Actualize the list pointers + auto this_lists = lists(); + auto this_data_ptrs = data_ptrs(); + auto this_inds_ptrs = inds_ptrs(); + IdxT recompute_total_size = 0; + for (uint32_t label = 0; label < this_lists.size(); label++) { + auto& list = this_lists[label]; + const auto data_ptr = list ? list->data.data_handle() : nullptr; + const auto inds_ptr = list ? list->indices.data_handle() : nullptr; + const auto list_size = list ? IdxT(list->size) : 0; + copy(&this_data_ptrs(label), &data_ptr, 1, stream); + copy(&this_inds_ptrs(label), &inds_ptr, 1, stream); + recompute_total_size += list_size; + } + total_size_ = recompute_total_size; + check_consistency(); + } + + void allocate_center_norms(raft::device_resources const& res) + { switch (metric_) { case raft::distance::DistanceType::L2Expanded: case raft::distance::DistanceType::L2SqrtExpanded: case raft::distance::DistanceType::L2Unexpanded: case raft::distance::DistanceType::L2SqrtUnexpanded: - center_norms_ = make_device_mdarray(handle, make_extents(n_lists())); + center_norms_ = make_device_vector(res, n_lists()); break; default: center_norms_ = std::nullopt; } + } - check_consistency(); + /** Lists' data and indices. */ + inline auto lists() noexcept -> std::vector>>& + { + return lists_; + } + [[nodiscard]] inline auto lists() const noexcept + -> const std::vector>>& + { + return lists_; } private: @@ -278,26 +349,29 @@ struct index : ann::index { uint32_t veclen_; raft::distance::DistanceType metric_; bool adaptive_centers_; - device_mdarray, row_major> data_; - device_mdarray, row_major> indices_; - device_mdarray, row_major> list_sizes_; - device_mdarray, row_major> list_offsets_; - device_mdarray, row_major> centers_; - std::optional, row_major>> center_norms_; + bool conservative_memory_allocation_; + std::vector>> lists_; + device_vector list_sizes_; + device_matrix centers_; + std::optional> center_norms_; + + // Computed members + device_vector data_ptrs_; + device_vector inds_ptrs_; + IdxT total_size_; /** Throw an error if the index content is inconsistent. */ void check_consistency() { + auto n_lists = lists_.size(); RAFT_EXPECTS(dim() % veclen_ == 0, "dimensionality is not a multiple of the veclen"); - RAFT_EXPECTS(data_.extent(0) == indices_.extent(0), "inconsistent index size"); - RAFT_EXPECTS(data_.extent(1) == IdxT(centers_.extent(1)), "inconsistent data dimensionality"); - RAFT_EXPECTS( // - (centers_.extent(0) == list_sizes_.extent(0)) && // - (centers_.extent(0) + 1 == list_offsets_.extent(0)) && // + RAFT_EXPECTS(list_sizes_.extent(0) == n_lists, "inconsistent list size"); + RAFT_EXPECTS(data_ptrs_.extent(0) == n_lists, "inconsistent list size"); + RAFT_EXPECTS(inds_ptrs_.extent(0) == n_lists, "inconsistent list size"); + RAFT_EXPECTS( // + (centers_.extent(0) == list_sizes_.extent(0)) && // (!center_norms_.has_value() || centers_.extent(0) == center_norms_->extent(0)), "inconsistent number of lists (clusters)"); - RAFT_EXPECTS(reinterpret_cast(data_.data_handle()) % (veclen_ * sizeof(T)) == 0, - "The data storage pointer is not aligned to the vector length"); } static auto calculate_veclen(uint32_t dim) -> uint32_t diff --git a/cpp/include/raft/neighbors/ivf_list.hpp b/cpp/include/raft/neighbors/ivf_list.hpp index 4644143057..e2aa661d12 100644 --- a/cpp/include/raft/neighbors/ivf_list.hpp +++ b/cpp/include/raft/neighbors/ivf_list.hpp @@ -35,10 +35,12 @@ namespace raft::neighbors::ivf { /** The data for a single IVF list. */ -template