Skip to content

Commit

Permalink
Address Review Comments
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Jul 29, 2024
1 parent 22caf2a commit 66ec044
Show file tree
Hide file tree
Showing 28 changed files with 558 additions and 131 deletions.
24 changes: 23 additions & 1 deletion jni/include/commons.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,35 @@ namespace knn_jni {
*/
void freeVectorData(jlong);

/**
* This is utility function that can be used to store data in native memory. This function will allocate memory for
* the data(rows*columns) with initialCapacity and return the memory address where the data is stored.
* If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created.
* For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location
* will throw Exception.
*
* @param memoryAddress The address of the memory location where data will be stored.
* @param data 2D byte array containing data to be stored in native memory.
* @param initialCapacity The initial capacity of the memory location.
* @return memory address of std::vector<int8_t> where the data is stored.
*/
jlong storeSignedByteVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong);

/**
* Free up the memory allocated for the data stored in memory address. This function should be used with the memory
* address returned by {@link JNICommons#storeByteVectorData(long, byte[][], long, long)}
*
* @param memoryAddress address to be freed.
*/
void freeSignedByteVectorData(jlong);

/**
* Free up the memory allocated for the data stored in memory address. This function should be used with the memory
* address returned by {@link JNICommons#storeByteVectorData(long, byte[][], long, long)}
*
* @param memoryAddress address to be freed.
*/
void freeByteVectorData(jlong);
void freeByteVectorData(jlong);

/**
* Extracts query time efSearch from method parameters
Expand Down
40 changes: 40 additions & 0 deletions jni/include/faiss_index_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,46 @@ class BinaryIndexService : public IndexService {
virtual ~BinaryIndexService() = default;
};

/**
* A class to provide operations on index
* This class should evolve to have only cpp object but not jni object
*/
class ByteIndexService : public IndexService {
public:
//TODO Remove dependency on JNIUtilInterface and JNIEnv
//TODO Reduce the number of parameters
ByteIndexService(std::unique_ptr<FaissMethods> faissMethods);
/**
* Create byte index
*
* @param jniUtil jni util
* @param env jni environment
* @param metric space type for distance calculation
* @param indexDescription index description to be used by faiss index factory
* @param dim dimension of vectors
* @param numIds number of vectors
* @param threadCount number of thread count to be used while adding data
* @param vectorsAddress memory address which is holding vector data
* @param ids a list of document ids for corresponding vectors
* @param indexPath path to write index
* @param parameters parameters to be applied to faiss index
*/
virtual void createIndex(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
faiss::MetricType metric,
std::string indexDescription,
int dim,
int numIds,
int threadCount,
int64_t vectorsAddress,
std::vector<int64_t> ids,
std::string indexPath,
std::unordered_map<std::string, jobject> parameters
) override;
virtual ~ByteIndexService() = default;
};

}
}

Expand Down
3 changes: 3 additions & 0 deletions jni/include/jni_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ namespace knn_jni {
int dim, std::vector<float> *vect ) = 0;
virtual void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ,
int dim, std::vector<uint8_t> *vect ) = 0;
virtual void Convert2dJavaObjectArrayAndStoreToSignedByteVector(JNIEnv *env, jobjectArray array2dJ,
int dim, std::vector<int8_t> *vect ) = 0;

virtual std::vector<int64_t> ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) = 0;

Expand Down Expand Up @@ -174,6 +176,7 @@ namespace knn_jni {
void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf);
void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<float> *vect);
void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<uint8_t> *vect);
void Convert2dJavaObjectArrayAndStoreToSignedByteVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<int8_t> *vect);

private:
std::unordered_map<std::string, jclass> cachedClasses;
Expand Down
8 changes: 8 additions & 0 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndex
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createByteIndex
* Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createByteIndex
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createIndexFromTemplate
Expand Down
16 changes: 16 additions & 0 deletions jni/include/org_opensearch_knn_jni_JNICommons.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeByteVectorData
(JNIEnv *, jclass, jlong, jobjectArray, jlong);

