Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HNSW changes to support Faiss byte vector #1823

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
* Integrate Lucene Vector field with native engines to use KNNVectorFormat during segment creation [#1945](https://github.com/opensearch-project/k-NN/pull/1945)
* k-NN query rescore support for native engines [#1984](https://github.com/opensearch-project/k-NN/pull/1984)
* Add support for byte vector with Faiss Engine HNSW algorithm [#1823](https://github.com/opensearch-project/k-NN/pull/1823)
### Enhancements
* Adds iterative graph build capability into a faiss index to improve the memory footprint during indexing and Integrates KNNVectorsFormat for native engines[#1950](https://github.com/opensearch-project/k-NN/pull/1950)
### Bug Fixes
Expand Down
26 changes: 25 additions & 1 deletion jni/include/commons.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,26 @@ namespace knn_jni {
* CAUTION: The behavior is undefined if the memory address is deallocated and the method is called
*
* @param memoryAddress The address of the memory location where data will be stored.
* @param data 2D byte array containing data to be stored in native memory.
* @param data 2D byte array containing binary data to be stored in native memory.
* @param initialCapacity The initial capacity of the memory location.
* @param append whether to append or start from index 0 when called subsequently with the same address
* @return memory address of std::vector<uint8_t> where the data is stored.
*/
jlong storeBinaryVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong, jboolean);
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved

/**
* This is utility function that can be used to store signed int8 data in native memory. This function will allocate memory for
* the data(rows*columns) with initialCapacity and return the memory address where the data is stored.
* If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created.
* For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location
* will throw Exception.
*
* @param memoryAddress The address of the memory location where data will be stored.
* @param data 2D byte array containing int8 data to be stored in native memory.
* @param initialCapacity The initial capacity of the memory location.
* @param append whether to append or start from index 0 when called subsequently with the same address
* @return memory address of std::vector<int8_t> where the data is stored.
*/
jlong storeByteVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong, jboolean);
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved

/**
Expand All @@ -69,6 +85,14 @@ namespace knn_jni {
*/
void freeByteVectorData(jlong);

/**
* Free up the memory allocated for the data stored in memory address. This function should be used with the memory
* address returned by {@link JNICommons#storeBinaryVectorData(long, byte[][], long, long)}
*
* @param memoryAddress address to be freed.
*/
void freeBinaryVectorData(jlong);

/**
* Extracts query time efSearch from method parameters
**/
Expand Down
57 changes: 57 additions & 0 deletions jni/include/faiss_index_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,63 @@ class BinaryIndexService : public IndexService {
virtual void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) override;
};

/**
* A class to provide operations on index
* This class should evolve to have only cpp object but not jni object
*/
class ByteIndexService : public IndexService {
public:
//TODO Remove dependency on JNIUtilInterface and JNIEnv
//TODO Reduce the number of parameters
ByteIndexService(std::unique_ptr<FaissMethods> faissMethods);

/**
* Initialize index
*
* @param jniUtil jni util
* @param env jni environment
* @param metric space type for distance calculation
* @param indexDescription index description to be used by faiss index factory
* @param dim dimension of vectors
* @param numVectors number of vectors
* @param threadCount number of thread count to be used while adding data
* @param parameters parameters to be applied to faiss index
* @return memory address of the native index object
*/
virtual jlong initIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int dim, int numVectors, int threadCount, std::unordered_map<std::string, jobject> parameters) override;
/**
* Add vectors to index
*
* @param jniUtil jni util
* @param env jni environment
* @param metric space type for distance calculation
* @param indexDescription index description to be used by faiss index factory
* @param dim dimension of vectors
* @param numIds number of vectors
* @param threadCount number of thread count to be used while adding data
* @param vectorsAddress memory address which is holding vector data
* @param idMap a map of document id and vector id
* @param parameters parameters to be applied to faiss index
*/
virtual void insertToIndex(int dim, int numIds, int threadCount, int64_t vectorsAddress, std::vector<int64_t> &ids, jlong idMapAddress) override;
/**
* Write index to disk
*
* @param jniUtil jni util
* @param env jni environment
* @param metric space type for distance calculation
* @param indexDescription index description to be used by faiss index factory
* @param threadCount number of thread count to be used while adding data
* @param indexPath path to write index
* @param idMap a map of document id and vector id
* @param parameters parameters to be applied to faiss index
*/
virtual void writeIndex(std::string indexPath, jlong idMapAddress) override;
virtual ~ByteIndexService() = default;
protected:
virtual void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) override;
};

}
}

