Skip to content

Commit

Permalink
Add the interface for streaming the vectors from java to jni layer wi…
Browse files Browse the repository at this point in the history
…th initial capacity

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v committed Mar 29, 2024
1 parent 771c4b5 commit e980b6f
Show file tree
Hide file tree
Showing 13 changed files with 357 additions and 70 deletions.
3 changes: 2 additions & 1 deletion jni/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ endif()
# ----------------------------------------------------------------------------

# ---------------------------------- COMMON ----------------------------------
add_library(${TARGET_LIB_COMMON} SHARED ${CMAKE_CURRENT_SOURCE_DIR}/src/jni_util.cpp)
add_library(${TARGET_LIB_COMMON} SHARED ${CMAKE_CURRENT_SOURCE_DIR}/src/jni_util.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/org_opensearch_knn_jni_JNICommons.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/commons.cpp)
target_include_directories(${TARGET_LIB_COMMON} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include $ENV{JAVA_HOME}/include $ENV{JAVA_HOME}/include/${JVM_OS_TYPE})
set_target_properties(${TARGET_LIB_COMMON} PROPERTIES SUFFIX ${LIB_EXT})
set_target_properties(${TARGET_LIB_COMMON} PROPERTIES POSITION_INDEPENDENT_CODE ON)
Expand Down Expand Up @@ -236,6 +236,7 @@ if ("${WIN32}" STREQUAL "")
tests/faiss_util_test.cpp
tests/nmslib_wrapper_test.cpp
tests/test_util.cpp
tests/commons_test.cpp
)

