Skip to content

Commit

Permalink
Add stream overloads to ivf_pq serialize/deserialize methods (#1315)
Browse files Browse the repository at this point in the history
Authors:
  - Divye Gala (https://github.com/divyegala)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - William Hicks (https://github.com/wphicks)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1315
  • Loading branch information
divyegala authored Mar 7, 2023
1 parent e7f0268 commit 3ca7eac
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 52 deletions.
127 changes: 80 additions & 47 deletions cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -48,93 +48,104 @@ struct check_index_layout {
template struct check_index_layout<sizeof(index<std::uint64_t>), 536>;

/**
* Save the index to file.
* Write the index to an output stream
*
* Experimental, both the API and the serialization format are subject to change.
*
* @param[in] handle the raft handle
* @param[in] filename the file name for saving the index
* @param[in] os output stream
* @param[in] index IVF-PQ index
*
*/
template <typename IdxT>
void serialize(raft::device_resources const& handle_,
const std::string& filename,
const index<IdxT>& index)
void serialize(raft::device_resources const& handle_, std::ostream& os, const index<IdxT>& index)
{
std::ofstream of(filename, std::ios::out | std::ios::binary);
if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); }

RAFT_LOG_DEBUG("Size %zu, dim %d, pq_dim %d, pq_bits %d",
static_cast<size_t>(index.size()),
static_cast<int>(index.dim()),
static_cast<int>(index.pq_dim()),
static_cast<int>(index.pq_bits()));

serialize_scalar(handle_, of, kSerializationVersion);
serialize_scalar(handle_, of, index.size());
serialize_scalar(handle_, of, index.dim());
serialize_scalar(handle_, of, index.pq_bits());
serialize_scalar(handle_, of, index.pq_dim());
serialize_scalar(handle_, of, index.conservative_memory_allocation());
serialize_scalar(handle_, os, kSerializationVersion);
serialize_scalar(handle_, os, index.size());
serialize_scalar(handle_, os, index.dim());
serialize_scalar(handle_, os, index.pq_bits());
serialize_scalar(handle_, os, index.pq_dim());
serialize_scalar(handle_, os, index.conservative_memory_allocation());

serialize_scalar(handle_, of, index.metric());
serialize_scalar(handle_, of, index.codebook_kind());
serialize_scalar(handle_, of, index.n_lists());
serialize_scalar(handle_, os, index.metric());
serialize_scalar(handle_, os, index.codebook_kind());
serialize_scalar(handle_, os, index.n_lists());

serialize_mdspan(handle_, of, index.pq_centers());
serialize_mdspan(handle_, of, index.centers());
serialize_mdspan(handle_, of, index.centers_rot());
serialize_mdspan(handle_, of, index.rotation_matrix());
serialize_mdspan(handle_, os, index.pq_centers());
serialize_mdspan(handle_, os, index.centers());
serialize_mdspan(handle_, os, index.centers_rot());
serialize_mdspan(handle_, os, index.rotation_matrix());

auto sizes_host = make_host_mdarray<uint32_t, uint32_t, row_major>(index.list_sizes().extents());
copy(sizes_host.data_handle(),
index.list_sizes().data_handle(),
sizes_host.size(),
handle_.get_stream());
handle_.sync_stream();
serialize_mdspan(handle_, of, sizes_host.view());
serialize_mdspan(handle_, os, sizes_host.view());
auto list_store_spec = list_spec<uint32_t>{index.pq_bits(), index.pq_dim(), true};
for (uint32_t label = 0; label < index.n_lists(); label++) {
ivf::serialize_list<list_spec, IdxT, uint32_t>(
handle_, of, index.lists()[label], list_store_spec, sizes_host(label));
handle_, os, index.lists()[label], list_store_spec, sizes_host(label));
}
}

/**
* Save the index to file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @param[in] handle the raft handle
* @param[in] filename the file name for saving the index
* @param[in] index IVF-PQ index
*
*/
template <typename IdxT>
void serialize(raft::device_resources const& handle_,
const std::string& filename,
const index<IdxT>& index)
{
std::ofstream of(filename, std::ios::out | std::ios::binary);
if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); }

detail::serialize(handle_, of, index);

of.close();
if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); }
return;
}

