Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IVF-Flat Python wrappers #1316

Merged
merged 37 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a0c57f4
Allow use of mdspan view in IVF-PQ API
viclafargue Feb 3, 2023
45553a1
restore legacy API
viclafargue Feb 6, 2023
c602e8d
row_major + assert tests
viclafargue Feb 6, 2023
3cffbb7
Merge branch 'branch-23.04' into ivf-pq-mdspan
viclafargue Feb 6, 2023
126677c
Merge branch 'branch-23.04' into ivf-pq-mdspan
cjnolet Feb 9, 2023
17ef73c
addressing review
viclafargue Feb 9, 2023
5eba07a
Merge branch 'branch-23.04' into ivf-pq-mdspan
cjnolet Feb 11, 2023
d4e3660
fix style
viclafargue Feb 17, 2023
8d6c6bc
Merge branch 'branch-23.04' into ivf-pq-mdspan
viclafargue Feb 17, 2023
22d0960
Merge branch 'branch-23.04' into ivf-pq-mdspan
viclafargue Feb 21, 2023
cd7e9a1
Merge branch 'branch-23.04' into ivf-pq-mdspan
viclafargue Mar 1, 2023
c73cfb2
moving helper funcs around
viclafargue Mar 1, 2023
ef624a8
Merge branch 'branch-23.04' into ivf-pq-mdspan
cjnolet Mar 6, 2023
6c5744a
Merge branch 'branch-23.04' into ivf-pq-mdspan
cjnolet Mar 6, 2023
efebf6f
IVF-Flat python wrappers work in progress
tfeher Mar 7, 2023
f3f3cf7
fix refine
viclafargue Mar 7, 2023
cc8451e
Merge branch 'branch-23.04' into ivf-pq-mdspan
cjnolet Mar 7, 2023
2e08cb8
Merge branch 'ivf-pq-mdspan' of github.com:viclafargue/raft into fea-…
divyegala Mar 8, 2023
0815f74
add mdspan based apis to runtime
divyegala Mar 8, 2023
89c0b33
link to runtime, build successfully
divyegala Mar 8, 2023
ac0da62
Merge branch 'branch-23.04' into fea-ivf-flat-python-api
cjnolet Mar 9, 2023
2fd6a60
add index pointer API for mdspan build, remove k from mdspan search, …
divyegala Mar 9, 2023
ada8cce
Merge branch 'fea-ivf-flat-python-api' of github.com:tfeher/raft into…
divyegala Mar 9, 2023
6925e84
consolidate python wrapper, write pytests
divyegala Mar 9, 2023
17196fb
Merge remote-tracking branch 'upstream/branch-23.04' into fea-ivf-fla…
divyegala Mar 9, 2023
8faadf5
remove references to load/save
divyegala Mar 9, 2023
0a960b8
change uint64_t to int64_t
divyegala Mar 10, 2023
edf733f
merge upstream
divyegala Mar 10, 2023
b712842
rearrange api signature
divyegala Mar 14, 2023
0f5b44c
merging upstream
divyegala Mar 14, 2023
5acdef1
resolve bad merge, address review
divyegala Mar 15, 2023
3373ccf
Merge branch 'branch-23.04' into fea-ivf-flat-python-api
cjnolet Mar 15, 2023
94611b2
fix namespaces
divyegala Mar 15, 2023
a8a9441
add missing TU back
divyegala Mar 15, 2023
d1d1c79
fix index
divyegala Mar 15, 2023
49ccbc7
fix pytests, looking to fix gtests
divyegala Mar 16, 2023
da39e91
all tests passing
divyegala Mar 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ log
.DS_Store
dask-worker-space/
*.egg-info/
*.bin

