Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IVF-Flat Python wrappers #1316

Merged
merged 37 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a0c57f4
Allow use of mdspan view in IVF-PQ API
viclafargue Feb 3, 2023
45553a1
restore legacy API
viclafargue Feb 6, 2023
c602e8d
row_major + assert tests
viclafargue Feb 6, 2023
3cffbb7
Merge branch 'branch-23.04' into ivf-pq-mdspan
viclafargue Feb 6, 2023
126677c
Merge branch 'branch-23.04' into ivf-pq-mdspan
cjnolet Feb 9, 2023
17ef73c
addressing review
viclafargue Feb 9, 2023
5eba07a
Merge branch 'branch-23.04' into ivf-pq-mdspan
cjnolet Feb 11, 2023
d4e3660
fix style
viclafargue Feb 17, 2023
8d6c6bc
Merge branch 'branch-23.04' into ivf-pq-mdspan
viclafargue Feb 17, 2023
22d0960
Merge branch 'branch-23.04' into ivf-pq-mdspan
viclafargue Feb 21, 2023
cd7e9a1
Merge branch 'branch-23.04' into ivf-pq-mdspan
viclafargue Mar 1, 2023
c73cfb2
moving helper funcs around
viclafargue Mar 1, 2023
ef624a8
Merge branch 'branch-23.04' into ivf-pq-mdspan
cjnolet Mar 6, 2023
6c5744a
Merge branch 'branch-23.04' into ivf-pq-mdspan
cjnolet Mar 6, 2023
efebf6f
IVF-Flat python wrappers work in progress
tfeher Mar 7, 2023
f3f3cf7
fix refine
viclafargue Mar 7, 2023
cc8451e
Merge branch 'branch-23.04' into ivf-pq-mdspan
cjnolet Mar 7, 2023
2e08cb8
Merge branch 'ivf-pq-mdspan' of github.com:viclafargue/raft into fea-…
divyegala Mar 8, 2023
0815f74
add mdspan based apis to runtime
divyegala Mar 8, 2023
89c0b33
link to runtime, build successfully
divyegala Mar 8, 2023
ac0da62
Merge branch 'branch-23.04' into fea-ivf-flat-python-api
cjnolet Mar 9, 2023
2fd6a60
add index pointer API for mdspan build, remove k from mdspan search, …
divyegala Mar 9, 2023
ada8cce
Merge branch 'fea-ivf-flat-python-api' of github.com:tfeher/raft into…
divyegala Mar 9, 2023
6925e84
consolidate python wrapper, write pytests
divyegala Mar 9, 2023
17196fb
Merge remote-tracking branch 'upstream/branch-23.04' into fea-ivf-fla…
divyegala Mar 9, 2023
8faadf5
remove references to load/save
divyegala Mar 9, 2023
0a960b8
change uint64_t to int64_t
divyegala Mar 10, 2023
edf733f
merge upstream
divyegala Mar 10, 2023
b712842
rearrange api signature
divyegala Mar 14, 2023
0f5b44c
merging upstream
divyegala Mar 14, 2023
5acdef1
resolve bad merge, address review
divyegala Mar 15, 2023
3373ccf
Merge branch 'branch-23.04' into fea-ivf-flat-python-api
cjnolet Mar 15, 2023
94611b2
fix namespaces
divyegala Mar 15, 2023
a8a9441
add missing TU back
divyegala Mar 15, 2023
d1d1c79
fix index
divyegala Mar 15, 2023
49ccbc7
fix pytests, looking to fix gtests
divyegala Mar 16, 2023
da39e91
all tests passing
divyegala Mar 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,8 @@ if(RAFT_COMPILE_DIST_LIBRARY)
src/distance/neighbors/ivfpq_search_float_uint64_t.cu
src/distance/neighbors/ivfpq_search_int8_t_uint64_t.cu
src/distance/neighbors/ivfpq_search_uint8_t_uint64_t.cu
src/distance/neighbors/ivf_flat_search.cu
src/distance/neighbors/ivf_flat_build.cu
src/distance/neighbors/specializations/ivfpq_build_float_uint64_t.cu
src/distance/neighbors/specializations/ivfpq_build_int8_t_uint64_t.cu
src/distance/neighbors/specializations/ivfpq_build_uint8_t_uint64_t.cu
Expand Down
84 changes: 84 additions & 0 deletions cpp/include/raft_runtime/neighbors/ivf_flat.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/neighbors/ivf_flat_types.hpp>

