Skip to content

Commit

Permalink
use dispatcher function to call HammingComputer
Browse files Browse the repository at this point in the history
Summary:
The HammingComputer class is optimized for several vector sizes. So far it's been the caller's responsiblity to instanciate the relevant optimized version.

This diff introduces a `dispatch_HammingComputer` function that can be called with a template class that is instanciated for all existing optimized HammingComputer's.

Differential Revision: D46858553

fbshipit-source-id: 6ebe0b784755642b416ce4d34fbdc45ef9ea4c49
  • Loading branch information
mdouze authored and facebook-github-bot committed Jun 20, 2023
1 parent f69b1db commit e410d23
Show file tree
Hide file tree
Showing 11 changed files with 250 additions and 436 deletions.
30 changes: 10 additions & 20 deletions faiss/IndexBinaryHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,31 +281,21 @@ struct FlatHammingDis : DistanceComputer {
}
};

struct BuildDistanceComputer {
using T = DistanceComputer*;
template <class HammingComputer>
DistanceComputer* f(IndexBinaryFlat* flat_storage) {
return new FlatHammingDis<HammingComputer>(*flat_storage);
}
};

} // namespace

DistanceComputer* IndexBinaryHNSW::get_distance_computer() const {
IndexBinaryFlat* flat_storage = dynamic_cast<IndexBinaryFlat*>(storage);

FAISS_ASSERT(flat_storage != nullptr);

switch (code_size) {
case 4:
return new FlatHammingDis<HammingComputer4>(*flat_storage);
case 8:
return new FlatHammingDis<HammingComputer8>(*flat_storage);
case 16:
return new FlatHammingDis<HammingComputer16>(*flat_storage);
case 20:
return new FlatHammingDis<HammingComputer20>(*flat_storage);
case 32:
return new FlatHammingDis<HammingComputer32>(*flat_storage);
case 64:
return new FlatHammingDis<HammingComputer64>(*flat_storage);
default:
break;
}

return new FlatHammingDis<HammingComputerDefault>(*flat_storage);
BuildDistanceComputer bd;
return dispatch_HammingComputer(code_size, bd, flat_storage);
}

} // namespace faiss
68 changes: 22 additions & 46 deletions faiss/IndexBinaryHash.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,14 @@ void search_single_query_template(
} while (fe.next());
}

struct Run_search_single_query {
using T = void;
template <class HammingComputer, class... Types>
T f(Types... args) {
search_single_query_template<HammingComputer>(args...);
}
};

template <class SearchResults>
void search_single_query(
const IndexBinaryHash& index,
Expand All @@ -184,29 +192,9 @@ void search_single_query(
size_t& n0,
size_t& nlist,
size_t& ndis) {
#define HC(name) \
search_single_query_template<name>(index, q, res, n0, nlist, ndis);
switch (index.code_size) {
case 4:
HC(HammingComputer4);
break;
case 8:
HC(HammingComputer8);
break;
case 16:
HC(HammingComputer16);
break;
case 20:
HC(HammingComputer20);
break;
case 32:
HC(HammingComputer32);
break;
default:
HC(HammingComputerDefault);
break;
}
#undef HC
Run_search_single_query r;
dispatch_HammingComputer(
index.code_size, r, index, q, res, n0, nlist, ndis);
}

} // anonymous namespace
Expand Down Expand Up @@ -365,6 +353,14 @@ static void verify_shortlist(
}
}

struct Run_verify_shortlist {
using T = void;
template <class HammingComputer, class... Types>
void f(Types... args) {
verify_shortlist<HammingComputer>(args...);
}
};

template <class SearchResults>
void search_1_query_multihash(
const IndexBinaryMultiHash& index,
Expand Down Expand Up @@ -405,29 +401,9 @@ void search_1_query_multihash(
ndis += shortlist.size();

// verify shortlist

#define HC(name) verify_shortlist<name>(*index.storage, xi, shortlist, res)
switch (index.code_size) {
case 4:
HC(HammingComputer4);
break;
case 8:
HC(HammingComputer8);
break;
case 16:
HC(HammingComputer16);
break;
case 20:
HC(HammingComputer20);
break;
case 32:
HC(HammingComputer32);
break;
default:
HC(HammingComputerDefault);
break;
}
#undef HC
Run_verify_shortlist r;
dispatch_HammingComputer(
index.code_size, r, *index.storage, xi, shortlist, res);
}

} // anonymous namespace
Expand Down
165 changes: 41 additions & 124 deletions faiss/IndexBinaryIVF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -735,151 +735,68 @@ void search_knn_hamming_per_invlist(
}
}

struct Run_search_knn_hamming_per_invlist {
using T = void;

template <class HammingComputer, class... Types>
void f(Types... args) {
search_knn_hamming_per_invlist<HammingComputer>(args...);
}
};

template <bool store_pairs>
void search_knn_hamming_count_1(
const IndexBinaryIVF& ivf,
size_t nx,
const uint8_t* x,
const idx_t* keys,
int k,
int32_t* distances,
idx_t* labels,
const IVFSearchParameters* params) {
switch (ivf.code_size) {
#define HANDLE_CS(cs) \
case cs: \
search_knn_hamming_count<HammingComputer##cs, store_pairs>( \
ivf, nx, x, keys, k, distances, labels, params); \
break;
HANDLE_CS(4);
HANDLE_CS(8);
HANDLE_CS(16);
HANDLE_CS(20);
HANDLE_CS(32);
HANDLE_CS(64);
#undef HANDLE_CS
default:
search_knn_hamming_count<HammingComputerDefault, store_pairs>(
ivf, nx, x, keys, k, distances, labels, params);
break;
struct Run_search_knn_hamming_count {
using T = void;

template <class HammingComputer, class... Types>
void f(Types... args) {
search_knn_hamming_count<HammingComputer, store_pairs>(args...);
}
}
};