## scikit-build
_skbuild
Expand Down
11 changes: 11 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,17 @@ if(RAFT_COMPILE_DIST_LIBRARY)
src/distance/neighbors/ivfpq_search_float_uint64_t.cu
src/distance/neighbors/ivfpq_search_int8_t_uint64_t.cu
src/distance/neighbors/ivfpq_search_uint8_t_uint64_t.cu
src/distance/neighbors/ivf_flat_search.cu
src/distance/neighbors/ivf_flat_build.cu
src/distance/neighbors/specializations/ivfflat_build_float_uint64_t.cu
divyegala marked this conversation as resolved.
Show resolved Hide resolved
src/distance/neighbors/specializations/ivfflat_build_int8_t_uint64_t.cu
src/distance/neighbors/specializations/ivfflat_build_uint8_t_uint64_t.cu
src/distance/neighbors/specializations/ivfflat_extend_float_uint64_t.cu
src/distance/neighbors/specializations/ivfflat_extend_int8_t_uint64_t.cu
src/distance/neighbors/specializations/ivfflat_extend_uint8_t_uint64_t.cu
src/distance/neighbors/specializations/ivfflat_search_float_uint64_t.cu
src/distance/neighbors/specializations/ivfflat_search_int8_t_uint64_t.cu
src/distance/neighbors/specializations/ivfflat_search_uint8_t_uint64_t.cu
src/distance/neighbors/specializations/ivfpq_build_float_uint64_t.cu
src/distance/neighbors/specializations/ivfpq_build_int8_t_uint64_t.cu
src/distance/neighbors/specializations/ivfpq_build_uint8_t_uint64_t.cu
Expand Down
61 changes: 52 additions & 9 deletions cpp/include/raft/neighbors/ivf_flat.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ auto build(raft::device_resources const& handle,
* @tparam matrix_idx_t matrix indexing type
*
* @param[in] handle
* @param[in] params configure the index building
* @param[in] dataset a device pointer to a row-major matrix [n_rows, dim]
* @param[in] params configure the index building
*
* @return the constructed ivf-flat index
*/
Expand All @@ -119,6 +119,52 @@ auto build(raft::device_resources const& handle,
static_cast<idx_t>(dataset.extent(1)));
}

/**
* @brief Build the index from the dataset for efficient search.
*
* NB: Currently, the following distance metrics are supported:
* - L2Expanded
* - L2Unexpanded
* - InnerProduct
*
* Usage example:
* @code{.cpp}
* using namespace raft::neighbors;
* // use default index parameters
* ivf_flat::index_params index_params;
* // create and fill the index from a [N, D] dataset
* ivf_flat::index<decltype(dataset::value_type), decltype(dataset::index_type)> *index_ptr;
* ivf_flat::build(handle, dataset, index_params, index_ptr);
* // use default search parameters
* ivf_flat::search_params search_params;
* // search K nearest neighbours for each of the N queries
* ivf_flat::search(handle, index, queries, out_inds, out_dists, search_params, k);
* @endcode
*
* @tparam value_t data element type
* @tparam idx_t type of the indices in the source dataset
* @tparam int_t precision / type of integral arguments
* @tparam matrix_idx_t matrix indexing type
*
* @param[in] handle
* @param[in] dataset a device pointer to a row-major matrix [n_rows, dim]
* @param[in] params configure the index building
* @param[out] idx pointer to ivf_flat::index
*
*/
template <typename value_t, typename idx_t>
void build(raft::device_resources const& handle,
raft::device_matrix_view<const value_t, idx_t, row_major> dataset,
const index_params& params,
raft::neighbors::ivf_flat::index<value_t, idx_t>* idx)
{
*idx = raft::spatial::knn::ivf_flat::detail::build(handle,
params,
dataset.data_handle(),
static_cast<idx_t>(dataset.extent(0)),
static_cast<idx_t>(dataset.extent(1)));
}

/** @} */

/**
Expand Down Expand Up @@ -397,24 +443,21 @@ void search(raft::device_resources const& handle,
* [n_queries, k]
* @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
* @param[in] params configure the search
* @param[in] k the number of neighbors to find for each query.
*/
template <typename value_t, typename idx_t, typename int_t>
template <typename value_t, typename idx_t>
void search(raft::device_resources const& handle,
const index<value_t, idx_t>& index,
raft::device_matrix_view<const value_t, idx_t, row_major> queries,
raft::device_matrix_view<idx_t, idx_t, row_major> neighbors,
raft::device_matrix_view<float, idx_t, row_major> distances,
const search_params& params,
int_t k)
const search_params& params)
{
RAFT_EXPECTS(
queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0),
"Number of rows in output neighbors and distances matrices must equal the number of queries.");

