Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a context parameter to InvertedLists and InvertedListsIterator #3247

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions faiss/IndexIVF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ void IndexIVF::add_core(
idx_t n,
const float* x,
const idx_t* xids,
const idx_t* coarse_idx) {
const idx_t* coarse_idx,
void* inverted_list_context) {
// do some blocking to avoid excessive allocs
idx_t bs = 65536;
if (n > bs) {
Expand All @@ -218,7 +219,8 @@ void IndexIVF::add_core(
i1 - i0,
x + i0 * d,
xids ? xids + i0 : nullptr,
coarse_idx + i0);
coarse_idx + i0,
inverted_list_context);
}
return;
}
Expand Down Expand Up @@ -249,7 +251,10 @@ void IndexIVF::add_core(
if (list_no >= 0 && list_no % nt == rank) {
idx_t id = xids ? xids[i] : ntotal + i;
size_t ofs = invlists->add_entry(
list_no, id, flat_codes.get() + i * code_size);
list_no,
id,
flat_codes.get() + i * code_size,
inverted_list_context);

dm_adder.add(i, list_no, ofs);

Expand Down Expand Up @@ -445,6 +450,9 @@ void IndexIVF::search_preassigned(
: pmode == 1 ? nprobe > 1
: nprobe * n > 1);

void* inverted_list_context =
params ? params->inverted_list_context : nullptr;

#pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis, nheap)
{
std::unique_ptr<InvertedListScanner> scanner(
Expand Down Expand Up @@ -507,7 +515,7 @@ void IndexIVF::search_preassigned(
nlist);

// don't waste time on empty lists
if (invlists->is_empty(key)) {
if (invlists->is_empty(key, inverted_list_context)) {
return (size_t)0;
}

Expand All @@ -520,7 +528,7 @@ void IndexIVF::search_preassigned(
size_t list_size = 0;

std::unique_ptr<InvertedListsIterator> it(
invlists->get_iterator(key));
invlists->get_iterator(key, inverted_list_context));

nheap += scanner->iterate_codes(
it.get(), simi, idxi, k, list_size);
Expand Down Expand Up @@ -783,6 +791,9 @@ void IndexIVF::range_search_preassigned(
: pmode == 1 ? nprobe > 1
: nprobe * nx > 1);

void* inverted_list_context =
params ? params->inverted_list_context : nullptr;

#pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis)
{
RangeSearchPartialResult pres(result);
Expand All @@ -804,7 +815,7 @@ void IndexIVF::range_search_preassigned(
ik,
nlist);

if (invlists->is_empty(key)) {
if (invlists->is_empty(key, inverted_list_context)) {
return;
}

Expand All @@ -813,7 +824,7 @@ void IndexIVF::range_search_preassigned(
scanner->set_list(key, coarse_dis[i * nprobe + ik]);
if (invlists->use_iterator) {
std::unique_ptr<InvertedListsIterator> it(
invlists->get_iterator(key));
invlists->get_iterator(key, inverted_list_context));

scanner->iterate_codes_range(
it.get(), radius, qres, list_size);
Expand Down
5 changes: 4 additions & 1 deletion faiss/IndexIVF.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ struct SearchParametersIVF : SearchParameters {
size_t nprobe = 1; ///< number of probes at query time
size_t max_codes = 0; ///< max nb of codes to visit to do a query
SearchParameters* quantizer_params = nullptr;
/// context object to pass to InvertedLists
void* inverted_list_context = nullptr;

virtual ~SearchParametersIVF() {}
};
Expand Down Expand Up @@ -232,7 +234,8 @@ struct IndexIVF : Index, IndexIVFInterface {
idx_t n,
const float* x,
const idx_t* xids,
const idx_t* precomputed_idx);
const idx_t* precomputed_idx,
void* inverted_list_context = nullptr);

/** Encodes a set of vectors as they would appear in the inverted lists
*
Expand Down
7 changes: 4 additions & 3 deletions faiss/IndexIVFFlat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ void IndexIVFFlat::add_core(
idx_t n,
const float* x,
const idx_t* xids,
const idx_t* coarse_idx) {
const idx_t* coarse_idx,
void* inverted_list_context) {
FAISS_THROW_IF_NOT(is_trained);
FAISS_THROW_IF_NOT(coarse_idx);
FAISS_THROW_IF_NOT(!by_residual);
Expand All @@ -70,8 +71,8 @@ void IndexIVFFlat::add_core(
if (list_no >= 0 && list_no % nt == rank) {
idx_t id = xids ? xids[i] : ntotal + i;
const float* xi = x + i * d;
size_t offset =
invlists->add_entry(list_no, id, (const uint8_t*)xi);
size_t offset = invlists->add_entry(
list_no, id, (const uint8_t*)xi, inverted_list_context);
dm_adder.add(i, list_no, offset);
n_add++;
} else if (rank == 0 && list_no == -1) {
Expand Down
3 changes: 2 additions & 1 deletion faiss/IndexIVFFlat.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ struct IndexIVFFlat : IndexIVF {
idx_t n,
const float* x,
const idx_t* xids,
const idx_t* precomputed_idx) override;
const idx_t* precomputed_idx,
void* inverted_list_context = nullptr) override;

void encode_vectors(
idx_t n,
Expand Down
14 changes: 9 additions & 5 deletions faiss/IndexIVFPQ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,9 @@ void IndexIVFPQ::add_core(
idx_t n,
const float* x,
const idx_t* xids,
const idx_t* coarse_idx) {
add_core_o(n, x, xids, nullptr, coarse_idx);
const idx_t* coarse_idx,
void* inverted_list_context) {
add_core_o(n, x, xids, nullptr, coarse_idx, inverted_list_context);
}

static std::unique_ptr<float[]> compute_residuals(
Expand Down Expand Up @@ -212,7 +213,8 @@ void IndexIVFPQ::add_core_o(
const float* x,
const idx_t* xids,
float* residuals_2,
const idx_t* precomputed_idx) {
const idx_t* precomputed_idx,
void* inverted_list_context) {
idx_t bs = index_ivfpq_add_core_o_bs;
if (n > bs) {
for (idx_t i0 = 0; i0 < n; i0 += bs) {
Expand All @@ -229,7 +231,8 @@ void IndexIVFPQ::add_core_o(
x + i0 * d,
xids ? xids + i0 : nullptr,
residuals_2 ? residuals_2 + i0 * d : nullptr,
precomputed_idx ? precomputed_idx + i0 : nullptr);
precomputed_idx ? precomputed_idx + i0 : nullptr,
inverted_list_context);
}
return;
}
Expand Down Expand Up @@ -281,7 +284,8 @@ void IndexIVFPQ::add_core_o(
}

uint8_t* code = xcodes.get() + i * code_size;
size_t offset = invlists->add_entry(key, id, code);
size_t offset =
invlists->add_entry(key, id, code, inverted_list_context);

if (residuals_2) {
float* res2 = residuals_2 + i * d;
Expand Down
6 changes: 4 additions & 2 deletions faiss/IndexIVFPQ.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ struct IndexIVFPQ : IndexIVF {
idx_t n,
const float* x,
const idx_t* xids,
const idx_t* precomputed_idx) override;
const idx_t* precomputed_idx,
void* inverted_list_context = nullptr) override;

/// same as add_core, also:
/// - output 2nd level residuals if residuals_2 != NULL
Expand All @@ -81,7 +82,8 @@ struct IndexIVFPQ : IndexIVF {
const float* x,
const idx_t* xids,
float* residuals_2,
const idx_t* precomputed_idx = nullptr);
const idx_t* precomputed_idx = nullptr,
void* inverted_list_context = nullptr);

/// trains the product quantizer
void train_encoder(idx_t n, const float* x, const idx_t* assign) override;
Expand Down
3 changes: 2 additions & 1 deletion faiss/IndexIVFPQR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ void IndexIVFPQR::add_core(
idx_t n,
const float* x,
const idx_t* xids,
const idx_t* precomputed_idx) {
const idx_t* precomputed_idx,
void* /*inverted_list_context*/) {
std::unique_ptr<float[]> residual_2(new float[n * d]);

idx_t n0 = ntotal;
Expand Down
3 changes: 2 additions & 1 deletion faiss/IndexIVFPQR.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ struct IndexIVFPQR : IndexIVFPQ {
idx_t n,
const float* x,
const idx_t* xids,
const idx_t* precomputed_idx) override;
const idx_t* precomputed_idx,
void* inverted_list_context = nullptr) override;

void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
const override;
Expand Down
6 changes: 4 additions & 2 deletions faiss/IndexScalarQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,8 @@ void IndexIVFScalarQuantizer::add_core(
idx_t n,
const float* x,
const idx_t* xids,
const idx_t* coarse_idx) {
const idx_t* coarse_idx,
void* inverted_list_context) {
FAISS_THROW_IF_NOT(is_trained);

std::unique_ptr<ScalarQuantizer::SQuantizer> squant(sq.select_quantizer());
Expand Down Expand Up @@ -236,7 +237,8 @@ void IndexIVFScalarQuantizer::add_core(
memset(one_code.data(), 0, code_size);
squant->encode_vector(xi, one_code.data());

size_t ofs = invlists->add_entry(list_no, id, one_code.data());
size_t ofs = invlists->add_entry(
list_no, id, one_code.data(), inverted_list_context);

dm_add.add(i, list_no, ofs);

Expand Down
3 changes: 2 additions & 1 deletion faiss/IndexScalarQuantizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ struct IndexIVFScalarQuantizer : IndexIVF {
idx_t n,
const float* x,
const idx_t* xids,
const idx_t* precomputed_idx) override;
const idx_t* precomputed_idx,
void* inverted_list_context = nullptr) override;

InvertedListScanner* get_InvertedListScanner(
bool store_pairs,
Expand Down
18 changes: 11 additions & 7 deletions faiss/invlists/InvertedLists.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ InvertedLists::InvertedLists(size_t nlist, size_t code_size)

InvertedLists::~InvertedLists() {}

bool InvertedLists::is_empty(size_t list_no) const {
return use_iterator
? !std::unique_ptr<InvertedListsIterator>(get_iterator(list_no))
->is_available()
: list_size(list_no) == 0;
bool InvertedLists::is_empty(size_t list_no, void* inverted_list_context)
const {
return use_iterator ? !std::unique_ptr<InvertedListsIterator>(
get_iterator(list_no, inverted_list_context))
->is_available()
: list_size(list_no) == 0;
}

idx_t InvertedLists::get_single_id(size_t list_no, size_t offset) const {
Expand All @@ -58,7 +59,8 @@ const uint8_t* InvertedLists::get_single_code(size_t list_no, size_t offset)
size_t InvertedLists::add_entry(
size_t list_no,
idx_t theid,
const uint8_t* code) {
const uint8_t* code,
void* /*inverted_list_context*/) {
return add_entries(list_no, 1, &theid, code);
}

Expand All @@ -76,7 +78,9 @@ void InvertedLists::reset() {
}
}

InvertedListsIterator* InvertedLists::get_iterator(size_t /*list_no*/) const {
InvertedListsIterator* InvertedLists::get_iterator(
size_t /*list_no*/,
void* /*inverted_list_context*/) const {
FAISS_THROW_MSG("get_iterator is not supported");
}

Expand Down
12 changes: 9 additions & 3 deletions faiss/invlists/InvertedLists.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,15 @@ struct InvertedLists {
* Read only functions */

// check if the list is empty
bool is_empty(size_t list_no) const;
bool is_empty(size_t list_no, void* inverted_list_context) const;

/// get the size of a list
virtual size_t list_size(size_t list_no) const = 0;

/// get iterable for lists that use_iterator
virtual InvertedListsIterator* get_iterator(size_t list_no) const;
virtual InvertedListsIterator* get_iterator(
size_t list_no,
void* inverted_list_context) const;

/** get the codes for an inverted list
* must be released by release_codes
Expand Down Expand Up @@ -94,7 +96,11 @@ struct InvertedLists {
* writing functions */

/// add one entry to an inverted list
virtual size_t add_entry(size_t list_no, idx_t theid, const uint8_t* code);
virtual size_t add_entry(
size_t list_no,
idx_t theid,
const uint8_t* code,
void* inverted_list_context = nullptr);

virtual size_t add_entries(
size_t list_no,
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ set(FAISS_TEST_SRC
test_ivfpq_codec.cpp
test_ivfpq_indexing.cpp
test_lowlevel_ivf.cpp
test_ivf_index.cpp
test_merge.cpp
test_omp_threads.cpp
test_ondisk_ivf.cpp
Expand Down
Loading