Skip to content

Commit

Permalink
Address Review Comments
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Jul 13, 2024
1 parent 374b81c commit 66ed657
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,20 @@ private KNNEngine getKNNEngine(@NonNull FieldInfo field) {
}

private VectorTransfer getVectorTransfer(FieldInfo field) {
if (VectorDataType.BINARY.getValue().equalsIgnoreCase(field.attributes().get(KNNConstants.VECTOR_DATA_TYPE_FIELD))) {
return new VectorTransferByte(KNNSettings.getVectorStreamingMemoryLimit().getBytes());
}
if (VectorDataType.BYTE.getValue().equalsIgnoreCase(field.attributes().get(KNNConstants.VECTOR_DATA_TYPE_FIELD))
&& KNNEngine.FAISS.getName().equalsIgnoreCase(field.attributes().get(KNNConstants.KNN_ENGINE))) {
return new VectorTransferByteToFloat(KNNSettings.getVectorStreamingMemoryLimit().getBytes());
String vectorDataType = field.attributes().getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue());
String knnEngine = field.attributes().get(KNNConstants.KNN_ENGINE);
long memoryLimit = KNNSettings.getVectorStreamingMemoryLimit().getBytes();

switch (vectorDataType.toLowerCase()) {
case "binary":
return new VectorTransferByte(memoryLimit);
case "byte":
if (KNNEngine.FAISS.getName().equalsIgnoreCase(knnEngine)) {
return new VectorTransferByteToFloat(memoryLimit);
}
default:
return new VectorTransferFloat(memoryLimit);
}
return new VectorTransferFloat(KNNSettings.getVectorStreamingMemoryLimit().getBytes());
}

public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import org.opensearch.knn.jni.JNICommons;

import java.io.ByteArrayInputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;

Expand Down Expand Up @@ -70,8 +72,9 @@ private static float[] byteToFloatArray(ByteArrayInputStream byteStream) {
final byte[] vectorAsByteArray = byteStream.readAllBytes();
final int sizeOfFloatArray = vectorAsByteArray.length;
final float[] vector = new float[sizeOfFloatArray];
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorAsByteArray).order(ByteOrder.nativeOrder());
for (int i = 0; i < sizeOfFloatArray; i++) {
vector[i] = vectorAsByteArray[i];
vector[i] = byteBuffer.get(i);
}
return vector;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,10 @@ public class MethodFieldMapper extends KNNVectorFieldMapper {
// For example, Index Description "HNSW16,Flat" will be updated as "HNSW16,SQ8_direct_signed"
if (VectorDataType.BYTE.equals(vectorDataType) && methodParamsMap.containsKey(INDEX_DESCRIPTION_PARAMETER)) {
String indexDescriptionValue = (String) methodParamsMap.get(INDEX_DESCRIPTION_PARAMETER);
String updatedIndexDescription = indexDescriptionValue.split(",")[0] + "," + FAISS_SIGNED_BYTE_SQ;
methodParamsMap.replace(INDEX_DESCRIPTION_PARAMETER, updatedIndexDescription);
if (indexDescriptionValue != null && !indexDescriptionValue.isEmpty()) {
String updatedIndexDescription = indexDescriptionValue.split(",")[0] + "," + FAISS_SIGNED_BYTE_SQ;
methodParamsMap.replace(INDEX_DESCRIPTION_PARAMETER, updatedIndexDescription);
}
}
this.fieldType.putAttribute(PARAMETERS, XContentFactory.jsonBuilder().map(methodParamsMap).toString());
} catch (IOException ioe) {
Expand Down
26 changes: 16 additions & 10 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -609,16 +609,8 @@ protected Query doToQuery(QueryShardContext context) {
.knnEngine(knnEngine)
.indexName(indexName)
.fieldName(this.fieldName)
.vector(
(VectorDataType.FLOAT == vectorDataType) || (VectorDataType.BYTE == vectorDataType && KNNEngine.FAISS == knnEngine)
? this.vector
: null
)
.byteVector(
(VectorDataType.BYTE == vectorDataType && KNNEngine.LUCENE == knnEngine) || VectorDataType.BINARY == vectorDataType
? byteVector
: null
)
.vector(getFloatVectorForCreatingQueryRequest(vectorDataType, knnEngine))
.byteVector(getByteVectorForCreatingQueryRequest(vectorDataType, knnEngine, byteVector))
.vectorDataType(vectorDataType)
.k(this.k)
.methodParameters(this.methodParameters)
Expand Down Expand Up @@ -696,6 +688,20 @@ private void updateQueryStats(VectorQueryType vectorQueryType) {
}
}

private float[] getFloatVectorForCreatingQueryRequest(VectorDataType vectorDataType, KNNEngine knnEngine) {
if ((VectorDataType.FLOAT == vectorDataType) || (VectorDataType.BYTE == vectorDataType && KNNEngine.FAISS == knnEngine)) {
return this.vector;
}
return null;
}

private byte[] getByteVectorForCreatingQueryRequest(VectorDataType vectorDataType, KNNEngine knnEngine, byte[] byteVector) {
if (VectorDataType.BINARY == vectorDataType || (VectorDataType.BYTE == vectorDataType && KNNEngine.LUCENE == knnEngine)) {
return byteVector;
}
return null;
}

@Override
protected boolean doEquals(KNNQueryBuilder other) {
return Objects.equals(fieldName, other.fieldName)
Expand Down

0 comments on commit 66ed657

Please sign in to comment.