From 5d83d1e9463d6e86eda923e8e76724607f705269 Mon Sep 17 00:00:00 2001
From: Naveen Tatikonda
Date: Fri, 12 Jul 2024 15:31:17 -0500
Subject: [PATCH 1/2] Add HNSW changes to support Faiss byte vector
Signed-off-by: Naveen Tatikonda
---
CHANGELOG.md | 1 +
jni/include/commons.h | 24 ++-
jni/include/faiss_index_service.h | 57 +++++++
jni/include/jni_util.h | 7 +-
.../org_opensearch_knn_jni_FaissService.h | 30 ++++
.../org_opensearch_knn_jni_JNICommons.h | 20 ++-
jni/src/commons.cpp | 31 +++-
jni/src/faiss_index_service.cpp | 115 ++++++++++++++
jni/src/jni_util.cpp | 34 +++-
.../org_opensearch_knn_jni_FaissService.cpp | 41 +++++
jni/src/org_opensearch_knn_jni_JNICommons.cpp | 22 +++
jni/tests/faiss_index_service_test.cpp | 44 ++++++
jni/tests/test_util.cpp | 8 +-
jni/tests/test_util.h | 4 +-
.../opensearch/knn/common/KNNConstants.java | 1 +
.../transfer/OffHeapBinaryVectorTransfer.java | 14 +-
.../OffHeapVectorTransferFactory.java | 2 +-
.../index/engine/faiss/FaissHNSWMethod.java | 6 +-
.../mapper/KNNVectorFieldMapperUtil.java | 2 +-
.../knn/index/mapper/MethodFieldMapper.java | 24 ++-
.../index/memory/NativeMemoryAllocation.java | 2 +-
.../knn/index/query/KNNQueryBuilder.java | 60 ++++---
.../opensearch/knn/index/util/IndexUtil.java | 12 ++
.../org/opensearch/knn/jni/FaissService.java | 33 ++++
.../org/opensearch/knn/jni/JNICommons.java | 66 +++++++-
.../org/opensearch/knn/jni/JNIService.java | 12 +-
.../training/ByteTrainingDataConsumer.java | 2 +-
.../knn/index/VectorDataTypeIT.java | 148 +++++++++++++++++-
.../OffHeapVectorTransferFactoryTests.java | 2 +-
.../index/engine/KNNMethodContextTests.java | 6 +-
.../memory/NativeMemoryAllocationTests.java | 2 +-
.../memory/NativeMemoryLoadStrategyTests.java | 2 +-
.../knn/index/query/KNNQueryBuilderTests.java | 1 +
.../java/org/opensearch/knn/TestUtils.java | 2 +-
34 files changed, 778 insertions(+), 59 deletions(-)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 56f207f46..349125beb 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
* Integrate Lucene Vector field with native engines to use KNNVectorFormat during segment creation [#1945](https://github.com/opensearch-project/k-NN/pull/1945)
* k-NN query rescore support for native engines [#1984](https://github.com/opensearch-project/k-NN/pull/1984)
+* Add support for byte vector with Faiss Engine HNSW algorithm [#1823](https://github.com/opensearch-project/k-NN/pull/1823)
### Enhancements
* Adds iterative graph build capability into a faiss index to improve the memory footprint during indexing and Integrates KNNVectorsFormat for native engines[#1950](https://github.com/opensearch-project/k-NN/pull/1950)
### Bug Fixes
diff --git a/jni/include/commons.h b/jni/include/commons.h
index 4cdaf28fc..e1aaacd9c 100644
--- a/jni/include/commons.h
+++ b/jni/include/commons.h
@@ -47,10 +47,24 @@ namespace knn_jni {
* CAUTION: The behavior is undefined if the memory address is deallocated and the method is called
*
* @param memoryAddress The address of the memory location where data will be stored.
- * @param data 2D byte array containing data to be stored in native memory.
+ * @param data 2D byte array containing binary data to be stored in native memory.
* @param initialCapacity The initial capacity of the memory location.
* @return memory address of std::vector where the data is stored.
*/
+ jlong storeBinaryVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong, jboolean);
+
+ /**
+ * This is utility function that can be used to store signed int8 data in native memory. This function will allocate memory for
+ * the data(rows*columns) with initialCapacity and return the memory address where the data is stored.
+ * If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created.
+ * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location
+ * will throw Exception.
+ *
+ * @param memoryAddress The address of the memory location where data will be stored.
+ * @param data 2D byte array containing int8 data to be stored in native memory.
+ * @param initialCapacity The initial capacity of the memory location.
+ * @return memory address of std::vector where the data is stored.
+ */
jlong storeByteVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong, jboolean);
/**
@@ -69,6 +83,14 @@ namespace knn_jni {
*/
void freeByteVectorData(jlong);
+ /**
+ * Free up the memory allocated for the data stored in memory address. This function should be used with the memory
+ * address returned by {@link JNICommons#storeBinaryVectorData(long, byte[][], long, long)}
+ *
+ * @param memoryAddress address to be freed.
+ */
+ void freeBinaryVectorData(jlong);
+
/**
* Extracts query time efSearch from method parameters
**/
diff --git a/jni/include/faiss_index_service.h b/jni/include/faiss_index_service.h
index c57309cfc..29ec90e80 100644
--- a/jni/include/faiss_index_service.h
+++ b/jni/include/faiss_index_service.h
@@ -125,6 +125,63 @@ class BinaryIndexService : public IndexService {
virtual void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) override;
};
+/**
+ * A class to provide operations on index
+ * This class should evolve to have only cpp object but not jni object
+ */
+class ByteIndexService : public IndexService {
+public:
+ //TODO Remove dependency on JNIUtilInterface and JNIEnv
+ //TODO Reduce the number of parameters
+ ByteIndexService(std::unique_ptr faissMethods);
+
+/**
+ * Initialize index
+ *
+ * @param jniUtil jni util
+ * @param env jni environment
+ * @param metric space type for distance calculation
+ * @param indexDescription index description to be used by faiss index factory
+ * @param dim dimension of vectors
+ * @param numVectors number of vectors
+ * @param threadCount number of thread count to be used while adding data
+ * @param parameters parameters to be applied to faiss index
+ * @return memory address of the native index object
+ */
+ virtual jlong initIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int dim, int numVectors, int threadCount, std::unordered_map parameters) override;
+ /**
+ * Add vectors to index
+ *
+ * @param jniUtil jni util
+ * @param env jni environment
+ * @param metric space type for distance calculation
+ * @param indexDescription index description to be used by faiss index factory
+ * @param dim dimension of vectors
+ * @param numIds number of vectors
+ * @param threadCount number of thread count to be used while adding data
+ * @param vectorsAddress memory address which is holding vector data
+ * @param idMap a map of document id and vector id
+ * @param parameters parameters to be applied to faiss index
+ */
+ virtual void insertToIndex(int dim, int numIds, int threadCount, int64_t vectorsAddress, std::vector &ids, jlong idMapAddress) override;
+ /**
+ * Write index to disk
+ *
+ * @param jniUtil jni util
+ * @param env jni environment
+ * @param metric space type for distance calculation
+ * @param indexDescription index description to be used by faiss index factory
+ * @param threadCount number of thread count to be used while adding data
+ * @param indexPath path to write index
+ * @param idMap a map of document id and vector id
+ * @param parameters parameters to be applied to faiss index
+ */
+ virtual void writeIndex(std::string indexPath, jlong idMapAddress) override;
+ virtual ~ByteIndexService() = default;
+protected:
+ virtual void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) override;
+};
+
}
}
diff --git a/jni/include/jni_util.h b/jni/include/jni_util.h
index 1579522d0..825471a3c 100644
--- a/jni/include/jni_util.h
+++ b/jni/include/jni_util.h
@@ -71,8 +71,10 @@ namespace knn_jni {
virtual void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ,
int dim, std::vector *vect ) = 0;
- virtual void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ,
+ virtual void Convert2dJavaObjectArrayAndStoreToBinaryVector(JNIEnv *env, jobjectArray array2dJ,
int dim, std::vector *vect ) = 0;
+ virtual void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ,
+ int dim, std::vector *vect ) = 0;
virtual std::vector ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) = 0;
@@ -173,7 +175,8 @@ namespace knn_jni {
void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val);
void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf);
void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector *vect);
- void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector *vect);
+ void Convert2dJavaObjectArrayAndStoreToBinaryVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector *vect);
+ void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector *vect);
private:
std::unordered_map cachedClasses;
diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h
index 19e13d402..09f3ec8b7 100644
--- a/jni/include/org_opensearch_knn_jni_FaissService.h
+++ b/jni/include/org_opensearch_knn_jni_FaissService.h
@@ -34,6 +34,16 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initIndex(JNIEn
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initBinaryIndex(JNIEnv * env, jclass cls,
jlong numDocs, jint dimJ,
jobject parametersJ);
+
+/*
+ * Class: org_opensearch_knn_jni_FaissService
+ * Method: initByteIndex
+ * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
+ */
+JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initByteIndex(JNIEnv * env, jclass cls,
+ jlong numDocs, jint dimJ,
+ jobject parametersJ);
+
/*
* Class: org_opensearch_knn_jni_FaissService
* Method: insertToIndex
@@ -50,6 +60,16 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JN
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToBinaryIndex(JNIEnv * env, jclass cls, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ,
jlong indexAddress, jint threadCount);
+
+/*
+ * Class: org_opensearch_knn_jni_FaissService
+ * Method: insertToByteIndex
+ * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
+ */
+JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToByteIndex(JNIEnv * env, jclass cls, jintArray idsJ,
+ jlong vectorsAddressJ, jint dimJ,
+ jlong indexAddress, jint threadCount);
+
/*
* Class: org_opensearch_knn_jni_FaissService
* Method: writeIndex
@@ -66,6 +86,16 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex(JNIEn
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeBinaryIndex(JNIEnv * env, jclass cls,
jlong indexAddress,
jstring indexPathJ);
+
+/*
+ * Class: org_opensearch_knn_jni_FaissService
+ * Method: writeByteIndex
+ * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
+ */
+JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeByteIndex(JNIEnv * env, jclass cls,
+ jlong indexAddress,
+ jstring indexPathJ);
+
/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createIndexFromTemplate
diff --git a/jni/include/org_opensearch_knn_jni_JNICommons.h b/jni/include/org_opensearch_knn_jni_JNICommons.h
index 03c0d023a..8bfbcc266 100644
--- a/jni/include/org_opensearch_knn_jni_JNICommons.h
+++ b/jni/include/org_opensearch_knn_jni_JNICommons.h
@@ -28,7 +28,15 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData
/*
* Class: org_opensearch_knn_jni_JNICommons
- * Method: storeVectorData
+ * Method: storeBinaryVectorData
+ * Signature: (J[[FJJ)
+ */
+JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeBinaryVectorData
+ (JNIEnv *, jclass, jlong, jobjectArray, jlong, jboolean);
+
+/*
+ * Class: org_opensearch_knn_jni_JNICommons
+ * Method: storeByteVectorData
* Signature: (J[[FJJ)
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeByteVectorData
@@ -44,7 +52,15 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeVectorData
/*
* Class: org_opensearch_knn_jni_JNICommons
-* Method: freeVectorData
+* Method: freeBinaryVectorData
+* Signature: (J)V
+*/
+JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeBinaryVectorData
+(JNIEnv *, jclass, jlong);
+
+/*
+* Class: org_opensearch_knn_jni_JNICommons
+* Method: freeByteVectorData
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeByteVectorData
diff --git a/jni/src/commons.cpp b/jni/src/commons.cpp
index f9764db73..38e3ac8a4 100644
--- a/jni/src/commons.cpp
+++ b/jni/src/commons.cpp
@@ -37,7 +37,7 @@ jlong knn_jni::commons::storeVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIE
return (jlong) vect;
}
-jlong knn_jni::commons::storeByteVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ,
+jlong knn_jni::commons::storeBinaryVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ,
jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ) {
std::vector *vect;
if ((long) memoryAddressJ == 0) {
@@ -51,6 +51,26 @@ jlong knn_jni::commons::storeByteVectorData(knn_jni::JNIUtilInterface *jniUtil,
vect->clear();
}
+ int dim = jniUtil->GetInnerDimensionOf2dJavaByteArray(env, dataJ);
+ jniUtil->Convert2dJavaObjectArrayAndStoreToBinaryVector(env, dataJ, dim, vect);
+
+ return (jlong) vect;
+}
+
+jlong knn_jni::commons::storeByteVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ,
+ jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ) {
+ std::vector *vect;
+ if (memoryAddressJ == 0) {
+ vect = new std::vector();
+ vect->reserve(static_cast(initialCapacityJ));
+ } else {
+ vect = reinterpret_cast*>(memoryAddressJ);
+ }
+
+ if (appendJ == JNI_FALSE) {
+ vect->clear();
+ }
+
int dim = jniUtil->GetInnerDimensionOf2dJavaByteArray(env, dataJ);
jniUtil->Convert2dJavaObjectArrayAndStoreToByteVector(env, dataJ, dim, vect);
@@ -64,13 +84,20 @@ void knn_jni::commons::freeVectorData(jlong memoryAddressJ) {
}
}
-void knn_jni::commons::freeByteVectorData(jlong memoryAddressJ) {
+void knn_jni::commons::freeBinaryVectorData(jlong memoryAddressJ) {
if (memoryAddressJ != 0) {
auto *vect = reinterpret_cast*>(memoryAddressJ);
delete vect;
}
}
+void knn_jni::commons::freeByteVectorData(jlong memoryAddressJ) {
+ if (memoryAddressJ != 0) {
+ auto *vect = reinterpret_cast*>(memoryAddressJ);
+ delete vect;
+ }
+}
+
int knn_jni::commons::getIntegerMethodParameter(JNIEnv * env, knn_jni::JNIUtilInterface * jniUtil, std::unordered_map methodParams, std::string methodParam, int defaultValue) {
if (methodParams.empty()) {
return defaultValue;
diff --git a/jni/src/faiss_index_service.cpp b/jni/src/faiss_index_service.cpp
index f76c54428..b6e465741 100644
--- a/jni/src/faiss_index_service.cpp
+++ b/jni/src/faiss_index_service.cpp
@@ -250,5 +250,120 @@ void BinaryIndexService::writeIndex(
}
}
+ByteIndexService::ByteIndexService(std::unique_ptr faissMethods) : IndexService(std::move(faissMethods)) {}
+
+void ByteIndexService::allocIndex(faiss::Index * index, size_t dim, size_t numVectors) {
+ if(auto * indexHNSWSQ = dynamic_cast(index)) {
+ if(auto * indexScalarQuantizer = dynamic_cast(indexHNSWSQ->storage)) {
+ indexScalarQuantizer->codes.reserve(indexScalarQuantizer->code_size * numVectors);
+ }
+ return;
+ }
+}
+
+jlong ByteIndexService::initIndex(
+ knn_jni::JNIUtilInterface * jniUtil,
+ JNIEnv * env,
+ faiss::MetricType metric,
+ std::string indexDescription,
+ int dim,
+ int numVectors,
+ int threadCount,
+ std::unordered_map parameters
+ ) {
+ // Create index using Faiss factory method
+ std::unique_ptr index(faissMethods->indexFactory(dim, indexDescription.c_str(), metric));
+
+ // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
+ if(threadCount != 0) {
+ omp_set_num_threads(threadCount);
+ }
+
+ // Add extra parameters that cant be configured with the index factory
+ SetExtraParameters(jniUtil, env, parameters, index.get());
+
+ // Check that the index does not need to be trained
+ if(!index->is_trained) {
+ throw std::runtime_error("Index is not trained");
+ }
+
+ std::unique_ptr idMap (faissMethods->indexIdMap(index.get()));
+ //Makes sure the index is deleted when the destructor is called, this cannot be passed in the constructor
+ idMap->own_fields = true;
+
+ allocIndex(dynamic_cast(idMap->index), dim, numVectors);
+
+ //Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later
+ //in insert and write operations
+ index.release();
+ return reinterpret_cast(idMap.release());
+}
+
+void ByteIndexService::insertToIndex(
+ int dim,
+ int numIds,
+ int threadCount,
+ int64_t vectorsAddress,
+ std::vector & ids,
+ jlong idMapAddress
+ ) {
+ // Read vectors from memory address
+ auto *inputVectors = reinterpret_cast*>(vectorsAddress);
+
+ // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value
+ int numVectors = inputVectors->size() / dim;
+ if(numVectors == 0) {
+ throw std::runtime_error("Number of vectors cannot be 0");
+ }
+
+ if (numIds != numVectors) {
+ throw std::runtime_error("Number of IDs does not match number of vectors");
+ }
+
+ // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
+ if(threadCount != 0) {
+ omp_set_num_threads(threadCount);
+ }
+
+ faiss::IndexIDMap * idMap = reinterpret_cast (idMapAddress);
+
+ // Add vectors in batches by casting int8 vectors into float with a batch size of 1000
+ int batchSize = 1000;
+ std::vector inputFloatVectors(batchSize * dim);
+ std::vector floatVectorsIds(batchSize);
+ int id = 0;
+ auto iter = inputVectors->begin();
+
+ for (int id = 0; id < numVectors; id += batchSize) {
+ if (numVectors - id < batchSize) {
+ batchSize = numVectors - id;
+ inputFloatVectors.resize(batchSize * dim);
+ floatVectorsIds.resize(batchSize);
+ }
+
+ for (int i = 0; i < batchSize; ++i) {
+ floatVectorsIds[i] = ids[id + i];
+ for (int j = 0; j < dim; ++j, ++iter) {
+ inputFloatVectors[i * dim + j] = static_cast(*iter);
+ }
+ }
+ idMap->add_with_ids(batchSize, inputFloatVectors.data(), floatVectorsIds.data());
+ }
+}
+
+void ByteIndexService::writeIndex(
+ std::string indexPath,
+ jlong idMapAddress
+ ) {
+ std::unique_ptr idMap (reinterpret_cast (idMapAddress));
+
+ try {
+ // Write the index to disk
+ faissMethods->writeIndex(idMap.get(), indexPath.c_str());
+ } catch(std::exception &e) {
+ throw std::runtime_error("Failed to write index to disk");
+ }
+}
+
} // namespace faiss_wrapper
} // namesapce knn_jni
\ No newline at end of file
diff --git a/jni/src/jni_util.cpp b/jni/src/jni_util.cpp
index ee4c382b5..82900b5ce 100644
--- a/jni/src/jni_util.cpp
+++ b/jni/src/jni_util.cpp
@@ -261,7 +261,7 @@ void knn_jni::JNIUtil::Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env
env->DeleteLocalRef(array2dJ);
}
-void knn_jni::JNIUtil::Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ,
+void knn_jni::JNIUtil::Convert2dJavaObjectArrayAndStoreToBinaryVector(JNIEnv *env, jobjectArray array2dJ,
int dim, std::vector *vect) {
if (array2dJ == nullptr) {
@@ -294,6 +294,38 @@ void knn_jni::JNIUtil::Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env,
env->DeleteLocalRef(array2dJ);
}
+void knn_jni::JNIUtil::Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ,
+ int dim, std::vector *vect) {
+
+ if (array2dJ == nullptr) {
+ throw std::runtime_error("Array cannot be null");
+ }
+
+ int numVectors = env->GetArrayLength(array2dJ);
+ this->HasExceptionInStack(env, "Unable to get array length");
+
+ for (int i = 0; i < numVectors; ++i) {
+ auto vectorArray = static_cast(env->GetObjectArrayElement(array2dJ, i));
+ this->HasExceptionInStack(env, "Unable to get object array element");
+
+ if (dim != env->GetArrayLength(vectorArray)) {
+ env->DeleteLocalRef(array2dJ);
+ throw std::runtime_error("Dimension of vectors is inconsistent");
+ }
+
+ int8_t* vector = reinterpret_cast(env->GetByteArrayElements(vectorArray, nullptr));
+ if (vector == nullptr) {
+ this->HasExceptionInStack(env);
+ throw std::runtime_error("Unable to get byte array elements");
+ }
+
+ vect->insert(vect->end(), vector, vector + dim);
+ env->ReleaseByteArrayElements(vectorArray, reinterpret_cast(vector), JNI_ABORT);
+ }
+ this->HasExceptionInStack(env);
+ env->DeleteLocalRef(array2dJ);
+}
+
std::vector knn_jni::JNIUtil::ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) {
if (arrayJ == nullptr) {
diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp
index 663e18457..bcdc4f18b 100644
--- a/jni/src/org_opensearch_knn_jni_FaissService.cpp
+++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp
@@ -67,6 +67,20 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initBinaryIndex
return (jlong)0;
}
+JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initByteIndex(JNIEnv * env, jclass cls,
+ jlong numDocs, jint dimJ,
+ jobject parametersJ)
+{
+ try {
+ std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods());
+ knn_jni::faiss_wrapper::ByteIndexService byteIndexService(std::move(faissMethods));
+ return knn_jni::faiss_wrapper::InitIndex(&jniUtil, env, numDocs, dimJ, parametersJ, &byteIndexService);
+ } catch (...) {
+ jniUtil.CatchCppExceptionAndThrowJava(env);
+ }
+ return (jlong)0;
+}
+
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JNIEnv * env, jclass cls, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ,
jlong indexAddress, jint threadCount)
@@ -95,6 +109,20 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToBinaryIn
}
}
+JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToByteIndex(JNIEnv * env, jclass cls, jintArray idsJ,
+ jlong vectorsAddressJ, jint dimJ,
+ jlong indexAddress, jint threadCount)
+{
+ try {
+ std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods());
+ knn_jni::faiss_wrapper::ByteIndexService byteIndexService(std::move(faissMethods));
+ knn_jni::faiss_wrapper::InsertToIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexAddress, threadCount, &byteIndexService);
+ } catch (...) {
+ // NOTE: ADDING DELETE STATEMENT HERE CAUSES A CRASH!
+ jniUtil.CatchCppExceptionAndThrowJava(env);
+ }
+}
+
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex(JNIEnv * env, jclass cls,
jlong indexAddress,
jstring indexPathJ)
@@ -121,6 +149,19 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeBinaryIndex
}
}
+JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeByteIndex(JNIEnv * env, jclass cls,
+ jlong indexAddress,
+ jstring indexPathJ)
+{
+ try {
+ std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods());
+ knn_jni::faiss_wrapper::ByteIndexService byteIndexService(std::move(faissMethods));
+ knn_jni::faiss_wrapper::WriteIndex(&jniUtil, env, indexPathJ, indexAddress, &byteIndexService);
+ } catch (...) {
+ jniUtil.CatchCppExceptionAndThrowJava(env);
+ }
+}
+
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate(JNIEnv * env, jclass cls,
jintArray idsJ,
jlong vectorsAddressJ,
diff --git a/jni/src/org_opensearch_knn_jni_JNICommons.cpp b/jni/src/org_opensearch_knn_jni_JNICommons.cpp
index 7432c44d3..906592b2d 100644
--- a/jni/src/org_opensearch_knn_jni_JNICommons.cpp
+++ b/jni/src/org_opensearch_knn_jni_JNICommons.cpp
@@ -49,6 +49,18 @@ jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ, jboolean appen
return (long)memoryAddressJ;
}
+JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeBinaryVectorData(JNIEnv * env, jclass cls,
+jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ)
+
+{
+ try {
+ return knn_jni::commons::storeBinaryVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ, appendJ);
+ } catch (...) {
+ jniUtil.CatchCppExceptionAndThrowJava(env);
+ }
+ return (long)memoryAddressJ;
+}
+
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeByteVectorData(JNIEnv * env, jclass cls,
jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ)
@@ -72,6 +84,16 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeVectorData(JNI
}
+JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeBinaryVectorData(JNIEnv * env, jclass cls,
+ jlong memoryAddressJ)
+{
+ try {
+ return knn_jni::commons::freeBinaryVectorData(memoryAddressJ);
+ } catch (...) {
+ jniUtil.CatchCppExceptionAndThrowJava(env);
+ }
+}
+
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeByteVectorData(JNIEnv * env, jclass cls,
jlong memoryAddressJ)
{
diff --git a/jni/tests/faiss_index_service_test.cpp b/jni/tests/faiss_index_service_test.cpp
index 1f00f6a1d..8d9e4bb43 100644
--- a/jni/tests/faiss_index_service_test.cpp
+++ b/jni/tests/faiss_index_service_test.cpp
@@ -113,4 +113,48 @@ TEST(CreateBinaryIndexTest, BasicAssertions) {
long indexAddress = indexService.initIndex(&mockJNIUtil, jniEnv, metricType, indexDescription, dim, numIds, threadCount, parametersMap);
indexService.insertToIndex(dim, numIds, threadCount, (int64_t) &vectors, ids, indexAddress);
indexService.writeIndex(indexPath, indexAddress);
+}
+
+TEST(CreateByteIndexTest, BasicAssertions) {
+ // Define the data
+ faiss::idx_t numIds = 200;
+ std::vector ids;
+ std::vector vectors;
+ int dim = 8;
+ vectors.reserve(numIds * dim);
+ for (int64_t i = 0; i < numIds; ++i) {
+ ids.push_back(i);
+ for (int j = 0; j < dim; ++j) {
+ vectors.push_back(test_util::RandomInt(-128, 127));
+ }
+ }
+
+ std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss");
+ faiss::MetricType metricType = faiss::METRIC_L2;
+ std::string indexDescription = "HNSW16,SQ8_direct_signed";
+ int threadCount = 1;
+ std::unordered_map parametersMap;
+
+ // Set up jni
+ JNIEnv *jniEnv = nullptr;
+ NiceMock mockJNIUtil;
+
+ // Setup faiss method mock
+ // This object is handled by unique_ptr inside indexService.createIndex()
+ MockIndex* index = new MockIndex();
+ // This object is handled by unique_ptr inside indexService.createIndex()
+ faiss::IndexIDMap* indexIdMap = new faiss::IndexIDMap(index);
+ std::unique_ptr mockFaissMethods(new MockFaissMethods());
+ EXPECT_CALL(*mockFaissMethods, indexFactory(dim, ::testing::StrEq(indexDescription.c_str()), metricType))
+ .WillOnce(Return(index));
+ EXPECT_CALL(*mockFaissMethods, indexIdMap(index))
+ .WillOnce(Return(indexIdMap));
+ EXPECT_CALL(*mockFaissMethods, writeIndex(indexIdMap, ::testing::StrEq(indexPath.c_str())))
+ .Times(1);
+
+ // Create the index
+ knn_jni::faiss_wrapper::ByteIndexService indexService(std::move(mockFaissMethods));
+ long indexAddress = indexService.initIndex(&mockJNIUtil, jniEnv, metricType, indexDescription, dim, numIds, threadCount, parametersMap);
+ indexService.insertToIndex(dim, numIds, threadCount, (int64_t) &vectors, ids, indexAddress);
+ indexService.writeIndex(indexPath, indexAddress);
}
\ No newline at end of file
diff --git a/jni/tests/test_util.cpp b/jni/tests/test_util.cpp
index 2149f8a1a..4f8bd2c34 100644
--- a/jni/tests/test_util.cpp
+++ b/jni/tests/test_util.cpp
@@ -51,12 +51,18 @@ test_util::MockJNIUtil::MockJNIUtil() {
(*reinterpret_cast> *>(array2dJ)))
for (auto item : v) data->push_back(item);
});
- ON_CALL(*this, Convert2dJavaObjectArrayAndStoreToByteVector)
+ ON_CALL(*this, Convert2dJavaObjectArrayAndStoreToBinaryVector)
.WillByDefault([this](JNIEnv *env, jobjectArray array2dJ, int dim, std::vector* data) {
for (const auto &v :
(*reinterpret_cast> *>(array2dJ)))
for (auto item : v) data->push_back(item);
});
+ ON_CALL(*this, Convert2dJavaObjectArrayAndStoreToByteVector)
+ .WillByDefault([this](JNIEnv *env, jobjectArray array2dJ, int dim, std::vector* data) {
+ for (const auto &v :
+ (*reinterpret_cast> *>(array2dJ)))
+ for (auto item : v) data->push_back(item);
+ });
// arrayJ is re-interpreted as std::vector *
diff --git a/jni/tests/test_util.h b/jni/tests/test_util.h
index ba773fad3..a90d45dd9 100644
--- a/jni/tests/test_util.h
+++ b/jni/tests/test_util.h
@@ -46,8 +46,10 @@ namespace test_util {
(JNIEnv * env, jobjectArray array2dJ, int dim));
MOCK_METHOD(void, Convert2dJavaObjectArrayAndStoreToFloatVector,
(JNIEnv * env, jobjectArray array2dJ, int dim, std::vector*vect));
- MOCK_METHOD(void, Convert2dJavaObjectArrayAndStoreToByteVector,
+ MOCK_METHOD(void, Convert2dJavaObjectArrayAndStoreToBinaryVector,
(JNIEnv * env, jobjectArray array2dJ, int dim, std::vector*vect));
+ MOCK_METHOD(void, Convert2dJavaObjectArrayAndStoreToByteVector,
+ (JNIEnv * env, jobjectArray array2dJ, int dim, std::vector*vect));
MOCK_METHOD(std::vector, ConvertJavaIntArrayToCppIntVector,
(JNIEnv * env, jintArray arrayJ));
MOCK_METHOD2(ConvertJavaMapToCppMap,
diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java
index bef2ccf0c..ac6f69d31 100644
--- a/src/main/java/org/opensearch/knn/common/KNNConstants.java
+++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java
@@ -111,6 +111,7 @@ public class KNNConstants {
public static final String FAISS_SQ_TYPE = "type";
public static final String FAISS_SQ_ENCODER_FP16 = "fp16";
public static final List FAISS_SQ_ENCODER_TYPES = List.of(FAISS_SQ_ENCODER_FP16);
+ public static final String FAISS_SIGNED_BYTE_SQ = "SQ8_direct_signed";
public static final String FAISS_SQ_CLIP = "clip";
// Parameter defaults/limits
diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java
index c9d4802fe..ffa12a231 100644
--- a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java
+++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java
@@ -5,6 +5,8 @@
package org.opensearch.knn.index.codec.transfer;
+import org.opensearch.knn.jni.JNICommons;
+
import java.io.IOException;
import java.util.List;
@@ -21,12 +23,16 @@ public OffHeapBinaryVectorTransfer(int transferLimit) {
@Override
public void deallocate() {
- // TODO: deallocate the memory location
+ JNICommons.freeBinaryVectorData(getVectorAddress());
}
@Override
- protected long transfer(List vectorsToTransfer, boolean append) throws IOException {
- // TODO: call to JNIService to transfer vector
- return 0L;
+ protected long transfer(List batch, boolean append) throws IOException {
+ return JNICommons.storeBinaryVectorData(
+ getVectorAddress(),
+ batch.toArray(new byte[][] {}),
+ (long) batch.get(0).length * transferLimit,
+ append
+ );
}
}
diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactory.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactory.java
index bfcc13491..446b6ae80 100644
--- a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactory.java
+++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactory.java
@@ -27,7 +27,7 @@ public static OffHeapVectorTransfer getVectorTransfer(final VectorDataTyp
case FLOAT:
return (OffHeapVectorTransfer) new OffHeapFloatVectorTransfer(transferLimit);
case BINARY:
- // TODO: Add binary here
+ return (OffHeapVectorTransfer) new OffHeapBinaryVectorTransfer(transferLimit);
case BYTE:
return (OffHeapVectorTransfer) new OffHeapByteVectorTransfer(transferLimit);
default:
diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java
index 15e43cdd5..dbe89bd0a 100644
--- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java
+++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java
@@ -35,7 +35,11 @@
*/
public class FaissHNSWMethod extends AbstractFaissMethod {
- private static final Set SUPPORTED_DATA_TYPES = ImmutableSet.of(VectorDataType.FLOAT, VectorDataType.BINARY);
+ private static final Set SUPPORTED_DATA_TYPES = ImmutableSet.of(
+ VectorDataType.FLOAT,
+ VectorDataType.BINARY,
+ VectorDataType.BYTE
+ );
public final static List SUPPORTED_SPACES = Arrays.asList(
SpaceType.UNDEFINED,
diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java
index 57a4dd062..5ab2dd888 100644
--- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java
+++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java
@@ -23,10 +23,10 @@
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.KnnCircuitBreakerException;
import org.opensearch.knn.index.SpaceType;
-import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;
import org.opensearch.knn.index.engine.KNNEngine;
+import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.util.IndexHyperParametersUtil;
diff --git a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java
index 90d4ca879..75f01ba24 100644
--- a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java
+++ b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java
@@ -11,6 +11,7 @@
import org.opensearch.common.Explicit;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.knn.index.SpaceType;
+import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.engine.KNNLibraryIndexingContext;
@@ -24,6 +25,8 @@
import java.util.Optional;
import static org.opensearch.knn.common.KNNConstants.DIMENSION;
+import static org.opensearch.knn.common.KNNConstants.FAISS_SIGNED_BYTE_SQ;
+import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.common.KNNConstants.QFRAMEWORK_CONFIG;
@@ -127,10 +130,23 @@ private MethodFieldMapper(
this.fieldType.putAttribute(KNN_ENGINE, knnEngine.getName());
try {
- this.fieldType.putAttribute(
- PARAMETERS,
- XContentFactory.jsonBuilder().map(knnLibraryIndexingContext.getLibraryParameters()).toString()
- );
+ Map libParams = knnLibraryIndexingContext.getLibraryParameters();
+
+ // If VectorDataType is Byte using Faiss engine then manipulate Index Description to use "SQ8_direct_signed" scalar quantizer
+ // For example, Index Description "HNSW16,Flat" will be updated as "HNSW16,SQ8_direct_signed"
+ if (VectorDataType.BYTE == vectorDataType && libParams.containsKey(INDEX_DESCRIPTION_PARAMETER)) {
+ String indexDescriptionValue = (String) libParams.get(INDEX_DESCRIPTION_PARAMETER);
+ if (indexDescriptionValue != null && indexDescriptionValue.isEmpty() == false) {
+ StringBuilder indexDescriptionBuilder = new StringBuilder();
+ indexDescriptionBuilder.append(indexDescriptionValue.split(",")[0]);
+ indexDescriptionBuilder.append(",");
+ indexDescriptionBuilder.append(FAISS_SIGNED_BYTE_SQ);
+
+ libParams.replace(INDEX_DESCRIPTION_PARAMETER, indexDescriptionBuilder.toString());
+ libParams.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.BYTE.getValue());
+ }
+ }
+ this.fieldType.putAttribute(PARAMETERS, XContentFactory.jsonBuilder().map(libParams).toString());
} catch (IOException ioe) {
throw new RuntimeException(String.format("Unable to create KNNVectorFieldMapper: %s", ioe));
}
diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java
index 4f92a9c4b..755b6b925 100644
--- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java
+++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java
@@ -300,7 +300,7 @@ private void cleanup() {
if (this.memoryAddress != 0) {
if (IndexUtil.isBinaryIndex(vectorDataType)) {
- JNICommons.freeByteVectorData(this.memoryAddress);
+ JNICommons.freeBinaryVectorData(this.memoryAddress);
} else {
JNICommons.freeVectorData(this.memoryAddress);
}
diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
index cba5692c9..208e075eb 100644
--- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
+++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
@@ -481,22 +481,32 @@ protected Query doToQuery(QueryShardContext context) {
}
byte[] byteVector = new byte[0];
- if (VectorDataType.BINARY == vectorDataType) {
- byteVector = new byte[vector.length];
- for (int i = 0; i < vector.length; i++) {
- validateByteVectorValue(vector[i], knnVectorFieldType.getVectorDataType());
- byteVector[i] = (byte) vector[i];
- }
- spaceType.validateVector(byteVector);
- } else if (VectorDataType.BYTE == vectorDataType) {
- byteVector = new byte[vector.length];
- for (int i = 0; i < vector.length; i++) {
- validateByteVectorValue(vector[i], knnVectorFieldType.getVectorDataType());
- byteVector[i] = (byte) vector[i];
- }
- spaceType.validateVector(byteVector);
- } else {
- spaceType.validateVector(vector);
+ switch (vectorDataType) {
+ case BINARY:
+ byteVector = new byte[vector.length];
+ for (int i = 0; i < vector.length; i++) {
+ validateByteVectorValue(vector[i], knnVectorFieldType.getVectorDataType());
+ byteVector[i] = (byte) vector[i];
+ }
+ spaceType.validateVector(byteVector);
+ break;
+ case BYTE:
+ if (KNNEngine.LUCENE == knnEngine) {
+ byteVector = new byte[vector.length];
+ for (int i = 0; i < vector.length; i++) {
+ validateByteVectorValue(vector[i], knnVectorFieldType.getVectorDataType());
+ byteVector[i] = (byte) vector[i];
+ }
+ spaceType.validateVector(byteVector);
+ } else {
+ for (float v : vector) {
+ validateByteVectorValue(v, knnVectorFieldType.getVectorDataType());
+ }
+ spaceType.validateVector(vector);
+ }
+ break;
+ default:
+ spaceType.validateVector(vector);
}
if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)
@@ -512,8 +522,8 @@ protected Query doToQuery(QueryShardContext context) {
.knnEngine(knnEngine)
.indexName(indexName)
.fieldName(this.fieldName)
- .vector(VectorDataType.FLOAT == vectorDataType ? this.vector : null)
- .byteVector(VectorDataType.BYTE == vectorDataType || VectorDataType.BINARY == vectorDataType ? byteVector : null)
+ .vector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine))
+ .byteVector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, byteVector))
.vectorDataType(vectorDataType)
.k(this.k)
.methodParameters(this.methodParameters)
@@ -581,6 +591,20 @@ private void updateQueryStats(VectorQueryType vectorQueryType) {
}
}
+ private float[] getVectorForCreatingQueryRequest(VectorDataType vectorDataType, KNNEngine knnEngine) {
+ if ((VectorDataType.FLOAT == vectorDataType) || (VectorDataType.BYTE == vectorDataType && KNNEngine.FAISS == knnEngine)) {
+ return this.vector;
+ }
+ return null;
+ }
+
+ private byte[] getVectorForCreatingQueryRequest(VectorDataType vectorDataType, KNNEngine knnEngine, byte[] byteVector) {
+ if (VectorDataType.BINARY == vectorDataType || (VectorDataType.BYTE == vectorDataType && KNNEngine.LUCENE == knnEngine)) {
+ return byteVector;
+ }
+ return null;
+ }
+
@Override
protected boolean doEquals(KNNQueryBuilder other) {
return Objects.equals(fieldName, other.fieldName)
diff --git a/src/main/java/org/opensearch/knn/index/util/IndexUtil.java b/src/main/java/org/opensearch/knn/index/util/IndexUtil.java
index 84a4e8940..431579fae 100644
--- a/src/main/java/org/opensearch/knn/index/util/IndexUtil.java
+++ b/src/main/java/org/opensearch/knn/index/util/IndexUtil.java
@@ -415,4 +415,16 @@ private static Map initializeMinimalRequiredVersionMap() {
}
return Collections.unmodifiableMap(versionMap);
}
+
+ /**
+ * Tell if it is byte index or not
+ *
+ * @param parameters parameters associated with an index
+ * @return true if it is binary index
+ */
+ public static boolean isByteIndex(Map parameters) {
+ return parameters.getOrDefault(VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue())
+ .toString()
+ .equals(VectorDataType.BYTE.getValue());
+ }
}
diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java
index a402be1f3..26c703eeb 100644
--- a/src/main/java/org/opensearch/knn/jni/FaissService.java
+++ b/src/main/java/org/opensearch/knn/jni/FaissService.java
@@ -69,6 +69,16 @@ class FaissService {
*/
public static native long initBinaryIndex(long numDocs, int dim, Map parameters);
+ /**
+ * Initialize a byte index for the native library. Takes in numDocs to
+ * allocate the correct amount of memory.
+ *
+ * @param numDocs number of documents to be added
+ * @param dim dimension of the vector to be indexed
+ * @param parameters parameters to build index
+ */
+ public static native long initByteIndex(long numDocs, int dim, Map parameters);
+
/**
* Inserts to a faiss index. The memory occupied by the vectorsAddress will be freed up during the
* function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer
@@ -95,6 +105,19 @@ class FaissService {
*/
public static native void insertToBinaryIndex(int[] ids, long vectorsAddress, int dim, long indexAddress, int threadCount);
+ /**
+ * Inserts to a faiss index. The memory occupied by the vectorsAddress will be freed up during the
+ * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer
+ * created the memory address and that should only free up the memory.
+ *
+ * @param ids ids of documents
+ * @param vectorsAddress address of native memory where vectors are stored
+ * @param dim dimension of the vector to be indexed
+ * @param indexAddress address of native memory where index is stored
+ * @param threadCount number of threads to use for insertion
+ */
+ public static native void insertToByteIndex(int[] ids, long vectorsAddress, int dim, long indexAddress, int threadCount);
+
/**
* Writes a faiss index.
*
@@ -115,6 +138,16 @@ class FaissService {
*/
public static native void writeBinaryIndex(long indexAddress, String indexPath);
+ /**
+ * Writes a faiss index.
+ *
+ * NOTE: This will always free the index. Do not call free after this.
+ *
+ * @param indexAddress address of native memory where index is stored
+ * @param indexPath path to save index file to
+ */
+ public static native void writeByteIndex(long indexAddress, String indexPath);
+
/**
* Create an index for the native library with a provided template index
*
diff --git a/src/main/java/org/opensearch/knn/jni/JNICommons.java b/src/main/java/org/opensearch/knn/jni/JNICommons.java
index c7222738e..df1024db4 100644
--- a/src/main/java/org/opensearch/knn/jni/JNICommons.java
+++ b/src/main/java/org/opensearch/knn/jni/JNICommons.java
@@ -70,7 +70,28 @@ public static long storeVectorData(long memoryAddress, float[][] data, long init
public static native long storeVectorData(long memoryAddress, float[][] data, long initialCapacity, boolean append);
/**
- * This is utility function that can be used to store data in native memory. This function will allocate memory for
+ * This is utility function that can be used to store binary data in native memory. This function will allocate memory for
+ * the data(rows*columns) with initialCapacity and return the memory address where the data is stored.
+ * If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created.
+ * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location
+ * will throw Exception.
+ *
+ *
+ * The function is not threadsafe. If multiple threads are trying to insert on same memory location, then it can
+ * lead to data corruption.
+ *
+ *
+ * @param memoryAddress The address of the memory location where data will be stored.
+ * @param data 2D byte array containing binary data to be stored in native memory.
+ * @param initialCapacity The initial capacity of the memory location.
+ * @return memory address where the data is stored.
+ */
+ public static long storeBinaryVectorData(long memoryAddress, byte[][] data, long initialCapacity) {
+ return storeBinaryVectorData(memoryAddress, data, initialCapacity, true);
+ }
+
+ /**
+ * This is utility function that can be used to store binary data in native memory. This function will allocate memory for
* the data(rows*columns) with initialCapacity and return the memory address where the data is stored.
* If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created.
* For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location
@@ -82,7 +103,27 @@ public static long storeVectorData(long memoryAddress, float[][] data, long init
*
*
* @param memoryAddress The address of the memory location where data will be stored.
- * @param data 2D byte array containing data to be stored in native memory.
+ * @param data 2D byte array containing binary data to be stored in native memory.
+ * @param initialCapacity The initial capacity of the memory location.
+ * @param append append the data or rewrite the memory location
+ * @return memory address where the data is stored.
+ */
+ public static native long storeBinaryVectorData(long memoryAddress, byte[][] data, long initialCapacity, boolean append);
+
+ /**
+ * This is utility function that can be used to store byte data in native memory. This function will allocate memory for
+ * the data(rows*columns) with initialCapacity and return the memory address where the data is stored.
+ * If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created.
+ * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location
+ * will throw Exception.
+ *
+ *
+ * The function is not threadsafe. If multiple threads are trying to insert on same memory location, then it can
+ * lead to data corruption.
+ *
+ *
+ * @param memoryAddress The address of the memory location where data will be stored.
+ * @param data 2D byte array containing byte data to be stored in native memory.
* @param initialCapacity The initial capacity of the memory location.
* @return memory address where the data is stored.
*/
@@ -91,7 +132,7 @@ public static long storeByteVectorData(long memoryAddress, byte[][] data, long i
}
/**
- * This is utility function that can be used to store data in native memory. This function will allocate memory for
+ * This is utility function that can be used to store byte data in native memory. This function will allocate memory for
* the data(rows*columns) with initialCapacity and return the memory address where the data is stored.
* If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created.
* For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location
@@ -103,7 +144,7 @@ public static long storeByteVectorData(long memoryAddress, byte[][] data, long i
*
*
* @param memoryAddress The address of the memory location where data will be stored.
- * @param data 2D byte array containing data to be stored in native memory.
+ * @param data 2D byte array containing byte data to be stored in native memory.
* @param initialCapacity The initial capacity of the memory location.
* @param append append the data or rewrite the memory location
* @return memory address where the data is stored.
@@ -124,8 +165,21 @@ public static long storeByteVectorData(long memoryAddress, byte[][] data, long i
public static native void freeVectorData(long memoryAddress);
/**
- * Free up the memory allocated for the byte data stored in memory address. This function should be used with the memory
- * address returned by {@link JNICommons#storeVectorData(long, float[][], long, boolean)}
+ * Free up the memory allocated for the binary data stored in memory address. This function should be used with the memory
+ * address returned by {@link JNICommons#storeBinaryVectorData(long, byte[][], long)}
+ *
+ *
+ * The function is not threadsafe. If multiple threads are trying to free up same memory location, then it can
+ * lead to errors.
+ *
+ *
+ * @param memoryAddress address to be freed.
+ */
+ public static native void freeBinaryVectorData(long memoryAddress);
+
+ /**
+ * Free up the memory allocated for the binary data stored in memory address. This function should be used with the memory
+ * address returned by {@link JNICommons#storeBinaryVectorData(long, byte[][], long)}
*
*
* The function is not threadsafe. If multiple threads are trying to free up same memory location, then it can
diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java
index d1d5f6c11..1177d635e 100644
--- a/src/main/java/org/opensearch/knn/jni/JNIService.java
+++ b/src/main/java/org/opensearch/knn/jni/JNIService.java
@@ -39,9 +39,13 @@ public static long initIndex(long numDocs, int dim, Map paramete
if (KNNEngine.FAISS == knnEngine) {
if (IndexUtil.isBinaryIndex(knnEngine, parameters)) {
return FaissService.initBinaryIndex(numDocs, dim, parameters);
- } else {
- return FaissService.initIndex(numDocs, dim, parameters);
}
+ if (IndexUtil.isByteIndex(parameters)) {
+ return FaissService.initByteIndex(numDocs, dim, parameters);
+ }
+
+ return FaissService.initIndex(numDocs, dim, parameters);
+
}
throw new IllegalArgumentException(
@@ -71,6 +75,8 @@ public static void insertToIndex(
if (KNNEngine.FAISS == knnEngine) {
if (IndexUtil.isBinaryIndex(knnEngine, parameters)) {
FaissService.insertToBinaryIndex(docs, vectorsAddress, dimension, indexAddress, threadCount);
+ } else if (IndexUtil.isByteIndex(parameters)) {
+ FaissService.insertToByteIndex(docs, vectorsAddress, dimension, indexAddress, threadCount);
} else {
FaissService.insertToIndex(docs, vectorsAddress, dimension, indexAddress, threadCount);
}
@@ -94,6 +100,8 @@ public static void writeIndex(String indexPath, long indexAddress, KNNEngine knn
if (KNNEngine.FAISS == knnEngine) {
if (IndexUtil.isBinaryIndex(knnEngine, parameters)) {
FaissService.writeBinaryIndex(indexAddress, indexPath);
+ } else if (IndexUtil.isByteIndex(parameters)) {
+ FaissService.writeByteIndex(indexAddress, indexPath);
} else {
FaissService.writeIndex(indexAddress, indexPath);
}
diff --git a/src/main/java/org/opensearch/knn/training/ByteTrainingDataConsumer.java b/src/main/java/org/opensearch/knn/training/ByteTrainingDataConsumer.java
index 70cfb4f4c..e838b5214 100644
--- a/src/main/java/org/opensearch/knn/training/ByteTrainingDataConsumer.java
+++ b/src/main/java/org/opensearch/knn/training/ByteTrainingDataConsumer.java
@@ -39,7 +39,7 @@ public ByteTrainingDataConsumer(NativeMemoryAllocation.TrainingDataAllocation tr
@Override
public void accept(List> byteVectors) {
long memoryAddress = trainingDataAllocation.getMemoryAddress();
- memoryAddress = JNICommons.storeByteVectorData(memoryAddress, byteVectors.toArray(new byte[0][0]), byteVectors.size());
+ memoryAddress = JNICommons.storeBinaryVectorData(memoryAddress, byteVectors.toArray(new byte[0][0]), byteVectors.size());
trainingDataAllocation.setMemoryAddress(memoryAddress);
}
diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java
index ea137f042..fe621d7d4 100644
--- a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java
+++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java
@@ -13,15 +13,15 @@
import org.opensearch.client.ResponseException;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentFactory;
+import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.KNNResult;
import org.opensearch.knn.common.KNNConstants;
-import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.engine.KNNEngine;
-import org.opensearch.core.rest.RestStatus;
+import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.script.Script;
import java.util.ArrayList;
@@ -32,7 +32,13 @@
import java.util.Map;
import static org.opensearch.knn.common.KNNConstants.DIMENSION;
+import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
+import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16;
+import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE;
+import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
+import static org.opensearch.knn.common.KNNConstants.NAME;
+import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES;
@@ -453,6 +459,138 @@ public void testSearchWithMissingQueryVector() {
assertTrue(ex.getMessage().contains("[knn] requires query vector"));
}
+ @SneakyThrows
+ public void testAddDocWithByteVectorUsingFaissEngine() {
+ createKnnIndexMappingWithFaissEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue());
+ Byte[] vector = { 6, 6 };
+ addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector);
+
+ refreshAllIndices();
+ assertEquals(1, getDocCount(INDEX_NAME));
+ }
+
+ @SneakyThrows
+ public void testUpdateDocWithByteVectorUsingFaissEngine() {
+ createKnnIndexMappingWithFaissEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue());
+ Byte[] vector = { -36, 78 };
+ addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector);
+
+ Byte[] updatedVector = { 89, -8 };
+ updateKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, updatedVector);
+
+ refreshAllIndices();
+ assertEquals(1, getDocCount(INDEX_NAME));
+ }
+
+ @SneakyThrows
+ public void testDeleteDocWithByteVectorUsingFaissEngine() {
+ createKnnIndexMappingWithFaissEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue());
+ Byte[] vector = { 35, -46 };
+ addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector);
+
+ deleteKnnDoc(INDEX_NAME, DOC_ID);
+ refreshAllIndices();
+
+ assertEquals(0, getDocCount(INDEX_NAME));
+ }
+
+ @SneakyThrows
+ public void testSearchWithByteVectorUsingFaissEngine() {
+ createKnnIndexMappingWithFaissEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue());
+ ingestL2ByteTestData();
+
+ Byte[] queryVector = { 1, 1 };
+ Response response = searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, convertByteToFloatArray(queryVector), 4), 4);
+
+ validateL2SearchResults(response);
+ }
+
+ @SneakyThrows
+ public void testInvalidVectorDataUsingFaissEngine() {
+ createKnnIndexMappingWithFaissEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue());
+ Float[] vector = { -10.76f, 15.89f };
+
+ ResponseException ex = expectThrows(ResponseException.class, () -> addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector));
+ assertTrue(
+ ex.getMessage()
+ .contains(
+ String.format(
+ Locale.ROOT,
+ "[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers",
+ VECTOR_DATA_TYPE_FIELD,
+ VectorDataType.BYTE.getValue()
+ )
+ )
+ );
+ }
+
+ // Create an index with byte vector data_type and add a doc with values out of byte range which should throw exception
+ @SneakyThrows
+ public void testInvalidByteVectorRangeUsingFaissEngine() {
+ createKnnIndexMappingWithFaissEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue());
+ Float[] vector = { -1000f, 155f };
+
+ ResponseException ex = expectThrows(ResponseException.class, () -> addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector));
+ assertTrue(
+ ex.getMessage()
+ .contains(
+ String.format(
+ Locale.ROOT,
+ "[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [%d, %d]",
+ VECTOR_DATA_TYPE_FIELD,
+ VectorDataType.BYTE.getValue(),
+ Byte.MIN_VALUE,
+ Byte.MAX_VALUE
+ )
+ )
+ );
+ }
+
+ // Create an index with byte vector data_type using faiss engine with an encoder which should throw an exception
+ @SneakyThrows
+ public void testByteVectorDataTypeWithFaissEngineUsingEncoderThrowsException() {
+ XContentBuilder builder = XContentFactory.jsonBuilder()
+ .startObject()
+ .startObject(PROPERTIES_FIELD)
+ .startObject(FIELD_NAME)
+ .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE)
+ .field(DIMENSION, 2)
+ .field(VECTOR_DATA_TYPE_FIELD, VectorDataType.BYTE.getValue())
+ .startObject(KNNConstants.KNN_METHOD)
+ .field(KNNConstants.NAME, METHOD_HNSW)
+ .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2)
+ .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName())
+ .startObject(PARAMETERS)
+ .field(KNNConstants.METHOD_PARAMETER_M, M)
+ .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, EF_CONSTRUCTION)
+ .startObject(METHOD_ENCODER_PARAMETER)
+ .field(NAME, ENCODER_SQ)
+ .startObject(PARAMETERS)
+ .field(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16)
+ .endObject()
+ .endObject()
+ .endObject()
+ .endObject()
+ .endObject()
+ .endObject()
+ .endObject();
+
+ String mapping = builder.toString();
+ expectThrows(ResponseException.class, () -> createKnnIndex(INDEX_NAME, mapping));
+ }
+
+ public void testDocValuesWithByteVectorDataTypeFaissEngine() throws Exception {
+ createKnnIndexMappingWithFaissEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue());
+ ingestL2ByteTestData();
+
+ Byte[] queryVector = { 1, 1 };
+ Request request = createScriptQueryRequest(queryVector, SpaceType.L2.getValue(), MATCH_ALL_QUERY_BUILDER);
+ Response response = client().performRequest(request);
+ assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
+
+ validateL2SearchResults(response);
+ }
+
@SneakyThrows
private void ingestL2ByteTestData() {
Byte[] b1 = { 6, 6 };
@@ -491,6 +629,10 @@ private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spac
createKnnIndexMappingWithCustomEngine(dimension, spaceType, vectorDataType, KNNEngine.LUCENE.getName());
}
+ private void createKnnIndexMappingWithFaissEngine(int dimension, SpaceType spaceType, String vectorDataType) throws Exception {
+ createKnnIndexMappingWithCustomEngine(dimension, spaceType, vectorDataType, KNNEngine.FAISS.getName());
+ }
+
private void createKnnIndexMappingWithCustomEngine(int dimension, SpaceType spaceType, String vectorDataType, String engine)
throws Exception {
XContentBuilder builder = XContentFactory.jsonBuilder()
@@ -504,7 +646,7 @@ private void createKnnIndexMappingWithCustomEngine(int dimension, SpaceType spac
.field(KNNConstants.NAME, METHOD_HNSW)
.field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue())
.field(KNNConstants.KNN_ENGINE, engine)
- .startObject(KNNConstants.PARAMETERS)
+ .startObject(PARAMETERS)
.field(KNNConstants.METHOD_PARAMETER_M, M)
.field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, EF_CONSTRUCTION)
.endObject()
diff --git a/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactoryTests.java b/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactoryTests.java
index cef875cfc..39415d811 100644
--- a/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactoryTests.java
+++ b/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactoryTests.java
@@ -20,7 +20,7 @@ public void testOffHeapVectorTransferFactory() {
assertNotSame(byteVectorTransfer, OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BYTE, 10));
var binaryVectorTransfer = OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BINARY, 10);
- assertEquals(OffHeapByteVectorTransfer.class, binaryVectorTransfer.getClass());
+ assertEquals(OffHeapBinaryVectorTransfer.class, binaryVectorTransfer.getClass());
assertNotSame(binaryVectorTransfer, OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BINARY, 10));
}
}
diff --git a/src/test/java/org/opensearch/knn/index/engine/KNNMethodContextTests.java b/src/test/java/org/opensearch/knn/index/engine/KNNMethodContextTests.java
index 6defa4c50..f142a9770 100644
--- a/src/test/java/org/opensearch/knn/index/engine/KNNMethodContextTests.java
+++ b/src/test/java/org/opensearch/knn/index/engine/KNNMethodContextTests.java
@@ -455,12 +455,12 @@ public void testValidateVectorDataType_whenBinaryNonFaiss_thenException() {
);
}
- public void testValidateVectorDataType_whenByteLucene_thenValid() {
+ public void testValidateVectorDataType_whenByte_thenValid() {
validateValidateVectorDataType(KNNEngine.LUCENE, KNNConstants.METHOD_HNSW, VectorDataType.BYTE, SpaceType.L2, null);
+ validateValidateVectorDataType(KNNEngine.FAISS, KNNConstants.METHOD_HNSW, VectorDataType.BYTE, SpaceType.L2, null);
}
- public void testValidateVectorDataType_whenByteNonLucene_thenException() {
- validateValidateVectorDataType(KNNEngine.FAISS, KNNConstants.METHOD_HNSW, VectorDataType.BYTE, SpaceType.L2, "UnsupportedMethod");
+ public void testValidateVectorDataType_whenByte_thenException() {
validateValidateVectorDataType(KNNEngine.NMSLIB, KNNConstants.METHOD_IVF, VectorDataType.BYTE, SpaceType.L2, "UnsupportedMethod");
}
diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java
index 316582f6c..1e2134581 100644
--- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java
+++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java
@@ -117,7 +117,7 @@ public void testClose_whenBinaryFiass_thenSuccess() {
KNNConstants.VECTOR_DATA_TYPE_FIELD,
VectorDataType.BINARY.getValue()
);
- long vectorMemoryAddress = JNICommons.storeByteVectorData(0, vectors, numVectors * dataLength);
+ long vectorMemoryAddress = JNICommons.storeBinaryVectorData(0, vectors, numVectors * dataLength);
TestUtils.createIndex(ids, vectorMemoryAddress, dimension, path, parameters, knnEngine);
// Load index into memory
diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java
index 8a38cadb5..29fbdb978 100644
--- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java
+++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java
@@ -106,7 +106,7 @@ public void testLoad_whenFaissBinary_thenSuccess() throws IOException {
KNNConstants.VECTOR_DATA_TYPE_FIELD,
VectorDataType.BINARY.getValue()
);
- long memoryAddress = JNICommons.storeByteVectorData(0, vectors, numVectors);
+ long memoryAddress = JNICommons.storeBinaryVectorData(0, vectors, numVectors);
TestUtils.createIndex(ids, memoryAddress, dimension, path, parameters, knnEngine);
// Setup mock resource manager
diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java
index 762a36227..b7de89564 100644
--- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java
+++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java
@@ -961,6 +961,7 @@ public void testRadialSearch_whenEfSearchIsSet_whenFaissEngine_thenSuccess() {
Index dummyIndex = new Index("dummy", "dummy");
when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4));
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
+ when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT);
when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField);
IndexSettings indexSettings = mock(IndexSettings.class);
when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings);
diff --git a/src/testFixtures/java/org/opensearch/knn/TestUtils.java b/src/testFixtures/java/org/opensearch/knn/TestUtils.java
index 6bbbc8a5b..b644c37f8 100644
--- a/src/testFixtures/java/org/opensearch/knn/TestUtils.java
+++ b/src/testFixtures/java/org/opensearch/knn/TestUtils.java
@@ -403,7 +403,7 @@ public long loadDataToMemoryAddress() {
}
public long loadBinaryDataToMemoryAddress() {
- return JNICommons.storeByteVectorData(0, indexBinaryData, (long) indexBinaryData.length * indexBinaryData[0].length, true);
+ return JNICommons.storeBinaryVectorData(0, indexBinaryData, (long) indexBinaryData.length * indexBinaryData[0].length, true);
}
@AllArgsConstructor
From 3cd458b4e2b24162f7fc6e190316f714940b0a34 Mon Sep 17 00:00:00 2001
From: Naveen Tatikonda
Date: Wed, 21 Aug 2024 23:09:41 -0500
Subject: [PATCH 2/2] Address Review Comments
Signed-off-by: Naveen Tatikonda
---
jni/include/commons.h | 2 ++
jni/src/faiss_index_service.cpp | 6 ++---
.../engine/faiss/AbstractFaissMethod.java | 17 ++++++++++++-
.../index/engine/faiss/FaissHNSWMethod.java | 2 +-
.../index/engine/faiss/FaissIVFMethod.java | 2 +-
.../knn/index/mapper/MethodFieldMapper.java | 24 ++++---------------
6 files changed, 26 insertions(+), 27 deletions(-)
diff --git a/jni/include/commons.h b/jni/include/commons.h
index e1aaacd9c..3f1ee19a3 100644
--- a/jni/include/commons.h
+++ b/jni/include/commons.h
@@ -49,6 +49,7 @@ namespace knn_jni {
* @param memoryAddress The address of the memory location where data will be stored.
* @param data 2D byte array containing binary data to be stored in native memory.
* @param initialCapacity The initial capacity of the memory location.
+ * @param append whether to append or start from index 0 when called subsequently with the same address
* @return memory address of std::vector where the data is stored.
*/
jlong storeBinaryVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong, jboolean);
@@ -63,6 +64,7 @@ namespace knn_jni {
* @param memoryAddress The address of the memory location where data will be stored.
* @param data 2D byte array containing int8 data to be stored in native memory.
* @param initialCapacity The initial capacity of the memory location.
+ * @param append whether to append or start from index 0 when called subsequently with the same address
* @return memory address of std::vector where the data is stored.
*/
jlong storeByteVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong, jboolean);
diff --git a/jni/src/faiss_index_service.cpp b/jni/src/faiss_index_service.cpp
index b6e465741..16ded4bcb 100644
--- a/jni/src/faiss_index_service.cpp
+++ b/jni/src/faiss_index_service.cpp
@@ -327,7 +327,8 @@ void ByteIndexService::insertToIndex(
faiss::IndexIDMap * idMap = reinterpret_cast (idMapAddress);
- // Add vectors in batches by casting int8 vectors into float with a batch size of 1000
+ // Add vectors in batches by casting int8 vectors into float with a batch size of 1000 to avoid additional memory spike.
+ // Refer to this github issue for more details https://github.com/opensearch-project/k-NN/issues/1659#issuecomment-2307390255
int batchSize = 1000;
std::vector inputFloatVectors(batchSize * dim);
std::vector floatVectorsIds(batchSize);
@@ -337,8 +338,6 @@ void ByteIndexService::insertToIndex(
for (int id = 0; id < numVectors; id += batchSize) {
if (numVectors - id < batchSize) {
batchSize = numVectors - id;
- inputFloatVectors.resize(batchSize * dim);
- floatVectorsIds.resize(batchSize);
}
for (int i = 0; i < batchSize; ++i) {
@@ -364,6 +363,5 @@ void ByteIndexService::writeIndex(
throw std::runtime_error("Failed to write index to disk");
}
}
-
} // namespace faiss_wrapper
} // namesapce knn_jni
\ No newline at end of file
diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java
index 908671a21..8c2dfc126 100644
--- a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java
+++ b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java
@@ -5,6 +5,7 @@
package org.opensearch.knn.index.engine.faiss;
+import org.apache.commons.lang.StringUtils;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.AbstractKNNMethod;
@@ -20,6 +21,7 @@
import java.util.Objects;
import java.util.Set;
+import static org.opensearch.knn.common.KNNConstants.FAISS_SIGNED_BYTE_SQ;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.index.engine.faiss.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX;
import static org.opensearch.knn.index.engine.faiss.FaissFP16Util.isFaissSQClipToFP16RangeEnabled;
@@ -87,7 +89,7 @@ protected PerDimensionProcessor doGetPerDimensionProcessor(
throw new IllegalStateException("Unsupported vector data type " + vectorDataType);
}
- static KNNLibraryIndexingContext adjustPrefix(
+ static KNNLibraryIndexingContext adjustIndexDescription(
MethodAsMapBuilder methodAsMapBuilder,
MethodComponentContext methodComponentContext,
KNNMethodConfigContext knnMethodConfigContext
@@ -105,6 +107,19 @@ static KNNLibraryIndexingContext adjustPrefix(
if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BINARY) {
prefix = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX;
}
+ if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BYTE) {
+
+ // If VectorDataType is Byte using Faiss engine then manipulate Index Description to use "SQ8_direct_signed" scalar quantizer
+ // For example, Index Description "HNSW16,Flat" will be updated as "HNSW16,SQ8_direct_signed"
+ String indexDescription = methodAsMapBuilder.indexDescription;
+ if (StringUtils.isNotEmpty(indexDescription)) {
+ StringBuilder indexDescriptionBuilder = new StringBuilder();
+ indexDescriptionBuilder.append(indexDescription.split(",")[0]);
+ indexDescriptionBuilder.append(",");
+ indexDescriptionBuilder.append(FAISS_SIGNED_BYTE_SQ);
+ methodAsMapBuilder.indexDescription = indexDescriptionBuilder.toString();
+ }
+ }
methodAsMapBuilder.indexDescription = prefix + methodAsMapBuilder.indexDescription;
return methodAsMapBuilder.build();
}
diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java
index dbe89bd0a..41db777e3 100644
--- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java
+++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java
@@ -99,7 +99,7 @@ private static MethodComponent initMethodComponent() {
methodComponentContext,
knnMethodConfigContext
).addParameter(METHOD_PARAMETER_M, "", "").addParameter(METHOD_ENCODER_PARAMETER, ",", "");
- return adjustPrefix(methodAsMapBuilder, methodComponentContext, knnMethodConfigContext);
+ return adjustIndexDescription(methodAsMapBuilder, methodComponentContext, knnMethodConfigContext);
}))
.build();
}
diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java
index bc30e372c..b3dd12c92 100644
--- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java
+++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java
@@ -95,7 +95,7 @@ private static MethodComponent initMethodComponent() {
methodComponentContext,
knnMethodConfigContext
).addParameter(METHOD_PARAMETER_NLIST, "", "").addParameter(METHOD_ENCODER_PARAMETER, ",", "");
- return adjustPrefix(methodAsMapBuilder, methodComponentContext, knnMethodConfigContext);
+ return adjustIndexDescription(methodAsMapBuilder, methodComponentContext, knnMethodConfigContext);
}))
.setOverheadInKBEstimator((methodComponent, methodComponentContext, dimension) -> {
// Size estimate formula: (4 * nlists * d) / 1024 + 1
diff --git a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java
index 75f01ba24..90d4ca879 100644
--- a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java
+++ b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java
@@ -11,7 +11,6 @@
import org.opensearch.common.Explicit;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.knn.index.SpaceType;
-import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.engine.KNNLibraryIndexingContext;
@@ -25,8 +24,6 @@
import java.util.Optional;
import static org.opensearch.knn.common.KNNConstants.DIMENSION;
-import static org.opensearch.knn.common.KNNConstants.FAISS_SIGNED_BYTE_SQ;
-import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.common.KNNConstants.QFRAMEWORK_CONFIG;
@@ -130,23 +127,10 @@ private MethodFieldMapper(
this.fieldType.putAttribute(KNN_ENGINE, knnEngine.getName());
try {
- Map libParams = knnLibraryIndexingContext.getLibraryParameters();
-
- // If VectorDataType is Byte using Faiss engine then manipulate Index Description to use "SQ8_direct_signed" scalar quantizer
- // For example, Index Description "HNSW16,Flat" will be updated as "HNSW16,SQ8_direct_signed"
- if (VectorDataType.BYTE == vectorDataType && libParams.containsKey(INDEX_DESCRIPTION_PARAMETER)) {
- String indexDescriptionValue = (String) libParams.get(INDEX_DESCRIPTION_PARAMETER);
- if (indexDescriptionValue != null && indexDescriptionValue.isEmpty() == false) {
- StringBuilder indexDescriptionBuilder = new StringBuilder();
- indexDescriptionBuilder.append(indexDescriptionValue.split(",")[0]);
- indexDescriptionBuilder.append(",");
- indexDescriptionBuilder.append(FAISS_SIGNED_BYTE_SQ);
-
- libParams.replace(INDEX_DESCRIPTION_PARAMETER, indexDescriptionBuilder.toString());
- libParams.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.BYTE.getValue());
- }
- }
- this.fieldType.putAttribute(PARAMETERS, XContentFactory.jsonBuilder().map(libParams).toString());
+ this.fieldType.putAttribute(
+ PARAMETERS,
+ XContentFactory.jsonBuilder().map(knnLibraryIndexingContext.getLibraryParameters()).toString()
+ );
} catch (IOException ioe) {
throw new RuntimeException(String.format("Unable to create KNNVectorFieldMapper: %s", ioe));
}