Skip to content

Commit

Permalink
add range_search() to IndexRefine (facebookresearch#4022)
Browse files Browse the repository at this point in the history
Summary:
This is very convenient to have `range_seach()` in `IndexRefine`. Unlike the plain `search()` method, `range_search()` just reevaluates the computed distances from the baseline index. The labels are not re-sorted according to new distances, because this is not listed as a requirement in a method description
https://github.com/facebookresearch/faiss/blob/adb188411a98c3af5b7295c7016e5f46fee9eb07/faiss/Index.h#L150-L161
https://github.com/facebookresearch/faiss/blob/adb188411a98c3af5b7295c7016e5f46fee9eb07/faiss/impl/AuxIndexStructures.h#L35

Pull Request resolved: facebookresearch#4022

Reviewed By: mnorris11

Differential Revision: D66116082

Pulled By: gtwang01

fbshipit-source-id: 915aca2570d5863c876c9497d4c885e270b9b220
  • Loading branch information
alexanderguzhva authored and facebook-github-bot committed Jan 6, 2025
1 parent 9590ad2 commit 162e6ce
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 1 deletion.
39 changes: 39 additions & 0 deletions faiss/IndexRefine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,45 @@ void IndexRefine::search(
}
}

void IndexRefine::range_search(
idx_t n,
const float* x,
float radius,
RangeSearchResult* result,
const SearchParameters* params_in) const {
const IndexRefineSearchParameters* params = nullptr;
if (params_in) {
params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
FAISS_THROW_IF_NOT_MSG(
params, "IndexRefine params have incorrect type");
}

SearchParameters* base_index_params =
(params != nullptr) ? params->base_index_params : nullptr;

base_index->range_search(n, x, radius, result, base_index_params);

#pragma omp parallel if (n > 1)
{
std::unique_ptr<DistanceComputer> dc(
refine_index->get_distance_computer());

#pragma omp for
for (idx_t i = 0; i < n; i++) {
dc->set_query(x + i * d);

// reevaluate distances
const size_t idx_start = result->lims[i];
const size_t idx_end = result->lims[i + 1];

for (size_t j = idx_start; j < idx_end; j++) {
const auto label = result->labels[j];
result->distances[j] = (*dc)(label);
}
}
}
}

void IndexRefine::reconstruct(idx_t key, float* recons) const {
refine_index->reconstruct(key, recons);
}
Expand Down
7 changes: 7 additions & 0 deletions faiss/IndexRefine.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ struct IndexRefine : Index {
idx_t* labels,
const SearchParameters* params = nullptr) const override;

void range_search(
idx_t n,
const float* x,
float radius,
RangeSearchResult* result,
const SearchParameters* params = nullptr) const override;

// reconstruct is routed to the refine_index
void reconstruct(idx_t key, float* recons) const override;

Expand Down
52 changes: 51 additions & 1 deletion tests/test_refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import unittest
import faiss

from faiss.contrib import datasets
from faiss.contrib import datasets, evaluation


class TestDistanceComputer(unittest.TestCase):
Expand Down Expand Up @@ -119,3 +119,53 @@ def test_rflat(self):
def test_refine_sq8(self):
# this case uses the IndexRefine class
self.do_test("IVF8,PQ2x4np,Refine(SQ8)")


class TestIndexRefineRangeSearch(unittest.TestCase):

def do_test(self, factory_string):
d = 32
radius = 8

ds = datasets.SyntheticDataset(d, 1024, 512, 256)

index = faiss.index_factory(d, factory_string)
index.train(ds.get_train())
index.add(ds.get_database())
xq = ds.get_queries()
xb = ds.get_database()

# perform a range_search
lims_1, D1, I1 = index.range_search(xq, radius)

# create a baseline (FlatL2)
index_flat = faiss.IndexFlatL2(d)
index_flat.train(ds.get_train())
index_flat.add(ds.get_database())

lims_ref, Dref, Iref = index_flat.range_search(xq, radius)

# add a refine index on top of the index
index_r = faiss.IndexRefine(index, index_flat)
lims_2, D2, I2 = index_r.range_search(xq, radius)

# validate: refined range_search() keeps indices untouched
precision_1, recall_1 = evaluation.range_PR(lims_ref, Iref, lims_1, I1)

precision_2, recall_2 = evaluation.range_PR(lims_ref, Iref, lims_2, I2)

self.assertAlmostEqual(recall_1, recall_2)

# validate: refined range_search() updates distances, and new distances are correct L2 distances
for iq in range(0, ds.nq):
start_lim = lims_2[iq]
end_lim = lims_2[iq + 1]
for i_lim in range(start_lim, end_lim):
idx = I2[i_lim]
l2_dis = np.sum(np.square(xq[iq : iq + 1,] - xb[idx : idx + 1,]))

self.assertAlmostEqual(l2_dis, D2[i_lim], places=4)


def test_refine_1(self):
self.do_test("SQ4")

0 comments on commit 162e6ce

Please sign in to comment.