Skip to content

Commit

Permalink
Address Review Comments
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Aug 22, 2024
1 parent 5d83d1e commit ec1570d
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 28 deletions.
2 changes: 2 additions & 0 deletions jni/include/commons.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ namespace knn_jni {
* @param memoryAddress The address of the memory location where data will be stored.
* @param data 2D byte array containing binary data to be stored in native memory.
* @param initialCapacity The initial capacity of the memory location.
* @param append whether to append or start from index 0 when called subsequently with the same address
* @return memory address of std::vector<uint8_t> where the data is stored.
*/
jlong storeBinaryVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong, jboolean);
Expand All @@ -63,6 +64,7 @@ namespace knn_jni {
* @param memoryAddress The address of the memory location where data will be stored.
* @param data 2D byte array containing int8 data to be stored in native memory.
* @param initialCapacity The initial capacity of the memory location.
* @param append whether to append or start from index 0 when called subsequently with the same address
* @return memory address of std::vector<int8_t> where the data is stored.
*/
jlong storeByteVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong, jboolean);
Expand Down
8 changes: 3 additions & 5 deletions jni/src/faiss_index_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,16 +329,14 @@ void ByteIndexService::insertToIndex(

// Add vectors in batches by casting int8 vectors into float with a batch size of 1000
int batchSize = 1000;
std::vector <float> inputFloatVectors(batchSize * dim);
std::vector <int64_t> floatVectorsIds(batchSize);
float inputFloatVectors[batchSize * dim];
int64_t floatVectorsIds[batchSize];
int id = 0;
auto iter = inputVectors->begin();

for (int id = 0; id < numVectors; id += batchSize) {
if (numVectors - id < batchSize) {
batchSize = numVectors - id;
inputFloatVectors.resize(batchSize * dim);
floatVectorsIds.resize(batchSize);
}

for (int i = 0; i < batchSize; ++i) {
Expand All @@ -347,7 +345,7 @@ void ByteIndexService::insertToIndex(
inputFloatVectors[i * dim + j] = static_cast<float>(*iter);
}
}
idMap->add_with_ids(batchSize, inputFloatVectors.data(), floatVectorsIds.data());
idMap->add_with_ids(batchSize, inputFloatVectors, floatVectorsIds);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.index.engine.faiss;

import org.apache.commons.lang.StringUtils;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.AbstractKNNMethod;
Expand All @@ -20,6 +21,7 @@
import java.util.Objects;
import java.util.Set;

import static org.opensearch.knn.common.KNNConstants.FAISS_SIGNED_BYTE_SQ;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.index.engine.faiss.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX;
import static org.opensearch.knn.index.engine.faiss.FaissFP16Util.isFaissSQClipToFP16RangeEnabled;
Expand Down Expand Up @@ -87,7 +89,7 @@ protected PerDimensionProcessor doGetPerDimensionProcessor(
throw new IllegalStateException("Unsupported vector data type " + vectorDataType);
}

static KNNLibraryIndexingContext adjustPrefix(
static KNNLibraryIndexingContext adjustIndexDescription(
MethodAsMapBuilder methodAsMapBuilder,
MethodComponentContext methodComponentContext,
KNNMethodConfigContext knnMethodConfigContext
Expand All @@ -105,6 +107,19 @@ static KNNLibraryIndexingContext adjustPrefix(
if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BINARY) {
prefix = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX;
}
if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BYTE) {

// If VectorDataType is Byte using Faiss engine then manipulate Index Description to use "SQ8_direct_signed" scalar quantizer
// For example, Index Description "HNSW16,Flat" will be updated as "HNSW16,SQ8_direct_signed"
String indexDescription = methodAsMapBuilder.indexDescription;
if (StringUtils.isNotEmpty(indexDescription)) {
StringBuilder indexDescriptionBuilder = new StringBuilder();
indexDescriptionBuilder.append(indexDescription.split(",")[0]);
indexDescriptionBuilder.append(",");
indexDescriptionBuilder.append(FAISS_SIGNED_BYTE_SQ);
methodAsMapBuilder.indexDescription = indexDescriptionBuilder.toString();
}
}
methodAsMapBuilder.indexDescription = prefix + methodAsMapBuilder.indexDescription;
return methodAsMapBuilder.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ private static MethodComponent initMethodComponent() {
methodComponentContext,
knnMethodConfigContext
).addParameter(METHOD_PARAMETER_M, "", "").addParameter(METHOD_ENCODER_PARAMETER, ",", "");
return adjustPrefix(methodAsMapBuilder, methodComponentContext, knnMethodConfigContext);
return adjustIndexDescription(methodAsMapBuilder, methodComponentContext, knnMethodConfigContext);
}))
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ private static MethodComponent initMethodComponent() {
methodComponentContext,
knnMethodConfigContext
).addParameter(METHOD_PARAMETER_NLIST, "", "").addParameter(METHOD_ENCODER_PARAMETER, ",", "");
return adjustPrefix(methodAsMapBuilder, methodComponentContext, knnMethodConfigContext);
return adjustIndexDescription(methodAsMapBuilder, methodComponentContext, knnMethodConfigContext);
}))
.setOverheadInKBEstimator((methodComponent, methodComponentContext, dimension) -> {
// Size estimate formula: (4 * nlists * d) / 1024 + 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.opensearch.common.Explicit;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.engine.KNNLibraryIndexingContext;
Expand All @@ -25,8 +24,6 @@
import java.util.Optional;

import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.FAISS_SIGNED_BYTE_SQ;
import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.common.KNNConstants.QFRAMEWORK_CONFIG;
Expand Down Expand Up @@ -130,23 +127,10 @@ private MethodFieldMapper(
this.fieldType.putAttribute(KNN_ENGINE, knnEngine.getName());

try {
Map<String, Object> libParams = knnLibraryIndexingContext.getLibraryParameters();

// If VectorDataType is Byte using Faiss engine then manipulate Index Description to use "SQ8_direct_signed" scalar quantizer
// For example, Index Description "HNSW16,Flat" will be updated as "HNSW16,SQ8_direct_signed"
if (VectorDataType.BYTE == vectorDataType && libParams.containsKey(INDEX_DESCRIPTION_PARAMETER)) {
String indexDescriptionValue = (String) libParams.get(INDEX_DESCRIPTION_PARAMETER);
if (indexDescriptionValue != null && indexDescriptionValue.isEmpty() == false) {
StringBuilder indexDescriptionBuilder = new StringBuilder();
indexDescriptionBuilder.append(indexDescriptionValue.split(",")[0]);
indexDescriptionBuilder.append(",");
indexDescriptionBuilder.append(FAISS_SIGNED_BYTE_SQ);

libParams.replace(INDEX_DESCRIPTION_PARAMETER, indexDescriptionBuilder.toString());
libParams.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.BYTE.getValue());
}
}
this.fieldType.putAttribute(PARAMETERS, XContentFactory.jsonBuilder().map(libParams).toString());
this.fieldType.putAttribute(
PARAMETERS,
XContentFactory.jsonBuilder().map(knnLibraryIndexingContext.getLibraryParameters()).toString()
);
} catch (IOException ioe) {
throw new RuntimeException(String.format("Unable to create KNNVectorFieldMapper: %s", ioe));
}
Expand Down

0 comments on commit ec1570d

Please sign in to comment.