void search_knn_hamming_per_invlist_1(
const IndexBinaryIVF& ivf,
size_t n,
const uint8_t* x,
idx_t k,
const idx_t* keys,
const int32_t* coarse_dis,
int32_t* distances,
idx_t* labels,
bool store_pairs,
const IVFSearchParameters* params) {
switch (ivf.code_size) {
#define HANDLE_CS(cs) \
case cs: \
search_knn_hamming_per_invlist<HammingComputer##cs>( \
ivf, \
n, \
x, \
k, \
keys, \
coarse_dis, \
distances, \
labels, \
store_pairs, \
params); \
break;
HANDLE_CS(4);
HANDLE_CS(8);
HANDLE_CS(16);
HANDLE_CS(20);
HANDLE_CS(32);
HANDLE_CS(64);
#undef HANDLE_CS
default:
search_knn_hamming_per_invlist<HammingComputerDefault>(
ivf,
n,
x,
k,
keys,
coarse_dis,
distances,
labels,
store_pairs,
params);
break;
struct BuildScanner {
using T = BinaryInvertedListScanner*;

template <class HammingComputer>
T f(size_t code_size, bool store_pairs) {
return new IVFBinaryScannerL2<HammingComputer>(code_size, store_pairs);
}
}
};

} // anonymous namespace

BinaryInvertedListScanner* IndexBinaryIVF::get_InvertedListScanner(
bool store_pairs) const {
#define HC(name) return new IVFBinaryScannerL2<name>(code_size, store_pairs)
switch (code_size) {
case 4:
HC(HammingComputer4);
case 8:
HC(HammingComputer8);
case 16:
HC(HammingComputer16);
case 20:
HC(HammingComputer20);
case 32:
HC(HammingComputer32);
case 64:
HC(HammingComputer64);
default:
HC(HammingComputerDefault);
}
#undef HC
BuildScanner bs;
return dispatch_HammingComputer(code_size, bs, code_size, store_pairs);
}

void IndexBinaryIVF::search_preassigned(
idx_t n,
const uint8_t* x,
idx_t k,
const idx_t* idx,
const int32_t* coarse_dis,
int32_t* distances,
idx_t* labels,
const idx_t* cidx,
const int32_t* cdis,
int32_t* dis,
idx_t* idx,
bool store_pairs,
const IVFSearchParameters* params) const {
if (per_invlist_search) {
search_knn_hamming_per_invlist_1(
*this,
n,
x,
k,
idx,
coarse_dis,
distances,
labels,
store_pairs,
params);
Run_search_knn_hamming_per_invlist r;
dispatch_HammingComputer(
code_size, r, *this, n, x, k,
cidx, cdis, dis, idx, store_pairs, params);
} else if (use_heap) {
search_knn_hamming_heap(
*this,
n,
x,
k,
idx,
coarse_dis,
distances,
labels,
store_pairs,
params);
*this, n, x, k, cidx, cdis, dis, idx, store_pairs, params);
} else if (store_pairs) {
Run_search_knn_hamming_count<true> r;
dispatch_HammingComputer(
code_size, r, *this, n, x, idx, k, dis, idx, params);
} else {
if (store_pairs) {
search_knn_hamming_count_1<true>(
*this, n, x, idx, k, distances, labels, params);
} else {
search_knn_hamming_count_1<false>(
*this, n, x, idx, k, distances, labels, params);
}
Run_search_knn_hamming_count<false> r;
dispatch_HammingComputer(
code_size, r, *this, n, x, idx, k, dis, idx, params);
}
}

Expand Down
31 changes: 12 additions & 19 deletions faiss/IndexIVFPQ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1154,30 +1154,23 @@ struct IVFPQScannerT : QueryTables {
{ indexIVFPQ_stats.n_hamming_pass += n_hamming_pass; }
}

template <class SearchResultType>
struct Run_scan_list_polysemous_hc {
using T = void;
template <class HammingComputer, class... Types>
void f(const IVFPQScannerT* scanner, Types... args) {
scanner->scan_list_polysemous_hc<HammingComputer, SearchResultType>(
args...);
}
};

template <class SearchResultType>
void scan_list_polysemous(
size_t ncode,
const uint8_t* codes,
SearchResultType& res) const {
switch (pq.code_size) {
#define HANDLE_CODE_SIZE(cs) \
case cs: \
scan_list_polysemous_hc<HammingComputer##cs, SearchResultType>( \
ncode, codes, res); \
break
HANDLE_CODE_SIZE(4);
HANDLE_CODE_SIZE(8);
HANDLE_CODE_SIZE(16);
HANDLE_CODE_SIZE(20);
HANDLE_CODE_SIZE(32);
HANDLE_CODE_SIZE(64);
#undef HANDLE_CODE_SIZE
default:
scan_list_polysemous_hc<
HammingComputerDefault,
SearchResultType>(ncode, codes, res);
break;
}
Run_scan_list_polysemous_hc<SearchResultType> r;
dispatch_HammingComputer(pq.code_size, r, this, ncode, codes, res);
}
};

Expand Down
Loading

0 comments on commit e410d23

Please sign in to comment.