/**
* Load index from file.
* Load index from input stream
*
* Experimental, both the API and the serialization format are subject to change.
*
* @param[in] handle the raft handle
* @param[in] filename the name of the file that stores the index
* @param[in] index IVF-PQ index
* @param[in] is input stream
*
*/
template <typename IdxT>
auto deserialize(raft::device_resources const& handle_, const std::string& filename) -> index<IdxT>
auto deserialize(raft::device_resources const& handle_, std::istream& is) -> index<IdxT>
{
std::ifstream infile(filename, std::ios::in | std::ios::binary);

if (!infile) { RAFT_FAIL("Cannot open file %s", filename.c_str()); }

auto ver = deserialize_scalar<int>(handle_, infile);
auto ver = deserialize_scalar<int>(handle_, is);
if (ver != kSerializationVersion) {
RAFT_FAIL("serialization version mismatch %d vs. %d", ver, kSerializationVersion);
}
auto n_rows = deserialize_scalar<IdxT>(handle_, infile);
auto dim = deserialize_scalar<std::uint32_t>(handle_, infile);
auto pq_bits = deserialize_scalar<std::uint32_t>(handle_, infile);
auto pq_dim = deserialize_scalar<std::uint32_t>(handle_, infile);
auto cma = deserialize_scalar<bool>(handle_, infile);
auto n_rows = deserialize_scalar<IdxT>(handle_, is);
auto dim = deserialize_scalar<std::uint32_t>(handle_, is);
auto pq_bits = deserialize_scalar<std::uint32_t>(handle_, is);
auto pq_dim = deserialize_scalar<std::uint32_t>(handle_, is);
auto cma = deserialize_scalar<bool>(handle_, is);

auto metric = deserialize_scalar<raft::distance::DistanceType>(handle_, infile);
auto codebook_kind = deserialize_scalar<raft::neighbors::ivf_pq::codebook_gen>(handle_, infile);
auto n_lists = deserialize_scalar<std::uint32_t>(handle_, infile);
auto metric = deserialize_scalar<raft::distance::DistanceType>(handle_, is);
auto codebook_kind = deserialize_scalar<raft::neighbors::ivf_pq::codebook_gen>(handle_, is);
auto n_lists = deserialize_scalar<std::uint32_t>(handle_, is);

RAFT_LOG_DEBUG("n_rows %zu, dim %d, pq_dim %d, pq_bits %d, n_lists %d",
static_cast<std::size_t>(n_rows),
Expand All @@ -146,24 +157,46 @@ auto deserialize(raft::device_resources const& handle_, const std::string& filen
auto index = raft::neighbors::ivf_pq::index<IdxT>(
handle_, metric, codebook_kind, n_lists, dim, pq_bits, pq_dim, cma);

deserialize_mdspan(handle_, infile, index.pq_centers());
deserialize_mdspan(handle_, infile, index.centers());
deserialize_mdspan(handle_, infile, index.centers_rot());
deserialize_mdspan(handle_, infile, index.rotation_matrix());
deserialize_mdspan(handle_, infile, index.list_sizes());
deserialize_mdspan(handle_, is, index.pq_centers());
deserialize_mdspan(handle_, is, index.centers());
deserialize_mdspan(handle_, is, index.centers_rot());
deserialize_mdspan(handle_, is, index.rotation_matrix());
deserialize_mdspan(handle_, is, index.list_sizes());
auto list_device_spec = list_spec<uint32_t>{pq_bits, pq_dim, cma};
auto list_store_spec = list_spec<uint32_t>{pq_bits, pq_dim, true};
for (auto& list : index.lists()) {
ivf::deserialize_list<list_spec, IdxT, uint32_t>(
handle_, infile, list, list_store_spec, list_device_spec);
handle_, is, list, list_store_spec, list_device_spec);
}

handle_.sync_stream();
infile.close();

recompute_internal_state(handle_, index);

return index;
}

/**
* Load index from file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @param[in] handle the raft handle
* @param[in] filename the name of the file that stores the index
*
*/
template <typename IdxT>
auto deserialize(raft::device_resources const& handle_, const std::string& filename) -> index<IdxT>
{
std::ifstream infile(filename, std::ios::in | std::ios::binary);

if (!infile) { RAFT_FAIL("Cannot open file %s", filename.c_str()); }

auto index = detail::deserialize<IdxT>(handle_, infile);

infile.close();

return index;
}

} // namespace raft::neighbors::ivf_pq::detail
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

#include <raft/neighbors/detail/ivf_pq_build.cuh>
#include <raft/neighbors/detail/ivf_pq_search.cuh>
#include <raft/neighbors/detail/ivf_pq_serialize.cuh>
#include <raft/neighbors/ivf_pq_serialize.cuh>
#include <raft/neighbors/ivf_pq_types.hpp>

