From 17846337451791625451018f9c78d3f4f3b1b7af Mon Sep 17 00:00:00 2001 From: Subhobrata Dey Date: Wed, 26 Apr 2023 21:53:43 +0000 Subject: [PATCH] adds hnsw graph based storage & query layer to events correlation engine Signed-off-by: Subhobrata Dey --- .../events-correlation-engine/build.gradle | 2 + .../CorrelationVectorsEngineIT.java | 247 +++++++++++ .../correlation/EventsCorrelationPlugin.java | 39 +- .../core/index/CorrelationParamsContext.java | 173 ++++++++ .../correlation/core/index/VectorField.java | 39 ++ .../BasePerFieldCorrelationVectorsFormat.java | 102 +++++ .../index/codec/CorrelationCodecService.java | 38 ++ .../index/codec/CorrelationCodecVersion.java | 103 +++++ .../correlation950/CorrelationCodec.java | 46 ++ .../PerFieldCorrelationVectorsFormat.java | 35 ++ .../codec/correlation950/package-info.java | 12 + .../core/index/codec/package-info.java | 12 + .../mapper/CorrelationVectorFieldMapper.java | 172 ++++++++ .../core/index/mapper/VectorFieldMapper.java | 394 ++++++++++++++++++ .../core/index/mapper/package-info.java | 12 + .../correlation/core/index/package-info.java | 12 + .../index/query/CorrelationQueryBuilder.java | 334 +++++++++++++++ .../index/query/CorrelationQueryFactory.java | 142 +++++++ .../core/index/query/package-info.java | 12 + .../CorrelationVectorSerializer.java | 58 +++ .../core/index/serializer/package-info.java | 12 + .../services/org.apache.lucene.codecs.Codec | 1 + .../correlation950/CorrelationCodecTests.java | 120 ++++++ .../CorrelationVectorFieldMapperTests.java | 208 +++++++++ .../query/CorrelationQueryBuilderTests.java | 268 ++++++++++++ .../CorrelationVectorSerializerTests.java | 57 +++ 26 files changed, 2649 insertions(+), 1 deletion(-) create mode 100644 plugins/events-correlation-engine/src/javaRestTest/java/org/opensearch/plugin/correlation/CorrelationVectorsEngineIT.java create mode 100644 plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/CorrelationParamsContext.java create mode 100644 plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/VectorField.java create mode 100644 plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/BasePerFieldCorrelationVectorsFormat.java create mode 100644 plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/CorrelationCodecService.java create mode 100644 plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/CorrelationCodecVersion.java create mode 100644 plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/correlation950/CorrelationCodec.java create mode 100644 plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/correlation950/PerFieldCorrelationVectorsFormat.java create mode 100644 plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/correlation950/package-info.java create mode 100644 plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/package-info.java create mode 100644 plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/mapper/CorrelationVectorFieldMapper.java create mode 100644 plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/mapper/VectorFieldMapper.java create mode 100644 plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/mapper/package-info.java create mode 100644 plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/package-info.java create mode 100644 plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/query/CorrelationQueryBuilder.java create mode 100644 plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/query/CorrelationQueryFactory.java create mode 100644 plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/query/package-info.java create mode 100644 plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/serializer/CorrelationVectorSerializer.java create mode 100644 plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/serializer/package-info.java create mode 100644 plugins/events-correlation-engine/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec create mode 100644 plugins/events-correlation-engine/src/test/java/org/opensearch/plugin/correlation/core/index/codec/correlation950/CorrelationCodecTests.java create mode 100644 plugins/events-correlation-engine/src/test/java/org/opensearch/plugin/correlation/core/index/mapper/CorrelationVectorFieldMapperTests.java create mode 100644 plugins/events-correlation-engine/src/test/java/org/opensearch/plugin/correlation/core/index/query/CorrelationQueryBuilderTests.java create mode 100644 plugins/events-correlation-engine/src/test/java/org/opensearch/plugin/correlation/core/index/serializer/CorrelationVectorSerializerTests.java diff --git a/plugins/events-correlation-engine/build.gradle b/plugins/events-correlation-engine/build.gradle index c3eff30012b1d..2782b7b8a29f3 100644 --- a/plugins/events-correlation-engine/build.gradle +++ b/plugins/events-correlation-engine/build.gradle @@ -17,5 +17,7 @@ opensearchplugin { classname 'org.opensearch.plugin.correlation.EventsCorrelationPlugin' } +forbiddenApis.ignoreFailures = true + dependencies { } diff --git a/plugins/events-correlation-engine/src/javaRestTest/java/org/opensearch/plugin/correlation/CorrelationVectorsEngineIT.java b/plugins/events-correlation-engine/src/javaRestTest/java/org/opensearch/plugin/correlation/CorrelationVectorsEngineIT.java new file mode 100644 index 0000000000000..8a878070bb20f --- /dev/null +++ b/plugins/events-correlation-engine/src/javaRestTest/java/org/opensearch/plugin/correlation/CorrelationVectorsEngineIT.java @@ -0,0 +1,247 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.plugin.correlation; + +import org.apache.hc.core5.http.Header; +import org.apache.hc.core5.http.HttpEntity; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.junit.Assert; +import org.opensearch.client.Request; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.Response; +import org.opensearch.client.RestClient; +import org.opensearch.client.WarningsHandler; +import org.opensearch.common.Strings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.IndexSettings; +import org.opensearch.rest.RestStatus; +import org.opensearch.test.rest.OpenSearchRestTestCase; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * Correlation Vectors Engine e2e tests + */ +public class CorrelationVectorsEngineIT extends OpenSearchRestTestCase { + + private static final int DIMENSION = 4; + private static final String PROPERTIES_FIELD_NAME = "properties"; + private static final String TYPE_FIELD_NAME = "type"; + private static final String CORRELATION_VECTOR_TYPE = "correlation_vector"; + private static final String DIMENSION_FIELD_NAME = "dimension"; + private static final int M = 16; + private static final int EF_CONSTRUCTION = 128; + private static final String INDEX_NAME = "test-index-1"; + private static final Float[][] TEST_VECTORS = new Float[][] { + { 1.0f, 1.0f, 1.0f, 1.0f }, + { 2.0f, 2.0f, 2.0f, 2.0f }, + { 3.0f, 3.0f, 3.0f, 3.0f } }; + private static final float[][] TEST_QUERY_VECTORS = new float[][] { + { 1.0f, 1.0f, 1.0f, 1.0f }, + { 2.0f, 2.0f, 2.0f, 2.0f }, + { 3.0f, 3.0f, 3.0f, 3.0f } }; + private static final Map> VECTOR_SIMILARITY_TO_SCORE = Map.of( + VectorSimilarityFunction.EUCLIDEAN, + (similarity) -> 1 / (1 + similarity), + VectorSimilarityFunction.DOT_PRODUCT, + (similarity) -> (1 + similarity) / 2, + VectorSimilarityFunction.COSINE, + (similarity) -> (1 + similarity) / 2 + ); + + /** + * test the e2e storage and query layer of events-correlation-engine + * @throws IOException IOException + */ + @SuppressWarnings("unchecked") + public void testQuery() throws IOException { + String textField = "text-field"; + String luceneField = "lucene-field"; + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD_NAME) + .startObject(textField) + .field(TYPE_FIELD_NAME, "text") + .endObject() + .startObject(luceneField) + .field(TYPE_FIELD_NAME, CORRELATION_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, DIMENSION) + .startObject("correlation_ctx") + .field("similarityFunction", VectorSimilarityFunction.EUCLIDEAN.name()) + .startObject("parameters") + .field("m", M) + .field("ef_construction", EF_CONSTRUCTION) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + String mapping = Strings.toString(builder); + createTestIndexWithMappingJson(client(), INDEX_NAME, mapping, getCorrelationDefaultIndexSettings()); + + for (int idx = 0; idx < TEST_VECTORS.length; ++idx) { + addCorrelationDoc( + INDEX_NAME, + String.valueOf(idx + 1), + List.of(textField, luceneField), + List.of(java.util.UUID.randomUUID().toString(), TEST_VECTORS[idx]) + ); + } + refreshAllIndices(); + Assert.assertEquals(TEST_VECTORS.length, getDocCount(INDEX_NAME)); + + int k = 2; + for (float[] query : TEST_QUERY_VECTORS) { + + String correlationQuery = "{\n" + + " \"query\": {\n" + + " \"correlation\": {\n" + + " \"lucene-field\": {\n" + + " \"vector\": \n" + + Arrays.toString(query) + + " ,\n" + + " \"k\": 2,\n" + + " \"boost\": 1\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + + Response response = searchCorrelationIndex(INDEX_NAME, correlationQuery, k); + Map responseBody = entityAsMap(response); + Assert.assertEquals(2, ((List) ((Map) responseBody.get("hits")).get("hits")).size()); + @SuppressWarnings("unchecked") + double actualScore1 = Double.parseDouble( + ((List>) ((Map) responseBody.get("hits")).get("hits")).get(0).get("_score").toString() + ); + @SuppressWarnings("unchecked") + double actualScore2 = Double.parseDouble( + ((List>) ((Map) responseBody.get("hits")).get("hits")).get(1).get("_score").toString() + ); + @SuppressWarnings("unchecked") + List hit1 = ((Map>) ((List>) ((Map) responseBody.get("hits")) + .get("hits")).get(0).get("_source")).get(luceneField).stream().map(Double::floatValue).collect(Collectors.toList()); + float[] resultVector1 = new float[hit1.size()]; + for (int i = 0; i < hit1.size(); ++i) { + resultVector1[i] = hit1.get(i); + } + + @SuppressWarnings("unchecked") + List hit2 = ((Map>) ((List>) ((Map) responseBody.get("hits")) + .get("hits")).get(1).get("_source")).get(luceneField).stream().map(Double::floatValue).collect(Collectors.toList()); + float[] resultVector2 = new float[hit2.size()]; + for (int i = 0; i < hit2.size(); ++i) { + resultVector2[i] = hit2.get(i); + } + + double rawScore1 = VectorSimilarityFunction.EUCLIDEAN.compare(resultVector1, query); + Assert.assertEquals(rawScore1, actualScore1, 0.0001); + double rawScore2 = VectorSimilarityFunction.EUCLIDEAN.compare(resultVector2, query); + Assert.assertEquals(rawScore2, actualScore2, 0.0001); + } + } + + private String createTestIndexWithMappingJson(RestClient client, String index, String mapping, Settings settings) throws IOException { + Request request = new Request("PUT", "/" + index); + String entity = "{\"settings\": " + Strings.toString(XContentType.JSON, settings); + if (mapping != null) { + entity = entity + ",\"mappings\" : " + mapping; + } + + entity = entity + "}"; + if (!settings.getAsBoolean(IndexSettings.INDEX_SOFT_DELETES_SETTING.getKey(), true)) { + expectSoftDeletesWarning(request, index); + } + + request.setJsonEntity(entity); + client.performRequest(request); + return index; + } + + private Settings getCorrelationDefaultIndexSettings() { + return Settings.builder().put("number_of_shards", 1).put("number_of_replicas", 0).put("index.correlation", true).build(); + } + + private void addCorrelationDoc(String index, String docId, List fieldNames, List vectors) throws IOException { + Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); + + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + for (int i = 0; i < fieldNames.size(); i++) { + builder.field(fieldNames.get(i), vectors.get(i)); + } + builder.endObject(); + + request.setJsonEntity(Strings.toString(builder)); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + private Response searchCorrelationIndex(String index, String correlationQuery, int resultSize) throws IOException { + Request request = new Request("POST", "/" + index + "/_search"); + + request.addParameter("size", Integer.toString(resultSize)); + request.addParameter("explain", Boolean.toString(true)); + request.addParameter("search_type", "query_then_fetch"); + request.setJsonEntity(correlationQuery); + + Response response = client().performRequest(request); + Assert.assertEquals("Search failed", RestStatus.OK, restStatus(response)); + return response; + } + + private int getDocCount(String index) throws IOException { + Response response = makeRequest( + client(), + "GET", + String.format(Locale.getDefault(), "/%s/_count", index), + Collections.emptyMap(), + null + ); + Assert.assertEquals(RestStatus.OK, restStatus(response)); + return Integer.parseInt(entityAsMap(response).get("count").toString()); + } + + private Response makeRequest( + RestClient client, + String method, + String endpoint, + Map params, + HttpEntity entity, + Header... headers + ) throws IOException { + Request request = new Request(method, endpoint); + RequestOptions.Builder options = RequestOptions.DEFAULT.toBuilder(); + options.setWarningsHandler(WarningsHandler.PERMISSIVE); + + for (Header header : headers) { + options.addHeader(header.getName(), header.getValue()); + } + request.setOptions(options.build()); + request.addParameters(params); + if (entity != null) { + request.setEntity(entity); + } + return client.performRequest(request); + } + + private RestStatus restStatus(Response response) { + return RestStatus.fromCode(response.getStatusLine().getStatusCode()); + } +} diff --git a/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/EventsCorrelationPlugin.java b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/EventsCorrelationPlugin.java index 443a794bd99df..6945f21a0fd7c 100644 --- a/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/EventsCorrelationPlugin.java +++ b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/EventsCorrelationPlugin.java @@ -23,13 +23,23 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; +import org.opensearch.index.IndexSettings; +import org.opensearch.index.codec.CodecServiceFactory; +import org.opensearch.index.mapper.Mapper; +import org.opensearch.plugin.correlation.core.index.codec.CorrelationCodecService; +import org.opensearch.plugin.correlation.core.index.mapper.CorrelationVectorFieldMapper; +import org.opensearch.plugin.correlation.core.index.mapper.VectorFieldMapper; +import org.opensearch.plugin.correlation.core.index.query.CorrelationQueryBuilder; import org.opensearch.plugin.correlation.rules.action.IndexCorrelationRuleAction; import org.opensearch.plugin.correlation.rules.resthandler.RestIndexCorrelationRuleAction; import org.opensearch.plugin.correlation.rules.transport.TransportIndexCorrelationRuleAction; import org.opensearch.plugin.correlation.settings.EventsCorrelationSettings; import org.opensearch.plugin.correlation.utils.CorrelationRuleIndices; import org.opensearch.plugins.ActionPlugin; +import org.opensearch.plugins.EnginePlugin; +import org.opensearch.plugins.MapperPlugin; import org.opensearch.plugins.Plugin; +import org.opensearch.plugins.SearchPlugin; import org.opensearch.repositories.RepositoriesService; import org.opensearch.rest.RestController; import org.opensearch.rest.RestHandler; @@ -38,13 +48,16 @@ import org.opensearch.watcher.ResourceWatcherService; import java.util.Collection; +import java.util.Collections; import java.util.List; +import java.util.Map; +import java.util.Optional; import java.util.function.Supplier; /** * Plugin class for events-correlation-engine */ -public class EventsCorrelationPlugin extends Plugin implements ActionPlugin { +public class EventsCorrelationPlugin extends Plugin implements ActionPlugin, MapperPlugin, SearchPlugin, EnginePlugin { /** * events-correlation-engine base uri @@ -93,6 +106,30 @@ public List getRestHandlers( return List.of(new RestIndexCorrelationRuleAction()); } + @Override + public Map getMappers() { + return Collections.singletonMap(CorrelationVectorFieldMapper.CONTENT_TYPE, new VectorFieldMapper.TypeParser()); + } + + @Override + public Optional getCustomCodecServiceFactory(IndexSettings indexSettings) { + if (indexSettings.getValue(EventsCorrelationSettings.IS_CORRELATION_INDEX_SETTING)) { + return Optional.of(CorrelationCodecService::new); + } + return Optional.empty(); + } + + @Override + public List> getQueries() { + return Collections.singletonList( + new QuerySpec<>( + CorrelationQueryBuilder.NAME_FIELD.getPreferredName(), + CorrelationQueryBuilder::new, + CorrelationQueryBuilder::parse + ) + ); + } + @Override public List> getActions() { return List.of(new ActionPlugin.ActionHandler<>(IndexCorrelationRuleAction.INSTANCE, TransportIndexCorrelationRuleAction.class)); diff --git a/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/CorrelationParamsContext.java b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/CorrelationParamsContext.java new file mode 100644 index 0000000000000..5b2db0aeea73d --- /dev/null +++ b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/CorrelationParamsContext.java @@ -0,0 +1,173 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.plugin.correlation.core.index; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentFragment; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.mapper.MapperParsingException; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; + +/** + * Defines vector similarity function, m and ef_construction hyper parameters field mappings for correlation_vector type. + * + * @opensearch.internal + */ +public class CorrelationParamsContext implements ToXContentFragment, Writeable { + + /** + * Vector Similarity Function field + */ + public static final String VECTOR_SIMILARITY_FUNCTION = "similarityFunction"; + /** + * Parameters field to define m and ef_construction + */ + public static final String PARAMETERS = "parameters"; + + private final VectorSimilarityFunction similarityFunction; + private final Map parameters; + + /** + * Parameterized ctor for CorrelationParamsContext + * @param similarityFunction Vector Similarity Function + * @param parameters Parameters to define m and ef_construction + */ + public CorrelationParamsContext(VectorSimilarityFunction similarityFunction, Map parameters) { + this.similarityFunction = similarityFunction; + this.parameters = parameters; + } + + /** + * Parameterized ctor for CorrelationParamsContext + * @param sin StreamInput + * @throws IOException IOException + */ + public CorrelationParamsContext(StreamInput sin) throws IOException { + this.similarityFunction = VectorSimilarityFunction.valueOf(sin.readString()); + if (sin.available() > 0) { + this.parameters = sin.readMap(); + } else { + this.parameters = null; + } + } + + /** + * Parse into CorrelationParamsContext + * @param in Object + * @return CorrelationParamsContext + */ + public static CorrelationParamsContext parse(Object in) { + if (!(in instanceof Map)) { + throw new MapperParsingException("Unable to parse CorrelationParamsContext"); + } + + @SuppressWarnings("unchecked") + Map contextMap = (Map) in; + VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.EUCLIDEAN; + Map parameters = new HashMap<>(); + + for (Map.Entry contextEntry : contextMap.entrySet()) { + String key = contextEntry.getKey(); + Object value = contextEntry.getValue(); + + if (VECTOR_SIMILARITY_FUNCTION.equals(key)) { + if (value != null && !(value instanceof String)) { + throw new MapperParsingException(String.format(Locale.getDefault(), "%s must be a string", VECTOR_SIMILARITY_FUNCTION)); + } + + try { + similarityFunction = VectorSimilarityFunction.valueOf((String) value); + } catch (IllegalArgumentException ex) { + throw new MapperParsingException( + String.format(Locale.getDefault(), "Invalid %s: %s", VECTOR_SIMILARITY_FUNCTION, value) + ); + } + } else if (PARAMETERS.equals(key)) { + if (value == null) { + parameters = null; + continue; + } + + if (!(value instanceof Map)) { + throw new MapperParsingException("Unable to parse parameters for Correlation context"); + } + + @SuppressWarnings("unchecked") + Map valueMap = (Map) value; + assert parameters != null; + parameters.putAll(valueMap); + } else { + throw new MapperParsingException(String.format(Locale.getDefault(), "Invalid parameter for : %s", key)); + } + } + return new CorrelationParamsContext(similarityFunction, parameters); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(VECTOR_SIMILARITY_FUNCTION, similarityFunction.name()); + if (params == null) { + builder.field(PARAMETERS, (String) null); + } else { + builder.startObject(PARAMETERS); + for (Map.Entry parameter : parameters.entrySet()) { + builder.field(parameter.getKey(), parameter.getValue()); + } + builder.endObject(); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + CorrelationParamsContext that = (CorrelationParamsContext) o; + return similarityFunction == that.similarityFunction && parameters.equals(that.parameters); + } + + @Override + public int hashCode() { + return Objects.hash(similarityFunction, parameters); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(similarityFunction.name()); + if (this.parameters != null) { + out.writeMap(parameters); + } + } + + /** + * get Vector Similarity Function + * @return Vector Similarity Function + */ + public VectorSimilarityFunction getSimilarityFunction() { + return similarityFunction; + } + + /** + * Get Parameters to define m and ef_construction + * @return Parameters to define m and ef_construction + */ + public Map getParameters() { + return parameters; + } +} diff --git a/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/VectorField.java b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/VectorField.java new file mode 100644 index 0000000000000..e292924e2b1b6 --- /dev/null +++ b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/VectorField.java @@ -0,0 +1,39 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.plugin.correlation.core.index; + +import org.apache.lucene.document.Field; +import org.apache.lucene.index.IndexableFieldType; +import org.apache.lucene.util.BytesRef; +import org.opensearch.plugin.correlation.core.index.serializer.CorrelationVectorSerializer; + +/** + * Generic Vector Field defining a correlation vector name, float array. + * + * @opensearch.internal + */ +public class VectorField extends Field { + + /** + * Parameterized ctor for VectorField + * @param name name of the field + * @param value float array value for the field + * @param type type of the field + */ + public VectorField(String name, float[] value, IndexableFieldType type) { + super(name, new BytesRef(), type); + try { + final CorrelationVectorSerializer vectorSerializer = new CorrelationVectorSerializer(); + final byte[] floatToByte = vectorSerializer.floatToByteArray(value); + this.setBytesValue(floatToByte); + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } +} diff --git a/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/BasePerFieldCorrelationVectorsFormat.java b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/BasePerFieldCorrelationVectorsFormat.java new file mode 100644 index 0000000000000..87856c80b9fa4 --- /dev/null +++ b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/BasePerFieldCorrelationVectorsFormat.java @@ -0,0 +1,102 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.plugin.correlation.core.index.codec; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.plugin.correlation.core.index.mapper.CorrelationVectorFieldMapper; + +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.function.BiFunction; +import java.util.function.Supplier; + +/** + * Class to define the hyper-parameters m and ef_construction for insert and store of correlation vectors into HNSW graphs based lucene index. + */ +public abstract class BasePerFieldCorrelationVectorsFormat extends PerFieldKnnVectorsFormat { + /** + * the hyper-parameters for constructing HNSW graphs. + * https://lucene.apache.org/core/9_4_0/core/org/apache/lucene/util/hnsw/HnswGraph.html + */ + public static final String METHOD_PARAMETER_M = "m"; + /** + * the hyper-parameters for constructing HNSW graphs. + * https://lucene.apache.org/core/9_4_0/core/org/apache/lucene/util/hnsw/HnswGraph.html + */ + public static final String METHOD_PARAMETER_EF_CONSTRUCTION = "ef_construction"; + + private final Optional mapperService; + private final int defaultMaxConnections; + private final int defaultBeamWidth; + private final Supplier defaultFormatSupplier; + private final BiFunction formatSupplier; + + /** + * Parameterized ctor of BasePerFieldCorrelationVectorsFormat + * @param mapperService mapper service + * @param defaultMaxConnections default m + * @param defaultBeamWidth default ef_construction + * @param defaultFormatSupplier default format supplier + * @param formatSupplier format supplier + */ + public BasePerFieldCorrelationVectorsFormat( + Optional mapperService, + int defaultMaxConnections, + int defaultBeamWidth, + Supplier defaultFormatSupplier, + BiFunction formatSupplier + ) { + this.mapperService = mapperService; + this.defaultMaxConnections = defaultMaxConnections; + this.defaultBeamWidth = defaultBeamWidth; + this.defaultFormatSupplier = defaultFormatSupplier; + this.formatSupplier = formatSupplier; + } + + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + if (!isCorrelationVectorFieldType(field)) { + return defaultFormatSupplier.get(); + } + + var type = (CorrelationVectorFieldMapper.CorrelationVectorFieldType) mapperService.orElseThrow( + () -> new IllegalArgumentException( + String.format(Locale.getDefault(), "Cannot read field type for field [%s] because mapper service is not available", field) + ) + ).fieldType(field); + + var params = type.getCorrelationParams().getParameters(); + int maxConnections = getMaxConnections(params); + int beamWidth = getBeamWidth(params); + + return formatSupplier.apply(maxConnections, beamWidth); + } + + private boolean isCorrelationVectorFieldType(final String field) { + return mapperService.isPresent() + && mapperService.get().fieldType(field) instanceof CorrelationVectorFieldMapper.CorrelationVectorFieldType; + } + + private int getMaxConnections(final Map params) { + if (params != null && params.containsKey(METHOD_PARAMETER_M)) { + return (int) params.get(METHOD_PARAMETER_M); + } + return defaultMaxConnections; + } + + private int getBeamWidth(final Map params) { + if (params != null && params.containsKey(METHOD_PARAMETER_EF_CONSTRUCTION)) { + return (int) params.get(METHOD_PARAMETER_EF_CONSTRUCTION); + } + return defaultBeamWidth; + } +} diff --git a/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/CorrelationCodecService.java b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/CorrelationCodecService.java new file mode 100644 index 0000000000000..0b70e7ed66f3d --- /dev/null +++ b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/CorrelationCodecService.java @@ -0,0 +1,38 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.plugin.correlation.core.index.codec; + +import org.apache.lucene.codecs.Codec; +import org.opensearch.index.codec.CodecService; +import org.opensearch.index.codec.CodecServiceConfig; +import org.opensearch.index.mapper.MapperService; + +/** + * custom Correlation Codec Service + * + * @opensearch.internal + */ +public class CorrelationCodecService extends CodecService { + + private final MapperService mapperService; + + /** + * Parameterized ctor for CorrelationCodecService + * @param codecServiceConfig Generic codec service config + */ + public CorrelationCodecService(CodecServiceConfig codecServiceConfig) { + super(codecServiceConfig.getMapperService(), codecServiceConfig.getLogger()); + mapperService = codecServiceConfig.getMapperService(); + } + + @Override + public Codec codec(String name) { + return CorrelationCodecVersion.current().getCorrelationCodecSupplier().apply(super.codec(name), mapperService); + } +} diff --git a/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/CorrelationCodecVersion.java b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/CorrelationCodecVersion.java new file mode 100644 index 0000000000000..5e2cb8bfbc03a --- /dev/null +++ b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/CorrelationCodecVersion.java @@ -0,0 +1,103 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.plugin.correlation.core.index.codec; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.lucene95.Lucene95Codec; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.plugin.correlation.core.index.codec.correlation950.CorrelationCodec; +import org.opensearch.plugin.correlation.core.index.codec.correlation950.PerFieldCorrelationVectorsFormat; + +import java.util.Optional; +import java.util.function.BiFunction; +import java.util.function.Supplier; + +/** + * CorrelationCodecVersion enum + * + * @opensearch.internal + */ +public enum CorrelationCodecVersion { + V_9_5_0( + "CorrelationCodec", + new Lucene95Codec(), + new PerFieldCorrelationVectorsFormat(Optional.empty()), + (userCodec, mapperService) -> new CorrelationCodec(userCodec, new PerFieldCorrelationVectorsFormat(Optional.of(mapperService))), + CorrelationCodec::new + ); + + private static final CorrelationCodecVersion CURRENT = V_9_5_0; + private final String codecName; + private final Codec defaultCodecDelegate; + private final PerFieldCorrelationVectorsFormat perFieldCorrelationVectorsFormat; + private final BiFunction correlationCodecSupplier; + private final Supplier defaultCorrelationCodecSupplier; + + CorrelationCodecVersion( + String codecName, + Codec defaultCodecDelegate, + PerFieldCorrelationVectorsFormat perFieldCorrelationVectorsFormat, + BiFunction correlationCodecSupplier, + Supplier defaultCorrelationCodecSupplier + ) { + this.codecName = codecName; + this.defaultCodecDelegate = defaultCodecDelegate; + this.perFieldCorrelationVectorsFormat = perFieldCorrelationVectorsFormat; + this.correlationCodecSupplier = correlationCodecSupplier; + this.defaultCorrelationCodecSupplier = defaultCorrelationCodecSupplier; + } + + /** + * get codec name + * @return codec name + */ + public String getCodecName() { + return codecName; + } + + /** + * get default codec delegate + * @return default codec delegate + */ + public Codec getDefaultCodecDelegate() { + return defaultCodecDelegate; + } + + /** + * get correlation vectors format + * @return correlation vectors format + */ + public PerFieldCorrelationVectorsFormat getPerFieldCorrelationVectorsFormat() { + return perFieldCorrelationVectorsFormat; + } + + /** + * get correlation codec supplier + * @return correlation codec supplier + */ + public BiFunction getCorrelationCodecSupplier() { + return correlationCodecSupplier; + } + + /** + * get default correlation codec supplier + * @return default correlation codec supplier + */ + public Supplier getDefaultCorrelationCodecSupplier() { + return defaultCorrelationCodecSupplier; + } + + /** + * static method to get correlation codec version + * @return correlation codec version + */ + public static final CorrelationCodecVersion current() { + return CURRENT; + } +} diff --git a/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/correlation950/CorrelationCodec.java b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/correlation950/CorrelationCodec.java new file mode 100644 index 0000000000000..f91ba429fbea9 --- /dev/null +++ b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/correlation950/CorrelationCodec.java @@ -0,0 +1,46 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.plugin.correlation.core.index.codec.correlation950; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.opensearch.plugin.correlation.core.index.codec.CorrelationCodecVersion; + +/** + * Correlation Codec class + * + * @opensearch.internal + */ +public class CorrelationCodec extends FilterCodec { + private static final CorrelationCodecVersion VERSION = CorrelationCodecVersion.V_9_5_0; + private final PerFieldCorrelationVectorsFormat perFieldCorrelationVectorsFormat; + + /** + * ctor for CorrelationCodec + */ + public CorrelationCodec() { + this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldCorrelationVectorsFormat()); + } + + /** + * Parameterized ctor for CorrelationCodec + * @param delegate codec delegate + * @param perFieldCorrelationVectorsFormat correlation vectors format + */ + public CorrelationCodec(Codec delegate, PerFieldCorrelationVectorsFormat perFieldCorrelationVectorsFormat) { + super(VERSION.getCodecName(), delegate); + this.perFieldCorrelationVectorsFormat = perFieldCorrelationVectorsFormat; + } + + @Override + public KnnVectorsFormat knnVectorsFormat() { + return perFieldCorrelationVectorsFormat; + } +} diff --git a/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/correlation950/PerFieldCorrelationVectorsFormat.java b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/correlation950/PerFieldCorrelationVectorsFormat.java new file mode 100644 index 0000000000000..f6862ecc17736 --- /dev/null +++ b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/correlation950/PerFieldCorrelationVectorsFormat.java @@ -0,0 +1,35 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.plugin.correlation.core.index.codec.correlation950; + +import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.plugin.correlation.core.index.codec.BasePerFieldCorrelationVectorsFormat; + +import java.util.Optional; + +/** + * Class to define the hyper-parameters m and ef_construction for insert and store of correlation vectors into HNSW graphs based lucene index. + */ +public class PerFieldCorrelationVectorsFormat extends BasePerFieldCorrelationVectorsFormat { + + /** + * Parameterized ctor for PerFieldCorrelationVectorsFormat + * @param mapperService mapper service + */ + public PerFieldCorrelationVectorsFormat(final Optional mapperService) { + super( + mapperService, + Lucene95HnswVectorsFormat.DEFAULT_MAX_CONN, + Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH, + Lucene95HnswVectorsFormat::new, + Lucene95HnswVectorsFormat::new + ); + } +} diff --git a/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/correlation950/package-info.java b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/correlation950/package-info.java new file mode 100644 index 0000000000000..b4dad34d2718e --- /dev/null +++ b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/correlation950/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * custom Lucene9.5 codec package for events-correlation-engine + */ +package org.opensearch.plugin.correlation.core.index.codec.correlation950; diff --git a/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/package-info.java b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/package-info.java new file mode 100644 index 0000000000000..862b7cd253f04 --- /dev/null +++ b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/codec/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * custom codec package for events-correlation-engine + */ +package org.opensearch.plugin.correlation.core.index.codec; diff --git a/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/mapper/CorrelationVectorFieldMapper.java b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/mapper/CorrelationVectorFieldMapper.java new file mode 100644 index 0000000000000..04cd5d3a66546 --- /dev/null +++ b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/mapper/CorrelationVectorFieldMapper.java @@ -0,0 +1,172 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.plugin.correlation.core.index.mapper; + +import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.document.StoredField; +import org.apache.lucene.index.DocValuesType; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.opensearch.common.Explicit; +import org.opensearch.index.mapper.FieldMapper; +import org.opensearch.index.mapper.ParseContext; +import org.opensearch.plugin.correlation.core.index.CorrelationParamsContext; +import org.opensearch.plugin.correlation.core.index.VectorField; + +import java.io.IOException; +import java.util.Optional; + +import static org.apache.lucene.index.FloatVectorValues.MAX_DIMENSIONS; + +/** + * Field mapper for the correlation vector type + * + * @opensearch.internal + */ +public class CorrelationVectorFieldMapper extends VectorFieldMapper { + + private static final int LUCENE_MAX_DIMENSION = MAX_DIMENSIONS; + + private final FieldType vectorFieldType; + + /** + * Parameterized ctor for CorrelationVectorFieldMapper + * @param input Object containing name of the field, type and other details. + */ + public CorrelationVectorFieldMapper(final CreateLuceneFieldMapperInput input) { + super( + input.getName(), + input.getMappedFieldType(), + input.getMultiFields(), + input.getCopyTo(), + input.getIgnoreMalformed(), + input.isStored(), + input.isHasDocValues() + ); + + this.correlationParams = input.getCorrelationParams(); + final VectorSimilarityFunction vectorSimilarityFunction = this.correlationParams.getSimilarityFunction(); + + final int dimension = input.getMappedFieldType().getDimension(); + if (dimension > LUCENE_MAX_DIMENSION) { + throw new IllegalArgumentException( + String.format( + "Dimension value cannot be greater than [%s] but got [%s] for vector [%s]", + LUCENE_MAX_DIMENSION, + dimension, + input.getName() + ) + ); + } + + this.fieldType = KnnFloatVectorField.createFieldType(dimension, vectorSimilarityFunction); + + if (this.hasDocValues) { + this.vectorFieldType = buildDocValuesFieldType(); + } else { + this.vectorFieldType = null; + } + } + + private static FieldType buildDocValuesFieldType() { + FieldType field = new FieldType(); + field.setDocValuesType(DocValuesType.BINARY); + field.freeze(); + return field; + } + + @Override + protected void parseCreateField(ParseContext context, int dimension) throws IOException { + Optional arrayOptional = getFloatsFromContext(context, dimension); + + if (arrayOptional.isEmpty()) { + return; + } + final float[] array = arrayOptional.get(); + + KnnFloatVectorField point = new KnnFloatVectorField(name(), array, fieldType); + + context.doc().add(point); + if (fieldType.stored()) { + context.doc().add(new StoredField(name(), point.toString())); + } + if (hasDocValues && vectorFieldType != null) { + context.doc().add(new VectorField(name(), array, vectorFieldType)); + } + context.path().remove(); + } + + static class CreateLuceneFieldMapperInput { + String name; + + CorrelationVectorFieldType mappedFieldType; + + FieldMapper.MultiFields multiFields; + + FieldMapper.CopyTo copyTo; + + Explicit ignoreMalformed; + boolean stored; + boolean hasDocValues; + + CorrelationParamsContext correlationParams; + + public CreateLuceneFieldMapperInput( + String name, + CorrelationVectorFieldType mappedFieldType, + FieldMapper.MultiFields multiFields, + FieldMapper.CopyTo copyTo, + Explicit ignoreMalformed, + boolean stored, + boolean hasDocValues, + CorrelationParamsContext correlationParams + ) { + this.name = name; + this.mappedFieldType = mappedFieldType; + this.multiFields = multiFields; + this.copyTo = copyTo; + this.ignoreMalformed = ignoreMalformed; + this.stored = stored; + this.hasDocValues = hasDocValues; + this.correlationParams = correlationParams; + } + + public String getName() { + return name; + } + + public CorrelationVectorFieldType getMappedFieldType() { + return mappedFieldType; + } + + public FieldMapper.MultiFields getMultiFields() { + return multiFields; + } + + public FieldMapper.CopyTo getCopyTo() { + return copyTo; + } + + public Explicit getIgnoreMalformed() { + return ignoreMalformed; + } + + public boolean isStored() { + return stored; + } + + public boolean isHasDocValues() { + return hasDocValues; + } + + public CorrelationParamsContext getCorrelationParams() { + return correlationParams; + } + } +} diff --git a/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/mapper/VectorFieldMapper.java b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/mapper/VectorFieldMapper.java new file mode 100644 index 0000000000000..315c770136a25 --- /dev/null +++ b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/mapper/VectorFieldMapper.java @@ -0,0 +1,394 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.plugin.correlation.core.index.mapper; + +import org.apache.lucene.search.FieldExistsQuery; +import org.apache.lucene.search.Query; +import org.opensearch.common.Explicit; +import org.opensearch.common.xcontent.support.XContentMapValues; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.mapper.FieldMapper; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.mapper.Mapper; +import org.opensearch.index.mapper.MapperParsingException; +import org.opensearch.index.mapper.ParametrizedFieldMapper; +import org.opensearch.index.mapper.ParseContext; +import org.opensearch.index.mapper.TextSearchInfo; +import org.opensearch.index.mapper.ValueFetcher; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.QueryShardException; +import org.opensearch.plugin.correlation.core.index.CorrelationParamsContext; +import org.opensearch.search.lookup.SearchLookup; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; + +/** + * Parameterized field mapper for Correlation Vector type + * + * @opensearch.internal + */ +public abstract class VectorFieldMapper extends ParametrizedFieldMapper { + + /** + * name of Correlation Vector type + */ + public static final String CONTENT_TYPE = "correlation_vector"; + /** + * dimension of the correlation vectors + */ + public static final String DIMENSION = "dimension"; + /** + * context e.g. parameters and vector similarity function of Correlation Vector type + */ + public static final String CORRELATION_CONTEXT = "correlation_ctx"; + + private static VectorFieldMapper toType(FieldMapper in) { + return (VectorFieldMapper) in; + } + + /** + * definition of VectorFieldMapper.Builder + */ + public static class Builder extends ParametrizedFieldMapper.Builder { + protected Boolean ignoreMalformed; + + protected final Parameter stored = Parameter.boolParam("store", false, m -> toType(m).stored, false); + protected final Parameter hasDocValues = Parameter.boolParam("doc_values", false, m -> toType(m).hasDocValues, true); + protected final Parameter dimension = new Parameter<>(DIMENSION, false, () -> -1, (n, c, o) -> { + if (o == null) { + throw new IllegalArgumentException("Dimension cannot be null"); + } + int value; + try { + value = XContentMapValues.nodeIntegerValue(o); + } catch (Exception ex) { + throw new IllegalArgumentException( + String.format(Locale.getDefault(), "Unable to parse [dimension] from provided value [%s] for vector [%s]", o, name) + ); + } + if (value <= 0) { + throw new IllegalArgumentException( + String.format(Locale.getDefault(), "Dimension value must be greater than 0 for vector: %s", name) + ); + } + return value; + }, m -> toType(m).dimension); + + protected final Parameter correlationParamsContext = new Parameter<>( + CORRELATION_CONTEXT, + false, + () -> null, + (n, c, o) -> CorrelationParamsContext.parse(o), + m -> toType(m).correlationParams + ); + + protected final Parameter> meta = Parameter.metaParam(); + + /** + * Parameterized ctor for VectorFieldMapper.Builder + * @param name name + */ + public Builder(String name) { + super(name); + } + + @Override + protected List> getParameters() { + return Arrays.asList(stored, hasDocValues, dimension, meta, correlationParamsContext); + } + + protected Explicit ignoreMalformed(BuilderContext context) { + if (ignoreMalformed != null) { + return new Explicit<>(ignoreMalformed, true); + } + if (context.indexSettings() != null) { + return new Explicit<>(IGNORE_MALFORMED_SETTING.get(context.indexSettings()), false); + } + return Defaults.IGNORE_MALFORMED; + } + + @Override + public ParametrizedFieldMapper build(BuilderContext context) { + final CorrelationParamsContext correlationParams = correlationParamsContext.getValue(); + final MultiFields multiFieldsBuilder = this.multiFieldsBuilder.build(this, context); + final CopyTo copyToBuilder = copyTo.build(); + final Explicit ignoreMalformed = ignoreMalformed(context); + final Map metaValue = meta.getValue(); + + final CorrelationVectorFieldType mappedFieldType = new CorrelationVectorFieldType( + buildFullName(context), + metaValue, + dimension.getValue(), + correlationParams + ); + + CorrelationVectorFieldMapper.CreateLuceneFieldMapperInput createLuceneFieldMapperInput = + new CorrelationVectorFieldMapper.CreateLuceneFieldMapperInput( + name, + mappedFieldType, + multiFieldsBuilder, + copyToBuilder, + ignoreMalformed, + stored.get(), + hasDocValues.get(), + correlationParams + ); + return new CorrelationVectorFieldMapper(createLuceneFieldMapperInput); + } + } + + /** + * deifintion of VectorFieldMapper.TypeParser + */ + public static class TypeParser implements Mapper.TypeParser { + + /** + * default constructor of VectorFieldMapper.TypeParser + */ + public TypeParser() {} + + @Override + public Mapper.Builder parse(String name, Map node, ParserContext context) throws MapperParsingException { + Builder builder = new VectorFieldMapper.Builder(name); + builder.parse(name, context, node); + + if (builder.dimension.getValue() == -1) { + throw new IllegalArgumentException(String.format(Locale.getDefault(), "Dimension value missing for vector: %s", name)); + } + return builder; + } + } + + /** + * deifintion of VectorFieldMapper.CorrelationVectorFieldType + */ + public static class CorrelationVectorFieldType extends MappedFieldType { + int dimension; + CorrelationParamsContext correlationParams; + + /** + * Parameterized ctor for VectorFieldMapper.CorrelationVectorFieldType + * @param name name of the field + * @param meta meta of the field + * @param dimension dimension of the field + */ + public CorrelationVectorFieldType(String name, Map meta, int dimension) { + this(name, meta, dimension, null); + } + + /** + * Parameterized ctor for VectorFieldMapper.CorrelationVectorFieldType + * @param name name of the field + * @param meta meta of the field + * @param dimension dimension of the field + * @param correlationParams correlation params for the field + */ + public CorrelationVectorFieldType( + String name, + Map meta, + int dimension, + CorrelationParamsContext correlationParams + ) { + super(name, false, false, true, TextSearchInfo.NONE, meta); + this.dimension = dimension; + this.correlationParams = correlationParams; + } + + @Override + public ValueFetcher valueFetcher(QueryShardContext context, SearchLookup searchLookup, String s) { + throw new UnsupportedOperationException("Correlation Vector do not support fields search"); + } + + @Override + public String typeName() { + return CONTENT_TYPE; + } + + @Override + public Query existsQuery(QueryShardContext context) { + return new FieldExistsQuery(name()); + } + + @Override + public Query termQuery(Object o, QueryShardContext context) { + throw new QueryShardException( + context, + String.format( + Locale.getDefault(), + "Correlation vector do not support exact searching, use Correlation queries instead: [%s]", + name() + ) + ); + } + + /** + * get dimension + * @return dimension + */ + public int getDimension() { + return dimension; + } + + /** + * get correlation params + * @return correlation params + */ + public CorrelationParamsContext getCorrelationParams() { + return correlationParams; + } + } + + protected Explicit ignoreMalformed; + protected boolean stored; + protected boolean hasDocValues; + protected Integer dimension; + protected CorrelationParamsContext correlationParams; + + /** + * Parameterized ctor for VectorFieldMapper + * @param simpleName name of field + * @param mappedFieldType field type of field + * @param multiFields multi fields + * @param copyTo copy to + * @param ignoreMalformed ignore malformed + * @param stored stored field + * @param hasDocValues has doc values + */ + public VectorFieldMapper( + String simpleName, + CorrelationVectorFieldType mappedFieldType, + FieldMapper.MultiFields multiFields, + FieldMapper.CopyTo copyTo, + Explicit ignoreMalformed, + boolean stored, + boolean hasDocValues + ) { + super(simpleName, mappedFieldType, multiFields, copyTo); + this.ignoreMalformed = ignoreMalformed; + this.stored = stored; + this.hasDocValues = hasDocValues; + this.dimension = mappedFieldType.getDimension(); + } + + @Override + protected VectorFieldMapper clone() { + return (VectorFieldMapper) super.clone(); + } + + @Override + protected String contentType() { + return CONTENT_TYPE; + } + + @Override + protected void parseCreateField(ParseContext parseContext) throws IOException { + parseCreateField(parseContext, fieldType().getDimension()); + } + + protected abstract void parseCreateField(ParseContext parseContext, int dimension) throws IOException; + + Optional getFloatsFromContext(ParseContext context, int dimension) throws IOException { + context.path().add(simpleName()); + + List vector = new ArrayList<>(); + XContentParser.Token token = context.parser().currentToken(); + float value; + if (token == XContentParser.Token.START_ARRAY) { + token = context.parser().nextToken(); + while (token != XContentParser.Token.END_ARRAY) { + value = context.parser().floatValue(); + + if (Float.isNaN(value)) { + throw new IllegalArgumentException("Correlation vector values cannot be NaN"); + } + + if (Float.isInfinite(value)) { + throw new IllegalArgumentException("Correlation vector values cannot be infinity"); + } + vector.add(value); + token = context.parser().nextToken(); + } + } else if (token == XContentParser.Token.VALUE_NUMBER) { + value = context.parser().floatValue(); + if (Float.isNaN(value)) { + throw new IllegalArgumentException("Correlation vector values cannot be NaN"); + } + + if (Float.isInfinite(value)) { + throw new IllegalArgumentException("Correlation vector values cannot be infinity"); + } + vector.add(value); + context.parser().nextToken(); + } else if (token == XContentParser.Token.VALUE_NULL) { + context.path().remove(); + return Optional.empty(); + } + + if (dimension != vector.size()) { + String errorMessage = String.format("Vector dimension mismatch. Expected: %d, Given: %d", dimension, vector.size()); + throw new IllegalArgumentException(errorMessage); + } + + float[] array = new float[vector.size()]; + int i = 0; + for (Float f : vector) { + array[i++] = f; + } + return Optional.of(array); + } + + @Override + protected boolean docValuesByDefault() { + return true; + } + + @Override + public ParametrizedFieldMapper.Builder getMergeBuilder() { + return new VectorFieldMapper.Builder(simpleName()).init(this); + } + + @Override + public boolean parsesArrayValue() { + return true; + } + + @Override + public CorrelationVectorFieldType fieldType() { + return (CorrelationVectorFieldType) super.fieldType(); + } + + @Override + protected void doXContentBody(XContentBuilder builder, boolean includeDefaults, Params params) throws IOException { + super.doXContentBody(builder, includeDefaults, params); + if (includeDefaults || ignoreMalformed.explicit()) { + builder.field(Names.IGNORE_MALFORMED, ignoreMalformed.value()); + } + } + + /** + * Class for constants used in parent class VectorFieldMapper + */ + public static class Names { + public static final String IGNORE_MALFORMED = "ignore_malformed"; + } + + /** + * Class for constants used in parent class VectorFieldMapper + */ + public static class Defaults { + public static final Explicit IGNORE_MALFORMED = new Explicit<>(false, false); + } +} diff --git a/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/mapper/package-info.java b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/mapper/package-info.java new file mode 100644 index 0000000000000..4fdc622c3d886 --- /dev/null +++ b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/mapper/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * correlation field mapper package + */ +package org.opensearch.plugin.correlation.core.index.mapper; diff --git a/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/package-info.java b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/package-info.java new file mode 100644 index 0000000000000..cfc0ffdfa81f1 --- /dev/null +++ b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * package to wrap Lucene KnnFloatVectorField and KnnFloatVectorQuery for Opensearch events-correlation-engine + */ +package org.opensearch.plugin.correlation.core.index; diff --git a/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/query/CorrelationQueryBuilder.java b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/query/CorrelationQueryBuilder.java new file mode 100644 index 0000000000000..c91a497763a95 --- /dev/null +++ b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/query/CorrelationQueryBuilder.java @@ -0,0 +1,334 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.plugin.correlation.core.index.query; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.search.Query; +import org.opensearch.common.ParsingException; +import org.opensearch.common.Strings; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.ParseField; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.index.query.AbstractQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.plugin.correlation.core.index.mapper.VectorFieldMapper; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.Objects; + +/** + * Constructs a query to get correlated events or documents for a particular event or document. + * + * @opensearch.internal + */ +public class CorrelationQueryBuilder extends AbstractQueryBuilder { + + private static final Logger log = LogManager.getLogger(CorrelationQueryBuilder.class); + protected static final ParseField VECTOR_FIELD = new ParseField("vector"); + protected static final ParseField K_FIELD = new ParseField("k"); + protected static final ParseField FILTER_FIELD = new ParseField("filter"); + /** + * max number of neighbors that can be retrieved. + */ + public static int K_MAX = 10000; + + /** + * name of the query + */ + public static final ParseField NAME_FIELD = new ParseField("correlation"); + + private String fieldName; + private float[] vector; + private int k = 0; + private double boost; + private QueryBuilder filter; + + private CorrelationQueryBuilder() {} + + /** + * parameterized ctor for CorrelationQueryBuilder + * @param fieldName field name for query + * @param vector query vector + * @param k number of nearby neighbors + */ + public CorrelationQueryBuilder(String fieldName, float[] vector, int k) { + this(fieldName, vector, k, null); + } + + /** + * parameterized ctor for CorrelationQueryBuilder + * @param fieldName field name for query + * @param vector query vector + * @param k number of nearby neighbors + * @param filter optional filter query + */ + public CorrelationQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder filter) { + if (Strings.isNullOrEmpty(fieldName)) { + throw new IllegalArgumentException( + String.format(Locale.getDefault(), "[%s] requires fieldName", NAME_FIELD.getPreferredName()) + ); + } + if (vector == null) { + throw new IllegalArgumentException( + String.format(Locale.getDefault(), "[%s] requires query vector", NAME_FIELD.getPreferredName()) + ); + } + if (vector.length == 0) { + throw new IllegalArgumentException( + String.format(Locale.getDefault(), "[%s] query vector is empty", NAME_FIELD.getPreferredName()) + ); + } + if (k <= 0) { + throw new IllegalArgumentException(String.format(Locale.getDefault(), "[%s] requires k > 0", NAME_FIELD.getPreferredName())); + } + if (k > K_MAX) { + throw new IllegalArgumentException(String.format(Locale.getDefault(), "[%s] requires k <= ", K_MAX)); + } + + this.fieldName = fieldName; + this.vector = vector; + this.k = k; + this.filter = filter; + } + + /** + * parameterized ctor for CorrelationQueryBuilder + * @param sin StreamInput + * @throws IOException IOException + */ + public CorrelationQueryBuilder(StreamInput sin) throws IOException { + super(sin); + try { + this.fieldName = sin.readString(); + this.vector = sin.readFloatArray(); + this.k = sin.readInt(); + this.filter = sin.readOptionalNamedWriteable(QueryBuilder.class); + } catch (IOException ex) { + throw new RuntimeException("Unable to create CorrelationQueryBuilder", ex); + } + } + + private static float[] objectsToFloats(List objs) { + float[] vector = new float[objs.size()]; + for (int i = 0; i < objs.size(); ++i) { + vector[i] = ((Number) objs.get(i)).floatValue(); + } + return vector; + } + + /** + * parse into CorrelationQueryBuilder + * @param xcp XContentParser + * @return CorrelationQueryBuilder + */ + public static CorrelationQueryBuilder parse(XContentParser xcp) throws IOException { + String fieldName = null; + List vector = null; + float boost = AbstractQueryBuilder.DEFAULT_BOOST; + + int k = 0; + QueryBuilder filter = null; + String queryName = null; + String currentFieldName = null; + XContentParser.Token token; + while ((token = xcp.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = xcp.currentName(); + } else if (token == XContentParser.Token.START_OBJECT) { + throwParsingExceptionOnMultipleFields(NAME_FIELD.getPreferredName(), xcp.getTokenLocation(), fieldName, currentFieldName); + fieldName = currentFieldName; + while ((token = xcp.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = xcp.currentName(); + } else if (token.isValue() || token == XContentParser.Token.START_ARRAY) { + if (VECTOR_FIELD.match(currentFieldName, xcp.getDeprecationHandler())) { + vector = xcp.list(); + } else if (AbstractQueryBuilder.BOOST_FIELD.match(currentFieldName, xcp.getDeprecationHandler())) { + boost = xcp.floatValue(); + } else if (K_FIELD.match(currentFieldName, xcp.getDeprecationHandler())) { + k = (Integer) NumberFieldMapper.NumberType.INTEGER.parse(xcp.objectBytes(), false); + } else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, xcp.getDeprecationHandler())) { + queryName = xcp.text(); + } else { + throw new ParsingException( + xcp.getTokenLocation(), + "[" + NAME_FIELD.getPreferredName() + "] query does not support [" + currentFieldName + "]" + ); + } + } else if (token == XContentParser.Token.START_OBJECT) { + String tokenName = xcp.currentName(); + if (FILTER_FIELD.getPreferredName().equals(tokenName)) { + filter = parseInnerQueryBuilder(xcp); + } else { + throw new ParsingException( + xcp.getTokenLocation(), + "[" + NAME_FIELD.getPreferredName() + "] unknown token [" + token + "]" + ); + } + } else { + throw new ParsingException( + xcp.getTokenLocation(), + "[" + NAME_FIELD.getPreferredName() + "] unknown token [" + token + "] after [" + currentFieldName + "]" + ); + } + } + } else { + throwParsingExceptionOnMultipleFields(NAME_FIELD.getPreferredName(), xcp.getTokenLocation(), fieldName, xcp.currentName()); + fieldName = xcp.currentName(); + vector = xcp.list(); + } + } + + assert vector != null; + CorrelationQueryBuilder correlationQueryBuilder = new CorrelationQueryBuilder(fieldName, objectsToFloats(vector), k, filter); + correlationQueryBuilder.queryName(queryName); + correlationQueryBuilder.boost(boost); + return correlationQueryBuilder; + } + + public void setFieldName(String fieldName) { + this.fieldName = fieldName; + } + + /** + * get field name + * @return field name + */ + public String fieldName() { + return fieldName; + } + + public void setVector(float[] vector) { + this.vector = vector; + } + + /** + * get query vector + * @return query vector + */ + public Object vector() { + return vector; + } + + public void setK(int k) { + this.k = k; + } + + /** + * get number of nearby neighbors + * @return number of nearby neighbors + */ + public int getK() { + return k; + } + + public void setBoost(double boost) { + this.boost = boost; + } + + /** + * get boost + * @return boost + */ + public double getBoost() { + return boost; + } + + public void setFilter(QueryBuilder filter) { + this.filter = filter; + } + + /** + * get optional filter + * @return optional filter + */ + public QueryBuilder getFilter() { + return filter; + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(fieldName); + out.writeFloatArray(vector); + out.writeInt(k); + out.writeOptionalNamedWriteable(filter); + } + + @Override + public void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(fieldName); + + builder.field(VECTOR_FIELD.getPreferredName(), vector); + builder.field(K_FIELD.getPreferredName(), k); + if (filter != null) { + builder.field(FILTER_FIELD.getPreferredName(), filter); + } + printBoostAndQueryName(builder); + builder.endObject(); + } + + @Override + protected Query doToQuery(QueryShardContext context) throws IOException { + MappedFieldType mappedFieldType = context.fieldMapper(fieldName); + + if (!(mappedFieldType instanceof VectorFieldMapper.CorrelationVectorFieldType)) { + throw new IllegalArgumentException(String.format(Locale.getDefault(), "Field '%s' is not knn_vector type.", this.fieldName)); + } + + VectorFieldMapper.CorrelationVectorFieldType correlationVectorFieldType = + (VectorFieldMapper.CorrelationVectorFieldType) mappedFieldType; + int fieldDimension = correlationVectorFieldType.getDimension(); + + if (fieldDimension != vector.length) { + throw new IllegalArgumentException( + String.format( + Locale.getDefault(), + "Query vector has invalid dimension: %d. Dimension should be: %d", + vector.length, + fieldDimension + ) + ); + } + + String indexName = context.index().getName(); + CorrelationQueryFactory.CreateQueryRequest createQueryRequest = new CorrelationQueryFactory.CreateQueryRequest( + indexName, + this.fieldName, + this.vector, + this.k, + this.filter, + context + ); + return CorrelationQueryFactory.create(createQueryRequest); + } + + @Override + protected boolean doEquals(CorrelationQueryBuilder other) { + return Objects.equals(fieldName, other.fieldName) && Arrays.equals(vector, other.vector) && Objects.equals(k, other.k); + } + + @Override + protected int doHashCode() { + return Objects.hash(fieldName, vector, k); + } + + @Override + public String getWriteableName() { + return NAME_FIELD.getPreferredName(); + } +} diff --git a/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/query/CorrelationQueryFactory.java b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/query/CorrelationQueryFactory.java new file mode 100644 index 0000000000000..d5db299bfa3a5 --- /dev/null +++ b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/query/CorrelationQueryFactory.java @@ -0,0 +1,142 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.plugin.correlation.core.index.query; + +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.Query; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryShardContext; + +import java.io.IOException; +import java.util.Optional; + +/** + * CorrelationQueryFactory util class is used to construct a Lucene KnnFloatVectorQuery. + * + * @opensearch.internal + */ +public class CorrelationQueryFactory { + + /** + * static method which takes input params to construct a Lucene KnnFloatVectorQuery. + * @param createQueryRequest object parameter containing inputs for constructing Lucene KnnFloatVectorQuery. + * @return generic Lucene Query object + */ + public static Query create(CreateQueryRequest createQueryRequest) { + final String indexName = createQueryRequest.getIndexName(); + final String fieldName = createQueryRequest.getFieldName(); + final int k = createQueryRequest.getK(); + final float[] vector = createQueryRequest.getVector(); + + if (createQueryRequest.getFilter().isPresent()) { + final QueryShardContext context = createQueryRequest.getContext() + .orElseThrow(() -> new RuntimeException("Shard context cannot be null")); + + try { + final Query filterQuery = createQueryRequest.getFilter().get().toQuery(context); + return new KnnFloatVectorQuery(fieldName, vector, k, filterQuery); + } catch (IOException ex) { + throw new RuntimeException("Cannot create knn query with filter", ex); + } + } + return new KnnFloatVectorQuery(fieldName, vector, k); + } + + /** + * class containing params to construct a Lucene KnnFloatVectorQuery. + * + * @opensearch.internal + */ + public static class CreateQueryRequest { + private String indexName; + + private String fieldName; + + private float[] vector; + + private int k; + + private QueryBuilder filter; + + private QueryShardContext context; + + /** + * Parameterized ctor for CreateQueryRequest + * @param indexName index name + * @param fieldName field name + * @param vector query vector + * @param k number of nearby neighbors + * @param filter additional filter query + * @param context QueryShardContext + */ + public CreateQueryRequest( + String indexName, + String fieldName, + float[] vector, + int k, + QueryBuilder filter, + QueryShardContext context + ) { + this.indexName = indexName; + this.fieldName = fieldName; + this.vector = vector; + this.k = k; + this.filter = filter; + this.context = context; + } + + /** + * get index name + * @return get index name + */ + public String getIndexName() { + return indexName; + } + + /** + * get field name + * @return get field name + */ + public String getFieldName() { + return fieldName; + } + + /** + * get vector + * @return get vector + */ + public float[] getVector() { + return vector; + } + + /** + * get number of nearby neighbors + * @return number of nearby neighbors + */ + public int getK() { + return k; + } + + /** + * get optional filter query + * @return get optional filter query + */ + public Optional getFilter() { + return Optional.ofNullable(filter); + } + + /** + * get optional query shard context + * @return get optional query shard context + */ + public Optional getContext() { + return Optional.ofNullable(context); + } + } +} diff --git a/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/query/package-info.java b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/query/package-info.java new file mode 100644 index 0000000000000..2cf5db786a60f --- /dev/null +++ b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/query/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * correlation query builder package + */ +package org.opensearch.plugin.correlation.core.index.query; diff --git a/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/serializer/CorrelationVectorSerializer.java b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/serializer/CorrelationVectorSerializer.java new file mode 100644 index 0000000000000..12edeebce16a0 --- /dev/null +++ b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/serializer/CorrelationVectorSerializer.java @@ -0,0 +1,58 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.plugin.correlation.core.index.serializer; + +import org.opensearch.ExceptionsHelper; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; + +/** + * CorrelationVectorSerializer class to do serde operation of converting float vectors to byte array and vice versa. + * + * @opensearch.internal + */ +public class CorrelationVectorSerializer { + + /** + * converts float array based vector to byte array. + * @param input float array + * @return byte array + */ + public byte[] floatToByteArray(float[] input) { + byte[] bytes; + try ( + ByteArrayOutputStream byteStream = new ByteArrayOutputStream(); + ObjectOutputStream objectStream = new ObjectOutputStream(byteStream); + ) { + objectStream.writeObject(input); + bytes = byteStream.toByteArray(); + } catch (IOException ex) { + throw ExceptionsHelper.convertToOpenSearchException(ex); + } + return bytes; + } + + /** + * converts byte array to float array + * @param byteStream byte array input stream + * @return float array + */ + public float[] byteToFloatArray(ByteArrayInputStream byteStream) { + try { + ObjectInputStream objectStream = new ObjectInputStream(byteStream); + return (float[]) objectStream.readObject(); + } catch (IOException | ClassNotFoundException ex) { + throw ExceptionsHelper.convertToOpenSearchException(ex); + } + } +} diff --git a/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/serializer/package-info.java b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/serializer/package-info.java new file mode 100644 index 0000000000000..61f587c30ba59 --- /dev/null +++ b/plugins/events-correlation-engine/src/main/java/org/opensearch/plugin/correlation/core/index/serializer/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * serde package of correlation vector + */ +package org.opensearch.plugin.correlation.core.index.serializer; diff --git a/plugins/events-correlation-engine/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec b/plugins/events-correlation-engine/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec new file mode 100644 index 0000000000000..598a3b6af73c2 --- /dev/null +++ b/plugins/events-correlation-engine/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec @@ -0,0 +1 @@ +org.opensearch.plugin.correlation.core.index.codec.correlation950.CorrelationCodec diff --git a/plugins/events-correlation-engine/src/test/java/org/opensearch/plugin/correlation/core/index/codec/correlation950/CorrelationCodecTests.java b/plugins/events-correlation-engine/src/test/java/org/opensearch/plugin/correlation/core/index/codec/correlation950/CorrelationCodecTests.java new file mode 100644 index 0000000000000..ac859773f6350 --- /dev/null +++ b/plugins/events-correlation-engine/src/test/java/org/opensearch/plugin/correlation/core/index/codec/correlation950/CorrelationCodecTests.java @@ -0,0 +1,120 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.plugin.correlation.core.index.codec.correlation950; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.SerialMergeScheduler; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.plugin.correlation.core.index.CorrelationParamsContext; +import org.opensearch.plugin.correlation.core.index.mapper.VectorFieldMapper; +import org.opensearch.plugin.correlation.core.index.query.CorrelationQueryFactory; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.spy; +import static org.opensearch.plugin.correlation.core.index.codec.BasePerFieldCorrelationVectorsFormat.METHOD_PARAMETER_EF_CONSTRUCTION; +import static org.opensearch.plugin.correlation.core.index.codec.BasePerFieldCorrelationVectorsFormat.METHOD_PARAMETER_M; +import static org.opensearch.plugin.correlation.core.index.codec.CorrelationCodecVersion.V_9_5_0; + +/** + * Unit tests for custom correlation codec + */ +public class CorrelationCodecTests extends OpenSearchTestCase { + + private static final String FIELD_NAME_ONE = "test_vector_one"; + private static final String FIELD_NAME_TWO = "test_vector_two"; + + /** + * test correlation vector index + * @throws Exception Exception + */ + public void testCorrelationVectorIndex() throws Exception { + Function perFieldCorrelationVectorsProvider = + mapperService -> new PerFieldCorrelationVectorsFormat(Optional.of(mapperService)); + Function correlationCodecProvider = (correlationVectorsFormat -> new CorrelationCodec( + V_9_5_0.getDefaultCodecDelegate(), + correlationVectorsFormat + )); + testCorrelationVectorIndex(correlationCodecProvider, perFieldCorrelationVectorsProvider); + } + + private void testCorrelationVectorIndex( + final Function codecProvider, + final Function perFieldCorrelationVectorsProvider + ) throws Exception { + final MapperService mapperService = mock(MapperService.class); + final CorrelationParamsContext correlationParamsContext = new CorrelationParamsContext( + VectorSimilarityFunction.EUCLIDEAN, + Map.of(METHOD_PARAMETER_M, 16, METHOD_PARAMETER_EF_CONSTRUCTION, 256) + ); + + final VectorFieldMapper.CorrelationVectorFieldType mappedFieldType1 = new VectorFieldMapper.CorrelationVectorFieldType( + FIELD_NAME_ONE, + Map.of(), + 3, + correlationParamsContext + ); + final VectorFieldMapper.CorrelationVectorFieldType mappedFieldType2 = new VectorFieldMapper.CorrelationVectorFieldType( + FIELD_NAME_TWO, + Map.of(), + 2, + correlationParamsContext + ); + when(mapperService.fieldType(eq(FIELD_NAME_ONE))).thenReturn(mappedFieldType1); + when(mapperService.fieldType(eq(FIELD_NAME_TWO))).thenReturn(mappedFieldType2); + + var perFieldCorrelationVectorsFormatSpy = spy(perFieldCorrelationVectorsProvider.apply(mapperService)); + final Codec codec = codecProvider.apply(perFieldCorrelationVectorsFormatSpy); + + Directory dir = newFSDirectory(createTempDir()); + IndexWriterConfig iwc = newIndexWriterConfig(); + iwc.setMergeScheduler(new SerialMergeScheduler()); + iwc.setCodec(codec); + + final FieldType luceneFieldType = KnnFloatVectorField.createFieldType(3, VectorSimilarityFunction.EUCLIDEAN); + float[] array = { 1.0f, 3.0f, 4.0f }; + KnnFloatVectorField vectorField = new KnnFloatVectorField(FIELD_NAME_ONE, array, luceneFieldType); + RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc); + Document doc = new Document(); + doc.add(vectorField); + writer.addDocument(doc); + writer.commit(); + IndexReader reader = writer.getReader(); + writer.close(); + + verify(perFieldCorrelationVectorsFormatSpy).getKnnVectorsFormatForField(eq(FIELD_NAME_ONE)); + + IndexSearcher searcher = new IndexSearcher(reader); + Query query = CorrelationQueryFactory.create( + new CorrelationQueryFactory.CreateQueryRequest("dummy", FIELD_NAME_ONE, new float[] { 1.0f, 0.0f, 0.0f }, 1, null, null) + ); + + assertEquals(1, searcher.count(query)); + + reader.close(); + dir.close(); + } +} diff --git a/plugins/events-correlation-engine/src/test/java/org/opensearch/plugin/correlation/core/index/mapper/CorrelationVectorFieldMapperTests.java b/plugins/events-correlation-engine/src/test/java/org/opensearch/plugin/correlation/core/index/mapper/CorrelationVectorFieldMapperTests.java new file mode 100644 index 0000000000000..d2cb46ca5c193 --- /dev/null +++ b/plugins/events-correlation-engine/src/test/java/org/opensearch/plugin/correlation/core/index/mapper/CorrelationVectorFieldMapperTests.java @@ -0,0 +1,208 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.plugin.correlation.core.index.mapper; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.junit.Assert; +import org.opensearch.Version; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.settings.IndexScopedSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.IndexSettings; +import org.opensearch.index.mapper.ContentPath; +import org.opensearch.index.mapper.Mapper; +import org.opensearch.index.mapper.MapperParsingException; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.HashSet; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Unit tests for correlation vector field mapper + */ +public class CorrelationVectorFieldMapperTests extends OpenSearchTestCase { + + private static final String CORRELATION_VECTOR_TYPE = "correlation_vector"; + private static final String DIMENSION_FIELD_NAME = "dimension"; + private static final String TYPE_FIELD_NAME = "type"; + + /** + * test builder construction from parse of correlation params context + * @throws IOException IOException + */ + public void testBuilder_parse_fromCorrelationParamsContext() throws IOException { + String fieldName = "test-field-name"; + String indexName = "test-index-name"; + Settings settings = Settings.builder().put(settings(Version.CURRENT).build()).build(); + + VectorFieldMapper.TypeParser typeParser = new VectorFieldMapper.TypeParser(); + + int efConstruction = 321; + int m = 12; + int dimension = 10; + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, CORRELATION_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, dimension) + .startObject("correlation_ctx") + .field("similarityFunction", VectorSimilarityFunction.EUCLIDEAN.name()) + .startObject("parameters") + .field("m", m) + .field("ef_construction", efConstruction) + .endObject() + .endObject() + .endObject(); + + VectorFieldMapper.Builder builder = (VectorFieldMapper.Builder) typeParser.parse( + fieldName, + XContentHelper.convertToMap(BytesReference.bytes(xContentBuilder), true, xContentBuilder.contentType()).v2(), + buildParserContext(indexName, settings) + ); + Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); + builder.build(builderContext); + + Assert.assertEquals(VectorSimilarityFunction.EUCLIDEAN, builder.correlationParamsContext.getValue().getSimilarityFunction()); + Assert.assertEquals(321, builder.correlationParamsContext.getValue().getParameters().get("ef_construction")); + + XContentBuilder xContentBuilderEmptyParams = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, CORRELATION_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, dimension) + .startObject("correlation_ctx") + .field("similarityFunction", VectorSimilarityFunction.EUCLIDEAN.name()) + .endObject() + .endObject(); + + VectorFieldMapper.Builder builderEmptyParams = (VectorFieldMapper.Builder) typeParser.parse( + fieldName, + XContentHelper.convertToMap(BytesReference.bytes(xContentBuilderEmptyParams), true, xContentBuilderEmptyParams.contentType()) + .v2(), + buildParserContext(indexName, settings) + ); + + Assert.assertEquals( + VectorSimilarityFunction.EUCLIDEAN, + builderEmptyParams.correlationParamsContext.getValue().getSimilarityFunction() + ); + Assert.assertTrue(builderEmptyParams.correlationParamsContext.getValue().getParameters().isEmpty()); + } + + /** + * test type parser construction throw error for invalid dimension of correlation vectors + * @throws IOException IOException + */ + public void testTypeParser_parse_fromCorrelationParamsContext_InvalidDimension() throws IOException { + String fieldName = "test-field-name"; + String indexName = "test-index-name"; + Settings settings = Settings.builder().put(settings(Version.CURRENT).build()).build(); + + VectorFieldMapper.TypeParser typeParser = new VectorFieldMapper.TypeParser(); + + int efConstruction = 321; + int m = 12; + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, CORRELATION_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, 2000) + .startObject("correlation_ctx") + .field("similarityFunction", VectorSimilarityFunction.EUCLIDEAN.name()) + .startObject("parameters") + .field("m", m) + .field("ef_construction", efConstruction) + .endObject() + .endObject() + .endObject(); + + VectorFieldMapper.Builder builder = (VectorFieldMapper.Builder) typeParser.parse( + fieldName, + XContentHelper.convertToMap(BytesReference.bytes(xContentBuilder), true, xContentBuilder.contentType()).v2(), + buildParserContext(indexName, settings) + ); + + expectThrows(IllegalArgumentException.class, () -> builder.build(new Mapper.BuilderContext(settings, new ContentPath()))); + } + + /** + * test type parser construction error for invalid vector similarity function + * @throws IOException IOException + */ + public void testTypeParser_parse_fromCorrelationParamsContext_InvalidVectorSimilarityFunction() throws IOException { + String fieldName = "test-field-name"; + String indexName = "test-index-name"; + Settings settings = Settings.builder().put(settings(Version.CURRENT).build()).build(); + + VectorFieldMapper.TypeParser typeParser = new VectorFieldMapper.TypeParser(); + + int efConstruction = 321; + int m = 12; + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, CORRELATION_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, 2000) + .startObject("correlation_ctx") + .field("similarityFunction", "invalid") + .startObject("parameters") + .field("m", m) + .field("ef_construction", efConstruction) + .endObject() + .endObject() + .endObject(); + + expectThrows( + MapperParsingException.class, + () -> typeParser.parse( + fieldName, + XContentHelper.convertToMap(BytesReference.bytes(xContentBuilder), true, xContentBuilder.contentType()).v2(), + buildParserContext(indexName, settings) + ) + ); + } + + private IndexMetadata buildIndexMetaData(String index, Settings settings) { + return IndexMetadata.builder(index) + .settings(settings) + .numberOfShards(1) + .numberOfReplicas(0) + .version(7) + .mappingVersion(0) + .settingsVersion(0) + .aliasesVersion(0) + .creationDate(0) + .build(); + } + + private Mapper.TypeParser.ParserContext buildParserContext(String index, Settings settings) { + IndexSettings indexSettings = new IndexSettings( + buildIndexMetaData(index, settings), + Settings.EMPTY, + new IndexScopedSettings(Settings.EMPTY, new HashSet<>(IndexScopedSettings.BUILT_IN_INDEX_SETTINGS)) + ); + + MapperService mapperService = mock(MapperService.class); + when(mapperService.getIndexSettings()).thenReturn(indexSettings); + + return new Mapper.TypeParser.ParserContext( + null, + mapperService, + type -> new VectorFieldMapper.TypeParser(), + Version.CURRENT, + null, + null, + null + ); + } +} diff --git a/plugins/events-correlation-engine/src/test/java/org/opensearch/plugin/correlation/core/index/query/CorrelationQueryBuilderTests.java b/plugins/events-correlation-engine/src/test/java/org/opensearch/plugin/correlation/core/index/query/CorrelationQueryBuilderTests.java new file mode 100644 index 0000000000000..fd3c7220aad74 --- /dev/null +++ b/plugins/events-correlation-engine/src/test/java/org/opensearch/plugin/correlation/core/index/query/CorrelationQueryBuilderTests.java @@ -0,0 +1,268 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.plugin.correlation.core.index.query; + +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.junit.Assert; +import org.opensearch.Version; +import org.opensearch.cluster.ClusterModule; +import org.opensearch.common.Strings; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.Index; +import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.plugin.correlation.core.index.mapper.VectorFieldMapper; +import org.opensearch.plugins.SearchPlugin; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.List; +import java.util.Optional; + +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Unit tests for Correlation Query Builder + */ +public class CorrelationQueryBuilderTests extends OpenSearchTestCase { + + private static final String FIELD_NAME = "myvector"; + private static final int K = 1; + private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("field", "value"); + private static final float[] QUERY_VECTOR = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; + + /** + * test invalid number of nearby neighbors + */ + public void testInvalidK() { + float[] queryVector = { 1.0f, 1.0f }; + + expectThrows(IllegalArgumentException.class, () -> new CorrelationQueryBuilder(FIELD_NAME, queryVector, -K)); + expectThrows(IllegalArgumentException.class, () -> new CorrelationQueryBuilder(FIELD_NAME, queryVector, 0)); + expectThrows( + IllegalArgumentException.class, + () -> new CorrelationQueryBuilder(FIELD_NAME, queryVector, CorrelationQueryBuilder.K_MAX + 1) + ); + } + + /** + * test empty vector scenario + */ + public void testEmptyVector() { + final float[] queryVector = null; + expectThrows(IllegalArgumentException.class, () -> new CorrelationQueryBuilder(FIELD_NAME, queryVector, 1)); + final float[] queryVector1 = new float[] {}; + expectThrows(IllegalArgumentException.class, () -> new CorrelationQueryBuilder(FIELD_NAME, queryVector1, 1)); + } + + /** + * test serde with xcontent + * @throws IOException IOException + */ + public void testFromXContent() throws IOException { + CorrelationQueryBuilder correlationQueryBuilder = new CorrelationQueryBuilder(FIELD_NAME, QUERY_VECTOR, K); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(correlationQueryBuilder.fieldName()); + builder.field(CorrelationQueryBuilder.VECTOR_FIELD.getPreferredName(), correlationQueryBuilder.vector()); + builder.field(CorrelationQueryBuilder.K_FIELD.getPreferredName(), correlationQueryBuilder.getK()); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + CorrelationQueryBuilder actualBuilder = CorrelationQueryBuilder.parse(contentParser); + Assert.assertEquals(actualBuilder, correlationQueryBuilder); + } + + /** + * test serde with xcontent + * @throws IOException IOException + */ + public void testFromXContentFromString() throws IOException { + String correlationQuery = "{\n" + + " \"myvector\" : {\n" + + " \"vector\" : [\n" + + " 1.0,\n" + + " 2.0,\n" + + " 3.0,\n" + + " 4.0\n" + + " ],\n" + + " \"k\" : 1,\n" + + " \"boost\" : 1.0\n" + + " }\n" + + "}"; + XContentParser contentParser = createParser(JsonXContent.jsonXContent, correlationQuery); + contentParser.nextToken(); + CorrelationQueryBuilder actualBuilder = CorrelationQueryBuilder.parse(contentParser); + Assert.assertEquals(correlationQuery.replace("\n", "").replace(" ", ""), Strings.toString(XContentType.JSON, actualBuilder)); + } + + /** + * test serde with xcontent with filters + * @throws IOException IOException + */ + public void testFromXContentWithFilters() throws IOException { + CorrelationQueryBuilder correlationQueryBuilder = new CorrelationQueryBuilder(FIELD_NAME, QUERY_VECTOR, K, TERM_QUERY); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(correlationQueryBuilder.fieldName()); + builder.field(CorrelationQueryBuilder.VECTOR_FIELD.getPreferredName(), correlationQueryBuilder.vector()); + builder.field(CorrelationQueryBuilder.K_FIELD.getPreferredName(), correlationQueryBuilder.getK()); + builder.field(CorrelationQueryBuilder.FILTER_FIELD.getPreferredName(), correlationQueryBuilder.getFilter()); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + CorrelationQueryBuilder actualBuilder = CorrelationQueryBuilder.parse(contentParser); + Assert.assertEquals(actualBuilder, correlationQueryBuilder); + } + + /** + * test conversion o KnnFloatVectorQuery logic + * @throws IOException IOException + */ + public void testDoToQuery() throws IOException { + CorrelationQueryBuilder correlationQueryBuilder = new CorrelationQueryBuilder(FIELD_NAME, QUERY_VECTOR, K); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + VectorFieldMapper.CorrelationVectorFieldType mockCorrVectorField = mock(VectorFieldMapper.CorrelationVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockCorrVectorField.getDimension()).thenReturn(4); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockCorrVectorField); + KnnFloatVectorQuery query = (KnnFloatVectorQuery) correlationQueryBuilder.doToQuery(mockQueryShardContext); + Assert.assertEquals(FIELD_NAME, query.getField()); + Assert.assertArrayEquals(QUERY_VECTOR, query.getTargetCopy(), 0.1f); + Assert.assertEquals(K, query.getK()); + } + + /** + * test conversion o KnnFloatVectorQuery logic with filter + * @throws IOException IOException + */ + public void testDoToQueryWithFilter() throws IOException { + CorrelationQueryBuilder correlationQueryBuilder = new CorrelationQueryBuilder(FIELD_NAME, QUERY_VECTOR, K, TERM_QUERY); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + VectorFieldMapper.CorrelationVectorFieldType mockCorrVectorField = mock(VectorFieldMapper.CorrelationVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockCorrVectorField.getDimension()).thenReturn(4); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockCorrVectorField); + KnnFloatVectorQuery query = (KnnFloatVectorQuery) correlationQueryBuilder.doToQuery(mockQueryShardContext); + Assert.assertEquals(FIELD_NAME, query.getField()); + Assert.assertArrayEquals(QUERY_VECTOR, query.getTargetCopy(), 0.1f); + Assert.assertEquals(K, query.getK()); + Assert.assertEquals(TERM_QUERY.toQuery(mockQueryShardContext), query.getFilter()); + } + + /** + * test conversion o KnnFloatVectorQuery logic failure with invalid dimensions + */ + public void testDoToQueryInvalidDimensions() { + CorrelationQueryBuilder correlationQueryBuilder = new CorrelationQueryBuilder(FIELD_NAME, QUERY_VECTOR, K); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + VectorFieldMapper.CorrelationVectorFieldType mockCorrVectorField = mock(VectorFieldMapper.CorrelationVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockCorrVectorField.getDimension()).thenReturn(400); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockCorrVectorField); + expectThrows(IllegalArgumentException.class, () -> correlationQueryBuilder.doToQuery(mockQueryShardContext)); + } + + /** + * test conversion o KnnFloatVectorQuery logic failure with invalid field type + */ + public void testDoToQueryInvalidFieldType() { + CorrelationQueryBuilder correlationQueryBuilder = new CorrelationQueryBuilder(FIELD_NAME, QUERY_VECTOR, K); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + NumberFieldMapper.NumberFieldType mockCorrVectorField = mock(NumberFieldMapper.NumberFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockCorrVectorField); + expectThrows(IllegalArgumentException.class, () -> correlationQueryBuilder.doToQuery(mockQueryShardContext)); + } + + /** + * test serialization of Correlation Query Builder + * @throws Exception + */ + public void testSerialization() throws Exception { + assertSerialization(Optional.empty()); + assertSerialization(Optional.of(TERM_QUERY)); + } + + private void assertSerialization(final Optional queryBuilderOptional) throws IOException { + final CorrelationQueryBuilder builder = queryBuilderOptional.isPresent() + ? new CorrelationQueryBuilder(FIELD_NAME, QUERY_VECTOR, K, queryBuilderOptional.get()) + : new CorrelationQueryBuilder(FIELD_NAME, QUERY_VECTOR, K); + + try (BytesStreamOutput output = new BytesStreamOutput()) { + output.setVersion(Version.CURRENT); + output.writeNamedWriteable(builder); + + try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry())) { + in.setVersion(Version.CURRENT); + final QueryBuilder deserializedQuery = in.readNamedWriteable(QueryBuilder.class); + + assertNotNull(deserializedQuery); + assertTrue(deserializedQuery instanceof CorrelationQueryBuilder); + final CorrelationQueryBuilder deserializedKnnQueryBuilder = (CorrelationQueryBuilder) deserializedQuery; + assertEquals(FIELD_NAME, deserializedKnnQueryBuilder.fieldName()); + assertArrayEquals(QUERY_VECTOR, (float[]) deserializedKnnQueryBuilder.vector(), 0.0f); + assertEquals(K, deserializedKnnQueryBuilder.getK()); + if (queryBuilderOptional.isPresent()) { + assertNotNull(deserializedKnnQueryBuilder.getFilter()); + assertEquals(queryBuilderOptional.get(), deserializedKnnQueryBuilder.getFilter()); + } else { + assertNull(deserializedKnnQueryBuilder.getFilter()); + } + } + } + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List list = ClusterModule.getNamedXWriteables(); + SearchPlugin.QuerySpec spec = new SearchPlugin.QuerySpec<>( + TermQueryBuilder.NAME, + TermQueryBuilder::new, + TermQueryBuilder::fromXContent + ); + list.add(new NamedXContentRegistry.Entry(QueryBuilder.class, spec.getName(), (p, c) -> spec.getParser().fromXContent(p))); + NamedXContentRegistry registry = new NamedXContentRegistry(list); + return registry; + } + + @Override + protected NamedWriteableRegistry writableRegistry() { + final List entries = ClusterModule.getNamedWriteables(); + entries.add( + new NamedWriteableRegistry.Entry( + QueryBuilder.class, + CorrelationQueryBuilder.NAME_FIELD.getPreferredName(), + CorrelationQueryBuilder::new + ) + ); + entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, TermQueryBuilder.NAME, TermQueryBuilder::new)); + return new NamedWriteableRegistry(entries); + } +} diff --git a/plugins/events-correlation-engine/src/test/java/org/opensearch/plugin/correlation/core/index/serializer/CorrelationVectorSerializerTests.java b/plugins/events-correlation-engine/src/test/java/org/opensearch/plugin/correlation/core/index/serializer/CorrelationVectorSerializerTests.java new file mode 100644 index 0000000000000..7553a36c7a184 --- /dev/null +++ b/plugins/events-correlation-engine/src/test/java/org/opensearch/plugin/correlation/core/index/serializer/CorrelationVectorSerializerTests.java @@ -0,0 +1,57 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.plugin.correlation.core.index.serializer; + +import org.junit.Assert; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectOutputStream; +import java.util.Random; + +/** + * Unit tests for Correlation Vector Serializer + */ +public class CorrelationVectorSerializerTests extends OpenSearchTestCase { + + private final Random random = new Random(); + + /** + * test float vector to array serializer + * @throws IOException IOException + */ + public void testVectorAsArraySerializer() throws IOException { + final float[] vector = getArrayOfRandomFloats(20); + + final ByteArrayOutputStream byteStream = new ByteArrayOutputStream(); + final ObjectOutputStream objectStream = new ObjectOutputStream(byteStream); + objectStream.writeObject(vector); + final byte[] serializedVector = byteStream.toByteArray(); + + CorrelationVectorSerializer vectorSerializer = new CorrelationVectorSerializer(); + final byte[] actualSerializedVector = vectorSerializer.floatToByteArray(vector); + + Assert.assertNotNull(actualSerializedVector); + Assert.assertArrayEquals(serializedVector, actualSerializedVector); + + final float[] actualDeserializedVector = vectorSerializer.byteToFloatArray(new ByteArrayInputStream(actualSerializedVector)); + Assert.assertNotNull(actualDeserializedVector); + Assert.assertArrayEquals(vector, actualDeserializedVector, 0.1f); + } + + private float[] getArrayOfRandomFloats(int length) { + float[] vector = new float[length]; + for (int i = 0; i < 20; ++i) { + vector[i] = random.nextFloat(); + } + return vector; + } +}