From 5d83d1e9463d6e86eda923e8e76724607f705269 Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Fri, 12 Jul 2024 15:31:17 -0500 Subject: [PATCH 1/2] Add HNSW changes to support Faiss byte vector Signed-off-by: Naveen Tatikonda --- CHANGELOG.md | 1 + jni/include/commons.h | 24 ++- jni/include/faiss_index_service.h | 57 +++++++ jni/include/jni_util.h | 7 +- .../org_opensearch_knn_jni_FaissService.h | 30 ++++ .../org_opensearch_knn_jni_JNICommons.h | 20 ++- jni/src/commons.cpp | 31 +++- jni/src/faiss_index_service.cpp | 115 ++++++++++++++ jni/src/jni_util.cpp | 34 +++- .../org_opensearch_knn_jni_FaissService.cpp | 41 +++++ jni/src/org_opensearch_knn_jni_JNICommons.cpp | 22 +++ jni/tests/faiss_index_service_test.cpp | 44 ++++++ jni/tests/test_util.cpp | 8 +- jni/tests/test_util.h | 4 +- .../opensearch/knn/common/KNNConstants.java | 1 + .../transfer/OffHeapBinaryVectorTransfer.java | 14 +- .../OffHeapVectorTransferFactory.java | 2 +- .../index/engine/faiss/FaissHNSWMethod.java | 6 +- .../mapper/KNNVectorFieldMapperUtil.java | 2 +- .../knn/index/mapper/MethodFieldMapper.java | 24 ++- .../index/memory/NativeMemoryAllocation.java | 2 +- .../knn/index/query/KNNQueryBuilder.java | 60 ++++--- .../opensearch/knn/index/util/IndexUtil.java | 12 ++ .../org/opensearch/knn/jni/FaissService.java | 33 ++++ .../org/opensearch/knn/jni/JNICommons.java | 66 +++++++- .../org/opensearch/knn/jni/JNIService.java | 12 +- .../training/ByteTrainingDataConsumer.java | 2 +- .../knn/index/VectorDataTypeIT.java | 148 +++++++++++++++++- .../OffHeapVectorTransferFactoryTests.java | 2 +- .../index/engine/KNNMethodContextTests.java | 6 +- .../memory/NativeMemoryAllocationTests.java | 2 +- .../memory/NativeMemoryLoadStrategyTests.java | 2 +- .../knn/index/query/KNNQueryBuilderTests.java | 1 + .../java/org/opensearch/knn/TestUtils.java | 2 +- 34 files changed, 778 insertions(+), 59 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 56f207f46..349125beb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features * Integrate Lucene Vector field with native engines to use KNNVectorFormat during segment creation [#1945](https://github.com/opensearch-project/k-NN/pull/1945) * k-NN query rescore support for native engines [#1984](https://github.com/opensearch-project/k-NN/pull/1984) +* Add support for byte vector with Faiss Engine HNSW algorithm [#1823](https://github.com/opensearch-project/k-NN/pull/1823) ### Enhancements * Adds iterative graph build capability into a faiss index to improve the memory footprint during indexing and Integrates KNNVectorsFormat for native engines[#1950](https://github.com/opensearch-project/k-NN/pull/1950) ### Bug Fixes diff --git a/jni/include/commons.h b/jni/include/commons.h index 4cdaf28fc..e1aaacd9c 100644 --- a/jni/include/commons.h +++ b/jni/include/commons.h @@ -47,10 +47,24 @@ namespace knn_jni { * CAUTION: The behavior is undefined if the memory address is deallocated and the method is called * * @param memoryAddress The address of the memory location where data will be stored. - * @param data 2D byte array containing data to be stored in native memory. + * @param data 2D byte array containing binary data to be stored in native memory. * @param initialCapacity The initial capacity of the memory location. * @return memory address of std::vector where the data is stored. */ + jlong storeBinaryVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong, jboolean); + + /** + * This is utility function that can be used to store signed int8 data in native memory. This function will allocate memory for + * the data(rows*columns) with initialCapacity and return the memory address where the data is stored. + * If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created. + * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location + * will throw Exception. + * + * @param memoryAddress The address of the memory location where data will be stored. + * @param data 2D byte array containing int8 data to be stored in native memory. + * @param initialCapacity The initial capacity of the memory location. + * @return memory address of std::vector where the data is stored. + */ jlong storeByteVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong, jboolean); /** @@ -69,6 +83,14 @@ namespace knn_jni { */ void freeByteVectorData(jlong); + /** + * Free up the memory allocated for the data stored in memory address. This function should be used with the memory + * address returned by {@link JNICommons#storeBinaryVectorData(long, byte[][], long, long)} + * + * @param memoryAddress address to be freed. + */ + void freeBinaryVectorData(jlong); + /** * Extracts query time efSearch from method parameters **/ diff --git a/jni/include/faiss_index_service.h b/jni/include/faiss_index_service.h index c57309cfc..29ec90e80 100644 --- a/jni/include/faiss_index_service.h +++ b/jni/include/faiss_index_service.h @@ -125,6 +125,63 @@ class BinaryIndexService : public IndexService { virtual void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) override; }; +/** + * A class to provide operations on index + * This class should evolve to have only cpp object but not jni object + */ +class ByteIndexService : public IndexService { +public: + //TODO Remove dependency on JNIUtilInterface and JNIEnv + //TODO Reduce the number of parameters + ByteIndexService(std::unique_ptr faissMethods); + +/** + * Initialize index + * + * @param jniUtil jni util + * @param env jni environment + * @param metric space type for distance calculation + * @param indexDescription index description to be used by faiss index factory + * @param dim dimension of vectors + * @param numVectors number of vectors + * @param threadCount number of thread count to be used while adding data + * @param parameters parameters to be applied to faiss index + * @return memory address of the native index object + */ + virtual jlong initIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int dim, int numVectors, int threadCount, std::unordered_map parameters) override; + /** + * Add vectors to index + * + * @param jniUtil jni util + * @param env jni environment + * @param metric space type for distance calculation + * @param indexDescription index description to be used by faiss index factory + * @param dim dimension of vectors + * @param numIds number of vectors + * @param threadCount number of thread count to be used while adding data + * @param vectorsAddress memory address which is holding vector data + * @param idMap a map of document id and vector id + * @param parameters parameters to be applied to faiss index + */ + virtual void insertToIndex(int dim, int numIds, int threadCount, int64_t vectorsAddress, std::vector &ids, jlong idMapAddress) override; + /** + * Write index to disk + * + * @param jniUtil jni util + * @param env jni environment + * @param metric space type for distance calculation + * @param indexDescription index description to be used by faiss index factory + * @param threadCount number of thread count to be used while adding data + * @param indexPath path to write index + * @param idMap a map of document id and vector id + * @param parameters parameters to be applied to faiss index + */ + virtual void writeIndex(std::string indexPath, jlong idMapAddress) override; + virtual ~ByteIndexService() = default; +protected: + virtual void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) override; +}; + } } diff --git a/jni/include/jni_util.h b/jni/include/jni_util.h index 1579522d0..825471a3c 100644 --- a/jni/include/jni_util.h +++ b/jni/include/jni_util.h @@ -71,8 +71,10 @@ namespace knn_jni { virtual void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector *vect ) = 0; - virtual void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, + virtual void Convert2dJavaObjectArrayAndStoreToBinaryVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector *vect ) = 0; + virtual void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, + int dim, std::vector *vect ) = 0; virtual std::vector ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) = 0; @@ -173,7 +175,8 @@ namespace knn_jni { void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val); void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf); void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector *vect); - void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector *vect); + void Convert2dJavaObjectArrayAndStoreToBinaryVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector *vect); + void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector *vect); private: std::unordered_map cachedClasses; diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index 19e13d402..09f3ec8b7 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -34,6 +34,16 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initIndex(JNIEn JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initBinaryIndex(JNIEnv * env, jclass cls, jlong numDocs, jint dimJ, jobject parametersJ); + +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: initByteIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initByteIndex(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ); + /* * Class: org_opensearch_knn_jni_FaissService * Method: insertToIndex @@ -50,6 +60,16 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JN JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToBinaryIndex(JNIEnv * env, jclass cls, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jlong indexAddress, jint threadCount); + +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: insertToByteIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToByteIndex(JNIEnv * env, jclass cls, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, + jlong indexAddress, jint threadCount); + /* * Class: org_opensearch_knn_jni_FaissService * Method: writeIndex @@ -66,6 +86,16 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex(JNIEn JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeBinaryIndex(JNIEnv * env, jclass cls, jlong indexAddress, jstring indexPathJ); + +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: writeByteIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeByteIndex(JNIEnv * env, jclass cls, + jlong indexAddress, + jstring indexPathJ); + /* * Class: org_opensearch_knn_jni_FaissService * Method: createIndexFromTemplate diff --git a/jni/include/org_opensearch_knn_jni_JNICommons.h b/jni/include/org_opensearch_knn_jni_JNICommons.h index 03c0d023a..8bfbcc266 100644 --- a/jni/include/org_opensearch_knn_jni_JNICommons.h +++ b/jni/include/org_opensearch_knn_jni_JNICommons.h @@ -28,7 +28,15 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData /* * Class: org_opensearch_knn_jni_JNICommons - * Method: storeVectorData + * Method: storeBinaryVectorData + * Signature: (J[[FJJ) + */ +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeBinaryVectorData + (JNIEnv *, jclass, jlong, jobjectArray, jlong, jboolean); + +/* + * Class: org_opensearch_knn_jni_JNICommons + * Method: storeByteVectorData * Signature: (J[[FJJ) */ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeByteVectorData @@ -44,7 +52,15 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeVectorData /* * Class: org_opensearch_knn_jni_JNICommons -* Method: freeVectorData +* Method: freeBinaryVectorData +* Signature: (J)V +*/ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeBinaryVectorData +(JNIEnv *, jclass, jlong); + +/* +* Class: org_opensearch_knn_jni_JNICommons +* Method: freeByteVectorData * Signature: (J)V */ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeByteVectorData diff --git a/jni/src/commons.cpp b/jni/src/commons.cpp index f9764db73..38e3ac8a4 100644 --- a/jni/src/commons.cpp +++ b/jni/src/commons.cpp @@ -37,7 +37,7 @@ jlong knn_jni::commons::storeVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIE return (jlong) vect; } -jlong knn_jni::commons::storeByteVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ, +jlong knn_jni::commons::storeBinaryVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ) { std::vector *vect; if ((long) memoryAddressJ == 0) { @@ -51,6 +51,26 @@ jlong knn_jni::commons::storeByteVectorData(knn_jni::JNIUtilInterface *jniUtil, vect->clear(); } + int dim = jniUtil->GetInnerDimensionOf2dJavaByteArray(env, dataJ); + jniUtil->Convert2dJavaObjectArrayAndStoreToBinaryVector(env, dataJ, dim, vect); + + return (jlong) vect; +} + +jlong knn_jni::commons::storeByteVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ, + jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ) { + std::vector *vect; + if (memoryAddressJ == 0) { + vect = new std::vector(); + vect->reserve(static_cast(initialCapacityJ)); + } else { + vect = reinterpret_cast*>(memoryAddressJ); + } + + if (appendJ == JNI_FALSE) { + vect->clear(); + } + int dim = jniUtil->GetInnerDimensionOf2dJavaByteArray(env, dataJ); jniUtil->Convert2dJavaObjectArrayAndStoreToByteVector(env, dataJ, dim, vect); @@ -64,13 +84,20 @@ void knn_jni::commons::freeVectorData(jlong memoryAddressJ) { } } -void knn_jni::commons::freeByteVectorData(jlong memoryAddressJ) { +void knn_jni::commons::freeBinaryVectorData(jlong memoryAddressJ) { if (memoryAddressJ != 0) { auto *vect = reinterpret_cast*>(memoryAddressJ); delete vect; } } +void knn_jni::commons::freeByteVectorData(jlong memoryAddressJ) { + if (memoryAddressJ != 0) { + auto *vect = reinterpret_cast*>(memoryAddressJ); + delete vect; + } +} + int knn_jni::commons::getIntegerMethodParameter(JNIEnv * env, knn_jni::JNIUtilInterface * jniUtil, std::unordered_map methodParams, std::string methodParam, int defaultValue) { if (methodParams.empty()) { return defaultValue; diff --git a/jni/src/faiss_index_service.cpp b/jni/src/faiss_index_service.cpp index f76c54428..b6e465741 100644 --- a/jni/src/faiss_index_service.cpp +++ b/jni/src/faiss_index_service.cpp @@ -250,5 +250,120 @@ void BinaryIndexService::writeIndex( } } +ByteIndexService::ByteIndexService(std::unique_ptr faissMethods) : IndexService(std::move(faissMethods)) {} + +void ByteIndexService::allocIndex(faiss::Index * index, size_t dim, size_t numVectors) { + if(auto * indexHNSWSQ = dynamic_cast(index)) { + if(auto * indexScalarQuantizer = dynamic_cast(indexHNSWSQ->storage)) { + indexScalarQuantizer->codes.reserve(indexScalarQuantizer->code_size * numVectors); + } + return; + } +} + +jlong ByteIndexService::initIndex( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + faiss::MetricType metric, + std::string indexDescription, + int dim, + int numVectors, + int threadCount, + std::unordered_map parameters + ) { + // Create index using Faiss factory method + std::unique_ptr index(faissMethods->indexFactory(dim, indexDescription.c_str(), metric)); + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(threadCount != 0) { + omp_set_num_threads(threadCount); + } + + // Add extra parameters that cant be configured with the index factory + SetExtraParameters(jniUtil, env, parameters, index.get()); + + // Check that the index does not need to be trained + if(!index->is_trained) { + throw std::runtime_error("Index is not trained"); + } + + std::unique_ptr idMap (faissMethods->indexIdMap(index.get())); + //Makes sure the index is deleted when the destructor is called, this cannot be passed in the constructor + idMap->own_fields = true; + + allocIndex(dynamic_cast(idMap->index), dim, numVectors); + + //Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later + //in insert and write operations + index.release(); + return reinterpret_cast(idMap.release()); +} + +void ByteIndexService::insertToIndex( + int dim, + int numIds, + int threadCount, + int64_t vectorsAddress, + std::vector & ids, + jlong idMapAddress + ) { + // Read vectors from memory address + auto *inputVectors = reinterpret_cast*>(vectorsAddress); + + // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value + int numVectors = inputVectors->size() / dim; + if(numVectors == 0) { + throw std::runtime_error("Number of vectors cannot be 0"); + } + + if (numIds != numVectors) { + throw std::runtime_error("Number of IDs does not match number of vectors"); + } + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(threadCount != 0) { + omp_set_num_threads(threadCount); + } + + faiss::IndexIDMap * idMap = reinterpret_cast (idMapAddress); + + // Add vectors in batches by casting int8 vectors into float with a batch size of 1000 + int batchSize = 1000; + std::vector inputFloatVectors(batchSize * dim); + std::vector floatVectorsIds(batchSize); + int id = 0; + auto iter = inputVectors->begin(); + + for (int id = 0; id < numVectors; id += batchSize) { + if (numVectors - id < batchSize) { + batchSize = numVectors - id; + inputFloatVectors.resize(batchSize * dim); + floatVectorsIds.resize(batchSize); + } + + for (int i = 0; i < batchSize; ++i) { + floatVectorsIds[i] = ids[id + i]; + for (int j = 0; j < dim; ++j, ++iter) { + inputFloatVectors[i * dim + j] = static_cast(*iter); + } + } + idMap->add_with_ids(batchSize, inputFloatVectors.data(), floatVectorsIds.data()); + } +} + +void ByteIndexService::writeIndex( + std::string indexPath, + jlong idMapAddress + ) { + std::unique_ptr idMap (reinterpret_cast (idMapAddress)); + + try { + // Write the index to disk + faissMethods->writeIndex(idMap.get(), indexPath.c_str()); + } catch(std::exception &e) { + throw std::runtime_error("Failed to write index to disk"); + } +} + } // namespace faiss_wrapper } // namesapce knn_jni \ No newline at end of file diff --git a/jni/src/jni_util.cpp b/jni/src/jni_util.cpp index ee4c382b5..82900b5ce 100644 --- a/jni/src/jni_util.cpp +++ b/jni/src/jni_util.cpp @@ -261,7 +261,7 @@ void knn_jni::JNIUtil::Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env env->DeleteLocalRef(array2dJ); } -void knn_jni::JNIUtil::Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, +void knn_jni::JNIUtil::Convert2dJavaObjectArrayAndStoreToBinaryVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector *vect) { if (array2dJ == nullptr) { @@ -294,6 +294,38 @@ void knn_jni::JNIUtil::Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, env->DeleteLocalRef(array2dJ); } +void knn_jni::JNIUtil::Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, + int dim, std::vector *vect) { + + if (array2dJ == nullptr) { + throw std::runtime_error("Array cannot be null"); + } + + int numVectors = env->GetArrayLength(array2dJ); + this->HasExceptionInStack(env, "Unable to get array length"); + + for (int i = 0; i < numVectors; ++i) { + auto vectorArray = static_cast(env->GetObjectArrayElement(array2dJ, i)); + this->HasExceptionInStack(env, "Unable to get object array element"); + + if (dim != env->GetArrayLength(vectorArray)) { + env->DeleteLocalRef(array2dJ); + throw std::runtime_error("Dimension of vectors is inconsistent"); + } + + int8_t* vector = reinterpret_cast(env->GetByteArrayElements(vectorArray, nullptr)); + if (vector == nullptr) { + this->HasExceptionInStack(env); + throw std::runtime_error("Unable to get byte array elements"); + } + + vect->insert(vect->end(), vector, vector + dim); + env->ReleaseByteArrayElements(vectorArray, reinterpret_cast(vector), JNI_ABORT); + } + this->HasExceptionInStack(env); + env->DeleteLocalRef(array2dJ); +} + std::vector knn_jni::JNIUtil::ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) { if (arrayJ == nullptr) { diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 663e18457..bcdc4f18b 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -67,6 +67,20 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initBinaryIndex return (jlong)0; } +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initByteIndex(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::ByteIndexService byteIndexService(std::move(faissMethods)); + return knn_jni::faiss_wrapper::InitIndex(&jniUtil, env, numDocs, dimJ, parametersJ, &byteIndexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return (jlong)0; +} + JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JNIEnv * env, jclass cls, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jlong indexAddress, jint threadCount) @@ -95,6 +109,20 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToBinaryIn } } +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToByteIndex(JNIEnv * env, jclass cls, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, + jlong indexAddress, jint threadCount) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::ByteIndexService byteIndexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::InsertToIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexAddress, threadCount, &byteIndexService); + } catch (...) { + // NOTE: ADDING DELETE STATEMENT HERE CAUSES A CRASH! + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex(JNIEnv * env, jclass cls, jlong indexAddress, jstring indexPathJ) @@ -121,6 +149,19 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeBinaryIndex } } +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeByteIndex(JNIEnv * env, jclass cls, + jlong indexAddress, + jstring indexPathJ) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::ByteIndexService byteIndexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::WriteIndex(&jniUtil, env, indexPathJ, indexAddress, &byteIndexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate(JNIEnv * env, jclass cls, jintArray idsJ, jlong vectorsAddressJ, diff --git a/jni/src/org_opensearch_knn_jni_JNICommons.cpp b/jni/src/org_opensearch_knn_jni_JNICommons.cpp index 7432c44d3..906592b2d 100644 --- a/jni/src/org_opensearch_knn_jni_JNICommons.cpp +++ b/jni/src/org_opensearch_knn_jni_JNICommons.cpp @@ -49,6 +49,18 @@ jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ, jboolean appen return (long)memoryAddressJ; } +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeBinaryVectorData(JNIEnv * env, jclass cls, +jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ) + +{ + try { + return knn_jni::commons::storeBinaryVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ, appendJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return (long)memoryAddressJ; +} + JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeByteVectorData(JNIEnv * env, jclass cls, jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ) @@ -72,6 +84,16 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeVectorData(JNI } +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeBinaryVectorData(JNIEnv * env, jclass cls, + jlong memoryAddressJ) +{ + try { + return knn_jni::commons::freeBinaryVectorData(memoryAddressJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeByteVectorData(JNIEnv * env, jclass cls, jlong memoryAddressJ) { diff --git a/jni/tests/faiss_index_service_test.cpp b/jni/tests/faiss_index_service_test.cpp index 1f00f6a1d..8d9e4bb43 100644 --- a/jni/tests/faiss_index_service_test.cpp +++ b/jni/tests/faiss_index_service_test.cpp @@ -113,4 +113,48 @@ TEST(CreateBinaryIndexTest, BasicAssertions) { long indexAddress = indexService.initIndex(&mockJNIUtil, jniEnv, metricType, indexDescription, dim, numIds, threadCount, parametersMap); indexService.insertToIndex(dim, numIds, threadCount, (int64_t) &vectors, ids, indexAddress); indexService.writeIndex(indexPath, indexAddress); +} + +TEST(CreateByteIndexTest, BasicAssertions) { + // Define the data + faiss::idx_t numIds = 200; + std::vector ids; + std::vector vectors; + int dim = 8; + vectors.reserve(numIds * dim); + for (int64_t i = 0; i < numIds; ++i) { + ids.push_back(i); + for (int j = 0; j < dim; ++j) { + vectors.push_back(test_util::RandomInt(-128, 127)); + } + } + + std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); + faiss::MetricType metricType = faiss::METRIC_L2; + std::string indexDescription = "HNSW16,SQ8_direct_signed"; + int threadCount = 1; + std::unordered_map parametersMap; + + // Set up jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + + // Setup faiss method mock + // This object is handled by unique_ptr inside indexService.createIndex() + MockIndex* index = new MockIndex(); + // This object is handled by unique_ptr inside indexService.createIndex() + faiss::IndexIDMap* indexIdMap = new faiss::IndexIDMap(index); + std::unique_ptr mockFaissMethods(new MockFaissMethods()); + EXPECT_CALL(*mockFaissMethods, indexFactory(dim, ::testing::StrEq(indexDescription.c_str()), metricType)) + .WillOnce(Return(index)); + EXPECT_CALL(*mockFaissMethods, indexIdMap(index)) + .WillOnce(Return(indexIdMap)); + EXPECT_CALL(*mockFaissMethods, writeIndex(indexIdMap, ::testing::StrEq(indexPath.c_str()))) + .Times(1); + + // Create the index + knn_jni::faiss_wrapper::ByteIndexService indexService(std::move(mockFaissMethods)); + long indexAddress = indexService.initIndex(&mockJNIUtil, jniEnv, metricType, indexDescription, dim, numIds, threadCount, parametersMap); + indexService.insertToIndex(dim, numIds, threadCount, (int64_t) &vectors, ids, indexAddress); + indexService.writeIndex(indexPath, indexAddress); } \ No newline at end of file diff --git a/jni/tests/test_util.cpp b/jni/tests/test_util.cpp index 2149f8a1a..4f8bd2c34 100644 --- a/jni/tests/test_util.cpp +++ b/jni/tests/test_util.cpp @@ -51,12 +51,18 @@ test_util::MockJNIUtil::MockJNIUtil() { (*reinterpret_cast> *>(array2dJ))) for (auto item : v) data->push_back(item); }); - ON_CALL(*this, Convert2dJavaObjectArrayAndStoreToByteVector) + ON_CALL(*this, Convert2dJavaObjectArrayAndStoreToBinaryVector) .WillByDefault([this](JNIEnv *env, jobjectArray array2dJ, int dim, std::vector* data) { for (const auto &v : (*reinterpret_cast> *>(array2dJ))) for (auto item : v) data->push_back(item); }); + ON_CALL(*this, Convert2dJavaObjectArrayAndStoreToByteVector) + .WillByDefault([this](JNIEnv *env, jobjectArray array2dJ, int dim, std::vector* data) { + for (const auto &v : + (*reinterpret_cast> *>(array2dJ))) + for (auto item : v) data->push_back(item); + }); // arrayJ is re-interpreted as std::vector * diff --git a/jni/tests/test_util.h b/jni/tests/test_util.h index ba773fad3..a90d45dd9 100644 --- a/jni/tests/test_util.h +++ b/jni/tests/test_util.h @@ -46,8 +46,10 @@ namespace test_util { (JNIEnv * env, jobjectArray array2dJ, int dim)); MOCK_METHOD(void, Convert2dJavaObjectArrayAndStoreToFloatVector, (JNIEnv * env, jobjectArray array2dJ, int dim, std::vector*vect)); - MOCK_METHOD(void, Convert2dJavaObjectArrayAndStoreToByteVector, + MOCK_METHOD(void, Convert2dJavaObjectArrayAndStoreToBinaryVector, (JNIEnv * env, jobjectArray array2dJ, int dim, std::vector*vect)); + MOCK_METHOD(void, Convert2dJavaObjectArrayAndStoreToByteVector, + (JNIEnv * env, jobjectArray array2dJ, int dim, std::vector*vect)); MOCK_METHOD(std::vector, ConvertJavaIntArrayToCppIntVector, (JNIEnv * env, jintArray arrayJ)); MOCK_METHOD2(ConvertJavaMapToCppMap, diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index bef2ccf0c..ac6f69d31 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -111,6 +111,7 @@ public class KNNConstants { public static final String FAISS_SQ_TYPE = "type"; public static final String FAISS_SQ_ENCODER_FP16 = "fp16"; public static final List FAISS_SQ_ENCODER_TYPES = List.of(FAISS_SQ_ENCODER_FP16); + public static final String FAISS_SIGNED_BYTE_SQ = "SQ8_direct_signed"; public static final String FAISS_SQ_CLIP = "clip"; // Parameter defaults/limits diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java index c9d4802fe..ffa12a231 100644 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java @@ -5,6 +5,8 @@ package org.opensearch.knn.index.codec.transfer; +import org.opensearch.knn.jni.JNICommons; + import java.io.IOException; import java.util.List; @@ -21,12 +23,16 @@ public OffHeapBinaryVectorTransfer(int transferLimit) { @Override public void deallocate() { - // TODO: deallocate the memory location + JNICommons.freeBinaryVectorData(getVectorAddress()); } @Override - protected long transfer(List vectorsToTransfer, boolean append) throws IOException { - // TODO: call to JNIService to transfer vector - return 0L; + protected long transfer(List batch, boolean append) throws IOException { + return JNICommons.storeBinaryVectorData( + getVectorAddress(), + batch.toArray(new byte[][] {}), + (long) batch.get(0).length * transferLimit, + append + ); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactory.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactory.java index bfcc13491..446b6ae80 100644 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactory.java +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactory.java @@ -27,7 +27,7 @@ public static OffHeapVectorTransfer getVectorTransfer(final VectorDataTyp case FLOAT: return (OffHeapVectorTransfer) new OffHeapFloatVectorTransfer(transferLimit); case BINARY: - // TODO: Add binary here + return (OffHeapVectorTransfer) new OffHeapBinaryVectorTransfer(transferLimit); case BYTE: return (OffHeapVectorTransfer) new OffHeapByteVectorTransfer(transferLimit); default: diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java index 15e43cdd5..dbe89bd0a 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java @@ -35,7 +35,11 @@ */ public class FaissHNSWMethod extends AbstractFaissMethod { - private static final Set SUPPORTED_DATA_TYPES = ImmutableSet.of(VectorDataType.FLOAT, VectorDataType.BINARY); + private static final Set SUPPORTED_DATA_TYPES = ImmutableSet.of( + VectorDataType.FLOAT, + VectorDataType.BINARY, + VectorDataType.BYTE + ); public final static List SUPPORTED_SPACES = Arrays.asList( SpaceType.UNDEFINED, diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java index 57a4dd062..5ab2dd888 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -23,10 +23,10 @@ import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.KnnCircuitBreakerException; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.util.IndexHyperParametersUtil; diff --git a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java index 90d4ca879..75f01ba24 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java @@ -11,6 +11,7 @@ import org.opensearch.common.Explicit; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.engine.KNNLibraryIndexingContext; @@ -24,6 +25,8 @@ import java.util.Optional; import static org.opensearch.knn.common.KNNConstants.DIMENSION; +import static org.opensearch.knn.common.KNNConstants.FAISS_SIGNED_BYTE_SQ; +import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.common.KNNConstants.QFRAMEWORK_CONFIG; @@ -127,10 +130,23 @@ private MethodFieldMapper( this.fieldType.putAttribute(KNN_ENGINE, knnEngine.getName()); try { - this.fieldType.putAttribute( - PARAMETERS, - XContentFactory.jsonBuilder().map(knnLibraryIndexingContext.getLibraryParameters()).toString() - ); + Map libParams = knnLibraryIndexingContext.getLibraryParameters(); + + // If VectorDataType is Byte using Faiss engine then manipulate Index Description to use "SQ8_direct_signed" scalar quantizer + // For example, Index Description "HNSW16,Flat" will be updated as "HNSW16,SQ8_direct_signed" + if (VectorDataType.BYTE == vectorDataType && libParams.containsKey(INDEX_DESCRIPTION_PARAMETER)) { + String indexDescriptionValue = (String) libParams.get(INDEX_DESCRIPTION_PARAMETER); + if (indexDescriptionValue != null && indexDescriptionValue.isEmpty() == false) { + StringBuilder indexDescriptionBuilder = new StringBuilder(); + indexDescriptionBuilder.append(indexDescriptionValue.split(",")[0]); + indexDescriptionBuilder.append(","); + indexDescriptionBuilder.append(FAISS_SIGNED_BYTE_SQ); + + libParams.replace(INDEX_DESCRIPTION_PARAMETER, indexDescriptionBuilder.toString()); + libParams.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.BYTE.getValue()); + } + } + this.fieldType.putAttribute(PARAMETERS, XContentFactory.jsonBuilder().map(libParams).toString()); } catch (IOException ioe) { throw new RuntimeException(String.format("Unable to create KNNVectorFieldMapper: %s", ioe)); } diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java index 4f92a9c4b..755b6b925 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java @@ -300,7 +300,7 @@ private void cleanup() { if (this.memoryAddress != 0) { if (IndexUtil.isBinaryIndex(vectorDataType)) { - JNICommons.freeByteVectorData(this.memoryAddress); + JNICommons.freeBinaryVectorData(this.memoryAddress); } else { JNICommons.freeVectorData(this.memoryAddress); } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index cba5692c9..208e075eb 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -481,22 +481,32 @@ protected Query doToQuery(QueryShardContext context) { } byte[] byteVector = new byte[0]; - if (VectorDataType.BINARY == vectorDataType) { - byteVector = new byte[vector.length]; - for (int i = 0; i < vector.length; i++) { - validateByteVectorValue(vector[i], knnVectorFieldType.getVectorDataType()); - byteVector[i] = (byte) vector[i]; - } - spaceType.validateVector(byteVector); - } else if (VectorDataType.BYTE == vectorDataType) { - byteVector = new byte[vector.length]; - for (int i = 0; i < vector.length; i++) { - validateByteVectorValue(vector[i], knnVectorFieldType.getVectorDataType()); - byteVector[i] = (byte) vector[i]; - } - spaceType.validateVector(byteVector); - } else { - spaceType.validateVector(vector); + switch (vectorDataType) { + case BINARY: + byteVector = new byte[vector.length]; + for (int i = 0; i < vector.length; i++) { + validateByteVectorValue(vector[i], knnVectorFieldType.getVectorDataType()); + byteVector[i] = (byte) vector[i]; + } + spaceType.validateVector(byteVector); + break; + case BYTE: + if (KNNEngine.LUCENE == knnEngine) { + byteVector = new byte[vector.length]; + for (int i = 0; i < vector.length; i++) { + validateByteVectorValue(vector[i], knnVectorFieldType.getVectorDataType()); + byteVector[i] = (byte) vector[i]; + } + spaceType.validateVector(byteVector); + } else { + for (float v : vector) { + validateByteVectorValue(v, knnVectorFieldType.getVectorDataType()); + } + spaceType.validateVector(vector); + } + break; + default: + spaceType.validateVector(vector); } if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine) @@ -512,8 +522,8 @@ protected Query doToQuery(QueryShardContext context) { .knnEngine(knnEngine) .indexName(indexName) .fieldName(this.fieldName) - .vector(VectorDataType.FLOAT == vectorDataType ? this.vector : null) - .byteVector(VectorDataType.BYTE == vectorDataType || VectorDataType.BINARY == vectorDataType ? byteVector : null) + .vector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine)) + .byteVector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, byteVector)) .vectorDataType(vectorDataType) .k(this.k) .methodParameters(this.methodParameters) @@ -581,6 +591,20 @@ private void updateQueryStats(VectorQueryType vectorQueryType) { } } + private float[] getVectorForCreatingQueryRequest(VectorDataType vectorDataType, KNNEngine knnEngine) { + if ((VectorDataType.FLOAT == vectorDataType) || (VectorDataType.BYTE == vectorDataType && KNNEngine.FAISS == knnEngine)) { + return this.vector; + } + return null; + } + + private byte[] getVectorForCreatingQueryRequest(VectorDataType vectorDataType, KNNEngine knnEngine, byte[] byteVector) { + if (VectorDataType.BINARY == vectorDataType || (VectorDataType.BYTE == vectorDataType && KNNEngine.LUCENE == knnEngine)) { + return byteVector; + } + return null; + } + @Override protected boolean doEquals(KNNQueryBuilder other) { return Objects.equals(fieldName, other.fieldName) diff --git a/src/main/java/org/opensearch/knn/index/util/IndexUtil.java b/src/main/java/org/opensearch/knn/index/util/IndexUtil.java index 84a4e8940..431579fae 100644 --- a/src/main/java/org/opensearch/knn/index/util/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/util/IndexUtil.java @@ -415,4 +415,16 @@ private static Map initializeMinimalRequiredVersionMap() { } return Collections.unmodifiableMap(versionMap); } + + /** + * Tell if it is byte index or not + * + * @param parameters parameters associated with an index + * @return true if it is binary index + */ + public static boolean isByteIndex(Map parameters) { + return parameters.getOrDefault(VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()) + .toString() + .equals(VectorDataType.BYTE.getValue()); + } } diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index a402be1f3..26c703eeb 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -69,6 +69,16 @@ class FaissService { */ public static native long initBinaryIndex(long numDocs, int dim, Map parameters); + /** + * Initialize a byte index for the native library. Takes in numDocs to + * allocate the correct amount of memory. + * + * @param numDocs number of documents to be added + * @param dim dimension of the vector to be indexed + * @param parameters parameters to build index + */ + public static native long initByteIndex(long numDocs, int dim, Map parameters); + /** * Inserts to a faiss index. The memory occupied by the vectorsAddress will be freed up during the * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer @@ -95,6 +105,19 @@ class FaissService { */ public static native void insertToBinaryIndex(int[] ids, long vectorsAddress, int dim, long indexAddress, int threadCount); + /** + * Inserts to a faiss index. The memory occupied by the vectorsAddress will be freed up during the + * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer + * created the memory address and that should only free up the memory. + * + * @param ids ids of documents + * @param vectorsAddress address of native memory where vectors are stored + * @param dim dimension of the vector to be indexed + * @param indexAddress address of native memory where index is stored + * @param threadCount number of threads to use for insertion + */ + public static native void insertToByteIndex(int[] ids, long vectorsAddress, int dim, long indexAddress, int threadCount); + /** * Writes a faiss index. * @@ -115,6 +138,16 @@ class FaissService { */ public static native void writeBinaryIndex(long indexAddress, String indexPath); + /** + * Writes a faiss index. + * + * NOTE: This will always free the index. Do not call free after this. + * + * @param indexAddress address of native memory where index is stored + * @param indexPath path to save index file to + */ + public static native void writeByteIndex(long indexAddress, String indexPath); + /** * Create an index for the native library with a provided template index * diff --git a/src/main/java/org/opensearch/knn/jni/JNICommons.java b/src/main/java/org/opensearch/knn/jni/JNICommons.java index c7222738e..df1024db4 100644 --- a/src/main/java/org/opensearch/knn/jni/JNICommons.java +++ b/src/main/java/org/opensearch/knn/jni/JNICommons.java @@ -70,7 +70,28 @@ public static long storeVectorData(long memoryAddress, float[][] data, long init public static native long storeVectorData(long memoryAddress, float[][] data, long initialCapacity, boolean append); /** - * This is utility function that can be used to store data in native memory. This function will allocate memory for + * This is utility function that can be used to store binary data in native memory. This function will allocate memory for + * the data(rows*columns) with initialCapacity and return the memory address where the data is stored. + * If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created. + * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location + * will throw Exception. + * + *

+ * The function is not threadsafe. If multiple threads are trying to insert on same memory location, then it can + * lead to data corruption. + *

+ * + * @param memoryAddress The address of the memory location where data will be stored. + * @param data 2D byte array containing binary data to be stored in native memory. + * @param initialCapacity The initial capacity of the memory location. + * @return memory address where the data is stored. + */ + public static long storeBinaryVectorData(long memoryAddress, byte[][] data, long initialCapacity) { + return storeBinaryVectorData(memoryAddress, data, initialCapacity, true); + } + + /** + * This is utility function that can be used to store binary data in native memory. This function will allocate memory for * the data(rows*columns) with initialCapacity and return the memory address where the data is stored. * If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created. * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location @@ -82,7 +103,27 @@ public static long storeVectorData(long memoryAddress, float[][] data, long init *

* * @param memoryAddress The address of the memory location where data will be stored. - * @param data 2D byte array containing data to be stored in native memory. + * @param data 2D byte array containing binary data to be stored in native memory. + * @param initialCapacity The initial capacity of the memory location. + * @param append append the data or rewrite the memory location + * @return memory address where the data is stored. + */ + public static native long storeBinaryVectorData(long memoryAddress, byte[][] data, long initialCapacity, boolean append); + + /** + * This is utility function that can be used to store byte data in native memory. This function will allocate memory for + * the data(rows*columns) with initialCapacity and return the memory address where the data is stored. + * If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created. + * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location + * will throw Exception. + * + *

+ * The function is not threadsafe. If multiple threads are trying to insert on same memory location, then it can + * lead to data corruption. + *

+ * + * @param memoryAddress The address of the memory location where data will be stored. + * @param data 2D byte array containing byte data to be stored in native memory. * @param initialCapacity The initial capacity of the memory location. * @return memory address where the data is stored. */ @@ -91,7 +132,7 @@ public static long storeByteVectorData(long memoryAddress, byte[][] data, long i } /** - * This is utility function that can be used to store data in native memory. This function will allocate memory for + * This is utility function that can be used to store byte data in native memory. This function will allocate memory for * the data(rows*columns) with initialCapacity and return the memory address where the data is stored. * If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created. * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location @@ -103,7 +144,7 @@ public static long storeByteVectorData(long memoryAddress, byte[][] data, long i *

* * @param memoryAddress The address of the memory location where data will be stored. - * @param data 2D byte array containing data to be stored in native memory. + * @param data 2D byte array containing byte data to be stored in native memory. * @param initialCapacity The initial capacity of the memory location. * @param append append the data or rewrite the memory location * @return memory address where the data is stored. @@ -124,8 +165,21 @@ public static long storeByteVectorData(long memoryAddress, byte[][] data, long i public static native void freeVectorData(long memoryAddress); /** - * Free up the memory allocated for the byte data stored in memory address. This function should be used with the memory - * address returned by {@link JNICommons#storeVectorData(long, float[][], long, boolean)} + * Free up the memory allocated for the binary data stored in memory address. This function should be used with the memory + * address returned by {@link JNICommons#storeBinaryVectorData(long, byte[][], long)} + * + *

+ * The function is not threadsafe. If multiple threads are trying to free up same memory location, then it can + * lead to errors. + *

+ * + * @param memoryAddress address to be freed. + */ + public static native void freeBinaryVectorData(long memoryAddress); + + /** + * Free up the memory allocated for the binary data stored in memory address. This function should be used with the memory + * address returned by {@link JNICommons#storeBinaryVectorData(long, byte[][], long)} * *

* The function is not threadsafe. If multiple threads are trying to free up same memory location, then it can diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index d1d5f6c11..1177d635e 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -39,9 +39,13 @@ public static long initIndex(long numDocs, int dim, Map paramete if (KNNEngine.FAISS == knnEngine) { if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { return FaissService.initBinaryIndex(numDocs, dim, parameters); - } else { - return FaissService.initIndex(numDocs, dim, parameters); } + if (IndexUtil.isByteIndex(parameters)) { + return FaissService.initByteIndex(numDocs, dim, parameters); + } + + return FaissService.initIndex(numDocs, dim, parameters); + } throw new IllegalArgumentException( @@ -71,6 +75,8 @@ public static void insertToIndex( if (KNNEngine.FAISS == knnEngine) { if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { FaissService.insertToBinaryIndex(docs, vectorsAddress, dimension, indexAddress, threadCount); + } else if (IndexUtil.isByteIndex(parameters)) { + FaissService.insertToByteIndex(docs, vectorsAddress, dimension, indexAddress, threadCount); } else { FaissService.insertToIndex(docs, vectorsAddress, dimension, indexAddress, threadCount); } @@ -94,6 +100,8 @@ public static void writeIndex(String indexPath, long indexAddress, KNNEngine knn if (KNNEngine.FAISS == knnEngine) { if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { FaissService.writeBinaryIndex(indexAddress, indexPath); + } else if (IndexUtil.isByteIndex(parameters)) { + FaissService.writeByteIndex(indexAddress, indexPath); } else { FaissService.writeIndex(indexAddress, indexPath); } diff --git a/src/main/java/org/opensearch/knn/training/ByteTrainingDataConsumer.java b/src/main/java/org/opensearch/knn/training/ByteTrainingDataConsumer.java index 70cfb4f4c..e838b5214 100644 --- a/src/main/java/org/opensearch/knn/training/ByteTrainingDataConsumer.java +++ b/src/main/java/org/opensearch/knn/training/ByteTrainingDataConsumer.java @@ -39,7 +39,7 @@ public ByteTrainingDataConsumer(NativeMemoryAllocation.TrainingDataAllocation tr @Override public void accept(List byteVectors) { long memoryAddress = trainingDataAllocation.getMemoryAddress(); - memoryAddress = JNICommons.storeByteVectorData(memoryAddress, byteVectors.toArray(new byte[0][0]), byteVectors.size()); + memoryAddress = JNICommons.storeBinaryVectorData(memoryAddress, byteVectors.toArray(new byte[0][0]), byteVectors.size()); trainingDataAllocation.setMemoryAddress(memoryAddress); } diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java index ea137f042..fe621d7d4 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java @@ -13,15 +13,15 @@ import org.opensearch.client.ResponseException; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.KNNResult; import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.core.rest.RestStatus; +import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.script.Script; import java.util.ArrayList; @@ -32,7 +32,13 @@ import java.util.Map; import static org.opensearch.knn.common.KNNConstants.DIMENSION; +import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.NAME; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; @@ -453,6 +459,138 @@ public void testSearchWithMissingQueryVector() { assertTrue(ex.getMessage().contains("[knn] requires query vector")); } + @SneakyThrows + public void testAddDocWithByteVectorUsingFaissEngine() { + createKnnIndexMappingWithFaissEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); + Byte[] vector = { 6, 6 }; + addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector); + + refreshAllIndices(); + assertEquals(1, getDocCount(INDEX_NAME)); + } + + @SneakyThrows + public void testUpdateDocWithByteVectorUsingFaissEngine() { + createKnnIndexMappingWithFaissEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); + Byte[] vector = { -36, 78 }; + addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector); + + Byte[] updatedVector = { 89, -8 }; + updateKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, updatedVector); + + refreshAllIndices(); + assertEquals(1, getDocCount(INDEX_NAME)); + } + + @SneakyThrows + public void testDeleteDocWithByteVectorUsingFaissEngine() { + createKnnIndexMappingWithFaissEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); + Byte[] vector = { 35, -46 }; + addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector); + + deleteKnnDoc(INDEX_NAME, DOC_ID); + refreshAllIndices(); + + assertEquals(0, getDocCount(INDEX_NAME)); + } + + @SneakyThrows + public void testSearchWithByteVectorUsingFaissEngine() { + createKnnIndexMappingWithFaissEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); + ingestL2ByteTestData(); + + Byte[] queryVector = { 1, 1 }; + Response response = searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, convertByteToFloatArray(queryVector), 4), 4); + + validateL2SearchResults(response); + } + + @SneakyThrows + public void testInvalidVectorDataUsingFaissEngine() { + createKnnIndexMappingWithFaissEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); + Float[] vector = { -10.76f, 15.89f }; + + ResponseException ex = expectThrows(ResponseException.class, () -> addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector)); + assertTrue( + ex.getMessage() + .contains( + String.format( + Locale.ROOT, + "[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue() + ) + ) + ); + } + + // Create an index with byte vector data_type and add a doc with values out of byte range which should throw exception + @SneakyThrows + public void testInvalidByteVectorRangeUsingFaissEngine() { + createKnnIndexMappingWithFaissEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); + Float[] vector = { -1000f, 155f }; + + ResponseException ex = expectThrows(ResponseException.class, () -> addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector)); + assertTrue( + ex.getMessage() + .contains( + String.format( + Locale.ROOT, + "[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [%d, %d]", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue(), + Byte.MIN_VALUE, + Byte.MAX_VALUE + ) + ) + ); + } + + // Create an index with byte vector data_type using faiss engine with an encoder which should throw an exception + @SneakyThrows + public void testByteVectorDataTypeWithFaissEngineUsingEncoderThrowsException() { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME) + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION, 2) + .field(VECTOR_DATA_TYPE_FIELD, VectorDataType.BYTE.getValue()) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, METHOD_HNSW) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, M) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, EF_CONSTRUCTION) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_SQ) + .startObject(PARAMETERS) + .field(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + String mapping = builder.toString(); + expectThrows(ResponseException.class, () -> createKnnIndex(INDEX_NAME, mapping)); + } + + public void testDocValuesWithByteVectorDataTypeFaissEngine() throws Exception { + createKnnIndexMappingWithFaissEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); + ingestL2ByteTestData(); + + Byte[] queryVector = { 1, 1 }; + Request request = createScriptQueryRequest(queryVector, SpaceType.L2.getValue(), MATCH_ALL_QUERY_BUILDER); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + validateL2SearchResults(response); + } + @SneakyThrows private void ingestL2ByteTestData() { Byte[] b1 = { 6, 6 }; @@ -491,6 +629,10 @@ private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spac createKnnIndexMappingWithCustomEngine(dimension, spaceType, vectorDataType, KNNEngine.LUCENE.getName()); } + private void createKnnIndexMappingWithFaissEngine(int dimension, SpaceType spaceType, String vectorDataType) throws Exception { + createKnnIndexMappingWithCustomEngine(dimension, spaceType, vectorDataType, KNNEngine.FAISS.getName()); + } + private void createKnnIndexMappingWithCustomEngine(int dimension, SpaceType spaceType, String vectorDataType, String engine) throws Exception { XContentBuilder builder = XContentFactory.jsonBuilder() @@ -504,7 +646,7 @@ private void createKnnIndexMappingWithCustomEngine(int dimension, SpaceType spac .field(KNNConstants.NAME, METHOD_HNSW) .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) .field(KNNConstants.KNN_ENGINE, engine) - .startObject(KNNConstants.PARAMETERS) + .startObject(PARAMETERS) .field(KNNConstants.METHOD_PARAMETER_M, M) .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, EF_CONSTRUCTION) .endObject() diff --git a/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactoryTests.java b/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactoryTests.java index cef875cfc..39415d811 100644 --- a/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactoryTests.java @@ -20,7 +20,7 @@ public void testOffHeapVectorTransferFactory() { assertNotSame(byteVectorTransfer, OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BYTE, 10)); var binaryVectorTransfer = OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BINARY, 10); - assertEquals(OffHeapByteVectorTransfer.class, binaryVectorTransfer.getClass()); + assertEquals(OffHeapBinaryVectorTransfer.class, binaryVectorTransfer.getClass()); assertNotSame(binaryVectorTransfer, OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BINARY, 10)); } } diff --git a/src/test/java/org/opensearch/knn/index/engine/KNNMethodContextTests.java b/src/test/java/org/opensearch/knn/index/engine/KNNMethodContextTests.java index 6defa4c50..f142a9770 100644 --- a/src/test/java/org/opensearch/knn/index/engine/KNNMethodContextTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/KNNMethodContextTests.java @@ -455,12 +455,12 @@ public void testValidateVectorDataType_whenBinaryNonFaiss_thenException() { ); } - public void testValidateVectorDataType_whenByteLucene_thenValid() { + public void testValidateVectorDataType_whenByte_thenValid() { validateValidateVectorDataType(KNNEngine.LUCENE, KNNConstants.METHOD_HNSW, VectorDataType.BYTE, SpaceType.L2, null); + validateValidateVectorDataType(KNNEngine.FAISS, KNNConstants.METHOD_HNSW, VectorDataType.BYTE, SpaceType.L2, null); } - public void testValidateVectorDataType_whenByteNonLucene_thenException() { - validateValidateVectorDataType(KNNEngine.FAISS, KNNConstants.METHOD_HNSW, VectorDataType.BYTE, SpaceType.L2, "UnsupportedMethod"); + public void testValidateVectorDataType_whenByte_thenException() { validateValidateVectorDataType(KNNEngine.NMSLIB, KNNConstants.METHOD_IVF, VectorDataType.BYTE, SpaceType.L2, "UnsupportedMethod"); } diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java index 316582f6c..1e2134581 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java @@ -117,7 +117,7 @@ public void testClose_whenBinaryFiass_thenSuccess() { KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.BINARY.getValue() ); - long vectorMemoryAddress = JNICommons.storeByteVectorData(0, vectors, numVectors * dataLength); + long vectorMemoryAddress = JNICommons.storeBinaryVectorData(0, vectors, numVectors * dataLength); TestUtils.createIndex(ids, vectorMemoryAddress, dimension, path, parameters, knnEngine); // Load index into memory diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java index 8a38cadb5..29fbdb978 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java @@ -106,7 +106,7 @@ public void testLoad_whenFaissBinary_thenSuccess() throws IOException { KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.BINARY.getValue() ); - long memoryAddress = JNICommons.storeByteVectorData(0, vectors, numVectors); + long memoryAddress = JNICommons.storeBinaryVectorData(0, vectors, numVectors); TestUtils.createIndex(ids, memoryAddress, dimension, path, parameters, knnEngine); // Setup mock resource manager diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 762a36227..b7de89564 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -961,6 +961,7 @@ public void testRadialSearch_whenEfSearchIsSet_whenFaissEngine_thenSuccess() { Index dummyIndex = new Index("dummy", "dummy"); when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); IndexSettings indexSettings = mock(IndexSettings.class); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); diff --git a/src/testFixtures/java/org/opensearch/knn/TestUtils.java b/src/testFixtures/java/org/opensearch/knn/TestUtils.java index 6bbbc8a5b..b644c37f8 100644 --- a/src/testFixtures/java/org/opensearch/knn/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/knn/TestUtils.java @@ -403,7 +403,7 @@ public long loadDataToMemoryAddress() { } public long loadBinaryDataToMemoryAddress() { - return JNICommons.storeByteVectorData(0, indexBinaryData, (long) indexBinaryData.length * indexBinaryData[0].length, true); + return JNICommons.storeBinaryVectorData(0, indexBinaryData, (long) indexBinaryData.length * indexBinaryData[0].length, true); } @AllArgsConstructor From 3cd458b4e2b24162f7fc6e190316f714940b0a34 Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Wed, 21 Aug 2024 23:09:41 -0500 Subject: [PATCH 2/2] Address Review Comments Signed-off-by: Naveen Tatikonda --- jni/include/commons.h | 2 ++ jni/src/faiss_index_service.cpp | 6 ++--- .../engine/faiss/AbstractFaissMethod.java | 17 ++++++++++++- .../index/engine/faiss/FaissHNSWMethod.java | 2 +- .../index/engine/faiss/FaissIVFMethod.java | 2 +- .../knn/index/mapper/MethodFieldMapper.java | 24 ++++--------------- 6 files changed, 26 insertions(+), 27 deletions(-) diff --git a/jni/include/commons.h b/jni/include/commons.h index e1aaacd9c..3f1ee19a3 100644 --- a/jni/include/commons.h +++ b/jni/include/commons.h @@ -49,6 +49,7 @@ namespace knn_jni { * @param memoryAddress The address of the memory location where data will be stored. * @param data 2D byte array containing binary data to be stored in native memory. * @param initialCapacity The initial capacity of the memory location. + * @param append whether to append or start from index 0 when called subsequently with the same address * @return memory address of std::vector where the data is stored. */ jlong storeBinaryVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong, jboolean); @@ -63,6 +64,7 @@ namespace knn_jni { * @param memoryAddress The address of the memory location where data will be stored. * @param data 2D byte array containing int8 data to be stored in native memory. * @param initialCapacity The initial capacity of the memory location. + * @param append whether to append or start from index 0 when called subsequently with the same address * @return memory address of std::vector where the data is stored. */ jlong storeByteVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong, jboolean); diff --git a/jni/src/faiss_index_service.cpp b/jni/src/faiss_index_service.cpp index b6e465741..16ded4bcb 100644 --- a/jni/src/faiss_index_service.cpp +++ b/jni/src/faiss_index_service.cpp @@ -327,7 +327,8 @@ void ByteIndexService::insertToIndex( faiss::IndexIDMap * idMap = reinterpret_cast (idMapAddress); - // Add vectors in batches by casting int8 vectors into float with a batch size of 1000 + // Add vectors in batches by casting int8 vectors into float with a batch size of 1000 to avoid additional memory spike. + // Refer to this github issue for more details https://github.com/opensearch-project/k-NN/issues/1659#issuecomment-2307390255 int batchSize = 1000; std::vector inputFloatVectors(batchSize * dim); std::vector floatVectorsIds(batchSize); @@ -337,8 +338,6 @@ void ByteIndexService::insertToIndex( for (int id = 0; id < numVectors; id += batchSize) { if (numVectors - id < batchSize) { batchSize = numVectors - id; - inputFloatVectors.resize(batchSize * dim); - floatVectorsIds.resize(batchSize); } for (int i = 0; i < batchSize; ++i) { @@ -364,6 +363,5 @@ void ByteIndexService::writeIndex( throw std::runtime_error("Failed to write index to disk"); } } - } // namespace faiss_wrapper } // namesapce knn_jni \ No newline at end of file diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java index 908671a21..8c2dfc126 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.engine.faiss; +import org.apache.commons.lang.StringUtils; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.AbstractKNNMethod; @@ -20,6 +21,7 @@ import java.util.Objects; import java.util.Set; +import static org.opensearch.knn.common.KNNConstants.FAISS_SIGNED_BYTE_SQ; import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.index.engine.faiss.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; import static org.opensearch.knn.index.engine.faiss.FaissFP16Util.isFaissSQClipToFP16RangeEnabled; @@ -87,7 +89,7 @@ protected PerDimensionProcessor doGetPerDimensionProcessor( throw new IllegalStateException("Unsupported vector data type " + vectorDataType); } - static KNNLibraryIndexingContext adjustPrefix( + static KNNLibraryIndexingContext adjustIndexDescription( MethodAsMapBuilder methodAsMapBuilder, MethodComponentContext methodComponentContext, KNNMethodConfigContext knnMethodConfigContext @@ -105,6 +107,19 @@ static KNNLibraryIndexingContext adjustPrefix( if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BINARY) { prefix = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; } + if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BYTE) { + + // If VectorDataType is Byte using Faiss engine then manipulate Index Description to use "SQ8_direct_signed" scalar quantizer + // For example, Index Description "HNSW16,Flat" will be updated as "HNSW16,SQ8_direct_signed" + String indexDescription = methodAsMapBuilder.indexDescription; + if (StringUtils.isNotEmpty(indexDescription)) { + StringBuilder indexDescriptionBuilder = new StringBuilder(); + indexDescriptionBuilder.append(indexDescription.split(",")[0]); + indexDescriptionBuilder.append(","); + indexDescriptionBuilder.append(FAISS_SIGNED_BYTE_SQ); + methodAsMapBuilder.indexDescription = indexDescriptionBuilder.toString(); + } + } methodAsMapBuilder.indexDescription = prefix + methodAsMapBuilder.indexDescription; return methodAsMapBuilder.build(); } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java index dbe89bd0a..41db777e3 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java @@ -99,7 +99,7 @@ private static MethodComponent initMethodComponent() { methodComponentContext, knnMethodConfigContext ).addParameter(METHOD_PARAMETER_M, "", "").addParameter(METHOD_ENCODER_PARAMETER, ",", ""); - return adjustPrefix(methodAsMapBuilder, methodComponentContext, knnMethodConfigContext); + return adjustIndexDescription(methodAsMapBuilder, methodComponentContext, knnMethodConfigContext); })) .build(); } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java index bc30e372c..b3dd12c92 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java @@ -95,7 +95,7 @@ private static MethodComponent initMethodComponent() { methodComponentContext, knnMethodConfigContext ).addParameter(METHOD_PARAMETER_NLIST, "", "").addParameter(METHOD_ENCODER_PARAMETER, ",", ""); - return adjustPrefix(methodAsMapBuilder, methodComponentContext, knnMethodConfigContext); + return adjustIndexDescription(methodAsMapBuilder, methodComponentContext, knnMethodConfigContext); })) .setOverheadInKBEstimator((methodComponent, methodComponentContext, dimension) -> { // Size estimate formula: (4 * nlists * d) / 1024 + 1 diff --git a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java index 75f01ba24..90d4ca879 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java @@ -11,7 +11,6 @@ import org.opensearch.common.Explicit; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.engine.KNNLibraryIndexingContext; @@ -25,8 +24,6 @@ import java.util.Optional; import static org.opensearch.knn.common.KNNConstants.DIMENSION; -import static org.opensearch.knn.common.KNNConstants.FAISS_SIGNED_BYTE_SQ; -import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.common.KNNConstants.QFRAMEWORK_CONFIG; @@ -130,23 +127,10 @@ private MethodFieldMapper( this.fieldType.putAttribute(KNN_ENGINE, knnEngine.getName()); try { - Map libParams = knnLibraryIndexingContext.getLibraryParameters(); - - // If VectorDataType is Byte using Faiss engine then manipulate Index Description to use "SQ8_direct_signed" scalar quantizer - // For example, Index Description "HNSW16,Flat" will be updated as "HNSW16,SQ8_direct_signed" - if (VectorDataType.BYTE == vectorDataType && libParams.containsKey(INDEX_DESCRIPTION_PARAMETER)) { - String indexDescriptionValue = (String) libParams.get(INDEX_DESCRIPTION_PARAMETER); - if (indexDescriptionValue != null && indexDescriptionValue.isEmpty() == false) { - StringBuilder indexDescriptionBuilder = new StringBuilder(); - indexDescriptionBuilder.append(indexDescriptionValue.split(",")[0]); - indexDescriptionBuilder.append(","); - indexDescriptionBuilder.append(FAISS_SIGNED_BYTE_SQ); - - libParams.replace(INDEX_DESCRIPTION_PARAMETER, indexDescriptionBuilder.toString()); - libParams.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.BYTE.getValue()); - } - } - this.fieldType.putAttribute(PARAMETERS, XContentFactory.jsonBuilder().map(libParams).toString()); + this.fieldType.putAttribute( + PARAMETERS, + XContentFactory.jsonBuilder().map(knnLibraryIndexingContext.getLibraryParameters()).toString() + ); } catch (IOException ioe) { throw new RuntimeException(String.format("Unable to create KNNVectorFieldMapper: %s", ioe)); }