diff --git a/faiss/IndexBinaryHNSW.cpp b/faiss/IndexBinaryHNSW.cpp index 1f034009f8..9481fe67f2 100644 --- a/faiss/IndexBinaryHNSW.cpp +++ b/faiss/IndexBinaryHNSW.cpp @@ -281,31 +281,21 @@ struct FlatHammingDis : DistanceComputer { } }; +struct BuildDistanceComputer { + using T = DistanceComputer*; + template + DistanceComputer* f(IndexBinaryFlat* flat_storage) { + return new FlatHammingDis(*flat_storage); + } +}; + } // namespace DistanceComputer* IndexBinaryHNSW::get_distance_computer() const { IndexBinaryFlat* flat_storage = dynamic_cast(storage); - FAISS_ASSERT(flat_storage != nullptr); - - switch (code_size) { - case 4: - return new FlatHammingDis(*flat_storage); - case 8: - return new FlatHammingDis(*flat_storage); - case 16: - return new FlatHammingDis(*flat_storage); - case 20: - return new FlatHammingDis(*flat_storage); - case 32: - return new FlatHammingDis(*flat_storage); - case 64: - return new FlatHammingDis(*flat_storage); - default: - break; - } - - return new FlatHammingDis(*flat_storage); + BuildDistanceComputer bd; + return dispatch_HammingComputer(code_size, bd, flat_storage); } } // namespace faiss diff --git a/faiss/IndexBinaryHash.cpp b/faiss/IndexBinaryHash.cpp index 0e449bab77..22d5fa2936 100644 --- a/faiss/IndexBinaryHash.cpp +++ b/faiss/IndexBinaryHash.cpp @@ -176,6 +176,14 @@ void search_single_query_template( } while (fe.next()); } +struct Run_search_single_query { + using T = void; + template + T f(Types... args) { + search_single_query_template(args...); + } +}; + template void search_single_query( const IndexBinaryHash& index, @@ -184,29 +192,9 @@ void search_single_query( size_t& n0, size_t& nlist, size_t& ndis) { -#define HC(name) \ - search_single_query_template(index, q, res, n0, nlist, ndis); - switch (index.code_size) { - case 4: - HC(HammingComputer4); - break; - case 8: - HC(HammingComputer8); - break; - case 16: - HC(HammingComputer16); - break; - case 20: - HC(HammingComputer20); - break; - case 32: - HC(HammingComputer32); - break; - default: - HC(HammingComputerDefault); - break; - } -#undef HC + Run_search_single_query r; + dispatch_HammingComputer( + index.code_size, r, index, q, res, n0, nlist, ndis); } } // anonymous namespace @@ -365,6 +353,14 @@ static void verify_shortlist( } } +struct Run_verify_shortlist { + using T = void; + template + void f(Types... args) { + verify_shortlist(args...); + } +}; + template void search_1_query_multihash( const IndexBinaryMultiHash& index, @@ -405,29 +401,9 @@ void search_1_query_multihash( ndis += shortlist.size(); // verify shortlist - -#define HC(name) verify_shortlist(*index.storage, xi, shortlist, res) - switch (index.code_size) { - case 4: - HC(HammingComputer4); - break; - case 8: - HC(HammingComputer8); - break; - case 16: - HC(HammingComputer16); - break; - case 20: - HC(HammingComputer20); - break; - case 32: - HC(HammingComputer32); - break; - default: - HC(HammingComputerDefault); - break; - } -#undef HC + Run_verify_shortlist r; + dispatch_HammingComputer( + index.code_size, r, *index.storage, xi, shortlist, res); } } // anonymous namespace diff --git a/faiss/IndexBinaryIVF.cpp b/faiss/IndexBinaryIVF.cpp index 65b98280dc..3a0332ee54 100644 --- a/faiss/IndexBinaryIVF.cpp +++ b/faiss/IndexBinaryIVF.cpp @@ -735,151 +735,68 @@ void search_knn_hamming_per_invlist( } } +struct Run_search_knn_hamming_per_invlist { + using T = void; + + template + void f(Types... args) { + search_knn_hamming_per_invlist(args...); + } +}; + template -void search_knn_hamming_count_1( - const IndexBinaryIVF& ivf, - size_t nx, - const uint8_t* x, - const idx_t* keys, - int k, - int32_t* distances, - idx_t* labels, - const IVFSearchParameters* params) { - switch (ivf.code_size) { -#define HANDLE_CS(cs) \ - case cs: \ - search_knn_hamming_count( \ - ivf, nx, x, keys, k, distances, labels, params); \ - break; - HANDLE_CS(4); - HANDLE_CS(8); - HANDLE_CS(16); - HANDLE_CS(20); - HANDLE_CS(32); - HANDLE_CS(64); -#undef HANDLE_CS - default: - search_knn_hamming_count( - ivf, nx, x, keys, k, distances, labels, params); - break; +struct Run_search_knn_hamming_count { + using T = void; + + template + void f(Types... args) { + search_knn_hamming_count(args...); } -} +}; -void search_knn_hamming_per_invlist_1( - const IndexBinaryIVF& ivf, - size_t n, - const uint8_t* x, - idx_t k, - const idx_t* keys, - const int32_t* coarse_dis, - int32_t* distances, - idx_t* labels, - bool store_pairs, - const IVFSearchParameters* params) { - switch (ivf.code_size) { -#define HANDLE_CS(cs) \ - case cs: \ - search_knn_hamming_per_invlist( \ - ivf, \ - n, \ - x, \ - k, \ - keys, \ - coarse_dis, \ - distances, \ - labels, \ - store_pairs, \ - params); \ - break; - HANDLE_CS(4); - HANDLE_CS(8); - HANDLE_CS(16); - HANDLE_CS(20); - HANDLE_CS(32); - HANDLE_CS(64); -#undef HANDLE_CS - default: - search_knn_hamming_per_invlist( - ivf, - n, - x, - k, - keys, - coarse_dis, - distances, - labels, - store_pairs, - params); - break; +struct BuildScanner { + using T = BinaryInvertedListScanner*; + + template + T f(size_t code_size, bool store_pairs) { + return new IVFBinaryScannerL2(code_size, store_pairs); } -} +}; } // anonymous namespace BinaryInvertedListScanner* IndexBinaryIVF::get_InvertedListScanner( bool store_pairs) const { -#define HC(name) return new IVFBinaryScannerL2(code_size, store_pairs) - switch (code_size) { - case 4: - HC(HammingComputer4); - case 8: - HC(HammingComputer8); - case 16: - HC(HammingComputer16); - case 20: - HC(HammingComputer20); - case 32: - HC(HammingComputer32); - case 64: - HC(HammingComputer64); - default: - HC(HammingComputerDefault); - } -#undef HC + BuildScanner bs; + return dispatch_HammingComputer(code_size, bs, code_size, store_pairs); } void IndexBinaryIVF::search_preassigned( idx_t n, const uint8_t* x, idx_t k, - const idx_t* idx, - const int32_t* coarse_dis, - int32_t* distances, - idx_t* labels, + const idx_t* cidx, + const int32_t* cdis, + int32_t* dis, + idx_t* idx, bool store_pairs, const IVFSearchParameters* params) const { if (per_invlist_search) { - search_knn_hamming_per_invlist_1( - *this, - n, - x, - k, - idx, - coarse_dis, - distances, - labels, - store_pairs, - params); + Run_search_knn_hamming_per_invlist r; + dispatch_HammingComputer( + code_size, r, *this, n, x, k, + cidx, cdis, dis, idx, store_pairs, params); } else if (use_heap) { search_knn_hamming_heap( - *this, - n, - x, - k, - idx, - coarse_dis, - distances, - labels, - store_pairs, - params); + *this, n, x, k, cidx, cdis, dis, idx, store_pairs, params); + } else if (store_pairs) { + Run_search_knn_hamming_count r; + dispatch_HammingComputer( + code_size, r, *this, n, x, idx, k, dis, idx, params); } else { - if (store_pairs) { - search_knn_hamming_count_1( - *this, n, x, idx, k, distances, labels, params); - } else { - search_knn_hamming_count_1( - *this, n, x, idx, k, distances, labels, params); - } + Run_search_knn_hamming_count r; + dispatch_HammingComputer( + code_size, r, *this, n, x, idx, k, dis, idx, params); } } diff --git a/faiss/IndexIVFPQ.cpp b/faiss/IndexIVFPQ.cpp index 60633cc41b..058798b15c 100644 --- a/faiss/IndexIVFPQ.cpp +++ b/faiss/IndexIVFPQ.cpp @@ -1154,30 +1154,23 @@ struct IVFPQScannerT : QueryTables { { indexIVFPQ_stats.n_hamming_pass += n_hamming_pass; } } + template + struct Run_scan_list_polysemous_hc { + using T = void; + template + void f(const IVFPQScannerT* scanner, Types... args) { + scanner->scan_list_polysemous_hc( + args...); + } + }; + template void scan_list_polysemous( size_t ncode, const uint8_t* codes, SearchResultType& res) const { - switch (pq.code_size) { -#define HANDLE_CODE_SIZE(cs) \ - case cs: \ - scan_list_polysemous_hc( \ - ncode, codes, res); \ - break - HANDLE_CODE_SIZE(4); - HANDLE_CODE_SIZE(8); - HANDLE_CODE_SIZE(16); - HANDLE_CODE_SIZE(20); - HANDLE_CODE_SIZE(32); - HANDLE_CODE_SIZE(64); -#undef HANDLE_CODE_SIZE - default: - scan_list_polysemous_hc< - HammingComputerDefault, - SearchResultType>(ncode, codes, res); - break; - } + Run_scan_list_polysemous_hc r; + dispatch_HammingComputer(pq.code_size, r, this, ncode, codes, res); } }; diff --git a/faiss/IndexIVFSpectralHash.cpp b/faiss/IndexIVFSpectralHash.cpp index 443c45dee6..d9a51fbe64 100644 --- a/faiss/IndexIVFSpectralHash.cpp +++ b/faiss/IndexIVFSpectralHash.cpp @@ -288,26 +288,23 @@ struct IVFScanner : InvertedListScanner { } }; +struct BuildScanner { + using T = InvertedListScanner*; + + template + static T f(const IndexIVFSpectralHash* index, bool store_pairs) { + return new IVFScanner(index, store_pairs); + } +}; + } // anonymous namespace InvertedListScanner* IndexIVFSpectralHash::get_InvertedListScanner( bool store_pairs, const IDSelector* sel) const { FAISS_THROW_IF_NOT(!sel); - switch (code_size) { -#define HANDLE_CODE_SIZE(cs) \ - case cs: \ - return new IVFScanner(this, store_pairs) - HANDLE_CODE_SIZE(4); - HANDLE_CODE_SIZE(8); - HANDLE_CODE_SIZE(16); - HANDLE_CODE_SIZE(20); - HANDLE_CODE_SIZE(32); - HANDLE_CODE_SIZE(64); -#undef HANDLE_CODE_SIZE - default: - return new IVFScanner(this, store_pairs); - } + BuildScanner bs; + return dispatch_HammingComputer(code_size, bs, this, store_pairs); } void IndexIVFSpectralHash::replace_vt(VectorTransform* vt_in, bool own) { diff --git a/faiss/IndexPQ.cpp b/faiss/IndexPQ.cpp index 7b1c28f8fd..6326155356 100644 --- a/faiss/IndexPQ.cpp +++ b/faiss/IndexPQ.cpp @@ -263,8 +263,10 @@ void IndexPQStats::reset() { IndexPQStats indexPQ_stats; +namespace { + template -static size_t polysemous_inner_loop( +size_t polysemous_inner_loop( const IndexPQ& index, const float* dis_table_qi, const uint8_t* q_code, @@ -305,6 +307,16 @@ static size_t polysemous_inner_loop( return n_pass_i; } +struct Run_polysemous_inner_loop { + using T = size_t; + template + size_t f(Types... args) { + return polysemous_inner_loop(args...); + } +}; + +} // anonymous namespace + void IndexPQ::search_core_polysemous( idx_t n, const float* x, @@ -355,39 +367,18 @@ void IndexPQ::search_core_polysemous( maxheap_heapify(k, heap_dis, heap_ids); if (!generalized_hamming) { - switch (pq.code_size) { -#define DISPATCH(cs) \ - case cs: \ - n_pass += polysemous_inner_loop( \ - *this, \ - dis_table_qi, \ - q_code, \ - k, \ - heap_dis, \ - heap_ids, \ - polysemous_ht); \ - break; - DISPATCH(4) - DISPATCH(8) - DISPATCH(16) - DISPATCH(32) - DISPATCH(20) - default: - if (pq.code_size % 4 == 0) { - n_pass += polysemous_inner_loop( - *this, - dis_table_qi, - q_code, - k, - heap_dis, - heap_ids, - polysemous_ht); - } else { - bad_code_size++; - } - break; - } -#undef DISPATCH + Run_polysemous_inner_loop r; + n_pass += dispatch_HammingComputer( + pq.code_size, + r, + *this, + dis_table_qi, + q_code, + k, + heap_dis, + heap_ids, + polysemous_ht); + } else { // generalized hamming switch (pq.code_size) { #define DISPATCH(cs) \ diff --git a/faiss/utils/hamming.cpp b/faiss/utils/hamming.cpp index 7019183bd0..f61e2635d7 100644 --- a/faiss/utils/hamming.cpp +++ b/faiss/utils/hamming.cpp @@ -5,14 +5,13 @@ * LICENSE file in the root directory of this source tree. */ -// -*- c++ -*- - /* * Implementation of Hamming related functions (distances, smallest distance * selection with regular heap|radix and probabilistic heap|radix. * * IMPLEMENTATION NOTES - * Bitvectors are generally assumed to be multiples of 64 bits. + * Optiomal speed is typically obtained for vector sizes of multiples of 64 + * bits. * * hamdis_t is used for distances because at this time * it is not clear how we will need to balance @@ -20,8 +19,6 @@ * - memory usage * - cache-misses when dealing with large volumes of data (lower bits is better) * - * The hamdis_t should optimally be compatibe with one of the Torch Storage - * (Byte,Short,Long) and therefore should be signed for 2-bytes and 4-bytes */ #include @@ -165,9 +162,11 @@ size_t match_hamming_thres( return posm; } +namespace { + /* Return closest neighbors w.r.t Hamming distance, using a heap. */ template -static void hammings_knn_hc( +void hammings_knn_hc( int bytes_per_code, int_maxheap_array_t* __restrict ha, const uint8_t* __restrict bs1, @@ -234,7 +233,7 @@ static void hammings_knn_hc( /* Return closest neighbors w.r.t Hamming distance, using max count. */ template -static void hammings_knn_mc( +void hammings_knn_mc( int bytes_per_code, const uint8_t* __restrict a, const uint8_t* __restrict b, @@ -287,6 +286,63 @@ static void hammings_knn_mc( } } +template +void hamming_range_search_template( + const uint8_t* a, + const uint8_t* b, + size_t na, + size_t nb, + int radius, + size_t code_size, + RangeSearchResult* res) { +#pragma omp parallel + { + RangeSearchPartialResult pres(res); + +#pragma omp for + for (int64_t i = 0; i < na; i++) { + HammingComputer hc(a + i * code_size, code_size); + const uint8_t* yi = b; + RangeQueryResult& qres = pres.new_result(i); + + for (size_t j = 0; j < nb; j++) { + int dis = hc.hamming(yi); + if (dis < radius) { + qres.add(dis, j); + } + yi += code_size; + } + } + pres.finalize(); + } +} + +struct Run_hammings_knn_hc { + using T = void; + template + void f(Types... args) { + hammings_knn_hc(args...); + } +}; + +struct Run_hammings_knn_mc { + using T = void; + template + void f(Types... args) { + hammings_knn_mc(args...); + } +}; + +struct Run_hamming_range_search_template { + using T = void; + template + void f(Types... args) { + hamming_range_search_template(args...); + } +}; + +} // namespace + /* Functions to maps vectors to bits. Assume proper allocation done beforehand, meaning that b should be be able to receive as many bits as x may produce. */ @@ -437,28 +493,9 @@ void hammings_knn_hc( size_t ncodes, int order, ApproxTopK_mode_t approx_topk_mode) { - switch (ncodes) { - case 4: - hammings_knn_hc( - 4, ha, a, b, nb, order, true, approx_topk_mode); - break; - case 8: - hammings_knn_hc( - 8, ha, a, b, nb, order, true, approx_topk_mode); - break; - case 16: - hammings_knn_hc( - 16, ha, a, b, nb, order, true, approx_topk_mode); - break; - case 32: - hammings_knn_hc( - 32, ha, a, b, nb, order, true, approx_topk_mode); - break; - default: - hammings_knn_hc( - ncodes, ha, a, b, nb, order, true, approx_topk_mode); - break; - } + Run_hammings_knn_hc r; + dispatch_HammingComputer( + ncodes, r, ncodes, ha, a, b, nb, order, true, approx_topk_mode); } void hammings_knn_mc( @@ -470,58 +507,9 @@ void hammings_knn_mc( size_t ncodes, int32_t* __restrict distances, int64_t* __restrict labels) { - switch (ncodes) { - case 4: - hammings_knn_mc( - 4, a, b, na, nb, k, distances, labels); - break; - case 8: - hammings_knn_mc( - 8, a, b, na, nb, k, distances, labels); - break; - case 16: - hammings_knn_mc( - 16, a, b, na, nb, k, distances, labels); - break; - case 32: - hammings_knn_mc( - 32, a, b, na, nb, k, distances, labels); - break; - default: - hammings_knn_mc( - ncodes, a, b, na, nb, k, distances, labels); - break; - } -} -template -static void hamming_range_search_template( - const uint8_t* a, - const uint8_t* b, - size_t na, - size_t nb, - int radius, - size_t code_size, - RangeSearchResult* res) { -#pragma omp parallel - { - RangeSearchPartialResult pres(res); - -#pragma omp for - for (int64_t i = 0; i < na; i++) { - HammingComputer hc(a + i * code_size, code_size); - const uint8_t* yi = b; - RangeQueryResult& qres = pres.new_result(i); - - for (size_t j = 0; j < nb; j++) { - int dis = hc.hamming(yi); - if (dis < radius) { - qres.add(dis, j); - } - yi += code_size; - } - } - pres.finalize(); - } + Run_hammings_knn_mc r; + dispatch_HammingComputer( + ncodes, r, ncodes, a, b, na, nb, k, distances, labels); } void hamming_range_search( @@ -532,27 +520,9 @@ void hamming_range_search( int radius, size_t code_size, RangeSearchResult* result) { -#define HC(name) \ - hamming_range_search_template(a, b, na, nb, radius, code_size, result) - - switch (code_size) { - case 4: - HC(HammingComputer4); - break; - case 8: - HC(HammingComputer8); - break; - case 16: - HC(HammingComputer16); - break; - case 32: - HC(HammingComputer32); - break; - default: - HC(HammingComputerDefault); - break; - } -#undef HC + Run_hamming_range_search_template r; + dispatch_HammingComputer( + code_size, r, a, b, na, nb, radius, code_size, result); } /* Count number of matches given a max threshold */ diff --git a/faiss/utils/hamming_distance/avx2-inl.h b/faiss/utils/hamming_distance/avx2-inl.h index 2393b75778..4c007477d1 100644 --- a/faiss/utils/hamming_distance/avx2-inl.h +++ b/faiss/utils/hamming_distance/avx2-inl.h @@ -405,32 +405,7 @@ struct HammingComputerM4 { } }; -/*************************************************************************** - * Equivalence with a template class when code size is known at compile time - **************************************************************************/ - -// default template -template -struct HammingComputer : HammingComputerDefault { - HammingComputer(const uint8_t* a, int code_size) - : HammingComputerDefault(a, code_size) {} -}; - -#define SPECIALIZED_HC(CODE_SIZE) \ - template <> \ - struct HammingComputer : HammingComputer##CODE_SIZE { \ - HammingComputer(const uint8_t* a) \ - : HammingComputer##CODE_SIZE(a, CODE_SIZE) {} \ - } - -SPECIALIZED_HC(4); -SPECIALIZED_HC(8); -SPECIALIZED_HC(16); -SPECIALIZED_HC(20); -SPECIALIZED_HC(32); -SPECIALIZED_HC(64); -#undef SPECIALIZED_HC /*************************************************************************** * generalized Hamming = number of bytes that are different between diff --git a/faiss/utils/hamming_distance/generic-inl.h b/faiss/utils/hamming_distance/generic-inl.h index 8e9356c9ab..b9be24479c 100644 --- a/faiss/utils/hamming_distance/generic-inl.h +++ b/faiss/utils/hamming_distance/generic-inl.h @@ -389,32 +389,7 @@ struct HammingComputerM4 { } }; -/*************************************************************************** - * Equivalence with a template class when code size is known at compile time - **************************************************************************/ - -// default template -template -struct HammingComputer : HammingComputerDefault { - HammingComputer(const uint8_t* a, int code_size) - : HammingComputerDefault(a, code_size) {} -}; - -#define SPECIALIZED_HC(CODE_SIZE) \ - template <> \ - struct HammingComputer : HammingComputer##CODE_SIZE { \ - HammingComputer(const uint8_t* a) \ - : HammingComputer##CODE_SIZE(a, CODE_SIZE) {} \ - } - -SPECIALIZED_HC(4); -SPECIALIZED_HC(8); -SPECIALIZED_HC(16); -SPECIALIZED_HC(20); -SPECIALIZED_HC(32); -SPECIALIZED_HC(64); -#undef SPECIALIZED_HC /*************************************************************************** * generalized Hamming = number of bytes that are different between diff --git a/faiss/utils/hamming_distance/hamdis-inl.h b/faiss/utils/hamming_distance/hamdis-inl.h index aaea84735e..b830df38b6 100644 --- a/faiss/utils/hamming_distance/hamdis-inl.h +++ b/faiss/utils/hamming_distance/hamdis-inl.h @@ -23,4 +23,61 @@ #include #endif +namespace faiss { + +/*************************************************************************** + * Equivalence with a template class when code size is known at compile time + **************************************************************************/ + +// default template +template +struct HammingComputer : HammingComputerDefault { + HammingComputer(const uint8_t* a, int code_size) + : HammingComputerDefault(a, code_size) {} +}; + +#define SPECIALIZED_HC(CODE_SIZE) \ + template <> \ + struct HammingComputer : HammingComputer##CODE_SIZE { \ + HammingComputer(const uint8_t* a) \ + : HammingComputer##CODE_SIZE(a, CODE_SIZE) {} \ + } + +SPECIALIZED_HC(4); +SPECIALIZED_HC(8); +SPECIALIZED_HC(16); +SPECIALIZED_HC(20); +SPECIALIZED_HC(32); +SPECIALIZED_HC(64); + +#undef SPECIALIZED_HC + +/*************************************************************************** + * Dispatching function that takes a code size and a consumer object + * the consumer object should contain a retun type t and a operation template + * function f() that to be called to perform the operation. + **************************************************************************/ + +template +typename Consumer::T dispatch_HammingComputer( + int code_size, + Consumer& consumer, + Types... args) { + switch (code_size) { +#define DISPATCH_HC(CODE_SIZE) \ + case CODE_SIZE: \ + return consumer.template f(args...); + DISPATCH_HC(4); + DISPATCH_HC(8); + DISPATCH_HC(16); + DISPATCH_HC(20); + DISPATCH_HC(32); + DISPATCH_HC(64); + default: + return consumer.template f(args...); + } +} + +} // namespace faiss + #endif diff --git a/faiss/utils/hamming_distance/neon-inl.h b/faiss/utils/hamming_distance/neon-inl.h index 38b5aa6af2..98528b0062 100644 --- a/faiss/utils/hamming_distance/neon-inl.h +++ b/faiss/utils/hamming_distance/neon-inl.h @@ -468,33 +468,6 @@ struct HammingComputerM4 { } }; -/*************************************************************************** - * Equivalence with a template class when code size is known at compile time - **************************************************************************/ - -// default template -template -struct HammingComputer : HammingComputerDefault { - HammingComputer(const uint8_t* a, int code_size) - : HammingComputerDefault(a, code_size) {} -}; - -#define SPECIALIZED_HC(CODE_SIZE) \ - template <> \ - struct HammingComputer : HammingComputer##CODE_SIZE { \ - HammingComputer(const uint8_t* a) \ - : HammingComputer##CODE_SIZE(a, CODE_SIZE) {} \ - } - -SPECIALIZED_HC(4); -SPECIALIZED_HC(8); -SPECIALIZED_HC(16); -SPECIALIZED_HC(20); -SPECIALIZED_HC(32); -SPECIALIZED_HC(64); - -#undef SPECIALIZED_HC - /*************************************************************************** * generalized Hamming = number of bytes that are different between * two codes.