From 164861cf7d879b72363dcdaaa988fa66eb77406b Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Thu, 4 May 2023 06:31:25 -0700 Subject: [PATCH] improve code_distance() for avx2 for 4 and 8 subquantizers (#2831) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2831 Differential Revision: D45329803 fbshipit-source-id: 540d6e85f36eb537d7cddf84028e370284f2df86 --- faiss/IndexIVFPQ.cpp | 45 ++- faiss/IndexPQ.cpp | 2 +- faiss/impl/code_distance/code_distance-avx2.h | 302 ++++++++++++++++-- .../code_distance/code_distance-generic.h | 35 +- faiss/impl/code_distance/code_distance.h | 34 +- tests/CMakeLists.txt | 1 + tests/test_code_distance.cpp | 241 ++++++++++++++ 7 files changed, 586 insertions(+), 74 deletions(-) create mode 100644 tests/test_code_distance.cpp diff --git a/faiss/IndexIVFPQ.cpp b/faiss/IndexIVFPQ.cpp index defbf9d5ec..fd91738ad1 100644 --- a/faiss/IndexIVFPQ.cpp +++ b/faiss/IndexIVFPQ.cpp @@ -937,7 +937,8 @@ struct IVFPQScannerT : QueryTables { float distance_2 = 0; float distance_3 = 0; distance_four_codes( - pq, + pq.M, + pq.nbits, sim_table, codes + saved_j[0] * pq.code_size, codes + saved_j[1] * pq.code_size, @@ -957,24 +958,30 @@ struct IVFPQScannerT : QueryTables { } if (counter >= 1) { - float dis = - dis0 + + float dis = dis0 + distance_single_code( - pq, sim_table, codes + saved_j[0] * pq.code_size); + pq.M, + pq.nbits, + sim_table, + codes + saved_j[0] * pq.code_size); res.add(saved_j[0], dis); } if (counter >= 2) { - float dis = - dis0 + + float dis = dis0 + distance_single_code( - pq, sim_table, codes + saved_j[1] * pq.code_size); + pq.M, + pq.nbits, + sim_table, + codes + saved_j[1] * pq.code_size); res.add(saved_j[1], dis); } if (counter >= 3) { - float dis = - dis0 + + float dis = dis0 + distance_single_code( - pq, sim_table, codes + saved_j[2] * pq.code_size); + pq.M, + pq.nbits, + sim_table, + codes + saved_j[2] * pq.code_size); res.add(saved_j[2], dis); } } @@ -1137,7 +1144,8 @@ struct IVFPQScannerT : QueryTables { float distance_2 = dis0; float distance_3 = dis0; distance_four_codes( - pq, + pq.M, + pq.nbits, sim_table, codes + saved_j[0] * pq.code_size, codes + saved_j[1] * pq.code_size, @@ -1165,10 +1173,12 @@ struct IVFPQScannerT : QueryTables { for (size_t kk = 0; kk < counter; kk++) { n_hamming_pass++; - float dis = - dis0 + + float dis = dis0 + distance_single_code( - pq, sim_table, codes + saved_j[kk] * pq.code_size); + pq.M, + pq.nbits, + sim_table, + codes + saved_j[kk] * pq.code_size); res.add(saved_j[kk], dis); } @@ -1185,7 +1195,10 @@ struct IVFPQScannerT : QueryTables { float dis = dis0 + distance_single_code( - pq, sim_table, codes + j * code_size); + pq.M, + pq.nbits, + sim_table, + codes + j * code_size); res.add(j, dis); } @@ -1263,7 +1276,7 @@ struct IVFPQScanner : IVFPQScannerT, assert(precompute_mode == 2); float dis = this->dis0 + distance_single_code( - this->pq, this->sim_table, code); + this->pq.M, this->pq.nbits, this->sim_table, code); return dis; } diff --git a/faiss/IndexPQ.cpp b/faiss/IndexPQ.cpp index 7df08899a4..7b1c28f8fd 100644 --- a/faiss/IndexPQ.cpp +++ b/faiss/IndexPQ.cpp @@ -86,7 +86,7 @@ struct PQDistanceComputer : FlatCodesDistanceComputer { ndis++; float dis = distance_single_code( - pq, precomputed_table.data(), code); + pq.M, pq.nbits, precomputed_table.data(), code); return dis; } diff --git a/faiss/impl/code_distance/code_distance-avx2.h b/faiss/impl/code_distance/code_distance-avx2.h index 3202025fb7..0aa1535b28 100644 --- a/faiss/impl/code_distance/code_distance-avx2.h +++ b/faiss/impl/code_distance/code_distance-avx2.h @@ -13,6 +13,7 @@ #include +#include #include namespace { @@ -32,6 +33,200 @@ inline float horizontal_sum(const __m256 v) { return horizontal_sum(v0); } +// processes a single code for M=4, ksub=256, nbits=8 +float inline distance_single_code_avx2_pqdecoder8_m4( + // precomputed distances, layout (4, 256) + const float* sim_table, + const uint8_t* code) { + float result = 0; + + const float* tab = sim_table; + constexpr size_t ksub = 1 << 8; + + const __m128i vksub = _mm_set1_epi32(ksub); + __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3); + offsets_0 = _mm_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m128 partialSum; + + // load 4 uint8 values + const __m128i mm1 = _mm_cvtsi32_si128(*((const int32_t*)code)); + { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m128i idx1 = _mm_cvtepu8_epi32(mm1); + + // add offsets + const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m128 collected = + _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSum = collected; + } + + // horizontal sum for partialSum + result = horizontal_sum(partialSum); + return result; +} + +// processes a single code for M=8, ksub=256, nbits=8 +float inline distance_single_code_avx2_pqdecoder8_m8( + // precomputed distances, layout (8, 256) + const float* sim_table, + const uint8_t* code) { + float result = 0; + + const float* tab = sim_table; + constexpr size_t ksub = 1 << 8; + + const __m256i vksub = _mm256_set1_epi32(ksub); + __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m256 partialSum; + + // load 8 uint8 values + const __m128i mm1 = _mm_loadu_si64((const __m128i_u*)code); + { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm1); + + // add offsets + const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = + _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSum = collected; + } + + // horizontal sum for partialSum + result = horizontal_sum(partialSum); + return result; +} + +// processes four codes for M=4, ksub=256, nbits=8 +inline void distance_four_codes_avx2_pqdecoder8_m4( + // precomputed distances, layout (4, 256) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + constexpr intptr_t N = 4; + + const float* tab = sim_table; + constexpr size_t ksub = 1 << 8; + + // process 8 values + const __m128i vksub = _mm_set1_epi32(ksub); + __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3); + offsets_0 = _mm_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m128 partialSums[N]; + + // load 4 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_cvtsi32_si128(*((const int32_t*)code0)); + mm1[1] = _mm_cvtsi32_si128(*((const int32_t*)code1)); + mm1[2] = _mm_cvtsi32_si128(*((const int32_t*)code2)); + mm1[3] = _mm_cvtsi32_si128(*((const int32_t*)code3)); + + for (intptr_t j = 0; j < N; j++) { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m128i idx1 = _mm_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0); + + // gather 4 values, similar to 4 operations of tab[idx] + __m128 collected = + _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSums[j] = collected; + } + + // horizontal sum for partialSum + result0 = horizontal_sum(partialSums[0]); + result1 = horizontal_sum(partialSums[1]); + result2 = horizontal_sum(partialSums[2]); + result3 = horizontal_sum(partialSums[3]); +} + +// processes four codes for M=8, ksub=256, nbits=8 +inline void distance_four_codes_avx2_pqdecoder8_m8( + // precomputed distances, layout (8, 256) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + constexpr intptr_t N = 4; + + const float* tab = sim_table; + constexpr size_t ksub = 1 << 8; + + // process 8 values + const __m256i vksub = _mm256_set1_epi32(ksub); + __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m256 partialSums[N]; + + // load 8 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_loadu_si64((const __m128i_u*)code0); + mm1[1] = _mm_loadu_si64((const __m128i_u*)code1); + mm1[2] = _mm_loadu_si64((const __m128i_u*)code2); + mm1[3] = _mm_loadu_si64((const __m128i_u*)code3); + + for (intptr_t j = 0; j < N; j++) { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = + _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSums[j] = collected; + } + + // horizontal sum for partialSum + result0 = horizontal_sum(partialSums[0]); + result1 = horizontal_sum(partialSums[1]); + result2 = horizontal_sum(partialSums[2]); + result3 = horizontal_sum(partialSums[3]); +} + } // namespace namespace faiss { @@ -39,36 +234,48 @@ namespace faiss { template typename std::enable_if::value, float>:: type inline distance_single_code_avx2( - // the product quantizer - const ProductQuantizer& pq, + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, // precomputed distances, layout (M, ksub) const float* sim_table, const uint8_t* code) { // default implementation - return distance_single_code_generic(pq, sim_table, code); + return distance_single_code_generic(M, nbits, sim_table, code); } template typename std::enable_if::value, float>:: type inline distance_single_code_avx2( - // the product quantizer - const ProductQuantizer& pq, + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, // precomputed distances, layout (M, ksub) const float* sim_table, const uint8_t* code) { + if (M == 4) { + return distance_single_code_avx2_pqdecoder8_m4(sim_table, code); + } + if (M == 8) { + return distance_single_code_avx2_pqdecoder8_m8(sim_table, code); + } + float result = 0; + constexpr size_t ksub = 1 << 8; size_t m = 0; - const size_t pqM16 = pq.M / 16; + const size_t pqM16 = M / 16; const float* tab = sim_table; if (pqM16 > 0) { // process 16 values per loop - const __m256i ksub = _mm256_set1_epi32(pq.ksub); + const __m256i vksub = _mm256_set1_epi32(ksub); __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - offsets_0 = _mm256_mullo_epi32(offsets_0, ksub); + offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); // accumulators of partial sums __m256 partialSum = _mm256_setzero_ps(); @@ -89,7 +296,7 @@ typename std::enable_if::value, float>:: // gather 8 values, similar to 8 operations of tab[idx] __m256 collected = _mm256_i32gather_ps( tab, indices_to_read_from, sizeof(float)); - tab += pq.ksub * 8; + tab += ksub * 8; // collect partial sums partialSum = _mm256_add_ps(partialSum, collected); @@ -109,7 +316,7 @@ typename std::enable_if::value, float>:: // gather 8 values, similar to 8 operations of tab[idx] __m256 collected = _mm256_i32gather_ps( tab, indices_to_read_from, sizeof(float)); - tab += pq.ksub * 8; + tab += ksub * 8; // collect partial sums partialSum = _mm256_add_ps(partialSum, collected); @@ -121,13 +328,13 @@ typename std::enable_if::value, float>:: } // - if (m < pq.M) { + if (m < M) { // process leftovers - PQDecoder8 decoder(code + m, pq.nbits); + PQDecoder8 decoder(code + m, nbits); - for (; m < pq.M; m++) { + for (; m < M; m++) { result += tab[decoder.decode()]; - tab += pq.ksub; + tab += ksub; } } @@ -138,8 +345,10 @@ template typename std::enable_if::value, void>:: type distance_four_codes_avx2( - // the product quantizer - const ProductQuantizer& pq, + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, // precomputed distances, layout (M, ksub) const float* sim_table, // codes @@ -153,7 +362,8 @@ typename std::enable_if::value, void>:: float& result2, float& result3) { distance_four_codes_generic( - pq, + M, + nbits, sim_table, code0, code1, @@ -169,8 +379,10 @@ typename std::enable_if::value, void>:: template typename std::enable_if::value, void>::type distance_four_codes_avx2( - // the product quantizer - const ProductQuantizer& pq, + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, // precomputed distances, layout (M, ksub) const float* sim_table, // codes @@ -183,13 +395,41 @@ distance_four_codes_avx2( float& result1, float& result2, float& result3) { + if (M == 4) { + distance_four_codes_avx2_pqdecoder8_m4( + sim_table, + code0, + code1, + code2, + code3, + result0, + result1, + result2, + result3); + return; + } + if (M == 8) { + distance_four_codes_avx2_pqdecoder8_m8( + sim_table, + code0, + code1, + code2, + code3, + result0, + result1, + result2, + result3); + return; + } + result0 = 0; result1 = 0; result2 = 0; result3 = 0; + constexpr size_t ksub = 1 << 8; size_t m = 0; - const size_t pqM16 = pq.M / 16; + const size_t pqM16 = M / 16; constexpr intptr_t N = 4; @@ -197,9 +437,9 @@ distance_four_codes_avx2( if (pqM16 > 0) { // process 16 values per loop - const __m256i ksub = _mm256_set1_epi32(pq.ksub); + const __m256i vksub = _mm256_set1_epi32(ksub); __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - offsets_0 = _mm256_mullo_epi32(offsets_0, ksub); + offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); // accumulators of partial sums __m256 partialSums[N]; @@ -233,7 +473,7 @@ distance_four_codes_avx2( // collect partial sums partialSums[j] = _mm256_add_ps(partialSums[j], collected); } - tab += pq.ksub * 8; + tab += ksub * 8; // process next 8 codes for (intptr_t j = 0; j < N; j++) { @@ -257,7 +497,7 @@ distance_four_codes_avx2( partialSums[j] = _mm256_add_ps(partialSums[j], collected); } - tab += pq.ksub * 8; + tab += ksub * 8; } // horizontal sum for partialSum @@ -268,18 +508,18 @@ distance_four_codes_avx2( } // - if (m < pq.M) { + if (m < M) { // process leftovers - PQDecoder8 decoder0(code0 + m, pq.nbits); - PQDecoder8 decoder1(code1 + m, pq.nbits); - PQDecoder8 decoder2(code2 + m, pq.nbits); - PQDecoder8 decoder3(code3 + m, pq.nbits); - for (; m < pq.M; m++) { + PQDecoder8 decoder0(code0 + m, nbits); + PQDecoder8 decoder1(code1 + m, nbits); + PQDecoder8 decoder2(code2 + m, nbits); + PQDecoder8 decoder3(code3 + m, nbits); + for (; m < M; m++) { result0 += tab[decoder0.decode()]; result1 += tab[decoder1.decode()]; result2 += tab[decoder2.decode()]; result3 += tab[decoder3.decode()]; - tab += pq.ksub; + tab += ksub; } } } diff --git a/faiss/impl/code_distance/code_distance-generic.h b/faiss/impl/code_distance/code_distance-generic.h index f17287695a..31f18d277d 100644 --- a/faiss/impl/code_distance/code_distance-generic.h +++ b/faiss/impl/code_distance/code_distance-generic.h @@ -7,27 +7,31 @@ #pragma once -#include +#include +#include namespace faiss { /// Returns the distance to a single code. template inline float distance_single_code_generic( - // the product quantizer - const ProductQuantizer& pq, + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, // precomputed distances, layout (M, ksub) const float* sim_table, // the code const uint8_t* code) { - PQDecoderT decoder(code, pq.nbits); + PQDecoderT decoder(code, nbits); + const size_t ksub = 1 << nbits; const float* tab = sim_table; float result = 0; - for (size_t m = 0; m < pq.M; m++) { + for (size_t m = 0; m < M; m++) { result += tab[decoder.decode()]; - tab += pq.ksub; + tab += ksub; } return result; @@ -37,8 +41,10 @@ inline float distance_single_code_generic( /// General-purpose version. template inline void distance_four_codes_generic( - // the product quantizer - const ProductQuantizer& pq, + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, // precomputed distances, layout (M, ksub) const float* sim_table, // codes @@ -51,10 +57,11 @@ inline void distance_four_codes_generic( float& result1, float& result2, float& result3) { - PQDecoderT decoder0(code0, pq.nbits); - PQDecoderT decoder1(code1, pq.nbits); - PQDecoderT decoder2(code2, pq.nbits); - PQDecoderT decoder3(code3, pq.nbits); + PQDecoderT decoder0(code0, nbits); + PQDecoderT decoder1(code1, nbits); + PQDecoderT decoder2(code2, nbits); + PQDecoderT decoder3(code3, nbits); + const size_t ksub = 1 << nbits; const float* tab = sim_table; result0 = 0; @@ -62,12 +69,12 @@ inline void distance_four_codes_generic( result2 = 0; result3 = 0; - for (size_t m = 0; m < pq.M; m++) { + for (size_t m = 0; m < M; m++) { result0 += tab[decoder0.decode()]; result1 += tab[decoder1.decode()]; result2 += tab[decoder2.decode()]; result3 += tab[decoder3.decode()]; - tab += pq.ksub; + tab += ksub; } } diff --git a/faiss/impl/code_distance/code_distance.h b/faiss/impl/code_distance/code_distance.h index e36be567a6..7cdf932f50 100644 --- a/faiss/impl/code_distance/code_distance.h +++ b/faiss/impl/code_distance/code_distance.h @@ -32,19 +32,23 @@ namespace faiss { template inline float distance_single_code( - // the product quantizer - const ProductQuantizer& pq, + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, // precomputed distances, layout (M, ksub) const float* sim_table, // the code const uint8_t* code) { - return distance_single_code_avx2(pq, sim_table, code); + return distance_single_code_avx2(M, nbits, sim_table, code); } template inline void distance_four_codes( - // the product quantizer - const ProductQuantizer& pq, + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, // precomputed distances, layout (M, ksub) const float* sim_table, // codes @@ -58,7 +62,8 @@ inline void distance_four_codes( float& result2, float& result3) { distance_four_codes_avx2( - pq, + M, + nbits, sim_table, code0, code1, @@ -80,19 +85,23 @@ namespace faiss { template inline float distance_single_code( - // the product quantizer - const ProductQuantizer& pq, + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, // precomputed distances, layout (M, ksub) const float* sim_table, // the code const uint8_t* code) { - return distance_single_code_generic(pq, sim_table, code); + return distance_single_code_generic(M, nbits, sim_table, code); } template inline void distance_four_codes( - // the product quantizer - const ProductQuantizer& pq, + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, // precomputed distances, layout (M, ksub) const float* sim_table, // codes @@ -106,7 +115,8 @@ inline void distance_four_codes( float& result2, float& result3) { distance_four_codes_generic( - pq, + M, + nbits, sim_table, code0, code1, diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 8167c7dfc7..ecf45cde50 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -27,6 +27,7 @@ set(FAISS_TEST_SRC test_RCQ_cropping.cpp test_distances_simd.cpp test_heap.cpp + test_code_distance.cpp ) add_executable(faiss_test ${FAISS_TEST_SRC}) diff --git a/tests/test_code_distance.cpp b/tests/test_code_distance.cpp new file mode 100644 index 0000000000..c144807630 --- /dev/null +++ b/tests/test_code_distance.cpp @@ -0,0 +1,241 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +size_t nMismatches( + const std::vector& ref, + const std::vector& candidate) { + size_t count = 0; + for (size_t i = 0; i < count; i++) { + double abs = std::abs(ref[i] - candidate[i]); + if (abs >= 1e-5) { + count += 1; + } + } + + return count; +} + +void test( + // dimensionality of the data + const size_t dim, + // number of subquantizers + const size_t subq, + // bits per subquantizer + const size_t nbits, + // number of codes to process + const size_t n) { + FAISS_THROW_IF_NOT(nbits == 8); + + // remove if benchmarking is needed + omp_set_num_threads(1); + + // rng + std::minstd_rand rng(123); + std::uniform_int_distribution u(0, 255); + std::uniform_real_distribution uf(0, 1); + + // initialize lookup + std::vector lookup(256 * subq, 0); + for (size_t i = 0; i < lookup.size(); i++) { + lookup[i] = uf(rng); + } + + // initialize codes + std::vector codes(n * subq); +#pragma omp parallel + { + std::minstd_rand rng0(123); + std::uniform_int_distribution u1(0, 255); + +#pragma omp for schedule(guided) + for (size_t i = 0; i < codes.size(); i++) { + codes[i] = u1(rng0); + } + } + + // warmup. compute reference results + std::vector resultsRef(n, 0); + for (size_t k = 0; k < 10; k++) { +#pragma omp parallel for schedule(guided) + for (size_t i = 0; i < n; i++) { + resultsRef[i] = + faiss::distance_single_code_generic( + subq, 8, lookup.data(), codes.data() + subq * i); + } + } + + // generic, 1 code per step + std::vector resultsNewGeneric1x(n, 0); + double generic1xMsec = 0; + { + const auto startingTimepoint = std::chrono::steady_clock::now(); + for (size_t k = 0; k < 1000; k++) { +#pragma omp parallel for schedule(guided) + for (size_t i = 0; i < n; i++) { + resultsNewGeneric1x[i] = + faiss::distance_single_code_generic( + subq, + 8, + lookup.data(), + codes.data() + subq * i); + } + } + const auto endingTimepoint = std::chrono::steady_clock::now(); + + std::chrono::duration duration = + endingTimepoint - startingTimepoint; + generic1xMsec = (duration.count() * 1000.0); + } + + // generic, 4 codes per step + std::vector resultsNewGeneric4x(n, 0); + double generic4xMsec = 0; + { + const auto startingTimepoint = std::chrono::steady_clock::now(); + for (size_t k = 0; k < 1000; k++) { +#pragma omp parallel for schedule(guided) + for (size_t i = 0; i < n; i += 4) { + faiss::distance_four_codes_generic( + subq, + 8, + lookup.data(), + codes.data() + subq * (i + 0), + codes.data() + subq * (i + 1), + codes.data() + subq * (i + 2), + codes.data() + subq * (i + 3), + resultsNewGeneric4x[i + 0], + resultsNewGeneric4x[i + 1], + resultsNewGeneric4x[i + 2], + resultsNewGeneric4x[i + 3]); + } + } + + const auto endingTimepoint = std::chrono::steady_clock::now(); + + std::chrono::duration duration = + endingTimepoint - startingTimepoint; + generic4xMsec = (duration.count() * 1000.0); + } + + // generic, 1 code per step + std::vector resultsNewCustom1x(n, 0); + double custom1xMsec = 0; + { + const auto startingTimepoint = std::chrono::steady_clock::now(); + for (size_t k = 0; k < 1000; k++) { +#pragma omp parallel for schedule(guided) + for (size_t i = 0; i < n; i++) { + resultsNewCustom1x[i] = + faiss::distance_single_code( + subq, + 8, + lookup.data(), + codes.data() + subq * i); + } + } + const auto endingTimepoint = std::chrono::steady_clock::now(); + + std::chrono::duration duration = + endingTimepoint - startingTimepoint; + custom1xMsec = (duration.count() * 1000.0); + } + + // generic, 4 codes per step + std::vector resultsNewCustom4x(n, 0); + double custom4xMsec = 0; + { + const auto startingTimepoint = std::chrono::steady_clock::now(); + for (size_t k = 0; k < 1000; k++) { +#pragma omp parallel for schedule(guided) + for (size_t i = 0; i < n; i += 4) { + faiss::distance_four_codes( + subq, + 8, + lookup.data(), + codes.data() + subq * (i + 0), + codes.data() + subq * (i + 1), + codes.data() + subq * (i + 2), + codes.data() + subq * (i + 3), + resultsNewCustom4x[i + 0], + resultsNewCustom4x[i + 1], + resultsNewCustom4x[i + 2], + resultsNewCustom4x[i + 3]); + } + } + + const auto endingTimepoint = std::chrono::steady_clock::now(); + + std::chrono::duration duration = + endingTimepoint - startingTimepoint; + custom4xMsec = (duration.count() * 1000.0); + } + + const size_t nMismatchesG1 = nMismatches(resultsRef, resultsNewGeneric1x); + const size_t nMismatchesG4 = nMismatches(resultsRef, resultsNewGeneric4x); + const size_t nMismatchesCustom1 = + nMismatches(resultsRef, resultsNewCustom1x); + const size_t nMismatchesCustom4 = + nMismatches(resultsRef, resultsNewCustom4x); + + std::cout << "Dim = " << dim << ", subq = " << subq << ", nbits = " << nbits + << ", n = " << n << std::endl; + std::cout << "Generic 1x code: " << generic1xMsec << " msec, " + << nMismatchesG1 << " mismatches" << std::endl; + std::cout << "Generic 4x code: " << generic4xMsec << " msec, " + << nMismatchesG4 << " mismatches" << std::endl; + std::cout << "custom 1x code: " << custom1xMsec << " msec, " + << nMismatchesCustom1 << " mismatches" << std::endl; + std::cout << "custom 4x code: " << custom4xMsec << " msec, " + << nMismatchesCustom4 << " mismatches" << std::endl; + std::cout << std::endl; + + ASSERT_EQ(nMismatchesG1, 0); + ASSERT_EQ(nMismatchesG4, 0); + ASSERT_EQ(nMismatchesCustom1, 0); + ASSERT_EQ(nMismatchesCustom4, 0); +} + +// this test can be used as a benchmark. +// 1. Increase the value of NELEMENTS +// 2. Remove omp_set_num_threads() + +constexpr size_t NELEMENTS = 10000; + +TEST(TEST_CODE_DISTANCE, SUBQ4_NBITS8) { + test(256, 4, 8, NELEMENTS); +} + +TEST(TEST_CODE_DISTANCE, SUBQ8_NBITS8) { + test(256, 8, 8, NELEMENTS); +} + +TEST(TEST_CODE_DISTANCE, SUBQ16_NBITS8) { + test(256, 16, 8, NELEMENTS); +} + +TEST(TEST_CODE_DISTANCE, SUBQ32_NBITS8) { + test(256, 32, 8, NELEMENTS); +}