Skip to content

Commit

Permalink
Integrating storeVectors interfaces with createIndex and createIndexT…
Browse files Browse the repository at this point in the history
…emplate functions. (#1588)

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v authored Apr 4, 2024
1 parent fccc5a9 commit badbb1d
Show file tree
Hide file tree
Showing 22 changed files with 359 additions and 313 deletions.
4 changes: 2 additions & 2 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ namespace knn_jni {
namespace faiss_wrapper {
// Create an index with ids and vectors. The configuration is defined by values in the Java map, parametersJ.
// The index is serialized to indexPathJ.
void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jobjectArray vectorsJ,
void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ,
jstring indexPathJ, jobject parametersJ);

// Create an index with ids and vectors. Instead of creating a new index, this function creates the index
// based off of the template index passed in. The index is serialized to indexPathJ.
void CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jobjectArray vectorsJ, jstring indexPathJ, jbyteArray templateIndexJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ,
jobject parametersJ);

// Load an index from indexPathJ into memory.
Expand Down
2 changes: 1 addition & 1 deletion jni/include/nmslib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace knn_jni {
namespace nmslib_wrapper {
// Create an index with ids and vectors. The configuration is defined by values in the Java map, parametersJ.
// The index is serialized to indexPathJ.
void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jobjectArray vectorsJ,
void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddress, jint dim,
jstring indexPathJ, jobject parametersJ);

// Load an index from indexPathJ into memory. Use parametersJ to set any query time parameters
Expand Down
16 changes: 4 additions & 12 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@ extern "C" {
/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createIndex
* Signature: ([I[[FLjava/lang/String;Ljava/util/Map;)V
* Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex
(JNIEnv *, jclass, jintArray, jobjectArray, jstring, jobject);
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createIndexFromTemplate
* Signature: ([I[[FLjava/lang/String;[BLjava/util/Map;)V
* Signature: ([IJILjava/lang/String;[BLjava/util/Map;)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate
(JNIEnv *, jclass, jintArray, jobjectArray, jstring, jbyteArray, jobject);
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
Expand Down Expand Up @@ -122,14 +122,6 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors
(JNIEnv *, jclass, jlong, jobjectArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: freeVectors
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_freeVectors
(JNIEnv *, jclass, jlong);

#ifdef __cplusplus
}
#endif
Expand Down
4 changes: 2 additions & 2 deletions jni/include/org_opensearch_knn_jni_NmslibService.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ extern "C" {
/*
* Class: org_opensearch_knn_jni_NmslibService
* Method: createIndex
* Signature: ([I[[FLjava/lang/String;Ljava/util/Map;)V
* Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_createIndex
(JNIEnv *, jclass, jintArray, jobjectArray, jstring, jobject);
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject);

/*
* Class: org_opensearch_knn_jni_NmslibService
Expand Down
48 changes: 30 additions & 18 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,19 @@ bool isIndexIVFPQL2(faiss::Index * index);
// IndexIDMap which has member that will point to underlying index that stores the data
faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index);

void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jobjectArray vectorsJ, jstring indexPathJ, jobject parametersJ) {
void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ,
jstring indexPathJ, jobject parametersJ) {

if (idsJ == nullptr) {
throw std::runtime_error("IDs cannot be null");
}

if (vectorsJ == nullptr) {
throw std::runtime_error("Vectors cannot be null");
if (vectorsAddressJ <= 0) {
throw std::runtime_error("VectorsAddress cannot be less than 0");
}

if(dimJ <= 0) {
throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0");
}

if (indexPathJ == nullptr) {
Expand All @@ -109,16 +113,20 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN
std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ));
faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp);

// Read data set
int numVectors = jniUtil->GetJavaObjectArrayLength(env, vectorsJ);
// Read vectors from memory address
auto *inputVectors = reinterpret_cast<std::vector<float>*>(vectorsAddressJ);
int dim = (int)dimJ;
// The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value
int numVectors = (int) (inputVectors->size() / (uint64_t) dim);
if(numVectors == 0) {
throw std::runtime_error("Number of vectors cannot be 0");
}

int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);
if (numIds != numVectors) {
throw std::runtime_error("Number of IDs does not match number of vectors");
}

int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ);
auto dataset = jniUtil->Convert2dJavaObjectArrayToCppFloatVector(env, vectorsJ, dim);

// Create faiss index
jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION);
std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ));
Expand Down Expand Up @@ -148,22 +156,26 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN

auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get());
idMap.add_with_ids(numVectors, dataset.data(), idVector.data());
idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data());

// Write the index to disk
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
faiss::write_index(&idMap, indexPathCpp.c_str());
}

void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jobjectArray vectorsJ, jstring indexPathJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ,
jbyteArray templateIndexJ, jobject parametersJ) {
if (idsJ == nullptr) {
throw std::runtime_error("IDs cannot be null");
}

if (vectorsJ == nullptr) {
throw std::runtime_error("Vectors cannot be null");
if (vectorsAddressJ <= 0) {
throw std::runtime_error("VectorsAddress cannot be less than 0");
}

if(dimJ <= 0) {
throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0");
}

if (indexPathJ == nullptr) {
Expand All @@ -183,15 +195,15 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface *
jniUtil->DeleteLocalRef(env, parametersJ);

// Read data set
int numVectors = jniUtil->GetJavaObjectArrayLength(env, vectorsJ);
// Read vectors from memory address
auto *inputVectors = reinterpret_cast<std::vector<float>*>(vectorsAddressJ);
int dim = (int)dimJ;
int numVectors = (int) (inputVectors->size() / (uint64_t) dim);
int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);
if (numIds != numVectors) {
throw std::runtime_error("Number of IDs does not match number of vectors");
}

int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ);
auto dataset = jniUtil->Convert2dJavaObjectArrayToCppFloatVector(env, vectorsJ, dim);

// Get vector of bytes from jbytearray
int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ);
jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr);
Expand All @@ -208,7 +220,7 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface *

auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get());
idMap.add_with_ids(numVectors, dataset.data(), idVector.data());
idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data());

// Write the index to disk
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
Expand Down
39 changes: 23 additions & 16 deletions jni/src/nmslib_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,19 @@ std::string TranslateSpaceType(const std::string& spaceType);
const similarity::LabelType DEFAULT_LABEL = -1;

void knn_jni::nmslib_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jobjectArray vectorsJ, jstring indexPathJ, jobject parametersJ) {
jlong vectorsAddressJ, jint dimJ,
jstring indexPathJ, jobject parametersJ) {

if (idsJ == nullptr) {
throw std::runtime_error("IDs cannot be null");
}

if (vectorsJ == nullptr) {
throw std::runtime_error("Vectors cannot be null");
if (vectorsAddressJ <= 0) {
throw std::runtime_error("VectorsAddress cannot be less than 0");
}

if(dimJ <= 0) {
throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0");
}

if (indexPathJ == nullptr) {
Expand Down Expand Up @@ -91,12 +96,18 @@ void knn_jni::nmslib_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, J
space.reset(similarity::SpaceFactoryRegistry<float>::Instance().CreateSpace(spaceTypeCpp,similarity::AnyParams()));

// Get number of ids and vectors and dimension
int numVectors = jniUtil->GetJavaObjectArrayLength(env, vectorsJ);
auto *inputVectors = reinterpret_cast<std::vector<float>*>(vectorsAddressJ);
int dim = (int)dimJ;
// The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value
int numVectors = (int) ( inputVectors->size() / (uint64_t) dim);
if(numVectors == 0) {
throw std::runtime_error("Number of vectors cannot be 0");
}

int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);
if (numIds != numVectors) {
throw std::runtime_error("Number of IDs does not match number of vectors");
}
int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ);

// Read dataset
similarity::ObjectVector dataset;
Expand All @@ -105,10 +116,12 @@ void knn_jni::nmslib_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, J
try {
// Read in data set
idsCpp = jniUtil->GetIntArrayElements(env, idsJ, nullptr);

float* floatArrayCpp;
jfloatArray floatArrayJ;
size_t vectorSizeInBytes = dim*sizeof(float);
// vectorPointer needs to be unsigned long long, this will ensure that out of range doesn't happen for this pointer
// when the values of numVectors * dim becomes very large.
// Example: for 10M vectors of 1536 dim vectorPointer max value will be ~15.3B which is already > range of ints.
// keeping it unsigned long long we will never go above the range.
unsigned long long vectorPointer = 0;

// Allocate a large buffer that will contain all the vectors. Allocating the objects in one large buffer as
// opposed to individually will prevent heap fragmentation. We have observed that allocating individual
Expand All @@ -134,15 +147,9 @@ void knn_jni::nmslib_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, J
memcpy(ptr, &vectorSizeInBytes, similarity::DATALENGTH_SIZE);
ptr += similarity::DATALENGTH_SIZE;

floatArrayJ = (jfloatArray)jniUtil->GetObjectArrayElement(env, vectorsJ, i);
if (dim != jniUtil->GetJavaFloatArrayLength(env, floatArrayJ)) {
throw std::runtime_error("Dimension of vectors is inconsistent");
}

floatArrayCpp = jniUtil->GetFloatArrayElements(env, floatArrayJ, nullptr);
memcpy(ptr, floatArrayCpp, vectorSizeInBytes);
jniUtil->ReleaseFloatArrayElements(env, floatArrayJ, floatArrayCpp, JNI_ABORT);
memcpy(ptr, &(inputVectors->at(vectorPointer)), vectorSizeInBytes);
ptr += vectorSizeInBytes;
vectorPointer += dim;
}
jniUtil->ReleaseIntArrayElements(env, idsJ, idsCpp, JNI_ABORT);

Expand Down
20 changes: 6 additions & 14 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,26 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) {
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex(JNIEnv * env, jclass cls, jintArray idsJ,
jobjectArray vectorsJ, jstring indexPathJ,
jobject parametersJ)
jlong vectorsAddressJ, jint dimJ,
jstring indexPathJ, jobject parametersJ)
{
try {
knn_jni::faiss_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsJ, indexPathJ, parametersJ);
knn_jni::faiss_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, parametersJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate(JNIEnv * env, jclass cls,
jintArray idsJ,
jobjectArray vectorsJ,
jlong vectorsAddressJ,
jint dimJ,
jstring indexPathJ,
jbyteArray templateIndexJ,
jobject parametersJ)
{
try {
knn_jni::faiss_wrapper::CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsJ, indexPathJ, templateIndexJ, parametersJ);
knn_jni::faiss_wrapper::CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
Expand Down Expand Up @@ -189,12 +190,3 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors

return (jlong) vect;
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_freeVectors(JNIEnv * env, jclass cls,
jlong vectorsPointerJ)
{
if (vectorsPointerJ != 0) {
auto *vect = reinterpret_cast<std::vector<float>*>(vectorsPointerJ);
delete vect;
}
}
4 changes: 2 additions & 2 deletions jni/src/org_opensearch_knn_jni_NmslibService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) {
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_createIndex(JNIEnv * env, jclass cls, jintArray idsJ,
jobjectArray vectorsJ, jstring indexPathJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ,
jobject parametersJ)
{
try {
knn_jni::nmslib_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsJ, indexPathJ, parametersJ);
knn_jni::nmslib_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, parametersJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
Expand Down
Loading

0 comments on commit badbb1d

Please sign in to comment.