From 3ca7eacc5cb411facdfb08ff27663a3402486bc4 Mon Sep 17 00:00:00 2001 From: Divye Gala Date: Tue, 7 Mar 2023 12:05:44 -0500 Subject: [PATCH] Add stream overloads to `ivf_pq` serialize/deserialize methods (#1315) Authors: - Divye Gala (https://github.com/divyegala) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - William Hicks (https://github.com/wphicks) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1315 --- .../neighbors/detail/ivf_pq_serialize.cuh | 127 +++++++++------ cpp/include/raft/neighbors/ivf_pq.cuh | 2 +- .../raft/neighbors/ivf_pq_serialize.cuh | 150 ++++++++++++++++++ .../distance/neighbors/ivfpq_deserialize.cu | 2 +- cpp/src/distance/neighbors/ivfpq_serialize.cu | 2 +- cpp/test/neighbors/ann_ivf_pq.cuh | 4 +- docs/source/cpp_api/neighbors_ivf_pq.rst | 14 ++ 7 files changed, 249 insertions(+), 52 deletions(-) create mode 100644 cpp/include/raft/neighbors/ivf_pq_serialize.cuh diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh index 33d9b363ba..0701b0feb5 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh @@ -48,44 +48,39 @@ struct check_index_layout { template struct check_index_layout), 536>; /** - * Save the index to file. + * Write the index to an output stream * * Experimental, both the API and the serialization format are subject to change. * * @param[in] handle the raft handle - * @param[in] filename the file name for saving the index + * @param[in] os output stream * @param[in] index IVF-PQ index * */ template -void serialize(raft::device_resources const& handle_, - const std::string& filename, - const index& index) +void serialize(raft::device_resources const& handle_, std::ostream& os, const index& index) { - std::ofstream of(filename, std::ios::out | std::ios::binary); - if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - RAFT_LOG_DEBUG("Size %zu, dim %d, pq_dim %d, pq_bits %d", static_cast(index.size()), static_cast(index.dim()), static_cast(index.pq_dim()), static_cast(index.pq_bits())); - serialize_scalar(handle_, of, kSerializationVersion); - serialize_scalar(handle_, of, index.size()); - serialize_scalar(handle_, of, index.dim()); - serialize_scalar(handle_, of, index.pq_bits()); - serialize_scalar(handle_, of, index.pq_dim()); - serialize_scalar(handle_, of, index.conservative_memory_allocation()); + serialize_scalar(handle_, os, kSerializationVersion); + serialize_scalar(handle_, os, index.size()); + serialize_scalar(handle_, os, index.dim()); + serialize_scalar(handle_, os, index.pq_bits()); + serialize_scalar(handle_, os, index.pq_dim()); + serialize_scalar(handle_, os, index.conservative_memory_allocation()); - serialize_scalar(handle_, of, index.metric()); - serialize_scalar(handle_, of, index.codebook_kind()); - serialize_scalar(handle_, of, index.n_lists()); + serialize_scalar(handle_, os, index.metric()); + serialize_scalar(handle_, os, index.codebook_kind()); + serialize_scalar(handle_, os, index.n_lists()); - serialize_mdspan(handle_, of, index.pq_centers()); - serialize_mdspan(handle_, of, index.centers()); - serialize_mdspan(handle_, of, index.centers_rot()); - serialize_mdspan(handle_, of, index.rotation_matrix()); + serialize_mdspan(handle_, os, index.pq_centers()); + serialize_mdspan(handle_, os, index.centers()); + serialize_mdspan(handle_, os, index.centers_rot()); + serialize_mdspan(handle_, os, index.rotation_matrix()); auto sizes_host = make_host_mdarray(index.list_sizes().extents()); copy(sizes_host.data_handle(), @@ -93,12 +88,33 @@ void serialize(raft::device_resources const& handle_, sizes_host.size(), handle_.get_stream()); handle_.sync_stream(); - serialize_mdspan(handle_, of, sizes_host.view()); + serialize_mdspan(handle_, os, sizes_host.view()); auto list_store_spec = list_spec{index.pq_bits(), index.pq_dim(), true}; for (uint32_t label = 0; label < index.n_lists(); label++) { ivf::serialize_list( - handle_, of, index.lists()[label], list_store_spec, sizes_host(label)); + handle_, os, index.lists()[label], list_store_spec, sizes_host(label)); } +} + +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index IVF-PQ index + * + */ +template +void serialize(raft::device_resources const& handle_, + const std::string& filename, + const index& index) +{ + std::ofstream of(filename, std::ios::out | std::ios::binary); + if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + + detail::serialize(handle_, of, index); of.close(); if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } @@ -106,35 +122,30 @@ void serialize(raft::device_resources const& handle_, } /** - * Load index from file. + * Load index from input stream * * Experimental, both the API and the serialization format are subject to change. * * @param[in] handle the raft handle - * @param[in] filename the name of the file that stores the index - * @param[in] index IVF-PQ index + * @param[in] is input stream * */ template -auto deserialize(raft::device_resources const& handle_, const std::string& filename) -> index +auto deserialize(raft::device_resources const& handle_, std::istream& is) -> index { - std::ifstream infile(filename, std::ios::in | std::ios::binary); - - if (!infile) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - - auto ver = deserialize_scalar(handle_, infile); + auto ver = deserialize_scalar(handle_, is); if (ver != kSerializationVersion) { RAFT_FAIL("serialization version mismatch %d vs. %d", ver, kSerializationVersion); } - auto n_rows = deserialize_scalar(handle_, infile); - auto dim = deserialize_scalar(handle_, infile); - auto pq_bits = deserialize_scalar(handle_, infile); - auto pq_dim = deserialize_scalar(handle_, infile); - auto cma = deserialize_scalar(handle_, infile); + auto n_rows = deserialize_scalar(handle_, is); + auto dim = deserialize_scalar(handle_, is); + auto pq_bits = deserialize_scalar(handle_, is); + auto pq_dim = deserialize_scalar(handle_, is); + auto cma = deserialize_scalar(handle_, is); - auto metric = deserialize_scalar(handle_, infile); - auto codebook_kind = deserialize_scalar(handle_, infile); - auto n_lists = deserialize_scalar(handle_, infile); + auto metric = deserialize_scalar(handle_, is); + auto codebook_kind = deserialize_scalar(handle_, is); + auto n_lists = deserialize_scalar(handle_, is); RAFT_LOG_DEBUG("n_rows %zu, dim %d, pq_dim %d, pq_bits %d, n_lists %d", static_cast(n_rows), @@ -146,24 +157,46 @@ auto deserialize(raft::device_resources const& handle_, const std::string& filen auto index = raft::neighbors::ivf_pq::index( handle_, metric, codebook_kind, n_lists, dim, pq_bits, pq_dim, cma); - deserialize_mdspan(handle_, infile, index.pq_centers()); - deserialize_mdspan(handle_, infile, index.centers()); - deserialize_mdspan(handle_, infile, index.centers_rot()); - deserialize_mdspan(handle_, infile, index.rotation_matrix()); - deserialize_mdspan(handle_, infile, index.list_sizes()); + deserialize_mdspan(handle_, is, index.pq_centers()); + deserialize_mdspan(handle_, is, index.centers()); + deserialize_mdspan(handle_, is, index.centers_rot()); + deserialize_mdspan(handle_, is, index.rotation_matrix()); + deserialize_mdspan(handle_, is, index.list_sizes()); auto list_device_spec = list_spec{pq_bits, pq_dim, cma}; auto list_store_spec = list_spec{pq_bits, pq_dim, true}; for (auto& list : index.lists()) { ivf::deserialize_list( - handle_, infile, list, list_store_spec, list_device_spec); + handle_, is, list, list_store_spec, list_device_spec); } handle_.sync_stream(); - infile.close(); recompute_internal_state(handle_, index); return index; } +/** + * Load index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * + */ +template +auto deserialize(raft::device_resources const& handle_, const std::string& filename) -> index +{ + std::ifstream infile(filename, std::ios::in | std::ios::binary); + + if (!infile) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + + auto index = detail::deserialize(handle_, infile); + + infile.close(); + + return index; +} + } // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/include/raft/neighbors/ivf_pq.cuh b/cpp/include/raft/neighbors/ivf_pq.cuh index e2cc3c4728..053fe634da 100644 --- a/cpp/include/raft/neighbors/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/ivf_pq.cuh @@ -18,7 +18,7 @@ #include #include -#include +#include #include #include diff --git a/cpp/include/raft/neighbors/ivf_pq_serialize.cuh b/cpp/include/raft/neighbors/ivf_pq_serialize.cuh new file mode 100644 index 0000000000..9a719c69d4 --- /dev/null +++ b/cpp/include/raft/neighbors/ivf_pq_serialize.cuh @@ -0,0 +1,150 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "detail/ivf_pq_serialize.cuh" + +namespace raft::neighbors::ivf_pq { + +/** + * \defgroup ivf_pq_serialize IVF-PQ Serialize + * @{ + */ + +/** + * Write the index to an output stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // create an output stream + * std::ostream os(std::cout.rdbuf()); + * // create an index with `auto index = ivf_pq::build(...);` + * raft::serailize(handle, os, index); + * @endcode + * + * @tparam IdxT type of the index + * + * @param[in] handle the raft handle + * @param[in] os output stream + * @param[in] index IVF-PQ index + * + * @return raft::neighbors::ivf_pq::index + */ +template +void serialize(raft::device_resources const& handle, std::ostream& os, const index& index) +{ + detail::serialize(handle, os, index); +} + +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an index with `auto index = ivf_pq::build(...);` + * raft::serailize(handle, filename, index); + * @endcode + * + * @tparam IdxT type of the index + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index IVF-PQ index + * + * @return raft::neighbors::ivf_pq::index + */ +template +void serialize(raft::device_resources const& handle, + const std::string& filename, + const index& index) +{ + detail::serialize(handle, filename, index); +} + +/** + * Load index from input stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // create an input stream + * std::istream is(std::cin.rdbuf()); + * using IdxT = int; // type of the index + * auto index = raft::deserialize(handle, is); + * @endcode + * + * @tparam IdxT type of the index + * + * @param[in] handle the raft handle + * @param[in] is input stream + * + * @return raft::neighbors::ivf_pq::index + */ +template +index deserialize(raft::device_resources const& handle, std::istream& is) +{ + return detail::deserialize(handle, is); +} + +/** + * Load index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * using IdxT = int; // type of the index + * auto index = raft::deserialize(handle, filename); + * @endcode + * + * @tparam IdxT type of the index + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * + * @return raft::neighbors::ivf_pq::index + */ +template +index deserialize(raft::device_resources const& handle, const std::string& filename) +{ + return detail::deserialize(handle, filename); +} + +/**@}*/ + +} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/distance/neighbors/ivfpq_deserialize.cu b/cpp/src/distance/neighbors/ivfpq_deserialize.cu index 403f80c9fc..8f71e5622b 100644 --- a/cpp/src/distance/neighbors/ivfpq_deserialize.cu +++ b/cpp/src/distance/neighbors/ivfpq_deserialize.cu @@ -24,6 +24,6 @@ void deserialize(raft::device_resources const& handle, raft::neighbors::ivf_pq::index* index) { if (!index) { RAFT_FAIL("Invalid index pointer"); } - *index = raft::neighbors::ivf_pq::detail::deserialize(handle, filename); + *index = raft::neighbors::ivf_pq::deserialize(handle, filename); }; } // namespace raft::runtime::neighbors::ivf_pq diff --git a/cpp/src/distance/neighbors/ivfpq_serialize.cu b/cpp/src/distance/neighbors/ivfpq_serialize.cu index f6fd70be82..b7ceb9150a 100644 --- a/cpp/src/distance/neighbors/ivfpq_serialize.cu +++ b/cpp/src/distance/neighbors/ivfpq_serialize.cu @@ -23,7 +23,7 @@ void serialize(raft::device_resources const& handle, const std::string& filename, const raft::neighbors::ivf_pq::index& index) { - raft::neighbors::ivf_pq::detail::serialize(handle, filename, index); + raft::neighbors::ivf_pq::serialize(handle, filename, index); }; } // namespace raft::runtime::neighbors::ivf_pq diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index df295b8bcb..91294a859a 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -212,8 +212,8 @@ class ivf_pq_test : public ::testing::TestWithParam { auto build_serialize() { - ivf_pq::detail::serialize(handle_, "ivf_pq_index", build_only()); - return ivf_pq::detail::deserialize(handle_, "ivf_pq_index"); + ivf_pq::serialize(handle_, "ivf_pq_index", build_only()); + return ivf_pq::deserialize(handle_, "ivf_pq_index"); } template diff --git a/docs/source/cpp_api/neighbors_ivf_pq.rst b/docs/source/cpp_api/neighbors_ivf_pq.rst index d22ea6231f..228833983c 100644 --- a/docs/source/cpp_api/neighbors_ivf_pq.rst +++ b/docs/source/cpp_api/neighbors_ivf_pq.rst @@ -14,4 +14,18 @@ namespace *raft::neighbors::ivf_pq* :members: :content-only: +Serializer Methods +------------------ +``#include `` +.. doxygenfunction:: serialize(raft::device_resources const& handle, std::ostream& os, const index& index) + :project: RAFT + +.. doxygenfunction:: serialize(raft::device_resources const& handle, const std::string& filename, const index& index) + :project: RAFT + +.. doxygenfunction:: deserialize(raft::device_resources const& handle, std::istream& is) + :project: RAFT + +.. doxygenfunction:: deserialize(raft::device_resources const& handle, const std::string& filename) + :project: RAFT \ No newline at end of file