Skip to content

Commit

Permalink
IVF-Flat index splitting (#1271)
Browse files Browse the repository at this point in the history
Refactor of `ivf_flat::index` to split the cluster data in separate buffers
- Addressing #1170.
- Following #1249 in the index structure implementation.
- Adding serialization API with stream and filename overloads
- Moving `raft::spatial::knn::ivf_flat` namespace to `raft::neighbors::ivf_flat`

Authors:
  - Micka (https://github.com/lowener)
  - Corey J. Nolet (https://github.com/cjnolet)
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - AJ Schmidt (https://github.com/ajschmidt8)
  - Artem M. Chirkin (https://github.com/achirkin)

URL: #1271
  • Loading branch information
lowener authored Mar 15, 2023
1 parent 667d873 commit 03e26b5
Show file tree
Hide file tree
Showing 13 changed files with 797 additions and 434 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@

#pragma once

#include "../ivf_flat_types.hpp"
#include "ann_utils.cuh"

#include <raft/core/cudart_utils.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/logger.hpp>
Expand All @@ -30,6 +27,8 @@
#include <raft/linalg/unary_op.cuh>
#include <raft/matrix/detail/select_k.cuh>
#include <raft/matrix/detail/select_warpsort.cuh>
#include <raft/neighbors/ivf_flat_types.hpp>
#include <raft/spatial/knn/detail/ann_utils.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/device_loads_stores.cuh>
#include <raft/util/integer_utils.hpp>
Expand All @@ -39,7 +38,7 @@
#include <rmm/cuda_stream_view.hpp>
#include <rmm/mr/device/per_device_resource.hpp>

namespace raft::spatial::knn::ivf_flat::detail {
namespace raft::neighbors::ivf_flat::detail {

using namespace raft::spatial::knn::detail; // NOLINT

Expand Down Expand Up @@ -673,10 +672,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
const uint32_t query_smem_elems,
const T* query,
const uint32_t* coarse_index,
const IdxT* list_indices,
const T* list_data,
const IdxT* const* list_indices_ptrs,
const T* const* list_data_ptrs,
const uint32_t* list_sizes,
const IdxT* list_offsets,
const uint32_t n_probes,
const uint32_t k,
const uint32_t dim,
Expand Down Expand Up @@ -722,8 +720,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)

// Every CUDA block scans one cluster at a time.
for (int probe_id = blockIdx.x; probe_id < n_probes; probe_id += gridDim.x) {
const uint32_t list_id = coarse_index[probe_id]; // The id of cluster(list)
const size_t list_offset = list_offsets[list_id];
const uint32_t list_id = coarse_index[probe_id]; // The id of cluster(list)

// The number of vectors in each cluster(list); [nlist]
const uint32_t list_length = list_sizes[list_id];
Expand All @@ -741,7 +738,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
group_id += kNumWarps) {
AccT dist = 0;
// This is where this warp begins reading data (start position of an interleaved group)
const T* data = list_data + (list_offset + group_id * kIndexGroupSize) * dim;
const T* data = list_data_ptrs[list_id] + (group_id * kIndexGroupSize) * dim;

// This is the vector a given lane/thread handles
const uint32_t vec_id = group_id * WarpSize + lane_id;
Expand Down Expand Up @@ -778,7 +775,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)

// Enqueue one element per thread
const float val = valid ? static_cast<float>(dist) : block_sort_t::queue_t::kDummy;
const size_t idx = valid ? static_cast<size_t>(list_indices[list_offset + vec_id]) : 0;
const size_t idx = valid ? static_cast<size_t>(list_indices_ptrs[list_id][vec_id]) : 0;
queue.add(val, idx);
}
}
Expand Down Expand Up @@ -819,7 +816,7 @@ template <int Capacity,
typename PostLambda>
void launch_kernel(Lambda lambda,
PostLambda post_process,
const ivf_flat::index<T, IdxT>& index,
const index<T, IdxT>& index,
const T* queries,
const uint32_t* coarse_index,
const uint32_t num_queries,
Expand Down Expand Up @@ -869,10 +866,9 @@ void launch_kernel(Lambda lambda,
query_smem_elems,
queries,
coarse_index,
index.indices().data_handle(),
index.data().data_handle(),
index.inds_ptrs().data_handle(),
index.data_ptrs().data_handle(),
index.list_sizes().data_handle(),
index.list_offsets().data_handle(),
n_probes,
k,
index.dim(),
Expand Down Expand Up @@ -1056,7 +1052,7 @@ struct select_interleaved_scan_kernel {
* @param stream
*/
template <typename T, typename AccT, typename IdxT>
void ivfflat_interleaved_scan(const ivf_flat::index<T, IdxT>& index,
void ivfflat_interleaved_scan(const index<T, IdxT>& index,
const T* queries,
const uint32_t* coarse_query_results,
const uint32_t n_queries,
Expand Down Expand Up @@ -1268,7 +1264,7 @@ inline bool is_min_close(distance::DistanceType metric)
return select_min;
}

/** See raft::spatial::knn::ivf_flat::search docs */
/** See raft::neighbors::ivf_flat::search docs */
template <typename T, typename IdxT>
inline void search(raft::device_resources const& handle,
const search_params& params,
Expand Down Expand Up @@ -1305,4 +1301,4 @@ inline void search(raft::device_resources const& handle,
mr);
}

} // namespace raft::spatial::knn::ivf_flat::detail
} // namespace raft::neighbors::ivf_flat::detail
176 changes: 176 additions & 0 deletions cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/mdarray.hpp>
#include <raft/core/serialize.hpp>
#include <raft/neighbors/ivf_flat_types.hpp>
#include <raft/neighbors/ivf_list.hpp>
#include <raft/neighbors/ivf_list_types.hpp>

#include <fstream>

namespace raft::neighbors::ivf_flat::detail {

// Serialization version 3
// No backward compatibility yet; that is, can't add additional fields without breaking
// backward compatibility.
// TODO(hcho3) Implement next-gen serializer for IVF that allows for expansion in a backward
// compatible fashion.
constexpr int serialization_version = 3;

// NB: we wrap this check in a struct, so that the updated RealSize is easy to see in the error
// message.
template <size_t RealSize, size_t ExpectedSize>
struct check_index_layout {
static_assert(RealSize == ExpectedSize,
"The size of the index struct has changed since the last update; "
"paste in the new size and consider updating the serialization logic");
};

template struct check_index_layout<sizeof(index<double, std::uint64_t>), 368>;

/**
* Save the index to file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @param[in] handle the raft handle
* @param[in] filename the file name for saving the index
* @param[in] index_ IVF-Flat index
*
*/
template <typename T, typename IdxT>
void serialize(raft::device_resources const& handle, std::ostream& os, const index<T, IdxT>& index_)
{
RAFT_LOG_DEBUG(
"Saving IVF-Flat index, size %zu, dim %u", static_cast<size_t>(index_.size()), index_.dim());

serialize_scalar(handle, os, serialization_version);
serialize_scalar(handle, os, index_.size());
serialize_scalar(handle, os, index_.dim());
serialize_scalar(handle, os, index_.n_lists());
serialize_scalar(handle, os, index_.metric());
serialize_scalar(handle, os, index_.adaptive_centers());
serialize_scalar(handle, os, index_.conservative_memory_allocation());
serialize_mdspan(handle, os, index_.centers());
if (index_.center_norms()) {
bool has_norms = true;
serialize_scalar(handle, os, has_norms);
serialize_mdspan(handle, os, *index_.center_norms());
} else {
bool has_norms = false;
serialize_scalar(handle, os, has_norms);
}
auto sizes_host = make_host_vector<uint32_t, uint32_t>(index_.list_sizes().extent(0));
copy(sizes_host.data_handle(),
index_.list_sizes().data_handle(),
sizes_host.size(),
handle.get_stream());
handle.sync_stream();
serialize_mdspan(handle, os, sizes_host.view());

list_spec<uint32_t, T, IdxT> list_store_spec{index_.dim(), true};
for (uint32_t label = 0; label < index_.n_lists(); label++) {
ivf::serialize_list(handle,
os,
index_.lists()[label],
list_store_spec,
Pow2<kIndexGroupSize>::roundUp(sizes_host(label)));
}
handle.sync_stream();
}

template <typename T, typename IdxT>
void serialize(raft::device_resources const& handle,
const std::string& filename,
const index<T, IdxT>& index_)
{
std::ofstream of(filename, std::ios::out | std::ios::binary);
if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); }

detail::serialize(handle, of, index_);

of.close();
if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); }
}

/** Load an index from file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @param[in] handle the raft handle
* @param[in] filename the name of the file that stores the index
* @param[in] index_ IVF-Flat index
*
*/
template <typename T, typename IdxT>
auto deserialize(raft::device_resources const& handle, std::istream& is) -> index<T, IdxT>
{
auto ver = deserialize_scalar<int>(handle, is);
if (ver != serialization_version) {
RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver);
}
auto n_rows = deserialize_scalar<IdxT>(handle, is);
auto dim = deserialize_scalar<std::uint32_t>(handle, is);
auto n_lists = deserialize_scalar<std::uint32_t>(handle, is);
auto metric = deserialize_scalar<raft::distance::DistanceType>(handle, is);
bool adaptive_centers = deserialize_scalar<bool>(handle, is);
bool cma = deserialize_scalar<bool>(handle, is);

index<T, IdxT> index_ = index<T, IdxT>(handle, metric, n_lists, adaptive_centers, cma, dim);

deserialize_mdspan(handle, is, index_.centers());
bool has_norms = deserialize_scalar<bool>(handle, is);
if (has_norms) {
index_.allocate_center_norms(handle);
if (!index_.center_norms()) {
RAFT_FAIL("Error inconsistent center norms");
} else {
auto center_norms = index_.center_norms().value();
deserialize_mdspan(handle, is, center_norms);
}
}
deserialize_mdspan(handle, is, index_.list_sizes());

list_spec<uint32_t, T, IdxT> list_device_spec{index_.dim(), cma};
list_spec<uint32_t, T, IdxT> list_store_spec{index_.dim(), true};
for (uint32_t label = 0; label < index_.n_lists(); label++) {
ivf::deserialize_list(handle, is, index_.lists()[label], list_store_spec, list_device_spec);
}
handle.sync_stream();

index_.recompute_internal_state(handle);

return index_;
}

template <typename T, typename IdxT>
auto deserialize(raft::device_resources const& handle, const std::string& filename)
-> index<T, IdxT>
{
std::ifstream is(filename, std::ios::in | std::ios::binary);

if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); }

auto index = detail::deserialize<T, IdxT>(handle, is);

is.close();

return index;
}
} // namespace raft::neighbors::ivf_flat::detail
7 changes: 4 additions & 3 deletions cpp/include/raft/neighbors/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ __launch_bounds__(BlockSize) __global__ void process_and_fill_codes_kernel(
const uint32_t pq_len = pq_centers.extent(1);
const uint32_t pq_dim = new_vectors.extent(1) / pq_len;

auto pq_extents = list_spec<uint32_t>{PqBits, pq_dim, true}.make_list_extents(out_ix + 1);
auto pq_extents = list_spec<uint32_t, IdxT>{PqBits, pq_dim, true}.make_list_extents(out_ix + 1);
auto pq_extents_vectorized =
make_extents<uint32_t>(pq_extents.extent(0), pq_extents.extent(1), pq_extents.extent(2));
auto pq_dataset = make_mdspan<pq_vec_t, uint32_t, row_major, false, true>(
Expand Down Expand Up @@ -899,14 +899,15 @@ void extend(raft::device_resources const& handle,
&managed_memory_upstream, 1024 * 1024);

// The spec defines how the clusters look like
auto spec = list_spec{index->pq_bits(), index->pq_dim(), index->conservative_memory_allocation()};
auto spec = list_spec<uint32_t, IdxT>{
index->pq_bits(), index->pq_dim(), index->conservative_memory_allocation()};
// Try to allocate an index with the same parameters and the projected new size
// (which can be slightly larger than index->size() + n_rows, due to padding).
// If this fails, the index would be too big to fit in the device anyway.
std::optional<list_data<IdxT, size_t>> placeholder_list(
std::in_place_t{},
handle,
list_spec<size_t>(spec),
list_spec<size_t, IdxT>{spec},
n_rows + (kIndexGroupSize - 1) * std::min<IdxT>(n_clusters, n_rows));

// Available device memory
Expand Down
12 changes: 5 additions & 7 deletions cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,9 @@ void serialize(raft::device_resources const& handle_, std::ostream& os, const in
handle_.get_stream());
handle_.sync_stream();
serialize_mdspan(handle_, os, sizes_host.view());
auto list_store_spec = list_spec<uint32_t>{index.pq_bits(), index.pq_dim(), true};
auto list_store_spec = list_spec<uint32_t, IdxT>{index.pq_bits(), index.pq_dim(), true};
for (uint32_t label = 0; label < index.n_lists(); label++) {
ivf::serialize_list<list_spec, IdxT, uint32_t>(
handle_, os, index.lists()[label], list_store_spec, sizes_host(label));
ivf::serialize_list(handle_, os, index.lists()[label], list_store_spec, sizes_host(label));
}
}

Expand Down Expand Up @@ -162,11 +161,10 @@ auto deserialize(raft::device_resources const& handle_, std::istream& is) -> ind
deserialize_mdspan(handle_, is, index.centers_rot());
deserialize_mdspan(handle_, is, index.rotation_matrix());
deserialize_mdspan(handle_, is, index.list_sizes());
auto list_device_spec = list_spec<uint32_t>{pq_bits, pq_dim, cma};
auto list_store_spec = list_spec<uint32_t>{pq_bits, pq_dim, true};
auto list_device_spec = list_spec<uint32_t, IdxT>{pq_bits, pq_dim, cma};
auto list_store_spec = list_spec<uint32_t, IdxT>{pq_bits, pq_dim, true};
for (auto& list : index.lists()) {
ivf::deserialize_list<list_spec, IdxT, uint32_t>(
handle_, is, list, list_store_spec, list_device_spec);
ivf::deserialize_list(handle_, is, list, list_store_spec, list_device_spec);
}

handle_.sync_stream();
Expand Down
22 changes: 11 additions & 11 deletions cpp/include/raft/neighbors/detail/refine.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
#include <raft/core/device_resources.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/neighbors/detail/ivf_flat_build.cuh>
#include <raft/neighbors/detail/ivf_flat_search.cuh>
#include <raft/spatial/knn/detail/ann_utils.cuh>
#include <raft/spatial/knn/detail/ivf_flat_build.cuh>
#include <raft/spatial/knn/detail/ivf_flat_search.cuh>

#include <cstdlib>
#include <omp.h>
Expand Down Expand Up @@ -108,17 +108,17 @@ void refine_device(raft::device_resources const& handle,
handle.get_thrust_policy(), fake_coarse_idx.data(), fake_coarse_idx.data() + n_queries);

raft::neighbors::ivf_flat::index<data_t, idx_t> refinement_index(
handle, metric, n_queries, false, dim);
handle, metric, n_queries, false, true, dim);

raft::spatial::knn::ivf_flat::detail::fill_refinement_index(handle,
&refinement_index,
dataset.data_handle(),
neighbor_candidates.data_handle(),
n_queries,
n_candidates);
raft::neighbors::ivf_flat::detail::fill_refinement_index(handle,
&refinement_index,
dataset.data_handle(),
neighbor_candidates.data_handle(),
n_queries,
n_candidates);

uint32_t grid_dim_x = 1;
raft::spatial::knn::ivf_flat::detail::ivfflat_interleaved_scan<
raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan<
data_t,
typename raft::spatial::knn::detail::utils::config<data_t>::value_t,
idx_t>(refinement_index,
Expand All @@ -128,7 +128,7 @@ void refine_device(raft::device_resources const& handle,
refinement_index.metric(),
1,
k,
raft::spatial::knn::ivf_flat::detail::is_min_close(metric),
raft::neighbors::ivf_flat::detail::is_min_close(metric),
indices.data_handle(),
distances.data_handle(),
grid_dim_x,
Expand Down
Loading

0 comments on commit 03e26b5

Please sign in to comment.