RAFT_EXPECTS(
neighbors.extent(1) == distances.extent(1) && neighbors.extent(1) == static_cast<idx_t>(k),
"Number of columns in output neighbors and distances matrices must equal k");
RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1),
"Number of columns in output neighbors and distances matrices must be equal");

RAFT_EXPECTS(queries.extent(1) == index.dim(),
"Number of query dimensions should equal number of dimensions in the index.");
Expand All @@ -424,7 +467,7 @@ void search(raft::device_resources const& handle,
index,
queries.data_handle(),
static_cast<std::uint32_t>(queries.extent(0)),
static_cast<std::uint32_t>(k),
static_cast<std::uint32_t>(neighbors.extent(1)),
neighbors.data_handle(),
distances.data_handle(),
nullptr);
Expand Down
1 change: 1 addition & 0 deletions cpp/include/raft/neighbors/specializations.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <raft/neighbors/specializations/ball_cover.cuh>
#include <raft/neighbors/specializations/fused_l2_knn.cuh>
#include <raft/neighbors/specializations/ivf_flat.cuh>
#include <raft/neighbors/specializations/ivf_pq.cuh>
#include <raft/neighbors/specializations/knn.cuh>

Expand Down
59 changes: 59 additions & 0 deletions cpp/include/raft/neighbors/specializations/ivf_flat.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/neighbors/ivf_flat.cuh>

namespace raft::neighbors::ivf_flat {

#define RAFT_INST(T, IdxT) \
extern template auto build(raft::device_resources const& handle, \
raft::device_matrix_view<const T, uint64_t, row_major> dataset, \
const index_params& params) \
->index<T, IdxT>; \
\
extern template auto extend( \
raft::device_resources const& handle, \
const index<T, IdxT>& orig_index, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices) \
->index<T, IdxT>; \
\
extern template void build(raft::device_resources const& handle, \
raft::device_matrix_view<const T, uint64_t, row_major> dataset, \
const index_params& params, \
index<T, IdxT>* idx); \
\
extern template void extend( \
raft::device_resources const& handle, \
index<T, IdxT>* idx, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices); \
\
extern template void search(raft::device_resources const&, \
const index<T, IdxT>&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
raft::device_matrix_view<IdxT, IdxT, row_major>, \
raft::device_matrix_view<float, IdxT, row_major>, \
search_params const&);

RAFT_INST(float, uint64_t);
divyegala marked this conversation as resolved.
Show resolved Hide resolved
RAFT_INST(int8_t, uint64_t);
RAFT_INST(uint8_t, uint64_t);

#undef RAFT_INST
} // namespace raft::neighbors::ivf_flat
68 changes: 68 additions & 0 deletions cpp/include/raft_runtime/neighbors/ivf_flat.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/neighbors/ivf_flat_types.hpp>

namespace raft::runtime::neighbors::ivf_flat {

// We define overloads for build and extend with void return type. This is used in the Cython
// wrappers, where exception handling is not compatible with return type that has nontrivial
// constructor.
#define RAFT_INST_BUILD_EXTEND(T, IdxT) \
auto build(raft::device_resources const& handle, \
raft::device_matrix_view<const T, uint64_t, row_major> dataset, \
const raft::neighbors::ivf_flat::index_params& params) \
->raft::neighbors::ivf_flat::index<T, IdxT>; \
\
auto extend(raft::device_resources const& handle, \
const raft::neighbors::ivf_flat::index<T, IdxT>& orig_index, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices) \
->raft::neighbors::ivf_flat::index<T, IdxT>; \
\
void build(raft::device_resources const& handle, \
raft::device_matrix_view<const T, uint64_t, row_major> dataset, \
const raft::neighbors::ivf_flat::index_params& params, \
raft::neighbors::ivf_flat::index<T, IdxT>* idx); \
\
void extend(raft::device_resources const& handle, \
raft::neighbors::ivf_flat::index<T, IdxT>* idx, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices);

RAFT_INST_BUILD_EXTEND(float, uint64_t)
RAFT_INST_BUILD_EXTEND(int8_t, uint64_t)
RAFT_INST_BUILD_EXTEND(uint8_t, uint64_t)

#undef RAFT_INST_BUILD_EXTEND

#define RAFT_INST_SEARCH(T, IdxT) \
void search(raft::device_resources const&, \
const raft::neighbors::ivf_flat::index<T, IdxT>&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
raft::device_matrix_view<IdxT, IdxT, row_major>, \
raft::device_matrix_view<float, IdxT, row_major>, \
raft::neighbors::ivf_flat::search_params const&);

RAFT_INST_SEARCH(float, uint64_t);
divyegala marked this conversation as resolved.
Show resolved Hide resolved
RAFT_INST_SEARCH(int8_t, uint64_t);
RAFT_INST_SEARCH(uint8_t, uint64_t);

#undef RAFT_INST_SEARCH

} // namespace raft::runtime::neighbors::ivf_flat
62 changes: 62 additions & 0 deletions cpp/src/distance/neighbors/ivf_flat_build.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/neighbors/specializations/ivf_flat.cuh>
divyegala marked this conversation as resolved.
Show resolved Hide resolved
#include <raft_runtime/neighbors/ivf_flat.hpp>

