Skip to content

Commit

Permalink
Fix radius search with HSNW and IP (facebookresearch#3698)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#3698

Addressed issue

facebookresearch#3684

I forgot to negate the threshold of the radius search.
This diff adds a test and fixes the issue.

Reviewed By: mengdilin

Differential Revision: D60373054

fbshipit-source-id: 70f3daa8292177a4038846a94aff6221f88077e8
  • Loading branch information
mdouze authored and ketor committed Aug 20, 2024
1 parent 7951e54 commit 4efde4f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 21 deletions.
22 changes: 1 addition & 21 deletions faiss/IndexHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,6 @@
#include <faiss/utils/random.h>
#include <faiss/utils/sorting.h>

extern "C" {

/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */

int sgemm_(
const char* transa,
const char* transb,
FINTEGER* m,
FINTEGER* n,
FINTEGER* k,
const float* alpha,
const float* a,
FINTEGER* lda,
const float* b,
FINTEGER* ldb,
float* beta,
float* c,
FINTEGER* ldc);
}

namespace faiss {

using MinimaxHeap = HNSW::MinimaxHeap;
Expand Down Expand Up @@ -340,7 +320,7 @@ void IndexHNSW::range_search(
RangeSearchResult* result,
const SearchParameters* params) const {
using RH = RangeSearchBlockResultHandler<HNSW::C>;
RH bres(result, radius);
RH bres(result, is_similarity_metric(metric_type) ? -radius : radius);

hnsw_search(this, n, x, bres, params);

Expand Down
28 changes: 28 additions & 0 deletions tests/test_graph_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,34 @@ def test_abs_inner_product(self):
# 4769 vs. 500*10
self.assertGreater(inter, Iref.size * 0.9)


class Issue3684(unittest.TestCase):

def test_issue3684(self):
np.random.seed(1234) # For reproducibility
d = 256 # Example dimension
nb = 10 # Number of database vectors
nq = 2 # Number of query vectors
xb = np.random.random((nb, d)).astype('float32')
xq = np.random.random((nq, d)).astype('float32')

faiss.normalize_L2(xb) # Normalize both query and database vectors
faiss.normalize_L2(xq)

hnsw_index_ip = faiss.IndexHNSWFlat(256, 16, faiss.METRIC_INNER_PRODUCT)
hnsw_index_ip.hnsw.efConstruction = 512
hnsw_index_ip.hnsw.efSearch = 512
hnsw_index_ip.add(xb)

# test knn
D, I = hnsw_index_ip.search(xq, 10)
self.assertTrue(np.all(D[:, :-1] >= D[:, 1:]))

# test range search
radius = 0.74 # Cosine similarity threshold
lims, D, I = hnsw_index_ip.range_search(xq, radius)
self.assertTrue(np.all(D >= radius))


class TestNSG(unittest.TestCase):

Expand Down

0 comments on commit 4efde4f

Please sign in to comment.