diff --git a/.gitignore b/.gitignore index 80709dbb96..c2528d2cd0 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,7 @@ log .DS_Store dask-worker-space/ *.egg-info/ +*.bin ## scikit-build _skbuild diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index d6e85d786b..2999045a0c 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -368,6 +368,17 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/matrix/specializations/detail/select_k_float_int64_t.cu src/distance/matrix/specializations/detail/select_k_half_uint32_t.cu src/distance/matrix/specializations/detail/select_k_half_int64_t.cu + src/distance/neighbors/ivf_flat_search.cu + src/distance/neighbors/ivf_flat_build.cu + src/distance/neighbors/specializations/ivfflat_build_float_int64_t.cu + src/distance/neighbors/specializations/ivfflat_build_int8_t_int64_t.cu + src/distance/neighbors/specializations/ivfflat_build_uint8_t_int64_t.cu + src/distance/neighbors/specializations/ivfflat_extend_float_int64_t.cu + src/distance/neighbors/specializations/ivfflat_extend_int8_t_int64_t.cu + src/distance/neighbors/specializations/ivfflat_extend_uint8_t_int64_t.cu + src/distance/neighbors/specializations/ivfflat_search_float_int64_t.cu + src/distance/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu + src/distance/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu src/distance/neighbors/ivfpq_build.cu src/distance/neighbors/ivfpq_deserialize.cu src/distance/neighbors/ivfpq_serialize.cu diff --git a/cpp/include/raft/neighbors/ivf_flat.cuh b/cpp/include/raft/neighbors/ivf_flat.cuh index f42bfe66c7..c573676504 100644 --- a/cpp/include/raft/neighbors/ivf_flat.cuh +++ b/cpp/include/raft/neighbors/ivf_flat.cuh @@ -109,8 +109,9 @@ auto build(raft::device_resources const& handle, */ template auto build(raft::device_resources const& handle, - raft::device_matrix_view dataset, - const index_params& params) -> index + const index_params& params, + raft::device_matrix_view dataset) + -> index { return raft::neighbors::ivf_flat::detail::build(handle, params, @@ -119,6 +120,52 @@ auto build(raft::device_resources const& handle, static_cast(dataset.extent(1))); } +/** + * @brief Build the index from the dataset for efficient search. + * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * // use default index parameters + * ivf_flat::index_params index_params; + * // create and fill the index from a [N, D] dataset + * ivf_flat::index index; + * ivf_flat::build(handle, dataset, index_params, index); + * // use default search parameters + * ivf_flat::search_params search_params; + * // search K nearest neighbours for each of the N queries + * ivf_flat::search(handle, index, queries, out_inds, out_dists, search_params, k); + * @endcode + * + * @tparam value_t data element type + * @tparam idx_t type of the indices in the source dataset + * @tparam int_t precision / type of integral arguments + * @tparam matrix_idx_t matrix indexing type + * + * @param[in] handle + * @param[in] params configure the index building + * @param[in] dataset raft::device_matrix_view to a row-major matrix [n_rows, dim] + * @param[out] idx reference to ivf_flat::index + * + */ +template +void build(raft::device_resources const& handle, + const index_params& params, + raft::device_matrix_view dataset, + raft::neighbors::ivf_flat::index& idx) +{ + idx = raft::neighbors::ivf_flat::detail::build(handle, + params, + dataset.data_handle(), + static_cast(dataset.extent(0)), + static_cast(dataset.extent(1))); +} + /** @} */ /** @@ -192,20 +239,19 @@ auto extend(raft::device_resources const& handle, * @tparam idx_t type of the indices in the source dataset * * @param[in] handle - * @param[in] orig_index original index - * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices a device pointer to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices optional raft::device_matrix_view to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` * here to imply a continuous range `[0...n_rows)`. + * @param[in] orig_index original index * * @return the constructed extended ivf-flat index */ template auto extend(raft::device_resources const& handle, - const index& orig_index, raft::device_matrix_view new_vectors, - std::optional> new_indices = std::nullopt) - -> index + std::optional> new_indices, + const index& orig_index) -> index { return extend( handle, @@ -270,24 +316,25 @@ void extend(raft::device_resources const& handle, * // train the index from a [N, D] dataset * auto index_empty = ivf_flat::build(handle, dataset, index_params, dataset); * // fill the index with the data - * ivf_flat::extend(handle, index_empty, dataset); + * std::optional> no_op = std::nullopt; + * ivf_flat::extend(handle, dataset, no_opt, &index_empty); * @endcode * * @tparam value_t data element type * @tparam idx_t type of the indices in the source dataset * * @param[in] handle - * @param[inout] index - * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices a device pointer to a vector of indices [n_rows]. + * @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices optional raft::device_matrix_view to a vector of indices [n_rows]. * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` * here to imply a continuous range `[0...n_rows)`. + * @param[inout] index pointer to index, to be overwritten in-place */ template void extend(raft::device_resources const& handle, - index* index, raft::device_matrix_view new_vectors, - std::optional> new_indices = std::nullopt) + std::optional> new_indices, + index* index) { extend(handle, index, @@ -386,30 +433,27 @@ void search(raft::device_resources const& handle, * @tparam int_t precision / type of integral arguments * * @param[in] handle + * @param[in] params configure the search * @param[in] index ivf-flat constructed index * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset * [n_queries, k] * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param[in] params configure the search - * @param[in] k the number of neighbors to find for each query. */ -template +template void search(raft::device_resources const& handle, + const search_params& params, const index& index, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - const search_params& params, - int_t k) + raft::device_matrix_view distances) { RAFT_EXPECTS( queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), "Number of rows in output neighbors and distances matrices must equal the number of queries."); - RAFT_EXPECTS( - neighbors.extent(1) == distances.extent(1) && neighbors.extent(1) == static_cast(k), - "Number of columns in output neighbors and distances matrices must equal k"); + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), + "Number of columns in output neighbors and distances matrices must be equal"); RAFT_EXPECTS(queries.extent(1) == index.dim(), "Number of query dimensions should equal number of dimensions in the index."); @@ -419,7 +463,7 @@ void search(raft::device_resources const& handle, index, queries.data_handle(), static_cast(queries.extent(0)), - static_cast(k), + static_cast(neighbors.extent(1)), neighbors.data_handle(), distances.data_handle(), nullptr); diff --git a/cpp/include/raft/neighbors/specializations.cuh b/cpp/include/raft/neighbors/specializations.cuh index f9bdda4e49..27105b6eab 100644 --- a/cpp/include/raft/neighbors/specializations.cuh +++ b/cpp/include/raft/neighbors/specializations.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include diff --git a/cpp/include/raft/neighbors/specializations/ivf_flat.cuh b/cpp/include/raft/neighbors/specializations/ivf_flat.cuh new file mode 100644 index 0000000000..02e1cbebb0 --- /dev/null +++ b/cpp/include/raft/neighbors/specializations/ivf_flat.cuh @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::neighbors::ivf_flat { + +#define RAFT_INST(T, IdxT) \ + extern template auto build(raft::device_resources const& handle, \ + const index_params& params, \ + raft::device_matrix_view dataset) \ + ->index; \ + \ + extern template auto extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const index& orig_index) \ + ->index; \ + \ + extern template void extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_flat::index* idx); \ + \ + extern template void search(raft::device_resources const&, \ + raft::neighbors::ivf_flat::search_params const&, \ + const raft::neighbors::ivf_flat::index&, \ + raft::device_matrix_view, \ + raft::device_matrix_view, \ + raft::device_matrix_view); + +RAFT_INST(float, uint64_t); +RAFT_INST(int8_t, uint64_t); +RAFT_INST(uint8_t, uint64_t); + +#undef RAFT_INST +} // namespace raft::neighbors::ivf_flat diff --git a/cpp/include/raft_runtime/neighbors/ivf_flat.hpp b/cpp/include/raft_runtime/neighbors/ivf_flat.hpp new file mode 100644 index 0000000000..18ea064015 --- /dev/null +++ b/cpp/include/raft_runtime/neighbors/ivf_flat.hpp @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::runtime::neighbors::ivf_flat { + +// We define overloads for build and extend with void return type. This is used in the Cython +// wrappers, where exception handling is not compatible with return type that has nontrivial +// constructor. +#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ + auto build(raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset) \ + ->raft::neighbors::ivf_flat::index; \ + \ + auto extend(raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_flat::index& orig_index) \ + ->raft::neighbors::ivf_flat::index; \ + \ + void build(raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::ivf_flat::index& idx); \ + \ + void extend(raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_flat::index* idx); + +RAFT_INST_BUILD_EXTEND(float, int64_t) +RAFT_INST_BUILD_EXTEND(int8_t, int64_t) +RAFT_INST_BUILD_EXTEND(uint8_t, int64_t) + +#undef RAFT_INST_BUILD_EXTEND + +#define RAFT_INST_SEARCH(T, IdxT) \ + void search(raft::device_resources const&, \ + raft::neighbors::ivf_flat::search_params const&, \ + raft::neighbors::ivf_flat::index const&, \ + raft::device_matrix_view, \ + raft::device_matrix_view, \ + raft::device_matrix_view); + +RAFT_INST_SEARCH(float, int64_t); +RAFT_INST_SEARCH(int8_t, int64_t); +RAFT_INST_SEARCH(uint8_t, int64_t); + +#undef RAFT_INST_SEARCH + +} // namespace raft::runtime::neighbors::ivf_flat diff --git a/cpp/src/distance/neighbors/ivf_flat_build.cu b/cpp/src/distance/neighbors/ivf_flat_build.cu new file mode 100644 index 0000000000..0d82fdbb08 --- /dev/null +++ b/cpp/src/distance/neighbors/ivf_flat_build.cu @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::runtime::neighbors::ivf_flat { + +#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ + auto build(raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset) \ + ->raft::neighbors::ivf_flat::index \ + { \ + return raft::neighbors::ivf_flat::build(handle, params, dataset); \ + } \ + auto extend(raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_flat::index& orig_index) \ + ->raft::neighbors::ivf_flat::index \ + { \ + return raft::neighbors::ivf_flat::extend( \ + handle, new_vectors, new_indices, orig_index); \ + } \ + \ + void build(raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::ivf_flat::index& idx) \ + { \ + idx = build(handle, params, dataset); \ + } \ + \ + void extend(raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_flat::index* idx) \ + { \ + raft::neighbors::ivf_flat::extend(handle, new_vectors, new_indices, idx); \ + } + +RAFT_INST_BUILD_EXTEND(float, int64_t); +RAFT_INST_BUILD_EXTEND(int8_t, int64_t); +RAFT_INST_BUILD_EXTEND(uint8_t, int64_t); + +#undef RAFT_INST_BUILD_EXTEND + +} // namespace raft::runtime::neighbors::ivf_flat diff --git a/cpp/src/distance/neighbors/ivf_flat_search.cu b/cpp/src/distance/neighbors/ivf_flat_search.cu new file mode 100644 index 0000000000..b843ee7c30 --- /dev/null +++ b/cpp/src/distance/neighbors/ivf_flat_search.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::runtime::neighbors::ivf_flat { + +#define RAFT_INST_SEARCH(T, IdxT) \ + void search(raft::device_resources const& handle, \ + raft::neighbors::ivf_flat::search_params const& params, \ + const raft::neighbors::ivf_flat::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances) \ + { \ + raft::neighbors::ivf_flat::search( \ + handle, params, index, queries, neighbors, distances); \ + } + +RAFT_INST_SEARCH(float, int64_t); +RAFT_INST_SEARCH(int8_t, int64_t); +RAFT_INST_SEARCH(uint8_t, int64_t); + +#undef RAFT_INST_SEARCH + +} // namespace raft::runtime::neighbors::ivf_flat diff --git a/cpp/src/distance/neighbors/specializations/ivfflat_build_float_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfflat_build_float_int64_t.cu new file mode 100644 index 0000000000..7082873d76 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/ivfflat_build_float_int64_t.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors::ivf_flat { + +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template auto build(raft::device_resources const& handle, \ + const index_params& params, \ + raft::device_matrix_view dataset) \ + ->index; + +RAFT_MAKE_INSTANCE(float, int64_t); + +#undef RAFT_MAKE_INSTANCE + +} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/distance/neighbors/specializations/ivfflat_build_int8_t_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfflat_build_int8_t_int64_t.cu new file mode 100644 index 0000000000..ebc1a7fefa --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/ivfflat_build_int8_t_int64_t.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors::ivf_flat { + +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template auto build(raft::device_resources const& handle, \ + const index_params& params, \ + raft::device_matrix_view dataset) \ + ->index; + +RAFT_MAKE_INSTANCE(int8_t, int64_t); + +#undef RAFT_MAKE_INSTANCE + +} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/distance/neighbors/specializations/ivfflat_build_uint8_t_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfflat_build_uint8_t_int64_t.cu new file mode 100644 index 0000000000..870db6e97e --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/ivfflat_build_uint8_t_int64_t.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors::ivf_flat { + +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template auto build(raft::device_resources const& handle, \ + const index_params& params, \ + raft::device_matrix_view dataset) \ + ->index; + +RAFT_MAKE_INSTANCE(uint8_t, int64_t); + +#undef RAFT_MAKE_INSTANCE + +} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/distance/neighbors/specializations/ivfflat_extend_float_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfflat_extend_float_int64_t.cu new file mode 100644 index 0000000000..71af06ad71 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/ivfflat_extend_float_int64_t.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors::ivf_flat { + +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template auto extend(raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_flat::index& orig_index) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void extend(raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_flat::index* idx); + +RAFT_MAKE_INSTANCE(float, int64_t); + +#undef RAFT_MAKE_INSTANCE + +} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/distance/neighbors/specializations/ivfflat_extend_int8_t_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfflat_extend_int8_t_int64_t.cu new file mode 100644 index 0000000000..bb7bb6e7eb --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/ivfflat_extend_int8_t_int64_t.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors::ivf_flat { + +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template auto extend(raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_flat::index& orig_index) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void extend(raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_flat::index* idx); + +RAFT_MAKE_INSTANCE(int8_t, int64_t); + +#undef RAFT_MAKE_INSTANCE + +} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/distance/neighbors/specializations/ivfflat_extend_uint8_t_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfflat_extend_uint8_t_int64_t.cu new file mode 100644 index 0000000000..607b4b0913 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/ivfflat_extend_uint8_t_int64_t.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors::ivf_flat { + +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template auto extend(raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_flat::index& orig_index) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void extend(raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_flat::index* idx); + +RAFT_MAKE_INSTANCE(uint8_t, int64_t); + +#undef RAFT_MAKE_INSTANCE + +} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/distance/neighbors/specializations/ivfflat_search_float_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfflat_search_float_int64_t.cu new file mode 100644 index 0000000000..6de65546c8 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/ivfflat_search_float_int64_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors::ivf_flat { + +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template void search(raft::device_resources const&, \ + raft::neighbors::ivf_flat::search_params const&, \ + const raft::neighbors::ivf_flat::index&, \ + raft::device_matrix_view, \ + raft::device_matrix_view, \ + raft::device_matrix_view); + +RAFT_MAKE_INSTANCE(float, int64_t); + +#undef RAFT_MAKE_INSTANCE + +} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/distance/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu new file mode 100644 index 0000000000..8eda240ccd --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors::ivf_flat { + +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template void search(raft::device_resources const&, \ + raft::neighbors::ivf_flat::search_params const&, \ + const raft::neighbors::ivf_flat::index&, \ + raft::device_matrix_view, \ + raft::device_matrix_view, \ + raft::device_matrix_view); + +RAFT_MAKE_INSTANCE(int8_t, int64_t); + +#undef RAFT_MAKE_INSTANCE + +} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/distance/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu new file mode 100644 index 0000000000..8ff6533628 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors::ivf_flat { + +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template void search(raft::device_resources const&, \ + raft::neighbors::ivf_flat::search_params const&, \ + const raft::neighbors::ivf_flat::index&, \ + raft::device_matrix_view, \ + raft::device_matrix_view, \ + raft::device_matrix_view); + +RAFT_MAKE_INSTANCE(uint8_t, int64_t); + +#undef RAFT_MAKE_INSTANCE + +} // namespace raft::neighbors::ivf_flat diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 002e4f07d2..486ff61724 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -166,7 +166,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { auto database_view = raft::make_device_matrix_view( (const DataT*)database.data(), ps.num_db_vecs, ps.dim); - auto index = ivf_flat::build(handle_, database_view, index_params); + auto idx = ivf_flat::build(handle_, index_params, database_view); rmm::device_uvector vector_indices(ps.num_db_vecs, stream_); thrust::sequence(handle_.get_thrust_policy(), @@ -179,7 +179,8 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { auto half_of_data_view = raft::make_device_matrix_view( (const DataT*)database.data(), half_of_data, ps.dim); - auto index_2 = ivf_flat::extend(handle_, index, half_of_data_view); + const std::optional> no_opt = std::nullopt; + index index_2 = ivf_flat::extend(handle_, half_of_data_view, no_opt, idx); auto new_half_of_data_view = raft::make_device_matrix_view( database.data() + half_of_data * ps.dim, IdxT(ps.num_db_vecs) - half_of_data, ps.dim); @@ -188,10 +189,10 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { vector_indices.data() + half_of_data, IdxT(ps.num_db_vecs) - half_of_data); ivf_flat::extend(handle_, - &index_2, new_half_of_data_view, std::make_optional>( - new_half_of_data_indices_view)); + new_half_of_data_indices_view), + &index_2); auto search_queries_view = raft::make_device_matrix_view( search_queries.data(), ps.num_queries, ps.dim); @@ -204,12 +205,11 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { auto index_loaded = ivf_flat::detail::deserialize(handle_, "ivf_flat_index"); ivf_flat::search(handle_, + search_params, index_loaded, search_queries_view, indices_out_view, - dists_out_view, - search_params, - ps.k); + dists_out_view); update_host(distances_ivfflat.data(), distances_ivfflat_dev.data(), queries_size, stream_); update_host(indices_ivfflat.data(), indices_ivfflat_dev.data(), queries_size, stream_); @@ -248,7 +248,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { } else { // The centers must be immutable ASSERT_TRUE(raft::devArrMatch(index_2.centers().data_handle(), - index.centers().data_handle(), + idx.centers().data_handle(), index_2.centers().size(), raft::Compare(), stream_)); diff --git a/python/pylibraft/pylibraft/neighbors/CMakeLists.txt b/python/pylibraft/pylibraft/neighbors/CMakeLists.txt index ae5fae8201..572ea47f4e 100644 --- a/python/pylibraft/pylibraft/neighbors/CMakeLists.txt +++ b/python/pylibraft/pylibraft/neighbors/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -13,7 +13,7 @@ # ============================================================================= # Set the list of Cython files to build -set(cython_sources refine.pyx) +set(cython_sources common.pyx refine.pyx) set(linked_libraries raft::raft raft::distance) # Build all of the Cython targets @@ -23,4 +23,5 @@ rapids_cython_create_modules( LINKED_LIBRARIES "${linked_libraries}" ASSOCIATED_TARGETS raft MODULE_PREFIX neighbors_ ) +add_subdirectory(ivf_flat) add_subdirectory(ivf_pq) diff --git a/python/pylibraft/pylibraft/neighbors/__init__.py b/python/pylibraft/pylibraft/neighbors/__init__.py index dd8cdd8445..f7510ba2db 100644 --- a/python/pylibraft/pylibraft/neighbors/__init__.py +++ b/python/pylibraft/pylibraft/neighbors/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,4 +14,4 @@ # from .refine import refine -__all__ = ["refine"] +__all__ = ["common", "refine"] diff --git a/python/pylibraft/pylibraft/neighbors/common.pxd b/python/pylibraft/pylibraft/neighbors/common.pxd new file mode 100644 index 0000000000..b11ef3176e --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/common.pxd @@ -0,0 +1,24 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +from pylibraft.distance.distance_type cimport DistanceType + + +cdef _get_metric_string(DistanceType metric) diff --git a/python/pylibraft/pylibraft/neighbors/common.pyx b/python/pylibraft/pylibraft/neighbors/common.pyx new file mode 100644 index 0000000000..a8380b589b --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/common.pyx @@ -0,0 +1,61 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +import warnings + +from pylibraft.distance.distance_type cimport DistanceType + + +def _get_metric(metric): + SUPPORTED_DISTANCES = { + "sqeuclidean": DistanceType.L2Expanded, + "euclidean": DistanceType.L2SqrtExpanded, + "inner_product": DistanceType.InnerProduct + } + if metric not in SUPPORTED_DISTANCES: + if metric == "l2_expanded": + warnings.warn("Using l2_expanded as a metric name is deprecated," + " use sqeuclidean instead", FutureWarning) + return DistanceType.L2Expanded + + raise ValueError("metric %s is not supported" % metric) + return SUPPORTED_DISTANCES[metric] + + +cdef _get_metric_string(DistanceType metric): + return {DistanceType.L2Expanded : "sqeuclidean", + DistanceType.InnerProduct: "inner_product", + DistanceType.L2SqrtExpanded: "euclidean"}[metric] + + +def _check_input_array(cai, exp_dt, exp_rows=None, exp_cols=None): + if cai.dtype not in exp_dt: + raise TypeError("dtype %s not supported" % cai.dtype) + + if not cai.c_contiguous: + raise ValueError("Row major input is expected") + + if exp_cols is not None and cai.shape[1] != exp_cols: + raise ValueError("Incorrect number of columns, expected {} got {}" + .format(exp_cols, cai.shape[1])) + + if exp_rows is not None and cai.shape[0] != exp_rows: + raise ValueError("Incorrect number of rows, expected {} , got {}" + .format(exp_rows, cai.shape[0])) diff --git a/python/pylibraft/pylibraft/neighbors/ivf_flat/CMakeLists.txt b/python/pylibraft/pylibraft/neighbors/ivf_flat/CMakeLists.txt new file mode 100644 index 0000000000..f183e17157 --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/ivf_flat/CMakeLists.txt @@ -0,0 +1,24 @@ +# ============================================================================= +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= + +# Set the list of Cython files to build +set(cython_sources ivf_flat.pyx) +set(linked_libraries raft::raft raft::distance) + +# Build all of the Cython targets +rapids_cython_create_modules( + CXX + SOURCE_FILES "${cython_sources}" + LINKED_LIBRARIES "${linked_libraries}" ASSOCIATED_TARGETS raft MODULE_PREFIX neighbors_ivfflat_ +) diff --git a/python/pylibraft/pylibraft/neighbors/ivf_flat/__init__.pxd b/python/pylibraft/pylibraft/neighbors/ivf_flat/__init__.pxd new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/pylibraft/pylibraft/neighbors/ivf_flat/__init__.py b/python/pylibraft/pylibraft/neighbors/ivf_flat/__init__.py new file mode 100644 index 0000000000..58fd88b873 --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/ivf_flat/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .ivf_flat import Index, IndexParams, SearchParams, build, extend, search + +__all__ = [ + "Index", + "IndexParams", + "SearchParams", + "build", + "extend", + "search", +] diff --git a/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/__init__.pxd b/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/__init__.pxd new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/__init__.py b/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/__init__.py new file mode 100644 index 0000000000..8f2cc34855 --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd b/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd new file mode 100644 index 0000000000..31a251e7c2 --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd @@ -0,0 +1,135 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +import numpy as np + +import pylibraft.common.handle + +from cython.operator cimport dereference as deref +from libc.stdint cimport int8_t, int64_t, uint8_t, uint32_t, uintptr_t +from libcpp cimport bool, nullptr +from libcpp.string cimport string + +from rmm._lib.memory_resource cimport device_memory_resource + +from pylibraft.common.cpp.mdspan cimport ( + device_matrix_view, + device_vector_view, + host_matrix_view, + make_device_matrix_view, + make_host_matrix_view, + row_major, +) +from pylibraft.common.cpp.optional cimport optional +from pylibraft.common.handle cimport device_resources +from pylibraft.distance.distance_type cimport DistanceType +from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport ( + ann_index, + ann_index_params, + ann_search_params, +) + + +cdef extern from "raft/neighbors/ivf_flat_types.hpp" \ + namespace "raft::neighbors::ivf_flat" nogil: + + cpdef cppclass index_params(ann_index_params): + uint32_t n_lists + uint32_t kmeans_n_iters + double kmeans_trainset_fraction + bool adaptive_centers + bool conservative_memory_allocation + + cdef cppclass index[T, IdxT](ann_index): + index(const device_resources& handle, + DistanceType metric, + uint32_t n_lists, + bool adaptive_centers, + bool conservative_memory_allocation, + uint32_t dim) + IdxT size() + uint32_t dim() + DistanceType metric() + uint32_t n_lists() + bool adaptive_centers() + + cpdef cppclass search_params(ann_search_params): + uint32_t n_probes + + +cdef extern from "raft_runtime/neighbors/ivf_flat.hpp" \ + namespace "raft::runtime::neighbors::ivf_flat" nogil: + + cdef void build(const device_resources&, + const index_params& params, + device_matrix_view[float, int64_t, row_major] dataset, + index[float, int64_t]& index) except + + + cdef void build(const device_resources& handle, + const index_params& params, + device_matrix_view[int8_t, int64_t, row_major] dataset, + index[int8_t, int64_t]& index) except + + + cdef void build(const device_resources& handle, + const index_params& params, + device_matrix_view[uint8_t, int64_t, row_major] dataset, + index[uint8_t, int64_t]& index) except + + + cdef void extend( + const device_resources& handle, + device_matrix_view[float, int64_t, row_major] new_vectors, + optional[device_vector_view[int64_t, int64_t]] new_indices, + index[float, int64_t]* index) except + + + cdef void extend( + const device_resources& handle, + device_matrix_view[int8_t, int64_t, row_major] new_vectors, + optional[device_vector_view[int64_t, int64_t]] new_indices, + index[int8_t, int64_t]* index) except + + + cdef void extend( + const device_resources& handle, + device_matrix_view[uint8_t, int64_t, row_major] new_vectors, + optional[device_vector_view[int64_t, int64_t]] new_indices, + index[uint8_t, int64_t]* index) except + + + cdef void search( + const device_resources& handle, + const search_params& params, + const index[float, int64_t]& index, + device_matrix_view[float, int64_t, row_major] queries, + device_matrix_view[int64_t, int64_t, row_major] neighbors, + device_matrix_view[float, int64_t, row_major] distances) except + + + cdef void search( + const device_resources& handle, + const search_params& params, + const index[int8_t, int64_t]& index, + device_matrix_view[int8_t, int64_t, row_major] queries, + device_matrix_view[int64_t, int64_t, row_major] neighbors, + device_matrix_view[float, int64_t, row_major] distances) except + + + cdef void search( + const device_resources& handle, + const search_params& params, + const index[uint8_t, int64_t]& index, + device_matrix_view[uint8_t, int64_t, row_major] queries, + device_matrix_view[int64_t, int64_t, row_major] neighbors, + device_matrix_view[float, int64_t, row_major] distances) except + diff --git a/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx b/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx new file mode 100644 index 0000000000..db279ad2db --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx @@ -0,0 +1,710 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +import warnings + +import numpy as np + +from cython.operator cimport dereference as deref +from libc.stdint cimport int8_t, int64_t, uint8_t, uint32_t, uintptr_t +from libcpp cimport bool, nullptr +from libcpp.string cimport string + +from pylibraft.distance.distance_type cimport DistanceType + +from pylibraft.common import ( + DeviceResources, + ai_wrapper, + auto_convert_output, + device_ndarray, +) +from pylibraft.common.cai_wrapper import cai_wrapper + +from pylibraft.common.cpp.mdspan cimport ( + device_matrix_view, + device_vector_view, + make_device_vector_view, + row_major, +) + +from pylibraft.common.interruptible import cuda_interruptible + +from pylibraft.common.handle cimport device_resources + +from pylibraft.common.handle import auto_sync_handle +from pylibraft.common.input_validation import is_c_contiguous + +from rmm._lib.memory_resource cimport ( + DeviceMemoryResource, + device_memory_resource, +) + +cimport pylibraft.neighbors.ivf_flat.cpp.c_ivf_flat as c_ivf_flat +from pylibraft.common.cpp.optional cimport optional + +from pylibraft.neighbors.common import _check_input_array, _get_metric + +from pylibraft.common.mdspan cimport ( + get_dmv_float, + get_dmv_int8, + get_dmv_int64, + get_dmv_uint8, +) +from pylibraft.neighbors.common cimport _get_metric_string +from pylibraft.neighbors.ivf_flat.cpp.c_ivf_flat cimport ( + index_params, + search_params, +) + + +cdef class IndexParams: + cdef c_ivf_flat.index_params params + + def __init__(self, *, + n_lists=1024, + metric="sqeuclidean", + kmeans_n_iters=20, + kmeans_trainset_fraction=0.5, + add_data_on_build=True, + bool adaptive_centers=False): + """" + Parameters to build index for IVF-FLAT nearest neighbor search + + Parameters + ---------- + n_list : int, default = 1024 + The number of clusters used in the coarse quantizer. + metric : string denoting the metric type, default="sqeuclidean" + Valid values for metric: ["sqeuclidean", "inner_product", + "euclidean"], where + - sqeuclidean is the euclidean distance without the square root + operation, i.e.: distance(a,b) = \\sum_i (a_i - b_i)^2, + - euclidean is the euclidean distance + - inner product distance is defined as + distance(a, b) = \\sum_i a_i * b_i. + kmeans_n_iters : int, default = 20 + The number of iterations searching for kmeans centers during index + building. + kmeans_trainset_fraction : int, default = 0.5 + If kmeans_trainset_fraction is less than 1, then the dataset is + subsampled, and only n_samples * kmeans_trainset_fraction rows + are used for training. + add_data_on_build : bool, default = True + After training the coarse and fine quantizers, we will populate + the index with the dataset if add_data_on_build == True, otherwise + the index is left empty, and the extend method can be used + to add new vectors to the index. + adaptive_centers : bool, default = False + By default (adaptive_centers = False), the cluster centers are + trained in `ivf_flat::build`, and and never modified in + `ivf_flat::extend`. The alternative behavior (adaptive_centers + = true) is to update the cluster centers for new data when it is + added. In this case, `index.centers()` are always exactly the + centroids of the data in the corresponding clusters. The drawback + of this behavior is that the centroids depend on the order of + adding new data (through the classification of the added data); + that is, `index.centers()` "drift" together with the changing + distribution of the newly added data. + """ + self.params.n_lists = n_lists + self.params.metric = _get_metric(metric) + self.params.metric_arg = 0 + self.params.kmeans_n_iters = kmeans_n_iters + self.params.kmeans_trainset_fraction = kmeans_trainset_fraction + self.params.add_data_on_build = add_data_on_build + self.params.adaptive_centers = adaptive_centers + + @property + def n_lists(self): + return self.params.n_lists + + @property + def metric(self): + return self.params.metric + + @property + def kmeans_n_iters(self): + return self.params.kmeans_n_iters + + @property + def kmeans_trainset_fraction(self): + return self.params.kmeans_trainset_fraction + + @property + def add_data_on_build(self): + return self.params.add_data_on_build + + @property + def adaptive_centers(self): + return self.params.adaptive_centers + + +cdef class Index: + cdef readonly bool trained + cdef str active_index_type + + def __cinit__(self): + self.trained = False + self.active_index_type = None + + +cdef class IndexFloat(Index): + cdef c_ivf_flat.index[float, int64_t] * index + + def __cinit__(self, handle=None): + if handle is None: + handle = DeviceResources() + cdef device_resources* handle_ = \ + handle.getHandle() + + # this is to keep track of which index type is being used + # We create a placeholder object. The actual parameter values do + # not matter, it will be replaced with a built index object later. + self.index = new c_ivf_flat.index[float, int64_t]( + deref(handle_), _get_metric("sqeuclidean"), + 1, + False, + False, + 4) + + def __repr__(self): + m_str = "metric=" + _get_metric_string(self.index.metric()) + attr_str = [ + attr + "=" + str(getattr(self, attr)) + for attr in ["size", "dim", "n_lists", "adaptive_centers"] + ] + attr_str = [m_str] + attr_str + return "Index(type=IVF-FLAT, " + (", ".join(attr_str)) + ")" + + @property + def dim(self): + return self.index[0].dim() + + @property + def size(self): + return self.index[0].size() + + @property + def metric(self): + return self.index[0].metric() + + @property + def n_lists(self): + return self.index[0].n_lists() + + @property + def adaptive_centers(self): + return self.index[0].adaptive_centers() + + +cdef class IndexInt8(Index): + cdef c_ivf_flat.index[int8_t, int64_t] * index + + def __cinit__(self, handle=None): + if handle is None: + handle = DeviceResources() + cdef device_resources* handle_ = \ + handle.getHandle() + + # this is to keep track of which index type is being used + # We create a placeholder object. The actual parameter values do + # not matter, it will be replaced with a built index object later. + self.index = new c_ivf_flat.index[int8_t, int64_t]( + deref(handle_), _get_metric("sqeuclidean"), + 1, + False, + False, + 4) + + def __repr__(self): + m_str = "metric=" + _get_metric_string(self.index.metric()) + attr_str = [ + attr + "=" + str(getattr(self, attr)) + for attr in ["size", "dim", "n_lists", "adaptive_centers"] + ] + attr_str = [m_str] + attr_str + return "Index(type=IVF-FLAT, " + (", ".join(attr_str)) + ")" + + @property + def dim(self): + return self.index[0].dim() + + @property + def size(self): + return self.index[0].size() + + @property + def metric(self): + return self.index[0].metric() + + @property + def n_lists(self): + return self.index[0].n_lists() + + @property + def adaptive_centers(self): + return self.index[0].adaptive_centers() + + +cdef class IndexUint8(Index): + cdef c_ivf_flat.index[uint8_t, int64_t] * index + + def __cinit__(self, handle=None): + if handle is None: + handle = DeviceResources() + cdef device_resources* handle_ = \ + handle.getHandle() + + # this is to keep track of which index type is being used + # We create a placeholder object. The actual parameter values do + # not matter, it will be replaced with a built index object later. + self.index = new c_ivf_flat.index[uint8_t, int64_t]( + deref(handle_), _get_metric("sqeuclidean"), + 1, + False, + False, + 4) + + def __repr__(self): + m_str = "metric=" + _get_metric_string(self.index.metric()) + attr_str = [ + attr + "=" + str(getattr(self, attr)) + for attr in ["size", "dim", "n_lists", "adaptive_centers"] + ] + attr_str = [m_str] + attr_str + return "Index(type=IVF-FLAT, " + (", ".join(attr_str)) + ")" + + @property + def dim(self): + return self.index[0].dim() + + @property + def size(self): + return self.index[0].size() + + @property + def metric(self): + return self.index[0].metric() + + @property + def n_lists(self): + return self.index[0].n_lists() + + @property + def adaptive_centers(self): + return self.index[0].adaptive_centers() + + +@auto_sync_handle +@auto_convert_output +def build(IndexParams index_params, dataset, handle=None): + """ + Builds an IVF-FLAT index that can be used for nearest neighbor search. + + Parameters + ---------- + index_params : IndexParams object + dataset : CUDA array interface compliant matrix shape (n_samples, dim) + Supported dtype [float, int8, uint8] + {handle_docstring} + + Returns + ------- + index: ivf_flat.Index + + Examples + -------- + + >>> import cupy as cp + + >>> from pylibraft.common import DeviceResources + >>> from pylibraft.neighbors import ivf_flat + + >>> n_samples = 50000 + >>> n_features = 50 + >>> n_queries = 1000 + + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> handle = DeviceResources() + >>> index_params = ivf_flat.IndexParams( + ... n_lists=1024, + ... metric="sqeuclidean") + + >>> index = ivf_flat.build(index_params, dataset, handle=handle) + + >>> # Search using the built index + >>> queries = cp.random.random_sample((n_queries, n_features), + ... dtype=cp.float32) + >>> k = 10 + >>> distances, neighbors = ivf_flat.search(ivf_flat.SearchParams(), index, + ... queries, k, handle=handle) + + >>> distances = cp.asarray(distances) + >>> neighbors = cp.asarray(neighbors) + + >>> # pylibraft functions are often asynchronous so the + >>> # handle needs to be explicitly synchronized + >>> handle.sync() + """ + dataset_cai = cai_wrapper(dataset) + dataset_dt = dataset_cai.dtype + _check_input_array(dataset_cai, [np.dtype('float32'), np.dtype('byte'), + np.dtype('ubyte')]) + + cdef int64_t n_rows = dataset_cai.shape[0] + cdef uint32_t dim = dataset_cai.shape[1] + + if handle is None: + handle = DeviceResources() + cdef device_resources* handle_ = \ + handle.getHandle() + + cdef IndexFloat idx_float + cdef IndexInt8 idx_int8 + cdef IndexUint8 idx_uint8 + + if dataset_dt == np.float32: + idx_float = IndexFloat(handle) + idx_float.active_index_type = "float32" + with cuda_interruptible(): + c_ivf_flat.build(deref(handle_), + index_params.params, + get_dmv_float(dataset_cai, check_shape=True), + deref(idx_float.index)) + idx_float.trained = True + return idx_float + elif dataset_dt == np.byte: + idx_int8 = IndexInt8(handle) + idx_int8.active_index_type = "byte" + with cuda_interruptible(): + c_ivf_flat.build(deref(handle_), + index_params.params, + get_dmv_int8(dataset_cai, check_shape=True), + deref(idx_int8.index)) + idx_int8.trained = True + return idx_int8 + elif dataset_dt == np.ubyte: + idx_uint8 = IndexUint8(handle) + idx_uint8.active_index_type = "ubyte" + with cuda_interruptible(): + c_ivf_flat.build(deref(handle_), + index_params.params, + get_dmv_uint8(dataset_cai, check_shape=True), + deref(idx_uint8.index)) + idx_uint8.trained = True + return idx_uint8 + else: + raise TypeError("dtype %s not supported" % dataset_dt) + + +@auto_sync_handle +@auto_convert_output +def extend(Index index, new_vectors, new_indices, handle=None): + """ + Extend an existing index with new vectors. + + Parameters + ---------- + index : ivf_flat.Index + Trained ivf_flat object. + new_vectors : CUDA array interface compliant matrix shape (n_samples, dim) + Supported dtype [float, int8, uint8] + new_indices : CUDA array interface compliant matrix shape (n_samples, dim) + Supported dtype [int64] + {handle_docstring} + + Returns + ------- + index: ivf_flat.Index + + Examples + -------- + + >>> import cupy as cp + + >>> from pylibraft.common import DeviceResources + >>> from pylibraft.neighbors import ivf_flat + + >>> n_samples = 50000 + >>> n_features = 50 + >>> n_queries = 1000 + + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> handle = DeviceResources() + >>> index = ivf_flat.build(ivf_flat.IndexParams(), dataset, handle=handle) + + >>> n_rows = 100 + >>> more_data = cp.random.random_sample((n_rows, n_features), + ... dtype=cp.float32) + >>> indices = index.size + cp.arange(n_rows, dtype=cp.int64) + >>> index = ivf_flat.extend(index, more_data, indices) + + >>> # Search using the built index + >>> queries = cp.random.random_sample((n_queries, n_features), + ... dtype=cp.float32) + >>> k = 10 + >>> distances, neighbors = ivf_flat.search(ivf_flat.SearchParams(), + ... index, queries, + ... k, handle=handle) + + >>> # pylibraft functions are often asynchronous so the + >>> # handle needs to be explicitly synchronized + >>> handle.sync() + + >>> distances = cp.asarray(distances) + >>> neighbors = cp.asarray(neighbors) + """ + if not index.trained: + raise ValueError("Index need to be built before calling extend.") + + if handle is None: + handle = DeviceResources() + cdef device_resources* handle_ = \ + handle.getHandle() + + vecs_cai = cai_wrapper(new_vectors) + vecs_dt = vecs_cai.dtype + cdef int64_t n_rows = vecs_cai.shape[0] + cdef uint32_t dim = vecs_cai.shape[1] + + _check_input_array(vecs_cai, [np.dtype(index.active_index_type)], + exp_cols=index.dim) + + idx_cai = cai_wrapper(new_indices) + _check_input_array(idx_cai, [np.dtype('int64')], exp_rows=n_rows) + if len(idx_cai.shape)!=1: + raise ValueError("Indices array is expected to be 1D") + + cdef optional[device_vector_view[int64_t, int64_t]] new_indices_opt + + cdef IndexFloat idx_float + cdef IndexInt8 idx_int8 + cdef IndexUint8 idx_uint8 + + if vecs_dt == np.float32: + idx_float = index + if idx_float.index.size() > 0: + new_indices_opt = make_device_vector_view( + idx_cai.data, + idx_cai.shape[0]) + with cuda_interruptible(): + c_ivf_flat.extend(deref(handle_), + get_dmv_float(vecs_cai, check_shape=True), + new_indices_opt, + idx_float.index) + elif vecs_dt == np.int8: + idx_int8 = index + if idx_int8.index[0].size() > 0: + new_indices_opt = make_device_vector_view( + idx_cai.data, + idx_cai.shape[0]) + with cuda_interruptible(): + c_ivf_flat.extend(deref(handle_), + get_dmv_int8(vecs_cai, check_shape=True), + new_indices_opt, + idx_int8.index) + elif vecs_dt == np.uint8: + idx_uint8 = index + if idx_uint8.index[0].size() > 0: + new_indices_opt = make_device_vector_view( + idx_cai.data, + idx_cai.shape[0]) + with cuda_interruptible(): + c_ivf_flat.extend(deref(handle_), + get_dmv_uint8(vecs_cai, check_shape=True), + new_indices_opt, + idx_uint8.index) + else: + raise TypeError("query dtype %s not supported" % vecs_dt) + + return index + + +cdef class SearchParams: + cdef c_ivf_flat.search_params params + + def __init__(self, *, n_probes=20): + """ + IVF-FLAT search parameters + + Parameters + ---------- + n_probes: int, default = 1024 + The number of course clusters to select for the fine search. + """ + self.params.n_probes = n_probes + + def __repr__(self): + attr_str = [attr + "=" + str(getattr(self, attr)) + for attr in ["n_probes"]] + return "SearchParams(type=IVF-FLAT, " + (", ".join(attr_str)) + ")" + + @property + def n_probes(self): + return self.params.n_probes + + +@auto_sync_handle +@auto_convert_output +def search(SearchParams search_params, + Index index, + queries, + k, + neighbors=None, + distances=None, + handle=None): + """ + Find the k nearest neighbors for each query. + + Parameters + ---------- + search_params : SearchParams + index : Index + Trained IVF-FLAT index. + queries : CUDA array interface compliant matrix shape (n_samples, dim) + Supported dtype [float, int8, uint8] + k : int + The number of neighbors. + neighbors : Optional CUDA array interface compliant matrix shape + (n_queries, k), dtype int64_t. If supplied, neighbor + indices will be written here in-place. (default None) + distances : Optional CUDA array interface compliant matrix shape + (n_queries, k) If supplied, the distances to the + neighbors will be written here in-place. (default None) + {handle_docstring} + + Examples + -------- + >>> import cupy as cp + + >>> from pylibraft.common import DeviceResources + >>> from pylibraft.neighbors import ivf_flat + + >>> n_samples = 50000 + >>> n_features = 50 + >>> n_queries = 1000 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + + >>> # Build index + >>> handle = DeviceResources() + >>> index = ivf_flat.build(ivf_flat.IndexParams(), dataset, handle=handle) + + >>> # Search using the built index + >>> queries = cp.random.random_sample((n_queries, n_features), + ... dtype=cp.float32) + >>> k = 10 + >>> search_params = ivf_flat.SearchParams( + ... n_probes=20, + ... lut_dtype=cp.float16, + ... internal_distance_dtype=cp.float32 + ... ) + + # TODO update example to set default pool allocator + # (instead of passing an mr) + + >>> # Using a pooling allocator reduces overhead of temporary array + >>> # creation during search. This is useful if multiple searches + >>> # are performad with same query size. + >>> import rmm + >>> mr = rmm.mr.PoolMemoryResource( + ... rmm.mr.CudaMemoryResource(), + ... initial_pool_size=2**29, + ... maximum_pool_size=2**31 + ... ) + >>> distances, neighbors = ivf_flat.search(search_params, index, queries, + ... k, memory_resource=mr, + ... handle=handle) + + >>> # pylibraft functions are often asynchronous so the + >>> # handle needs to be explicitly synchronized + >>> handle.sync() + + >>> neighbors = cp.asarray(neighbors) + >>> distances = cp.asarray(distances) + """ + + if not index.trained: + raise ValueError("Index need to be built before calling search.") + + if handle is None: + handle = DeviceResources() + cdef device_resources* handle_ = \ + handle.getHandle() + + queries_cai = cai_wrapper(queries) + queries_dt = queries_cai.dtype + cdef uint32_t n_queries = queries_cai.shape[0] + + _check_input_array(queries_cai, [np.dtype(index.active_index_type)], + exp_cols=index.dim) + + if neighbors is None: + neighbors = device_ndarray.empty((n_queries, k), dtype='int64') + + neighbors_cai = cai_wrapper(neighbors) + _check_input_array(neighbors_cai, [np.dtype('int64')], + exp_rows=n_queries, exp_cols=k) + + if distances is None: + distances = device_ndarray.empty((n_queries, k), dtype='float32') + + distances_cai = cai_wrapper(distances) + _check_input_array(distances_cai, [np.dtype('float32')], + exp_rows=n_queries, exp_cols=k) + + cdef c_ivf_flat.search_params params = search_params.params + cdef IndexFloat idx_float + cdef IndexInt8 idx_int8 + cdef IndexUint8 idx_uint8 + + if queries_dt == np.float32: + idx_float = index + with cuda_interruptible(): + c_ivf_flat.search(deref(handle_), + params, + deref(idx_float.index), + get_dmv_float(queries_cai, check_shape=True), + get_dmv_int64(neighbors_cai, check_shape=True), + get_dmv_float(distances_cai, check_shape=True)) + elif queries_dt == np.byte: + idx_int8 = index + with cuda_interruptible(): + c_ivf_flat.search(deref(handle_), + params, + deref(idx_int8.index), + get_dmv_int8(queries_cai, check_shape=True), + get_dmv_int64(neighbors_cai, check_shape=True), + get_dmv_float(distances_cai, check_shape=True)) + elif queries_dt == np.ubyte: + idx_uint8 = index + with cuda_interruptible(): + c_ivf_flat.search(deref(handle_), + params, + deref(idx_uint8.index), + get_dmv_uint8(queries_cai, check_shape=True), + get_dmv_int64(neighbors_cai, check_shape=True), + get_dmv_float(distances_cai, check_shape=True)) + else: + raise ValueError("query dtype %s not supported" % queries_dt) + + return (distances, neighbors) diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx index f0edfbd9c0..1906c569f6 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx @@ -49,7 +49,11 @@ from rmm._lib.memory_resource cimport ( device_memory_resource, ) +cimport pylibraft.neighbors.ivf_flat.cpp.c_ivf_flat as c_ivf_flat cimport pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq as c_ivf_pq + +from pylibraft.neighbors.common import _check_input_array, _get_metric + from pylibraft.common.cpp.mdspan cimport device_matrix_view, row_major from pylibraft.common.mdspan cimport ( get_dmv_float, @@ -58,34 +62,13 @@ from pylibraft.common.mdspan cimport ( get_dmv_uint8, make_optional_view_int64, ) +from pylibraft.neighbors.common cimport _get_metric_string from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport ( index_params, search_params, ) -def _get_metric(metric): - SUPPORTED_DISTANCES = { - "sqeuclidean": DistanceType.L2Expanded, - "euclidean": DistanceType.L2SqrtExpanded, - "inner_product": DistanceType.InnerProduct - } - if metric not in SUPPORTED_DISTANCES: - if metric == "l2_expanded": - warnings.warn("Using l2_expanded as a metric name is deprecated," - " use sqeuclidean instead", FutureWarning) - return DistanceType.L2Expanded - - raise ValueError("metric %s is not supported" % metric) - return SUPPORTED_DISTANCES[metric] - - -cdef _get_metric_string(DistanceType metric): - return {DistanceType.L2Expanded : "sqeuclidean", - DistanceType.InnerProduct: "inner_product", - DistanceType.L2SqrtExpanded: "euclidean"}[metric] - - cdef _get_codebook_string(c_ivf_pq.codebook_gen codebook): return {c_ivf_pq.codebook_gen.PER_SUBSPACE: "subspace", c_ivf_pq.codebook_gen.PER_CLUSTER: "cluster"}[codebook] @@ -105,22 +88,6 @@ cdef _get_dtype_string(dtype): c_ivf_pq.cudaDataType_t.CUDA_R_8U: np.uint8}[dtype]) -def _check_input_array(cai, exp_dt, exp_rows=None, exp_cols=None): - if cai.dtype not in exp_dt: - raise TypeError("dtype %s not supported" % cai["typestr"]) - - if not cai.c_contiguous: - raise ValueError("Row major input is expected") - - if exp_cols is not None and cai.shape[1] != exp_cols: - raise ValueError("Incorrect number of columns, expected {} got {}" - .format(exp_cols, cai.shape[1])) - - if exp_rows is not None and cai.shape[0] != exp_rows: - raise ValueError("Incorrect number of rows, expected {} , got {}" - .format(exp_rows, cai.shape[0])) - - cdef class IndexParams: cdef c_ivf_pq.index_params params diff --git a/python/pylibraft/pylibraft/neighbors/refine.pyx b/python/pylibraft/pylibraft/neighbors/refine.pyx index 8eb468c805..20f5327226 100644 --- a/python/pylibraft/pylibraft/neighbors/refine.pyx +++ b/python/pylibraft/pylibraft/neighbors/refine.pyx @@ -18,6 +18,7 @@ # cython: embedsignature = True # cython: language_level = 3 +import cupy as cp import numpy as np from cython.operator cimport dereference as deref @@ -42,7 +43,7 @@ from pylibraft.common.interruptible import cuda_interruptible from pylibraft.distance.distance_type cimport DistanceType import pylibraft.neighbors.ivf_pq as ivf_pq -from pylibraft.neighbors.ivf_pq.ivf_pq import _get_metric +from pylibraft.neighbors.common import _get_metric cimport pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq as c_ivf_pq from pylibraft.common.cpp.mdspan cimport ( @@ -57,6 +58,7 @@ from pylibraft.common.mdspan cimport ( get_dmv_int64, get_dmv_uint8, ) +from pylibraft.neighbors.common cimport _get_metric_string from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport ( index_params, search_params, diff --git a/python/pylibraft/pylibraft/test/test_ivf_flat.py b/python/pylibraft/pylibraft/test/test_ivf_flat.py new file mode 100644 index 0000000000..593980f7c8 --- /dev/null +++ b/python/pylibraft/pylibraft/test/test_ivf_flat.py @@ -0,0 +1,463 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# h ttp://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pytest +from sklearn.metrics import pairwise_distances +from sklearn.neighbors import NearestNeighbors +from sklearn.preprocessing import normalize + +from pylibraft.common import device_ndarray +from pylibraft.neighbors import ivf_flat + + +def generate_data(shape, dtype): + if dtype == np.byte: + x = np.random.randint(-127, 128, size=shape, dtype=np.byte) + elif dtype == np.ubyte: + x = np.random.randint(0, 255, size=shape, dtype=np.ubyte) + else: + x = np.random.random_sample(shape).astype(dtype) + + return x + + +def calc_recall(ann_idx, true_nn_idx): + assert ann_idx.shape == true_nn_idx.shape + n = 0 + for i in range(ann_idx.shape[0]): + n += np.intersect1d(ann_idx[i, :], true_nn_idx[i, :]).size + recall = n / ann_idx.size + return recall + + +def check_distances(dataset, queries, metric, out_idx, out_dist, eps=None): + """ + Calculate the real distance between queries and dataset[out_idx], + and compare it to out_dist. + """ + if eps is None: + # Quantization leads to errors in the distance calculation. + # The aim of this test is not to test precision, but to catch obvious + # errors. + eps = 0.1 + + dist = np.empty(out_dist.shape, out_dist.dtype) + for i in range(queries.shape[0]): + X = queries[np.newaxis, i, :] + Y = dataset[out_idx[i, :], :] + if metric == "sqeuclidean": + dist[i, :] = pairwise_distances(X, Y, "sqeuclidean") + elif metric == "euclidean": + dist[i, :] = pairwise_distances(X, Y, "euclidean") + elif metric == "inner_product": + dist[i, :] = np.matmul(X, Y.T) + else: + raise ValueError("Invalid metric") + + dist_eps = abs(dist) + dist_eps[dist < 1e-3] = 1e-3 + diff = abs(out_dist - dist) / dist_eps + + assert np.mean(diff) < eps + + +def run_ivf_flat_build_search_test( + n_rows, + n_cols, + n_queries, + k, + n_lists, + metric, + dtype, + add_data_on_build=True, + n_probes=100, + kmeans_trainset_fraction=1, + kmeans_n_iters=20, + compare=True, + inplace=True, + array_type="device", +): + dataset = generate_data((n_rows, n_cols), dtype) + if metric == "inner_product": + dataset = normalize(dataset, norm="l2", axis=1) + dataset_device = device_ndarray(dataset) + + build_params = ivf_flat.IndexParams( + n_lists=n_lists, + metric=metric, + kmeans_n_iters=kmeans_n_iters, + kmeans_trainset_fraction=kmeans_trainset_fraction, + add_data_on_build=add_data_on_build, + ) + + if array_type == "device": + index = ivf_flat.build(build_params, dataset_device) + else: + index = ivf_flat.build(build_params, dataset) + + assert index.trained + + assert index.metric == build_params.metric + assert index.n_lists == build_params.n_lists + + if not add_data_on_build: + dataset_1 = dataset[: n_rows // 2, :] + dataset_2 = dataset[n_rows // 2 :, :] + indices_1 = np.arange(n_rows // 2, dtype=np.int64) + indices_2 = np.arange(n_rows // 2, n_rows, dtype=np.int64) + if array_type == "device": + dataset_1_device = device_ndarray(dataset_1) + dataset_2_device = device_ndarray(dataset_2) + indices_1_device = device_ndarray(indices_1) + indices_2_device = device_ndarray(indices_2) + index = ivf_flat.extend(index, dataset_1_device, indices_1_device) + index = ivf_flat.extend(index, dataset_2_device, indices_2_device) + else: + index = ivf_flat.extend(index, dataset_1, indices_1) + index = ivf_flat.extend(index, dataset_2, indices_2) + + assert index.size >= n_rows + + queries = generate_data((n_queries, n_cols), dtype) + out_idx = np.zeros((n_queries, k), dtype=np.int64) + out_dist = np.zeros((n_queries, k), dtype=np.float32) + + queries_device = device_ndarray(queries) + out_idx_device = device_ndarray(out_idx) if inplace else None + out_dist_device = device_ndarray(out_dist) if inplace else None + + search_params = ivf_flat.SearchParams(n_probes=n_probes) + + ret_output = ivf_flat.search( + search_params, + index, + queries_device, + k, + neighbors=out_idx_device, + distances=out_dist_device, + ) + + if not inplace: + out_dist_device, out_idx_device = ret_output + + if not compare: + return + + out_idx = out_idx_device.copy_to_host() + out_dist = out_dist_device.copy_to_host() + + # Calculate reference values with sklearn + skl_metric = { + "sqeuclidean": "sqeuclidean", + "inner_product": "cosine", + "euclidean": "euclidean", + }[metric] + nn_skl = NearestNeighbors( + n_neighbors=k, algorithm="brute", metric=skl_metric + ) + nn_skl.fit(dataset) + skl_idx = nn_skl.kneighbors(queries, return_distance=False) + + recall = calc_recall(out_idx, skl_idx) + assert recall > 0.7 + + check_distances(dataset, queries, metric, out_idx, out_dist) + + +@pytest.mark.parametrize("inplace", [True, False]) +@pytest.mark.parametrize("n_rows", [10000]) +@pytest.mark.parametrize("n_cols", [10]) +@pytest.mark.parametrize("n_queries", [100]) +@pytest.mark.parametrize("n_lists", [100]) +@pytest.mark.parametrize("dtype", [np.float32, np.int8, np.uint8]) +@pytest.mark.parametrize("array_type", ["device"]) +def test_ivf_pq_dtypes( + n_rows, n_cols, n_queries, n_lists, dtype, inplace, array_type +): + # Note that inner_product tests use normalized input which we cannot + # represent in int8, therefore we test only sqeuclidean metric here. + run_ivf_flat_build_search_test( + n_rows=n_rows, + n_cols=n_cols, + n_queries=n_queries, + k=10, + n_lists=n_lists, + metric="sqeuclidean", + dtype=dtype, + inplace=inplace, + array_type=array_type, + ) + + +@pytest.mark.parametrize( + "params", + [ + pytest.param( + { + "n_rows": 0, + "n_cols": 10, + "n_queries": 10, + "k": 1, + "n_lists": 10, + }, + marks=pytest.mark.xfail(reason="empty dataset"), + ), + {"n_rows": 1, "n_cols": 10, "n_queries": 10, "k": 1, "n_lists": 1}, + {"n_rows": 10, "n_cols": 1, "n_queries": 10, "k": 10, "n_lists": 10}, + # {"n_rows": 999, "n_cols": 42, "n_queries": 453, "k": 137, + # "n_lists": 53}, + ], +) +def test_ivf_flat_n(params): + # We do not test recall, just confirm that we can handle edge cases for + # certain parameters + run_ivf_flat_build_search_test( + n_rows=params["n_rows"], + n_cols=params["n_cols"], + n_queries=params["n_queries"], + k=params["k"], + n_lists=params["n_lists"], + metric="sqeuclidean", + dtype=np.float32, + compare=False, + ) + + +@pytest.mark.parametrize( + "metric", ["sqeuclidean", "inner_product", "euclidean"] +) +@pytest.mark.parametrize("dtype", [np.float32]) +def test_ivf_flat_build_params(metric, dtype): + run_ivf_flat_build_search_test( + n_rows=10000, + n_cols=10, + n_queries=1000, + k=10, + n_lists=100, + metric=metric, + dtype=dtype, + add_data_on_build=True, + n_probes=100, + ) + + +@pytest.mark.parametrize( + "params", + [ + { + "n_lists": 100, + "trainset_fraction": 0.9, + "n_iters": 30, + }, + ], +) +def test_ivf_flat_params(params): + run_ivf_flat_build_search_test( + n_rows=10000, + n_cols=16, + n_queries=1000, + k=10, + n_lists=params["n_lists"], + metric="sqeuclidean", + dtype=np.float32, + kmeans_trainset_fraction=params.get("trainset_fraction", 1.0), + kmeans_n_iters=params.get("n_iters", 20), + ) + + +@pytest.mark.parametrize( + "params", + [ + { + "k": 10, + "n_probes": 100, + }, + { + "k": 10, + "n_probes": 99, + }, + { + "k": 10, + "n_probes": 100, + }, + { + "k": 129, + "n_probes": 100, + }, + ], +) +def test_ivf_pq_search_params(params): + run_ivf_flat_build_search_test( + n_rows=10000, + n_cols=16, + n_queries=1000, + k=params["k"], + n_lists=100, + n_probes=params["n_probes"], + metric="sqeuclidean", + dtype=np.float32, + ) + + +@pytest.mark.parametrize("dtype", [np.float32, np.int8, np.uint8]) +@pytest.mark.parametrize("array_type", ["device"]) +def test_extend(dtype, array_type): + run_ivf_flat_build_search_test( + n_rows=10000, + n_cols=10, + n_queries=100, + k=10, + n_lists=100, + metric="sqeuclidean", + dtype=dtype, + add_data_on_build=False, + array_type=array_type, + ) + + +def test_build_assertions(): + with pytest.raises(TypeError): + run_ivf_flat_build_search_test( + n_rows=1000, + n_cols=10, + n_queries=100, + k=10, + n_lists=100, + metric="sqeuclidean", + dtype=np.float64, + ) + + n_rows = 1000 + n_cols = 100 + n_queries = 212 + k = 10 + dataset = generate_data((n_rows, n_cols), np.float32) + dataset_device = device_ndarray(dataset) + + index_params = ivf_flat.IndexParams( + n_lists=50, + metric="sqeuclidean", + kmeans_n_iters=20, + kmeans_trainset_fraction=1, + add_data_on_build=False, + ) + + index = ivf_flat.Index() + + queries = generate_data((n_queries, n_cols), np.float32) + out_idx = np.zeros((n_queries, k), dtype=np.int64) + out_dist = np.zeros((n_queries, k), dtype=np.float32) + + queries_device = device_ndarray(queries) + out_idx_device = device_ndarray(out_idx) + out_dist_device = device_ndarray(out_dist) + + search_params = ivf_flat.SearchParams(n_probes=50) + + with pytest.raises(ValueError): + # Index must be built before search + ivf_flat.search( + search_params, + index, + queries_device, + k, + out_idx_device, + out_dist_device, + ) + + index = ivf_flat.build(index_params, dataset_device) + assert index.trained + + indices = np.arange(n_rows + 1, dtype=np.int64) + indices_device = device_ndarray(indices) + + with pytest.raises(ValueError): + # Dataset dimension mismatch + ivf_flat.extend(index, queries_device, indices_device) + + with pytest.raises(ValueError): + # indices dimension mismatch + ivf_flat.extend(index, dataset_device, indices_device) + + +@pytest.mark.parametrize( + "params", + [ + {"q_dt": np.float64}, + {"q_order": "F"}, + {"q_cols": 101}, + {"idx_dt": np.uint32}, + {"idx_order": "F"}, + {"idx_rows": 42}, + {"idx_cols": 137}, + {"dist_dt": np.float64}, + {"dist_order": "F"}, + {"dist_rows": 42}, + {"dist_cols": 137}, + ], +) +def test_search_inputs(params): + """Test with invalid input dtype, order, or dimension.""" + n_rows = 1000 + n_cols = 100 + n_queries = 256 + k = 10 + dtype = np.float32 + + q_dt = params.get("q_dt", np.float32) + q_order = params.get("q_order", "C") + queries = generate_data( + (n_queries, params.get("q_cols", n_cols)), q_dt + ).astype(q_dt, order=q_order) + queries_device = device_ndarray(queries) + + idx_dt = params.get("idx_dt", np.int64) + idx_order = params.get("idx_order", "C") + out_idx = np.zeros( + (params.get("idx_rows", n_queries), params.get("idx_cols", k)), + dtype=idx_dt, + order=idx_order, + ) + out_idx_device = device_ndarray(out_idx) + + dist_dt = params.get("dist_dt", np.float32) + dist_order = params.get("dist_order", "C") + out_dist = np.zeros( + (params.get("dist_rows", n_queries), params.get("dist_cols", k)), + dtype=dist_dt, + order=dist_order, + ) + out_dist_device = device_ndarray(out_dist) + + index_params = ivf_flat.IndexParams( + n_lists=50, metric="sqeuclidean", add_data_on_build=True + ) + + dataset = generate_data((n_rows, n_cols), dtype) + dataset_device = device_ndarray(dataset) + index = ivf_flat.build(index_params, dataset_device) + assert index.trained + + with pytest.raises(Exception): + search_params = ivf_flat.SearchParams(n_probes=50) + ivf_flat.search( + search_params, + index, + queries_device, + k, + out_idx_device, + out_dist_device, + )