namespace raft::runtime::neighbors::ivf_flat {

#define RAFT_INST_BUILD_EXTEND(T, IdxT) \
auto build(raft::device_resources const& handle, \
raft::device_matrix_view<const T, uint64_t, row_major> dataset, \
const raft::neighbors::ivf_flat::index_params& params) \
->raft::neighbors::ivf_flat::index<T, IdxT> \
{ \
return raft::neighbors::ivf_flat::build<T, IdxT>(handle, dataset, params); \
} \
auto extend(raft::device_resources const& handle, \
const raft::neighbors::ivf_flat::index<T, IdxT>& orig_index, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices) \
->raft::neighbors::ivf_flat::index<T, IdxT> \
{ \
return raft::neighbors::ivf_flat::extend<T, IdxT>( \
handle, orig_index, new_vectors, new_indices); \
} \
\
void build(raft::device_resources const& handle, \
raft::device_matrix_view<const T, uint64_t, row_major> dataset, \
const raft::neighbors::ivf_flat::index_params& params, \
raft::neighbors::ivf_flat::index<T, IdxT>* idx) \
{ \
*idx = raft::neighbors::ivf_flat::build<T, IdxT>(handle, dataset, params); \
} \
\
void extend(raft::device_resources const& handle, \
raft::neighbors::ivf_flat::index<T, IdxT>* idx, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices) \
{ \
raft::neighbors::ivf_flat::extend<T, IdxT>(handle, idx, new_vectors, new_indices); \
}

RAFT_INST_BUILD_EXTEND(float, uint64_t);
divyegala marked this conversation as resolved.
Show resolved Hide resolved
RAFT_INST_BUILD_EXTEND(int8_t, uint64_t);
RAFT_INST_BUILD_EXTEND(uint8_t, uint64_t);

#undef RAFT_INST_BUILD_EXTEND

} // namespace raft::runtime::neighbors::ivf_flat
40 changes: 40 additions & 0 deletions cpp/src/distance/neighbors/ivf_flat_search.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/neighbors/specializations/ivf_flat.cuh>
#include <raft_runtime/neighbors/ivf_flat.hpp>

namespace raft::runtime::neighbors::ivf_flat {

#define RAFT_INST_SEARCH(T, IdxT) \
void search(raft::device_resources const& handle, \
const raft::neighbors::ivf_flat::index<T, IdxT>& index, \
raft::device_matrix_view<const T, IdxT, row_major> queries, \
raft::device_matrix_view<IdxT, IdxT, row_major> neighbors, \
raft::device_matrix_view<float, IdxT, row_major> distances, \
raft::neighbors::ivf_flat::search_params const& params) \
{ \
raft::neighbors::ivf_flat::search<T, IdxT>( \
handle, index, queries, neighbors, distances, params); \
}

RAFT_INST_SEARCH(float, uint64_t);
RAFT_INST_SEARCH(int8_t, uint64_t);
RAFT_INST_SEARCH(uint8_t, uint64_t);

#undef RAFT_INST_SEARCH

} // namespace raft::runtime::neighbors::ivf_flat
Loading