From c4f4b6331d6e4db28229eeac7024781aad13ed28 Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Fri, 29 Mar 2024 00:07:57 -0700 Subject: [PATCH] Add the interface for streaming the vectors from java to jni layer with initial capacity Signed-off-by: Navneet Verma --- jni/CMakeLists.txt | 3 +- jni/include/commons.h | 37 + jni/include/jni_util.h | 4 + .../org_opensearch_knn_jni_FaissService.h | 8 - .../org_opensearch_knn_jni_JNICommons.h | 40 + ...Custom-patch-to-support-multi-vector.patch | 1050 ----------------- ...ble-precomp-table-to-be-shared-ivfpq.patch | 512 -------- ...vel-during-add-from-enterpoint-level.patch | 31 - jni/src/commons.cpp | 41 + jni/src/jni_util.cpp | 11 +- .../org_opensearch_knn_jni_FaissService.cpp | 19 - jni/src/org_opensearch_knn_jni_JNICommons.cpp | 60 + jni/tests/commons_test.cpp | 76 ++ jni/tests/test_util.cpp | 8 + jni/tests/test_util.h | 2 + .../knn/TransferVectorsBenchmarks.java | 24 +- .../opensearch/knn/common/KNNConstants.java | 3 + .../org/opensearch/knn/jni/FaissService.java | 19 +- .../org/opensearch/knn/jni/JNICommons.java | 52 + .../org/opensearch/knn/jni/JNIService.java | 30 +- .../opensearch/knn/jni/JNICommonsTest.java | 40 + 21 files changed, 404 insertions(+), 1666 deletions(-) create mode 100644 jni/include/commons.h create mode 100644 jni/include/org_opensearch_knn_jni_JNICommons.h delete mode 100644 jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch delete mode 100644 jni/patches/faiss/0002-Enable-precomp-table-to-be-shared-ivfpq.patch delete mode 100644 jni/patches/nmslib/0001-Initialize-maxlevel-during-add-from-enterpoint-level.patch create mode 100644 jni/src/commons.cpp create mode 100644 jni/src/org_opensearch_knn_jni_JNICommons.cpp create mode 100644 jni/tests/commons_test.cpp create mode 100644 src/main/java/org/opensearch/knn/jni/JNICommons.java create mode 100644 src/test/java/org/opensearch/knn/jni/JNICommonsTest.java diff --git a/jni/CMakeLists.txt b/jni/CMakeLists.txt index 60321ed1bf..4f32c87b99 100644 --- a/jni/CMakeLists.txt +++ b/jni/CMakeLists.txt @@ -61,7 +61,7 @@ endif() # ---------------------------------------------------------------------------- # ---------------------------------- COMMON ---------------------------------- -add_library(${TARGET_LIB_COMMON} SHARED ${CMAKE_CURRENT_SOURCE_DIR}/src/jni_util.cpp) +add_library(${TARGET_LIB_COMMON} SHARED ${CMAKE_CURRENT_SOURCE_DIR}/src/jni_util.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/org_opensearch_knn_jni_JNICommons.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/commons.cpp) target_include_directories(${TARGET_LIB_COMMON} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include $ENV{JAVA_HOME}/include $ENV{JAVA_HOME}/include/${JVM_OS_TYPE}) set_target_properties(${TARGET_LIB_COMMON} PROPERTIES SUFFIX ${LIB_EXT}) set_target_properties(${TARGET_LIB_COMMON} PROPERTIES POSITION_INDEPENDENT_CODE ON) @@ -236,6 +236,7 @@ if ("${WIN32}" STREQUAL "") tests/faiss_util_test.cpp tests/nmslib_wrapper_test.cpp tests/test_util.cpp + tests/commons_test.cpp ) target_link_libraries( diff --git a/jni/include/commons.h b/jni/include/commons.h new file mode 100644 index 0000000000..05367a6939 --- /dev/null +++ b/jni/include/commons.h @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ +#include "jni_util.h" +#include +namespace knn_jni { + namespace commons { + /** + * This is utility function that can be used to store data in native memory. This function will allocate memory for + * the data(rows*columns) with initialCapacity and return the memory address where the data is stored. + * If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created. + * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location + * will throw Exception. + * + * @param memoryAddress The address of the memory location where data will be stored. + * @param data 2D float array containing data to be stored in native memory. + * @param initialCapacity The initial capacity of the memory location. + * @return memory address where the data is stored. + */ + jlong storeVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong); + + /** + * Free up the memory allocated for the data stored in memory address. This function should be used with the memory + * address returned by {@link JNICommons#storeVectorData(long, float[][], long, long)} + * + * @param memoryAddress address to be freed. + */ + void freeVectorData(jlong); + } +} diff --git a/jni/include/jni_util.h b/jni/include/jni_util.h index 52b08a202b..b3d55f1c1c 100644 --- a/jni/include/jni_util.h +++ b/jni/include/jni_util.h @@ -69,6 +69,9 @@ namespace knn_jni { virtual std::vector Convert2dJavaObjectArrayToCppFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim) = 0; + virtual void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, + int dim, std::vector *vect ) = 0; + virtual std::vector ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) = 0; // -------------------------------------------------------------------------- @@ -164,6 +167,7 @@ namespace knn_jni { void ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode); void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val); void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf); + void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector *vect); private: std::unordered_map cachedClasses; diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index ec1f46bc37..64a858f844 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -122,14 +122,6 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors (JNIEnv *, jclass, jlong, jobjectArray); -/* - * Class: org_opensearch_knn_jni_FaissService - * Method: transferVectorsV2 - * Signature: (J[[F)J - */ -JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectorsV2 - (JNIEnv *, jclass, jlong, jobjectArray); - /* * Class: org_opensearch_knn_jni_FaissService * Method: freeVectors diff --git a/jni/include/org_opensearch_knn_jni_JNICommons.h b/jni/include/org_opensearch_knn_jni_JNICommons.h new file mode 100644 index 0000000000..d0758d7c8c --- /dev/null +++ b/jni/include/org_opensearch_knn_jni_JNICommons.h @@ -0,0 +1,40 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class org_opensearch_knn_jni_JNICommons */ + +#ifndef _Included_org_opensearch_knn_jni_JNICommons +#define _Included_org_opensearch_knn_jni_JNICommons +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: org_opensearch_knn_jni_JNICommons + * Method: storeVectorData + * Signature: (J[[FJJ) + */ +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData + (JNIEnv *, jclass, jlong, jobjectArray, jlong); + +/* + * Class: org_opensearch_knn_jni_JNICommons + * Method: freeVectorData + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeVectorData + (JNIEnv *, jclass, jlong); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch b/jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch deleted file mode 100644 index a22e281305..0000000000 --- a/jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch +++ /dev/null @@ -1,1050 +0,0 @@ -From 0d1385959ddecabb2825957e48ff28ff0e8abf53 Mon Sep 17 00:00:00 2001 -From: Heemin Kim -Date: Tue, 30 Jan 2024 14:43:56 -0800 -Subject: [PATCH] Add IDGrouper for HNSW - -Signed-off-by: Heemin Kim ---- - faiss/CMakeLists.txt | 3 + - faiss/Index.h | 8 +- - faiss/IndexHNSW.cpp | 13 ++- - faiss/IndexIDMap.cpp | 29 ++++++ - faiss/IndexIDMap.h | 22 +++++ - faiss/impl/HNSW.cpp | 10 +- - faiss/impl/IDGrouper.cpp | 51 ++++++++++ - faiss/impl/IDGrouper.h | 51 ++++++++++ - faiss/impl/ResultHandler.h | 187 ++++++++++++++++++++++++++++++++++++ - faiss/utils/GroupHeap.h | 182 +++++++++++++++++++++++++++++++++++ - tests/CMakeLists.txt | 2 + - tests/test_group_heap.cpp | 98 +++++++++++++++++++ - tests/test_id_grouper.cpp | 189 +++++++++++++++++++++++++++++++++++++ - 13 files changed, 838 insertions(+), 7 deletions(-) - create mode 100644 faiss/impl/IDGrouper.cpp - create mode 100644 faiss/impl/IDGrouper.h - create mode 100644 faiss/utils/GroupHeap.h - create mode 100644 tests/test_group_heap.cpp - create mode 100644 tests/test_id_grouper.cpp - -diff --git a/faiss/CMakeLists.txt b/faiss/CMakeLists.txt -index a890a46f..137e68d4 100644 ---- a/faiss/CMakeLists.txt -+++ b/faiss/CMakeLists.txt -@@ -54,6 +54,7 @@ set(FAISS_SRC - impl/AuxIndexStructures.cpp - impl/CodePacker.cpp - impl/IDSelector.cpp -+ impl/IDGrouper.cpp - impl/FaissException.cpp - impl/HNSW.cpp - impl/NSG.cpp -@@ -149,6 +150,7 @@ set(FAISS_HEADERS - impl/AuxIndexStructures.h - impl/CodePacker.h - impl/IDSelector.h -+ impl/IDGrouper.h - impl/DistanceComputer.h - impl/FaissAssert.h - impl/FaissException.h -@@ -183,6 +185,7 @@ set(FAISS_HEADERS - invlists/InvertedLists.h - invlists/InvertedListsIOHook.h - utils/AlignedTable.h -+ utils/GroupHeap.h - utils/Heap.h - utils/WorkerThread.h - utils/distances.h -diff --git a/faiss/Index.h b/faiss/Index.h -index 4b4b302b..3b673d1e 100644 ---- a/faiss/Index.h -+++ b/faiss/Index.h -@@ -38,9 +38,10 @@ - - namespace faiss { - --/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h and --/// impl/DistanceComputer.h -+/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h -+/// ,impl/IDGrouper.h and impl/DistanceComputer.h - struct IDSelector; -+struct IDGrouper; - struct RangeSearchResult; - struct DistanceComputer; - -@@ -52,6 +53,9 @@ struct DistanceComputer; - struct SearchParameters { - /// if non-null, only these IDs will be considered during search. - IDSelector* sel = nullptr; -+ /// if non-null, only best matched ID per group will be included in the -+ /// result. -+ IDGrouper* grp = nullptr; - /// make sure we can dynamic_cast this - virtual ~SearchParameters() {} - }; -diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp -index 9a67332d..a5e0fea0 100644 ---- a/faiss/IndexHNSW.cpp -+++ b/faiss/IndexHNSW.cpp -@@ -354,10 +354,17 @@ void IndexHNSW::search( - const SearchParameters* params_in) const { - FAISS_THROW_IF_NOT(k > 0); - -- using RH = HeapBlockResultHandler; -- RH bres(n, distances, labels, k); -+ if (params_in && params_in->grp) { -+ using RH = GroupedHeapBlockResultHandler; -+ RH bres(n, distances, labels, k, params_in->grp); - -- hnsw_search(this, n, x, bres, params_in); -+ hnsw_search(this, n, x, bres, params_in); -+ } else { -+ using RH = HeapBlockResultHandler; -+ RH bres(n, distances, labels, k); -+ -+ hnsw_search(this, n, x, bres, params_in); -+ } - - if (is_similarity_metric(this->metric_type)) { - // we need to revert the negated distances -diff --git a/faiss/IndexIDMap.cpp b/faiss/IndexIDMap.cpp -index e093bbda..e24365d5 100644 ---- a/faiss/IndexIDMap.cpp -+++ b/faiss/IndexIDMap.cpp -@@ -102,6 +102,23 @@ struct ScopedSelChange { - } - }; - -+/// RAII object to reset the IDGrouper in the params object -+struct ScopedGrpChange { -+ SearchParameters* params = nullptr; -+ IDGrouper* old_grp = nullptr; -+ -+ void set(SearchParameters* params_2, IDGrouper* new_grp) { -+ this->params = params_2; -+ old_grp = params_2->grp; -+ params_2->grp = new_grp; -+ } -+ ~ScopedGrpChange() { -+ if (params) { -+ params->grp = old_grp; -+ } -+ } -+}; -+ - } // namespace - - template -@@ -114,6 +131,8 @@ void IndexIDMapTemplate::search( - const SearchParameters* params) const { - IDSelectorTranslated this_idtrans(this->id_map, nullptr); - ScopedSelChange sel_change; -+ IDGrouperTranslated this_idgrptrans(this->id_map, nullptr); -+ ScopedGrpChange grp_change; - - if (params && params->sel) { - auto idtrans = dynamic_cast(params->sel); -@@ -131,6 +150,16 @@ void IndexIDMapTemplate::search( - sel_change.set(params_non_const, &this_idtrans); - } - } -+ -+ if (params && params->grp) { -+ auto idtrans = dynamic_cast(params->grp); -+ -+ if (!idtrans) { -+ auto params_non_const = const_cast(params); -+ this_idgrptrans.grp = params->grp; -+ grp_change.set(params_non_const, &this_idgrptrans); -+ } -+ } - index->search(n, x, k, distances, labels, params); - idx_t* li = labels; - #pragma omp parallel for -diff --git a/faiss/IndexIDMap.h b/faiss/IndexIDMap.h -index 2d164123..a68887bd 100644 ---- a/faiss/IndexIDMap.h -+++ b/faiss/IndexIDMap.h -@@ -9,6 +9,7 @@ - - #include - #include -+#include - #include - - #include -@@ -124,4 +125,25 @@ struct IDSelectorTranslated : IDSelector { - } - }; - -+// IDGrouper that translates the ids using an IDMap -+struct IDGrouperTranslated : IDGrouper { -+ const std::vector& id_map; -+ const IDGrouper* grp; -+ -+ IDGrouperTranslated( -+ const std::vector& id_map, -+ const IDGrouper* grp) -+ : id_map(id_map), grp(grp) {} -+ -+ IDGrouperTranslated(IndexBinaryIDMap& index_idmap, const IDGrouper* grp) -+ : id_map(index_idmap.id_map), grp(grp) {} -+ -+ IDGrouperTranslated(IndexIDMap& index_idmap, const IDGrouper* grp) -+ : id_map(index_idmap.id_map), grp(grp) {} -+ -+ idx_t get_group(idx_t id) const override { -+ return grp->get_group(id_map[id]); -+ } -+}; -+ - } // namespace faiss -diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp -index fb4de678..b6f602a0 100644 ---- a/faiss/impl/HNSW.cpp -+++ b/faiss/impl/HNSW.cpp -@@ -110,8 +110,8 @@ void HNSW::print_neighbor_stats(int level) const { - level, - nb_neighbors(level)); - size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0; --#pragma omp parallel for reduction(+: tot_neigh) reduction(+: tot_common) \ -- reduction(+: tot_reciprocal) reduction(+: n_node) -+#pragma omp parallel for reduction(+ : tot_neigh) reduction(+ : tot_common) \ -+ reduction(+ : tot_reciprocal) reduction(+ : n_node) - for (int i = 0; i < levels.size(); i++) { - if (levels[i] > level) { - n_node++; -@@ -804,6 +804,12 @@ int extract_k_from_ResultHandler(ResultHandler& res) { - if (auto hres = dynamic_cast(&res)) { - return hres->k; - } -+ -+ if (auto hres = dynamic_cast< -+ GroupedHeapBlockResultHandler::SingleResultHandler*>(&res)) { -+ return hres->k; -+ } -+ - return 1; - } - -diff --git a/faiss/impl/IDGrouper.cpp b/faiss/impl/IDGrouper.cpp -new file mode 100644 -index 00000000..ca9f5fda ---- /dev/null -+++ b/faiss/impl/IDGrouper.cpp -@@ -0,0 +1,51 @@ -+/** -+ * Copyright (c) Facebook, Inc. and its affiliates. -+ * -+ * This source code is licensed under the MIT license found in the -+ * LICENSE file in the root directory of this source tree. -+ */ -+ -+#include -+#include -+#include -+ -+namespace faiss { -+ -+/*********************************************************************** -+ * IDGrouperBitmap -+ ***********************************************************************/ -+ -+IDGrouperBitmap::IDGrouperBitmap(size_t n, uint64_t* bitmap) -+ : n(n), bitmap(bitmap) {} -+ -+idx_t IDGrouperBitmap::get_group(idx_t id) const { -+ assert(id >= 0 && "id shouldn't be less than zero"); -+ assert(id < this->n * 64 && "is should be less than total number of bits"); -+ -+ idx_t index = id >> 6; // div by 64 -+ uint64_t block = this->bitmap[index] >> -+ (id & 63); // Equivalent of words[i] >> (index % 64) -+ // block is non zero after right shift, it means, next set bit is in current -+ // block The index of set bit is "given index" + "trailing zero in the right -+ // shifted word" -+ if (block != 0) { -+ return id + __builtin_ctzll(block); -+ } -+ -+ while (++index < this->n) { -+ block = this->bitmap[index]; -+ if (block != 0) { -+ return (index << 6) + __builtin_ctzll(block); -+ } -+ } -+ -+ return NO_MORE_DOCS; -+} -+ -+void IDGrouperBitmap::set_group(idx_t group_id) { -+ idx_t index = group_id >> 6; -+ this->bitmap[index] |= 1ULL -+ << (group_id & 63); // Equivalent of 1ULL << (value % 64) -+} -+ -+} // namespace faiss -diff --git a/faiss/impl/IDGrouper.h b/faiss/impl/IDGrouper.h -new file mode 100644 -index 00000000..d56113d9 ---- /dev/null -+++ b/faiss/impl/IDGrouper.h -@@ -0,0 +1,51 @@ -+/** -+ * Copyright (c) Facebook, Inc. and its affiliates. -+ * -+ * This source code is licensed under the MIT license found in the -+ * LICENSE file in the root directory of this source tree. -+ */ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include -+ -+/** IDGrouper is intended to define a group of vectors to include only -+ * the nearest vector of each group during search */ -+ -+namespace faiss { -+ -+/** Encapsulates a group id of ids */ -+struct IDGrouper { -+ const idx_t NO_MORE_DOCS = std::numeric_limits::max(); -+ virtual idx_t get_group(idx_t id) const = 0; -+ virtual ~IDGrouper() {} -+}; -+ -+/** One bit per element. Constructed with a bitmap, size ceil(n / 8). -+ */ -+struct IDGrouperBitmap : IDGrouper { -+ // length of the bitmap array -+ size_t n; -+ -+ // Array of uint64_t holding the bits -+ // Using uint64_t to leverage function __builtin_ctzll which is defined in -+ // faiss/impl/platform_macros.h Group id of a given id is next set bit in -+ // the bitmap -+ uint64_t* bitmap; -+ -+ /** Construct with a binary mask -+ * -+ * @param n size of the bitmap array -+ * @param bitmap group id of a given id is next set bit in the bitmap -+ */ -+ IDGrouperBitmap(size_t n, uint64_t* bitmap); -+ idx_t get_group(idx_t id) const final; -+ void set_group(idx_t group_id); -+ ~IDGrouperBitmap() override {} -+}; -+ -+} // namespace faiss -diff --git a/faiss/impl/ResultHandler.h b/faiss/impl/ResultHandler.h -index 270de8dc..2f7f3e7f 100644 ---- a/faiss/impl/ResultHandler.h -+++ b/faiss/impl/ResultHandler.h -@@ -12,6 +12,8 @@ - #pragma once - - #include -+#include -+#include - #include - #include - -@@ -265,6 +267,191 @@ struct HeapBlockResultHandler : BlockResultHandler { - } - }; - -+/***************************************************************** -+ * Heap based result handler with grouping -+ *****************************************************************/ -+ -+template -+struct GroupedHeapBlockResultHandler : BlockResultHandler { -+ using T = typename C::T; -+ using TI = typename C::TI; -+ using BlockResultHandler::i0; -+ using BlockResultHandler::i1; -+ -+ T* heap_dis_tab; -+ TI* heap_ids_tab; -+ int64_t k; // number of results to keep -+ -+ IDGrouper* id_grouper; -+ TI* heap_group_ids_tab; -+ std::unordered_map* group_id_to_index_in_heap_tab; -+ -+ GroupedHeapBlockResultHandler( -+ size_t nq, -+ T* heap_dis_tab, -+ TI* heap_ids_tab, -+ size_t k, -+ IDGrouper* id_grouper) -+ : BlockResultHandler(nq), -+ heap_dis_tab(heap_dis_tab), -+ heap_ids_tab(heap_ids_tab), -+ k(k), -+ id_grouper(id_grouper) {} -+ -+ /****************************************************** -+ * API for 1 result at a time (each SingleResultHandler is -+ * called from 1 thread) -+ */ -+ -+ struct SingleResultHandler : ResultHandler { -+ GroupedHeapBlockResultHandler& hr; -+ using ResultHandler::threshold; -+ size_t k; -+ -+ T* heap_dis; -+ TI* heap_ids; -+ TI* heap_group_ids; -+ std::unordered_map group_id_to_index_in_heap; -+ -+ explicit SingleResultHandler(GroupedHeapBlockResultHandler& hr) -+ : hr(hr), k(hr.k) {} -+ -+ /// begin results for query # i -+ void begin(size_t i) { -+ heap_dis = hr.heap_dis_tab + i * k; -+ heap_ids = hr.heap_ids_tab + i * k; -+ heap_heapify(k, heap_dis, heap_ids); -+ threshold = heap_dis[0]; -+ heap_group_ids = new TI[hr.k]; -+ for (size_t i = 0; i < hr.k; i++) { -+ heap_group_ids[i] = -1; -+ } -+ } -+ -+ /// add one result for query i -+ bool add_result(T dis, TI idx) final { -+ if (!C::cmp(threshold, dis)) { -+ return false; -+ } -+ -+ idx_t group_id = hr.id_grouper->get_group(idx); -+ typename std::unordered_map::const_iterator it_pos = -+ group_id_to_index_in_heap.find(group_id); -+ if (it_pos == group_id_to_index_in_heap.end()) { -+ group_heap_replace_top( -+ k, -+ heap_dis, -+ heap_ids, -+ heap_group_ids, -+ dis, -+ idx, -+ group_id, -+ &group_id_to_index_in_heap); -+ return true; -+ } else { -+ size_t pos = it_pos->second; -+ if (!C::cmp(heap_dis[pos], dis)) { -+ return false; -+ } -+ group_heap_replace_at( -+ pos, -+ k, -+ heap_dis, -+ heap_ids, -+ heap_group_ids, -+ dis, -+ idx, -+ group_id, -+ &group_id_to_index_in_heap); -+ return true; -+ } -+ } -+ -+ /// series of results for query i is done -+ void end() { -+ heap_reorder(k, heap_dis, heap_ids); -+ delete heap_group_ids; -+ } -+ }; -+ -+ /****************************************************** -+ * API for multiple results (called from 1 thread) -+ */ -+ -+ /// begin -+ void begin_multiple(size_t i0_2, size_t i1_2) final { -+ this->i0 = i0_2; -+ this->i1 = i1_2; -+ for (size_t i = i0; i < i1; i++) { -+ heap_heapify(k, heap_dis_tab + i * k, heap_ids_tab + i * k); -+ } -+ size_t size = (i1 - i0) * k; -+ heap_group_ids_tab = new TI[size]; -+ for (size_t i = 0; i < size; i++) { -+ heap_group_ids_tab[i] = -1; -+ } -+ group_id_to_index_in_heap_tab = -+ new std::unordered_map[i1 - i0]; -+ } -+ -+ /// add results for query i0..i1 and j0..j1 -+ void add_results(size_t j0, size_t j1, const T* dis_tab) final { -+#pragma omp parallel for -+ for (int64_t i = i0; i < i1; i++) { -+ T* heap_dis = heap_dis_tab + i * k; -+ TI* heap_ids = heap_ids_tab + i * k; -+ const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0; -+ T thresh = heap_dis[0]; // NOLINT(*-use-default-none) -+ for (size_t j = j0; j < j1; j++) { -+ T dis = dis_tab_i[j]; -+ if (C::cmp(thresh, dis)) { -+ idx_t group_id = id_grouper->get_group(j); -+ typename std::unordered_map::const_iterator -+ it_pos = group_id_to_index_in_heap_tab[i - i0].find( -+ group_id); -+ if (it_pos == group_id_to_index_in_heap_tab[i - i0].end()) { -+ group_heap_replace_top( -+ k, -+ heap_dis, -+ heap_ids, -+ heap_group_ids_tab + ((i - i0) * k), -+ dis, -+ j, -+ group_id, -+ &group_id_to_index_in_heap_tab[i - i0]); -+ thresh = heap_dis[0]; -+ } else { -+ size_t pos = it_pos->first; -+ if (C::cmp(heap_dis[pos], dis)) { -+ group_heap_replace_at( -+ pos, -+ k, -+ heap_dis, -+ heap_ids, -+ heap_group_ids_tab + ((i - i0) * k), -+ dis, -+ j, -+ group_id, -+ &group_id_to_index_in_heap_tab[i - i0]); -+ thresh = heap_dis[0]; -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ /// series of results for queries i0..i1 is done -+ void end_multiple() final { -+ // maybe parallel for -+ for (size_t i = i0; i < i1; i++) { -+ heap_reorder(k, heap_dis_tab + i * k, heap_ids_tab + i * k); -+ } -+ delete group_id_to_index_in_heap_tab; -+ delete heap_group_ids_tab; -+ } -+}; -+ - /***************************************************************** - * Reservoir result handler - * -diff --git a/faiss/utils/GroupHeap.h b/faiss/utils/GroupHeap.h -new file mode 100644 -index 00000000..3b7078da ---- /dev/null -+++ b/faiss/utils/GroupHeap.h -@@ -0,0 +1,182 @@ -+/** -+ * Copyright (c) Facebook, Inc. and its affiliates. -+ * -+ * This source code is licensed under the MIT license found in the -+ * LICENSE file in the root directory of this source tree. -+ */ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include -+#include -+#include -+ -+#include -+#include -+ -+#include -+#include -+ -+namespace faiss { -+ -+/** -+ * From start_index, it compare its value with parent node's and swap if needed. -+ * Continue until either there is no swap or it reaches the top node. -+ */ -+template -+static inline void group_up_heap( -+ typename C::T* heap_dis, -+ typename C::TI* heap_ids, -+ typename C::TI* heap_group_ids, -+ std::unordered_map* group_id_to_index_in_heap, -+ size_t start_index) { -+ heap_dis--; /* Use 1-based indexing for easier node->child translation */ -+ heap_ids--; -+ heap_group_ids--; -+ size_t i = start_index + 1, i_father; -+ typename C::T target_dis = heap_dis[i]; -+ typename C::TI target_id = heap_ids[i]; -+ typename C::TI target_group_id = heap_group_ids[i]; -+ -+ while (i > 1) { -+ i_father = i >> 1; -+ if (!C::cmp2( -+ target_dis, -+ heap_dis[i_father], -+ target_id, -+ heap_ids[i_father])) { -+ /* the heap structure is ok */ -+ break; -+ } -+ heap_dis[i] = heap_dis[i_father]; -+ heap_ids[i] = heap_ids[i_father]; -+ heap_group_ids[i] = heap_group_ids[i_father]; -+ (*group_id_to_index_in_heap)[heap_group_ids[i]] = i - 1; -+ i = i_father; -+ } -+ heap_dis[i] = target_dis; -+ heap_ids[i] = target_id; -+ heap_group_ids[i] = target_group_id; -+ (*group_id_to_index_in_heap)[heap_group_ids[i]] = i - 1; -+} -+ -+/** -+ * From start_index, it compare its value with child node's and swap if needed. -+ * Continue until either there is no swap or it reaches the leaf node. -+ */ -+template -+static inline void group_down_heap( -+ size_t k, -+ typename C::T* heap_dis, -+ typename C::TI* heap_ids, -+ typename C::TI* heap_group_ids, -+ std::unordered_map* group_id_to_index_in_heap, -+ size_t start_index) { -+ heap_dis--; /* Use 1-based indexing for easier node->child translation */ -+ heap_ids--; -+ heap_group_ids--; -+ size_t i = start_index + 1, i1, i2; -+ typename C::T target_dis = heap_dis[i]; -+ typename C::TI target_id = heap_ids[i]; -+ typename C::TI target_group_id = heap_group_ids[i]; -+ -+ while (1) { -+ i1 = i << 1; -+ i2 = i1 + 1; -+ if (i1 > k) { -+ break; -+ } -+ -+ // Note that C::cmp2() is a bool function answering -+ // `(a1 > b1) || ((a1 == b1) && (a2 > b2))` for max -+ // heap and same with the `<` sign for min heap. -+ if ((i2 == k + 1) || -+ C::cmp2(heap_dis[i1], heap_dis[i2], heap_ids[i1], heap_ids[i2])) { -+ if (C::cmp2(target_dis, heap_dis[i1], target_id, heap_ids[i1])) { -+ break; -+ } -+ heap_dis[i] = heap_dis[i1]; -+ heap_ids[i] = heap_ids[i1]; -+ heap_group_ids[i] = heap_group_ids[i1]; -+ (*group_id_to_index_in_heap)[heap_group_ids[i]] = i - 1; -+ i = i1; -+ } else { -+ if (C::cmp2(target_dis, heap_dis[i2], target_id, heap_ids[i2])) { -+ break; -+ } -+ heap_dis[i] = heap_dis[i2]; -+ heap_ids[i] = heap_ids[i2]; -+ heap_group_ids[i] = heap_group_ids[i2]; -+ (*group_id_to_index_in_heap)[heap_group_ids[i]] = i - 1; -+ i = i2; -+ } -+ } -+ heap_dis[i] = target_dis; -+ heap_ids[i] = target_id; -+ heap_group_ids[i] = target_group_id; -+ (*group_id_to_index_in_heap)[heap_group_ids[i]] = i - 1; -+} -+ -+template -+static inline void group_heap_replace_top( -+ size_t k, -+ typename C::T* heap_dis, -+ typename C::TI* heap_ids, -+ typename C::TI* heap_group_ids, -+ typename C::T dis, -+ typename C::TI id, -+ typename C::TI group_id, -+ std::unordered_map* group_id_to_index_in_heap) { -+ assert(group_id_to_index_in_heap->find(group_id) == -+ group_id_to_index_in_heap->end() && -+ "group id should not exist in the binary heap"); -+ -+ group_id_to_index_in_heap->erase(heap_group_ids[0]); -+ heap_group_ids[0] = group_id; -+ heap_dis[0] = dis; -+ heap_ids[0] = id; -+ (*group_id_to_index_in_heap)[group_id] = 0; -+ group_down_heap( -+ k, -+ heap_dis, -+ heap_ids, -+ heap_group_ids, -+ group_id_to_index_in_heap, -+ 0); -+} -+ -+template -+static inline void group_heap_replace_at( -+ size_t pos, -+ size_t k, -+ typename C::T* heap_dis, -+ typename C::TI* heap_ids, -+ typename C::TI* heap_group_ids, -+ typename C::T dis, -+ typename C::TI id, -+ typename C::TI group_id, -+ std::unordered_map* group_id_to_index_in_heap) { -+ assert(group_id_to_index_in_heap->find(group_id) != -+ group_id_to_index_in_heap->end() && -+ "group id should exist in the binary heap"); -+ assert(group_id_to_index_in_heap->find(group_id)->second == pos && -+ "index of group id in the heap should be same as pos"); -+ -+ heap_dis[pos] = dis; -+ heap_ids[pos] = id; -+ group_up_heap( -+ heap_dis, heap_ids, heap_group_ids, group_id_to_index_in_heap, pos); -+ group_down_heap( -+ k, -+ heap_dis, -+ heap_ids, -+ heap_group_ids, -+ group_id_to_index_in_heap, -+ pos); -+} -+ -+} // namespace faiss -\ No newline at end of file -diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt -index cc0a4f4c..96e19328 100644 ---- a/tests/CMakeLists.txt -+++ b/tests/CMakeLists.txt -@@ -26,6 +26,8 @@ set(FAISS_TEST_SRC - test_approx_topk.cpp - test_RCQ_cropping.cpp - test_distances_simd.cpp -+ test_id_grouper.cpp -+ test_group_heap.cpp - test_heap.cpp - test_code_distance.cpp - test_hnsw.cpp -diff --git a/tests/test_group_heap.cpp b/tests/test_group_heap.cpp -new file mode 100644 -index 00000000..0e8fe7a7 ---- /dev/null -+++ b/tests/test_group_heap.cpp -@@ -0,0 +1,98 @@ -+/** -+ * Copyright (c) Facebook, Inc. and its affiliates. -+ * -+ * This source code is licensed under the MIT license found in the -+ * LICENSE file in the root directory of this source tree. -+ */ -+#include -+#include -+#include -+#include -+ -+using namespace faiss; -+ -+TEST(GroupHeap, group_heap_replace_top) { -+ using C = CMax; -+ const int k = 100; -+ float binary_heap_values[k]; -+ int64_t binary_heap_ids[k]; -+ heap_heapify(k, binary_heap_values, binary_heap_ids); -+ int64_t binary_heap_group_ids[k]; -+ for (size_t i = 0; i < k; i++) { -+ binary_heap_group_ids[i] = -1; -+ } -+ std::unordered_map group_id_to_index_in_heap; -+ for (int i = 1000; i > 0; i--) { -+ group_heap_replace_top( -+ k, -+ binary_heap_values, -+ binary_heap_ids, -+ binary_heap_group_ids, -+ i * 10.0, -+ i, -+ i, -+ &group_id_to_index_in_heap); -+ } -+ -+ heap_reorder(k, binary_heap_values, binary_heap_ids); -+ -+ for (int i = 0; i < k; i++) { -+ ASSERT_EQ((i + 1) * 10.0, binary_heap_values[i]); -+ ASSERT_EQ(i + 1, binary_heap_ids[i]); -+ } -+} -+ -+TEST(GroupHeap, group_heap_replace_at) { -+ using C = CMax; -+ const int k = 10; -+ float binary_heap_values[k]; -+ int64_t binary_heap_ids[k]; -+ heap_heapify(k, binary_heap_values, binary_heap_ids); -+ int64_t binary_heap_group_ids[k]; -+ for (size_t i = 0; i < k; i++) { -+ binary_heap_group_ids[i] = -1; -+ } -+ std::unordered_map group_id_to_index_in_heap; -+ -+ std::unordered_map group_id_to_id; -+ for (int i = 1000; i > 0; i--) { -+ int64_t group_id = rand() % 100; -+ group_id_to_id[group_id] = i; -+ if (group_id_to_index_in_heap.find(group_id) == -+ group_id_to_index_in_heap.end()) { -+ group_heap_replace_top( -+ k, -+ binary_heap_values, -+ binary_heap_ids, -+ binary_heap_group_ids, -+ i * 10.0, -+ i, -+ group_id, -+ &group_id_to_index_in_heap); -+ } else { -+ group_heap_replace_at( -+ group_id_to_index_in_heap.at(group_id), -+ k, -+ binary_heap_values, -+ binary_heap_ids, -+ binary_heap_group_ids, -+ i * 10.0, -+ i, -+ group_id, -+ &group_id_to_index_in_heap); -+ } -+ } -+ -+ heap_reorder(k, binary_heap_values, binary_heap_ids); -+ -+ std::vector sorted_ids; -+ for (const auto& pair : group_id_to_id) { -+ sorted_ids.push_back(pair.second); -+ } -+ std::sort(sorted_ids.begin(), sorted_ids.end()); -+ -+ for (int i = 0; i < k && binary_heap_ids[i] != -1; i++) { -+ ASSERT_EQ(sorted_ids[i] * 10.0, binary_heap_values[i]); -+ ASSERT_EQ(sorted_ids[i], binary_heap_ids[i]); -+ } -+} -diff --git a/tests/test_id_grouper.cpp b/tests/test_id_grouper.cpp -new file mode 100644 -index 00000000..2aed5500 ---- /dev/null -+++ b/tests/test_id_grouper.cpp -@@ -0,0 +1,189 @@ -+/** -+ * Copyright (c) Facebook, Inc. and its affiliates. -+ * -+ * This source code is licensed under the MIT license found in the -+ * LICENSE file in the root directory of this source tree. -+ */ -+#include -+#include -+#include -+#include -+ -+#include -+#include -+#include -+#include -+#include -+ -+// 64-bit int -+using idx_t = faiss::idx_t; -+ -+using namespace faiss; -+ -+TEST(IdGrouper, get_group) { -+ uint64_t ids1[1] = {0b1000100010001000}; -+ IDGrouperBitmap bitmap(1, ids1); -+ -+ ASSERT_EQ(3, bitmap.get_group(0)); -+ ASSERT_EQ(3, bitmap.get_group(1)); -+ ASSERT_EQ(3, bitmap.get_group(2)); -+ ASSERT_EQ(3, bitmap.get_group(3)); -+ ASSERT_EQ(7, bitmap.get_group(4)); -+ ASSERT_EQ(7, bitmap.get_group(5)); -+ ASSERT_EQ(7, bitmap.get_group(6)); -+ ASSERT_EQ(7, bitmap.get_group(7)); -+ ASSERT_EQ(11, bitmap.get_group(8)); -+ ASSERT_EQ(11, bitmap.get_group(9)); -+ ASSERT_EQ(11, bitmap.get_group(10)); -+ ASSERT_EQ(11, bitmap.get_group(11)); -+ ASSERT_EQ(15, bitmap.get_group(12)); -+ ASSERT_EQ(15, bitmap.get_group(13)); -+ ASSERT_EQ(15, bitmap.get_group(14)); -+ ASSERT_EQ(15, bitmap.get_group(15)); -+ ASSERT_EQ(bitmap.NO_MORE_DOCS, bitmap.get_group(16)); -+} -+ -+TEST(IdGrouper, set_group) { -+ idx_t group_ids[] = {64, 127, 128, 1022}; -+ uint64_t ids[16] = {}; // 1023 / 64 + 1 -+ IDGrouperBitmap bitmap(16, ids); -+ -+ for (int i = 0; i < 4; i++) { -+ bitmap.set_group(group_ids[i]); -+ } -+ -+ int group_id_index = 0; -+ for (int i = 0; i <= group_ids[3]; i++) { -+ ASSERT_EQ(group_ids[group_id_index], bitmap.get_group(i)); -+ if (group_ids[group_id_index] == i) { -+ group_id_index++; -+ } -+ } -+ ASSERT_EQ(bitmap.NO_MORE_DOCS, bitmap.get_group(group_ids[3] + 1)); -+} -+ -+TEST(IdGrouper, bitmap_with_hnsw) { -+ int d = 1; // dimension -+ int nb = 10; // database size -+ -+ std::mt19937 rng; -+ std::uniform_real_distribution<> distrib; -+ -+ float* xb = new float[d * nb]; -+ -+ for (int i = 0; i < nb; i++) { -+ for (int j = 0; j < d; j++) -+ xb[d * i + j] = distrib(rng); -+ xb[d * i] += i / 1000.; -+ } -+ -+ uint64_t bitmap[1] = {}; -+ faiss::IDGrouperBitmap id_grouper(1, bitmap); -+ for (int i = 0; i < nb; i++) { -+ if (i % 2 == 1) { -+ id_grouper.set_group(i); -+ } -+ } -+ -+ int k = 10; -+ int m = 8; -+ faiss::Index* index = -+ new faiss::IndexHNSWFlat(d, m, faiss::MetricType::METRIC_L2); -+ index->add(nb, xb); // add vectors to the index -+ -+ // search -+ idx_t* I = new idx_t[k]; -+ float* D = new float[k]; -+ -+ auto pSearchParameters = new faiss::SearchParametersHNSW(); -+ pSearchParameters->grp = &id_grouper; -+ -+ index->search(1, xb, k, D, I, pSearchParameters); -+ -+ std::unordered_set group_ids; -+ ASSERT_EQ(0, I[0]); -+ ASSERT_EQ(0, D[0]); -+ group_ids.insert(id_grouper.get_group(I[0])); -+ for (int j = 1; j < 5; j++) { -+ ASSERT_NE(-1, I[j]); -+ ASSERT_NE(std::numeric_limits::max(), D[j]); -+ group_ids.insert(id_grouper.get_group(I[j])); -+ } -+ for (int j = 5; j < k; j++) { -+ ASSERT_EQ(-1, I[j]); -+ ASSERT_EQ(std::numeric_limits::max(), D[j]); -+ } -+ ASSERT_EQ(5, group_ids.size()); -+ -+ delete[] I; -+ delete[] D; -+ delete[] xb; -+} -+ -+TEST(IdGrouper, bitmap_with_hnswn_idmap) { -+ int d = 1; // dimension -+ int nb = 10; // database size -+ -+ std::mt19937 rng; -+ std::uniform_real_distribution<> distrib; -+ -+ float* xb = new float[d * nb]; -+ idx_t* xids = new idx_t[d * nb]; -+ -+ for (int i = 0; i < nb; i++) { -+ for (int j = 0; j < d; j++) -+ xb[d * i + j] = distrib(rng); -+ xb[d * i] += i / 1000.; -+ } -+ -+ uint64_t bitmap[1] = {}; -+ faiss::IDGrouperBitmap id_grouper(1, bitmap); -+ int num_grp = 0; -+ int grp_size = 2; -+ int id_in_grp = 0; -+ for (int i = 0; i < nb; i++) { -+ xids[i] = i + num_grp; -+ id_in_grp++; -+ if (id_in_grp == grp_size) { -+ id_grouper.set_group(i + num_grp + 1); -+ num_grp++; -+ id_in_grp = 0; -+ } -+ } -+ -+ int k = 10; -+ int m = 8; -+ faiss::Index* index = -+ new faiss::IndexHNSWFlat(d, m, faiss::MetricType::METRIC_L2); -+ faiss::IndexIDMap id_map = -+ faiss::IndexIDMap(index); // add vectors to the index -+ id_map.add_with_ids(nb, xb, xids); -+ -+ // search -+ idx_t* I = new idx_t[k]; -+ float* D = new float[k]; -+ -+ auto pSearchParameters = new faiss::SearchParametersHNSW(); -+ pSearchParameters->grp = &id_grouper; -+ -+ id_map.search(1, xb, k, D, I, pSearchParameters); -+ -+ std::unordered_set group_ids; -+ ASSERT_EQ(0, I[0]); -+ ASSERT_EQ(0, D[0]); -+ group_ids.insert(id_grouper.get_group(I[0])); -+ for (int j = 1; j < 5; j++) { -+ ASSERT_NE(-1, I[j]); -+ ASSERT_NE(std::numeric_limits::max(), D[j]); -+ group_ids.insert(id_grouper.get_group(I[j])); -+ } -+ for (int j = 5; j < k; j++) { -+ ASSERT_EQ(-1, I[j]); -+ ASSERT_EQ(std::numeric_limits::max(), D[j]); -+ } -+ ASSERT_EQ(5, group_ids.size()); -+ -+ delete[] I; -+ delete[] D; -+ delete[] xb; -+} --- -2.39.3 (Apple Git-145) - diff --git a/jni/patches/faiss/0002-Enable-precomp-table-to-be-shared-ivfpq.patch b/jni/patches/faiss/0002-Enable-precomp-table-to-be-shared-ivfpq.patch deleted file mode 100644 index dfc5099aaa..0000000000 --- a/jni/patches/faiss/0002-Enable-precomp-table-to-be-shared-ivfpq.patch +++ /dev/null @@ -1,512 +0,0 @@ -From c5ca07299b427dedafc738b98bd20f8f286f6783 Mon Sep 17 00:00:00 2001 -From: John Mazanec -Date: Wed, 21 Feb 2024 15:34:15 -0800 -Subject: [PATCH] Enable precomp table to be shared ivfpq - -Changes IVFPQ and IVFPQFastScan indices to be able to share the -precomputed table amongst other instances. Switches var to a pointer and -add necessary functions to set them correctly. - -Adds a tests to validate the behavior. - -Signed-off-by: John Mazanec ---- - faiss/IndexIVFPQ.cpp | 47 +++++++- - faiss/IndexIVFPQ.h | 16 ++- - faiss/IndexIVFPQFastScan.cpp | 47 ++++++-- - faiss/IndexIVFPQFastScan.h | 11 +- - tests/CMakeLists.txt | 1 + - tests/test_disable_pq_sdc_tables.cpp | 4 +- - tests/test_ivfpq_share_table.cpp | 173 +++++++++++++++++++++++++++ - 7 files changed, 284 insertions(+), 15 deletions(-) - create mode 100644 tests/test_ivfpq_share_table.cpp - -diff --git a/faiss/IndexIVFPQ.cpp b/faiss/IndexIVFPQ.cpp -index 0b7f4d05..07bc7e83 100644 ---- a/faiss/IndexIVFPQ.cpp -+++ b/faiss/IndexIVFPQ.cpp -@@ -59,6 +59,29 @@ IndexIVFPQ::IndexIVFPQ( - polysemous_training = nullptr; - do_polysemous_training = false; - polysemous_ht = 0; -+ precomputed_table = new AlignedTable(); -+ owns_precomputed_table = true; -+} -+ -+IndexIVFPQ::IndexIVFPQ(const IndexIVFPQ& orig) : IndexIVF(orig), pq(orig.pq) { -+ code_size = orig.pq.code_size; -+ invlists->code_size = code_size; -+ is_trained = orig.is_trained; -+ by_residual = orig.by_residual; -+ use_precomputed_table = orig.use_precomputed_table; -+ scan_table_threshold = orig.scan_table_threshold; -+ -+ polysemous_training = orig.polysemous_training; -+ do_polysemous_training = orig.do_polysemous_training; -+ polysemous_ht = orig.polysemous_ht; -+ precomputed_table = new AlignedTable(*orig.precomputed_table); -+ owns_precomputed_table = true; -+} -+ -+IndexIVFPQ::~IndexIVFPQ() { -+ if (owns_precomputed_table) { -+ delete precomputed_table; -+ } - } - - /**************************************************************** -@@ -466,11 +489,23 @@ void IndexIVFPQ::precompute_table() { - use_precomputed_table, - quantizer, - pq, -- precomputed_table, -+ *precomputed_table, - by_residual, - verbose); - } - -+void IndexIVFPQ::set_precomputed_table( -+ AlignedTable* _precompute_table, -+ int _use_precomputed_table) { -+ // Clean up old pre-computed table -+ if (owns_precomputed_table) { -+ delete precomputed_table; -+ } -+ owns_precomputed_table = false; -+ precomputed_table = _precompute_table; -+ use_precomputed_table = _use_precomputed_table; -+} -+ - namespace { - - #define TIC t0 = get_cycles() -@@ -650,7 +685,7 @@ struct QueryTables { - - fvec_madd( - pq.M * pq.ksub, -- ivfpq.precomputed_table.data() + key * pq.ksub * pq.M, -+ ivfpq.precomputed_table->data() + key * pq.ksub * pq.M, - -2.0, - sim_table_2, - sim_table); -@@ -679,7 +714,7 @@ struct QueryTables { - k >>= cpq.nbits; - - // get corresponding table -- const float* pc = ivfpq.precomputed_table.data() + -+ const float* pc = ivfpq.precomputed_table->data() + - (ki * pq.M + cm * Mf) * pq.ksub; - - if (polysemous_ht == 0) { -@@ -709,7 +744,7 @@ struct QueryTables { - dis0 = coarse_dis; - - const float* s = -- ivfpq.precomputed_table.data() + key * pq.ksub * pq.M; -+ ivfpq.precomputed_table->data() + key * pq.ksub * pq.M; - for (int m = 0; m < pq.M; m++) { - sim_table_ptrs[m] = s; - s += pq.ksub; -@@ -729,7 +764,7 @@ struct QueryTables { - int ki = k & ((uint64_t(1) << cpq.nbits) - 1); - k >>= cpq.nbits; - -- const float* pc = ivfpq.precomputed_table.data() + -+ const float* pc = ivfpq.precomputed_table->data() + - (ki * pq.M + cm * Mf) * pq.ksub; - - for (int m = m0; m < m0 + Mf; m++) { -@@ -1346,6 +1381,8 @@ IndexIVFPQ::IndexIVFPQ() { - do_polysemous_training = false; - polysemous_ht = 0; - polysemous_training = nullptr; -+ precomputed_table = new AlignedTable(); -+ owns_precomputed_table = true; - } - - struct CodeCmp { -diff --git a/faiss/IndexIVFPQ.h b/faiss/IndexIVFPQ.h -index d5d21da4..850bbe44 100644 ---- a/faiss/IndexIVFPQ.h -+++ b/faiss/IndexIVFPQ.h -@@ -48,7 +48,8 @@ struct IndexIVFPQ : IndexIVF { - - /// if use_precompute_table - /// size nlist * pq.M * pq.ksub -- AlignedTable precomputed_table; -+ bool owns_precomputed_table; -+ AlignedTable* precomputed_table; - - IndexIVFPQ( - Index* quantizer, -@@ -58,6 +59,10 @@ struct IndexIVFPQ : IndexIVF { - size_t nbits_per_idx, - MetricType metric = METRIC_L2); - -+ IndexIVFPQ(const IndexIVFPQ& orig); -+ -+ ~IndexIVFPQ(); -+ - void encode_vectors( - idx_t n, - const float* x, -@@ -139,6 +144,15 @@ struct IndexIVFPQ : IndexIVF { - /// build precomputed table - void precompute_table(); - -+ /** -+ * Initialize the precomputed table -+ * @param precompute_table -+ * @param _use_precomputed_table -+ */ -+ void set_precomputed_table( -+ AlignedTable* precompute_table, -+ int _use_precomputed_table); -+ - IndexIVFPQ(); - }; - -diff --git a/faiss/IndexIVFPQFastScan.cpp b/faiss/IndexIVFPQFastScan.cpp -index d069db13..09a335ff 100644 ---- a/faiss/IndexIVFPQFastScan.cpp -+++ b/faiss/IndexIVFPQFastScan.cpp -@@ -46,6 +46,8 @@ IndexIVFPQFastScan::IndexIVFPQFastScan( - : IndexIVFFastScan(quantizer, d, nlist, 0, metric), pq(d, M, nbits) { - by_residual = false; // set to false by default because it's faster - -+ precomputed_table = new AlignedTable(); -+ owns_precomputed_table = true; - init_fastscan(M, nbits, nlist, metric, bbs); - } - -@@ -53,6 +55,17 @@ IndexIVFPQFastScan::IndexIVFPQFastScan() { - by_residual = false; - bbs = 0; - M2 = 0; -+ precomputed_table = new AlignedTable(); -+ owns_precomputed_table = true; -+} -+ -+IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQFastScan& orig) -+ : IndexIVFFastScan(orig), pq(orig.pq) { -+ by_residual = orig.by_residual; -+ bbs = orig.bbs; -+ M2 = orig.M2; -+ precomputed_table = new AlignedTable(*orig.precomputed_table); -+ owns_precomputed_table = true; - } - - IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs) -@@ -71,13 +84,15 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs) - ntotal = orig.ntotal; - is_trained = orig.is_trained; - nprobe = orig.nprobe; -+ precomputed_table = new AlignedTable(); -+ owns_precomputed_table = true; - -- precomputed_table.resize(orig.precomputed_table.size()); -+ precomputed_table->resize(orig.precomputed_table->size()); - -- if (precomputed_table.nbytes() > 0) { -- memcpy(precomputed_table.get(), -- orig.precomputed_table.data(), -- precomputed_table.nbytes()); -+ if (precomputed_table->nbytes() > 0) { -+ memcpy(precomputed_table->get(), -+ orig.precomputed_table->data(), -+ precomputed_table->nbytes()); - } - - for (size_t i = 0; i < nlist; i++) { -@@ -102,6 +117,12 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs) - orig_invlists = orig.invlists; - } - -+IndexIVFPQFastScan::~IndexIVFPQFastScan() { -+ if (owns_precomputed_table) { -+ delete precomputed_table; -+ } -+} -+ - /********************************************************* - * Training - *********************************************************/ -@@ -127,11 +148,23 @@ void IndexIVFPQFastScan::precompute_table() { - use_precomputed_table, - quantizer, - pq, -- precomputed_table, -+ *precomputed_table, - by_residual, - verbose); - } - -+void IndexIVFPQFastScan::set_precomputed_table( -+ AlignedTable* _precompute_table, -+ int _use_precomputed_table) { -+ // Clean up old pre-computed table -+ if (owns_precomputed_table) { -+ delete precomputed_table; -+ } -+ owns_precomputed_table = false; -+ precomputed_table = _precompute_table; -+ use_precomputed_table = _use_precomputed_table; -+} -+ - /********************************************************* - * Code management functions - *********************************************************/ -@@ -229,7 +262,7 @@ void IndexIVFPQFastScan::compute_LUT( - if (cij >= 0) { - fvec_madd_simd( - dim12, -- precomputed_table.get() + cij * dim12, -+ precomputed_table->get() + cij * dim12, - -2, - ip_table.get() + i * dim12, - tab); -diff --git a/faiss/IndexIVFPQFastScan.h b/faiss/IndexIVFPQFastScan.h -index 00dd2f11..91f35a6e 100644 ---- a/faiss/IndexIVFPQFastScan.h -+++ b/faiss/IndexIVFPQFastScan.h -@@ -38,7 +38,8 @@ struct IndexIVFPQFastScan : IndexIVFFastScan { - /// precomputed tables management - int use_precomputed_table = 0; - /// if use_precompute_table size (nlist, pq.M, pq.ksub) -- AlignedTable precomputed_table; -+ bool owns_precomputed_table; -+ AlignedTable* precomputed_table; - - IndexIVFPQFastScan( - Index* quantizer, -@@ -51,6 +52,10 @@ struct IndexIVFPQFastScan : IndexIVFFastScan { - - IndexIVFPQFastScan(); - -+ IndexIVFPQFastScan(const IndexIVFPQFastScan& orig); -+ -+ ~IndexIVFPQFastScan(); -+ - // built from an IndexIVFPQ - explicit IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs = 32); - -@@ -60,6 +65,10 @@ struct IndexIVFPQFastScan : IndexIVFFastScan { - - /// build precomputed table, possibly updating use_precomputed_table - void precompute_table(); -+ /// Pass in externally a precomputed -+ void set_precomputed_table( -+ AlignedTable* precompute_table, -+ int _use_precomputed_table); - - /// same as the regular IVFPQ encoder. The codes are not reorganized by - /// blocks a that point -diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt -index 9017edc5..0889bf72 100644 ---- a/tests/CMakeLists.txt -+++ b/tests/CMakeLists.txt -@@ -33,6 +33,7 @@ set(FAISS_TEST_SRC - test_partitioning.cpp - test_fastscan_perf.cpp - test_disable_pq_sdc_tables.cpp -+ test_ivfpq_share_table.cpp - ) - - add_executable(faiss_test ${FAISS_TEST_SRC}) -diff --git a/tests/test_disable_pq_sdc_tables.cpp b/tests/test_disable_pq_sdc_tables.cpp -index b211a5c4..a27973d5 100644 ---- a/tests/test_disable_pq_sdc_tables.cpp -+++ b/tests/test_disable_pq_sdc_tables.cpp -@@ -15,7 +15,9 @@ - #include "faiss/index_io.h" - #include "test_util.h" - --pthread_mutex_t temp_file_mutex = PTHREAD_MUTEX_INITIALIZER; -+namespace { -+ pthread_mutex_t temp_file_mutex = PTHREAD_MUTEX_INITIALIZER; -+} - - TEST(IO, TestReadHNSWPQ_whenSDCDisabledFlagPassed_thenDisableSDCTable) { - Tempfilename index_filename(&temp_file_mutex, "/tmp/faiss_TestReadHNSWPQ"); -diff --git a/tests/test_ivfpq_share_table.cpp b/tests/test_ivfpq_share_table.cpp -new file mode 100644 -index 00000000..f827315d ---- /dev/null -+++ b/tests/test_ivfpq_share_table.cpp -@@ -0,0 +1,173 @@ -+/** -+ * Copyright (c) Facebook, Inc. and its affiliates. -+ * -+ * This source code is licensed under the MIT license found in the -+ * LICENSE file in the root directory of this source tree. -+ */ -+ -+#include -+ -+#include -+ -+#include "faiss/Index.h" -+#include "faiss/IndexHNSW.h" -+#include "faiss/IndexIVFPQFastScan.h" -+#include "faiss/index_factory.h" -+#include "faiss/index_io.h" -+#include "test_util.h" -+ -+namespace { -+ pthread_mutex_t temp_file_mutex = PTHREAD_MUTEX_INITIALIZER; -+} -+ -+std::vector generate_data( -+ int d, -+ int n, -+ std::default_random_engine rng, -+ std::uniform_real_distribution u) { -+ std::vector vectors(n * d); -+ for (size_t i = 0; i < n * d; i++) { -+ vectors[i] = u(rng); -+ } -+ return vectors; -+} -+ -+void assert_float_vectors_almost_equal( -+ std::vector a, -+ std::vector b) { -+ float margin = 0.000001; -+ ASSERT_EQ(a.size(), b.size()); -+ for (int i = 0; i < a.size(); i++) { -+ ASSERT_NEAR(a[i], b[i], margin); -+ } -+} -+ -+/// Test case test precomputed table sharing for IVFPQ indices. -+template /// T represents class cast to use for index -+void test_ivfpq_table_sharing( -+ const std::string& index_description, -+ const std::string& filename, -+ faiss::MetricType metric) { -+ // Setup the index: -+ // 1. Build an index -+ // 2. ingest random data -+ // 3. serialize to disk -+ int d = 32, n = 1000; -+ std::default_random_engine rng( -+ std::chrono::system_clock::now().time_since_epoch().count()); -+ std::uniform_real_distribution u(0, 100); -+ -+ std::vector index_vectors = generate_data(d, n, rng, u); -+ std::vector query_vectors = generate_data(d, n, rng, u); -+ -+ Tempfilename index_filename(&temp_file_mutex, filename); -+ { -+ std::unique_ptr index_writer( -+ faiss::index_factory(d, index_description.c_str(), metric)); -+ -+ index_writer->train(n, index_vectors.data()); -+ index_writer->add(n, index_vectors.data()); -+ faiss::write_index(index_writer.get(), index_filename.c_str()); -+ } -+ -+ // Load index from disk. Confirm that the sdc table is equal to 0 when -+ // disable sdc is set -+ std::unique_ptr> sharedAlignedTable( -+ new faiss::AlignedTable()); -+ int shared_use_precomputed_table = 0; -+ int k = 10; -+ std::vector distances_test_a(k * n); -+ std::vector labels_test_a(k * n); -+ { -+ std::vector distances_baseline(k * n); -+ std::vector labels_baseline(k * n); -+ -+ std::unique_ptr index_read_pq_table_enabled( -+ dynamic_cast(faiss::read_index( -+ index_filename.c_str(), faiss::IO_FLAG_READ_ONLY))); -+ std::unique_ptr index_read_pq_table_disabled( -+ dynamic_cast(faiss::read_index( -+ index_filename.c_str(), -+ faiss::IO_FLAG_READ_ONLY | -+ faiss::IO_FLAG_SKIP_PRECOMPUTE_TABLE))); -+ faiss::initialize_IVFPQ_precomputed_table( -+ shared_use_precomputed_table, -+ index_read_pq_table_disabled->quantizer, -+ index_read_pq_table_disabled->pq, -+ *sharedAlignedTable, -+ index_read_pq_table_disabled->by_residual, -+ index_read_pq_table_disabled->verbose); -+ index_read_pq_table_disabled->set_precomputed_table( -+ sharedAlignedTable.get(), shared_use_precomputed_table); -+ -+ ASSERT_TRUE(index_read_pq_table_enabled->owns_precomputed_table); -+ ASSERT_FALSE(index_read_pq_table_disabled->owns_precomputed_table); -+ index_read_pq_table_enabled->search( -+ n, -+ query_vectors.data(), -+ k, -+ distances_baseline.data(), -+ labels_baseline.data()); -+ index_read_pq_table_disabled->search( -+ n, -+ query_vectors.data(), -+ k, -+ distances_test_a.data(), -+ labels_test_a.data()); -+ -+ assert_float_vectors_almost_equal(distances_baseline, distances_test_a); -+ ASSERT_EQ(labels_baseline, labels_test_a); -+ } -+ -+ // The precomputed table should only be set for L2 metric type -+ if (metric == faiss::METRIC_L2) { -+ ASSERT_EQ(shared_use_precomputed_table, 1); -+ } else { -+ ASSERT_EQ(shared_use_precomputed_table, 0); -+ } -+ -+ // At this point, the original has gone out of scope, the destructor has -+ // been called. Confirm that initializing a new index from the table -+ // preserves the functionality. -+ { -+ std::vector distances_test_b(k * n); -+ std::vector labels_test_b(k * n); -+ -+ std::unique_ptr index_read_pq_table_disabled( -+ dynamic_cast(faiss::read_index( -+ index_filename.c_str(), -+ faiss::IO_FLAG_READ_ONLY | -+ faiss::IO_FLAG_SKIP_PRECOMPUTE_TABLE))); -+ index_read_pq_table_disabled->set_precomputed_table( -+ sharedAlignedTable.get(), shared_use_precomputed_table); -+ ASSERT_FALSE(index_read_pq_table_disabled->owns_precomputed_table); -+ index_read_pq_table_disabled->search( -+ n, -+ query_vectors.data(), -+ k, -+ distances_test_b.data(), -+ labels_test_b.data()); -+ assert_float_vectors_almost_equal(distances_test_a, distances_test_b); -+ ASSERT_EQ(labels_test_a, labels_test_b); -+ } -+} -+ -+TEST(TestIVFPQTableSharing, L2) { -+ test_ivfpq_table_sharing( -+ "IVF16,PQ8x4", "/tmp/ivfpql2", faiss::METRIC_L2); -+} -+ -+TEST(TestIVFPQTableSharing, IP) { -+ test_ivfpq_table_sharing( -+ "IVF16,PQ8x4", "/tmp/ivfpqip", faiss::METRIC_INNER_PRODUCT); -+} -+ -+TEST(TestIVFPQTableSharing, FastScanL2) { -+ test_ivfpq_table_sharing( -+ "IVF16,PQ8x4fsr", "/tmp/ivfpqfsl2", faiss::METRIC_L2); -+} -+ -+TEST(TestIVFPQTableSharing, FastScanIP) { -+ test_ivfpq_table_sharing( -+ "IVF16,PQ8x4fsr", "/tmp/ivfpqfsip", faiss::METRIC_INNER_PRODUCT); -+} --- -2.39.3 (Apple Git-145) - diff --git a/jni/patches/nmslib/0001-Initialize-maxlevel-during-add-from-enterpoint-level.patch b/jni/patches/nmslib/0001-Initialize-maxlevel-during-add-from-enterpoint-level.patch deleted file mode 100644 index a9d9381f9b..0000000000 --- a/jni/patches/nmslib/0001-Initialize-maxlevel-during-add-from-enterpoint-level.patch +++ /dev/null @@ -1,31 +0,0 @@ -From aa1ca485c0ab8b79dae1fb5c1567149c5f61b533 Mon Sep 17 00:00:00 2001 -From: John Mazanec -Date: Thu, 14 Mar 2024 12:22:06 -0700 -Subject: [PATCH] Initialize maxlevel during add from enterpoint->level - -Signed-off-by: John Mazanec ---- - similarity_search/src/method/hnsw.cc | 6 +++++- - 1 file changed, 5 insertions(+), 1 deletion(-) - -diff --git a/similarity_search/src/method/hnsw.cc b/similarity_search/src/method/hnsw.cc -index 35b372c..e9a725e 100644 ---- a/similarity_search/src/method/hnsw.cc -+++ b/similarity_search/src/method/hnsw.cc -@@ -542,8 +542,12 @@ namespace similarity { - - NewElement->init(curlevel, maxM_, maxM0_); - -- int maxlevelcopy = maxlevel_; -+ // Get the enterpoint at this moment and then use it to set the -+ // max level that is used. Copying maxlevel from this->maxlevel_ -+ // can lead to race conditions during concurrent insertion. See: -+ // https://github.com/nmslib/nmslib/issues/544 - HnswNode *ep = enterpoint_; -+ int maxlevelcopy = ep->level; - if (curlevel < maxlevelcopy) { - const Object *currObj = ep->getData(); - --- -2.39.3 (Apple Git-146) - diff --git a/jni/src/commons.cpp b/jni/src/commons.cpp new file mode 100644 index 0000000000..3c03ac49d9 --- /dev/null +++ b/jni/src/commons.cpp @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ +#ifndef OPENSEARCH_KNN_COMMONS_H +#define OPENSEARCH_KNN_COMMONS_H +#include + +#include + +#include "jni_util.h" +#include "commons.h" + +jlong knn_jni::commons::storeVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ, + jobjectArray dataJ, jlong initialCapacityJ) { + std::vector *vect; + if ((long) memoryAddressJ == 0) { + vect = new std::vector(); + vect->reserve((long)initialCapacityJ); + } else { + vect = reinterpret_cast*>(memoryAddressJ); + } + int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, dataJ); + jniUtil->Convert2dJavaObjectArrayAndStoreToFloatVector(env, dataJ, dim, vect); + + return (jlong) vect; +} + +void knn_jni::commons::freeVectorData(jlong memoryAddressJ) { + if (memoryAddressJ != 0) { + auto *vect = reinterpret_cast*>(memoryAddressJ); + delete vect; + } +} +#endif //OPENSEARCH_KNN_COMMONS_H \ No newline at end of file diff --git a/jni/src/jni_util.cpp b/jni/src/jni_util.cpp index a0c1d57336..a1faa4894f 100644 --- a/jni/src/jni_util.cpp +++ b/jni/src/jni_util.cpp @@ -223,6 +223,13 @@ int knn_jni::JNIUtil::ConvertJavaObjectToCppInteger(JNIEnv *env, jobject objectJ std::vector knn_jni::JNIUtil::Convert2dJavaObjectArrayToCppFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim) { + std::vector vect; + Convert2dJavaObjectArrayAndStoreToFloatVector(env, array2dJ, dim, &vect); + return vect; +} + +void knn_jni::JNIUtil::Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, + int dim, std::vector *vect) { if (array2dJ == nullptr) { throw std::runtime_error("Array cannot be null"); @@ -231,7 +238,6 @@ std::vector knn_jni::JNIUtil::Convert2dJavaObjectArrayToCppFloatVector(JN int numVectors = env->GetArrayLength(array2dJ); this->HasExceptionInStack(env); - std::vector floatVectorCpp; for (int i = 0; i < numVectors; ++i) { auto vectorArray = (jfloatArray)env->GetObjectArrayElement(array2dJ, i); this->HasExceptionInStack(env, "Unable to get object array element"); @@ -247,13 +253,12 @@ std::vector knn_jni::JNIUtil::Convert2dJavaObjectArrayToCppFloatVector(JN } for(int j = 0; j < dim; ++j) { - floatVectorCpp.push_back(vector[j]); + vect->push_back(vector[j]); } env->ReleaseFloatArrayElements(vectorArray, vector, JNI_ABORT); } this->HasExceptionInStack(env); env->DeleteLocalRef(array2dJ); - return floatVectorCpp; } std::vector knn_jni::JNIUtil::ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) { diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 3d9624c250..ae9d461ffb 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -13,7 +13,6 @@ #include -#include #include #include "faiss_wrapper.h" @@ -191,24 +190,6 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors return (jlong) vect; } -JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectorsV2(JNIEnv * env, jclass cls, -jlong vectorsPointerJ, - jobjectArray vectorsJ) -{ - std::vector *vect; - if ((long) vectorsPointerJ == 0) { - vect = new std::vector; - } else { - vect = reinterpret_cast*>(vectorsPointerJ); - } - - int dim = jniUtil.GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ); - auto dataset = jniUtil.Convert2dJavaObjectArrayToCppFloatVector(env, vectorsJ, dim); - vect->insert(vect->end(), dataset.begin(), dataset.end()); - - return (jlong) vect; -} - JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_freeVectors(JNIEnv * env, jclass cls, jlong vectorsPointerJ) { diff --git a/jni/src/org_opensearch_knn_jni_JNICommons.cpp b/jni/src/org_opensearch_knn_jni_JNICommons.cpp new file mode 100644 index 0000000000..ccdd118826 --- /dev/null +++ b/jni/src/org_opensearch_knn_jni_JNICommons.cpp @@ -0,0 +1,60 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +#include "org_opensearch_knn_jni_JNICommons.h" + +#include +#include "commons.h" +#include "jni_util.h" + +static knn_jni::JNIUtil jniUtil; +static const jint KNN_JNICOMMONS_JNI_VERSION = JNI_VERSION_1_1; + +jint JNI_OnLoad(JavaVM* vm, void* reserved) { + // Obtain the JNIEnv from the VM and confirm JNI_VERSION + JNIEnv* env; + if (vm->GetEnv((void**)&env, KNN_JNICOMMONS_JNI_VERSION) != JNI_OK) { + return JNI_ERR; + } + + jniUtil.Initialize(env); + + return KNN_JNICOMMONS_JNI_VERSION; +} + +void JNI_OnUnload(JavaVM *vm, void *reserved) { + JNIEnv* env; + vm->GetEnv((void**)&env, KNN_JNICOMMONS_JNI_VERSION); + jniUtil.Uninitialize(env); +} + + +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData(JNIEnv * env, jclass cls, +jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ) + +{ + try { + return knn_jni::commons::storeVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return (long)memoryAddressJ; +} + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeVectorData(JNIEnv * env, jclass cls, + jlong memoryAddressJ) +{ + try { + return knn_jni::commons::freeVectorData(memoryAddressJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} diff --git a/jni/tests/commons_test.cpp b/jni/tests/commons_test.cpp new file mode 100644 index 0000000000..064a9f1f19 --- /dev/null +++ b/jni/tests/commons_test.cpp @@ -0,0 +1,76 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + + +#include "test_util.h" +#include +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "jni_util.h" +#include "commons.h" + +TEST(CommonsTests, BasicAssertions) { + long dim = 3; + long totalNumberOfVector = 5; + std::vector> data; + for(int i = 0 ; i < totalNumberOfVector - 1 ; i++) { + std::vector vector; + for(int j = 0 ; j < dim ; j ++) { + vector.push_back((float)j); + } + data.push_back(vector); + } + JNIEnv *jniEnv = nullptr; + + testing::NiceMock mockJNIUtil; + + jlong memoryAddress = knn_jni::commons::storeVectorData(&mockJNIUtil, jniEnv, (jlong)0, + reinterpret_cast(&data), (jlong)(totalNumberOfVector * dim) , (jlong)0); + ASSERT_NE(memoryAddress, 0); + auto *vect = reinterpret_cast*>(memoryAddress); + ASSERT_EQ(vect->size(), totalNumberOfVector*dim); + + // Check by inserting more vectors at same memory location + jlong oldMemoryAddress = memoryAddress; + std::vector> data2; + std::vector vector; + for(int j = 0 ; j < dim ; j ++) { + vector.push_back((float)j); + } + data2.push_back(vector); + memoryAddress = knn_jni::commons::storeVectorData(&mockJNIUtil, jniEnv, memoryAddress, + reinterpret_cast(&data2), (jlong)(totalNumberOfVector * dim) , (jlong)(data.size() * dim)); + ASSERT_NE(memoryAddress, 0); + ASSERT_EQ(memoryAddress, oldMemoryAddress); + vect = reinterpret_cast*>(memoryAddress); + int currentIndex = 0; + ASSERT_EQ(vect->size(), totalNumberOfVector*dim); + + // Validate if all vectors data are at correct location + for(auto & i : data) { + for(float j : i) { + ASSERT_FLOAT_EQ(vect->at(currentIndex), j); + currentIndex++; + } + } + + for(auto & i : data2) { + for(float j : i) { + ASSERT_FLOAT_EQ(vect->at(currentIndex), j); + currentIndex++; + } + } + + // Validate if more data is provided then the initial capacity throw exception + ASSERT_THROW( + knn_jni::commons::storeVectorData(&mockJNIUtil, jniEnv, memoryAddress, + reinterpret_cast(&data2), (jlong)(totalNumberOfVector * dim), (jlong)(vect->size())), std::runtime_error); +} \ No newline at end of file diff --git a/jni/tests/test_util.cpp b/jni/tests/test_util.cpp index 89b19f9aa5..92532b9e26 100644 --- a/jni/tests/test_util.cpp +++ b/jni/tests/test_util.cpp @@ -45,6 +45,14 @@ test_util::MockJNIUtil::MockJNIUtil() { return data; }); + ON_CALL(*this, Convert2dJavaObjectArrayAndStoreToFloatVector) + .WillByDefault([this](JNIEnv *env, jobjectArray array2dJ, int dim, std::vector* data) { + for (const auto &v : + (*reinterpret_cast> *>(array2dJ))) + for (auto item : v) data->push_back(item); + }); + + // arrayJ is re-interpreted as std::vector * ON_CALL(*this, ConvertJavaIntArrayToCppIntVector) .WillByDefault([this](JNIEnv *env, jintArray arrayJ) { diff --git a/jni/tests/test_util.h b/jni/tests/test_util.h index 1e32ad3c30..8e73a8ab0c 100644 --- a/jni/tests/test_util.h +++ b/jni/tests/test_util.h @@ -44,6 +44,8 @@ namespace test_util { // TODO: Figure out why this cant use "new" MOCK_METHOD MOCK_METHOD(std::vector, Convert2dJavaObjectArrayToCppFloatVector, (JNIEnv * env, jobjectArray array2dJ, int dim)); + MOCK_METHOD(void, Convert2dJavaObjectArrayAndStoreToFloatVector, + (JNIEnv * env, jobjectArray array2dJ, int dim, std::vector*vect)); MOCK_METHOD(std::vector, ConvertJavaIntArrayToCppIntVector, (JNIEnv * env, jintArray arrayJ)); MOCK_METHOD2(ConvertJavaMapToCppMap, diff --git a/micro-benchmarks/src/main/java/org/opensearch/knn/TransferVectorsBenchmarks.java b/micro-benchmarks/src/main/java/org/opensearch/knn/TransferVectorsBenchmarks.java index ad1076484f..2bce54ee61 100644 --- a/micro-benchmarks/src/main/java/org/opensearch/knn/TransferVectorsBenchmarks.java +++ b/micro-benchmarks/src/main/java/org/opensearch/knn/TransferVectorsBenchmarks.java @@ -23,7 +23,7 @@ import org.openjdk.jmh.annotations.Setup; import org.openjdk.jmh.annotations.State; import org.openjdk.jmh.annotations.Warmup; -import org.opensearch.knn.jni.JNIService; +import org.opensearch.knn.jni.JNICommons; import java.util.ArrayList; import java.util.List; @@ -42,9 +42,9 @@ @State(Scope.Benchmark) public class TransferVectorsBenchmarks { private static final Random random = new Random(1212121212); - private static final int TOTAL_NUMBER_OF_VECTOR_TO_BE_TRANSFERRED = 1000000; + private static final long TOTAL_NUMBER_OF_VECTOR_TO_BE_TRANSFERRED = 1000000; - @Param({ "128", "256", "384", "512" }) + @Param({ "128", "256", "384", "512", "960", "1024", "1536" }) private int dimension; @Param({ "100000", "500000", "1000000" }) @@ -61,20 +61,30 @@ public void setup() { } @Benchmark - public void transferVectors() { + public void transferVectors_withCapacity() { long vectorsAddress = 0; List vectorToTransfer = new ArrayList<>(); + long startingIndex = 0; for (float[] floats : vectorList) { if (vectorToTransfer.size() == vectorsPerTransfer) { - vectorsAddress = JNIService.transferVectorsV2(vectorsAddress, vectorToTransfer.toArray(new float[][] {})); + vectorsAddress = JNICommons.storeVectorData( + vectorsAddress, + vectorToTransfer.toArray(new float[][] {}), + dimension * TOTAL_NUMBER_OF_VECTOR_TO_BE_TRANSFERRED + ); + startingIndex += vectorsPerTransfer; vectorToTransfer = new ArrayList<>(); } vectorToTransfer.add(floats); } if (!vectorToTransfer.isEmpty()) { - vectorsAddress = JNIService.transferVectorsV2(vectorsAddress, vectorToTransfer.toArray(new float[][] {})); + vectorsAddress = JNICommons.storeVectorData( + vectorsAddress, + vectorToTransfer.toArray(new float[][] {}), + dimension * TOTAL_NUMBER_OF_VECTOR_TO_BE_TRANSFERRED + ); } - JNIService.freeVectors(vectorsAddress); + JNICommons.freeVectorData(vectorsAddress); } private float[] generateRandomVector(int dimensions) { diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 34805b7e5a..98622ee85f 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -72,6 +72,7 @@ public class KNNConstants { // nmslib specific constants public static final String NMSLIB_NAME = "nmslib"; + public static final String COMMONS_NAME = "common"; public static final String SPACE_TYPE = "spaceType"; // used as field info key public static final String HNSW_ALGO_M = "M"; public static final String HNSW_ALGO_EF_CONSTRUCTION = "efConstruction"; @@ -121,6 +122,8 @@ public class KNNConstants { public static final String FAISS_AVX2_JNI_LIBRARY_NAME = JNI_LIBRARY_PREFIX + FAISS_NAME + "_avx2"; public static final String NMSLIB_JNI_LIBRARY_NAME = JNI_LIBRARY_PREFIX + NMSLIB_NAME; + public static final String COMMON_JNI_LIBRARY_NAME = JNI_LIBRARY_PREFIX + COMMONS_NAME; + // API Constants public static final String CLEAR_CACHE = "clear_cache"; diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 4b50453598..517945968e 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -173,33 +173,26 @@ public static native KNNQueryResult[] queryIndexWithFilter( public static native byte[] trainIndex(Map indexParameters, int dimension, long trainVectorsPointer); /** + *