namespace raft::runtime::neighbors::ivf_flat {

// We define overloads for build and extend with void return type. This is used in the Cython
// wrappers, where exception handling is not compatible with return type that has nontrivial
// constructor.
#define RAFT_INST_BUILD_EXTEND(T, IdxT) \
auto build(raft::device_resources const& handle, \
raft::device_matrix_view<const T, uint64_t, row_major> dataset, \
const raft::neighbors::ivf_flat::index_params& params) \
->raft::neighbors::ivf_flat::index<T, IdxT>; \
\
auto extend(raft::device_resources const& handle, \
const raft::neighbors::ivf_flat::index<T, IdxT>& orig_index, \
const T* new_vectors, \
const IdxT* new_indices, \
IdxT n_rows) \
->raft::neighbors::ivf_flat::index<T, IdxT>; \
\
void build(raft::device_resources const& handle, \
raft::device_matrix_view<const T, uint64_t, row_major> dataset, \
const raft::neighbors::ivf_flat::index_params& params, \
raft::neighbors::ivf_flat::index<T, IdxT>* idx); \
\
void extend(raft::device_resources const& handle, \
raft::neighbors::ivf_flat::index<T, IdxT>* idx, \
divyegala marked this conversation as resolved.
Show resolved Hide resolved
const T* new_vectors, \
const IdxT* new_indices, \
IdxT n_rows);

RAFT_INST_BUILD_EXTEND(float, uint64_t)
RAFT_INST_BUILD_EXTEND(int8_t, uint64_t)
RAFT_INST_BUILD_EXTEND(uint8_t, uint64_t)

#undef RAFT_INST_BUILD_EXTEND

// /**
// * Save the index to file.
// *
// * Experimental, both the API and the serialization format are subject to change.
// *
// * @param[in] handle the raft handle
// * @param[in] filename the filename for saving the index
// * @param[in] index IVF-PQ index
// *
// */
// void save(raft::device_resources const& handle,
// const std::string& filename,
// const raft::neighbors::ivf_flat::index<uint64_t>& index);

// /**
// * Load index from file.
// *
// * Experimental, both the API and the serialization format are subject to change.
// *
// * @param[in] handle the raft handle
// * @param[in] filename the name of the file that stores the index
// * @param[in] index IVF-PQ index
// *
// */
// void load(raft::device_resources const& handle,
// const std::string& filename,
// raft::neighbors::ivf_flat::index<uint64_t>* index);

} // namespace raft::runtime::neighbors::ivf_flat
78 changes: 78 additions & 0 deletions cpp/src/distance/neighbors/ivf_flat_build.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/neighbors/ivf_flat.cuh>
#include <raft_runtime/neighbors/ivf_flat.hpp>

