Skip to content

Commit

Permalink
change helper to constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
tarang-jain committed Apr 23, 2024
1 parent 54061d0 commit 0e498a8
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 13 deletions.
3 changes: 1 addition & 2 deletions cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ namespace raft::neighbors::cagra {
* @code{.cpp}
* using namespace raft::neighbors;
* // use default index parameters
* ivf_pq::index_params build_params;
* build_params.initialize_from_dataset(dataset);
* ivf_pq::index_params build_params(dataset);
* ivf_pq::search_params search_params;
* auto knn_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 128);
* // create knn graph
Expand Down
5 changes: 1 addition & 4 deletions cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,7 @@ void build_knn_graph(raft::resources const& res,
size_t(dataset.extent(1)),
node_degree);

if (!build_params) {
build_params = ivf_pq::index_params{};
build_params.value().initialize_from_dataset(dataset);
}
if (!build_params) { build_params = ivf_pq::index_params(dataset); }

// Make model name
const std::string model_name = [&]() {
Expand Down
8 changes: 3 additions & 5 deletions cpp/include/raft/neighbors/ivf_pq_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,11 @@ struct index_params : ann::index_params {
bool conservative_memory_allocation = false;

/**
* Helper that sets values according to the extents of the dataset mdspan.
* Constructor that sets values according to the extents of the dataset mdspan.
*/
template <typename DataT, typename Accessor>
void initialize_from_dataset(
mdspan<const DataT, matrix_extent<int64_t>, row_major, Accessor> dataset,
raft::distance::DistanceType metric = raft::distance::L2Expanded)
explicit index_params(mdspan<const DataT, matrix_extent<int64_t>, row_major, Accessor> dataset,
raft::distance::DistanceType metric = raft::distance::L2Expanded)
{
n_lists =
dataset.extent(0) < 4 * 2500 ? 4 : static_cast<uint32_t>(std::sqrt(dataset.extent(0)));
Expand Down Expand Up @@ -161,7 +160,6 @@ struct search_params : ann::search_params {
double preferred_shmem_carveout = 1.0;
};

static_assert(std::is_aggregate_v<index_params>);
static_assert(std::is_aggregate_v<search_params>);

/** Size of the interleaved group. */
Expand Down
3 changes: 1 addition & 2 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,7 @@ class AnnCagraSortTest : public ::testing::TestWithParam<AnnCagraInputs> {
raft::make_host_matrix<IdxT, int64_t>(ps.n_rows, index_params.intermediate_graph_degree);

if (ps.build_algo == graph_build_algo::IVF_PQ) {
auto build_params = ivf_pq::index_params{};
build_params.initialize_from_dataset(database_view, ps.metric);
auto build_params = ivf_pq::index_params(database_view, ps.metric);
if (ps.host_dataset) {
cagra::build_knn_graph<DataT, IdxT>(
handle_, database_host_view, knn_graph.view(), 2, build_params);
Expand Down

0 comments on commit 0e498a8

Please sign in to comment.