#include <raft/core/device_resources.hpp>
Expand Down
150 changes: 150 additions & 0 deletions cpp/include/raft/neighbors/ivf_pq_serialize.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "detail/ivf_pq_serialize.cuh"

namespace raft::neighbors::ivf_pq {

/**
* \defgroup ivf_pq_serialize IVF-PQ Serialize
* @{
*/

/**
* Write the index to an output stream
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/device_resources.hpp>
*
* raft::device_resources handle;
*
* // create an output stream
* std::ostream os(std::cout.rdbuf());
* // create an index with `auto index = ivf_pq::build(...);`
* raft::serailize(handle, os, index);
* @endcode
*
* @tparam IdxT type of the index
*
* @param[in] handle the raft handle
* @param[in] os output stream
* @param[in] index IVF-PQ index
*
* @return raft::neighbors::ivf_pq::index<IdxT>
*/
template <typename IdxT>
void serialize(raft::device_resources const& handle, std::ostream& os, const index<IdxT>& index)
{
detail::serialize(handle, os, index);
}

/**
* Save the index to file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/device_resources.hpp>
*
* raft::device_resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* // create an index with `auto index = ivf_pq::build(...);`
* raft::serailize(handle, filename, index);
* @endcode
*
* @tparam IdxT type of the index
*
* @param[in] handle the raft handle
* @param[in] filename the file name for saving the index
* @param[in] index IVF-PQ index
*
* @return raft::neighbors::ivf_pq::index<IdxT>
*/
template <typename IdxT>
void serialize(raft::device_resources const& handle,
const std::string& filename,
const index<IdxT>& index)
{
detail::serialize(handle, filename, index);
}

/**
* Load index from input stream
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/device_resources.hpp>
*
* raft::device_resources handle;
*
* // create an input stream
* std::istream is(std::cin.rdbuf());
* using IdxT = int; // type of the index
* auto index = raft::deserialize<IdxT>(handle, is);
* @endcode
*
* @tparam IdxT type of the index
*
* @param[in] handle the raft handle
* @param[in] is input stream
*
* @return raft::neighbors::ivf_pq::index<IdxT>
*/
template <typename IdxT>
index<IdxT> deserialize(raft::device_resources const& handle, std::istream& is)
{
return detail::deserialize<IdxT>(handle, is);
}

/**
* Load index from file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/device_resources.hpp>
*
* raft::device_resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* using IdxT = int; // type of the index
* auto index = raft::deserialize<IdxT>(handle, filename);
* @endcode
*
* @tparam IdxT type of the index
*
* @param[in] handle the raft handle
* @param[in] filename the name of the file that stores the index
*
* @return raft::neighbors::ivf_pq::index<IdxT>
*/
template <typename IdxT>
index<IdxT> deserialize(raft::device_resources const& handle, const std::string& filename)
{
return detail::deserialize<IdxT>(handle, filename);
}

/**@}*/

} // namespace raft::neighbors::ivf_pq
2 changes: 1 addition & 1 deletion cpp/src/distance/neighbors/ivfpq_deserialize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ void deserialize(raft::device_resources const& handle,
raft::neighbors::ivf_pq::index<uint64_t>* index)
{
if (!index) { RAFT_FAIL("Invalid index pointer"); }
*index = raft::neighbors::ivf_pq::detail::deserialize<uint64_t>(handle, filename);
*index = raft::neighbors::ivf_pq::deserialize<uint64_t>(handle, filename);
};
} // namespace raft::runtime::neighbors::ivf_pq
2 changes: 1 addition & 1 deletion cpp/src/distance/neighbors/ivfpq_serialize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ void serialize(raft::device_resources const& handle,
const std::string& filename,
const raft::neighbors::ivf_pq::index<uint64_t>& index)
{
raft::neighbors::ivf_pq::detail::serialize(handle, filename, index);
raft::neighbors::ivf_pq::serialize(handle, filename, index);
};

} // namespace raft::runtime::neighbors::ivf_pq
4 changes: 2 additions & 2 deletions cpp/test/neighbors/ann_ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {

auto build_serialize()
{
ivf_pq::detail::serialize<IdxT>(handle_, "ivf_pq_index", build_only());
return ivf_pq::detail::deserialize<IdxT>(handle_, "ivf_pq_index");
ivf_pq::serialize<IdxT>(handle_, "ivf_pq_index", build_only());
return ivf_pq::deserialize<IdxT>(handle_, "ivf_pq_index");
}

template <typename BuildIndex>
Expand Down
Loading

0 comments on commit 3ca7eac

Please sign in to comment.