+ * The function is deprecated. Use {@link JNICommons#storeVectorData(long, float[][], long)} + *

* Transfer vectors from Java to native * * @param vectorsPointer pointer to vectors in native memory. Should be 0 to create vector as well * @param trainingData data to be transferred * @return pointer to native memory location of training data */ + @Deprecated(since = "2.14.0", forRemoval = true) public static native long transferVectors(long vectorsPointer, float[][] trainingData); /** - * Transfer vectors from Java to native layer. This is the version 2 of transfer vector functionality. The - * difference between this and the version 1 is, this version puts vectors at the end rather than in front. - * Keeping this name as V2 for now, will come up with better name going forward. *

- * TODO: Rename the function - *
- * TODO: Make this function native function and use a common cpp file to host these functions. + * The function is deprecated. Use {@link JNICommons#freeVectorData(long)} *

- * @param vectorsPointer pointer to vectors in native memory. Should be 0 to create vector as well - * @param data data to be transferred - * @return pointer to native memory location for data - */ - public static native long transferVectorsV2(long vectorsPointer, float[][] data); - - /** * Free vectors from memory * * @param vectorsPointer to be freed */ + @Deprecated(since = "2.14.0", forRemoval = true) public static native void freeVectors(long vectorsPointer); } diff --git a/src/main/java/org/opensearch/knn/jni/JNICommons.java b/src/main/java/org/opensearch/knn/jni/JNICommons.java new file mode 100644 index 0000000000..f7c6974a01 --- /dev/null +++ b/src/main/java/org/opensearch/knn/jni/JNICommons.java @@ -0,0 +1,52 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.jni; + +import org.opensearch.knn.common.KNNConstants; + +import java.security.AccessController; +import java.security.PrivilegedAction; + +/** + * Common class for providing the JNI related functionality to various JNIServices. + */ +public class JNICommons { + + static { + AccessController.doPrivileged((PrivilegedAction) () -> { + System.loadLibrary(KNNConstants.COMMON_JNI_LIBRARY_NAME); + return null; + }); + } + + /** + * This is utility function that can be used to store data in native memory. This function will allocate memory for + * the data(rows*columns) with initialCapacity and return the memory address where the data is stored. + * If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created. + * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location + * will throw Exception. + * + * @param memoryAddress The address of the memory location where data will be stored. + * @param data 2D float array containing data to be stored in native memory. + * @param initialCapacity The initial capacity of the memory location. + * @return memory address where the data is stored. + */ + public static native long storeVectorData(long memoryAddress, float[][] data, long initialCapacity); + + /** + * Free up the memory allocated for the data stored in memory address. This function should be used with the memory + * address returned by {@link JNICommons#storeVectorData(long, float[][], long)} + * + * @param memoryAddress address to be freed. + */ + public static native void freeVectorData(long memoryAddress); +} diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index 80b56b1736..a850895170 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -235,44 +235,30 @@ public static byte[] trainIndex(Map indexParameters, int dimensi } /** + *

+ * The function is deprecated. Use {@link JNICommons#storeVectorData(long, float[][], long)} + *

* Transfer vectors from Java to native * * @param vectorsPointer pointer to vectors in native memory. Should be 0 to create vector as well * @param trainingData data to be transferred * @return pointer to native memory location of training data */ + @Deprecated(since = "2.14.0", forRemoval = true) public static long transferVectors(long vectorsPointer, float[][] trainingData) { return FaissService.transferVectors(vectorsPointer, trainingData); } /** + *

+ * The function is deprecated. Use {@link JNICommons#freeVectorData(long)} + *

* Free vectors from memory * * @param vectorsPointer to be freed */ + @Deprecated(since = "2.14.0", forRemoval = true) public static void freeVectors(long vectorsPointer) { FaissService.freeVectors(vectorsPointer); } - - /** - * Experimental: Transfer vectors from Java to native layer. This is the version 2 of transfer vector - * functionality. The difference between this and the version 1 is, this version puts vectors at the end rather - * than in front. Keeping this name as V2 for now, will come up with better name going forward. - *

- * This is not a production ready function for now. Adding this to ensure that we are able to run atleast 1 - * micro-benchmarks. - *

- *

- * TODO: Rename the function - *
- * TODO: Make this function native function and use a common cpp file to host these functions. - *

- * @param vectorsPointer pointer to vectors in native memory. Should be 0 to create vector as well - * @param data data to be transferred - * @return pointer to native memory location for data - * - */ - public static long transferVectorsV2(long vectorsPointer, float[][] data) { - return FaissService.transferVectorsV2(vectorsPointer, data); - } } diff --git a/src/test/java/org/opensearch/knn/jni/JNICommonsTest.java b/src/test/java/org/opensearch/knn/jni/JNICommonsTest.java new file mode 100644 index 0000000000..4b92c69062 --- /dev/null +++ b/src/test/java/org/opensearch/knn/jni/JNICommonsTest.java @@ -0,0 +1,40 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.jni; + +import org.opensearch.knn.KNNTestCase; + +public class JNICommonsTest extends KNNTestCase { + + public void testStoreVectorData_whenVaildInputThenSuccess() { + float[][] data = new float[2][2]; + for(int i = 0 ; i < 2 ; i++) { + for(int j = 0; j < 2; j++) { + data[i][j] = i + j; + } + } + long memoryAddress = JNICommons.storeVectorData(0, data, 8); + assertTrue(memoryAddress > 0); + assertEquals(memoryAddress, JNICommons.storeVectorData(memoryAddress, data, 8)); + } + + public void testFreeVectorData_whenValidInput_ThenSuccess() { + float[][] data = new float[2][2]; + for(int i = 0 ; i < 2 ; i++) { + for(int j = 0; j < 2; j++) { + data[i][j] = i + j; + } + } + long memoryAddress = JNICommons.storeVectorData(0, data, 8); + JNICommons.freeVectorData(memoryAddress); + } +}