Expand Down
7 changes: 5 additions & 2 deletions jni/include/jni_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@ namespace knn_jni {

virtual void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ,
int dim, std::vector<float> *vect ) = 0;
virtual void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ,
virtual void Convert2dJavaObjectArrayAndStoreToBinaryVector(JNIEnv *env, jobjectArray array2dJ,
int dim, std::vector<uint8_t> *vect ) = 0;
virtual void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ,
int dim, std::vector<int8_t> *vect ) = 0;

virtual std::vector<int64_t> ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) = 0;

Expand Down Expand Up @@ -173,7 +175,8 @@ namespace knn_jni {
void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val);
void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf);
void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<float> *vect);
void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<uint8_t> *vect);
void Convert2dJavaObjectArrayAndStoreToBinaryVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<uint8_t> *vect);
void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<int8_t> *vect);

private:
std::unordered_map<std::string, jclass> cachedClasses;
Expand Down
30 changes: 30 additions & 0 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initIndex(JNIEn
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initBinaryIndex(JNIEnv * env, jclass cls,
jlong numDocs, jint dimJ,
jobject parametersJ);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: initByteIndex
* Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initByteIndex(JNIEnv * env, jclass cls,
jlong numDocs, jint dimJ,
jobject parametersJ);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: insertToIndex
Expand All @@ -50,6 +60,16 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JN
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToBinaryIndex(JNIEnv * env, jclass cls, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ,
jlong indexAddress, jint threadCount);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: insertToByteIndex
* Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToByteIndex(JNIEnv * env, jclass cls, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ,
jlong indexAddress, jint threadCount);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: writeIndex
Expand All @@ -66,6 +86,16 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex(JNIEn
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeBinaryIndex(JNIEnv * env, jclass cls,
jlong indexAddress,
jstring indexPathJ);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: writeByteIndex
* Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeByteIndex(JNIEnv * env, jclass cls,
jlong indexAddress,
jstring indexPathJ);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createIndexFromTemplate
Expand Down
20 changes: 18 additions & 2 deletions jni/include/org_opensearch_knn_jni_JNICommons.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,15 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData

/*
* Class: org_opensearch_knn_jni_JNICommons
* Method: storeVectorData
* Method: storeBinaryVectorData
* Signature: (J[[FJJ)
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeBinaryVectorData
(JNIEnv *, jclass, jlong, jobjectArray, jlong, jboolean);

/*
* Class: org_opensearch_knn_jni_JNICommons
* Method: storeByteVectorData
* Signature: (J[[FJJ)
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeByteVectorData
Expand All @@ -44,7 +52,15 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeVectorData

/*
* Class: org_opensearch_knn_jni_JNICommons
* Method: freeVectorData
* Method: freeBinaryVectorData
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeBinaryVectorData
(JNIEnv *, jclass, jlong);

/*
* Class: org_opensearch_knn_jni_JNICommons
* Method: freeByteVectorData
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeByteVectorData
Expand Down
31 changes: 29 additions & 2 deletions jni/src/commons.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jlong knn_jni::commons::storeVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIE
return (jlong) vect;
}

jlong knn_jni::commons::storeByteVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ,
jlong knn_jni::commons::storeBinaryVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ,
jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ) {
std::vector<uint8_t> *vect;
if ((long) memoryAddressJ == 0) {
Expand All @@ -51,6 +51,26 @@ jlong knn_jni::commons::storeByteVectorData(knn_jni::JNIUtilInterface *jniUtil,
vect->clear();
}

int dim = jniUtil->GetInnerDimensionOf2dJavaByteArray(env, dataJ);
jniUtil->Convert2dJavaObjectArrayAndStoreToBinaryVector(env, dataJ, dim, vect);

return (jlong) vect;
}

jlong knn_jni::commons::storeByteVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ,
jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ) {
std::vector<int8_t> *vect;
if (memoryAddressJ == 0) {
vect = new std::vector<int8_t>();
vect->reserve(static_cast<long>(initialCapacityJ));
} else {
vect = reinterpret_cast<std::vector<int8_t>*>(memoryAddressJ);
}

if (appendJ == JNI_FALSE) {
vect->clear();
navneet1v marked this conversation as resolved.
Show resolved Hide resolved
}

int dim = jniUtil->GetInnerDimensionOf2dJavaByteArray(env, dataJ);
jniUtil->Convert2dJavaObjectArrayAndStoreToByteVector(env, dataJ, dim, vect);

Expand All @@ -64,13 +84,20 @@ void knn_jni::commons::freeVectorData(jlong memoryAddressJ) {
}
}

void knn_jni::commons::freeByteVectorData(jlong memoryAddressJ) {
void knn_jni::commons::freeBinaryVectorData(jlong memoryAddressJ) {
if (memoryAddressJ != 0) {
auto *vect = reinterpret_cast<std::vector<uint8_t>*>(memoryAddressJ);
delete vect;
}
}

void knn_jni::commons::freeByteVectorData(jlong memoryAddressJ) {
if (memoryAddressJ != 0) {
auto *vect = reinterpret_cast<std::vector<int8_t>*>(memoryAddressJ);
delete vect;
}
}

int knn_jni::commons::getIntegerMethodParameter(JNIEnv * env, knn_jni::JNIUtilInterface * jniUtil, std::unordered_map<std::string, jobject> methodParams, std::string methodParam, int defaultValue) {
if (methodParams.empty()) {
return defaultValue;
Expand Down
113 changes: 113 additions & 0 deletions jni/src/faiss_index_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,5 +250,118 @@ void BinaryIndexService::writeIndex(
}
}

ByteIndexService::ByteIndexService(std::unique_ptr<FaissMethods> faissMethods) : IndexService(std::move(faissMethods)) {}

void ByteIndexService::allocIndex(faiss::Index * index, size_t dim, size_t numVectors) {
if(auto * indexHNSWSQ = dynamic_cast<faiss::IndexHNSWSQ *>(index)) {
if(auto * indexScalarQuantizer = dynamic_cast<faiss::IndexScalarQuantizer *>(indexHNSWSQ->storage)) {
indexScalarQuantizer->codes.reserve(indexScalarQuantizer->code_size * numVectors);
}
return;
}
}

jlong ByteIndexService::initIndex(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
faiss::MetricType metric,
std::string indexDescription,
int dim,
int numVectors,
int threadCount,
std::unordered_map<std::string, jobject> parameters
) {
// Create index using Faiss factory method
std::unique_ptr<faiss::Index> index(faissMethods->indexFactory(dim, indexDescription.c_str(), metric));

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
if(threadCount != 0) {
omp_set_num_threads(threadCount);
}

// Add extra parameters that cant be configured with the index factory
SetExtraParameters<faiss::Index, faiss::IndexIVF, faiss::IndexHNSW>(jniUtil, env, parameters, index.get());

// Check that the index does not need to be trained
if(!index->is_trained) {
throw std::runtime_error("Index is not trained");
}

std::unique_ptr<faiss::IndexIDMap> idMap (faissMethods->indexIdMap(index.get()));
//Makes sure the index is deleted when the destructor is called, this cannot be passed in the constructor
idMap->own_fields = true;

allocIndex(dynamic_cast<faiss::Index *>(idMap->index), dim, numVectors);

//Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later
//in insert and write operations
index.release();
return reinterpret_cast<jlong>(idMap.release());
}

void ByteIndexService::insertToIndex(
int dim,
int numIds,
int threadCount,
int64_t vectorsAddress,
std::vector<int64_t> & ids,
jlong idMapAddress
) {
// Read vectors from memory address
auto *inputVectors = reinterpret_cast<std::vector<int8_t>*>(vectorsAddress);

// The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value
int numVectors = inputVectors->size() / dim;
if(numVectors == 0) {
throw std::runtime_error("Number of vectors cannot be 0");
}

if (numIds != numVectors) {
throw std::runtime_error("Number of IDs does not match number of vectors");
}

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
if(threadCount != 0) {
omp_set_num_threads(threadCount);
}

faiss::IndexIDMap * idMap = reinterpret_cast<faiss::IndexIDMap *> (idMapAddress);

// Add vectors in batches by casting int8 vectors into float with a batch size of 1000 to avoid additional memory spike.
// Refer to this github issue for more details https://github.com/opensearch-project/k-NN/issues/1659#issuecomment-2307390255
int batchSize = 1000;
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
std::vector <float> inputFloatVectors(batchSize * dim);
std::vector <int64_t> floatVectorsIds(batchSize);
int id = 0;
auto iter = inputVectors->begin();

for (int id = 0; id < numVectors; id += batchSize) {
if (numVectors - id < batchSize) {
batchSize = numVectors - id;
}

for (int i = 0; i < batchSize; ++i) {
floatVectorsIds[i] = ids[id + i];
for (int j = 0; j < dim; ++j, ++iter) {
inputFloatVectors[i * dim + j] = static_cast<float>(*iter);
}
}
idMap->add_with_ids(batchSize, inputFloatVectors.data(), floatVectorsIds.data());
}
}

void ByteIndexService::writeIndex(
std::string indexPath,
jlong idMapAddress
) {
std::unique_ptr<faiss::IndexIDMap> idMap (reinterpret_cast<faiss::IndexIDMap *> (idMapAddress));

try {
// Write the index to disk
faissMethods->writeIndex(idMap.get(), indexPath.c_str());
} catch(std::exception &e) {
throw std::runtime_error("Failed to write index to disk");
}
}
} // namespace faiss_wrapper
} // namesapce knn_jni
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading