Skip to content

Commit

Permalink
Add tests for Faiss HNSW Byte Vector
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Jul 12, 2024
1 parent 24a3d4c commit 5d532b0
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Adds dynamic query parameter ef_search [#1783](https://github.com/opensearch-project/k-NN/pull/1783)
* Adds dynamic query parameter ef_search in radial search faiss engine [#1790](https://github.com/opensearch-project/k-NN/pull/1790)
* Add binary format support with HNSW method in Faiss Engine [#1781](https://github.com/opensearch-project/k-NN/pull/1781)
* Add HNSW changes to support Faiss byte vector [#1823](https://github.com/opensearch-project/k-NN/pull/1823)
### Enhancements
### Bug Fixes
* Fixing the arithmetic to find the number of vectors to stream from java to jni layer.[#1804](https://github.com/opensearch-project/k-NN/pull/1804)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.index.codec.transfer;
Expand Down
150 changes: 149 additions & 1 deletion src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,16 @@
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE;
import static org.opensearch.knn.common.KNNConstants.NAME;
import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.FAISS_NAME;
import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES;

Expand Down Expand Up @@ -479,6 +485,144 @@ public void testSearchWithMissingQueryVector() {
assertTrue(ex.getMessage().contains("[knn] field 'vector' requires to be non-null and non-empty"));
}

@SneakyThrows
public void testAddDocWithByteVectorUsingFaissEngine() {
createKnnIndexMappingWithFaissEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue());
Byte[] vector = { 6, 6 };
addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector);

refreshAllIndices();
assertEquals(1, getDocCount(INDEX_NAME));
}

@SneakyThrows
public void testUpdateDocWithByteVectorUsingFaissEngine() {
createKnnIndexMappingWithFaissEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue());
Byte[] vector = { -36, 78 };
addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector);

Byte[] updatedVector = { 89, -8 };
updateKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, updatedVector);

refreshAllIndices();
assertEquals(1, getDocCount(INDEX_NAME));
}

@SneakyThrows
public void testDeleteDocWithByteVectorUsingFaissEngine() {
createKnnIndexMappingWithFaissEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue());
Byte[] vector = { 35, -46 };
addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector);

deleteKnnDoc(INDEX_NAME, DOC_ID);
refreshAllIndices();

assertEquals(0, getDocCount(INDEX_NAME));
}

@SneakyThrows
public void testSearchWithByteVectorUsingFaissEngine() {
createKnnIndexMappingWithFaissEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue());
ingestL2ByteTestData();

Byte[] queryVector = { 1, 1 };
Response response = searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, convertByteToFloatArray(queryVector), 4), 4);

validateL2SearchResults(response);
}

@SneakyThrows
public void testInvalidVectorDataUsingFaissEngine() {
createKnnIndexMappingWithFaissEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue());
Float[] vector = { -10.76f, 15.89f };

ResponseException ex = expectThrows(ResponseException.class, () -> addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector));
assertTrue(
ex.getMessage()
.contains(
String.format(
Locale.ROOT,
"[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers",
VECTOR_DATA_TYPE_FIELD,
VectorDataType.BYTE.getValue()
)
)
);
}

// Create an index with byte vector data_type and add a doc with values out of byte range which should throw exception
@SneakyThrows
public void testInvalidByteVectorRangeUsingFaissEngine() {
createKnnIndexMappingWithFaissEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue());
Float[] vector = { -1000f, 155f };

ResponseException ex = expectThrows(ResponseException.class, () -> addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector));
assertTrue(
ex.getMessage()
.contains(
String.format(
Locale.ROOT,
"[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [%d, %d]",
VECTOR_DATA_TYPE_FIELD,
VectorDataType.BYTE.getValue(),
Byte.MIN_VALUE,
Byte.MAX_VALUE
)
)
);
}

// Create an index with byte vector data_type using faiss engine with an encoder which should throw an exception
@SneakyThrows
public void testByteVectorDataTypeWithFaissEngineUsingEncoderThrowsException() {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject(PROPERTIES_FIELD)
.startObject(FIELD_NAME)
.field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE)
.field(DIMENSION, 2)
.field(VECTOR_DATA_TYPE_FIELD, VectorDataType.BYTE.getValue())
.startObject(KNNConstants.KNN_METHOD)
.field(KNNConstants.NAME, METHOD_HNSW)
.field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2)
.field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName())
.startObject(PARAMETERS)
.field(KNNConstants.METHOD_PARAMETER_M, M)
.field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, EF_CONSTRUCTION)
.startObject(METHOD_ENCODER_PARAMETER)
.field(NAME, ENCODER_SQ)
.startObject(PARAMETERS)
.field(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16)
.endObject()
.endObject()
.endObject()
.endObject()
.endObject()
.endObject()
.endObject();

String mapping = builder.toString();
ResponseException ex = expectThrows(ResponseException.class, () -> createKnnIndex(INDEX_NAME, mapping));
assertTrue(
ex.getMessage()
.contains(
String.format(Locale.ROOT, "%s data type does not support %s encoder", VectorDataType.BYTE.getValue(), ENCODER_SQ)
)
);
}

public void testDocValuesWithByteVectorDataTypeFaissEngine() throws Exception {
createKnnIndexMappingWithFaissEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue());
ingestL2ByteTestData();

Byte[] queryVector = { 1, 1 };
Request request = createScriptQueryRequest(queryVector, SpaceType.L2.getValue(), MATCH_ALL_QUERY_BUILDER);
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

validateL2SearchResults(response);
}

@SneakyThrows
private void ingestL2ByteTestData() {
Byte[] b1 = { 6, 6 };
Expand Down Expand Up @@ -517,6 +661,10 @@ private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spac
createKnnIndexMappingWithCustomEngine(dimension, spaceType, vectorDataType, KNNEngine.LUCENE.getName());
}

private void createKnnIndexMappingWithFaissEngine(int dimension, SpaceType spaceType, String vectorDataType) throws Exception {
createKnnIndexMappingWithCustomEngine(dimension, spaceType, vectorDataType, KNNEngine.FAISS.getName());
}

private void createKnnIndexMappingWithCustomEngine(int dimension, SpaceType spaceType, String vectorDataType, String engine)
throws Exception {
XContentBuilder builder = XContentFactory.jsonBuilder()
Expand All @@ -530,7 +678,7 @@ private void createKnnIndexMappingWithCustomEngine(int dimension, SpaceType spac
.field(KNNConstants.NAME, METHOD_HNSW)
.field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue())
.field(KNNConstants.KNN_ENGINE, engine)
.startObject(KNNConstants.PARAMETERS)
.startObject(PARAMETERS)
.field(KNNConstants.METHOD_PARAMETER_M, M)
.field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, EF_CONSTRUCTION)
.endObject()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.transfer;

import junit.framework.TestCase;
import lombok.SneakyThrows;
import org.opensearch.knn.index.codec.util.SerializationMode;
import org.opensearch.knn.jni.JNICommons;

import java.io.ByteArrayInputStream;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.IntStream;

import static org.junit.Assert.assertNotEquals;

public class VectorTransferByteToFloatTests extends TestCase {
@SneakyThrows
public void testTransfer_whenCalled_thenAdded() {
final ByteArrayInputStream bais1 = getByteArrayOfVectors(20);
final ByteArrayInputStream bais2 = getByteArrayOfVectors(20);
VectorTransferByteToFloat vectorTransfer = new VectorTransferByteToFloat(1000);
try {
vectorTransfer.init(2);

vectorTransfer.transfer(bais1);
// flush is not called
assertEquals(0, vectorTransfer.getVectorAddress());

vectorTransfer.transfer(bais2);
// flush should be called
assertNotEquals(0, vectorTransfer.getVectorAddress());
} finally {
if (vectorTransfer.getVectorAddress() != 0) {
JNICommons.freeVectorData(vectorTransfer.getVectorAddress());
}
}
}

@SneakyThrows
public void testSerializationMode_whenCalled_thenReturn() {
final ByteArrayInputStream bais = getByteArrayOfVectors(20);
VectorTransferByteToFloat vectorTransfer = new VectorTransferByteToFloat(1000);

// Verify
assertEquals(SerializationMode.COLLECTION_OF_FLOATS, vectorTransfer.getSerializationMode(bais));
}

private ByteArrayInputStream getByteArrayOfVectors(int vectorLength) {
byte[] vector = new byte[vectorLength];
IntStream.range(0, vectorLength).forEach(index -> vector[index] = (byte) ThreadLocalRandom.current().nextInt(-128, 127));
return new ByteArrayInputStream(vector);
}
}

0 comments on commit 5d532b0

Please sign in to comment.