Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up batch comments + obey IO_FLAG_SKIP_PRECOMPUTE_TABLE #3013

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
312 changes: 155 additions & 157 deletions faiss/IndexAdditiveQuantizer.cpp

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions faiss/impl/AdditiveQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,8 @@ void AdditiveQuantizer::compute_LUT(

namespace {

/* compute inner products of one query with all centroids, given a look-up
* table of all inner producst with codebook entries */
void compute_inner_prod_with_LUT(
const AdditiveQuantizer& aq,
const float* LUT,
Expand Down
6 changes: 4 additions & 2 deletions faiss/impl/AdditiveQuantizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,13 @@ struct AdditiveQuantizer : Quantizer {
/// encode a norm into norm_bits bits
uint64_t encode_norm(float norm) const;

/// encode norm by non-uniform scalar quantization
uint32_t encode_qcint(
float x) const; ///< encode norm by non-uniform scalar quantization
float x) const;

/// decode norm by non-uniform scalar quantization
float decode_qcint(uint32_t c)
const; ///< decode norm by non-uniform scalar quantization
const;

/// Encodes how search is performed and how vectors are encoded
enum Search_type_t {
Expand Down
67 changes: 26 additions & 41 deletions faiss/impl/ResidualQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ void ResidualQuantizer::initialize_from(
}
}

/****************************************************************
* Encoding steps, used both for training and search
*/

void beam_search_encode_step(
size_t d,
size_t K,
Expand Down Expand Up @@ -277,6 +281,10 @@ void beam_search_encode_step(
}
}

/****************************************************************
* Training
****************************************************************/

void ResidualQuantizer::train(size_t n, const float* x) {
codebooks.resize(d * codebook_offsets.back());

Expand Down Expand Up @@ -568,7 +576,12 @@ size_t ResidualQuantizer::memory_per_point(int beam_size) const {
return mem;
}

// a namespace full of preallocated buffers
/****************************************************************
* Encoding
****************************************************************/

// a namespace full of preallocated buffers. This speeds up
// computations, instead of re-allocating them at every encoing step
namespace {

// Preallocated memory chunk for refine_beam_mp() call
Expand Down Expand Up @@ -609,8 +622,6 @@ struct ComputeCodesAddCentroidsLUT1MemoryPool {
RefineBeamLUTMemoryPool refine_beam_lut_pool;
};

} // namespace

// forward declaration
void refine_beam_mp(
const ResidualQuantizer& rq,
Expand Down Expand Up @@ -743,6 +754,8 @@ void compute_codes_add_centroids_mp_lut1(
centroids);
}

} // namespace

void ResidualQuantizer::compute_codes_add_centroids(
const float* x,
uint8_t* codes_out,
Expand All @@ -769,11 +782,6 @@ void ResidualQuantizer::compute_codes_add_centroids(
cent = centroids + i0 * d;
}

// compute_codes_add_centroids(
// x + i0 * d,
// codes_out + i0 * code_size,
// i1 - i0,
// cent);
if (use_beam_LUT == 0) {
compute_codes_add_centroids_mp_lut0(
*this,
Expand All @@ -794,6 +802,8 @@ void ResidualQuantizer::compute_codes_add_centroids(
}
}

namespace {

void refine_beam_mp(
const ResidualQuantizer& rq,
size_t n,
Expand Down Expand Up @@ -873,15 +883,11 @@ void refine_beam_mp(
codebooks_m,
n,
cur_beam_size,
// residuals.data(),
residuals_ptr,
m,
// codes.data(),
codes_ptr,
new_beam_size,
// new_codes.data(),
new_codes_ptr,
// new_residuals.data(),
new_residuals_ptr,
pool.distances.data(),
assign_index.get(),
Expand All @@ -896,9 +902,6 @@ void refine_beam_mp(

if (rq.verbose) {
float sum_distances = 0;
// for (int j = 0; j < distances.size(); j++) {
// sum_distances += distances[j];
// }
for (int j = 0; j < distances_size; j++) {
sum_distances += pool.distances[j];
}
Expand All @@ -914,27 +917,22 @@ void refine_beam_mp(
}

if (out_codes) {
// memcpy(out_codes, codes.data(), codes.size() * sizeof(codes[0]));
memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr));
}
if (out_residuals) {
// memcpy(out_residuals,
// residuals.data(),
// residuals.size() * sizeof(residuals[0]));
memcpy(out_residuals,
residuals_ptr,
residuals_size * sizeof(*residuals_ptr));
}
if (out_distances) {
// memcpy(out_distances,
// distances.data(),
// distances.size() * sizeof(distances[0]));
memcpy(out_distances,
pool.distances.data(),
distances_size * sizeof(pool.distances[0]));
}
}

} // anonymous namespace

void ResidualQuantizer::refine_beam(
size_t n,
size_t beam_size,
Expand Down Expand Up @@ -1165,7 +1163,7 @@ void accum_and_finalize_tab(
}
}

} // namespace
} // anonymous namespace

void beam_search_encode_step_tab(
size_t K,
Expand Down Expand Up @@ -1390,6 +1388,8 @@ void beam_search_encode_step_tab(
}
}

namespace {

//
void refine_beam_LUT_mp(
const ResidualQuantizer& rq,
Expand Down Expand Up @@ -1443,13 +1443,9 @@ void refine_beam_LUT_mp(
for (int m = 0; m < rq.M; m++) {
int K = 1 << rq.nbits[m];

// it is guaranteed that (new_beam_size <= than max_beam_size) ==
// true
// it is guaranteed that (new_beam_size <= max_beam_size)
int new_beam_size = std::min(beam_size * K, out_beam_size);

// std::vector<int32_t> new_codes(n * new_beam_size * (m + 1));
// std::vector<float> new_distances(n * new_beam_size);

codes_size = n * new_beam_size * (m + 1);
distances_size = n * new_beam_size;

Expand All @@ -1464,29 +1460,20 @@ void refine_beam_LUT_mp(
rq.total_codebook_size,
rq.cent_norms.data() + rq.codebook_offsets[m],
m,
// codes.data(),
codes_ptr,
// distances.data(),
distances_ptr,
new_beam_size,
// new_codes.data(),
new_codes_ptr,
// new_distances.data()
new_distances_ptr,
rq.approx_topk_mode);

// codes.swap(new_codes);
std::swap(codes_ptr, new_codes_ptr);
// distances.swap(new_distances);
std::swap(distances_ptr, new_distances_ptr);

beam_size = new_beam_size;

if (rq.verbose) {
float sum_distances = 0;
// for (int j = 0; j < distances.size(); j++) {
// sum_distances += distances[j];
// }
for (int j = 0; j < distances_size; j++) {
sum_distances += distances_ptr[j];
}
Expand All @@ -1501,19 +1488,17 @@ void refine_beam_LUT_mp(
}

if (out_codes) {
// memcpy(out_codes, codes.data(), codes.size() * sizeof(codes[0]));
memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr));
}
if (out_distances) {
// memcpy(out_distances,
// distances.data(),
// distances.size() * sizeof(distances[0]));
memcpy(out_distances,
distances_ptr,
distances_size * sizeof(*distances_ptr));
}
}

} // namespace

void ResidualQuantizer::refine_beam_LUT(
size_t n,
const float* query_norms, // size n
Expand Down
13 changes: 10 additions & 3 deletions faiss/impl/ResidualQuantizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,7 @@ struct ResidualQuantizer : AdditiveQuantizer {
*/
size_t memory_per_point(int beam_size = -1) const;

/** Cross products used in codebook tables
*
* These are used to keep trak of norms of centroids.
/** Cross products used in codebook tables used for beam_LUT = 1
*/
void compute_codebook_tables();

Expand Down Expand Up @@ -194,6 +192,15 @@ void beam_search_encode_step(

/** Encode a set of vectors using their dot products with the codebooks
*
* @param K number of vectors in the codebook
* @param n nb of vectors to encode
* @param beam_size input beam size
* @param codebook_cross_norms inner product of this codebook with the m
* previously encoded codebooks
* @param codebook_offsets offsets into codebook_cross_norms for each
* previous codebook
* @param query_cp dot products of query vectors with ???
* @param cent_norms_i norms of centroids
*/
void beam_search_encode_step_tab(
size_t K,
Expand Down
39 changes: 26 additions & 13 deletions faiss/impl/index_read.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,11 +292,17 @@ static void read_AdditiveQuantizer(AdditiveQuantizer* aq, IOReader* f) {
aq->set_derived_values();
}

static void read_ResidualQuantizer(ResidualQuantizer* rq, IOReader* f) {
static void read_ResidualQuantizer(
ResidualQuantizer* rq,
IOReader* f,
int io_flags) {
read_AdditiveQuantizer(rq, f);
READ1(rq->train_type);
READ1(rq->max_beam_size);
if (!(rq->train_type & ResidualQuantizer::Skip_codebook_tables)) {
if ((rq->train_type & ResidualQuantizer::Skip_codebook_tables) ||
(io_flags & IO_FLAG_SKIP_PRECOMPUTE_TABLE)) {
// don't precompute the tables
} else {
rq->compute_codebook_tables();
}
}
Expand Down Expand Up @@ -325,12 +331,13 @@ static void read_ProductAdditiveQuantizer(

static void read_ProductResidualQuantizer(
ProductResidualQuantizer* prq,
IOReader* f) {
IOReader* f,
int io_flags) {
read_ProductAdditiveQuantizer(prq, f);

for (size_t i = 0; i < prq->nsplits; i++) {
auto rq = new ResidualQuantizer();
read_ResidualQuantizer(rq, f);
read_ResidualQuantizer(rq, f, io_flags);
prq->quantizers.push_back(rq);
}
}
Expand Down Expand Up @@ -601,7 +608,7 @@ Index* read_index(IOReader* f, int io_flags) {
if (h == fourcc("IxRQ")) {
read_ResidualQuantizer_old(&idxr->rq, f);
} else {
read_ResidualQuantizer(&idxr->rq, f);
read_ResidualQuantizer(&idxr->rq, f, io_flags);
}
READ1(idxr->code_size);
READVECTOR(idxr->codes);
Expand All @@ -616,7 +623,7 @@ Index* read_index(IOReader* f, int io_flags) {
} else if (h == fourcc("IxPR")) {
auto idxpr = new IndexProductResidualQuantizer();
read_index_header(idxpr, f);
read_ProductResidualQuantizer(&idxpr->prq, f);
read_ProductResidualQuantizer(&idxpr->prq, f, io_flags);
READ1(idxpr->code_size);
READVECTOR(idxpr->codes);
idx = idxpr;
Expand All @@ -630,8 +637,13 @@ Index* read_index(IOReader* f, int io_flags) {
} else if (h == fourcc("ImRQ")) {
ResidualCoarseQuantizer* idxr = new ResidualCoarseQuantizer();
read_index_header(idxr, f);
read_ResidualQuantizer(&idxr->rq, f);
read_ResidualQuantizer(&idxr->rq, f, io_flags);
READ1(idxr->beam_factor);
if (io_flags & IO_FLAG_SKIP_PRECOMPUTE_TABLE) {
// then we force the beam factor to -1
// which skips the table precomputation.
idxr->beam_factor = -1;
}
idxr->set_beam_factor(idxr->beam_factor);
idx = idxr;
} else if (
Expand All @@ -656,13 +668,14 @@ Index* read_index(IOReader* f, int io_flags) {
if (is_LSQ) {
read_LocalSearchQuantizer((LocalSearchQuantizer*)idxaqfs->aq, f);
} else if (is_RQ) {
read_ResidualQuantizer((ResidualQuantizer*)idxaqfs->aq, f);
read_ResidualQuantizer(
(ResidualQuantizer*)idxaqfs->aq, f, io_flags);
} else if (is_PLSQ) {
read_ProductLocalSearchQuantizer(
(ProductLocalSearchQuantizer*)idxaqfs->aq, f);
} else {
read_ProductResidualQuantizer(
(ProductResidualQuantizer*)idxaqfs->aq, f);
(ProductResidualQuantizer*)idxaqfs->aq, f, io_flags);
}

READ1(idxaqfs->implem);
Expand Down Expand Up @@ -704,13 +717,13 @@ Index* read_index(IOReader* f, int io_flags) {
if (is_LSQ) {
read_LocalSearchQuantizer((LocalSearchQuantizer*)ivaqfs->aq, f);
} else if (is_RQ) {
read_ResidualQuantizer((ResidualQuantizer*)ivaqfs->aq, f);
read_ResidualQuantizer((ResidualQuantizer*)ivaqfs->aq, f, io_flags);
} else if (is_PLSQ) {
read_ProductLocalSearchQuantizer(
(ProductLocalSearchQuantizer*)ivaqfs->aq, f);
} else {
read_ProductResidualQuantizer(
(ProductResidualQuantizer*)ivaqfs->aq, f);
(ProductResidualQuantizer*)ivaqfs->aq, f, io_flags);
}

READ1(ivaqfs->by_residual);
Expand Down Expand Up @@ -832,13 +845,13 @@ Index* read_index(IOReader* f, int io_flags) {
if (is_LSQ) {
read_LocalSearchQuantizer((LocalSearchQuantizer*)iva->aq, f);
} else if (is_RQ) {
read_ResidualQuantizer((ResidualQuantizer*)iva->aq, f);
read_ResidualQuantizer((ResidualQuantizer*)iva->aq, f, io_flags);
} else if (is_PLSQ) {
read_ProductLocalSearchQuantizer(
(ProductLocalSearchQuantizer*)iva->aq, f);
} else {
read_ProductResidualQuantizer(
(ProductResidualQuantizer*)iva->aq, f);
(ProductResidualQuantizer*)iva->aq, f, io_flags);
}
READ1(iva->by_residual);
READ1(iva->use_precomputed_table);
Expand Down
4 changes: 2 additions & 2 deletions faiss/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,10 @@ def serialize_index(index):
return vector_to_array(writer.data)


def deserialize_index(data):
def deserialize_index(data, io_flags=0):
reader = VectorIOReader()
copy_array_to_vector(data, reader.data)
return read_index(reader)
return read_index(reader, io_flags)


def serialize_index_binary(index):
Expand Down
Loading