Skip to content

Commit

Permalink
Refactoring changes
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Aug 21, 2024
1 parent 1f9fb8f commit b500228
Show file tree
Hide file tree
Showing 16 changed files with 272 additions and 209 deletions.
18 changes: 9 additions & 9 deletions jni/include/commons.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,6 @@ namespace knn_jni {
*/
jlong storeBinaryVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong, jboolean);

/**
* Free up the memory allocated for the data stored in memory address. This function should be used with the memory
* address returned by {@link JNICommons#storeVectorData(long, float[][], long, long)}
*
* @param memoryAddress address to be freed.
*/
void freeVectorData(jlong);

/**
* This is utility function that can be used to store signed int8 data in native memory. This function will allocate memory for
* the data(rows*columns) with initialCapacity and return the memory address where the data is stored.
Expand All @@ -73,7 +65,15 @@ namespace knn_jni {
* @param initialCapacity The initial capacity of the memory location.
* @return memory address of std::vector<int8_t> where the data is stored.
*/
jlong storeByteVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong);
jlong storeByteVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong, jboolean);

/**
* Free up the memory allocated for the data stored in memory address. This function should be used with the memory
* address returned by {@link JNICommons#storeVectorData(long, float[][], long, long)}
*
* @param memoryAddress address to be freed.
*/
void freeVectorData(jlong);

/**
* Free up the memory allocated for the data stored in memory address. This function should be used with the memory
Expand Down
68 changes: 32 additions & 36 deletions jni/include/faiss_index_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,23 @@ class ByteIndexService : public IndexService {
//TODO Remove dependency on JNIUtilInterface and JNIEnv
//TODO Reduce the number of parameters
ByteIndexService(std::unique_ptr<FaissMethods> faissMethods);

/**
* Initialize index
*
* @param jniUtil jni util
* @param env jni environment
* @param metric space type for distance calculation
* @param indexDescription index description to be used by faiss index factory
* @param dim dimension of vectors
* @param numVectors number of vectors
* @param threadCount number of thread count to be used while adding data
* @param parameters parameters to be applied to faiss index
* @return memory address of the native index object
*/
virtual jlong initIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int dim, int numVectors, int threadCount, std::unordered_map<std::string, jobject> parameters) override;
/**
* Create byte index
* Add vectors to index
*
* @param jniUtil jni util
* @param env jni environment
Expand All @@ -145,45 +160,26 @@ class ByteIndexService : public IndexService {
* @param numIds number of vectors
* @param threadCount number of thread count to be used while adding data
* @param vectorsAddress memory address which is holding vector data
* @param ids a list of document ids for corresponding vectors
* @param idMap a map of document id and vector id
* @param parameters parameters to be applied to faiss index
*/
virtual void insertToIndex(int dim, int numIds, int threadCount, int64_t vectorsAddress, std::vector<int64_t> &ids, jlong idMapAddress) override;
/**
* Write index to disk
*
* @param jniUtil jni util
* @param env jni environment
* @param metric space type for distance calculation
* @param indexDescription index description to be used by faiss index factory
* @param threadCount number of thread count to be used while adding data
* @param indexPath path to write index
* @param idMap a map of document id and vector id
* @param parameters parameters to be applied to faiss index
*/
virtual void createIndex(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
faiss::MetricType metric,
std::string indexDescription,
int dim,
int numIds,
int threadCount,
int64_t vectorsAddress,
std::vector<int64_t> ids,
std::string indexPath,
std::unordered_map<std::string, jobject> parameters
) override;
virtual void writeIndex(std::string indexPath, jlong idMapAddress) override;
virtual ~ByteIndexService() = default;

private:
virtual std::unique_ptr <faiss::Index> generateIndex(
knn_jni::JNIUtilInterface *jniUtil,
JNIEnv *env,
int vectorSize,
faiss::MetricType metric,
std::string indexDescription,
int dim,
int numIds,
int threadCount,
std::unordered_map <std::string, jobject> parameters
);

virtual void addVectorsToIndex(
faiss::Index* indexWriter,
std::vector <int8_t> *inputVectors,
int dim,
std::vector <int64_t> ids,
std::string indexPath
);
protected:
virtual void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) override;
};

}
Expand Down
28 changes: 25 additions & 3 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initIndex(JNIEn
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initBinaryIndex(JNIEnv * env, jclass cls,
jlong numDocs, jint dimJ,
jobject parametersJ);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: initByteIndex
* Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initByteIndex(JNIEnv * env, jclass cls,
jlong numDocs, jint dimJ,
jobject parametersJ);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: insertToIndex
Expand All @@ -50,6 +60,16 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JN
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToBinaryIndex(JNIEnv * env, jclass cls, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ,
jlong indexAddress, jint threadCount);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: insertToByteIndex
* Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToByteIndex(JNIEnv * env, jclass cls, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ,
jlong indexAddress, jint threadCount);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: writeIndex
Expand All @@ -66,13 +86,15 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex(JNIEn
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeBinaryIndex(JNIEnv * env, jclass cls,
jlong indexAddress,
jstring indexPathJ);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createByteIndex
* Method: writeByteIndex
* Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createByteIndex
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject);
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeByteIndex(JNIEnv * env, jclass cls,
jlong indexAddress,
jstring indexPathJ);

/*
* Class: org_opensearch_knn_jni_FaissService
Expand Down
2 changes: 1 addition & 1 deletion jni/include/org_opensearch_knn_jni_JNICommons.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData
* Signature: (J[[FJJ)
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeBinaryVectorData
(JNIEnv *, jclass, jlong, jobjectArray, jlong);
(JNIEnv *, jclass, jlong, jobjectArray, jlong, jboolean);

/*
* Class: org_opensearch_knn_jni_JNICommons
Expand Down
7 changes: 6 additions & 1 deletion jni/src/commons.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,19 @@ jlong knn_jni::commons::storeBinaryVectorData(knn_jni::JNIUtilInterface *jniUtil
}

jlong knn_jni::commons::storeByteVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ,
jobjectArray dataJ, jlong initialCapacityJ) {
jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ) {
std::vector<int8_t> *vect;
if (memoryAddressJ == 0) {
vect = new std::vector<int8_t>();
vect->reserve(static_cast<long>(initialCapacityJ));
} else {
vect = reinterpret_cast<std::vector<int8_t>*>(memoryAddressJ);
}

if (appendJ == JNI_FALSE) {
vect->clear();
}

int dim = jniUtil->GetInnerDimensionOf2dJavaByteArray(env, dataJ);
jniUtil->Convert2dJavaObjectArrayAndStoreToByteVector(env, dataJ, dim, vect);

Expand Down
Loading

0 comments on commit b500228

Please sign in to comment.