namespace raft::runtime::neighbors::ivf_flat {

#define RAFT_INST_BUILD_EXTEND(T, IdxT) \
auto build(raft::device_resources const& handle, \
raft::device_matrix_view<const T, uint64_t, row_major> dataset, \
const raft::neighbors::ivf_flat::index_params& params) \
->raft::neighbors::ivf_flat::index<T, IdxT> \
{ \
return raft::neighbors::ivf_flat::build<T, IdxT>(handle, dataset, params); \
} \
auto extend(raft::device_resources const& handle, \
const raft::neighbors::ivf_flat::index<T, IdxT>& orig_index, \
const T* new_vectors, \
const IdxT* new_indices, \
IdxT n_rows) \
->raft::neighbors::ivf_flat::index<T, IdxT> \
{ \
return raft::neighbors::ivf_flat::extend<T, IdxT>( \
handle, orig_index, new_vectors, new_indices, n_rows); \
} \
\
void build(raft::device_resources const& handle, \
raft::device_matrix_view<const T, uint64_t, row_major> dataset, \
const raft::neighbors::ivf_flat::index_params& params, \
raft::neighbors::ivf_flat::index<T, IdxT>* idx) \
{ \
*idx = raft::neighbors::ivf_flat::build<T, IdxT>(handle, dataset, params); \
} \
\
void extend(raft::device_resources const& handle, \
raft::neighbors::ivf_flat::index<T, IdxT>* idx, \
const T* new_vectors, \
const IdxT* new_indices, \
IdxT n_rows) \
{ \
raft::neighbors::ivf_flat::extend<T, IdxT>(handle, idx, new_vectors, new_indices, n_rows); \
}

RAFT_INST_BUILD_EXTEND(float, uint64_t);
divyegala marked this conversation as resolved.
Show resolved Hide resolved
RAFT_INST_BUILD_EXTEND(int8_t, uint64_t);
RAFT_INST_BUILD_EXTEND(uint8_t, uint64_t);

#undef RAFT_INST_BUILD_EXTEND

// void save(raft::device_resources const& handle,
// const std::string& filename,
// const raft::neighbors::ivf_flat::index<T, uint64_t>& index)
// {
// raft::spatial::knn::ivf_flat::detail::save(handle, filename, index);
// };

// void load(raft::device_resources const& handle,
// const std::string& filename,
// raft::neighbors::ivf_flat::index<T, uint64_t>* index)
// {
// if (!index) { RAFT_FAIL("Invalid index pointer"); }
// *index = raft::spatial::knn::ivf_flat::detail::load<T, uint64_t>(handle, filename);
// };
} // namespace raft::runtime::neighbors::ivf_flat
38 changes: 38 additions & 0 deletions cpp/src/distance/neighbors/ivf_flat_search.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/neighbors/ivf_pq.cuh>
// #include <raft/neighbors/specializations/detail/ivf_flat_search.cuh>
#include <raft_runtime/neighbors/ivf_flat.hpp>

namespace raft::runtime::neighbors::ivf_flat {

#define RAFT_INST_SEARCH(T, IdxT) \
void search(raft::device_resources const&, \
const raft::neighbors::ivf_flat::search_params&, \
const raft::neighbors::ivf_flat::index<T, IdxT>&, \
raft::device_matrix_view<const T, IdxT, row_major> queries, \
raft::device_matrix_view<IdxT, IdxT, row_major> neighbors, \
raft::device_matrix_view<T, IdxT, row_major> distances, \
uint32_t k);

RAFT_INST_SEARCH(float, uint64_t);
RAFT_INST_SEARCH(int8_t, uint64_t);
RAFT_INST_SEARCH(uint8_t, uint64_t);

#undef RAFT_INST_SEARCH

} // namespace raft::runtime::neighbors::ivf_flat
3 changes: 2 additions & 1 deletion python/pylibraft/pylibraft/neighbors/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# =============================================================================
# Copyright (c) 2022, NVIDIA CORPORATION.
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
Expand All @@ -23,4 +23,5 @@ rapids_cython_create_modules(
LINKED_LIBRARIES "${linked_libraries}" ASSOCIATED_TARGETS raft MODULE_PREFIX neighbors_
)

add_subdirectory(ivf_flat)
add_subdirectory(ivf_pq)
24 changes: 24 additions & 0 deletions python/pylibraft/pylibraft/neighbors/ivf_flat/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# =============================================================================
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.
# =============================================================================

# Set the list of Cython files to build
set(cython_sources ivf_flat.pyx)
set(linked_libraries raft::raft raft::distance)

# Build all of the Cython targets
rapids_cython_create_modules(
CXX
SOURCE_FILES "${cython_sources}"
LINKED_LIBRARIES "${linked_libraries}" ASSOCIATED_TARGETS raft MODULE_PREFIX neighbors_ivfflat_
)
Empty file.
25 changes: 25 additions & 0 deletions python/pylibraft/pylibraft/neighbors/ivf_flat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from .ivf_flat import Index, IndexParams, SearchParams, build, extend, search

__all__ = [
"Index",
"IndexParams",
"SearchParams",
"build",
"extend",
"search",
]
Empty file.
14 changes: 14 additions & 0 deletions python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
Loading