target_link_libraries(
Expand Down
38 changes: 38 additions & 0 deletions jni/include/commons.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/
#include "jni_util.h"
#include <jni.h>
namespace knn_jni {
namespace commons {
/**
* This is utility function that can be used to store data in native memory. This function will allocate memory for
* the data(rows*columns) with initialCapacity and return the memory address where the data is stored.
* If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created.
* For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location
* will throw Exception.
*
* @param memoryAddress The address of the memory location where data will be stored.
* @param data 2D float array containing data to be stored in native memory.
* @param initialCapacity The initial capacity of the memory location.
* @param startIndex The starting index of the data to be stored.
* @return memory address where the data is stored.
*/
jlong storeVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong , jlong );

/**
* Free up the memory allocated for the data stored in memory address. This function should be used with the memory
* address returned by {@link JNICommons#storeVectorData(long, float[][], long, long)}
*
* @param memoryAddress address to be freed.
*/
void freeVectorData(jlong);
}
}
8 changes: 0 additions & 8 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,6 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors
(JNIEnv *, jclass, jlong, jobjectArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: transferVectorsV2
* Signature: (J[[F)J
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectorsV2
(JNIEnv *, jclass, jlong, jobjectArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: freeVectors
Expand Down
40 changes: 40 additions & 0 deletions jni/include/org_opensearch_knn_jni_JNICommons.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

/* DO NOT EDIT THIS FILE - it is machine generated */
#include <jni.h>
/* Header for class org_opensearch_knn_jni_JNICommons */

#ifndef _Included_org_opensearch_knn_jni_JNICommons
#define _Included_org_opensearch_knn_jni_JNICommons
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: org_opensearch_knn_jni_JNICommons
* Method: storeVectorData
* Signature: (J[[FJJ)J
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData
(JNIEnv *, jclass, jlong, jobjectArray, jlong, jlong);

/*
* Class: org_opensearch_knn_jni_JNICommons
* Method: freeVectorData
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeVectorData
(JNIEnv *, jclass, jlong);

#ifdef __cplusplus
}
#endif
#endif
52 changes: 52 additions & 0 deletions jni/src/commons.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/
#ifndef OPENSEARCH_KNN_COMMONS_H
#define OPENSEARCH_KNN_COMMONS_H
#include <jni.h>

#include <vector>

#include "jni_util.h"
#include "commons.h"

jlong knn_jni::commons::storeVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ,
jobjectArray dataJ, jlong initialCapacityJ, jlong startingIndexJ) {
std::vector<float> *vect;
if ((long) memoryAddressJ == 0) {
vect = new std::vector<float>((long)initialCapacityJ);
} else {
vect = reinterpret_cast<std::vector<float>*>(memoryAddressJ);
}
long startingIndex = (long) startingIndexJ;
int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, dataJ);
auto dataset = jniUtil->Convert2dJavaObjectArrayToCppFloatVector(env, dataJ, dim);

if (startingIndex + dataset.size() > vect->size()) {
throw std::runtime_error("Cannot store data on memory location as total size: "
+ std::to_string(startingIndex + dataset.size()) + " will be greater than actual capacity: "
+ std::to_string(vect->size()));
}

// set the vector values at correct location.
for(int i = 0 ; i < dataset.size() ; i ++) {
vect->at(startingIndex + i) = dataset[i];
}

return (jlong) vect;
}

void knn_jni::commons::freeVectorData(jlong memoryAddressJ) {
if (memoryAddressJ != 0) {
auto *vect = reinterpret_cast<std::vector<float>*>(memoryAddressJ);
delete vect;
}
}
#endif //OPENSEARCH_KNN_COMMONS_H
19 changes: 0 additions & 19 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

#include <jni.h>

#include <algorithm>
#include <vector>

#include "faiss_wrapper.h"
Expand Down Expand Up @@ -191,24 +190,6 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors
return (jlong) vect;
}

JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectorsV2(JNIEnv * env, jclass cls,
jlong vectorsPointerJ,
jobjectArray vectorsJ)
{
std::vector<float> *vect;
if ((long) vectorsPointerJ == 0) {
vect = new std::vector<float>;
} else {
vect = reinterpret_cast<std::vector<float>*>(vectorsPointerJ);
}

int dim = jniUtil.GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ);
auto dataset = jniUtil.Convert2dJavaObjectArrayToCppFloatVector(env, vectorsJ, dim);
vect->insert(vect->end(), dataset.begin(), dataset.end());

return (jlong) vect;
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_freeVectors(JNIEnv * env, jclass cls,
jlong vectorsPointerJ)
{
Expand Down
60 changes: 60 additions & 0 deletions jni/src/org_opensearch_knn_jni_JNICommons.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

#include "org_opensearch_knn_jni_JNICommons.h"

#include <jni.h>
#include "commons.h"
#include "jni_util.h"

static knn_jni::JNIUtil jniUtil;
static const jint KNN_JNICOMMONS_JNI_VERSION = JNI_VERSION_1_1;

jint JNI_OnLoad(JavaVM* vm, void* reserved) {
// Obtain the JNIEnv from the VM and confirm JNI_VERSION
JNIEnv* env;
if (vm->GetEnv((void**)&env, KNN_JNICOMMONS_JNI_VERSION) != JNI_OK) {
return JNI_ERR;
}

jniUtil.Initialize(env);

return KNN_JNICOMMONS_JNI_VERSION;
}

void JNI_OnUnload(JavaVM *vm, void *reserved) {
JNIEnv* env;
vm->GetEnv((void**)&env, KNN_JNICOMMONS_JNI_VERSION);
jniUtil.Uninitialize(env);
}


JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData(JNIEnv * env, jclass cls,
jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ, jlong startingIndexJ)

{
try {
return knn_jni::commons::storeVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ, startingIndexJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return (long)memoryAddressJ;
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeVectorData(JNIEnv * env, jclass cls,
jlong memoryAddressJ)
{
try {
return knn_jni::commons::freeVectorData(memoryAddressJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
}
76 changes: 76 additions & 0 deletions jni/tests/commons_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/


#include "test_util.h"
#include <vector>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "jni_util.h"
#include "commons.h"

TEST(CommonsTests, BasicAssertions) {
long dim = 3;
long totalNumberOfVector = 5;
std::vector<std::vector<float>> data;
for(int i = 0 ; i < totalNumberOfVector - 1 ; i++) {
std::vector<float> vector;
for(int j = 0 ; j < dim ; j ++) {
vector.push_back((float)j);
}
data.push_back(vector);
}
JNIEnv *jniEnv = nullptr;

testing::NiceMock<test_util::MockJNIUtil> mockJNIUtil;

jlong memoryAddress = knn_jni::commons::storeVectorData(&mockJNIUtil, jniEnv, (jlong)0,
reinterpret_cast<jobjectArray>(&data), (jlong)(totalNumberOfVector * dim) , (jlong)0);
ASSERT_NE(memoryAddress, 0);
auto *vect = reinterpret_cast<std::vector<float>*>(memoryAddress);
ASSERT_EQ(vect->size(), totalNumberOfVector*dim);

// Check by inserting more vectors at same memory location
jlong oldMemoryAddress = memoryAddress;
std::vector<std::vector<float>> data2;
std::vector<float> vector;
for(int j = 0 ; j < dim ; j ++) {
vector.push_back((float)j);
}
data2.push_back(vector);
memoryAddress = knn_jni::commons::storeVectorData(&mockJNIUtil, jniEnv, memoryAddress,
reinterpret_cast<jobjectArray>(&data2), (jlong)(totalNumberOfVector * dim) , (jlong)(data.size() * dim));
ASSERT_NE(memoryAddress, 0);
ASSERT_EQ(memoryAddress, oldMemoryAddress);
vect = reinterpret_cast<std::vector<float>*>(memoryAddress);
int currentIndex = 0;
ASSERT_EQ(vect->size(), totalNumberOfVector*dim);

// Validate if all vectors data are at correct location
for(auto & i : data) {
for(float j : i) {
ASSERT_FLOAT_EQ(vect->at(currentIndex), j);
currentIndex++;
}
}

for(auto & i : data2) {
for(float j : i) {
ASSERT_FLOAT_EQ(vect->at(currentIndex), j);
currentIndex++;
}
}

// Validate if more data is provided then the initial capacity throw exception
ASSERT_THROW(
knn_jni::commons::storeVectorData(&mockJNIUtil, jniEnv, memoryAddress,
reinterpret_cast<jobjectArray>(&data2), (jlong)(totalNumberOfVector * dim), (jlong)(vect->size())), std::runtime_error);
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.jni.JNICommons;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -42,9 +42,9 @@
@State(Scope.Benchmark)
public class TransferVectorsBenchmarks {
private static final Random random = new Random(1212121212);
private static final int TOTAL_NUMBER_OF_VECTOR_TO_BE_TRANSFERRED = 1000000;
private static final long TOTAL_NUMBER_OF_VECTOR_TO_BE_TRANSFERRED = 1000000;

@Param({ "128", "256", "384", "512" })
@Param({ "128", "256", "384", "512", "960", "1024", "1536" })
private int dimension;

@Param({ "100000", "500000", "1000000" })
Expand All @@ -61,20 +61,32 @@ public void setup() {
}

@Benchmark
public void transferVectors() {
public void transferVectors_withCapacity() {
long vectorsAddress = 0;
List<float[]> vectorToTransfer = new ArrayList<>();
long startingIndex = 0;
for (float[] floats : vectorList) {
if (vectorToTransfer.size() == vectorsPerTransfer) {
vectorsAddress = JNIService.transferVectorsV2(vectorsAddress, vectorToTransfer.toArray(new float[][] {}));
vectorsAddress = JNICommons.storeVectorData(
vectorsAddress,
vectorToTransfer.toArray(new float[][] {}),
dimension * TOTAL_NUMBER_OF_VECTOR_TO_BE_TRANSFERRED,
startingIndex
);
startingIndex += vectorsPerTransfer;
vectorToTransfer = new ArrayList<>();
}
vectorToTransfer.add(floats);
}
if (!vectorToTransfer.isEmpty()) {
vectorsAddress = JNIService.transferVectorsV2(vectorsAddress, vectorToTransfer.toArray(new float[][] {}));
vectorsAddress = JNICommons.storeVectorData(
vectorsAddress,
vectorToTransfer.toArray(new float[][] {}),
dimension * TOTAL_NUMBER_OF_VECTOR_TO_BE_TRANSFERRED,
startingIndex
);
}
JNIService.freeVectors(vectorsAddress);
JNICommons.freeVectorData(vectorsAddress);
}

private float[] generateRandomVector(int dimensions) {
Expand Down
Loading

0 comments on commit e980b6f

Please sign in to comment.