Skip to content

Commit

Permalink
Throw proper exception to invalid k-NN query (#1380) (#1381)
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei authored Jan 9, 2024
1 parent 722bc63 commit 7c65643
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Increase Lucene max dimension limit to 16,000 [#1346](https://github.com/opensearch-project/k-NN/pull/1346)
* Tuned default values for ef_search and ef_construction for better indexing and search performance for vector search [#1353](https://github.com/opensearch-project/k-NN/pull/1353)
* Enabled Filtering on Nested Vector fields with top level filters [#1372](https://github.com/opensearch-project/k-NN/pull/1372)
* Throw proper exception to invalid k-NN query [#1380](https://github.com/opensearch-project/k-NN/pull/1380)
### Bug Fixes
* Fix use-after-free case on nmslib search path [#1305](https://github.com/opensearch-project/k-NN/pull/1305)
* Allow nested knn field mapping when train model [#1318](https://github.com/opensearch-project/k-NN/pull/1318)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,14 @@ public static void initialize(ModelDao modelDao) {
}

private static float[] ObjectsToFloats(List<Object> objs) {
if (Objects.isNull(objs) || objs.isEmpty()) {
throw new IllegalArgumentException(String.format("[%s] field 'vector' requires to be non-null and non-empty", NAME));
}
float[] vec = new float[objs.size()];
for (int i = 0; i < objs.size(); i++) {
if ((objs.get(i) instanceof Number) == false) {
throw new IllegalArgumentException(String.format("[%s] field 'vector' requires to be an array of numbers", NAME));
}
vec[i] = ((Number) objs.get(i)).floatValue();
}
return vec;
Expand Down
51 changes: 51 additions & 0 deletions src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.script.Script;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -425,6 +426,56 @@ public void testKNNScriptScoreWithInvalidByteQueryVector() throws Exception {
);
}

@SneakyThrows
public void testSearchWithInvalidSearchVectorType() {
createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.FLOAT.getValue());
ingestL2FloatTestData();
Request request = new Request("POST", String.format("/%s/_search", INDEX_NAME));
List<Object> invalidTypeQueryVector = new ArrayList<>();
invalidTypeQueryVector.add(1.5);
invalidTypeQueryVector.add(2.5);
invalidTypeQueryVector.add("a");
invalidTypeQueryVector.add(null);
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("query")
.startObject("knn")
.startObject(FIELD_NAME)
.field("vector", invalidTypeQueryVector)
.field("k", 4)
.endObject()
.endObject()
.endObject()
.endObject();
request.setJsonEntity(builder.toString());

ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request));
assertEquals(400, ex.getResponse().getStatusLine().getStatusCode());
assertTrue(ex.getMessage().contains("[knn] field 'vector' requires to be an array of numbers"));
}

@SneakyThrows
public void testSearchWithMissingQueryVector() {
createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.FLOAT.getValue());
ingestL2FloatTestData();
Request request = new Request("POST", String.format("/%s/_search", INDEX_NAME));
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("query")
.startObject("knn")
.startObject(FIELD_NAME)
.field("k", 4)
.endObject()
.endObject()
.endObject()
.endObject();
request.setJsonEntity(builder.toString());

ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request));
assertEquals(400, ex.getResponse().getStatusLine().getStatusCode());
assertTrue(ex.getMessage().contains("[knn] field 'vector' requires to be non-null and non-empty"));
}

@SneakyThrows
private void ingestL2ByteTestData() {
Byte[] b1 = { 6, 6 };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.opensearch.plugins.SearchPlugin;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

Expand Down Expand Up @@ -149,6 +150,70 @@ public void testFromXcontent_WithFilter_UnsupportedClusterVersion() throws Excep
expectThrows(IllegalArgumentException.class, () -> KNNQueryBuilder.fromXContent(contentParser));
}

public void testFromXContent_invalidQueryVectorType() throws Exception {
final ClusterService clusterService = mockClusterService(Version.CURRENT);

final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance();
knnClusterUtil.initialize(clusterService);

List<Object> invalidTypeQueryVector = new ArrayList<>();
invalidTypeQueryVector.add(1.5);
invalidTypeQueryVector.add(2.5);
invalidTypeQueryVector.add("a");
invalidTypeQueryVector.add(null);

XContentBuilder builder = XContentFactory.jsonBuilder();
builder.startObject();
builder.startObject(FIELD_NAME);
builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), invalidTypeQueryVector);
builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K);
builder.endObject();
builder.endObject();
XContentParser contentParser = createParser(builder);
contentParser.nextToken();
IllegalArgumentException exception = expectThrows(
IllegalArgumentException.class,
() -> KNNQueryBuilder.fromXContent(contentParser)
);
assertTrue(exception.getMessage().contains("[knn] field 'vector' requires to be an array of numbers"));
}

public void testFromXContent_missingQueryVector() throws Exception {
final ClusterService clusterService = mockClusterService(Version.CURRENT);

final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance();
knnClusterUtil.initialize(clusterService);

// Test without vector field
XContentBuilder builderWithoutVectorField = XContentFactory.jsonBuilder();
builderWithoutVectorField.startObject();
builderWithoutVectorField.startObject(FIELD_NAME);
builderWithoutVectorField.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K);
builderWithoutVectorField.endObject();
builderWithoutVectorField.endObject();
XContentParser contentParserWithoutVectorField = createParser(builderWithoutVectorField);
contentParserWithoutVectorField.nextToken();
IllegalArgumentException exception = expectThrows(
IllegalArgumentException.class,
() -> KNNQueryBuilder.fromXContent(contentParserWithoutVectorField)
);
assertTrue(exception.getMessage().contains("[knn] field 'vector' requires to be non-null and non-empty"));

// Test empty vector field
List<Object> emptyQueryVector = new ArrayList<>();
XContentBuilder builderWithEmptyVector = XContentFactory.jsonBuilder();
builderWithEmptyVector.startObject();
builderWithEmptyVector.startObject(FIELD_NAME);
builderWithEmptyVector.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), emptyQueryVector);
builderWithEmptyVector.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K);
builderWithEmptyVector.endObject();
builderWithEmptyVector.endObject();
XContentParser contentParserWithEmptyVector = createParser(builderWithEmptyVector);
contentParserWithEmptyVector.nextToken();
exception = expectThrows(IllegalArgumentException.class, () -> KNNQueryBuilder.fromXContent(contentParserWithEmptyVector));
assertTrue(exception.getMessage().contains("[knn] field 'vector' requires to be non-null and non-empty"));
}

@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> list = ClusterModule.getNamedXWriteables();
Expand Down

0 comments on commit 7c65643

Please sign in to comment.