From 58051630f1d3e2ca7568b160ece33a09370fc654 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 21 Apr 2023 17:55:59 -0700 Subject: [PATCH] fix ivf_pq n_probes The ivf-pq search code was including a guard like ```auto n_probes = std::min(params.n_probes, index.n_lists());``` to check to make sure that we weren't selecting more values than are available. However, this wasn't being used and instead just `params.n_probes` was being passed to functions like `select_k`. This lead to asking select_k to select say 100 items, when there were only 90 to choose from - and caused some issues downstream when trying to update the select_k algorithm Fix. --- cpp/include/raft/neighbors/detail/ivf_pq_search.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh index 4b6e6f5e31..9a94458748 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh @@ -1613,7 +1613,7 @@ inline void search(raft::device_resources const& handle, rmm::device_uvector float_queries(max_queries * dim_ext, stream, mr); rmm::device_uvector rot_queries(max_queries * index.rot_dim(), stream, mr); - rmm::device_uvector clusters_to_probe(max_queries * params.n_probes, stream, mr); + rmm::device_uvector clusters_to_probe(max_queries * n_probes, stream, mr); auto search_instance = ivfpq_search::fun(params, index.metric()); @@ -1624,7 +1624,7 @@ inline void search(raft::device_resources const& handle, clusters_to_probe.data(), float_queries.data(), queries_batch, - params.n_probes, + n_probes, index.n_lists(), dim, dim_ext, @@ -1661,10 +1661,10 @@ inline void search(raft::device_resources const& handle, search_instance(handle, index, max_samples, - params.n_probes, + n_probes, k, batch_size, - clusters_to_probe.data() + uint64_t(params.n_probes) * offset_b, + clusters_to_probe.data() + uint64_t(n_probes) * offset_b, rot_queries.data() + uint64_t(index.rot_dim()) * offset_b, neighbors + uint64_t(k) * (offset_q + offset_b), distances + uint64_t(k) * (offset_q + offset_b),