/*
* Class: org_opensearch_knn_jni_JNICommons
* Method: storeSignedByteVectorData
* Signature: (J[[FJJ)
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeSignedByteVectorData
(JNIEnv *, jclass, jlong, jobjectArray, jlong);

/*
* Class: org_opensearch_knn_jni_JNICommons
* Method: freeVectorData
Expand All @@ -50,6 +58,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeVectorData
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeByteVectorData
(JNIEnv *, jclass, jlong);

/*
* Class: org_opensearch_knn_jni_JNICommons
* Method: freeSignedByteVectorData
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeSignedByteVectorData
(JNIEnv *, jclass, jlong);

#ifdef __cplusplus
}
#endif
Expand Down
22 changes: 22 additions & 0 deletions jni/src/commons.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,21 @@ jlong knn_jni::commons::storeByteVectorData(knn_jni::JNIUtilInterface *jniUtil,
return (jlong) vect;
}

jlong knn_jni::commons::storeSignedByteVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ,
jobjectArray dataJ, jlong initialCapacityJ) {
std::vector<int8_t> *vect;
if ((long) memoryAddressJ == 0) {
vect = new std::vector<int8_t>();
vect->reserve((long)initialCapacityJ);
} else {
vect = reinterpret_cast<std::vector<int8_t>*>(memoryAddressJ);
}
int dim = jniUtil->GetInnerDimensionOf2dJavaByteArray(env, dataJ);
jniUtil->Convert2dJavaObjectArrayAndStoreToSignedByteVector(env, dataJ, dim, vect);

return (jlong) vect;
}

void knn_jni::commons::freeVectorData(jlong memoryAddressJ) {
if (memoryAddressJ != 0) {
auto *vect = reinterpret_cast<std::vector<float>*>(memoryAddressJ);
Expand All @@ -61,6 +76,13 @@ void knn_jni::commons::freeByteVectorData(jlong memoryAddressJ) {
}
}

void knn_jni::commons::freeSignedByteVectorData(jlong memoryAddressJ) {
if (memoryAddressJ != 0) {
auto *vect = reinterpret_cast<std::vector<int8_t>*>(memoryAddressJ);
delete vect;
}
}

int knn_jni::commons::getIntegerMethodParameter(JNIEnv * env, knn_jni::JNIUtilInterface * jniUtil, std::unordered_map<std::string, jobject> methodParams, std::string methodParam, int defaultValue) {
if (methodParams.empty()) {
return defaultValue;
Expand Down
84 changes: 84 additions & 0 deletions jni/src/faiss_index_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,5 +160,89 @@ void BinaryIndexService::createIndex(
faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str());
}

ByteIndexService::ByteIndexService(std::unique_ptr<FaissMethods> faissMethods) : IndexService(std::move(faissMethods)) {}

void ByteIndexService::createIndex(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
faiss::MetricType metric,
std::string indexDescription,
int dim,
int numIds,
int threadCount,
int64_t vectorsAddress,
std::vector<int64_t> ids,
std::string indexPath,
std::unordered_map<std::string, jobject> parameters
) {
// Read vectors from memory address
auto *inputVectors = reinterpret_cast<std::vector<int8_t>*>(vectorsAddress);

if (inputVectors->size() == 0) {
throw std::runtime_error("Number of vectors cannot be 0");
}

// The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value
int numVectors = (int) (inputVectors->size() / (uint64_t) dim);
if(numVectors == 0) {
throw std::runtime_error("Number of vectors cannot be 0");
}

if (numIds != numVectors) {
throw std::runtime_error("Number of IDs does not match number of vectors");
}

std::unique_ptr<faiss::Index> indexWriter(faissMethods->indexFactory(dim, indexDescription.c_str(), metric));

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
if(threadCount != 0) {
omp_set_num_threads(threadCount);
}

// Add extra parameters that cant be configured with the index factory
SetExtraParameters<faiss::Index, faiss::IndexIVF, faiss::IndexHNSW>(jniUtil, env, parameters, indexWriter.get());

// Check that the index does not need to be trained
if(!indexWriter->is_trained) {
throw std::runtime_error("Index is not trained");
}

// Add vectors
std::unique_ptr<faiss::IndexIDMap> idMap(faissMethods->indexIdMap(indexWriter.get()));

int batchSize = 1;
int totalNumVecs = numVectors;
auto *inputFloatVectors = new std::vector<float>();
std::vector<int64_t> floatVectorsIds;
int id = 0;
std::vector<int8_t>::iterator iter = inputVectors->begin();

while (id < numVectors) {
if(totalNumVecs < batchSize) {
batchSize = totalNumVecs;
}

inputFloatVectors->reserve(batchSize * dim);
floatVectorsIds.reserve(batchSize);
for(int i = 0; i < batchSize; i++) {
floatVectorsIds.push_back(ids[id++]);
for(int j = 0; j < dim; j++) {
inputFloatVectors->push_back((float) *iter);
iter++;
}
}
idMap->add_with_ids(batchSize, inputFloatVectors->data(), floatVectorsIds.data());

totalNumVecs = totalNumVecs - batchSize;
inputFloatVectors->clear();
floatVectorsIds.clear();
}

delete inputFloatVectors;

// Write the index to disk
faissMethods->writeIndex(idMap.get(), indexPath.c_str());
}

} // namespace faiss_wrapper
} // namesapce knn_jni
33 changes: 33 additions & 0 deletions jni/src/jni_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,39 @@ void knn_jni::JNIUtil::Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env,
env->DeleteLocalRef(array2dJ);
}

void knn_jni::JNIUtil::Convert2dJavaObjectArrayAndStoreToSignedByteVector(JNIEnv *env, jobjectArray array2dJ,
int dim, std::vector<int8_t> *vect) {

if (array2dJ == nullptr) {
throw std::runtime_error("Array cannot be null");
}

int numVectors = env->GetArrayLength(array2dJ);
this->HasExceptionInStack(env);

for (int i = 0; i < numVectors; ++i) {
auto vectorArray = (jbyteArray)env->GetObjectArrayElement(array2dJ, i);
this->HasExceptionInStack(env, "Unable to get object array element");

if (dim != env->GetArrayLength(vectorArray)) {
throw std::runtime_error("Dimension of vectors is inconsistent");
}

int8_t* vector = reinterpret_cast<int8_t*>(env->GetByteArrayElements(vectorArray, nullptr));
if (vector == nullptr) {
this->HasExceptionInStack(env);
throw std::runtime_error("Unable to get byte array elements");
}

for(int j = 0; j < dim; ++j) {
vect->push_back(vector[j]);
}
env->ReleaseByteArrayElements(vectorArray, reinterpret_cast<int8_t*>(vector), JNI_ABORT);
}
this->HasExceptionInStack(env);
env->DeleteLocalRef(array2dJ);
}

std::vector<int64_t> knn_jni::JNIUtil::ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) {

if (arrayJ == nullptr) {
Expand Down
18 changes: 18 additions & 0 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,24 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryInde
}
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createByteIndex(JNIEnv * env, jclass cls, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ,
jstring indexPathJ, jobject parametersJ)
{
try {
std::unique_ptr<knn_jni::faiss_wrapper::FaissMethods> faissMethods(new knn_jni::faiss_wrapper::FaissMethods());
knn_jni::faiss_wrapper::ByteIndexService byteIndexService(std::move(faissMethods));
knn_jni::faiss_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, parametersJ, &byteIndexService);

// Releasing the vectorsAddressJ memory as that is not required once we have created the index.
// This is not the ideal approach, please refer this gh issue for long term solution:
// https://github.com/opensearch-project/k-NN/issues/1600
delete reinterpret_cast<std::vector<int8_t>*>(vectorsAddressJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate(JNIEnv * env, jclass cls,
jintArray idsJ,
jlong vectorsAddressJ,
Expand Down
22 changes: 22 additions & 0 deletions jni/src/org_opensearch_knn_jni_JNICommons.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ)
return (long)memoryAddressJ;
}

JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeSignedByteVectorData(JNIEnv * env, jclass cls,
jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ)

{
try {
return knn_jni::commons::storeSignedByteVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return (long)memoryAddressJ;
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeVectorData(JNIEnv * env, jclass cls,
jlong memoryAddressJ)
{
Expand All @@ -81,3 +93,13 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeByteVectorData
jniUtil.CatchCppExceptionAndThrowJava(env);
}
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeSignedByteVectorData(JNIEnv * env, jclass cls,
jlong memoryAddressJ)
{
try {
return knn_jni::commons::freeSignedByteVectorData(memoryAddressJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
}
Loading

0 comments on commit 66ec044

Please sign in to comment.