From 50a6935e422580c2e96572453cac7b008e91f691 Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Thu, 15 Aug 2024 18:35:41 -0400 Subject: [PATCH] Encapsulate dimension, vector data type validation/processing inside Library Introduces a new configuration object KNNMethodConfigContext. The KNNMethodContext contains the user provided information for want their index to be built like. However, it is missing a few pieces that are defined outside of it. These pieces are needed to validate the config and actually build the config. This change Integrates validation of KNNMethodContexts with KNNMethodConfigContext and better encapsulates KNNLibrary config info in the engine package Signed-off-by: John Mazanec --- CHANGELOG.md | 1 + qa/restart-upgrade/build.gradle | 46 +++ .../bwc/AbstractRestartUpgradeTestCase.java | 1 - .../org/opensearch/knn/bwc/IndexingIT.java | 29 ++ .../KNN80Codec/KNN80DocValuesConsumer.java | 41 ++- .../knn/index/engine/AbstractKNNLibrary.java | 83 ++++- .../knn/index/engine/AbstractKNNMethod.java | 87 +++-- .../engine/DefaultHnswSearchContext.java | 5 +- .../index/engine/DefaultIVFSearchContext.java | 5 +- .../knn/index/engine/JVMLibrary.java | 2 +- .../knn/index/engine/KNNEngine.java | 21 +- .../knn/index/engine/KNNLibrary.java | 24 +- .../engine/KNNLibraryIndexingContext.java | 23 +- .../engine/KNNLibraryIndexingContextImpl.java | 21 ++ .../knn/index/engine/KNNMethod.java | 23 +- .../index/engine/KNNMethodConfigContext.java | 51 +++ .../knn/index/engine/KNNMethodContext.java | 24 +- .../knn/index/engine/MethodComponent.java | 108 +++--- .../index/engine/MethodComponentContext.java | 6 - .../knn/index/engine/NativeLibrary.java | 4 +- .../knn/index/engine/Parameter.java | 261 ++------------ .../engine/faiss/AbstractFaissMethod.java | 84 +++++ .../knn/index/engine/faiss/FaissFP16Util.java | 145 ++++++++ .../index/engine/faiss/FaissFlatEncoder.java | 16 +- .../index/engine/faiss/FaissHNSWMethod.java | 37 +- .../engine/faiss/FaissHNSWPQEncoder.java | 25 +- .../index/engine/faiss/FaissIVFMethod.java | 30 +- .../index/engine/faiss/FaissIVFPQEncoder.java | 35 +- .../index/engine/faiss/FaissSQEncoder.java | 19 +- .../engine/faiss/MethodAsMapBuilder.java | 14 +- .../index/engine/lucene/LuceneHNSWMethod.java | 9 +- .../lucene/LuceneHNSWSearchContext.java | 5 +- .../index/engine/lucene/LuceneSQEncoder.java | 10 +- .../index/engine/nmslib/NmslibHNSWMethod.java | 9 +- .../engine/validation/ParameterValidator.java | 11 +- .../index/mapper/FlatVectorFieldMapper.java | 16 +- .../index/mapper/KNNVectorFieldMapper.java | 284 +++++++-------- .../mapper/KNNVectorFieldMapperUtil.java | 202 +---------- .../knn/index/mapper/LuceneFieldMapper.java | 87 ++--- .../knn/index/mapper/MethodFieldMapper.java | 105 ++---- .../knn/index/mapper/ModelFieldMapper.java | 95 +++-- .../index/mapper/PerDimensionProcessor.java | 16 - .../index/mapper/PerDimensionValidator.java | 14 - .../knn/index/query/KNNQueryBuilder.java | 4 +- .../opensearch/knn/index/util/IndexUtil.java | 2 +- .../transport/TrainingModelRequest.java | 122 +------ .../TrainingModelTransportAction.java | 16 +- .../opensearch/knn/training/TrainingJob.java | 58 +-- .../knn/training/VectorSpaceInfo.java | 26 -- .../java/org/opensearch/knn/KNNTestCase.java | 14 +- .../opensearch/knn/index/OpenSearchIT.java | 4 +- .../knn/index/VectorDataTypeIT.java | 28 +- .../KNN80DocValuesConsumerTests.java | 21 +- .../knn/index/codec/KNNCodecTestCase.java | 13 +- .../index/engine/AbstractKNNLibraryTests.java | 36 +- .../index/engine/AbstractKNNMethodTests.java | 44 ++- .../{ => engine}/KNNMethodContextTests.java | 124 +++++-- .../index/engine/MethodComponentTests.java | 69 ++-- .../knn/index/engine/ParameterTests.java | 157 ++++---- .../engine/faiss/FaissFP16UtilTests.java | 60 ++++ .../knn/index/engine/faiss/FaissTests.java | 69 +++- .../knn/index/engine/lucene/LuceneTests.java | 19 +- .../mapper/KNNVectorFieldMapperTests.java | 339 +++++++++--------- .../mapper/KNNVectorFieldMapperUtilTests.java | 54 --- .../index/mapper/MethodFieldMapperTests.java | 39 -- .../integ/BinaryIndexInvalidMappingIT.java | 6 +- .../opensearch/knn/jni/JNIServiceTests.java | 49 ++- .../LibraryInitializedSupplierTests.java | 16 +- .../transport/TrainingModelRequestTests.java | 23 +- .../knn/training/TrainingJobTests.java | 65 ++-- .../org/opensearch/knn/KNNRestTestCase.java | 8 + 71 files changed, 1890 insertions(+), 1729 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/engine/KNNMethodConfigContext.java create mode 100644 src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java create mode 100644 src/main/java/org/opensearch/knn/index/engine/faiss/FaissFP16Util.java delete mode 100644 src/main/java/org/opensearch/knn/training/VectorSpaceInfo.java rename src/test/java/org/opensearch/knn/index/{ => engine}/KNNMethodContextTests.java (80%) create mode 100644 src/test/java/org/opensearch/knn/index/engine/faiss/FaissFP16UtilTests.java delete mode 100644 src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 86defd59e..ccac8d7c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,3 +38,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Move k search k-NN query to re-write phase of vector search query for Native Engines [#1877](https://github.com/opensearch-project/k-NN/pull/1877) * Restructure mappers to better handle null cases and avoid branching in parsing [#1939](https://github.com/opensearch-project/k-NN/pull/1939) * Added Quantization Framework and implemented 1Bit and multibit quantizer[#1889](https://github.com/opensearch-project/k-NN/issues/1889) +* Encapsulate dimension, vector data type validation/processing inside Library [#1957](https://github.com/opensearch-project/k-NN/pull/1957) diff --git a/qa/restart-upgrade/build.gradle b/qa/restart-upgrade/build.gradle index 6bae754a2..a629e87ff 100644 --- a/qa/restart-upgrade/build.gradle +++ b/qa/restart-upgrade/build.gradle @@ -58,6 +58,29 @@ testClusters { excludeTestsMatching "org.opensearch.knn.bwc.IndexingIT.testEmptyParametersOnUpgrade" } } + + if (knn_bwc_version.startsWith("1.") || + knn_bwc_version.startsWith("2.0.") || + knn_bwc_version.startsWith("2.1.") || + knn_bwc_version.startsWith("2.2.") || + knn_bwc_version.startsWith("2.3.") || + knn_bwc_version.startsWith("2.4") || + knn_bwc_version.startsWith("2.5.") || + knn_bwc_version.startsWith("2.6.") || + knn_bwc_version.startsWith("2.7.") || + knn_bwc_version.startsWith("2.8.") || + knn_bwc_version.startsWith("2.9.") || + knn_bwc_version.startsWith("2.10.") || + knn_bwc_version.startsWith("2.11.") || + knn_bwc_version.startsWith("2.12.") || + knn_bwc_version.startsWith("2.13.") || + knn_bwc_version.startsWith("2.14.") || + knn_bwc_version.startsWith("2.15.")) { + filter { + excludeTestsMatching "org.opensearch.knn.bwc.IndexingIT.testKNNIndexBinaryForceMerge" + } + } + nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}".allHttpSocketURI.join(",")}") nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}".getName()}") systemProperty 'tests.security.manager', 'false' @@ -101,6 +124,29 @@ testClusters { excludeTestsMatching "org.opensearch.knn.bwc.IndexingIT.testEmptyParametersOnUpgrade" } } + + if (knn_bwc_version.startsWith("1.") || + knn_bwc_version.startsWith("2.0.") || + knn_bwc_version.startsWith("2.1.") || + knn_bwc_version.startsWith("2.2.") || + knn_bwc_version.startsWith("2.3.") || + knn_bwc_version.startsWith("2.4") || + knn_bwc_version.startsWith("2.5.") || + knn_bwc_version.startsWith("2.6.") || + knn_bwc_version.startsWith("2.7.") || + knn_bwc_version.startsWith("2.8.") || + knn_bwc_version.startsWith("2.9.") || + knn_bwc_version.startsWith("2.10.") || + knn_bwc_version.startsWith("2.11.") || + knn_bwc_version.startsWith("2.12.") || + knn_bwc_version.startsWith("2.13.") || + knn_bwc_version.startsWith("2.14.") || + knn_bwc_version.startsWith("2.15.")) { + filter { + excludeTestsMatching "org.opensearch.knn.bwc.IndexingIT.testKNNIndexBinaryForceMerge" + } + } + nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}".allHttpSocketURI.join(",")}") nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}".getName()}") systemProperty 'tests.security.manager', 'false' diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/AbstractRestartUpgradeTestCase.java b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/AbstractRestartUpgradeTestCase.java index ed11ca9d8..667a39d16 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/AbstractRestartUpgradeTestCase.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/AbstractRestartUpgradeTestCase.java @@ -58,5 +58,4 @@ protected static final boolean isRunningAgainstOldCluster() { protected final Optional getBWCVersion() { return Optional.ofNullable(System.getProperty(BWC_VERSION, null)); } - } diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/IndexingIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/IndexingIT.java index 1531dd0da..9c1dfb018 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/IndexingIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/IndexingIT.java @@ -8,6 +8,8 @@ import org.junit.Assert; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNEngine; import java.util.Map; @@ -68,6 +70,33 @@ public void testKNNIndexDefaultLegacyFieldMappingForceMerge() throws Exception { } } + // Ensure bwc works for binary force merge + public void testKNNIndexBinaryForceMerge() throws Exception { + int dimension = 40; + + waitForClusterHealthGreen(NODES_BWC_CLUSTER); + if (isRunningAgainstOldCluster()) { + createKnnIndex( + testIndex, + getKNNDefaultIndexSettings(), + createKnnIndexMapping( + TEST_FIELD, + dimension, + METHOD_HNSW, + KNNEngine.FAISS.getName(), + SpaceType.HAMMING.getValue(), + true, + VectorDataType.BINARY + ) + ); + addKNNByteDocs(testIndex, TEST_FIELD, dimension / 8, DOC_ID, 100); + // Flush to ensure that index is not re-indexed when node comes back up + flush(testIndex, true); + } else { + forceMergeKnnIndex(testIndex); + } + } + // Custom Legacy Field Mapping // space_type : "linf", engine : "nmslib", m : 2, ef_construction : 2 public void testKNNIndexCustomLegacyFieldMapping() throws Exception { diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 5874eaded..f8bd0b3f7 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -238,17 +238,11 @@ private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pa ); } - // Update index description of Faiss for binary data type - if (KNNEngine.FAISS == knnEngine - && VectorDataType.BINARY.getValue() - .equals(fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue())) - && parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null) { - parameters.put( - KNNConstants.INDEX_DESCRIPTION_PARAMETER, - FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString() - ); - IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY); - } + // In OpenSearch 2.16, we added the prefix for binary indices in the index description in the codec logic. + // After 2.16, we added the binary prefix in the faiss library code. However, to ensure backwards compatibility, + // we need to ensure that if the description does not contain the prefix but the type is binary, we add the + // description. + maybeAddBinaryPrefixForFaissBWC(knnEngine, parameters, fieldAttributes); // Used to determine how many threads to use when indexing parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); @@ -260,6 +254,31 @@ private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pa }); } + private void maybeAddBinaryPrefixForFaissBWC(KNNEngine knnEngine, Map parameters, Map fieldAttributes) { + if (KNNEngine.FAISS != knnEngine) { + return; + } + + if (!VectorDataType.BINARY.getValue() + .equals(fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()))) { + return; + } + + if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) == null) { + return; + } + + if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString().startsWith(FAISS_BINARY_INDEX_DESCRIPTION_PREFIX)) { + return; + } + + parameters.put( + KNNConstants.INDEX_DESCRIPTION_PARAMETER, + FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString() + ); + IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY); + } + /** * Merges in the fields from the readers in mergeState * diff --git a/src/main/java/org/opensearch/knn/index/engine/AbstractKNNLibrary.java b/src/main/java/org/opensearch/knn/index/engine/AbstractKNNLibrary.java index 92e34be7c..9b38b1b6b 100644 --- a/src/main/java/org/opensearch/knn/index/engine/AbstractKNNLibrary.java +++ b/src/main/java/org/opensearch/knn/index/engine/AbstractKNNLibrary.java @@ -9,8 +9,9 @@ import lombok.AllArgsConstructor; import lombok.Getter; import org.opensearch.common.ValidationException; -import org.opensearch.knn.training.VectorSpaceInfo; +import org.opensearch.knn.index.VectorDataType; +import java.util.Locale; import java.util.Map; /** @@ -25,44 +26,94 @@ public abstract class AbstractKNNLibrary implements KNNLibrary { @Override public KNNLibrarySearchContext getKNNLibrarySearchContext(String methodName) { - validateMethodExists(methodName); + throwIllegalArgOnNonNull(validateMethodExists(methodName)); KNNMethod method = methods.get(methodName); return method.getKNNLibrarySearchContext(); } @Override - public KNNLibraryIndexingContext getKNNLibraryIndexingContext(KNNMethodContext knnMethodContext) { + public KNNLibraryIndexingContext getKNNLibraryIndexingContext( + KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext + ) { String method = knnMethodContext.getMethodComponentContext().getName(); - validateMethodExists(method); + throwIllegalArgOnNonNull(validateMethodExists(method)); KNNMethod knnMethod = methods.get(method); - return knnMethod.getKNNLibraryIndexingContext(knnMethodContext); + return knnMethod.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext); } @Override - public ValidationException validateMethod(KNNMethodContext knnMethodContext) { + public ValidationException validateMethod(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) { String methodName = knnMethodContext.getMethodComponentContext().getName(); - validateMethodExists(methodName); - return methods.get(methodName).validate(knnMethodContext); + ValidationException validationException = null; + String invalidErrorMessage = validateMethodExists(methodName); + if (invalidErrorMessage != null) { + validationException = new ValidationException(); + validationException.addValidationError(invalidErrorMessage); + return validationException; + } + invalidErrorMessage = validateDimension(knnMethodContext, knnMethodConfigContext); + if (invalidErrorMessage != null) { + validationException = new ValidationException(); + validationException.addValidationError(invalidErrorMessage); + } + + validateSpaceType(knnMethodContext, knnMethodConfigContext); + ValidationException methodValidation = methods.get(methodName).validate(knnMethodContext, knnMethodConfigContext); + if (methodValidation != null) { + validationException = validationException == null ? new ValidationException() : validationException; + validationException.addValidationErrors(methodValidation.validationErrors()); + } + + return validationException; } - @Override - public ValidationException validateMethodWithData(KNNMethodContext knnMethodContext, VectorSpaceInfo vectorSpaceInfo) { - String methodName = knnMethodContext.getMethodComponentContext().getName(); - validateMethodExists(methodName); - return methods.get(methodName).validateWithData(knnMethodContext, vectorSpaceInfo); + private void validateSpaceType(final KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) { + if (knnMethodContext == null) { + return; + } + knnMethodContext.getSpaceType().validateVectorDataType(knnMethodConfigContext.getVectorDataType()); + } + + private String validateDimension(final KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) { + if (knnMethodContext == null) { + return null; + } + int dimension = knnMethodConfigContext.getDimension(); + if (dimension > KNNEngine.getMaxDimensionByEngine(knnMethodContext.getKnnEngine())) { + return String.format( + Locale.ROOT, + "Dimension value cannot be greater than %s for vector with engine: %s", + KNNEngine.getMaxDimensionByEngine(knnMethodContext.getKnnEngine()), + knnMethodContext.getKnnEngine().getName() + ); + } + + if (VectorDataType.BINARY == knnMethodConfigContext.getVectorDataType() && dimension % 8 != 0) { + return "Dimension should be multiply of 8 for binary vector data type"; + } + + return null; } @Override public boolean isTrainingRequired(KNNMethodContext knnMethodContext) { String methodName = knnMethodContext.getMethodComponentContext().getName(); - validateMethodExists(methodName); + throwIllegalArgOnNonNull(validateMethodExists(methodName)); return methods.get(methodName).isTrainingRequired(knnMethodContext); } - private void validateMethodExists(String methodName) { + private String validateMethodExists(String methodName) { KNNMethod method = methods.get(methodName); if (method == null) { - throw new IllegalArgumentException(String.format("Invalid method name: %s", methodName)); + return String.format("Invalid method name: %s", methodName); + } + return null; + } + + private void throwIllegalArgOnNonNull(String errorMessage) { + if (errorMessage != null) { + throw new IllegalArgumentException(errorMessage); } } } diff --git a/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java b/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java index 6e57e6913..52cc79129 100644 --- a/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java @@ -9,7 +9,11 @@ import org.opensearch.common.ValidationException; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.training.VectorSpaceInfo; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.mapper.PerDimensionProcessor; +import org.opensearch.knn.index.mapper.PerDimensionValidator; +import org.opensearch.knn.index.mapper.SpaceVectorValidator; +import org.opensearch.knn.index.mapper.VectorValidator; import java.util.ArrayList; import java.util.HashMap; @@ -35,7 +39,7 @@ public boolean isSpaceTypeSupported(SpaceType space) { } @Override - public ValidationException validate(KNNMethodContext knnMethodContext) { + public ValidationException validate(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) { List errorMessages = new ArrayList<>(); if (!isSpaceTypeSupported(knnMethodContext.getSpaceType())) { errorMessages.add( @@ -49,7 +53,10 @@ public ValidationException validate(KNNMethodContext knnMethodContext) { ); } - ValidationException methodValidation = methodComponent.validate(knnMethodContext.getMethodComponentContext()); + ValidationException methodValidation = methodComponent.validate( + knnMethodContext.getMethodComponentContext(), + knnMethodConfigContext + ); if (methodValidation != null) { errorMessages.addAll(methodValidation.validationErrors()); } @@ -64,52 +71,58 @@ public ValidationException validate(KNNMethodContext knnMethodContext) { } @Override - public ValidationException validateWithData(KNNMethodContext knnMethodContext, VectorSpaceInfo vectorSpaceInfo) { - List errorMessages = new ArrayList<>(); - if (!isSpaceTypeSupported(knnMethodContext.getSpaceType())) { - errorMessages.add( - String.format( - Locale.ROOT, - "\"%s\" with \"%s\" configuration does not support space type: " + "\"%s\".", - this.methodComponent.getName(), - knnMethodContext.getKnnEngine().getName().toLowerCase(Locale.ROOT), - knnMethodContext.getSpaceType().getValue() - ) - ); - } + public boolean isTrainingRequired(KNNMethodContext knnMethodContext) { + return methodComponent.isTrainingRequired(knnMethodContext.getMethodComponentContext()); + } - ValidationException methodValidation = methodComponent.validateWithData( - knnMethodContext.getMethodComponentContext(), - vectorSpaceInfo - ); - if (methodValidation != null) { - errorMessages.addAll(methodValidation.validationErrors()); - } + @Override + public int estimateOverheadInKB(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) { + return methodComponent.estimateOverheadInKB(knnMethodContext.getMethodComponentContext(), knnMethodConfigContext.getDimension()); + } - if (errorMessages.isEmpty()) { - return null; + protected PerDimensionValidator doGetPerDimensionValidator( + KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext + ) { + VectorDataType vectorDataType = knnMethodConfigContext.getVectorDataType(); + + if (VectorDataType.BINARY == vectorDataType) { + return PerDimensionValidator.DEFAULT_BIT_VALIDATOR; } - ValidationException validationException = new ValidationException(); - validationException.addValidationErrors(errorMessages); - return validationException; + if (VectorDataType.BYTE == vectorDataType) { + return PerDimensionValidator.DEFAULT_BYTE_VALIDATOR; + } + return PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR; } - @Override - public boolean isTrainingRequired(KNNMethodContext knnMethodContext) { - return methodComponent.isTrainingRequired(knnMethodContext.getMethodComponentContext()); + protected VectorValidator doGetVectorValidator(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) { + return new SpaceVectorValidator(knnMethodContext.getSpaceType()); } - @Override - public int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension) { - return methodComponent.estimateOverheadInKB(knnMethodContext.getMethodComponentContext(), dimension); + protected PerDimensionProcessor doGetPerDimensionProcessor( + KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext + ) { + return PerDimensionProcessor.NOOP_PROCESSOR; } @Override - public KNNLibraryIndexingContext getKNNLibraryIndexingContext(KNNMethodContext knnMethodContext) { - Map parameterMap = new HashMap<>(methodComponent.getAsMap(knnMethodContext.getMethodComponentContext())); + public KNNLibraryIndexingContext getKNNLibraryIndexingContext( + KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext + ) { + Map parameterMap = new HashMap<>( + methodComponent.getAsMap(knnMethodContext.getMethodComponentContext(), knnMethodConfigContext) + ); parameterMap.put(KNNConstants.SPACE_TYPE, knnMethodContext.getSpaceType().getValue()); - return KNNLibraryIndexingContextImpl.builder().parameters(parameterMap).build(); + parameterMap.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, knnMethodConfigContext.getVectorDataType().getValue()); + return KNNLibraryIndexingContextImpl.builder() + .parameters(parameterMap) + .vectorValidator(doGetVectorValidator(knnMethodContext, knnMethodConfigContext)) + .perDimensionValidator(doGetPerDimensionValidator(knnMethodContext, knnMethodConfigContext)) + .perDimensionProcessor(doGetPerDimensionProcessor(knnMethodContext, knnMethodConfigContext)) + .build(); } @Override diff --git a/src/main/java/org/opensearch/knn/index/engine/DefaultHnswSearchContext.java b/src/main/java/org/opensearch/knn/index/engine/DefaultHnswSearchContext.java index ecc11f338..884657442 100644 --- a/src/main/java/org/opensearch/knn/index/engine/DefaultHnswSearchContext.java +++ b/src/main/java/org/opensearch/knn/index/engine/DefaultHnswSearchContext.java @@ -17,7 +17,10 @@ public final class DefaultHnswSearchContext implements KNNLibrarySearchContext { private final Map> supportedMethodParameters = ImmutableMap.>builder() - .put(MethodParameter.EF_SEARCH.getName(), new Parameter.IntegerParameter(MethodParameter.EF_SEARCH.getName(), null, value -> true)) + .put( + MethodParameter.EF_SEARCH.getName(), + new Parameter.IntegerParameter(MethodParameter.EF_SEARCH.getName(), null, (value, context) -> true) + ) .build(); @Override diff --git a/src/main/java/org/opensearch/knn/index/engine/DefaultIVFSearchContext.java b/src/main/java/org/opensearch/knn/index/engine/DefaultIVFSearchContext.java index cc612bf8c..16e3f67d8 100644 --- a/src/main/java/org/opensearch/knn/index/engine/DefaultIVFSearchContext.java +++ b/src/main/java/org/opensearch/knn/index/engine/DefaultIVFSearchContext.java @@ -14,7 +14,10 @@ public final class DefaultIVFSearchContext implements KNNLibrarySearchContext { private final Map> supportedMethodParameters = ImmutableMap.>builder() - .put(MethodParameter.NPROBE.getName(), new Parameter.IntegerParameter(MethodParameter.NPROBE.getName(), null, value -> true)) + .put( + MethodParameter.NPROBE.getName(), + new Parameter.IntegerParameter(MethodParameter.NPROBE.getName(), null, (value, context) -> true) + ) .build(); @Override diff --git a/src/main/java/org/opensearch/knn/index/engine/JVMLibrary.java b/src/main/java/org/opensearch/knn/index/engine/JVMLibrary.java index 762966567..bfb25c7c6 100644 --- a/src/main/java/org/opensearch/knn/index/engine/JVMLibrary.java +++ b/src/main/java/org/opensearch/knn/index/engine/JVMLibrary.java @@ -25,7 +25,7 @@ public JVMLibrary(Map methods, String version) { } @Override - public int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension) { + public int estimateOverheadInKB(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) { throw new UnsupportedOperationException("Estimating overhead is not supported for JVM based libraries."); } diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java b/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java index c7b271783..2f3cb3430 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java @@ -11,7 +11,6 @@ import org.opensearch.knn.index.engine.faiss.Faiss; import org.opensearch.knn.index.engine.lucene.Lucene; import org.opensearch.knn.index.engine.nmslib.Nmslib; -import org.opensearch.knn.training.VectorSpaceInfo; import java.util.List; import java.util.Map; @@ -161,13 +160,8 @@ public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { } @Override - public ValidationException validateMethod(KNNMethodContext knnMethodContext) { - return knnLibrary.validateMethod(knnMethodContext); - } - - @Override - public ValidationException validateMethodWithData(KNNMethodContext knnMethodContext, VectorSpaceInfo vectorSpaceInfo) { - return knnLibrary.validateMethodWithData(knnMethodContext, vectorSpaceInfo); + public ValidationException validateMethod(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) { + return knnLibrary.validateMethod(knnMethodContext, knnMethodConfigContext); } @Override @@ -176,8 +170,11 @@ public boolean isTrainingRequired(KNNMethodContext knnMethodContext) { } @Override - public KNNLibraryIndexingContext getKNNLibraryIndexingContext(KNNMethodContext knnMethodContext) { - return knnLibrary.getKNNLibraryIndexingContext(knnMethodContext); + public KNNLibraryIndexingContext getKNNLibraryIndexingContext( + KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext + ) { + return knnLibrary.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext); } @Override @@ -186,8 +183,8 @@ public KNNLibrarySearchContext getKNNLibrarySearchContext(String methodName) { } @Override - public int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension) { - return knnLibrary.estimateOverheadInKB(knnMethodContext, dimension); + public int estimateOverheadInKB(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) { + return knnLibrary.estimateOverheadInKB(knnMethodContext, knnMethodConfigContext); } @Override diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java b/src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java index 96d492307..14085243f 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java @@ -7,7 +7,6 @@ import org.opensearch.common.ValidationException; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.training.VectorSpaceInfo; import java.util.Collections; import java.util.List; @@ -76,19 +75,10 @@ public interface KNNLibrary { * deemed invalid. * * @param knnMethodContext to be validated + * @param knnMethodConfigContext configuration context for the method * @return ValidationException produced by validation errors; null if no validations errors. */ - ValidationException validateMethod(KNNMethodContext knnMethodContext); - - /** - * Validate the knnMethodContext for the given library, using additional data not present in the method context. A ValidationException should be thrown if the method is - * deemed invalid. - * - * @param knnMethodContext to be validated - * @param vectorSpaceInfo additional data not present in the method context - * @return ValidationException produced by validation errors; null if no validations errors. - */ - ValidationException validateMethodWithData(KNNMethodContext knnMethodContext, VectorSpaceInfo vectorSpaceInfo); + ValidationException validateMethod(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext); /** * Returns whether training is required or not from knnMethodContext for the given library. @@ -102,18 +92,22 @@ public interface KNNLibrary { * Estimate overhead of KNNMethodContext in Kilobytes. * * @param knnMethodContext to estimate size for - * @param dimension to estimate size for + * @param knnMethodConfigContext configuration context for the method * @return size overhead estimate in KB */ - int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension); + int estimateOverheadInKB(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext); /** * Get the context from the library needed to build the index. * * @param knnMethodContext to get build context for + * @param knnMethodConfigContext configuration context for the method * @return parameter map */ - KNNLibraryIndexingContext getKNNLibraryIndexingContext(KNNMethodContext knnMethodContext); + KNNLibraryIndexingContext getKNNLibraryIndexingContext( + KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext + ); /** * Gets metadata related to methods supported by the library diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java index d00b7c436..20285471e 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java @@ -5,7 +5,10 @@ package org.opensearch.knn.index.engine; -import java.util.Collections; +import org.opensearch.knn.index.mapper.PerDimensionProcessor; +import org.opensearch.knn.index.mapper.PerDimensionValidator; +import org.opensearch.knn.index.mapper.VectorValidator; + import java.util.Map; /** @@ -19,5 +22,21 @@ public interface KNNLibraryIndexingContext { */ Map getLibraryParameters(); - KNNLibraryIndexingContext EMPTY = Collections::emptyMap; + /** + * + * @return Get the vector validator + */ + VectorValidator getVectorValidator(); + + /** + * + * @return Get the per dimension validator + */ + PerDimensionValidator getPerDimensionValidator(); + + /** + * + * @return Get the per dimension processor + */ + PerDimensionProcessor getPerDimensionProcessor(); } diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java index b7c775261..51b60d9e5 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java @@ -6,6 +6,9 @@ package org.opensearch.knn.index.engine; import lombok.Builder; +import org.opensearch.knn.index.mapper.PerDimensionProcessor; +import org.opensearch.knn.index.mapper.PerDimensionValidator; +import org.opensearch.knn.index.mapper.VectorValidator; import java.util.Map; @@ -15,10 +18,28 @@ @Builder public class KNNLibraryIndexingContextImpl implements KNNLibraryIndexingContext { + private VectorValidator vectorValidator; + private PerDimensionValidator perDimensionValidator; + private PerDimensionProcessor perDimensionProcessor; private Map parameters; @Override public Map getLibraryParameters() { return parameters; } + + @Override + public VectorValidator getVectorValidator() { + return vectorValidator; + } + + @Override + public PerDimensionValidator getPerDimensionValidator() { + return perDimensionValidator; + } + + @Override + public PerDimensionProcessor getPerDimensionProcessor() { + return perDimensionProcessor; + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNMethod.java b/src/main/java/org/opensearch/knn/index/engine/KNNMethod.java index 326e5c1e0..0bcccacf0 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNMethod.java @@ -7,7 +7,6 @@ import org.opensearch.common.ValidationException; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.training.VectorSpaceInfo; /** * KNNMethod defines the structure of a method supported by a particular k-NN library. It is used to validate @@ -28,18 +27,10 @@ public interface KNNMethod { * Validate that the configured KNNMethodContext is valid for this method * * @param knnMethodContext to be validated + * @param knnMethodConfigContext to be validated * @return ValidationException produced by validation errors; null if no validations errors. */ - ValidationException validate(KNNMethodContext knnMethodContext); - - /** - * Validate that the configured KNNMethodContext is valid for this method, using additional data not present in the method context - * - * @param knnMethodContext to be validated - * @param vectorSpaceInfo additional data not present in the method context - * @return ValidationException produced by validation errors; null if no validations errors. - */ - ValidationException validateWithData(KNNMethodContext knnMethodContext, VectorSpaceInfo vectorSpaceInfo); + ValidationException validate(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext); /** * returns whether training is required or not @@ -53,18 +44,22 @@ public interface KNNMethod { * Returns the estimated overhead of the method in KB * * @param knnMethodContext context to estimate overhead - * @param dimension dimension to make estimate with + * @param knnMethodConfigContext config context to estimate overhead * @return estimate overhead in KB */ - int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension); + int estimateOverheadInKB(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext); /** * Parse knnMethodContext into context that the library can use to build the index * * @param knnMethodContext to generate the context for + * @param knnMethodConfigContext to generate the context for * @return KNNLibraryIndexingContext */ - KNNLibraryIndexingContext getKNNLibraryIndexingContext(KNNMethodContext knnMethodContext); + KNNLibraryIndexingContext getKNNLibraryIndexingContext( + KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext + ); /** * Get the search context for a particular method diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNMethodConfigContext.java b/src/main/java/org/opensearch/knn/index/engine/KNNMethodConfigContext.java new file mode 100644 index 000000000..731085f0b --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/engine/KNNMethodConfigContext.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import org.apache.commons.lang.builder.EqualsBuilder; +import org.apache.commons.lang.builder.HashCodeBuilder; +import org.opensearch.Version; +import org.opensearch.knn.index.VectorDataType; + +/** + * This object provides additional context that the user does not provide when {@link KNNMethodContext} is + * created via parsing. The values in this object need to be dynamically set and calling code needs to handle + * the possibility that the values have not been set. + */ +@Setter +@Getter +@Builder +@AllArgsConstructor +public final class KNNMethodConfigContext { + private VectorDataType vectorDataType; + private Integer dimension; + private Version versionCreated; + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + KNNMethodConfigContext other = (KNNMethodConfigContext) obj; + + EqualsBuilder equalsBuilder = new EqualsBuilder(); + equalsBuilder.append(vectorDataType, other.vectorDataType); + equalsBuilder.append(dimension, other.dimension); + equalsBuilder.append(versionCreated, other.versionCreated); + + return equalsBuilder.isEquals(); + } + + @Override + public int hashCode() { + return new HashCodeBuilder().append(vectorDataType).append(dimension).append(versionCreated).toHashCode(); + } + + public static final KNNMethodConfigContext EMPTY = KNNMethodConfigContext.builder().build(); +} diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNMethodContext.java b/src/main/java/org/opensearch/knn/index/engine/KNNMethodContext.java index d210483e6..8b2f00f74 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNMethodContext.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNMethodContext.java @@ -24,7 +24,6 @@ import java.util.stream.Collectors; import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; -import org.opensearch.knn.training.VectorSpaceInfo; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; @@ -60,22 +59,13 @@ public KNNMethodContext(StreamInput in) throws IOException { } /** - * This method uses the knnEngine to validate that the method is compatible with the engine + * This method uses the knnEngine to validate that the method is compatible with the engine. * + * @param knnMethodConfigContext context to validate against * @return ValidationException produced by validation errors; null if no validations errors. */ - public ValidationException validate() { - return knnEngine.validateMethod(this); - } - - /** - * This method uses the knnEngine to validate that the method is compatible with the engine, using additional data not present in the method context - * - * @param vectorSpaceInfo additional data not present in the method context - * @return ValidationException produced by validation errors; null if no validations errors. - */ - public ValidationException validateWithData(VectorSpaceInfo vectorSpaceInfo) { - return knnEngine.validateMethodWithData(this, vectorSpaceInfo); + public ValidationException validate(KNNMethodConfigContext knnMethodConfigContext) { + return knnEngine.validateMethod(this, knnMethodConfigContext); } /** @@ -90,11 +80,11 @@ public boolean isTrainingRequired() { /** * This method estimates the overhead the knn method adds irrespective of the number of vectors * - * @param dimension dimension to make estimate with + * @param knnMethodConfigContext context to estimate overhead * @return size in Kilobytes */ - public int estimateOverheadInKB(int dimension) { - return knnEngine.estimateOverheadInKB(this, dimension); + public int estimateOverheadInKB(KNNMethodConfigContext knnMethodConfigContext) { + return knnEngine.estimateOverheadInKB(this, knnMethodConfigContext); } /** diff --git a/src/main/java/org/opensearch/knn/index/engine/MethodComponent.java b/src/main/java/org/opensearch/knn/index/engine/MethodComponent.java index cd9377ef1..988812e61 100644 --- a/src/main/java/org/opensearch/knn/index/engine/MethodComponent.java +++ b/src/main/java/org/opensearch/knn/index/engine/MethodComponent.java @@ -10,14 +10,14 @@ import org.opensearch.common.TriFunction; import org.opensearch.common.ValidationException; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.IndexHyperParametersUtil; -import org.opensearch.knn.training.VectorSpaceInfo; import java.util.HashMap; +import java.util.HashSet; +import java.util.Locale; import java.util.Map; -import java.util.function.BiFunction; -import java.util.List; -import java.util.ArrayList; +import java.util.Set; import static org.opensearch.knn.index.engine.validation.ParameterValidator.validateParameters; @@ -27,12 +27,13 @@ public class MethodComponent { @Getter - private String name; + private final String name; @Getter - private Map> parameters; - private BiFunction> mapGenerator; - private TriFunction overheadInKBEstimator; - final private boolean requiresTraining; + private final Map> parameters; + private final TriFunction> mapGenerator; + private final TriFunction overheadInKBEstimator; + private final boolean requiresTraining; + private final Set supportedVectorDataTypes; /** * Constructor @@ -45,6 +46,7 @@ private MethodComponent(Builder builder) { this.mapGenerator = builder.mapGenerator; this.overheadInKBEstimator = builder.overheadInKBEstimator; this.requiresTraining = builder.requiresTraining; + this.supportedVectorDataTypes = builder.supportedDataTypes; } /** @@ -53,61 +55,49 @@ private MethodComponent(Builder builder) { * @param methodComponentContext from which to generate map * @return Method component as a map */ - public Map getAsMap(MethodComponentContext methodComponentContext) { + public Map getAsMap(MethodComponentContext methodComponentContext, KNNMethodConfigContext knnMethodConfigContext) { if (mapGenerator == null) { Map parameterMap = new HashMap<>(); parameterMap.put(KNNConstants.NAME, methodComponentContext.getName()); - parameterMap.put(KNNConstants.PARAMETERS, getParameterMapWithDefaultsAdded(methodComponentContext, this)); + parameterMap.put( + KNNConstants.PARAMETERS, + getParameterMapWithDefaultsAdded(methodComponentContext, this, knnMethodConfigContext) + ); return parameterMap; } - return mapGenerator.apply(this, methodComponentContext); + return mapGenerator.apply(this, methodComponentContext, knnMethodConfigContext); } /** * Validate that the methodComponentContext is a valid configuration for this methodComponent * * @param methodComponentContext to be validated + * @param knnMethodConfigContext context for the method configuration * @return ValidationException produced by validation errors; null if no validations errors. */ - public ValidationException validate(MethodComponentContext methodComponentContext) { + public ValidationException validate(MethodComponentContext methodComponentContext, KNNMethodConfigContext knnMethodConfigContext) { Map providedParameters = methodComponentContext.getParameters(); - return validateParameters(parameters, providedParameters); - } - - /** - * Validate that the methodComponentContext is a valid configuration for this methodComponent, using additional data not present in the method component context - * - * @param methodComponentContext to be validated - * @param vectorSpaceInfo additional data not present in the method component context - * @return ValidationException produced by validation errors; null if no validations errors. - */ - public ValidationException validateWithData(MethodComponentContext methodComponentContext, VectorSpaceInfo vectorSpaceInfo) { - Map providedParameters = methodComponentContext.getParameters(); - List errorMessages = new ArrayList<>(); - if (providedParameters == null) { - return null; + ValidationException validationException = null; + if (!supportedVectorDataTypes.contains(knnMethodConfigContext.getVectorDataType())) { + validationException = new ValidationException(); + validationException.addValidationError( + String.format( + Locale.ROOT, + "Method \"%s\" is not supported for vector data type \"%s\".", + name, + knnMethodConfigContext.getVectorDataType() + ) + ); } - ValidationException parameterValidation; - for (Map.Entry parameter : providedParameters.entrySet()) { - if (!parameters.containsKey(parameter.getKey())) { - errorMessages.add(String.format("Invalid parameter for method \"%s\".", getName())); - continue; - } - - parameterValidation = parameters.get(parameter.getKey()).validateWithData(parameter.getValue(), vectorSpaceInfo); - if (parameterValidation != null) { - errorMessages.addAll(parameterValidation.validationErrors()); - } - } + ValidationException methodValidationException = validateParameters(parameters, providedParameters, knnMethodConfigContext); - if (errorMessages.isEmpty()) { - return null; + if (methodValidationException != null) { + validationException = validationException == null ? new ValidationException() : validationException; + validationException.addValidationErrors(methodValidationException.validationErrors()); } - ValidationException validationException = new ValidationException(); - validationException.addValidationErrors(errorMessages); return validationException; } @@ -217,11 +207,12 @@ public int estimateOverheadInKB(MethodComponentContext methodComponentContext, i */ public static class Builder { - private String name; - private Map> parameters; - private BiFunction> mapGenerator; + private final String name; + private final Map> parameters; + private TriFunction> mapGenerator; private TriFunction overheadInKBEstimator; private boolean requiresTraining; + private final Set supportedDataTypes; /** * Method to get a Builder instance @@ -230,7 +221,7 @@ public static class Builder { * @return Builder instance */ public static Builder builder(String name) { - return new MethodComponent.Builder(name); + return new Builder(name); } private Builder(String name) { @@ -238,6 +229,7 @@ private Builder(String name) { this.parameters = new HashMap<>(); this.mapGenerator = null; this.overheadInKBEstimator = (mc, mcc, d) -> 0L; + this.supportedDataTypes = new HashSet<>(); } /** @@ -258,7 +250,9 @@ public Builder addParameter(String parameterName, Parameter parameter) { * @param mapGenerator function to parse a MethodComponentContext as a map * @return this builder */ - public Builder setMapGenerator(BiFunction> mapGenerator) { + public Builder setMapGenerator( + TriFunction> mapGenerator + ) { this.mapGenerator = mapGenerator; return this; } @@ -284,6 +278,17 @@ public Builder setOverheadInKBEstimator(TriFunction dataTypeSet) { + supportedDataTypes.addAll(dataTypeSet); + return this; + } + /** * Build MethodComponent * @@ -303,11 +308,12 @@ public MethodComponent build() { */ public static Map getParameterMapWithDefaultsAdded( MethodComponentContext methodComponentContext, - MethodComponent methodComponent + MethodComponent methodComponent, + KNNMethodConfigContext knnMethodConfigContext ) { Map parametersWithDefaultsMap = new HashMap<>(); Map userProvidedParametersMap = methodComponentContext.getParameters(); - Version indexCreationVersion = methodComponentContext.getIndexVersion(); + Version indexCreationVersion = knnMethodConfigContext.getVersionCreated(); for (Parameter parameter : methodComponent.getParameters().values()) { if (methodComponentContext.getParameters().containsKey(parameter.getName())) { parametersWithDefaultsMap.put(parameter.getName(), userProvidedParametersMap.get(parameter.getName())); diff --git a/src/main/java/org/opensearch/knn/index/engine/MethodComponentContext.java b/src/main/java/org/opensearch/knn/index/engine/MethodComponentContext.java index fb4327487..586cc338f 100644 --- a/src/main/java/org/opensearch/knn/index/engine/MethodComponentContext.java +++ b/src/main/java/org/opensearch/knn/index/engine/MethodComponentContext.java @@ -8,9 +8,7 @@ import lombok.AllArgsConstructor; import lombok.Getter; import lombok.RequiredArgsConstructor; -import lombok.Setter; import org.apache.commons.lang.math.NumberUtils; -import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -51,10 +49,6 @@ public class MethodComponentContext implements ToXContentFragment, Writeable { private final String name; private final Map parameters; - @Getter - @Setter - private Version indexVersion; - /** * Constructor from stream. * diff --git a/src/main/java/org/opensearch/knn/index/engine/NativeLibrary.java b/src/main/java/org/opensearch/knn/index/engine/NativeLibrary.java index 1e34cc380..c3c61292a 100644 --- a/src/main/java/org/opensearch/knn/index/engine/NativeLibrary.java +++ b/src/main/java/org/opensearch/knn/index/engine/NativeLibrary.java @@ -59,9 +59,9 @@ public float score(float rawScore, SpaceType spaceType) { } @Override - public int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension) { + public int estimateOverheadInKB(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) { String methodName = knnMethodContext.getMethodComponentContext().getName(); - return methods.get(methodName).estimateOverheadInKB(knnMethodContext, dimension); + return methods.get(methodName).estimateOverheadInKB(knnMethodContext, knnMethodConfigContext); } @Override diff --git a/src/main/java/org/opensearch/knn/index/engine/Parameter.java b/src/main/java/org/opensearch/knn/index/engine/Parameter.java index fbd3ae692..4dd6b9c33 100644 --- a/src/main/java/org/opensearch/knn/index/engine/Parameter.java +++ b/src/main/java/org/opensearch/knn/index/engine/Parameter.java @@ -5,14 +5,13 @@ package org.opensearch.knn.index.engine; +import lombok.Getter; import org.opensearch.common.ValidationException; -import org.opensearch.knn.training.VectorSpaceInfo; import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.function.BiFunction; -import java.util.function.Predicate; /** * Parameter that can be set for a method component @@ -21,10 +20,11 @@ */ public abstract class Parameter { - private String name; - private T defaultValue; - protected Predicate validator; - protected BiFunction validatorWithData; + @Getter + private final String name; + @Getter + private final T defaultValue; + protected BiFunction validator; /** * Constructor @@ -33,74 +33,31 @@ public abstract class Parameter { * @param defaultValue of the parameter * @param validator used to validate a parameter value passed */ - public Parameter(String name, T defaultValue, Predicate validator) { + public Parameter(String name, T defaultValue, BiFunction validator) { this.name = name; this.defaultValue = defaultValue; this.validator = validator; - this.validatorWithData = null; - } - - public Parameter(String name, T defaultValue, Predicate validator, BiFunction validatorWithData) { - this.name = name; - this.defaultValue = defaultValue; - this.validator = validator; - this.validatorWithData = validatorWithData; - } - - /** - * Getter for parameter name - * - * @return parameter name - */ - public String getName() { - return name; - } - - /** - * Get default value for parameter - * - * @return default value of the parameter - */ - public T getDefaultValue() { - return defaultValue; } /** * Check if the value passed in is valid * * @param value to be checked + * @param knnMethodConfigContext context for the validation * @return ValidationException produced by validation errors; null if no validations errors. */ - public abstract ValidationException validate(Object value); - - /** - * Check if the value passed in is valid, using additional data not present in the value - * - * @param value to be checked - * @param vectorSpaceInfo additional data not present in the value - * @return ValidationException produced by validation errors; null if no validations errors. - */ - public abstract ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo); + public abstract ValidationException validate(Object value, KNNMethodConfigContext knnMethodConfigContext); /** * Boolean method parameter */ public static class BooleanParameter extends Parameter { - public BooleanParameter(String name, Boolean defaultValue, Predicate validator) { + public BooleanParameter(String name, Boolean defaultValue, BiFunction validator) { super(name, defaultValue, validator); } - public BooleanParameter( - String name, - Boolean defaultValue, - Predicate validator, - BiFunction validatorWithData - ) { - super(name, defaultValue, validator, validatorWithData); - } - @Override - public ValidationException validate(Object value) { + public ValidationException validate(Object value, KNNMethodConfigContext knnMethodConfigContext) { ValidationException validationException = null; if (!(value instanceof Boolean)) { validationException = new ValidationException(); @@ -110,74 +67,24 @@ public ValidationException validate(Object value) { return validationException; } - if (!validator.test((Boolean) value)) { + if (!validator.apply((Boolean) value, knnMethodConfigContext)) { validationException = new ValidationException(); validationException.addValidationError(String.format("parameter validation failed for Boolean parameter [%s].", getName())); } return validationException; } - - @Override - public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) { - ValidationException validationException = null; - if (!(value instanceof Boolean)) { - validationException = new ValidationException(); - validationException.addValidationError(String.format("value not of type Boolean for Boolean parameter [%s].", getName())); - return validationException; - } - - if (validatorWithData == null) { - return null; - } - - if (!validatorWithData.apply((Boolean) value, vectorSpaceInfo)) { - validationException = new ValidationException(); - validationException.addValidationError(String.format("parameter validation failed for Boolean parameter [%s].", getName())); - } - - return validationException; - } } /** * Integer method parameter */ public static class IntegerParameter extends Parameter { - public IntegerParameter(String name, Integer defaultValue, Predicate validator) { + public IntegerParameter(String name, Integer defaultValue, BiFunction validator) { super(name, defaultValue, validator); } - public IntegerParameter( - String name, - Integer defaultValue, - Predicate validator, - BiFunction validatorWithData - ) { - super(name, defaultValue, validator, validatorWithData); - } - @Override - public ValidationException validate(Object value) { - ValidationException validationException = null; - if (!(value instanceof Integer)) { - validationException = new ValidationException(); - validationException.addValidationError( - String.format("Value not of type Integer for Integer " + "parameter \"%s\".", getName()) - ); - return validationException; - } - - if (!validator.test((Integer) value)) { - validationException = new ValidationException(); - validationException.addValidationError( - String.format("Parameter validation failed for Integer " + "parameter \"%s\".", getName()) - ); - } - return validationException; - } - - @Override - public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) { + public ValidationException validate(Object value, KNNMethodConfigContext knnMethodConfigContext) { ValidationException validationException = null; if (!(value instanceof Integer)) { validationException = new ValidationException(); @@ -187,11 +94,7 @@ public ValidationException validateWithData(Object value, VectorSpaceInfo vector return validationException; } - if (validatorWithData == null) { - return null; - } - - if (!validatorWithData.apply((Integer) value, vectorSpaceInfo)) { + if (!validator.apply((Integer) value, knnMethodConfigContext)) { validationException = new ValidationException(); validationException.addValidationError(String.format("parameter validation failed for Integer parameter [%s].", getName())); } @@ -204,53 +107,18 @@ public ValidationException validateWithData(Object value, VectorSpaceInfo vector * Double method parameter */ public static class DoubleParameter extends Parameter { - public DoubleParameter(String name, Double defaultValue, Predicate validator) { + public DoubleParameter(String name, Double defaultValue, BiFunction validator) { super(name, defaultValue, validator); } - public DoubleParameter( - String name, - Double defaultValue, - Predicate validator, - BiFunction validatorWithData - ) { - super(name, defaultValue, validator, validatorWithData); - } - @Override - public ValidationException validate(Object value) { + public ValidationException validate(Object value, KNNMethodConfigContext knnMethodConfigContext) { if (Objects.isNull(value)) { String validationErrorMsg = String.format(Locale.ROOT, "Null value provided for Double " + "parameter \"%s\".", getName()); return getValidationException(validationErrorMsg); } - if (value.equals(0)) value = 0.0; - if (!(value instanceof Double)) { - String validationErrorMsg = String.format( - Locale.ROOT, - "Value not of type Double for Double " + "parameter \"%s\".", - getName() - ); - return getValidationException(validationErrorMsg); - } - - if (!validator.test((Double) value)) { - String validationErrorMsg = String.format( - Locale.ROOT, - "Parameter validation failed for Double " + "parameter \"%s\".", - getName() - ); - return getValidationException(validationErrorMsg); - } - return null; - } - - @Override - public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) { - if (Objects.isNull(value)) { - String validationErrorMsg = String.format(Locale.ROOT, "Null value provided for Double " + "parameter \"%s\".", getName()); - return getValidationException(validationErrorMsg); - } + if (value.equals(0)) value = 0.0; if (!(value instanceof Double)) { String validationErrorMsg = String.format( @@ -261,11 +129,7 @@ public ValidationException validateWithData(Object value, VectorSpaceInfo vector return getValidationException(validationErrorMsg); } - if (validatorWithData == null) { - return null; - } - - if (!validatorWithData.apply((Double) value, vectorSpaceInfo)) { + if (!validator.apply((Double) value, knnMethodConfigContext)) { String validationErrorMsg = String.format(Locale.ROOT, "parameter validation failed for Double parameter [%s].", getName()); return getValidationException(validationErrorMsg); } @@ -291,47 +155,12 @@ public static class StringParameter extends Parameter { * @param defaultValue value to assign if the parameter is not set * @param validator used to validate the parameter value passed */ - public StringParameter(String name, String defaultValue, Predicate validator) { + public StringParameter(String name, String defaultValue, BiFunction validator) { super(name, defaultValue, validator); } - public StringParameter( - String name, - String defaultValue, - Predicate validator, - BiFunction validatorWithData - ) { - super(name, defaultValue, validator, validatorWithData); - } - - /** - * Check if the value passed in is valid - * - * @param value to be checked - * @return ValidationException produced by validation errors; null if no validations errors. - */ - @Override - public ValidationException validate(Object value) { - ValidationException validationException = null; - if (!(value instanceof String)) { - validationException = new ValidationException(); - validationException.addValidationError( - String.format("Value not of type String for String " + "parameter \"%s\".", getName()) - ); - return validationException; - } - - if (!validator.test((String) value)) { - validationException = new ValidationException(); - validationException.addValidationError( - String.format("Parameter validation failed for String " + "parameter \"%s\".", getName()) - ); - } - return validationException; - } - @Override - public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) { + public ValidationException validate(Object value, KNNMethodConfigContext knnMethodConfigContext) { ValidationException validationException = null; if (!(value instanceof String)) { validationException = new ValidationException(); @@ -341,11 +170,7 @@ public ValidationException validateWithData(Object value, VectorSpaceInfo vector return validationException; } - if (validatorWithData == null) { - return null; - } - - if (!validatorWithData.apply((String) value, vectorSpaceInfo)) { + if (!validator.apply((String) value, knnMethodConfigContext)) { validationException = new ValidationException(); validationException.addValidationError(String.format("parameter validation failed for String parameter [%s].", getName())); } @@ -361,7 +186,7 @@ public ValidationException validateWithData(Object value, VectorSpaceInfo vector */ public static class MethodComponentContextParameter extends Parameter { - private Map methodComponents; + private final Map methodComponents; /** * Constructor @@ -375,46 +200,18 @@ public MethodComponentContextParameter( MethodComponentContext defaultValue, Map methodComponents ) { - super(name, defaultValue, methodComponentContext -> { - if (!methodComponents.containsKey(methodComponentContext.getName())) { - return false; - } - - return methodComponents.get(methodComponentContext.getName()).validate(methodComponentContext) == null; - }, (methodComponentContext, vectorSpaceInfo) -> { + super(name, defaultValue, (methodComponentContext, knnMethodConfigContext) -> { if (!methodComponents.containsKey(methodComponentContext.getName())) { return false; } return methodComponents.get(methodComponentContext.getName()) - .validateWithData(methodComponentContext, vectorSpaceInfo) == null; + .validate(methodComponentContext, knnMethodConfigContext) == null; }); this.methodComponents = methodComponents; } @Override - public ValidationException validate(Object value) { - ValidationException validationException = null; - if (!(value instanceof MethodComponentContext)) { - validationException = new ValidationException(); - validationException.addValidationError( - String.format("Value not of type MethodComponentContext for" + " MethodComponentContext parameter \"%s\".", getName()) - ); - return validationException; - } - - if (!validator.test((MethodComponentContext) value)) { - validationException = new ValidationException(); - validationException.addValidationError("Parameter validation failed."); - validationException.addValidationError( - String.format("Parameter validation failed for " + "MethodComponentContext parameter \"%s\".", getName()) - ); - } - - return validationException; - } - - @Override - public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) { + public ValidationException validate(Object value, KNNMethodConfigContext knnMethodConfigContext) { ValidationException validationException = null; if (!(value instanceof MethodComponentContext)) { validationException = new ValidationException(); @@ -424,11 +221,7 @@ public ValidationException validateWithData(Object value, VectorSpaceInfo vector return validationException; } - if (validatorWithData == null) { - return null; - } - - if (!validatorWithData.apply((MethodComponentContext) value, vectorSpaceInfo)) { + if (!validator.apply((MethodComponentContext) value, knnMethodConfigContext)) { validationException = new ValidationException(); validationException.addValidationError( String.format("parameter validation failed for MethodComponentContext parameter [%s].", getName()) diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java new file mode 100644 index 000000000..52b7efe73 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java @@ -0,0 +1,84 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine.faiss; + +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.AbstractKNNMethod; +import org.opensearch.knn.index.engine.KNNLibrarySearchContext; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.MethodComponent; +import org.opensearch.knn.index.mapper.PerDimensionProcessor; +import org.opensearch.knn.index.mapper.PerDimensionValidator; + +import java.util.Set; + +import static org.opensearch.knn.index.engine.faiss.FaissFP16Util.isFaissSQClipToFP16RangeEnabled; +import static org.opensearch.knn.index.engine.faiss.FaissFP16Util.isFaissSQfp16; + +public abstract class AbstractFaissMethod extends AbstractKNNMethod { + + /** + * Constructor for the AbstractFaissMethod class. + * + * @param methodComponent The method component used to create the method + * @param spaces The set of spaces supported by the method + * @param knnLibrarySearchContext The KNN library search context + */ + public AbstractFaissMethod(MethodComponent methodComponent, Set spaces, KNNLibrarySearchContext knnLibrarySearchContext) { + super(methodComponent, spaces, knnLibrarySearchContext); + } + + @Override + protected PerDimensionValidator doGetPerDimensionValidator( + KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext + ) { + VectorDataType vectorDataType = knnMethodConfigContext.getVectorDataType(); + if (VectorDataType.BINARY == vectorDataType) { + return PerDimensionValidator.DEFAULT_BIT_VALIDATOR; + } + + if (VectorDataType.BYTE == vectorDataType) { + return PerDimensionValidator.DEFAULT_BYTE_VALIDATOR; + } + + if (VectorDataType.FLOAT == vectorDataType) { + if (isFaissSQfp16(knnMethodContext.getMethodComponentContext())) { + return FaissFP16Util.FP16_VALIDATOR; + } + return PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR; + } + + throw new IllegalStateException("Unsupported vector data type " + vectorDataType); + } + + @Override + protected PerDimensionProcessor doGetPerDimensionProcessor( + KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext + ) { + VectorDataType vectorDataType = knnMethodConfigContext.getVectorDataType(); + + if (VectorDataType.BINARY == vectorDataType) { + return PerDimensionProcessor.NOOP_PROCESSOR; + } + + if (VectorDataType.BYTE == vectorDataType) { + return PerDimensionProcessor.NOOP_PROCESSOR; + } + + if (VectorDataType.FLOAT == vectorDataType) { + if (isFaissSQClipToFP16RangeEnabled(knnMethodContext.getMethodComponentContext())) { + return FaissFP16Util.CLIP_TO_FP16_PROCESSOR; + } + return PerDimensionProcessor.NOOP_PROCESSOR; + } + + throw new IllegalStateException("Unsupported vector data type " + vectorDataType); + } +} diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissFP16Util.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissFP16Util.java new file mode 100644 index 000000000..8e76ca0fb --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissFP16Util.java @@ -0,0 +1,145 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine.faiss; + +import org.opensearch.knn.index.engine.MethodComponentContext; +import org.opensearch.knn.index.mapper.PerDimensionProcessor; +import org.opensearch.knn.index.mapper.PerDimensionValidator; + +import java.util.Locale; +import java.util.Map; +import java.util.Objects; + +import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE; +import static org.opensearch.knn.common.KNNConstants.FP16_MAX_VALUE; +import static org.opensearch.knn.common.KNNConstants.FP16_MIN_VALUE; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; +import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue; + +public class FaissFP16Util { + + // Validates if it is a finite number and within the fp16 range of [-65504 to 65504]. + static PerDimensionValidator FP16_VALIDATOR = new PerDimensionValidator() { + @Override + public void validate(float value) { + validateFP16VectorValue(value); + } + + @Override + public void validateByte(float value) { + throw new IllegalStateException("DEFAULT_FP16_VALIDATOR should only be used for float vectors"); + } + }; + + // If the encoder parameter, "clip" is set to True, if the vector value is outside the FP16 range then it will be + // clipped to FP16 range. + static PerDimensionProcessor CLIP_TO_FP16_PROCESSOR = new PerDimensionProcessor() { + @Override + public float process(float value) { + return clipVectorValueToFP16Range(value); + } + + @Override + public float processByte(float value) { + throw new IllegalStateException("CLIP_TO_FP16_PROCESSOR should not be called with byte type"); + } + }; + + /** + * Validate the float vector value and if it is outside FP16 range, + * then it will be clipped to FP16 range of [-65504 to 65504]. + * + * @param value float vector value + * @return vector value clipped to FP16 range + */ + public static float clipVectorValueToFP16Range(float value) { + validateFloatVectorValue(value); + if (value < FP16_MIN_VALUE) return FP16_MIN_VALUE; + if (value > FP16_MAX_VALUE) return FP16_MAX_VALUE; + return value; + } + + /** + * Validate the float vector value and throw exception if it is not a number or not in the finite range + * or is not within the FP16 range of [-65504 to 65504]. + * + * @param value float vector value + */ + public static void validateFP16VectorValue(float value) { + validateFloatVectorValue(value); + if (value < FP16_MIN_VALUE || value > FP16_MAX_VALUE) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]", + ENCODER_SQ, + FAISS_SQ_ENCODER_FP16, + FP16_MIN_VALUE, + FP16_MAX_VALUE + ) + ); + } + } + + /** + * Verify mapping and return true if it is a "faiss" Index using "sq" encoder of type "fp16" + * + * @param methodComponentContext MethodComponentContext + * @return true if it is a "faiss" Index using "sq" encoder of type "fp16" + */ + static boolean isFaissSQfp16(MethodComponentContext methodComponentContext) { + MethodComponentContext encoderContext = extractEncoderMethodComponentContext(methodComponentContext); + if (encoderContext == null) { + return false; + } + + // returns true if encoder name is "sq" and type is "fp16" + return ENCODER_SQ.equals(encoderContext.getName()) + && FAISS_SQ_ENCODER_FP16.equals(encoderContext.getParameters().getOrDefault(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16)); + } + + /** + * Verify mapping and return the value of "clip" parameter(default false) for a "faiss" Index + * using "sq" encoder of type "fp16". + * + * @param methodComponentContext MethodComponentContext + * @return boolean value of "clip" parameter + */ + static boolean isFaissSQClipToFP16RangeEnabled(MethodComponentContext methodComponentContext) { + MethodComponentContext encoderContext = extractEncoderMethodComponentContext(methodComponentContext); + if (encoderContext == null) { + return false; + } + return (boolean) encoderContext.getParameters().getOrDefault(FAISS_SQ_CLIP, false); + } + + static MethodComponentContext extractEncoderMethodComponentContext(MethodComponentContext methodComponentContext) { + if (Objects.isNull(methodComponentContext)) { + return null; + } + + if (methodComponentContext.getParameters().isEmpty()) { + return null; + } + + Map methodComponentParams = methodComponentContext.getParameters(); + + // The method component parameters should have an encoder + if (!methodComponentParams.containsKey(METHOD_ENCODER_PARAMETER)) { + return null; + } + + // Validate if the object is of type MethodComponentContext before casting it later + if (!(methodComponentParams.get(METHOD_ENCODER_PARAMETER) instanceof MethodComponentContext)) { + return null; + } + + return (MethodComponentContext) methodComponentParams.get(METHOD_ENCODER_PARAMETER); + } +} diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissFlatEncoder.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissFlatEncoder.java index aea3bf51a..5e6e4060f 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissFlatEncoder.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissFlatEncoder.java @@ -5,24 +5,36 @@ package org.opensearch.knn.index.engine.faiss; +import com.google.common.collect.ImmutableSet; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.Encoder; import org.opensearch.knn.index.engine.MethodComponent; +import java.util.Set; + /** * Flat faiss encoder. Flat encoding means that it does nothing. It needs an encoder, though, because it * is used in generating the index description. */ public class FaissFlatEncoder implements Encoder { + private static final Set SUPPORTED_DATA_TYPES = ImmutableSet.of( + VectorDataType.FLOAT, + VectorDataType.BYTE, + VectorDataType.BINARY + ); + private final static MethodComponent METHOD_COMPONENT = MethodComponent.Builder.builder(KNNConstants.ENCODER_FLAT) .setMapGenerator( - ((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder( + ((methodComponent, methodComponentContext, knnMethodConfigContext) -> MethodAsMapBuilder.builder( KNNConstants.FAISS_FLAT_DESCRIPTION, methodComponent, - methodComponentContext + methodComponentContext, + knnMethodConfigContext ).build()) ) + .addSupportedDataTypes(SUPPORTED_DATA_TYPES) .build(); @Override diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java index 382a71741..ee6a4f101 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java @@ -5,9 +5,11 @@ package org.opensearch.knn.index.engine.faiss; +import com.google.common.collect.ImmutableSet; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.AbstractKNNMethod; import org.opensearch.knn.index.engine.DefaultHnswSearchContext; import org.opensearch.knn.index.engine.Encoder; @@ -27,11 +29,15 @@ import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; +import static org.opensearch.knn.index.engine.faiss.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; /** * Faiss HNSW method implementation */ -public class FaissHNSWMethod extends AbstractKNNMethod { +public class FaissHNSWMethod extends AbstractFaissMethod { + + private static final Set SUPPORTED_DATA_TYPES = ImmutableSet.of(VectorDataType.FLOAT, VectorDataType.BINARY); + public final static List SUPPORTED_SPACES = Arrays.asList( SpaceType.UNDEFINED, SpaceType.HAMMING, @@ -56,30 +62,41 @@ public FaissHNSWMethod() { private static MethodComponent initMethodComponent() { return MethodComponent.Builder.builder(METHOD_HNSW) + .addSupportedDataTypes(SUPPORTED_DATA_TYPES) .addParameter( METHOD_PARAMETER_M, - new Parameter.IntegerParameter(METHOD_PARAMETER_M, KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M, v -> v > 0) + new Parameter.IntegerParameter(METHOD_PARAMETER_M, KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M, (v, context) -> v > 0) ) .addParameter( METHOD_PARAMETER_EF_CONSTRUCTION, new Parameter.IntegerParameter( METHOD_PARAMETER_EF_CONSTRUCTION, KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION, - v -> v > 0 + (v, context) -> v > 0 ) ) .addParameter( METHOD_PARAMETER_EF_SEARCH, - new Parameter.IntegerParameter(METHOD_PARAMETER_EF_SEARCH, KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH, v -> v > 0) + new Parameter.IntegerParameter( + METHOD_PARAMETER_EF_SEARCH, + KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH, + (v, context) -> v > 0 + ) ) .addParameter(METHOD_ENCODER_PARAMETER, initEncoderParameter()) - .setMapGenerator( - ((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder( - FAISS_HNSW_DESCRIPTION, + .setMapGenerator(((methodComponent, methodComponentContext, knnMethodConfigContext) -> { + String prefix = ""; + if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BINARY) { + prefix = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; + } + + return MethodAsMapBuilder.builder( + prefix + FAISS_HNSW_DESCRIPTION, methodComponent, - methodComponentContext - ).addParameter(METHOD_PARAMETER_M, "", "").addParameter(METHOD_ENCODER_PARAMETER, ",", "").build()) - ) + methodComponentContext, + knnMethodConfigContext + ).addParameter(METHOD_PARAMETER_M, "", "").addParameter(METHOD_ENCODER_PARAMETER, ",", "").build(); + })) .build(); } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWPQEncoder.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWPQEncoder.java index 9880b2cd9..8d53f3c0a 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWPQEncoder.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWPQEncoder.java @@ -5,12 +5,15 @@ package org.opensearch.knn.index.engine.faiss; +import com.google.common.collect.ImmutableSet; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.Encoder; import org.opensearch.knn.index.engine.MethodComponent; import org.opensearch.knn.index.engine.Parameter; import java.util.Objects; +import java.util.Set; import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES; import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT; @@ -26,30 +29,34 @@ */ public class FaissHNSWPQEncoder implements Encoder { + private static final Set SUPPORTED_DATA_TYPES = ImmutableSet.of(VectorDataType.FLOAT); + private final static MethodComponent METHOD_COMPONENT = MethodComponent.Builder.builder(KNNConstants.ENCODER_PQ) + .addSupportedDataTypes(SUPPORTED_DATA_TYPES) .addParameter( ENCODER_PARAMETER_PQ_M, - new Parameter.IntegerParameter( - ENCODER_PARAMETER_PQ_M, - ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT, - v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT, - (v, vectorSpaceInfo) -> vectorSpaceInfo.getDimension() % v == 0 - ) + new Parameter.IntegerParameter(ENCODER_PARAMETER_PQ_M, ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT, (v, context) -> { + boolean isValueGreaterThan0 = v > 0; + boolean isValueLessThanCodeCountLimit = v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT; + boolean isDimensionDivisibleByValue = context.getDimension() % v == 0; + return isValueGreaterThan0 && isValueLessThanCodeCountLimit && isDimensionDivisibleByValue; + }) ) .addParameter( ENCODER_PARAMETER_PQ_CODE_SIZE, new Parameter.IntegerParameter( ENCODER_PARAMETER_PQ_CODE_SIZE, ENCODER_PARAMETER_PQ_CODE_SIZE_DEFAULT, - v -> Objects.equals(v, ENCODER_PARAMETER_PQ_CODE_SIZE_DEFAULT) + (v, context) -> Objects.equals(v, ENCODER_PARAMETER_PQ_CODE_SIZE_DEFAULT) ) ) .setRequiresTraining(true) .setMapGenerator( - ((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder( + ((methodComponent, methodComponentContext, knnMethodConfigContext) -> MethodAsMapBuilder.builder( FAISS_PQ_DESCRIPTION, methodComponent, - methodComponentContext + methodComponentContext, + knnMethodConfigContext ).addParameter(ENCODER_PARAMETER_PQ_M, "", "").build()) ) .setOverheadInKBEstimator((methodComponent, methodComponentContext, dimension) -> { diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java index aa05e8c87..a21810b50 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java @@ -5,8 +5,10 @@ package org.opensearch.knn.index.engine.faiss; +import com.google.common.collect.ImmutableSet; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.AbstractKNNMethod; import org.opensearch.knn.index.engine.DefaultIVFSearchContext; import org.opensearch.knn.index.engine.Encoder; @@ -30,11 +32,14 @@ import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES_DEFAULT; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES_LIMIT; +import static org.opensearch.knn.index.engine.faiss.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; /** * Faiss ivf implementation */ -public class FaissIVFMethod extends AbstractKNNMethod { +public class FaissIVFMethod extends AbstractFaissMethod { + + private static final Set SUPPORTED_DATA_TYPES = ImmutableSet.of(VectorDataType.FLOAT, VectorDataType.BINARY); public final static List SUPPORTED_SPACES = Arrays.asList( SpaceType.UNDEFINED, @@ -60,12 +65,13 @@ public FaissIVFMethod() { private static MethodComponent initMethodComponent() { return MethodComponent.Builder.builder(METHOD_IVF) + .addSupportedDataTypes(SUPPORTED_DATA_TYPES) .addParameter( METHOD_PARAMETER_NPROBES, new Parameter.IntegerParameter( METHOD_PARAMETER_NPROBES, METHOD_PARAMETER_NPROBES_DEFAULT, - v -> v > 0 && v < METHOD_PARAMETER_NPROBES_LIMIT + (v, context) -> v > 0 && v < METHOD_PARAMETER_NPROBES_LIMIT ) ) .addParameter( @@ -73,18 +79,24 @@ private static MethodComponent initMethodComponent() { new Parameter.IntegerParameter( METHOD_PARAMETER_NLIST, METHOD_PARAMETER_NLIST_DEFAULT, - v -> v > 0 && v < METHOD_PARAMETER_NLIST_LIMIT + (v, context) -> v > 0 && v < METHOD_PARAMETER_NLIST_LIMIT ) ) .addParameter(METHOD_ENCODER_PARAMETER, initEncoderParameter()) .setRequiresTraining(true) - .setMapGenerator( - ((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder( - FAISS_IVF_DESCRIPTION, + .setMapGenerator(((methodComponent, methodComponentContext, knnMethodConfigContext) -> { + String prefix = ""; + if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BINARY) { + prefix = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; + } + + return MethodAsMapBuilder.builder( + prefix + FAISS_IVF_DESCRIPTION, methodComponent, - methodComponentContext - ).addParameter(METHOD_PARAMETER_NLIST, "", "").addParameter(METHOD_ENCODER_PARAMETER, ",", "").build()) - ) + methodComponentContext, + knnMethodConfigContext + ).addParameter(METHOD_PARAMETER_NLIST, "", "").addParameter(METHOD_ENCODER_PARAMETER, ",", "").build(); + })) .setOverheadInKBEstimator((methodComponent, methodComponentContext, dimension) -> { // Size estimate formula: (4 * nlists * d) / 1024 + 1 diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFPQEncoder.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFPQEncoder.java index b9632004d..b38f5c816 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFPQEncoder.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFPQEncoder.java @@ -5,11 +5,15 @@ package org.opensearch.knn.index.engine.faiss; +import com.google.common.collect.ImmutableSet; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.Encoder; import org.opensearch.knn.index.engine.MethodComponent; import org.opensearch.knn.index.engine.Parameter; +import java.util.Set; + import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES; import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT; import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT; @@ -24,30 +28,35 @@ * {@link FaissHNSWPQEncoder}. Hence, they are separate classes. */ public class FaissIVFPQEncoder implements Encoder { + + private static final Set SUPPORTED_DATA_TYPES = ImmutableSet.of(VectorDataType.FLOAT); + private final static MethodComponent METHOD_COMPONENT = MethodComponent.Builder.builder(KNNConstants.ENCODER_PQ) + .addSupportedDataTypes(SUPPORTED_DATA_TYPES) .addParameter( ENCODER_PARAMETER_PQ_M, - new Parameter.IntegerParameter( - ENCODER_PARAMETER_PQ_M, - ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT, - v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT, - (v, vectorSpaceInfo) -> vectorSpaceInfo.getDimension() % v == 0 - ) + new Parameter.IntegerParameter(ENCODER_PARAMETER_PQ_M, ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT, (v, context) -> { + boolean isValueGreaterThan0 = v > 0; + boolean isValueLessThanCodeCountLimit = v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT; + boolean isDimensionDivisibleByValue = context.getDimension() % v == 0; + return isValueGreaterThan0 && isValueLessThanCodeCountLimit && isDimensionDivisibleByValue; + }) ) .addParameter( ENCODER_PARAMETER_PQ_CODE_SIZE, - new Parameter.IntegerParameter( - ENCODER_PARAMETER_PQ_CODE_SIZE, - ENCODER_PARAMETER_PQ_CODE_SIZE_DEFAULT, - v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_SIZE_LIMIT - ) + new Parameter.IntegerParameter(ENCODER_PARAMETER_PQ_CODE_SIZE, ENCODER_PARAMETER_PQ_CODE_SIZE_DEFAULT, (v, context) -> { + boolean isValueGreaterThan0 = v > 0; + boolean isValueLessThanCodeSizeLimit = v < ENCODER_PARAMETER_PQ_CODE_SIZE_LIMIT; + return isValueGreaterThan0 && isValueLessThanCodeSizeLimit; + }) ) .setRequiresTraining(true) .setMapGenerator( - ((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder( + ((methodComponent, methodComponentContext, knnMethodConfigContext) -> MethodAsMapBuilder.builder( FAISS_PQ_DESCRIPTION, methodComponent, - methodComponentContext + methodComponentContext, + knnMethodConfigContext ).addParameter(ENCODER_PARAMETER_PQ_M, "", "").addParameter(ENCODER_PARAMETER_PQ_CODE_SIZE, "x", "").build()) ) .setOverheadInKBEstimator((methodComponent, methodComponentContext, dimension) -> { diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissSQEncoder.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissSQEncoder.java index eb0af9c38..2d0d184ca 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissSQEncoder.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissSQEncoder.java @@ -5,11 +5,14 @@ package org.opensearch.knn.index.engine.faiss; +import com.google.common.collect.ImmutableSet; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.Encoder; import org.opensearch.knn.index.engine.MethodComponent; import org.opensearch.knn.index.engine.Parameter; import java.util.Objects; +import java.util.Set; import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP; @@ -22,14 +25,22 @@ * Faiss SQ encoder */ public class FaissSQEncoder implements Encoder { + + private static final Set SUPPORTED_DATA_TYPES = ImmutableSet.of(VectorDataType.FLOAT); + private final static MethodComponent METHOD_COMPONENT = MethodComponent.Builder.builder(ENCODER_SQ) - .addParameter(FAISS_SQ_TYPE, new Parameter.StringParameter(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16, FAISS_SQ_ENCODER_TYPES::contains)) - .addParameter(FAISS_SQ_CLIP, new Parameter.BooleanParameter(FAISS_SQ_CLIP, false, Objects::nonNull)) + .addSupportedDataTypes(SUPPORTED_DATA_TYPES) + .addParameter( + FAISS_SQ_TYPE, + new Parameter.StringParameter(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16, (v, context) -> FAISS_SQ_ENCODER_TYPES.contains(v)) + ) + .addParameter(FAISS_SQ_CLIP, new Parameter.BooleanParameter(FAISS_SQ_CLIP, false, (v, context) -> Objects.nonNull(v))) .setMapGenerator( - ((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder( + ((methodComponent, methodComponentContext, knnMethodConfigContext) -> MethodAsMapBuilder.builder( FAISS_SQ_DESCRIPTION, methodComponent, - methodComponentContext + methodComponentContext, + knnMethodConfigContext ).addParameter(FAISS_SQ_TYPE, "", "").build()) ) .build(); diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/MethodAsMapBuilder.java b/src/main/java/org/opensearch/knn/index/engine/faiss/MethodAsMapBuilder.java index 445abfdd8..abb3d08c9 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/MethodAsMapBuilder.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/MethodAsMapBuilder.java @@ -7,6 +7,7 @@ import lombok.AllArgsConstructor; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.MethodComponent; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.engine.Parameter; @@ -29,6 +30,7 @@ class MethodAsMapBuilder { String indexDescription; MethodComponent methodComponent; Map methodAsMap; + KNNMethodConfigContext knnMethodConfigContext; /** * Add a parameter that will be used in the index description for the given method component @@ -55,7 +57,7 @@ MethodAsMapBuilder addParameter(String parameterName, String prefix, String suff subMethodComponentContext.getName() ); - Map subMethodAsMap = subMethodComponent.getAsMap(subMethodComponentContext); + Map subMethodAsMap = subMethodComponent.getAsMap(subMethodComponentContext, knnMethodConfigContext); indexDescription += subMethodAsMap.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER); subMethodAsMap.remove(KNNConstants.INDEX_DESCRIPTION_PARAMETER); @@ -85,11 +87,15 @@ Map build() { static MethodAsMapBuilder builder( String baseDescription, MethodComponent methodComponent, - MethodComponentContext methodComponentContext + MethodComponentContext methodComponentContext, + KNNMethodConfigContext knnMethodConfigContext ) { Map initialMap = new HashMap<>(); initialMap.put(NAME, methodComponent.getName()); - initialMap.put(PARAMETERS, MethodComponent.getParameterMapWithDefaultsAdded(methodComponentContext, methodComponent)); - return new MethodAsMapBuilder(baseDescription, methodComponent, initialMap); + initialMap.put( + PARAMETERS, + MethodComponent.getParameterMapWithDefaultsAdded(methodComponentContext, methodComponent, knnMethodConfigContext) + ); + return new MethodAsMapBuilder(baseDescription, methodComponent, initialMap, knnMethodConfigContext); } } diff --git a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWMethod.java b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWMethod.java index c6fcdb7c4..317f67c10 100644 --- a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWMethod.java @@ -5,9 +5,11 @@ package org.opensearch.knn.index.engine.lucene; +import com.google.common.collect.ImmutableSet; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.AbstractKNNMethod; import org.opensearch.knn.index.engine.Encoder; import org.opensearch.knn.index.engine.MethodComponent; @@ -30,6 +32,8 @@ */ public class LuceneHNSWMethod extends AbstractKNNMethod { + private static final Set SUPPORTED_DATA_TYPES = ImmutableSet.of(VectorDataType.FLOAT, VectorDataType.BYTE); + public final static List SUPPORTED_SPACES = Arrays.asList( SpaceType.UNDEFINED, SpaceType.L2, @@ -54,16 +58,17 @@ public LuceneHNSWMethod() { private static MethodComponent initMethodComponent() { return MethodComponent.Builder.builder(METHOD_HNSW) + .addSupportedDataTypes(SUPPORTED_DATA_TYPES) .addParameter( METHOD_PARAMETER_M, - new Parameter.IntegerParameter(METHOD_PARAMETER_M, KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M, v -> v > 0) + new Parameter.IntegerParameter(METHOD_PARAMETER_M, KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M, (v, context) -> v > 0) ) .addParameter( METHOD_PARAMETER_EF_CONSTRUCTION, new Parameter.IntegerParameter( METHOD_PARAMETER_EF_CONSTRUCTION, KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION, - v -> v > 0 + (v, context) -> v > 0 ) ) .addParameter(METHOD_ENCODER_PARAMETER, initEncoderParameter()) diff --git a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWSearchContext.java b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWSearchContext.java index 2c4da27df..bcc1c9af0 100644 --- a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWSearchContext.java +++ b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWSearchContext.java @@ -17,7 +17,10 @@ public class LuceneHNSWSearchContext implements KNNLibrarySearchContext { private final Map> supportedMethodParameters = ImmutableMap.>builder() - .put(MethodParameter.EF_SEARCH.getName(), new Parameter.IntegerParameter(MethodParameter.EF_SEARCH.getName(), null, value -> true)) + .put( + MethodParameter.EF_SEARCH.getName(), + new Parameter.IntegerParameter(MethodParameter.EF_SEARCH.getName(), null, (v, context) -> true) + ) .build(); @Override diff --git a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneSQEncoder.java b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneSQEncoder.java index fac851ea1..0ec43db41 100644 --- a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneSQEncoder.java +++ b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneSQEncoder.java @@ -5,11 +5,14 @@ package org.opensearch.knn.index.engine.lucene; +import com.google.common.collect.ImmutableSet; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.Encoder; import org.opensearch.knn.index.engine.MethodComponent; import org.opensearch.knn.index.engine.Parameter; import java.util.List; +import java.util.Set; import static org.opensearch.knn.common.KNNConstants.DYNAMIC_CONFIDENCE_INTERVAL; import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; @@ -23,19 +26,22 @@ * Lucene scalar quantization encoder */ public class LuceneSQEncoder implements Encoder { + private static final Set SUPPORTED_DATA_TYPES = ImmutableSet.of(VectorDataType.FLOAT); + private final static List LUCENE_SQ_BITS_SUPPORTED = List.of(7); private final static MethodComponent METHOD_COMPONENT = MethodComponent.Builder.builder(ENCODER_SQ) + .addSupportedDataTypes(SUPPORTED_DATA_TYPES) .addParameter( LUCENE_SQ_CONFIDENCE_INTERVAL, new Parameter.DoubleParameter( LUCENE_SQ_CONFIDENCE_INTERVAL, null, - v -> v == DYNAMIC_CONFIDENCE_INTERVAL || (v >= MINIMUM_CONFIDENCE_INTERVAL && v <= MAXIMUM_CONFIDENCE_INTERVAL) + (v, context) -> v == DYNAMIC_CONFIDENCE_INTERVAL || (v >= MINIMUM_CONFIDENCE_INTERVAL && v <= MAXIMUM_CONFIDENCE_INTERVAL) ) ) .addParameter( LUCENE_SQ_BITS, - new Parameter.IntegerParameter(LUCENE_SQ_BITS, LUCENE_SQ_DEFAULT_BITS, LUCENE_SQ_BITS_SUPPORTED::contains) + new Parameter.IntegerParameter(LUCENE_SQ_BITS, LUCENE_SQ_DEFAULT_BITS, (v, context) -> LUCENE_SQ_BITS_SUPPORTED.contains(v)) ) .build(); diff --git a/src/main/java/org/opensearch/knn/index/engine/nmslib/NmslibHNSWMethod.java b/src/main/java/org/opensearch/knn/index/engine/nmslib/NmslibHNSWMethod.java index e8e27bcd6..779c16cd3 100644 --- a/src/main/java/org/opensearch/knn/index/engine/nmslib/NmslibHNSWMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/nmslib/NmslibHNSWMethod.java @@ -5,8 +5,10 @@ package org.opensearch.knn.index.engine.nmslib; +import com.google.common.collect.ImmutableSet; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.AbstractKNNMethod; import org.opensearch.knn.index.engine.DefaultHnswSearchContext; import org.opensearch.knn.index.engine.MethodComponent; @@ -25,6 +27,8 @@ */ public class NmslibHNSWMethod extends AbstractKNNMethod { + private static final Set SUPPORTED_DATA_TYPES = ImmutableSet.of(VectorDataType.FLOAT); + public final static List SUPPORTED_SPACES = Arrays.asList( SpaceType.UNDEFINED, SpaceType.L2, @@ -44,16 +48,17 @@ public NmslibHNSWMethod() { private static MethodComponent initMethodComponent() { return MethodComponent.Builder.builder(METHOD_HNSW) + .addSupportedDataTypes(SUPPORTED_DATA_TYPES) .addParameter( METHOD_PARAMETER_M, - new Parameter.IntegerParameter(METHOD_PARAMETER_M, KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M, v -> v > 0) + new Parameter.IntegerParameter(METHOD_PARAMETER_M, KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M, (v, context) -> v > 0) ) .addParameter( METHOD_PARAMETER_EF_CONSTRUCTION, new Parameter.IntegerParameter( METHOD_PARAMETER_EF_CONSTRUCTION, KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION, - v -> v > 0 + (v, context) -> v > 0 ) ) .build(); diff --git a/src/main/java/org/opensearch/knn/index/engine/validation/ParameterValidator.java b/src/main/java/org/opensearch/knn/index/engine/validation/ParameterValidator.java index 6a16b48a5..c79778503 100644 --- a/src/main/java/org/opensearch/knn/index/engine/validation/ParameterValidator.java +++ b/src/main/java/org/opensearch/knn/index/engine/validation/ParameterValidator.java @@ -7,6 +7,7 @@ import org.opensearch.common.Nullable; import org.opensearch.common.ValidationException; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.Parameter; import java.util.ArrayList; @@ -17,14 +18,17 @@ public final class ParameterValidator { /** * A function which validates request parameters. + * * @param validParameters A set of valid parameters that can be requestParameters can be validated against * @param requestParameters parameters from the request - * @return + * @param knnMethodConfigContext context of the knn method + * @return ValidationException if there are any validation errors, null otherwise */ @Nullable public static ValidationException validateParameters( final Map> validParameters, - final Map requestParameters + final Map requestParameters, + KNNMethodConfigContext knnMethodConfigContext ) { if (validParameters == null) { @@ -38,7 +42,8 @@ public static ValidationException validateParameters( final List errorMessages = new ArrayList<>(); for (Map.Entry parameter : requestParameters.entrySet()) { if (validParameters.containsKey(parameter.getKey())) { - final ValidationException parameterValidation = validParameters.get(parameter.getKey()).validate(parameter.getValue()); + final ValidationException parameterValidation = validParameters.get(parameter.getKey()) + .validate(parameter.getValue(), knnMethodConfigContext); if (parameterValidation != null) { errorMessages.addAll(parameterValidation.validationErrors()); } diff --git a/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java index 146b5132f..d37ab9b86 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java @@ -10,6 +10,7 @@ import org.opensearch.Version; import org.opensearch.common.Explicit; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; import java.util.Map; @@ -25,16 +26,19 @@ public static FlatVectorFieldMapper createFieldMapper( String fullname, String simpleName, Map metaValue, - VectorDataType vectorDataType, - Integer dimension, + KNNMethodConfigContext knnMethodConfigContext, MultiFields multiFields, CopyTo copyTo, Explicit ignoreMalformed, boolean stored, - boolean hasDocValues, - Version indexCreatedVersion + boolean hasDocValues ) { - final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType(fullname, metaValue, vectorDataType, () -> dimension); + final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType( + fullname, + metaValue, + knnMethodConfigContext.getVectorDataType(), + knnMethodConfigContext::getDimension + ); return new FlatVectorFieldMapper( simpleName, mappedFieldType, @@ -43,7 +47,7 @@ public static FlatVectorFieldMapper createFieldMapper( ignoreMalformed, stored, hasDocValues, - indexCreatedVersion + knnMethodConfigContext.getVersionCreated() ); } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 94756f595..65c3cfb66 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -15,6 +15,8 @@ import java.util.function.Supplier; import java.util.stream.Collectors; +import lombok.Getter; +import lombok.Setter; import lombok.extern.log4j.Log4j2; import org.apache.lucene.document.Field; import org.apache.lucene.document.FieldType; @@ -24,6 +26,7 @@ import org.opensearch.Version; import org.opensearch.common.Explicit; import org.opensearch.common.ValidationException; +import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.support.XContentMapValues; import org.opensearch.core.common.Strings; import org.opensearch.core.xcontent.ToXContent; @@ -36,18 +39,16 @@ import org.opensearch.index.mapper.ParseContext; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.indices.ModelDao; import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; -import static org.opensearch.knn.common.KNNConstants.ENCODER_FLAT; import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; -import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNValidationUtil.validateVectorDimension; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createKNNMethodContextFromLegacy; @@ -55,7 +56,6 @@ import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForFloatVector; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateIfCircuitBreakerIsNotTriggered; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateIfKNNPluginEnabled; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataType; import static org.opensearch.knn.index.mapper.ModelFieldMapper.UNSET_MODEL_DIMENSION_IDENTIFIER; /** @@ -154,19 +154,10 @@ public static class Builder extends ParametrizedFieldMapper.Builder { }), m -> m.getMethodComponentContext().getName()).setValidator(v -> { if (v == null) return; - ValidationException validationException = null; + ValidationException validationException; if (v.isTrainingRequired()) { validationException = new ValidationException(); validationException.addValidationError(String.format(Locale.ROOT, "\"%s\" requires training.", KNN_METHOD)); - } - - ValidationException methodValidation = v.validate(); - if (methodValidation != null) { - validationException = validationException == null ? new ValidationException() : validationException; - validationException.addValidationErrors(methodValidation.validationErrors()); - } - - if (validationException != null) { throw validationException; } }); @@ -190,13 +181,24 @@ public static class Builder extends ParametrizedFieldMapper.Builder { // (https://github.com/opensearch-project/OpenSearch/blob/2.16.0/server/src/main/java/org/opensearch/index/mapper/ParametrizedFieldMapper.java#L322-L324). // So, what we do is pass in a "resolvedKNNMethodContext" that will either be null or be set via the merge builder // constructor. A similar approach was taken for https://github.com/opendistro-for-elasticsearch/k-NN/issues/288 + @Setter + @Getter private KNNMethodContext resolvedKNNMethodContext; - - public Builder(String name, ModelDao modelDao, Version indexCreatedVersion, KNNMethodContext resolvedKNNMethodContext) { + @Setter + private KNNMethodConfigContext knnMethodConfigContext; + + public Builder( + String name, + ModelDao modelDao, + Version indexCreatedVersion, + KNNMethodContext resolvedKNNMethodContext, + KNNMethodConfigContext knnMethodConfigContext + ) { super(name); this.modelDao = modelDao; this.indexCreatedVersion = indexCreatedVersion; this.resolvedKNNMethodContext = resolvedKNNMethodContext; + this.knnMethodConfigContext = knnMethodConfigContext; } @Override @@ -214,12 +216,6 @@ protected Explicit ignoreMalformed(BuilderContext context) { return KNNVectorFieldMapper.Defaults.IGNORE_MALFORMED; } - private void validateFlatMapper() { - if (modelId.get() != null || knnMethodContext.get() != null) { - throw new IllegalArgumentException("Cannot set modelId or method parameters when index.knn setting is false"); - } - } - @Override public KNNVectorFieldMapper build(BuilderContext context) { validateFullFieldName(context); @@ -229,15 +225,13 @@ public KNNVectorFieldMapper build(BuilderContext context) { final Explicit ignoreMalformed = ignoreMalformed(context); final Map metaValue = meta.getValue(); - // Index is being created from model - String modelIdAsString = this.modelId.get(); - if (modelIdAsString != null) { + if (modelId.get() != null) { return ModelFieldMapper.createFieldMapper( buildFullName(context), name, metaValue, vectorDataType.getValue(), - modelIdAsString, + modelId.get(), multiFieldsBuilder, copyToBuilder, ignoreMalformed, @@ -248,48 +242,24 @@ public KNNVectorFieldMapper build(BuilderContext context) { ); } - // If the field mapper is using the legacy context and being constructed from another field mapper, - // the settings will be empty. See https://github.com/opendistro-for-elasticsearch/k-NN/issues/288. In this - // case, the input resolvedKNNMethodContext will be null and the settings wont exist (so flat mapper should - // be used). Otherwise, we need to check the setting. - boolean isResolvedNull = resolvedKNNMethodContext == null; - boolean isSettingPresent = KNNSettings.IS_KNN_INDEX_SETTING.exists(context.indexSettings()); - boolean isKnnSettingNotPresentOrFalse = !isSettingPresent || !KNNSettings.IS_KNN_INDEX_SETTING.get(context.indexSettings()); - if (isResolvedNull && isKnnSettingNotPresentOrFalse) { - validateFlatMapper(); + if (resolvedKNNMethodContext == null) { return FlatVectorFieldMapper.createFieldMapper( buildFullName(context), name, metaValue, - vectorDataType.getValue(), - dimension.getValue(), + KNNMethodConfigContext.builder() + .vectorDataType(vectorDataType.getValue()) + .versionCreated(indexCreatedVersion) + .dimension(dimension.getValue()) + .build(), multiFieldsBuilder, copyToBuilder, ignoreMalformed, stored.get(), - hasDocValues.get(), - indexCreatedVersion + hasDocValues.get() ); } - // See resolvedKNNMethodContext definition for explanation - if (isResolvedNull) { - resolvedKNNMethodContext = this.knnMethodContext.getValue(); - setDefaultSpaceType(resolvedKNNMethodContext, vectorDataType.getValue()); - validateSpaceType(resolvedKNNMethodContext, vectorDataType.getValue()); - validateDimensions(resolvedKNNMethodContext, vectorDataType.getValue()); - validateEncoder(resolvedKNNMethodContext, vectorDataType.getValue()); - } - - // If the knnMethodContext is null at this point, that means user built the index with the legacy k-NN - // settings to specify algo params. We need to convert this here to a KNNMethodContext so that we can - // properly configure the rest of the index - if (resolvedKNNMethodContext == null) { - resolvedKNNMethodContext = createKNNMethodContextFromLegacy(context, vectorDataType.getValue(), indexCreatedVersion); - } - - validateVectorDataType(resolvedKNNMethodContext, vectorDataType.getValue()); - resolvedKNNMethodContext.getMethodComponentContext().setIndexVersion(indexCreatedVersion); if (resolvedKNNMethodContext.getKnnEngine() == KNNEngine.LUCENE) { log.debug(String.format(Locale.ROOT, "Use [LuceneFieldMapper] mapper for field [%s]", name)); LuceneFieldMapper.CreateLuceneFieldMapperInput createLuceneFieldMapperInput = LuceneFieldMapper.CreateLuceneFieldMapperInput @@ -300,16 +270,13 @@ public KNNVectorFieldMapper build(BuilderContext context) { .ignoreMalformed(ignoreMalformed) .stored(stored.getValue()) .hasDocValues(hasDocValues.getValue()) - .vectorDataType(vectorDataType.getValue()) - .indexVersion(indexCreatedVersion) .originalKnnMethodContext(knnMethodContext.get()) .build(); return LuceneFieldMapper.createFieldMapper( buildFullName(context), metaValue, - vectorDataType.getValue(), - dimension.getValue(), resolvedKNNMethodContext, + knnMethodConfigContext, createLuceneFieldMapperInput ); } @@ -318,107 +285,17 @@ public KNNVectorFieldMapper build(BuilderContext context) { buildFullName(context), name, metaValue, - vectorDataType.getValue(), - dimension.getValue(), resolvedKNNMethodContext, + knnMethodConfigContext, knnMethodContext.get(), multiFieldsBuilder, copyToBuilder, ignoreMalformed, stored.getValue(), - hasDocValues.getValue(), - indexCreatedVersion + hasDocValues.getValue() ); } - private void validateEncoder(final KNNMethodContext knnMethodContext, final VectorDataType vectorDataType) { - if (knnMethodContext == null) { - return; - } - - if (VectorDataType.FLOAT == vectorDataType) { - return; - } - - if (knnMethodContext.getMethodComponentContext() == null) { - return; - } - - if (knnMethodContext.getMethodComponentContext().getParameters() == null) { - return; - } - - if (knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_ENCODER_PARAMETER) == null) { - return; - } - - if (knnMethodContext.getMethodComponentContext() - .getParameters() - .get(METHOD_ENCODER_PARAMETER) instanceof MethodComponentContext == false) { - return; - } - - MethodComponentContext encoderMethodComponentContext = (MethodComponentContext) knnMethodContext.getMethodComponentContext() - .getParameters() - .get(METHOD_ENCODER_PARAMETER); - - if (ENCODER_FLAT.equals(encoderMethodComponentContext.getName()) == false) { - throw new IllegalArgumentException( - String.format( - Locale.ROOT, - "%s data type does not support %s encoder", - vectorDataType.getValue(), - encoderMethodComponentContext.getName() - ) - ); - } - } - - private void setDefaultSpaceType(final KNNMethodContext knnMethodContext, final VectorDataType vectorDataType) { - if (knnMethodContext == null) { - return; - } - - if (SpaceType.UNDEFINED == knnMethodContext.getSpaceType()) { - if (VectorDataType.BINARY == vectorDataType) { - knnMethodContext.setSpaceType(SpaceType.DEFAULT_BINARY); - } else { - knnMethodContext.setSpaceType(SpaceType.DEFAULT); - } - } - } - - private void validateSpaceType(final KNNMethodContext knnMethodContext, final VectorDataType vectorDataType) { - if (knnMethodContext == null) { - return; - } - - knnMethodContext.getSpaceType().validateVectorDataType(vectorDataType); - } - - private KNNEngine validateDimensions(final KNNMethodContext knnMethodContext, final VectorDataType dataType) { - final KNNEngine knnEngine; - if (knnMethodContext != null) { - knnEngine = knnMethodContext.getKnnEngine(); - } else { - knnEngine = KNNEngine.DEFAULT; - } - if (dimension.getValue() > KNNEngine.getMaxDimensionByEngine(knnEngine)) { - throw new IllegalArgumentException( - String.format( - Locale.ROOT, - "Dimension value cannot be greater than %s for vector: %s", - KNNEngine.getMaxDimensionByEngine(knnEngine), - name - ) - ); - } - if (VectorDataType.BINARY == dataType && dimension.getValue() % 8 != 0) { - throw new IllegalArgumentException("Dimension should be multiply of 8 for binary vector data type"); - } - return knnEngine; - } - /** * Validate whether provided full field name contain any invalid characters for physical file name. * At the moment, we use a field name as a part of file name while we throw an exception @@ -458,7 +335,13 @@ public TypeParser(Supplier modelDaoSupplier) { @Override public Mapper.Builder parse(String name, Map node, ParserContext parserContext) throws MapperParsingException { - Builder builder = new KNNVectorFieldMapper.Builder(name, modelDaoSupplier.get(), parserContext.indexVersionCreated(), null); + Builder builder = new KNNVectorFieldMapper.Builder( + name, + modelDaoSupplier.get(), + parserContext.indexVersionCreated(), + null, + null + ); builder.parse(name, parserContext, node); // All parse(String name, Map node, ParserCont ); } + // Check for flat configuration + if (isKNNDisabled(parserContext.getSettings())) { + validateFromFlat(builder); + } else if (builder.modelId.get() != null) { + validateFromModel(builder); + } else { + resolveKNNMethodComponents(builder, parserContext); + validateFromKNNMethod(builder); + } + + return builder; + } + + private void validateFromFlat(KNNVectorFieldMapper.Builder builder) { + if (builder.modelId.get() != null || builder.knnMethodContext.get() != null) { + throw new IllegalArgumentException("Cannot set modelId or method parameters when index.knn setting is false"); + } + validateDimensionSet(builder); + } + + private void validateFromModel(KNNVectorFieldMapper.Builder builder) { // Dimension should not be null unless modelId is used if (builder.dimension.getValue() == UNSET_MODEL_DIMENSION_IDENTIFIER && builder.modelId.get() == null) { - throw new IllegalArgumentException(String.format(Locale.ROOT, "Dimension value missing for vector: %s", name)); + throw new IllegalArgumentException(String.format(Locale.ROOT, "Dimension value missing for vector: %s", builder.name())); } + } - return builder; + private void validateFromKNNMethod(KNNVectorFieldMapper.Builder builder) { + if (builder.resolvedKNNMethodContext != null) { + ValidationException validationException = builder.resolvedKNNMethodContext.validate(builder.knnMethodConfigContext); + if (validationException != null) { + throw validationException; + } + } + validateDimensionSet(builder); + } + + private void validateDimensionSet(KNNVectorFieldMapper.Builder builder) { + if (builder.dimension.getValue() == UNSET_MODEL_DIMENSION_IDENTIFIER) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Dimension value missing for vector: %s", builder.name())); + } + } + + private void resolveKNNMethodComponents(KNNVectorFieldMapper.Builder builder, ParserContext parserContext) { + builder.setKnnMethodConfigContext( + KNNMethodConfigContext.builder() + .vectorDataType(builder.vectorDataType.getValue()) + .versionCreated(parserContext.indexVersionCreated()) + .dimension(builder.dimension.getValue()) + .build() + ); + + // Configure method from map or legacy + builder.setResolvedKNNMethodContext( + builder.knnMethodContext.getValue() != null + ? builder.knnMethodContext.getValue() + : createKNNMethodContextFromLegacy(parserContext.getSettings(), parserContext.indexVersionCreated()) + ); + // TODO: We should remove this and set it based on the KNNMethodContext + setDefaultSpaceType(builder.resolvedKNNMethodContext, builder.vectorDataType.getValue()); + } + + private boolean isKNNDisabled(Settings settings) { + boolean isSettingPresent = KNNSettings.IS_KNN_INDEX_SETTING.exists(settings); + return !isSettingPresent || !KNNSettings.IS_KNN_INDEX_SETTING.get(settings); + } + + private void setDefaultSpaceType(final KNNMethodContext knnMethodContext, final VectorDataType vectorDataType) { + if (knnMethodContext == null) { + return; + } + + if (SpaceType.UNDEFINED == knnMethodContext.getSpaceType()) { + if (VectorDataType.BINARY == vectorDataType) { + knnMethodContext.setSpaceType(SpaceType.DEFAULT_BINARY); + } else { + knnMethodContext.setSpaceType(SpaceType.DEFAULT); + } + } } } @@ -707,11 +663,25 @@ Optional getFloatsFromContext(ParseContext context, int dimension) thro @Override public ParametrizedFieldMapper.Builder getMergeBuilder() { + // We cannot get the dimension from the model based indices at this field because the + // cluster state may not be available. So, we need to set it to null. + KNNMethodConfigContext knnMethodConfigContext; + if (fieldType().getKnnMappingConfig().getModelId().isPresent()) { + knnMethodConfigContext = null; + } else { + knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(vectorDataType) + .versionCreated(indexCreatedVersion) + .dimension(fieldType().getKnnMappingConfig().getDimension()) + .build(); + } + return new KNNVectorFieldMapper.Builder( simpleName(), modelDao, indexCreatedVersion, - fieldType().getKnnMappingConfig().getKnnMethodContext().orElse(null) + fieldType().getKnnMappingConfig().getKnnMethodContext().orElse(null), + knnMethodConfigContext ).init(this); } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java index 9cd6bb467..57a4dd062 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -20,7 +20,6 @@ import org.apache.lucene.util.BytesRef; import org.opensearch.Version; import org.opensearch.common.settings.Settings; -import org.opensearch.index.mapper.Mapper; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.KnnCircuitBreakerException; import org.opensearch.knn.index.SpaceType; @@ -32,29 +31,15 @@ import org.opensearch.knn.index.util.IndexHyperParametersUtil; import java.util.Arrays; -import java.util.Locale; import java.util.Map; -import java.util.Objects; -import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; -import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; -import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP; -import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; -import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE; -import static org.opensearch.knn.common.KNNConstants.FP16_MAX_VALUE; -import static org.opensearch.knn.common.KNNConstants.FP16_MIN_VALUE; import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_CONSTRUCTION; import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_M; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; -import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; -import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; -import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; -import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; -import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue; /** * Utility class for KNNVectorFieldMapper @@ -63,99 +48,6 @@ @NoArgsConstructor(access = AccessLevel.PRIVATE) public class KNNVectorFieldMapperUtil { - /** - * Validate the float vector value and throw exception if it is not a number or not in the finite range - * or is not within the FP16 range of [-65504 to 65504]. - * - * @param value float vector value - */ - public static void validateFP16VectorValue(float value) { - validateFloatVectorValue(value); - if (value < FP16_MIN_VALUE || value > FP16_MAX_VALUE) { - throw new IllegalArgumentException( - String.format( - Locale.ROOT, - "encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]", - ENCODER_SQ, - FAISS_SQ_ENCODER_FP16, - FP16_MIN_VALUE, - FP16_MAX_VALUE - ) - ); - } - } - - /** - * Validate the float vector value and if it is outside FP16 range, - * then it will be clipped to FP16 range of [-65504 to 65504]. - * - * @param value float vector value - * @return vector value clipped to FP16 range - */ - public static float clipVectorValueToFP16Range(float value) { - validateFloatVectorValue(value); - if (value < FP16_MIN_VALUE) return FP16_MIN_VALUE; - if (value > FP16_MAX_VALUE) return FP16_MAX_VALUE; - return value; - } - - /** - * Validates if the vector data type is supported with given method context - * - * @param methodContext methodContext - * @param vectorDataType vector data type - */ - public static void validateVectorDataType(KNNMethodContext methodContext, VectorDataType vectorDataType) { - if (VectorDataType.FLOAT == vectorDataType) { - return; - } - - if (VectorDataType.BYTE == vectorDataType) { - if (KNNEngine.LUCENE == methodContext.getKnnEngine()) { - return; - } else { - throw new IllegalArgumentException( - String.format( - Locale.ROOT, - "[%s] field with value [%s] is only supported for [%s] engine", - VECTOR_DATA_TYPE_FIELD, - vectorDataType.getValue(), - LUCENE_NAME - ) - ); - } - } - - if (VectorDataType.BINARY == vectorDataType) { - if (KNNEngine.FAISS == methodContext.getKnnEngine()) { - if (METHOD_HNSW.equals(methodContext.getMethodComponentContext().getName())) { - return; - } else { - throw new IllegalArgumentException( - String.format( - Locale.ROOT, - "[%s] field with value [%s] is only supported for [%s] method", - VECTOR_DATA_TYPE_FIELD, - vectorDataType.getValue(), - METHOD_HNSW - ) - ); - } - } else { - throw new IllegalArgumentException( - String.format( - Locale.ROOT, - "[%s] field with value [%s] is only supported for [%s] engine", - VECTOR_DATA_TYPE_FIELD, - vectorDataType.getValue(), - FAISS_NAME - ) - ); - } - } - throw new IllegalArgumentException("This line should not be reached"); - } - /** * @param knnEngine KNNEngine * @return DocValues FieldType of type Binary @@ -254,12 +146,10 @@ static boolean useLuceneKNNVectorsFormat(final Version indexCreatedVersion) { return indexCreatedVersion.onOrAfter(Version.V_2_17_0) && KNNSettings.getIsLuceneVectorFormatEnabled(); } - private static SpaceType getSpaceType(final Settings indexSettings, final VectorDataType vectorDataType) { + private static SpaceType getSpaceType(final Settings indexSettings) { String spaceType = indexSettings.get(KNNSettings.INDEX_KNN_SPACE_TYPE.getKey()); if (spaceType == null) { - spaceType = VectorDataType.BINARY == vectorDataType - ? KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE_FOR_BINARY - : KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE; + spaceType = KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE; log.info( String.format( "[KNN] The setting \"%s\" was not set for the index. Likely caused by recent version upgrade. Setting the setting to the default value=%s", @@ -303,97 +193,17 @@ private static int getEfConstruction(Settings indexSettings, Version indexVersio return Integer.parseInt(efConstruction); } - /** - * Verify mapping and return true if it is a "faiss" Index using "sq" encoder of type "fp16" - * - * @param methodComponentContext MethodComponentContext - * @return true if it is a "faiss" Index using "sq" encoder of type "fp16" - */ - static boolean isFaissSQfp16(MethodComponentContext methodComponentContext) { - if (Objects.isNull(methodComponentContext)) { - return false; - } - - if (methodComponentContext.getParameters().size() == 0) { - return false; - } - - Map methodComponentParams = methodComponentContext.getParameters(); - - // The method component parameters should have an encoder - if (!methodComponentParams.containsKey(METHOD_ENCODER_PARAMETER)) { - return false; - } - - // Validate if the object is of type MethodComponentContext before casting it later - if (!(methodComponentParams.get(METHOD_ENCODER_PARAMETER) instanceof MethodComponentContext)) { - return false; - } - - MethodComponentContext encoderMethodComponentContext = (MethodComponentContext) methodComponentParams.get(METHOD_ENCODER_PARAMETER); - - // returns true if encoder name is "sq" and type is "fp16" - return ENCODER_SQ.equals(encoderMethodComponentContext.getName()) - && FAISS_SQ_ENCODER_FP16.equals( - encoderMethodComponentContext.getParameters().getOrDefault(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16) - ); - - } - - /** - * Verify mapping and return the value of "clip" parameter(default false) for a "faiss" Index - * using "sq" encoder of type "fp16". - * - * @param methodComponentContext MethodComponentContext - * @return boolean value of "clip" parameter - */ - static boolean isFaissSQClipToFP16RangeEnabled(MethodComponentContext methodComponentContext) { - if (Objects.nonNull(methodComponentContext)) { - return (boolean) methodComponentContext.getParameters().getOrDefault(FAISS_SQ_CLIP, false); - } - return false; - } - - /** - * Extract MethodComponentContext from KNNMethodContext - * - * @param knnMethodContext KNNMethodContext - * @return MethodComponentContext - */ - static MethodComponentContext getMethodComponentContext(KNNMethodContext knnMethodContext) { - if (Objects.isNull(knnMethodContext)) { - return null; - } - return knnMethodContext.getMethodComponentContext(); - } - - static KNNMethodContext createKNNMethodContextFromLegacy( - Mapper.BuilderContext context, - VectorDataType vectorDataType, - Version indexCreatedVersion - ) { - if (VectorDataType.FLOAT != vectorDataType) { - throw new IllegalArgumentException( - String.format( - Locale.ROOT, - "[%s] field with value [%s] is not supported for [%s] engine", - VECTOR_DATA_TYPE_FIELD, - vectorDataType.getValue(), - NMSLIB_NAME - ) - ); - } - + static KNNMethodContext createKNNMethodContextFromLegacy(Settings indexSettings, Version indexCreatedVersion) { return new KNNMethodContext( KNNEngine.NMSLIB, - KNNVectorFieldMapperUtil.getSpaceType(context.indexSettings(), vectorDataType), + KNNVectorFieldMapperUtil.getSpaceType(indexSettings), new MethodComponentContext( METHOD_HNSW, Map.of( METHOD_PARAMETER_M, - KNNVectorFieldMapperUtil.getM(context.indexSettings()), + KNNVectorFieldMapperUtil.getM(indexSettings), METHOD_PARAMETER_EF_CONSTRUCTION, - KNNVectorFieldMapperUtil.getEfConstruction(context.indexSettings(), indexCreatedVersion) + KNNVectorFieldMapperUtil.getEfConstruction(indexSettings, indexCreatedVersion) ) ) ); diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index 7c3d942b6..744ba4bd5 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -7,7 +7,6 @@ import java.util.ArrayList; import java.util.List; -import java.util.Locale; import java.util.Map; import java.util.Optional; @@ -19,11 +18,12 @@ import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.VectorSimilarityFunction; -import org.opensearch.Version; import org.opensearch.common.Explicit; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNLibraryIndexingContext; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.KNNMethodContext; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForByteVector; @@ -37,36 +37,43 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { /** FieldType used for initializing VectorField, which is used for creating binary doc values. **/ private final FieldType vectorFieldType; - private final VectorDataType vectorDataType; - private PerDimensionProcessor perDimensionProcessor; - private PerDimensionValidator perDimensionValidator; - private VectorValidator vectorValidator; + private final PerDimensionProcessor perDimensionProcessor; + private final PerDimensionValidator perDimensionValidator; + private final VectorValidator vectorValidator; static LuceneFieldMapper createFieldMapper( String fullname, Map metaValue, - VectorDataType vectorDataType, - Integer dimension, KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext, CreateLuceneFieldMapperInput createLuceneFieldMapperInput ) { - final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType(fullname, metaValue, vectorDataType, new KNNMappingConfig() { - @Override - public Optional getKnnMethodContext() { - return Optional.of(knnMethodContext); + final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType( + fullname, + metaValue, + knnMethodConfigContext.getVectorDataType(), + new KNNMappingConfig() { + @Override + public Optional getKnnMethodContext() { + return Optional.of(knnMethodContext); + } + + @Override + public int getDimension() { + return knnMethodConfigContext.getDimension(); + } } + ); - @Override - public int getDimension() { - return dimension; - } - }); - - return new LuceneFieldMapper(mappedFieldType, createLuceneFieldMapperInput); + return new LuceneFieldMapper(mappedFieldType, createLuceneFieldMapperInput, knnMethodConfigContext); } - private LuceneFieldMapper(final KNNVectorFieldType mappedFieldType, final CreateLuceneFieldMapperInput input) { + private LuceneFieldMapper( + final KNNVectorFieldType mappedFieldType, + final CreateLuceneFieldMapperInput input, + KNNMethodConfigContext knnMethodConfigContext + ) { super( input.getName(), mappedFieldType, @@ -75,30 +82,18 @@ private LuceneFieldMapper(final KNNVectorFieldType mappedFieldType, final Create input.getIgnoreMalformed(), input.isStored(), input.isHasDocValues(), - input.getIndexVersion(), + knnMethodConfigContext.getVersionCreated(), mappedFieldType.knnMappingConfig.getKnnMethodContext().orElse(null) ); KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig(); KNNMethodContext knnMethodContext = knnMappingConfig.getKnnMethodContext() .orElseThrow(() -> new IllegalArgumentException("KNN method context is missing")); - vectorDataType = input.getVectorDataType(); + VectorDataType vectorDataType = mappedFieldType.getVectorDataType(); final VectorSimilarityFunction vectorSimilarityFunction = knnMethodContext.getSpaceType() .getKnnVectorSimilarityFunction() .getVectorSimilarityFunction(); - if (knnMappingConfig.getDimension() > KNNEngine.getMaxDimensionByEngine(KNNEngine.LUCENE)) { - throw new IllegalArgumentException( - String.format( - Locale.ROOT, - "Dimension value cannot be greater than [%s] but got [%s] for vector [%s]", - KNNEngine.getMaxDimensionByEngine(KNNEngine.LUCENE), - knnMappingConfig.getDimension(), - input.getName() - ) - ); - } - this.fieldType = vectorDataType.createKnnVectorFieldType(knnMappingConfig.getDimension(), vectorSimilarityFunction); if (this.hasDocValues) { @@ -107,8 +102,11 @@ private LuceneFieldMapper(final KNNVectorFieldType mappedFieldType, final Create this.vectorFieldType = null; } - initValidatorsAndProcessors(knnMethodContext); - knnMethodContext.getSpaceType().validateVectorDataType(vectorDataType); + KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine() + .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext); + this.perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor(); + this.perDimensionValidator = knnLibraryIndexingContext.getPerDimensionValidator(); + this.vectorValidator = knnLibraryIndexingContext.getVectorValidator(); } @Override @@ -141,21 +139,6 @@ protected List getFieldsForByteVector(final byte[] array) { return fieldsToBeAdded; } - private void initValidatorsAndProcessors(KNNMethodContext knnMethodContext) { - this.vectorValidator = new SpaceVectorValidator(knnMethodContext.getSpaceType()); - this.perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; - if (VectorDataType.BINARY == vectorDataType) { - this.perDimensionValidator = PerDimensionValidator.DEFAULT_BIT_VALIDATOR; - return; - } - - if (VectorDataType.BYTE == vectorDataType) { - this.perDimensionValidator = PerDimensionValidator.DEFAULT_BYTE_VALIDATOR; - return; - } - this.perDimensionValidator = PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR; - } - @Override protected VectorValidator getVectorValidator() { return vectorValidator; @@ -190,8 +173,6 @@ static class CreateLuceneFieldMapperInput { Explicit ignoreMalformed; boolean stored; boolean hasDocValues; - VectorDataType vectorDataType; - Version indexVersion; KNNMethodContext originalKnnMethodContext; } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java index cc2c43386..c602b53fb 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java @@ -8,14 +8,14 @@ import org.apache.lucene.document.FieldType; import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.VectorEncoding; -import org.opensearch.Version; import org.opensearch.common.Explicit; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNLibraryIndexingContext; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.KNNMethodContext; -import org.opensearch.knn.index.engine.MethodComponentContext; import java.io.IOException; import java.util.Map; @@ -23,49 +23,48 @@ import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; -import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.getMethodComponentContext; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.isFaissSQClipToFP16RangeEnabled; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.isFaissSQfp16; /** * Field mapper for method definition in mapping */ public class MethodFieldMapper extends KNNVectorFieldMapper { - private PerDimensionProcessor perDimensionProcessor; - private PerDimensionValidator perDimensionValidator; - private VectorValidator vectorValidator; + private final PerDimensionProcessor perDimensionProcessor; + private final PerDimensionValidator perDimensionValidator; + private final VectorValidator vectorValidator; public static MethodFieldMapper createFieldMapper( String fullname, String simpleName, Map metaValue, - VectorDataType vectorDataType, - Integer dimension, KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext, KNNMethodContext originalKNNMethodContext, MultiFields multiFields, CopyTo copyTo, Explicit ignoreMalformed, boolean stored, - boolean hasDocValues, - Version indexCreatedVersion + boolean hasDocValues ) { - final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType(fullname, metaValue, vectorDataType, new KNNMappingConfig() { - @Override - public Optional getKnnMethodContext() { - return Optional.of(knnMethodContext); - } - - @Override - public int getDimension() { - return dimension; + final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType( + fullname, + metaValue, + knnMethodConfigContext.getVectorDataType(), + new KNNMappingConfig() { + @Override + public Optional getKnnMethodContext() { + return Optional.of(knnMethodContext); + } + + @Override + public int getDimension() { + return knnMethodConfigContext.getDimension(); + } } - }); + ); return new MethodFieldMapper( simpleName, mappedFieldType, @@ -74,8 +73,8 @@ public int getDimension() { ignoreMalformed, stored, hasDocValues, - indexCreatedVersion, - originalKNNMethodContext + originalKNNMethodContext, + knnMethodConfigContext ); } @@ -87,8 +86,8 @@ private MethodFieldMapper( Explicit ignoreMalformed, boolean stored, boolean hasDocValues, - Version indexVerision, - KNNMethodContext originalKNNMethodContext + KNNMethodContext originalKNNMethodContext, + KNNMethodConfigContext knnMethodConfigContext ) { super( @@ -99,7 +98,7 @@ private MethodFieldMapper( ignoreMalformed, stored, hasDocValues, - indexVerision, + knnMethodConfigContext.getVersionCreated(), originalKNNMethodContext ); this.useLuceneBasedVectorField = KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(indexCreatedVersion); @@ -115,9 +114,15 @@ private MethodFieldMapper( KNNEngine knnEngine = knnMethodContext.getKnnEngine(); this.fieldType.putAttribute(KNN_ENGINE, knnEngine.getName()); + KNNLibraryIndexingContext knnLibraryIndexingContext = knnEngine.getKNNLibraryIndexingContext( + knnMethodContext, + knnMethodConfigContext + ); try { - Map libParams = knnEngine.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters(); - this.fieldType.putAttribute(PARAMETERS, XContentFactory.jsonBuilder().map(libParams).toString()); + this.fieldType.putAttribute( + PARAMETERS, + XContentFactory.jsonBuilder().map(knnLibraryIndexingContext.getLibraryParameters()).toString() + ); } catch (IOException ioe) { throw new RuntimeException(String.format("Unable to create KNNVectorFieldMapper: %s", ioe)); } @@ -139,43 +144,9 @@ private MethodFieldMapper( } this.fieldType.freeze(); - initValidatorsAndProcessors(knnMethodContext); - knnMethodContext.getSpaceType().validateVectorDataType(vectorDataType); - } - - private void initValidatorsAndProcessors(KNNMethodContext knnMethodContext) { - this.vectorValidator = new SpaceVectorValidator(knnMethodContext.getSpaceType()); - - if (VectorDataType.BINARY == vectorDataType) { - this.perDimensionValidator = PerDimensionValidator.DEFAULT_BIT_VALIDATOR; - this.perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; - return; - } - - if (VectorDataType.BYTE == vectorDataType) { - this.perDimensionValidator = PerDimensionValidator.DEFAULT_BYTE_VALIDATOR; - this.perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; - return; - } - - MethodComponentContext methodComponentContext = getMethodComponentContext(knnMethodContext); - if (!isFaissSQfp16(methodComponentContext)) { - // Normal float and byte processor - this.perDimensionValidator = PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR; - this.perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; - return; - } - - this.perDimensionValidator = PerDimensionValidator.DEFAULT_FP16_VALIDATOR; - - if (!isFaissSQClipToFP16RangeEnabled( - (MethodComponentContext) methodComponentContext.getParameters().get(METHOD_ENCODER_PARAMETER) - )) { - this.perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; - return; - } - - this.perDimensionProcessor = PerDimensionProcessor.CLIP_TO_FP16_PROCESSOR; + this.perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor(); + this.perDimensionValidator = knnLibraryIndexingContext.getPerDimensionValidator(); + this.vectorValidator = knnLibraryIndexingContext.getVectorValidator(); } @Override diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index 6c7e45e7e..954d6addf 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -13,6 +13,9 @@ import org.opensearch.index.mapper.ParseContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNLibraryIndexingContext; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; +import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; @@ -22,10 +25,7 @@ import java.util.Map; import java.util.Optional; -import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.isFaissSQClipToFP16RangeEnabled; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.isFaissSQfp16; /** * Field mapper for model in mapping @@ -131,7 +131,18 @@ private void initVectorValidator() { return; } ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId); - vectorValidator = new SpaceVectorValidator(modelMetadata.getSpaceType()); + + KNNMethodContext knnMethodContext = getKNNMethodContextFromModelMetadata(modelMetadata); + KNNMethodConfigContext knnMethodConfigContext = getKNNMethodConfigContextFromModelMetadata(modelMetadata); + // Need to handle BWC case + if (knnMethodContext == null || knnMethodConfigContext == null) { + vectorValidator = new SpaceVectorValidator(modelMetadata.getSpaceType()); + return; + } + + KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine() + .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext); + vectorValidator = knnLibraryIndexingContext.getVectorValidator(); } private void initPerDimensionValidator() { @@ -139,25 +150,25 @@ private void initPerDimensionValidator() { return; } ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId); - MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext(); - VectorDataType dataType = modelMetadata.getVectorDataType(); - - if (VectorDataType.BINARY == dataType) { - perDimensionValidator = PerDimensionValidator.DEFAULT_BIT_VALIDATOR; - return; - } - if (VectorDataType.BYTE == dataType) { - perDimensionValidator = PerDimensionValidator.DEFAULT_BYTE_VALIDATOR; - return; - } + KNNMethodContext knnMethodContext = getKNNMethodContextFromModelMetadata(modelMetadata); + KNNMethodConfigContext knnMethodConfigContext = getKNNMethodConfigContextFromModelMetadata(modelMetadata); + // Need to handle BWC case + if (knnMethodContext == null || knnMethodConfigContext == null) { + if (modelMetadata.getVectorDataType() == VectorDataType.BINARY) { + perDimensionValidator = PerDimensionValidator.DEFAULT_BIT_VALIDATOR; + } else if (modelMetadata.getVectorDataType() == VectorDataType.BYTE) { + perDimensionValidator = PerDimensionValidator.DEFAULT_BYTE_VALIDATOR; + } else { + perDimensionValidator = PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR; + } - if (!isFaissSQfp16(methodComponentContext)) { - perDimensionValidator = PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR; return; } - perDimensionValidator = PerDimensionValidator.DEFAULT_FP16_VALIDATOR; + KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine() + .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext); + perDimensionValidator = knnLibraryIndexingContext.getPerDimensionValidator(); } private void initPerDimensionProcessor() { @@ -165,31 +176,18 @@ private void initPerDimensionProcessor() { return; } ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId); - MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext(); - VectorDataType dataType = modelMetadata.getVectorDataType(); - - if (VectorDataType.BINARY == dataType) { - perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; - return; - } - if (VectorDataType.BYTE == dataType) { + KNNMethodContext knnMethodContext = getKNNMethodContextFromModelMetadata(modelMetadata); + KNNMethodConfigContext knnMethodConfigContext = getKNNMethodConfigContextFromModelMetadata(modelMetadata); + // Need to handle BWC case + if (knnMethodContext == null || knnMethodConfigContext == null) { perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; return; } - if (!isFaissSQfp16(methodComponentContext)) { - perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; - return; - } - - if (!isFaissSQClipToFP16RangeEnabled( - (MethodComponentContext) methodComponentContext.getParameters().get(METHOD_ENCODER_PARAMETER) - )) { - perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; - return; - } - perDimensionProcessor = PerDimensionProcessor.CLIP_TO_FP16_PROCESSOR; + KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine() + .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext); + perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor(); } @Override @@ -214,6 +212,27 @@ protected void parseCreateField(ParseContext context) throws IOException { parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getVectorDataType()); } + private static KNNMethodContext getKNNMethodContextFromModelMetadata(ModelMetadata modelMetadata) { + MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext(); + if (methodComponentContext == MethodComponentContext.EMPTY) { + return null; + } + return new KNNMethodContext(modelMetadata.getKnnEngine(), modelMetadata.getSpaceType(), methodComponentContext); + } + + private static KNNMethodConfigContext getKNNMethodConfigContextFromModelMetadata(ModelMetadata modelMetadata) { + MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext(); + if (methodComponentContext == MethodComponentContext.EMPTY) { + return null; + } + // TODO: Need to fix this version check by serializing the model + return KNNMethodConfigContext.builder() + .vectorDataType(modelMetadata.getVectorDataType()) + .dimension(modelMetadata.getDimension()) + .versionCreated(Version.V_2_14_0) + .build(); + } + private static ModelMetadata getModelMetadata(ModelDao modelDao, String modelId) { ModelMetadata modelMetadata = modelDao.getMetadata(modelId); if (!ModelUtil.isModelCreated(modelMetadata)) { diff --git a/src/main/java/org/opensearch/knn/index/mapper/PerDimensionProcessor.java b/src/main/java/org/opensearch/knn/index/mapper/PerDimensionProcessor.java index 21139f2ad..9a3bbfb6b 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/PerDimensionProcessor.java +++ b/src/main/java/org/opensearch/knn/index/mapper/PerDimensionProcessor.java @@ -5,8 +5,6 @@ package org.opensearch.knn.index.mapper; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.clipVectorValueToFP16Range; - /** * Process values per dimension. Good to have if we want to do some kind of cleanup on data as it is coming in. */ @@ -34,18 +32,4 @@ default float processByte(float value) { PerDimensionProcessor NOOP_PROCESSOR = new PerDimensionProcessor() { }; - - // If the encoder parameter, "clip" is set to True, if the vector value is outside the FP16 range then it will be - // clipped to FP16 range. - PerDimensionProcessor CLIP_TO_FP16_PROCESSOR = new PerDimensionProcessor() { - @Override - public float process(float value) { - return clipVectorValueToFP16Range(value); - } - - @Override - public float processByte(float value) { - throw new IllegalStateException("CLIP_TO_FP16_PROCESSOR should not be called with byte type"); - } - }; } diff --git a/src/main/java/org/opensearch/knn/index/mapper/PerDimensionValidator.java b/src/main/java/org/opensearch/knn/index/mapper/PerDimensionValidator.java index 2ca0761c0..60d8540c6 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/PerDimensionValidator.java +++ b/src/main/java/org/opensearch/knn/index/mapper/PerDimensionValidator.java @@ -9,7 +9,6 @@ import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFP16VectorValue; /** * Validates per dimension fields @@ -41,19 +40,6 @@ public void validateByte(float value) { } }; - // Validates if it is a finite number and within the fp16 range of [-65504 to 65504]. - PerDimensionValidator DEFAULT_FP16_VALIDATOR = new PerDimensionValidator() { - @Override - public void validate(float value) { - validateFP16VectorValue(value); - } - - @Override - public void validateByte(float value) { - throw new IllegalStateException("DEFAULT_FP16_VALIDATOR should only be used for float vectors"); - } - }; - PerDimensionValidator DEFAULT_BYTE_VALIDATOR = new PerDimensionValidator() { @Override public void validate(float value) { diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index af8e410d4..61dba45c8 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -23,6 +23,7 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.model.QueryContext; import org.opensearch.knn.index.mapper.KNNMappingConfig; import org.opensearch.knn.index.mapper.KNNVectorFieldType; @@ -423,7 +424,8 @@ protected Query doToQuery(QueryShardContext context) { QueryContext queryContext = new QueryContext(vectorQueryType); ValidationException validationException = validateParameters( engineSpecificMethodContext.supportedMethodParameters(queryContext), - (Map) methodParameters + (Map) methodParameters, + KNNMethodConfigContext.EMPTY ); if (validationException != null) { throw new IllegalArgumentException( diff --git a/src/main/java/org/opensearch/knn/index/util/IndexUtil.java b/src/main/java/org/opensearch/knn/index/util/IndexUtil.java index 15e1959f8..853b5237a 100644 --- a/src/main/java/org/opensearch/knn/index/util/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/util/IndexUtil.java @@ -344,7 +344,7 @@ public static boolean isBinaryIndex(VectorDataType vectorDataType) { */ public static void updateVectorDataTypeToParameters(Map parameters, VectorDataType vectorDataType) { if (VectorDataType.BINARY == vectorDataType) { - parameters.put(VECTOR_DATA_TYPE_FIELD, VectorDataType.BINARY.getValue()); + parameters.put(VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); } } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java index 9464ea806..3634d13f0 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java @@ -11,6 +11,8 @@ package org.opensearch.knn.plugin.transport; +import lombok.Getter; +import org.opensearch.Version; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.cluster.metadata.IndexMetadata; @@ -19,17 +21,18 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.indices.ModelDao; -import org.opensearch.knn.training.VectorSpaceInfo; import java.io.IOException; /** * Request to train and serialize a model */ +@Getter public class TrainingModelRequest extends ActionRequest { private static ClusterService clusterService; @@ -37,16 +40,15 @@ public class TrainingModelRequest extends ActionRequest { private final String modelId; private final KNNMethodContext knnMethodContext; + private final KNNMethodConfigContext knnMethodConfigContext; private final int dimension; private final String trainingIndex; private final String trainingField; private final String preferredNodeId; private final String description; private final VectorDataType vectorDataType; - private int maximumVectorCount; private int searchSize; - private int trainingDataSizeInKB; /** @@ -87,6 +89,11 @@ public TrainingModelRequest( // Training data size in kilobytes. By default, this is invalid (it cant have negative kb). It eventually gets // calculated in transit. A user cannot set this value directly. this.trainingDataSizeInKB = -1; + this.knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(vectorDataType) + .dimension(dimension) + .versionCreated(Version.CURRENT) + .build(); } /** @@ -112,6 +119,11 @@ public TrainingModelRequest(StreamInput in) throws IOException { } else { this.vectorDataType = VectorDataType.DEFAULT; } + this.knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(vectorDataType) + .dimension(dimension) + .versionCreated(in.getVersion()) + .build(); } /** @@ -125,79 +137,6 @@ public static void initialize(ModelDao modelDao, ClusterService clusterService) TrainingModelRequest.clusterService = clusterService; } - /** - * Getter for modelId - * - * @return modelId - */ - public String getModelId() { - return modelId; - } - - /** - * Getter for knnMethodContext - * - * @return knnMethodContext - */ - public KNNMethodContext getKnnMethodContext() { - return knnMethodContext; - } - - /** - * Getter for dimension - * - * @return dimension - */ - public int getDimension() { - return dimension; - } - - /** - * Getter for trainingIndex - * - * @return trainingIndex - */ - public String getTrainingIndex() { - return trainingIndex; - } - - /** - * Getter for trainingField - * - * @return trainingField - */ - public String getTrainingField() { - return trainingField; - } - - /** - * Getter for preferredNodeId - * - * @return preferredNodeId - */ - public String getPreferredNodeId() { - return preferredNodeId; - } - - /** - * Getter description of the model - * - * @return description - */ - public String getDescription() { - return description; - } - - /** - * Getter for maximum vector count. This corresponds to the maximum number of vectors from the training index - * a user wants to use for training. - * - * @return maximumVectorCount - */ - public int getMaximumVectorCount() { - return maximumVectorCount; - } - /** * Setter for maximum vector count * @@ -212,20 +151,6 @@ public void setMaximumVectorCount(int maximumVectorCount) { this.maximumVectorCount = maximumVectorCount; } - /** - * Getter for search size. This value corresponds to how many vectors are pulled from the training index per - * search request - * - * @return searchSize - */ - public int getSearchSize() { - return searchSize; - } - - public VectorDataType getVectorDataType() { - return vectorDataType; - } - /** * Setter for search size. * @@ -240,15 +165,6 @@ public void setSearchSize(int searchSize) { this.searchSize = searchSize; } - /** - * Getter for training data size in kilobytes. - * - * @return trainingDataSizeInKB - */ - public int getTrainingDataSizeInKB() { - return trainingDataSizeInKB; - } - /** * Setter for trainingDataSizeInKB. Package private to prevent users from changing this value directly. * @@ -289,13 +205,7 @@ public ActionRequestValidationException validate() { } // Confirm that the passed in knnMethodContext is valid and requires training - ValidationException validationException = this.knnMethodContext.validate(); - if (validationException != null) { - exception = new ActionRequestValidationException(); - exception.addValidationErrors(validationException.validationErrors()); - } - - validationException = this.knnMethodContext.validateWithData(new VectorSpaceInfo(dimension)); + ValidationException validationException = this.knnMethodContext.validate(knnMethodConfigContext); if (validationException != null) { exception = new ActionRequestValidationException(); exception.addValidationErrors(validationException.validationErrors()); diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java index a9eca609d..963142c1f 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java @@ -11,11 +11,13 @@ package org.opensearch.knn.plugin.transport; +import org.opensearch.Version; import org.opensearch.core.action.ActionListener; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryEntryContext; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; @@ -57,7 +59,14 @@ protected void doExecute(Task task, TrainingModelRequest request, ActionListener // Allocation representing size model will occupy in memory during training NativeMemoryEntryContext.AnonymousEntryContext modelAnonymousEntryContext = new NativeMemoryEntryContext.AnonymousEntryContext( - request.getKnnMethodContext().estimateOverheadInKB(request.getDimension()), + request.getKnnMethodContext() + .estimateOverheadInKB( + KNNMethodConfigContext.builder() + .dimension(request.getDimension()) + .vectorDataType(request.getVectorDataType()) + .versionCreated(Version.CURRENT) + .build() + ), NativeMemoryLoadStrategy.AnonymousLoadStrategy.getInstance() ); @@ -67,10 +76,9 @@ protected void doExecute(Task task, TrainingModelRequest request, ActionListener NativeMemoryCacheManager.getInstance(), trainingDataEntryContext, modelAnonymousEntryContext, - request.getDimension(), + request.getKnnMethodConfigContext(), request.getDescription(), - clusterService.localNode().getEphemeralId(), - request.getVectorDataType() + clusterService.localNode().getEphemeralId() ); KNNCounter.TRAINING_REQUESTS.increment(); diff --git a/src/main/java/org/opensearch/knn/training/TrainingJob.java b/src/main/java/org/opensearch/knn/training/TrainingJob.java index 3bdb50ad0..e30d860db 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingJob.java +++ b/src/main/java/org/opensearch/knn/training/TrainingJob.java @@ -11,14 +11,14 @@ package org.opensearch.knn.training; +import lombok.Getter; import org.apache.commons.lang.StringUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.common.UUIDs; import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.memory.NativeMemoryAllocation; @@ -34,8 +34,6 @@ import java.util.Map; import java.util.Objects; -import static org.opensearch.knn.index.engine.faiss.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; - /** * Encapsulates all information required to generate and train a model. */ @@ -44,12 +42,15 @@ public class TrainingJob implements Runnable { public static Logger logger = LogManager.getLogger(TrainingJob.class); private final KNNMethodContext knnMethodContext; + private final KNNMethodConfigContext knnMethodConfigContext; private final NativeMemoryCacheManager nativeMemoryCacheManager; private final NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext; private final NativeMemoryEntryContext.AnonymousEntryContext modelAnonymousEntryContext; + @Getter private final Model model; - private String modelId; + @Getter + private final String modelId; /** * Constructor. @@ -59,7 +60,6 @@ public class TrainingJob implements Runnable { * @param nativeMemoryCacheManager Cache manager loads training data into native memory. * @param trainingDataEntryContext Training data configuration * @param modelAnonymousEntryContext Model allocation context - * @param dimension model's dimension * @param description user provided description of the model. */ public TrainingJob( @@ -68,14 +68,14 @@ public TrainingJob( NativeMemoryCacheManager nativeMemoryCacheManager, NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext, NativeMemoryEntryContext.AnonymousEntryContext modelAnonymousEntryContext, - int dimension, + KNNMethodConfigContext knnMethodConfigContext, String description, - String nodeAssignment, - VectorDataType vectorDataType + String nodeAssignment ) { // Generate random base64 string if one is not provided this.modelId = StringUtils.isNotBlank(modelId) ? modelId : UUIDs.randomBase64UUID(); this.knnMethodContext = Objects.requireNonNull(knnMethodContext, "MethodContext cannot be null."); + this.knnMethodConfigContext = knnMethodConfigContext; this.nativeMemoryCacheManager = Objects.requireNonNull(nativeMemoryCacheManager, "NativeMemoryCacheManager cannot be null."); this.trainingDataEntryContext = Objects.requireNonNull(trainingDataEntryContext, "TrainingDataEntryContext cannot be null."); this.modelAnonymousEntryContext = Objects.requireNonNull(modelAnonymousEntryContext, "AnonymousEntryContext cannot be null."); @@ -83,38 +83,20 @@ public TrainingJob( new ModelMetadata( knnMethodContext.getKnnEngine(), knnMethodContext.getSpaceType(), - dimension, + knnMethodConfigContext.getDimension(), ModelState.TRAINING, ZonedDateTime.now(ZoneOffset.UTC).toString(), description, "", nodeAssignment, knnMethodContext.getMethodComponentContext(), - vectorDataType + knnMethodConfigContext.getVectorDataType() ), null, this.modelId ); } - /** - * Getter for model id. - * - * @return modelId - */ - public String getModelId() { - return modelId; - } - - /** - * Getter for model - * - * @return model - */ - public Model getModel() { - return model; - } - @Override public void run() { NativeMemoryAllocation trainingDataAllocation = null; @@ -181,25 +163,15 @@ public void run() { if (trainingDataAllocation.isClosed()) { throw new RuntimeException("Unable to load training data into memory: allocation is already closed"); } - setVersionInKnnMethodContext(); Map trainParameters = model.getModelMetadata() .getKnnEngine() - .getKNNLibraryIndexingContext(knnMethodContext) + .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext) .getLibraryParameters(); trainParameters.put( KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY) ); - if (VectorDataType.BINARY == model.getModelMetadata().getVectorDataType()) { - trainParameters.put( - KNNConstants.INDEX_DESCRIPTION_PARAMETER, - FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + trainParameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString() - ); - } - - IndexUtil.updateVectorDataTypeToParameters(trainParameters, model.getModelMetadata().getVectorDataType()); - byte[] modelBlob = JNIService.trainIndex( trainParameters, model.getModelMetadata().getDimension(), @@ -227,10 +199,4 @@ public void run() { nativeMemoryCacheManager.invalidate(modelAnonymousEntryContext.getKey()); } } - - private void setVersionInKnnMethodContext() { - // We are picking up the node version here. For more details why we did this please check below conversation - // Ref: https://github.com/opensearch-project/k-NN/pull/1353#discussion_r1434428542 - knnMethodContext.getMethodComponentContext().setIndexVersion(trainingDataEntryContext.getClusterService().localNode().getVersion()); - } } diff --git a/src/main/java/org/opensearch/knn/training/VectorSpaceInfo.java b/src/main/java/org/opensearch/knn/training/VectorSpaceInfo.java deleted file mode 100644 index 13843486d..000000000 --- a/src/main/java/org/opensearch/knn/training/VectorSpaceInfo.java +++ /dev/null @@ -1,26 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.knn.training; - -import lombok.AllArgsConstructor; -import lombok.Getter; -import lombok.Setter; - -/** - * A data spec containing relevant information for validation. - */ -@Getter -@Setter -@AllArgsConstructor -public class VectorSpaceInfo { - private int dimension; -} diff --git a/src/test/java/org/opensearch/knn/KNNTestCase.java b/src/test/java/org/opensearch/knn/KNNTestCase.java index fb09fb30b..2aa11a247 100644 --- a/src/test/java/org/opensearch/knn/KNNTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNTestCase.java @@ -7,7 +7,6 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.opensearch.Version; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; @@ -103,16 +102,17 @@ public Map xContentBuilderToMap(XContentBuilder xContentBuilder) public static KNNMethodContext getDefaultKNNMethodContext() { MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); - KNNMethodContext defaultInstance = new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, methodComponentContext); - methodComponentContext.setIndexVersion(Version.CURRENT); - return defaultInstance; + return new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, methodComponentContext); + } + + public static KNNMethodContext getDefaultByteKNNMethodContext() { + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); + return new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, methodComponentContext); } public static KNNMethodContext getDefaultBinaryKNNMethodContext() { MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); - KNNMethodContext defaultInstance = new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT_BINARY, methodComponentContext); - methodComponentContext.setIndexVersion(Version.CURRENT); - return defaultInstance; + return new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT_BINARY, methodComponentContext); } public static KNNMappingConfig getMappingConfigForMethodMapping(KNNMethodContext knnMethodContext, int dimension) { diff --git a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java index 81a5ab142..c1d3b47c3 100644 --- a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java +++ b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java @@ -326,8 +326,8 @@ public void testVectorMappingValidation_invalidDimension() { containsString( "Dimension value cannot be greater than " + KNNEngine.getMaxDimensionByEngine(KNNEngine.DEFAULT) - + " for vector: " - + FIELD_NAME + + " for vector with engine: " + + KNNEngine.DEFAULT.getName() ) ); } diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java index ec2d49de0..0b045e848 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java @@ -32,9 +32,7 @@ import java.util.Map; import static org.opensearch.knn.common.KNNConstants.DIMENSION; -import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; -import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; @@ -245,18 +243,7 @@ public void testByteVectorDataTypeWithNmslibEngine() { ResponseException.class, () -> createKnnIndexMappingWithNmslibEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()) ); - assertTrue( - ex.getMessage() - .contains( - String.format( - Locale.ROOT, - "[%s] field with value [%s] is only supported for [%s] engine", - VECTOR_DATA_TYPE_FIELD, - VectorDataType.BYTE.getValue(), - LUCENE_NAME - ) - ) - ); + assertTrue(ex.getMessage().contains("is not supported for vector data type")); } @SneakyThrows @@ -278,18 +265,7 @@ public void testByteVectorDataTypeWithLegacyFieldMapperKnnIndexSetting() { String mapping = builder.toString(); ResponseException ex = expectThrows(ResponseException.class, () -> createKnnIndex(INDEX_NAME, mapping)); - assertTrue( - ex.getMessage() - .contains( - String.format( - Locale.ROOT, - "[%s] field with value [%s] is not supported for [%s] engine", - VECTOR_DATA_TYPE_FIELD, - VectorDataType.BYTE.getValue(), - NMSLIB_NAME - ) - ) - ); + assertTrue(ex.getMessage(), ex.getMessage().contains("is not supported for vector data type")); } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java index e87531561..4c235a896 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -25,6 +25,7 @@ import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.vectorvalues.TestVectorValues; @@ -202,6 +203,10 @@ public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException .codec(codec) .build(); + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .versionCreated(Version.CURRENT) + .build(); KNNMethodContext knnMethodContext = new KNNMethodContext( knnEngine, spaceType, @@ -209,7 +214,7 @@ public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException ); String parameterString = XContentFactory.jsonBuilder() - .map(knnEngine.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters()) + .map(knnEngine.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext).getLibraryParameters()) .toString(); FieldInfo[] fieldInfoArray = new FieldInfo[] { @@ -267,15 +272,18 @@ public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException .codec(codec) .build(); + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .versionCreated(Version.CURRENT) + .build(); KNNMethodContext knnMethodContext = new KNNMethodContext( knnEngine, spaceType, new MethodComponentContext(METHOD_HNSW, ImmutableMap.of(METHOD_PARAMETER_M, 16, METHOD_PARAMETER_EF_CONSTRUCTION, 512)) ); - knnMethodContext.getMethodComponentContext().setIndexVersion(Version.CURRENT); String parameterString = XContentFactory.jsonBuilder() - .map(knnEngine.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters()) + .map(knnEngine.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext).getLibraryParameters()) .toString(); FieldInfo[] fieldInfoArray = new FieldInfo[] { @@ -334,15 +342,18 @@ public void testAddKNNBinaryField_whenFaissBinary_thenAdded() throws IOException .codec(codec) .build(); + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.BINARY) + .versionCreated(Version.CURRENT) + .build(); KNNMethodContext knnMethodContext = new KNNMethodContext( knnEngine, spaceType, new MethodComponentContext(METHOD_HNSW, ImmutableMap.of(METHOD_PARAMETER_M, 16, METHOD_PARAMETER_EF_CONSTRUCTION, 512)) ); - knnMethodContext.getMethodComponentContext().setIndexVersion(Version.CURRENT); String parameterString = XContentFactory.jsonBuilder() - .map(knnEngine.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters()) + .map(knnEngine.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext).getLibraryParameters()) .toString(); FieldInfo[] fieldInfoArray = new FieldInfo[] { diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index bf2c33bf9..1158d3ebb 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -16,13 +16,13 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.join.BitSetProducer; -import org.opensearch.Version; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.engine.MethodComponentContext; @@ -93,16 +93,23 @@ public class KNNCodecTestCase extends KNNTestCase { private static final FieldType sampleFieldType; static { + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .versionCreated(CURRENT) + .vectorDataType(VectorDataType.DEFAULT) + .build(); KNNMethodContext knnMethodContext = new KNNMethodContext( KNNEngine.DEFAULT, SpaceType.DEFAULT, new MethodComponentContext(METHOD_HNSW, ImmutableMap.of(METHOD_PARAMETER_M, 16, METHOD_PARAMETER_EF_CONSTRUCTION, 512)) ); - knnMethodContext.getMethodComponentContext().setIndexVersion(Version.CURRENT); String parameterString; try { parameterString = XContentFactory.jsonBuilder() - .map(knnMethodContext.getKnnEngine().getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters()) + .map( + knnMethodContext.getKnnEngine() + .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext) + .getLibraryParameters() + ) .toString(); } catch (IOException e) { throw new RuntimeException(e); diff --git a/src/test/java/org/opensearch/knn/index/engine/AbstractKNNLibraryTests.java b/src/test/java/org/opensearch/knn/index/engine/AbstractKNNLibraryTests.java index c6ab9ccdb..6f8f9afe5 100644 --- a/src/test/java/org/opensearch/knn/index/engine/AbstractKNNLibraryTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/AbstractKNNLibraryTests.java @@ -27,24 +27,25 @@ public class AbstractKNNLibraryTests extends KNNTestCase { private final static String CURRENT_VERSION = "test-version"; private final static String INVALID_METHOD_THROWS_VALIDATION_NAME = "test-method-1"; private final static KNNMethod INVALID_METHOD_THROWS_VALIDATION = new AbstractKNNMethod( - MethodComponent.Builder.builder(INVALID_METHOD_THROWS_VALIDATION_NAME).build(), + MethodComponent.Builder.builder(INVALID_METHOD_THROWS_VALIDATION_NAME).addSupportedDataTypes(Set.of(VectorDataType.FLOAT)).build(), Set.of(SpaceType.DEFAULT), new DefaultHnswSearchContext() ) { @Override - public ValidationException validate(KNNMethodContext knnMethodContext) { + public ValidationException validate(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) { return new ValidationException(); } }; private final static String VALID_METHOD_NAME = "test-method-2"; private final static KNNLibrarySearchContext VALID_METHOD_CONTEXT = ctx -> ImmutableMap.of( "myparameter", - new Parameter.BooleanParameter("myparameter", null, value -> true) + new Parameter.BooleanParameter("myparameter", null, (v, context) -> true) ); private final static Map VALID_EXPECTED_MAP = ImmutableMap.of("test-key", "test-param"); private final static KNNMethod VALID_METHOD = new AbstractKNNMethod( MethodComponent.Builder.builder(VALID_METHOD_NAME) - .setMapGenerator((methodComponent, methodComponentContext) -> VALID_EXPECTED_MAP) + .setMapGenerator((methodComponent, methodComponentContext, knnMethodConfigContext) -> VALID_EXPECTED_MAP) + .addSupportedDataTypes(Set.of(VectorDataType.FLOAT)) .build(), Set.of(SpaceType.DEFAULT), VALID_METHOD_CONTEXT @@ -60,17 +61,22 @@ public void testGetVersion() { } public void testValidateMethod() throws IOException { + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .versionCreated(org.opensearch.Version.CURRENT) + .dimension(10) + .vectorDataType(VectorDataType.FLOAT) + .build(); // Invalid - method not supported XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().field(NAME, "invalid").endObject(); Map in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext1 = KNNMethodContext.parse(in); - expectThrows(IllegalArgumentException.class, () -> TEST_LIBRARY.validateMethod(knnMethodContext1)); + assertNotNull(TEST_LIBRARY.validateMethod(knnMethodContext1, knnMethodConfigContext)); // Invalid - method validation xContentBuilder = XContentFactory.jsonBuilder().startObject().field(NAME, INVALID_METHOD_THROWS_VALIDATION_NAME).endObject(); in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext2 = KNNMethodContext.parse(in); - assertNotNull(TEST_LIBRARY.validateMethod(knnMethodContext2)); + expectThrows(IllegalStateException.class, () -> TEST_LIBRARY.validateMethod(knnMethodContext2, knnMethodConfigContext)); } public void testEngineSpecificMethods() { @@ -84,15 +90,24 @@ public void testEngineSpecificMethods() { } public void testGetKNNLibraryIndexingContext() { + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .versionCreated(org.opensearch.Version.CURRENT) + .dimension(10) + .vectorDataType(VectorDataType.FLOAT) + .build(); // Check that map is expected Map expectedMap = new HashMap<>(VALID_EXPECTED_MAP); expectedMap.put(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue()); + expectedMap.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue()); KNNMethodContext knnMethodContext = new KNNMethodContext( KNNEngine.DEFAULT, SpaceType.DEFAULT, new MethodComponentContext(VALID_METHOD_NAME, Collections.emptyMap()) ); - assertEquals(expectedMap, TEST_LIBRARY.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters()); + assertEquals( + expectedMap, + TEST_LIBRARY.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext).getLibraryParameters() + ); // Check when invalid method is passed in KNNMethodContext invalidKnnMethodContext = new KNNMethodContext( @@ -100,7 +115,10 @@ public void testGetKNNLibraryIndexingContext() { SpaceType.DEFAULT, new MethodComponentContext("invalid", Collections.emptyMap()) ); - expectThrows(IllegalArgumentException.class, () -> TEST_LIBRARY.getKNNLibraryIndexingContext(invalidKnnMethodContext)); + expectThrows( + IllegalArgumentException.class, + () -> TEST_LIBRARY.getKNNLibraryIndexingContext(invalidKnnMethodContext, knnMethodConfigContext) + ); } private static class TestAbstractKNNLibrary extends AbstractKNNLibrary { @@ -133,7 +151,7 @@ public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { } @Override - public int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension) { + public int estimateOverheadInKB(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) { return 0; } diff --git a/src/test/java/org/opensearch/knn/index/engine/AbstractKNNMethodTests.java b/src/test/java/org/opensearch/knn/index/engine/AbstractKNNMethodTests.java index 2c739c6f7..4d743c42a 100644 --- a/src/test/java/org/opensearch/knn/index/engine/AbstractKNNMethodTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/AbstractKNNMethodTests.java @@ -11,7 +11,7 @@ import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.training.VectorSpaceInfo; +import org.opensearch.knn.index.VectorDataType; import java.io.IOException; import java.util.HashMap; @@ -49,9 +49,14 @@ public void testHasSpace() { * Test KNNMethod validate */ public void testValidate() throws IOException { + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .versionCreated(org.opensearch.Version.CURRENT) + .dimension(10) + .vectorDataType(VectorDataType.FLOAT) + .build(); String methodName = "test-method"; KNNMethod knnMethod = new TestKNNMethod( - MethodComponent.Builder.builder(methodName).build(), + MethodComponent.Builder.builder(methodName).addSupportedDataTypes(Set.of(VectorDataType.FLOAT)).build(), Set.of(SpaceType.L2), EMPTY_ENGINE_SPECIFIC_CONTEXT ); @@ -64,7 +69,7 @@ public void testValidate() throws IOException { .endObject(); Map in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext1 = KNNMethodContext.parse(in); - assertNotNull(knnMethod.validate(knnMethodContext1)); + assertNotNull(knnMethod.validate(knnMethodContext1, knnMethodConfigContext)); // Invalid methodComponent xContentBuilder = XContentFactory.jsonBuilder() @@ -78,7 +83,7 @@ public void testValidate() throws IOException { in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext2 = KNNMethodContext.parse(in); - assertNotNull(knnMethod.validate(knnMethodContext2)); + assertNotNull(knnMethod.validate(knnMethodContext2, knnMethodConfigContext)); // Valid everything xContentBuilder = XContentFactory.jsonBuilder() @@ -88,22 +93,25 @@ public void testValidate() throws IOException { .endObject(); in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext3 = KNNMethodContext.parse(in); - assertNull(knnMethod.validate(knnMethodContext3)); + assertNull(knnMethod.validate(knnMethodContext3, knnMethodConfigContext)); } /** * Test KNNMethod validateWithData */ - public void testValidateWithData() throws IOException { + public void testValidateWithContext() throws IOException { + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .versionCreated(org.opensearch.Version.CURRENT) + .dimension(4) + .vectorDataType(VectorDataType.FLOAT) + .build(); String methodName = "test-method"; KNNMethod knnMethod = new TestKNNMethod( - MethodComponent.Builder.builder(methodName).build(), + MethodComponent.Builder.builder(methodName).addSupportedDataTypes(Set.of(VectorDataType.FLOAT)).build(), Set.of(SpaceType.L2), EMPTY_ENGINE_SPECIFIC_CONTEXT ); - VectorSpaceInfo testVectorSpaceInfo = new VectorSpaceInfo(4); - // Invalid space XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() .startObject() @@ -112,7 +120,7 @@ public void testValidateWithData() throws IOException { .endObject(); Map in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext1 = KNNMethodContext.parse(in); - assertNotNull(knnMethod.validateWithData(knnMethodContext1, testVectorSpaceInfo)); + assertNotNull(knnMethod.validate(knnMethodContext1, knnMethodConfigContext)); // Invalid methodComponent xContentBuilder = XContentFactory.jsonBuilder() @@ -125,8 +133,7 @@ public void testValidateWithData() throws IOException { .endObject(); in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext2 = KNNMethodContext.parse(in); - - assertNotNull(knnMethod.validateWithData(knnMethodContext2, testVectorSpaceInfo)); + assertNotNull(knnMethod.validate(knnMethodContext2, knnMethodConfigContext)); // Valid everything xContentBuilder = XContentFactory.jsonBuilder() @@ -136,26 +143,33 @@ public void testValidateWithData() throws IOException { .endObject(); in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext3 = KNNMethodContext.parse(in); - assertNull(knnMethod.validateWithData(knnMethodContext3, testVectorSpaceInfo)); + assertNull(knnMethod.validate(knnMethodContext3, knnMethodConfigContext)); } public void testGetKNNLibraryIndexingContext() { + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .versionCreated(org.opensearch.Version.CURRENT) + .dimension(4) + .vectorDataType(VectorDataType.FLOAT) + .build(); SpaceType spaceType = SpaceType.DEFAULT; String methodName = "test-method"; Map generatedMap = ImmutableMap.of("test-key", "test-value"); MethodComponent methodComponent = MethodComponent.Builder.builder(methodName) - .setMapGenerator(((methodComponent1, methodComponentContext) -> methodComponentContext.getParameters())) + .setMapGenerator(((methodComponent1, methodComponentContext, methodConfigContext) -> methodComponentContext.getParameters())) .build(); KNNMethod knnMethod = new TestKNNMethod(methodComponent, Set.of(SpaceType.L2), EMPTY_ENGINE_SPECIFIC_CONTEXT); Map expectedMap = new HashMap<>(generatedMap); expectedMap.put(KNNConstants.SPACE_TYPE, spaceType.getValue()); + expectedMap.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue()); assertEquals( expectedMap, knnMethod.getKNNLibraryIndexingContext( - new KNNMethodContext(KNNEngine.DEFAULT, spaceType, new MethodComponentContext(methodName, generatedMap)) + new KNNMethodContext(KNNEngine.DEFAULT, spaceType, new MethodComponentContext(methodName, generatedMap)), + knnMethodConfigContext ).getLibraryParameters() ); } diff --git a/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java b/src/test/java/org/opensearch/knn/index/engine/KNNMethodContextTests.java similarity index 80% rename from src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java rename to src/test/java/org/opensearch/knn/index/engine/KNNMethodContextTests.java index f71fbaae0..6defa4c50 100644 --- a/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/KNNMethodContextTests.java @@ -1,26 +1,21 @@ /* + * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. */ -package org.opensearch.knn.index; +package org.opensearch.knn.index.engine; +import org.opensearch.Version; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import com.google.common.collect.ImmutableMap; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.index.mapper.MapperParsingException; -import org.opensearch.knn.index.engine.KNNMethodContext; -import org.opensearch.knn.index.engine.MethodComponentContext; import java.io.IOException; import java.util.Collections; @@ -93,23 +88,25 @@ public void testGetSpaceType() { * Test KNNMethodContext validation */ public void testValidate() { - // Check valid default - this should not throw any exception - assertNull(getDefaultKNNMethodContext().validate()); - // Check a valid nmslib method MethodComponentContext hnswMethod = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .dimension(2) + .versionCreated(Version.CURRENT) + .build(); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.NMSLIB, SpaceType.L2, hnswMethod); - assertNull(knnMethodContext.validate()); + assertNull(knnMethodContext.validate(knnMethodConfigContext)); // Check invalid parameter nmslib hnswMethod = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of("invalid", 111)); KNNMethodContext knnMethodContext1 = new KNNMethodContext(KNNEngine.NMSLIB, SpaceType.L2, hnswMethod); - assertNotNull(knnMethodContext1.validate()); + assertNotNull(knnMethodContext1.validate(knnMethodConfigContext)); // Check invalid method nmslib MethodComponentContext invalidMethod = new MethodComponentContext("invalid", Collections.emptyMap()); KNNMethodContext knnMethodContext2 = new KNNMethodContext(KNNEngine.NMSLIB, SpaceType.L2, invalidMethod); - expectThrows(IllegalArgumentException.class, knnMethodContext2::validate); + assertNotNull(knnMethodContext2.validate(knnMethodConfigContext)); } /** @@ -146,16 +143,26 @@ public void testRequiresTraining() { public void testEstimateOverheadInKB_whenMethodIsHNSWFlatNmslib_thenSizeIsExpectedValue() { // For HNSW no encoding we expect 0 MethodComponentContext hnswMethod = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .dimension(2) + .versionCreated(Version.CURRENT) + .build(); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.NMSLIB, SpaceType.L2, hnswMethod); - assertEquals(0, knnMethodContext.estimateOverheadInKB(1000)); + assertEquals(0, knnMethodContext.estimateOverheadInKB(knnMethodConfigContext)); } public void testEstimateOverheadInKB_whenMethodIsHNSWFlatFaiss_thenSizeIsExpectedValue() { // For HNSW no encoding we expect 0 MethodComponentContext hnswMethod = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .dimension(168) + .versionCreated(Version.CURRENT) + .build(); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, hnswMethod); - assertEquals(0, knnMethodContext.estimateOverheadInKB(168)); + assertEquals(0, knnMethodContext.estimateOverheadInKB(knnMethodConfigContext)); } @@ -172,8 +179,13 @@ public void testEstimateOverheadInKB_whenMethodIsHNSWPQFaiss_thenSizeIsExpectedV METHOD_HNSW, ImmutableMap.of(METHOD_ENCODER_PARAMETER, pqMethodContext) ); + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .dimension(dimension) + .versionCreated(Version.CURRENT) + .build(); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, hnswMethodPq); - assertEquals(expectedHnswPq, knnMethodContext.estimateOverheadInKB(dimension)); + assertEquals(expectedHnswPq, knnMethodContext.estimateOverheadInKB(knnMethodConfigContext)); } public void testEstimateOverheadInKB_whenMethodIsIVFFlatFaiss_thenSizeIsExpectedValue() { @@ -183,8 +195,13 @@ public void testEstimateOverheadInKB_whenMethodIsIVFFlatFaiss_thenSizeIsExpected int expectedIvf = 4 * nlists * dimension / BYTES_PER_KILOBYTES + 1; MethodComponentContext ivfMethod = new MethodComponentContext(METHOD_IVF, ImmutableMap.of(METHOD_PARAMETER_NLIST, nlists)); + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .dimension(dimension) + .versionCreated(Version.CURRENT) + .build(); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, ivfMethod); - assertEquals(expectedIvf, knnMethodContext.estimateOverheadInKB(dimension)); + assertEquals(expectedIvf, knnMethodContext.estimateOverheadInKB(knnMethodConfigContext)); } public void testEstimateOverheadInKB_whenMethodIsIVFPQFaiss_thenSizeIsExpectedValue() { @@ -206,8 +223,13 @@ public void testEstimateOverheadInKB_whenMethodIsIVFPQFaiss_thenSizeIsExpectedVa METHOD_IVF, ImmutableMap.of(METHOD_PARAMETER_NLIST, nlists, METHOD_ENCODER_PARAMETER, pqMethodContext) ); + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .dimension(dimension) + .versionCreated(Version.CURRENT) + .build(); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, ivfMethodPq); - assertEquals(expectedIvfPq, knnMethodContext.estimateOverheadInKB(dimension)); + assertEquals(expectedIvfPq, knnMethodContext.estimateOverheadInKB(knnMethodConfigContext)); } /** @@ -411,4 +433,62 @@ public void testHashCode() { assertNotEquals(methodContext1.hashCode(), methodContext4.hashCode()); assertNotEquals(methodContext1.hashCode(), methodContext5.hashCode()); } + + public void testValidateVectorDataType_whenBinaryFaissHNSW_thenValid() { + validateValidateVectorDataType(KNNEngine.FAISS, KNNConstants.METHOD_HNSW, VectorDataType.BINARY, SpaceType.HAMMING, null); + } + + public void testValidateVectorDataType_whenBinaryNonFaiss_thenException() { + validateValidateVectorDataType( + KNNEngine.LUCENE, + KNNConstants.METHOD_HNSW, + VectorDataType.BINARY, + SpaceType.HAMMING, + "UnsupportedMethod" + ); + validateValidateVectorDataType( + KNNEngine.NMSLIB, + KNNConstants.METHOD_HNSW, + VectorDataType.BINARY, + SpaceType.HAMMING, + "UnsupportedMethod" + ); + } + + public void testValidateVectorDataType_whenByteLucene_thenValid() { + validateValidateVectorDataType(KNNEngine.LUCENE, KNNConstants.METHOD_HNSW, VectorDataType.BYTE, SpaceType.L2, null); + } + + public void testValidateVectorDataType_whenByteNonLucene_thenException() { + validateValidateVectorDataType(KNNEngine.FAISS, KNNConstants.METHOD_HNSW, VectorDataType.BYTE, SpaceType.L2, "UnsupportedMethod"); + validateValidateVectorDataType(KNNEngine.NMSLIB, KNNConstants.METHOD_IVF, VectorDataType.BYTE, SpaceType.L2, "UnsupportedMethod"); + } + + public void testValidateVectorDataType_whenFloat_thenValid() { + validateValidateVectorDataType(KNNEngine.FAISS, KNNConstants.METHOD_HNSW, VectorDataType.FLOAT, SpaceType.L2, null); + validateValidateVectorDataType(KNNEngine.LUCENE, KNNConstants.METHOD_HNSW, VectorDataType.FLOAT, SpaceType.L2, null); + validateValidateVectorDataType(KNNEngine.NMSLIB, KNNConstants.METHOD_HNSW, VectorDataType.FLOAT, SpaceType.L2, null); + } + + private void validateValidateVectorDataType( + final KNNEngine knnEngine, + final String methodName, + final VectorDataType vectorDataType, + final SpaceType spaceType, + final String expectedErrMsg + ) { + MethodComponentContext methodComponentContext = new MethodComponentContext(methodName, Collections.emptyMap()); + KNNMethodContext methodContext = new KNNMethodContext(knnEngine, spaceType, methodComponentContext); + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(vectorDataType) + .dimension(8) + .versionCreated(Version.CURRENT) + .build(); + if (expectedErrMsg == null) { + assertNull(methodContext.validate(knnMethodConfigContext)); + } else { + assertNotNull(methodContext.validate(knnMethodConfigContext)); + } + } + } diff --git a/src/test/java/org/opensearch/knn/index/engine/MethodComponentTests.java b/src/test/java/org/opensearch/knn/index/engine/MethodComponentTests.java index c72f59be4..a5c72f5ee 100644 --- a/src/test/java/org/opensearch/knn/index/engine/MethodComponentTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/MethodComponentTests.java @@ -6,12 +6,15 @@ package org.opensearch.knn.index.engine; import com.google.common.collect.ImmutableMap; +import org.opensearch.Version; import org.opensearch.knn.KNNTestCase; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.knn.index.VectorDataType; import java.io.IOException; import java.util.Map; +import java.util.Set; import static org.opensearch.knn.common.KNNConstants.NAME; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; @@ -33,7 +36,7 @@ public void testGetParameters() { String name = "test"; String paramKey = "key"; MethodComponent methodComponent = MethodComponent.Builder.builder(name) - .addParameter(paramKey, new Parameter.IntegerParameter(paramKey, 1, v -> v > 0)) + .addParameter(paramKey, new Parameter.IntegerParameter(paramKey, 1, (v, context) -> v > 0)) .build(); assertEquals(1, methodComponent.getParameters().size()); assertTrue(methodComponent.getParameters().containsKey(paramKey)); @@ -52,11 +55,18 @@ public void testValidate() throws IOException { .field("invalid", "invalid") .endObject() .endObject(); + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .dimension(1) + .versionCreated(Version.CURRENT) + .vectorDataType(VectorDataType.FLOAT) + .build(); Map in = xContentBuilderToMap(xContentBuilder); MethodComponentContext componentContext1 = MethodComponentContext.parse(in); - MethodComponent methodComponent1 = MethodComponent.Builder.builder(methodName).build(); - assertNotNull(methodComponent1.validate(componentContext1)); + MethodComponent methodComponent1 = MethodComponent.Builder.builder(methodName) + .addSupportedDataTypes(Set.of(VectorDataType.FLOAT)) + .build(); + assertNotNull(methodComponent1.validate(componentContext1, knnMethodConfigContext)); // Invalid parameter type xContentBuilder = XContentFactory.jsonBuilder() @@ -70,9 +80,10 @@ public void testValidate() throws IOException { MethodComponentContext componentContext2 = MethodComponentContext.parse(in); MethodComponent methodComponent2 = MethodComponent.Builder.builder(methodName) - .addParameter("valid", new Parameter.IntegerParameter("valid", 1, v -> v > 0)) + .addSupportedDataTypes(Set.of(VectorDataType.FLOAT)) + .addParameter("valid", new Parameter.IntegerParameter("valid", 1, (v, context) -> v > 0)) .build(); - assertNotNull(methodComponent2.validate(componentContext2)); + assertNotNull(methodComponent2.validate(componentContext2, knnMethodConfigContext)); // valid configuration xContentBuilder = XContentFactory.jsonBuilder() @@ -87,10 +98,11 @@ public void testValidate() throws IOException { MethodComponentContext componentContext3 = MethodComponentContext.parse(in); MethodComponent methodComponent3 = MethodComponent.Builder.builder(methodName) - .addParameter("valid1", new Parameter.IntegerParameter("valid1", 1, v -> v > 0)) - .addParameter("valid2", new Parameter.IntegerParameter("valid2", 1, v -> v > 0)) + .addSupportedDataTypes(Set.of(VectorDataType.FLOAT)) + .addParameter("valid1", new Parameter.IntegerParameter("valid1", 1, (v, context) -> v > 0)) + .addParameter("valid2", new Parameter.IntegerParameter("valid2", 1, (v, context) -> v > 0)) .build(); - assertNull(methodComponent3.validate(componentContext3)); + assertNull(methodComponent3.validate(componentContext3, knnMethodConfigContext)); // valid configuration - empty parameters xContentBuilder = XContentFactory.jsonBuilder().startObject().field(NAME, methodName).endObject(); @@ -98,10 +110,11 @@ public void testValidate() throws IOException { MethodComponentContext componentContext4 = MethodComponentContext.parse(in); MethodComponent methodComponent4 = MethodComponent.Builder.builder(methodName) - .addParameter("valid1", new Parameter.IntegerParameter("valid1", 1, v -> v > 0)) - .addParameter("valid2", new Parameter.IntegerParameter("valid2", 1, v -> v > 0)) + .addSupportedDataTypes(Set.of(VectorDataType.FLOAT)) + .addParameter("valid1", new Parameter.IntegerParameter("valid1", 1, (v, context) -> v > 0)) + .addParameter("valid2", new Parameter.IntegerParameter("valid2", 1, (v, context) -> v > 0)) .build(); - assertNull(methodComponent4.validate(componentContext4)); + assertNull(methodComponent4.validate(componentContext4, knnMethodConfigContext)); } @SuppressWarnings("unchecked") @@ -113,8 +126,8 @@ public void testGetAsMap_withoutGenerator() throws IOException { int default2 = 5; MethodComponent methodComponent = MethodComponent.Builder.builder(methodName) - .addParameter(parameterName1, new Parameter.IntegerParameter(parameterName1, default1, v -> v > 0)) - .addParameter(parameterName2, new Parameter.IntegerParameter(parameterName2, default2, v -> v > 0)) + .addParameter(parameterName1, new Parameter.IntegerParameter(parameterName1, default1, (v, context) -> v > 0)) + .addParameter(parameterName2, new Parameter.IntegerParameter(parameterName2, default2, (v, context) -> v > 0)) .build(); XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() @@ -128,13 +141,19 @@ public void testGetAsMap_withoutGenerator() throws IOException { Map in = xContentBuilderToMap(xContentBuilder); MethodComponentContext methodComponentContext = MethodComponentContext.parse(in); - assertEquals(in, methodComponent.getAsMap(methodComponentContext)); + assertEquals( + in, + methodComponent.getAsMap(methodComponentContext, KNNMethodConfigContext.builder().versionCreated(Version.CURRENT).build()) + ); xContentBuilder = XContentFactory.jsonBuilder().startObject().field(NAME, methodName).endObject(); in = xContentBuilderToMap(xContentBuilder); methodComponentContext = MethodComponentContext.parse(in); - Map methodAsMap = methodComponent.getAsMap(methodComponentContext); + Map methodAsMap = methodComponent.getAsMap( + methodComponentContext, + KNNMethodConfigContext.builder().versionCreated(Version.CURRENT).build() + ); assertEquals(default1, ((Map) methodAsMap.get(PARAMETERS)).get(parameterName1)); assertEquals(default2, ((Map) methodAsMap.get(PARAMETERS)).get(parameterName2)); } @@ -143,16 +162,19 @@ public void testGetAsMap_withGenerator() throws IOException { String methodName = "test-method"; Map generatedMap = ImmutableMap.of("test-key", "test-value"); MethodComponent methodComponent = MethodComponent.Builder.builder(methodName) - .addParameter("valid1", new Parameter.IntegerParameter("valid1", 1, v -> v > 0)) - .addParameter("valid2", new Parameter.IntegerParameter("valid2", 1, v -> v > 0)) - .setMapGenerator((methodComponent1, methodComponentContext) -> generatedMap) + .addParameter("valid1", new Parameter.IntegerParameter("valid1", 1, (v, context) -> v > 0)) + .addParameter("valid2", new Parameter.IntegerParameter("valid2", 1, (v, context) -> v > 0)) + .setMapGenerator((methodComponent1, methodComponentContext, knnMethodConfigContext) -> generatedMap) .build(); XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().field(NAME, methodName).endObject(); Map in = xContentBuilderToMap(xContentBuilder); MethodComponentContext methodComponentContext = MethodComponentContext.parse(in); - assertEquals(generatedMap, methodComponent.getAsMap(methodComponentContext)); + assertEquals( + generatedMap, + methodComponent.getAsMap(methodComponentContext, KNNMethodConfigContext.builder().versionCreated(Version.CURRENT).build()) + ); } public void testBuilder() { @@ -163,15 +185,18 @@ public void testBuilder() { assertEquals(0, methodComponent.getParameters().size()); assertEquals(name, methodComponent.getName()); - builder.addParameter("test", new Parameter.IntegerParameter("test", 1, v -> v > 0)); + builder.addParameter("test", new Parameter.IntegerParameter("test", 1, (v, context) -> v > 0)); methodComponent = builder.build(); assertEquals(1, methodComponent.getParameters().size()); Map generatedMap = ImmutableMap.of("test-key", "test-value"); - builder.setMapGenerator((methodComponent1, methodComponentContext) -> generatedMap); + builder.setMapGenerator((methodComponent1, methodComponentContext, knnMethodConfigContext) -> generatedMap); methodComponent = builder.build(); - assertEquals(generatedMap, methodComponent.getAsMap(null)); + assertEquals( + generatedMap, + methodComponent.getAsMap(null, KNNMethodConfigContext.builder().versionCreated(Version.CURRENT).build()) + ); } } diff --git a/src/test/java/org/opensearch/knn/index/engine/ParameterTests.java b/src/test/java/org/opensearch/knn/index/engine/ParameterTests.java index 7224e2824..9f3979314 100644 --- a/src/test/java/org/opensearch/knn/index/engine/ParameterTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/ParameterTests.java @@ -6,14 +6,16 @@ package org.opensearch.knn.index.engine; import com.google.common.collect.ImmutableMap; +import org.opensearch.Version; import org.opensearch.knn.KNNTestCase; import org.opensearch.common.ValidationException; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.Parameter.IntegerParameter; import org.opensearch.knn.index.engine.Parameter.StringParameter; import org.opensearch.knn.index.engine.Parameter.MethodComponentContextParameter; -import org.opensearch.knn.training.VectorSpaceInfo; import java.util.Map; +import java.util.Set; public class ParameterTests extends KNNTestCase { /** @@ -21,17 +23,11 @@ public class ParameterTests extends KNNTestCase { */ public void testGetDefaultValue() { String defaultValue = "test-default"; - Parameter parameter = new Parameter("test", defaultValue, v -> true) { + Parameter parameter = new Parameter("test", defaultValue, (v, context) -> true) { @Override - public ValidationException validate(Object value) { + public ValidationException validate(Object value, KNNMethodConfigContext context) { return null; } - - @Override - public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) { - return null; - } - }; assertEquals(defaultValue, parameter.getDefaultValue()); @@ -41,95 +37,97 @@ public ValidationException validateWithData(Object value, VectorSpaceInfo vector * Test integer parameter validate */ public void testIntegerParameter_validate() { - final IntegerParameter parameter = new IntegerParameter("test", 1, v -> v > 0); - + final IntegerParameter parameter = new IntegerParameter("test", 1, (v, context) -> v > 0); + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .dimension(1) + .versionCreated(Version.CURRENT) + .vectorDataType(VectorDataType.FLOAT) + .build(); // Invalid type - assertNotNull(parameter.validate("String")); + assertNotNull(parameter.validate("String", knnMethodConfigContext)); // Invalid value - assertNotNull(parameter.validate(-1)); + assertNotNull(parameter.validate(-1, knnMethodConfigContext)); // valid value - assertNull(parameter.validate(12)); + assertNull(parameter.validate(12, knnMethodConfigContext)); } /** * Test integer parameter validate */ - public void testIntegerParameter_validateWithData() { - final IntegerParameter parameter = new IntegerParameter( - "test", - 1, - v -> v > 0, - (v, vectorSpaceInfo) -> v > vectorSpaceInfo.getDimension() - ); + public void testIntegerParameter_validateWithContext() { + final IntegerParameter parameter = new IntegerParameter("test", 1, (v, context) -> v > 0 && v > context.getDimension()); - VectorSpaceInfo testVectorSpaceInfo = new VectorSpaceInfo(0); + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder().dimension(0).build(); // Invalid type - assertNotNull(parameter.validateWithData("String", testVectorSpaceInfo)); + assertNotNull(parameter.validate("String", knnMethodConfigContext)); // Invalid value - assertNotNull(parameter.validateWithData(-1, testVectorSpaceInfo)); + assertNotNull(parameter.validate(-1, knnMethodConfigContext)); // valid value - assertNull(parameter.validateWithData(12, testVectorSpaceInfo)); + assertNull(parameter.validate(12, knnMethodConfigContext)); } public void testStringParameter_validate() { - final StringParameter parameter = new StringParameter("test_parameter", "default_value", v -> "test".equals(v)); - + final StringParameter parameter = new StringParameter("test_parameter", "default_value", (v, context) -> "test".equals(v)); + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .dimension(1) + .versionCreated(Version.CURRENT) + .vectorDataType(VectorDataType.FLOAT) + .build(); // Invalid type - assertNotNull(parameter.validate(5)); + assertNotNull(parameter.validate(5, knnMethodConfigContext)); // null - assertNotNull(parameter.validate(null)); + assertNotNull(parameter.validate(null, knnMethodConfigContext)); // valid value - assertNull(parameter.validate("test")); + assertNull(parameter.validate("test", knnMethodConfigContext)); } public void testStringParameter_validateWithData() { - final StringParameter parameter = new StringParameter( - "test_parameter", - "default_value", - v -> "test".equals(v), - (v, vectorSpaceInfo) -> { - if (vectorSpaceInfo.getDimension() > 0) { - return "test".equals(v); - } - return false; + final StringParameter parameter = new StringParameter("test_parameter", "default_value", (v, context) -> { + if (context.getDimension() > 0) { + return "test".equals(v); } - ); + return false; + }); - VectorSpaceInfo testVectorSpaceInfo = new VectorSpaceInfo(1); + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder().dimension(1).build(); // Invalid type - assertNotNull(parameter.validateWithData(5, testVectorSpaceInfo)); + assertNotNull(parameter.validate(5, knnMethodConfigContext)); // null - assertNotNull(parameter.validateWithData(null, testVectorSpaceInfo)); + assertNotNull(parameter.validate(null, knnMethodConfigContext)); // valid value - assertNull(parameter.validateWithData("test", testVectorSpaceInfo)); + assertNull(parameter.validate("test", knnMethodConfigContext)); - testVectorSpaceInfo.setDimension(0); + knnMethodConfigContext.setDimension(0); // invalid value - assertNotNull(parameter.validateWithData("test", testVectorSpaceInfo)); + assertNotNull(parameter.validate("test", knnMethodConfigContext)); } public void testDoubleParameter_validate() { - final Parameter.DoubleParameter parameter = new Parameter.DoubleParameter("test_parameter", 1.0, v -> v >= 0); - + final Parameter.DoubleParameter parameter = new Parameter.DoubleParameter("test_parameter", 1.0, (v, context) -> v >= 0); + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .dimension(1) + .versionCreated(Version.CURRENT) + .vectorDataType(VectorDataType.FLOAT) + .build(); // valid value - assertNull(parameter.validate(0.9)); + assertNull(parameter.validate(0.9, knnMethodConfigContext)); // Invalid type - assertNotNull(parameter.validate(true)); + assertNotNull(parameter.validate(true, knnMethodConfigContext)); // Invalid type - assertNotNull(parameter.validate(-1)); + assertNotNull(parameter.validate(-1, knnMethodConfigContext)); } @@ -137,20 +135,19 @@ public void testDoubleParameter_validateWithData() { final Parameter.DoubleParameter parameter = new Parameter.DoubleParameter( "test", 1.0, - v -> v > 0, - (v, vectorSpaceInfo) -> v > vectorSpaceInfo.getDimension() + (v, context) -> v > 0 && v > context.getDimension() ); - VectorSpaceInfo testVectorSpaceInfo = new VectorSpaceInfo(0); + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder().dimension(0).build(); // Invalid type - assertNotNull(parameter.validateWithData("String", testVectorSpaceInfo)); + assertNotNull(parameter.validate("String", knnMethodConfigContext)); // Invalid value - assertNotNull(parameter.validateWithData(-1, testVectorSpaceInfo)); + assertNotNull(parameter.validate(-1, knnMethodConfigContext)); // valid value - assertNull(parameter.validateWithData(1.2, testVectorSpaceInfo)); + assertNull(parameter.validate(1.2, knnMethodConfigContext)); } public void testMethodComponentContextParameter_validate() { @@ -161,10 +158,17 @@ public void testMethodComponentContextParameter_validate() { Map defaultParameterMap = ImmutableMap.of(parameterKey1, parameterValue1); MethodComponentContext methodComponentContext = new MethodComponentContext(methodComponentName1, defaultParameterMap); + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .dimension(1) + .versionCreated(Version.CURRENT) + .vectorDataType(VectorDataType.FLOAT) + .build(); + Map methodComponentMap = ImmutableMap.of( methodComponentName1, MethodComponent.Builder.builder(parameterKey1) - .addParameter(parameterKey1, new IntegerParameter(parameterKey1, 1, v -> v > 0)) + .addSupportedDataTypes(Set.of(VectorDataType.FLOAT)) + .addParameter(parameterKey1, new IntegerParameter(parameterKey1, 1, (v, context) -> v > 0)) .build() ); @@ -175,26 +179,26 @@ public void testMethodComponentContextParameter_validate() { ); // Invalid type - assertNotNull(parameter.validate(17)); - assertNotNull(parameter.validate("invalid-value")); + assertNotNull(parameter.validate(17, knnMethodConfigContext)); + assertNotNull(parameter.validate("invalid-value", knnMethodConfigContext)); // Invalid value String invalidMethodComponentName = "invalid-method"; MethodComponentContext invalidMethodComponentContext1 = new MethodComponentContext(invalidMethodComponentName, defaultParameterMap); - assertNotNull(parameter.validate(invalidMethodComponentContext1)); + assertNotNull(parameter.validate(invalidMethodComponentContext1, knnMethodConfigContext)); String invalidParameterKey = "invalid-parameter"; Map invalidParameterMap1 = ImmutableMap.of(invalidParameterKey, parameterValue1); MethodComponentContext invalidMethodComponentContext2 = new MethodComponentContext(methodComponentName1, invalidParameterMap1); - assertNotNull(parameter.validate(invalidMethodComponentContext2)); + assertNotNull(parameter.validate(invalidMethodComponentContext2, knnMethodConfigContext)); String invalidParameterValue = "invalid-value"; Map invalidParameterMap2 = ImmutableMap.of(parameterKey1, invalidParameterValue); MethodComponentContext invalidMethodComponentContext3 = new MethodComponentContext(methodComponentName1, invalidParameterMap2); - assertNotNull(parameter.validate(invalidMethodComponentContext3)); + assertNotNull(parameter.validate(invalidMethodComponentContext3, knnMethodConfigContext)); // valid value - assertNull(parameter.validate(methodComponentContext)); + assertNull(parameter.validate(methodComponentContext, knnMethodConfigContext)); } public void testMethodComponentContextParameter_validateWithData() { @@ -208,10 +212,8 @@ public void testMethodComponentContextParameter_validateWithData() { Map methodComponentMap = ImmutableMap.of( methodComponentName1, MethodComponent.Builder.builder(parameterKey1) - .addParameter( - parameterKey1, - new IntegerParameter(parameterKey1, 1, v -> v > 0, (v, vectorSpaceInfo) -> v > vectorSpaceInfo.getDimension()) - ) + .addSupportedDataTypes(Set.of(VectorDataType.FLOAT)) + .addParameter(parameterKey1, new IntegerParameter(parameterKey1, 1, (v, context) -> v > 0 && v > context.getDimension())) .build() ); @@ -221,29 +223,32 @@ public void testMethodComponentContextParameter_validateWithData() { methodComponentMap ); - VectorSpaceInfo testVectorSpaceInfo = new VectorSpaceInfo(0); + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .dimension(0) + .vectorDataType(VectorDataType.FLOAT) + .versionCreated(Version.CURRENT) + .build(); // Invalid type - assertNotNull(parameter.validateWithData(17, testVectorSpaceInfo)); - assertNotNull(parameter.validateWithData("invalid-value", testVectorSpaceInfo)); + assertNotNull(parameter.validate("invalid-value", knnMethodConfigContext)); // Invalid value String invalidMethodComponentName = "invalid-method"; MethodComponentContext invalidMethodComponentContext1 = new MethodComponentContext(invalidMethodComponentName, defaultParameterMap); - assertNotNull(parameter.validateWithData(invalidMethodComponentContext1, testVectorSpaceInfo)); + assertNotNull(parameter.validate(invalidMethodComponentContext1, knnMethodConfigContext)); String invalidParameterKey = "invalid-parameter"; Map invalidParameterMap1 = ImmutableMap.of(invalidParameterKey, parameterValue1); MethodComponentContext invalidMethodComponentContext2 = new MethodComponentContext(methodComponentName1, invalidParameterMap1); - assertNotNull(parameter.validateWithData(invalidMethodComponentContext2, testVectorSpaceInfo)); + assertNotNull(parameter.validate(invalidMethodComponentContext2, knnMethodConfigContext)); String invalidParameterValue = "invalid-value"; Map invalidParameterMap2 = ImmutableMap.of(parameterKey1, invalidParameterValue); MethodComponentContext invalidMethodComponentContext3 = new MethodComponentContext(methodComponentName1, invalidParameterMap2); - assertNotNull(parameter.validateWithData(invalidMethodComponentContext3, testVectorSpaceInfo)); + assertNotNull(parameter.validate(invalidMethodComponentContext3, knnMethodConfigContext)); // valid value - assertNull(parameter.validateWithData(methodComponentContext, testVectorSpaceInfo)); + assertNull(parameter.validate(methodComponentContext, knnMethodConfigContext)); } public void testMethodComponentContextParameter_getMethodComponent() { @@ -257,7 +262,7 @@ public void testMethodComponentContextParameter_getMethodComponent() { Map methodComponentMap = ImmutableMap.of( methodComponentName1, MethodComponent.Builder.builder(parameterKey1) - .addParameter(parameterKey1, new IntegerParameter(parameterKey1, 1, v -> v > 0)) + .addParameter(parameterKey1, new IntegerParameter(parameterKey1, 1, (v, context) -> v > 0)) .build() ); diff --git a/src/test/java/org/opensearch/knn/index/engine/faiss/FaissFP16UtilTests.java b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissFP16UtilTests.java new file mode 100644 index 000000000..81afef877 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissFP16UtilTests.java @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine.faiss; + +import org.opensearch.knn.KNNTestCase; + +import java.util.Locale; + +import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; +import static org.opensearch.knn.common.KNNConstants.FP16_MAX_VALUE; +import static org.opensearch.knn.common.KNNConstants.FP16_MIN_VALUE; +import static org.opensearch.knn.index.engine.faiss.FaissFP16Util.clipVectorValueToFP16Range; +import static org.opensearch.knn.index.engine.faiss.FaissFP16Util.validateFP16VectorValue; + +public class FaissFP16UtilTests extends KNNTestCase { + + public void testValidateFp16VectorValue_outOfRange_throwsException() { + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> validateFP16VectorValue(65505.25f)); + assertTrue( + ex.getMessage() + .contains( + String.format( + Locale.ROOT, + "encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]", + ENCODER_SQ, + FAISS_SQ_ENCODER_FP16, + FP16_MIN_VALUE, + FP16_MAX_VALUE + ) + ) + ); + + IllegalArgumentException ex1 = expectThrows(IllegalArgumentException.class, () -> validateFP16VectorValue(-65525.65f)); + assertTrue( + ex1.getMessage() + .contains( + String.format( + Locale.ROOT, + "encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]", + ENCODER_SQ, + FAISS_SQ_ENCODER_FP16, + FP16_MIN_VALUE, + FP16_MAX_VALUE + ) + ) + ); + } + + public void testClipVectorValuetoFP16Range_succeed() { + assertEquals(65504.0f, clipVectorValueToFP16Range(65504.10f), 0.0f); + assertEquals(65504.0f, clipVectorValueToFP16Range(1000000.89f), 0.0f); + assertEquals(-65504.0f, clipVectorValueToFP16Range(-65504.10f), 0.0f); + assertEquals(-65504.0f, clipVectorValueToFP16Range(-1000000.89f), 0.0f); + } + +} diff --git a/src/test/java/org/opensearch/knn/index/engine/faiss/FaissTests.java b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissTests.java index af5086491..5dfd6c58c 100644 --- a/src/test/java/org/opensearch/knn/index/engine/faiss/FaissTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissTests.java @@ -10,6 +10,8 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.engine.MethodComponent; import org.opensearch.knn.index.engine.MethodComponentContext; @@ -40,6 +42,12 @@ public class FaissTests extends KNNTestCase { public void testGetMethodAsMap_whenMethodIsHNSWFlat_thenCreateCorrectIndexDescription() throws IOException { + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .versionCreated(org.opensearch.Version.CURRENT) + .dimension(4) + .vectorDataType(VectorDataType.FLOAT) + .build(); + int mParam = 65; String expectedIndexDescription = String.format(Locale.ROOT, "HNSW%d,Flat", mParam); @@ -53,15 +61,20 @@ public void testGetMethodAsMap_whenMethodIsHNSWFlat_thenCreateCorrectIndexDescri .endObject(); Map in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); - knnMethodContext.getMethodComponentContext().setIndexVersion(Version.CURRENT); - Map map = Faiss.INSTANCE.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters(); + Map map = Faiss.INSTANCE.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext) + .getLibraryParameters(); assertTrue(map.containsKey(INDEX_DESCRIPTION_PARAMETER)); assertEquals(expectedIndexDescription, map.get(INDEX_DESCRIPTION_PARAMETER)); } public void testGetMethodAsMap_whenMethodIsHNSWPQ_thenCreateCorrectIndexDescription() throws IOException { + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .versionCreated(org.opensearch.Version.CURRENT) + .dimension(4) + .vectorDataType(VectorDataType.FLOAT) + .build(); int hnswMParam = 65; int pqMParam = 17; String expectedIndexDescription = String.format(Locale.ROOT, "HNSW%d,PQ%d", hnswMParam, pqMParam); @@ -82,9 +95,9 @@ public void testGetMethodAsMap_whenMethodIsHNSWPQ_thenCreateCorrectIndexDescript .endObject(); Map in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); - knnMethodContext.getMethodComponentContext().setIndexVersion(Version.CURRENT); - Map map = Faiss.INSTANCE.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters(); + Map map = Faiss.INSTANCE.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext) + .getLibraryParameters(); assertTrue(map.containsKey(INDEX_DESCRIPTION_PARAMETER)); assertEquals(expectedIndexDescription, map.get(INDEX_DESCRIPTION_PARAMETER)); @@ -92,6 +105,11 @@ public void testGetMethodAsMap_whenMethodIsHNSWPQ_thenCreateCorrectIndexDescript @SneakyThrows public void testGetMethodAsMap_whenMethodIsHNSWSQFP16_thenCreateCorrectIndexDescription() { + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .versionCreated(org.opensearch.Version.CURRENT) + .dimension(4) + .vectorDataType(VectorDataType.FLOAT) + .build(); int hnswMParam = 65; String expectedIndexDescription = String.format(Locale.ROOT, "HNSW%d,SQfp16", hnswMParam); @@ -111,15 +129,20 @@ public void testGetMethodAsMap_whenMethodIsHNSWSQFP16_thenCreateCorrectIndexDesc .endObject(); Map in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); - knnMethodContext.getMethodComponentContext().setIndexVersion(Version.CURRENT); - Map map = Faiss.INSTANCE.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters(); + Map map = Faiss.INSTANCE.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext) + .getLibraryParameters(); assertTrue(map.containsKey(INDEX_DESCRIPTION_PARAMETER)); assertEquals(expectedIndexDescription, map.get(INDEX_DESCRIPTION_PARAMETER)); } public void testGetMethodAsMap_whenMethodIsIVFFlat_thenCreateCorrectIndexDescription() throws IOException { + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .versionCreated(org.opensearch.Version.CURRENT) + .dimension(4) + .vectorDataType(VectorDataType.FLOAT) + .build(); int nlists = 88; String expectedIndexDescription = String.format(Locale.ROOT, "IVF%d,Flat", nlists); @@ -134,13 +157,19 @@ public void testGetMethodAsMap_whenMethodIsIVFFlat_thenCreateCorrectIndexDescrip Map in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); - Map map = Faiss.INSTANCE.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters(); + Map map = Faiss.INSTANCE.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext) + .getLibraryParameters(); assertTrue(map.containsKey(INDEX_DESCRIPTION_PARAMETER)); assertEquals(expectedIndexDescription, map.get(INDEX_DESCRIPTION_PARAMETER)); } public void testGetMethodAsMap_whenMethodIsIVFPQ_thenCreateCorrectIndexDescription() throws IOException { + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .versionCreated(org.opensearch.Version.CURRENT) + .dimension(4) + .vectorDataType(VectorDataType.FLOAT) + .build(); int ivfNlistsParam = 88; int pqMParam = 17; int pqCodeSizeParam = 53; @@ -164,7 +193,8 @@ public void testGetMethodAsMap_whenMethodIsIVFPQ_thenCreateCorrectIndexDescripti Map in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); - Map map = Faiss.INSTANCE.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters(); + Map map = Faiss.INSTANCE.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext) + .getLibraryParameters(); assertTrue(map.containsKey(INDEX_DESCRIPTION_PARAMETER)); assertEquals(expectedIndexDescription, map.get(INDEX_DESCRIPTION_PARAMETER)); @@ -172,6 +202,11 @@ public void testGetMethodAsMap_whenMethodIsIVFPQ_thenCreateCorrectIndexDescripti @SneakyThrows public void testGetMethodAsMap_whenMethodIsIVFSQFP16_thenCreateCorrectIndexDescription() { + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .versionCreated(org.opensearch.Version.CURRENT) + .dimension(4) + .vectorDataType(VectorDataType.FLOAT) + .build(); int nlists = 88; String expectedIndexDescription = String.format(Locale.ROOT, "IVF%d,SQfp16", nlists); @@ -192,7 +227,8 @@ public void testGetMethodAsMap_whenMethodIsIVFSQFP16_thenCreateCorrectIndexDescr Map in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); - Map map = Faiss.INSTANCE.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters(); + Map map = Faiss.INSTANCE.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext) + .getLibraryParameters(); assertTrue(map.containsKey(INDEX_DESCRIPTION_PARAMETER)); assertEquals(expectedIndexDescription, map.get(INDEX_DESCRIPTION_PARAMETER)); @@ -210,9 +246,9 @@ public void testMethodAsMapBuilder() throws IOException { String parameter3 = "test-parameter-3"; Integer defaultValue3 = 3; MethodComponent methodComponent = MethodComponent.Builder.builder(methodName) - .addParameter(parameter1, new Parameter.IntegerParameter(parameter1, defaultValue1, value -> value > 0)) - .addParameter(parameter2, new Parameter.IntegerParameter(parameter2, defaultValue2, value -> value > 0)) - .addParameter(parameter3, new Parameter.IntegerParameter(parameter3, defaultValue3, value -> value > 0)) + .addParameter(parameter1, new Parameter.IntegerParameter(parameter1, defaultValue1, (value, context) -> value > 0)) + .addParameter(parameter2, new Parameter.IntegerParameter(parameter2, defaultValue2, (value, context) -> value > 0)) + .addParameter(parameter3, new Parameter.IntegerParameter(parameter3, defaultValue3, (value, context) -> value > 0)) .build(); XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() @@ -234,9 +270,12 @@ public void testMethodAsMapBuilder() throws IOException { expectedMap.put(NAME, methodName); expectedMap.put(INDEX_DESCRIPTION_PARAMETER, methodDescription + value1); - Map methodAsMap = MethodAsMapBuilder.builder(methodDescription, methodComponent, methodComponentContext) - .addParameter(parameter1, "", "") - .build(); + Map methodAsMap = MethodAsMapBuilder.builder( + methodDescription, + methodComponent, + methodComponentContext, + KNNMethodConfigContext.builder().versionCreated(Version.CURRENT).build() + ).addParameter(parameter1, "", "").build(); assertEquals(expectedMap, methodAsMap); } diff --git a/src/test/java/org/opensearch/knn/index/engine/lucene/LuceneTests.java b/src/test/java/org/opensearch/knn/index/engine/lucene/LuceneTests.java index d149f4b3b..2d2025d49 100644 --- a/src/test/java/org/opensearch/knn/index/engine/lucene/LuceneTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/lucene/LuceneTests.java @@ -9,7 +9,9 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.SpaceType; @@ -28,6 +30,11 @@ public class LuceneTests extends KNNTestCase { public void testLucenHNSWMethod() throws IOException { + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .versionCreated(org.opensearch.Version.CURRENT) + .dimension(10) + .vectorDataType(VectorDataType.FLOAT) + .build(); int efConstruction = 100; int m = 17; XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() @@ -41,7 +48,7 @@ public void testLucenHNSWMethod() throws IOException { .endObject(); Map in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext1 = KNNMethodContext.parse(in); - assertNull(KNNEngine.LUCENE.validateMethod(knnMethodContext1)); + assertNull(KNNEngine.LUCENE.validateMethod(knnMethodContext1, knnMethodConfigContext)); // Invalid parameter String invalidParameter = "invalid"; @@ -54,7 +61,8 @@ public void testLucenHNSWMethod() throws IOException { .endObject(); in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext2 = KNNMethodContext.parse(in); - assertNotNull(KNNEngine.LUCENE.validateMethod(knnMethodContext2)); + knnMethodContext2.setSpaceType(SpaceType.L2); + assertNotNull(KNNEngine.LUCENE.validateMethod(knnMethodContext2, knnMethodConfigContext)); // Valid parameter, invalid value int invalidEfConstruction = -1; @@ -67,7 +75,8 @@ public void testLucenHNSWMethod() throws IOException { .endObject(); in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext3 = KNNMethodContext.parse(in); - assertNotNull(KNNEngine.LUCENE.validateMethod(knnMethodContext3)); + knnMethodContext3.setSpaceType(SpaceType.L2); + assertNotNull(KNNEngine.LUCENE.validateMethod(knnMethodContext3, knnMethodConfigContext)); // Unsupported space type SpaceType invalidSpaceType = SpaceType.LINF; // Not currently supported @@ -78,7 +87,7 @@ public void testLucenHNSWMethod() throws IOException { .endObject(); in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext4 = KNNMethodContext.parse(in); - assertNotNull(KNNEngine.LUCENE.validateMethod(knnMethodContext4)); + assertNotNull(KNNEngine.LUCENE.validateMethod(knnMethodContext4, knnMethodConfigContext)); // Check INNER_PRODUCT is supported with Lucene Engine xContentBuilder = XContentFactory.jsonBuilder() @@ -92,7 +101,7 @@ public void testLucenHNSWMethod() throws IOException { .endObject(); in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext5 = KNNMethodContext.parse(in); - assertNull(KNNEngine.LUCENE.validateMethod(knnMethodContext5)); + assertNull(KNNEngine.LUCENE.validateMethod(knnMethodContext5, knnMethodConfigContext)); } public void testGetExtension() { diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index b3139fa5c..4034e4cb6 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -5,7 +5,6 @@ package org.opensearch.knn.index.mapper; -import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; import lombok.extern.log4j.Log4j2; import org.apache.lucene.document.KnnByteVectorField; @@ -37,6 +36,7 @@ import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.indices.ModelDao; @@ -51,7 +51,6 @@ import java.util.HashSet; import java.util.List; import java.util.Locale; -import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; @@ -62,9 +61,6 @@ import static org.opensearch.Version.CURRENT; import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; -import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; -import static org.opensearch.knn.common.KNNConstants.FP16_MAX_VALUE; -import static org.opensearch.knn.common.KNNConstants.FP16_MIN_VALUE; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; @@ -80,8 +76,6 @@ import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.KNNSettings.KNN_INDEX; import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.clipVectorValueToFP16Range; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFP16VectorValue; @Log4j2 public class KNNVectorFieldMapperTests extends KNNTestCase { @@ -108,7 +102,7 @@ public class KNNVectorFieldMapperTests extends KNNTestCase { public void testBuilder_getParameters() { String fieldName = "test-field-name"; ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder(fieldName, modelDao, CURRENT, null); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder(fieldName, modelDao, CURRENT, null, null); assertEquals(7, builder.getParameters().size()); List actualParams = builder.getParameters().stream().map(a -> a.name).collect(Collectors.toList()); @@ -116,46 +110,64 @@ public void testBuilder_getParameters() { assertEquals(expectedParams, actualParams); } - public void testBuilder_build_fromKnnMethodContext() { + public void testTypeParser_build_fromKnnMethodContext() throws IOException { // Check that knnMethodContext takes precedent over both model and legacy ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); SpaceType spaceType = SpaceType.COSINESIMIL; - int m = 17; - int efConstruction = 17; + int mRight = 17; + int mWrong = 71; + + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, TEST_DIMENSION) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_M, mRight) + .endObject() + .endObject() + .endObject(); // Setup settings Settings settings = Settings.builder() .put(settings(CURRENT).build()) - .put(KNNSettings.KNN_SPACE_TYPE, spaceType.getValue()) - .put(KNNSettings.KNN_ALGO_PARAM_M, m) - .put(KNNSettings.KNN_ALGO_PARAM_EF_CONSTRUCTION, efConstruction) + .put(KNNSettings.KNN_ALGO_PARAM_M, mWrong) .put(KNN_INDEX, true) .build(); - builder.knnMethodContext.setValue( - new KNNMethodContext( - KNNEngine.DEFAULT, - spaceType, - new MethodComponentContext( - METHOD_HNSW, - ImmutableMap.of(METHOD_PARAMETER_M, m, METHOD_PARAMETER_EF_CONSTRUCTION, efConstruction) - ) - ) + KNNVectorFieldMapper.Builder builder = (KNNVectorFieldMapper.Builder) typeParser.parse( + "test-field-name-1", + xContentBuilderToMap(xContentBuilder), + buildParserContext("test", settings) ); Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper); assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().isPresent()); + assertEquals(spaceType, knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().get().getSpaceType()); + assertEquals( + mRight, + knnVectorFieldMapper.fieldType() + .getKnnMappingConfig() + .getKnnMethodContext() + .get() + .getMethodComponentContext() + .getParameters() + .get(METHOD_PARAMETER_M) + ); assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getModelId().isEmpty()); } public void testBuilder_build_fromModel() { // Check that modelContext takes precedent over legacy ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null, null); SpaceType spaceType = SpaceType.COSINESIMIL; int m = 17; @@ -193,15 +205,19 @@ public void testBuilder_build_fromModel() { assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().isEmpty()); } - public void testBuilder_build_fromLegacy() { + public void testBuilder_build_fromLegacy() throws IOException { // Check legacy is picked up if model context and method context are not set ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); - int m = 17; int efConstruction = 17; + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, 12) + .endObject(); - // Setup settings Settings settings = Settings.builder() .put(settings(CURRENT).build()) .put(KNNSettings.KNN_ALGO_PARAM_M, m) @@ -209,6 +225,13 @@ public void testBuilder_build_fromLegacy() { .put(KNN_INDEX, true) .build(); + KNNVectorFieldMapper.Builder builder = (KNNVectorFieldMapper.Builder) typeParser.parse( + "test-field-name-1", + xContentBuilderToMap(xContentBuilder), + buildParserContext("test", settings) + ); + + // Setup settings Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper); @@ -327,17 +350,16 @@ public void testTypeParser_parse_fromKnnMethodContext_invalidDimension() throws .endObject() .endObject() .endObject(); - KNNVectorFieldMapper.Builder builderOverMaxDimension = (KNNVectorFieldMapper.Builder) typeParser.parse( - fieldName, - xContentBuilderToMap(xContentBuilderOverMaxDimension), - buildParserContext(indexName, settings) - ); IllegalArgumentException ex = expectThrows( IllegalArgumentException.class, - () -> builderOverMaxDimension.build(new Mapper.BuilderContext(settings, new ContentPath())) + () -> typeParser.parse( + fieldName, + xContentBuilderToMap(xContentBuilderOverMaxDimension), + buildParserContext(indexName, settings) + ) ); - assertEquals("Dimension value cannot be greater than 16000 for vector: test-field-name", ex.getMessage()); + assertTrue(ex.getMessage().contains("Dimension value cannot be greater than 16000 for vector with engine: lucene")); XContentBuilder xContentBuilderInvalidDimension = XContentFactory.jsonBuilder() .startObject() @@ -417,7 +439,7 @@ public void testTypeParser_parse_fromKnnMethodContext_invalidSpaceType() throws String fieldName = "test-field-name"; String indexName = "test-index-name"; - Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); ModelDao modelDao = mock(ModelDao.class); KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); @@ -449,7 +471,7 @@ public void testTypeParser_parse_fromKnnMethodContext() throws IOException { String fieldName = "test-field-name"; String indexName = "test-index-name"; - Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); ModelDao modelDao = mock(ModelDao.class); KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); @@ -546,7 +568,7 @@ public void testTypeParser_parse_fromModel() throws IOException { String fieldName = "test-field-name"; String indexName = "test-index-name"; - Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); ModelDao modelDao = mock(ModelDao.class); KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); @@ -751,8 +773,12 @@ public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldT log.info("Vector Data Type is : {}", dataType); int dimension = dataType == VectorDataType.BINARY ? TEST_DIMENSION * 8 : TEST_DIMENSION; final MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); - methodComponentContext.setIndexVersion(CURRENT); SpaceType spaceType = VectorDataType.BINARY == dataType ? SpaceType.DEFAULT_BINARY : SpaceType.INNER_PRODUCT; + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(dataType) + .versionCreated(CURRENT) + .dimension(dimension) + .build(); final KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, spaceType, methodComponentContext); ParseContext.Document document = new ParseContext.Document(); @@ -767,16 +793,14 @@ public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldT TEST_FIELD_NAME, TEST_FIELD_NAME, Collections.emptyMap(), - dataType, - dimension, knnMethodContext, + knnMethodConfigContext, knnMethodContext, FieldMapper.MultiFields.empty(), FieldMapper.CopyTo.empty(), new Explicit<>(true, true), false, - false, - CURRENT + false ) ); @@ -819,16 +843,14 @@ public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldT TEST_FIELD_NAME, TEST_FIELD_NAME, Collections.emptyMap(), - dataType, - dimension, knnMethodContext, + knnMethodConfigContext, knnMethodContext, FieldMapper.MultiFields.empty(), FieldMapper.CopyTo.empty(), new Explicit<>(true, true), false, - false, - CURRENT + false ) ); @@ -854,21 +876,24 @@ public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldT public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { // Create a lucene field mapper that creates a binary doc values field as well as KnnVectorField LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperInputBuilder inputBuilder = - createLuceneFieldMapperInputBuilder(VectorDataType.FLOAT); + createLuceneFieldMapperInputBuilder(); ParseContext.Document document = new ParseContext.Document(); ContentPath contentPath = new ContentPath(); ParseContext parseContext = mock(ParseContext.class); when(parseContext.doc()).thenReturn(document); when(parseContext.path()).thenReturn(contentPath); - + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .versionCreated(CURRENT) + .dimension(TEST_DIMENSION) + .build(); LuceneFieldMapper luceneFieldMapper = Mockito.spy( LuceneFieldMapper.createFieldMapper( TEST_FIELD_NAME, Collections.emptyMap(), - VectorDataType.FLOAT, - TEST_DIMENSION, getDefaultKNNMethodContext(), + knnMethodConfigContext, inputBuilder.build() ) ); @@ -908,18 +933,19 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { inputBuilder.hasDocValues(false); - KNNMethodContext knnMethodContext = new KNNMethodContext( - KNNEngine.LUCENE, - SpaceType.DEFAULT, - new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()) - ); + knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .versionCreated(CURRENT) + .dimension(TEST_DIMENSION) + .build(); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.DEFAULT, methodComponentContext); luceneFieldMapper = Mockito.spy( LuceneFieldMapper.createFieldMapper( TEST_FIELD_NAME, Collections.emptyMap(), - VectorDataType.FLOAT, - TEST_DIMENSION, knnMethodContext, + knnMethodConfigContext, inputBuilder.build() ) ); @@ -942,7 +968,7 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { // Create a lucene field mapper that creates a binary doc values field as well as KnnByteVectorField LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperInputBuilder inputBuilder = - createLuceneFieldMapperInputBuilder(VectorDataType.BYTE); + createLuceneFieldMapperInputBuilder(); ParseContext.Document document = new ParseContext.Document(); ContentPath contentPath = new ContentPath(); @@ -954,9 +980,12 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { LuceneFieldMapper.createFieldMapper( TEST_FIELD_NAME, Collections.emptyMap(), - VectorDataType.BYTE, - TEST_DIMENSION, - getDefaultKNNMethodContext(), + getDefaultByteKNNMethodContext(), + KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.BYTE) + .versionCreated(CURRENT) + .dimension(TEST_DIMENSION) + .build(), inputBuilder.build() ) ); @@ -1000,9 +1029,12 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { LuceneFieldMapper.createFieldMapper( TEST_FIELD_NAME, Collections.emptyMap(), - VectorDataType.BYTE, - TEST_DIMENSION, - getDefaultKNNMethodContext(), + getDefaultByteKNNMethodContext(), + KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.BYTE) + .versionCreated(CURRENT) + .dimension(TEST_DIMENSION) + .build(), inputBuilder.build() ) ); @@ -1021,131 +1053,105 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { assertArrayEquals(TEST_BYTE_VECTOR, knnByteVectorField.vectorValue()); } - public void testValidateFp16VectorValue_outOfRange_throwsException() { - IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> validateFP16VectorValue(65505.25f)); - assertTrue( - ex.getMessage() - .contains( - String.format( - Locale.ROOT, - "encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]", - ENCODER_SQ, - FAISS_SQ_ENCODER_FP16, - FP16_MIN_VALUE, - FP16_MAX_VALUE - ) - ) - ); - - IllegalArgumentException ex1 = expectThrows(IllegalArgumentException.class, () -> validateFP16VectorValue(-65525.65f)); - assertTrue( - ex1.getMessage() - .contains( - String.format( - Locale.ROOT, - "encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]", - ENCODER_SQ, - FAISS_SQ_ENCODER_FP16, - FP16_MIN_VALUE, - FP16_MAX_VALUE - ) - ) - ); - } - - public void testClipVectorValuetoFP16Range_succeed() { - assertEquals(65504.0f, clipVectorValueToFP16Range(65504.10f), 0.0f); - assertEquals(65504.0f, clipVectorValueToFP16Range(1000000.89f), 0.0f); - assertEquals(-65504.0f, clipVectorValueToFP16Range(-65504.10f), 0.0f); - assertEquals(-65504.0f, clipVectorValueToFP16Range(-1000000.89f), 0.0f); + public void testTypeParser_whenBinaryFaissHNSW_thenValid() throws IOException { + testTypeParserWithBinaryDataType(KNNEngine.FAISS, SpaceType.HAMMING, METHOD_HNSW, 8, null); } - public void testBuilder_whenBinaryFaissHNSW_thenValid() { - testBuilderWithBinaryDataType(KNNEngine.FAISS, SpaceType.UNDEFINED, METHOD_HNSW, 8, null); + public void testTypeParser_whenBinaryWithInvalidDimension_thenException() throws IOException { + testTypeParserWithBinaryDataType(KNNEngine.FAISS, SpaceType.UNDEFINED, METHOD_HNSW, 4, "should be multiply of 8"); } - public void testBuilder_whenBinaryWithInvalidDimension_thenException() { - testBuilderWithBinaryDataType(KNNEngine.FAISS, SpaceType.UNDEFINED, METHOD_HNSW, 4, "should be multiply of 8"); - } - - public void testBuilder_whenBinaryFaissHNSWWithInvalidSpaceType_thenException() { + public void testTypeParser_whenBinaryFaissHNSWWithInvalidSpaceType_thenException() throws IOException { for (SpaceType spaceType : SpaceType.values()) { if (SpaceType.UNDEFINED == spaceType || SpaceType.HAMMING == spaceType) { continue; } - testBuilderWithBinaryDataType(KNNEngine.FAISS, spaceType, METHOD_HNSW, 8, "is not supported"); + testTypeParserWithBinaryDataType(KNNEngine.FAISS, spaceType, METHOD_HNSW, 8, "is not supported with"); } } - public void testBuilder_whenBinaryNonFaiss_thenException() { - testBuilderWithBinaryDataType(KNNEngine.LUCENE, SpaceType.UNDEFINED, METHOD_HNSW, 8, "is only supported for"); - testBuilderWithBinaryDataType(KNNEngine.NMSLIB, SpaceType.UNDEFINED, METHOD_HNSW, 8, "is only supported for"); + public void testTypeParser_whenBinaryNonFaiss_thenException() throws IOException { + testTypeParserWithBinaryDataType(KNNEngine.LUCENE, SpaceType.UNDEFINED, METHOD_HNSW, 8, "is not supported for vector data type"); + testTypeParserWithBinaryDataType(KNNEngine.NMSLIB, SpaceType.UNDEFINED, METHOD_HNSW, 8, "is not supported for vector data type"); } - private void testBuilderWithBinaryDataType( + private void testTypeParserWithBinaryDataType( KNNEngine knnEngine, SpaceType spaceType, String method, int dimension, String expectedErrMsg - ) { + ) throws IOException { + // Check legacy is picked up if model context and method context are not set ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); + String fieldName = "test-field-name-1"; + String indexName = "test-index"; // Setup settings Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); - builder.knnMethodContext.setValue( - new KNNMethodContext(knnEngine, spaceType, new MethodComponentContext(method, Collections.emptyMap())) - ); - builder.vectorDataType.setValue(VectorDataType.BINARY); - builder.dimension.setValue(dimension); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, dimension) + .field(VECTOR_DATA_TYPE_FIELD, VectorDataType.BINARY.getValue()) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, knnEngine.getName()) + .endObject() + .endObject(); - Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); if (expectedErrMsg == null) { - KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); - assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper); - if (SpaceType.UNDEFINED == spaceType) { - assertEquals( - SpaceType.HAMMING, - knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().get().getSpaceType() - ); - } + KNNVectorFieldMapper.Builder builder = (KNNVectorFieldMapper.Builder) typeParser.parse( + fieldName, + xContentBuilderToMap(xContentBuilder), + buildParserContext(indexName, settings) + ); + + assertEquals(spaceType, builder.getResolvedKNNMethodContext().getSpaceType()); } else { - Exception ex = expectThrows(Exception.class, () -> builder.build(builderContext)); + Exception ex = expectThrows(Exception.class, () -> { + typeParser.parse(fieldName, xContentBuilderToMap(xContentBuilder), buildParserContext(indexName, settings)); + }); assertTrue(ex.getMessage(), ex.getMessage().contains(expectedErrMsg)); } } - public void testBuilder_whenBinaryFaissHNSWWithSQ_thenException() { + public void testTypeParser_whenBinaryFaissHNSWWithSQ_thenException() throws IOException { ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); - + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); // Setup settings Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); - builder.knnMethodContext.setValue( - new KNNMethodContext( - KNNEngine.FAISS, - SpaceType.HAMMING, - new MethodComponentContext( - METHOD_HNSW, - Map.of(METHOD_ENCODER_PARAMETER, new MethodComponentContext(ENCODER_SQ, Collections.emptyMap())) - ) - ) - ); - builder.vectorDataType.setValue(VectorDataType.BINARY); - builder.dimension.setValue(8); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, 8) + .field(VECTOR_DATA_TYPE_FIELD, VectorDataType.BINARY.getValue()) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_SQ) + .endObject() + .endObject() + .endObject() + .endObject(); - Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); - Exception ex = expectThrows(Exception.class, () -> builder.build(builderContext)); - assertTrue(ex.getMessage(), ex.getMessage().contains("data type does not support")); + Exception ex = expectThrows( + Exception.class, + () -> typeParser.parse("test", xContentBuilderToMap(xContentBuilder), buildParserContext("test", settings)) + ); + assertTrue(ex.getMessage(), ex.getMessage().contains("parameter validation failed for MethodComponentContext parameter [encoder]")); } public void testBuilder_whenBinaryWithLegacyKNNDisabled_thenValid() { // Check legacy is picked up if model context and method context are not set ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null, null); builder.vectorDataType.setValue(VectorDataType.BINARY); builder.dimension.setValue(8); @@ -1157,19 +1163,28 @@ public void testBuilder_whenBinaryWithLegacyKNNDisabled_thenValid() { assertTrue(knnVectorFieldMapper instanceof FlatVectorFieldMapper); } - public void testBuilder_whenBinaryWithLegacyKNNEnabled_thenException() { + public void testTypeParser_whenBinaryWithLegacyKNNEnabled_thenException() throws IOException { // Check legacy is picked up if model context and method context are not set ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); - builder.vectorDataType.setValue(VectorDataType.BINARY); - builder.dimension.setValue(8); + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); + String fieldName = "test-field-name-1"; + String indexName = "test-index"; // Setup settings Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); - Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); - Exception ex = expectThrows(Exception.class, () -> builder.build(builderContext)); - assertTrue(ex.getMessage(), ex.getMessage().contains("is not supported for")); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, 8) + .field(VECTOR_DATA_TYPE_FIELD, VectorDataType.BINARY.getValue()) + .endObject(); + + Exception ex = expectThrows(Exception.class, () -> { + typeParser.parse(fieldName, xContentBuilderToMap(xContentBuilder), buildParserContext(indexName, settings)); + }); + + assertTrue(ex.getMessage(), ex.getMessage().contains("is not supported with")); } public void testBuild_whenInvalidCharsInFieldName_thenThrowException() { @@ -1183,21 +1198,18 @@ public void testBuild_whenInvalidCharsInFieldName_thenThrowException() { // IllegalArgumentException should be thrown. Exception e = assertThrows(IllegalArgumentException.class, () -> { - new KNNVectorFieldMapper.Builder(invalidVectorFieldName, null, CURRENT, null).build(builderContext); + new KNNVectorFieldMapper.Builder(invalidVectorFieldName, null, CURRENT, null, null).build(builderContext); }); assertTrue(e.getMessage(), e.getMessage().contains("Vector field name must not include")); } } - private LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperInputBuilder createLuceneFieldMapperInputBuilder( - VectorDataType vectorDataType - ) { + private LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperInputBuilder createLuceneFieldMapperInputBuilder() { return LuceneFieldMapper.CreateLuceneFieldMapperInput.builder() .name(TEST_FIELD_NAME) .multiFields(FieldMapper.MultiFields.empty()) .copyTo(FieldMapper.CopyTo.empty()) .hasDocValues(true) - .vectorDataType(vectorDataType) .ignoreMalformed(new Explicit<>(true, true)) .originalKnnMethodContext(getDefaultKNNMethodContext()); } @@ -1247,6 +1259,5 @@ public Mapper.TypeParser.ParserContext buildParserContext(String indexName, Sett null, null ); - } } diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java index ad041e47e..740d75206 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java @@ -18,17 +18,11 @@ import org.mockito.Mockito; import org.opensearch.Version; import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.engine.KNNMethodContext; -import org.opensearch.knn.index.engine.MethodComponentContext; -import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; -import org.opensearch.knn.index.engine.KNNEngine; import java.util.Arrays; -import java.util.Collections; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -81,35 +75,6 @@ public void testGetExpectedVectorLengthSuccess() { assertEquals(4, KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldTypeModelBased)); } - public void testValidateVectorDataType_whenBinaryFaissHNSW_thenValid() { - validateValidateVectorDataType(KNNEngine.FAISS, KNNConstants.METHOD_HNSW, VectorDataType.BINARY, null); - } - - public void testValidateVectorDataType_whenBinaryNonFaiss_thenException() { - validateValidateVectorDataType(KNNEngine.LUCENE, KNNConstants.METHOD_HNSW, VectorDataType.BINARY, "only supported"); - validateValidateVectorDataType(KNNEngine.NMSLIB, KNNConstants.METHOD_HNSW, VectorDataType.BINARY, "only supported"); - } - - public void testValidateVectorDataType_whenBinaryFaissIVF_thenException() { - validateValidateVectorDataType(KNNEngine.FAISS, KNNConstants.METHOD_IVF, VectorDataType.BINARY, "only supported"); - } - - public void testValidateVectorDataType_whenByteLucene_thenValid() { - validateValidateVectorDataType(KNNEngine.LUCENE, KNNConstants.METHOD_HNSW, VectorDataType.BYTE, null); - validateValidateVectorDataType(KNNEngine.LUCENE, KNNConstants.METHOD_IVF, VectorDataType.BYTE, null); - } - - public void testValidateVectorDataType_whenByteNonLucene_thenException() { - validateValidateVectorDataType(KNNEngine.FAISS, KNNConstants.METHOD_HNSW, VectorDataType.BYTE, "only supported"); - validateValidateVectorDataType(KNNEngine.NMSLIB, KNNConstants.METHOD_IVF, VectorDataType.BYTE, "only supported"); - } - - public void testValidateVectorDataType_whenFloat_thenValid() { - validateValidateVectorDataType(KNNEngine.FAISS, KNNConstants.METHOD_HNSW, VectorDataType.FLOAT, null); - validateValidateVectorDataType(KNNEngine.LUCENE, KNNConstants.METHOD_HNSW, VectorDataType.FLOAT, null); - validateValidateVectorDataType(KNNEngine.NMSLIB, KNNConstants.METHOD_HNSW, VectorDataType.FLOAT, null); - } - public void testUseLuceneKNNVectorsFormat_withDifferentInputs_thenSuccess() { final KNNSettings knnSettings = mock(KNNSettings.class); final MockedStatic mockedStatic = Mockito.mockStatic(KNNSettings.class); @@ -125,23 +90,4 @@ public void testUseLuceneKNNVectorsFormat_withDifferentInputs_thenSuccess() { // this mocking mockedStatic.close(); } - - private void validateValidateVectorDataType( - final KNNEngine knnEngine, - final String methodName, - final VectorDataType vectorDataType, - final String expectedErrMsg - ) { - MethodComponentContext methodComponentContext = new MethodComponentContext(methodName, Collections.emptyMap()); - KNNMethodContext methodContext = new KNNMethodContext(knnEngine, SpaceType.UNDEFINED, methodComponentContext); - if (expectedErrMsg == null) { - KNNVectorFieldMapperUtil.validateVectorDataType(methodContext, vectorDataType); - } else { - Exception ex = expectThrows( - IllegalArgumentException.class, - () -> KNNVectorFieldMapperUtil.validateVectorDataType(methodContext, vectorDataType) - ); - assertTrue(ex.getMessage().contains(expectedErrMsg)); - } - } } diff --git a/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java deleted file mode 100644 index faae3e35d..000000000 --- a/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.mapper; - -import org.opensearch.Version; -import org.opensearch.index.mapper.FieldMapper; -import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.engine.KNNMethodContext; - -import java.util.Collections; - -public class MethodFieldMapperTests extends KNNTestCase { - public void testMethodFieldMapper_whenVectorDataTypeAndContextMismatch_thenThrow() { - // Expect that we cannot create the mapper with an invalid field type - KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); - expectThrows( - IllegalArgumentException.class, - () -> MethodFieldMapper.createFieldMapper( - "testField", - "simpleName", - Collections.emptyMap(), - VectorDataType.BINARY, - 1, - knnMethodContext, - knnMethodContext, - null, - new FieldMapper.CopyTo.Builder().build(), - KNNVectorFieldMapper.Defaults.IGNORE_MALFORMED, - true, - true, - Version.CURRENT - ) - ); - } -} diff --git a/src/test/java/org/opensearch/knn/integ/BinaryIndexInvalidMappingIT.java b/src/test/java/org/opensearch/knn/integ/BinaryIndexInvalidMappingIT.java index 5832f2718..29e710ec1 100644 --- a/src/test/java/org/opensearch/knn/integ/BinaryIndexInvalidMappingIT.java +++ b/src/test/java/org/opensearch/knn/integ/BinaryIndexInvalidMappingIT.java @@ -49,17 +49,17 @@ public static Collection parameters() throws IOException { $( "Creation of binary index with lucene engine should fail", createKnnHnswBinaryIndexMapping(KNNEngine.LUCENE, FIELD_NAME, 16, null), - "only supported for [faiss] engine" + "Validation Failed" ), $( "Creation of binary index with nmslib engine should fail", createKnnHnswBinaryIndexMapping(KNNEngine.NMSLIB, FIELD_NAME, 16, null), - "only supported for [faiss] engine" + "Validation Failed" ), $( "Creation of binary index with encoder should fail", createKnnHnswBinaryIndexMapping(KNNEngine.FAISS, FIELD_NAME, 16, ENCODER_SQ), - "does not support sq encoder" + "Validation Failed" ) ) ); diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index ae9ad7106..53245cc62 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -21,6 +21,7 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.TestUtils; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.nmslib.NmslibHNSWMethod; @@ -611,7 +612,13 @@ public void testTrain_whenConfigurationIsIVFSQFP16_thenSucceed() { .endObject(); Map in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); - Map parameters = KNNEngine.FAISS.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters(); + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .versionCreated(Version.CURRENT) + .dimension(128) + .vectorDataType(VectorDataType.FLOAT) + .build(); + Map parameters = KNNEngine.FAISS.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext) + .getLibraryParameters(); byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer, KNNEngine.FAISS); @@ -1131,7 +1138,13 @@ public void testTrain_whenConfigurationIsIVFFlat_thenSucceed() throws IOExceptio .endObject(); Map in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); - Map parameters = KNNEngine.FAISS.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters(); + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .dimension(testData.indexData.getDimension()) + .versionCreated(Version.CURRENT) + .build(); + Map parameters = KNNEngine.FAISS.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext) + .getLibraryParameters(); byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer, KNNEngine.FAISS); @@ -1162,7 +1175,13 @@ public void testTrain_whenConfigurationIsIVFPQ_thenSucceed() throws IOException .endObject(); Map in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); - Map parameters = KNNEngine.FAISS.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters(); + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .versionCreated(Version.CURRENT) + .dimension(128) + .vectorDataType(VectorDataType.FLOAT) + .build(); + Map parameters = KNNEngine.FAISS.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext) + .getLibraryParameters(); byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer, KNNEngine.FAISS); @@ -1189,8 +1208,13 @@ public void testTrain_whenConfigurationIsHNSWPQ_thenSucceed() throws IOException .endObject(); Map in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); - knnMethodContext.getMethodComponentContext().setIndexVersion(Version.CURRENT); - Map parameters = KNNEngine.FAISS.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters(); + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .dimension(testData.indexData.getDimension()) + .versionCreated(Version.CURRENT) + .build(); + Map parameters = KNNEngine.FAISS.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext) + .getLibraryParameters(); byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer, KNNEngine.FAISS); @@ -1223,6 +1247,11 @@ public void testCreateIndexFromTemplate() throws IOException { } SpaceType spaceType = SpaceType.L2; + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .versionCreated(Version.CURRENT) + .dimension(128) + .vectorDataType(VectorDataType.FLOAT) + .build(); KNNMethodContext knnMethodContext = new KNNMethodContext( KNNEngine.FAISS, spaceType, @@ -1238,7 +1267,7 @@ public void testCreateIndexFromTemplate() throws IOException { ); String description = knnMethodContext.getKnnEngine() - .getKNNLibraryIndexingContext(knnMethodContext) + .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext) .getLibraryParameters() .get(INDEX_DESCRIPTION_PARAMETER) .toString(); @@ -1361,7 +1390,11 @@ private void assertQueryResultsMatch(float[][] testQueries, int k, List in private String createFaissIVFPQIndex(int ivfNlist, int pqM, int pqCodeSize, SpaceType spaceType) throws IOException { long trainPointer = JNIService.transferVectors(0, testData.indexData.vectors); assertNotEquals(0, trainPointer); - + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .versionCreated(Version.CURRENT) + .dimension(128) + .vectorDataType(VectorDataType.FLOAT) + .build(); KNNMethodContext knnMethodContext = new KNNMethodContext( KNNEngine.FAISS, spaceType, @@ -1380,7 +1413,7 @@ private String createFaissIVFPQIndex(int ivfNlist, int pqM, int pqCodeSize, Spac ); String description = knnMethodContext.getKnnEngine() - .getKNNLibraryIndexingContext(knnMethodContext) + .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext) .getLibraryParameters() .get(INDEX_DESCRIPTION_PARAMETER) .toString(); diff --git a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java index 7fa0d3bca..4399b3318 100644 --- a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java +++ b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java @@ -14,10 +14,10 @@ import org.opensearch.common.ValidationException; import org.opensearch.knn.index.engine.KNNLibraryIndexingContext; import org.opensearch.knn.index.engine.KNNLibrarySearchContext; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.engine.KNNLibrary; -import org.opensearch.knn.training.VectorSpaceInfo; import org.opensearch.test.OpenSearchTestCase; public class LibraryInitializedSupplierTests extends OpenSearchTestCase { @@ -74,12 +74,7 @@ public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { } @Override - public ValidationException validateMethod(KNNMethodContext knnMethodContext) { - return null; - } - - @Override - public ValidationException validateMethodWithData(KNNMethodContext knnMethodContext, VectorSpaceInfo vectorSpaceInfo) { + public ValidationException validateMethod(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) { return null; } @@ -89,12 +84,15 @@ public boolean isTrainingRequired(KNNMethodContext knnMethodContext) { } @Override - public int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension) { + public int estimateOverheadInKB(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) { return 0; } @Override - public KNNLibraryIndexingContext getKNNLibraryIndexingContext(KNNMethodContext knnMethodContext) { + public KNNLibraryIndexingContext getKNNLibraryIndexingContext( + KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext + ) { return null; } diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java index 83d39cfdc..d7920d987 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java @@ -41,6 +41,7 @@ import java.util.List; import java.util.Map; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -149,7 +150,7 @@ public void testValidation_invalid_modelIdAlreadyExists() { // Setup the training request String modelId = "test-model-id"; KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); - when(knnMethodContext.validate()).thenReturn(null); + when(knnMethodContext.validate(any())).thenReturn(null); when(knnMethodContext.isTrainingRequired()).thenReturn(true); int dimension = 10; String trainingIndex = "test-training-index"; @@ -206,7 +207,7 @@ public void testValidation_blocked_modelId() { // Setup the training request String modelId = "test-model-id"; KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); - when(knnMethodContext.validate()).thenReturn(null); + when(knnMethodContext.validate(any())).thenReturn(null); when(knnMethodContext.isTrainingRequired()).thenReturn(true); int dimension = 10; String trainingIndex = "test-training-index"; @@ -252,7 +253,7 @@ public void testValidation_invalid_invalidMethodContext() { String validationExceptionMessage = "knn method invalid"; ValidationException validationException = new ValidationException(); validationException.addValidationError(validationExceptionMessage); - when(knnMethodContext.validate()).thenReturn(validationException); + when(knnMethodContext.validate(any())).thenReturn(validationException); when(knnMethodContext.isTrainingRequired()).thenReturn(false); when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY); @@ -297,7 +298,7 @@ public void testValidation_invalid_trainingIndexDoesNotExist() { String modelId = "test-model-id"; KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); - when(knnMethodContext.validate()).thenReturn(null); + when(knnMethodContext.validate(any())).thenReturn(null); when(knnMethodContext.isTrainingRequired()).thenReturn(true); int dimension = 10; @@ -344,7 +345,7 @@ public void testValidation_invalid_trainingFieldDoesNotExist() { String modelId = "test-model-id"; KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); - when(knnMethodContext.validate()).thenReturn(null); + when(knnMethodContext.validate(any())).thenReturn(null); when(knnMethodContext.isTrainingRequired()).thenReturn(true); int dimension = 10; @@ -396,7 +397,7 @@ public void testValidation_invalid_trainingFieldNotKnnVector() { String modelId = "test-model-id"; KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); - when(knnMethodContext.validate()).thenReturn(null); + when(knnMethodContext.validate(any())).thenReturn(null); when(knnMethodContext.isTrainingRequired()).thenReturn(true); int dimension = 10; @@ -452,7 +453,7 @@ public void testValidation_invalid_dimensionDoesNotMatch() { String modelId = "test-model-id"; KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); - when(knnMethodContext.validate()).thenReturn(null); + when(knnMethodContext.validate(any())).thenReturn(null); when(knnMethodContext.isTrainingRequired()).thenReturn(true); when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY); @@ -511,7 +512,7 @@ public void testValidation_invalid_preferredNodeDoesNotExist() { // Setup the training request String modelId = "test-model-id"; KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); - when(knnMethodContext.validate()).thenReturn(null); + when(knnMethodContext.validate(any())).thenReturn(null); when(knnMethodContext.isTrainingRequired()).thenReturn(true); when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY); int dimension = 10; @@ -574,7 +575,7 @@ public void testValidation_invalid_descriptionToLong() { // Setup the training request String modelId = "test-model-id"; KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); - when(knnMethodContext.validate()).thenReturn(null); + when(knnMethodContext.validate(any())).thenReturn(null); when(knnMethodContext.isTrainingRequired()).thenReturn(true); when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY); int dimension = 10; @@ -625,7 +626,7 @@ public void testValidation_valid_trainingIndexBuiltFromMethod() { // Setup the training request String modelId = "test-model-id"; KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); - when(knnMethodContext.validate()).thenReturn(null); + when(knnMethodContext.validate(any())).thenReturn(null); when(knnMethodContext.isTrainingRequired()).thenReturn(true); when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY); int dimension = 10; @@ -663,7 +664,7 @@ public void testValidation_valid_trainingIndexBuiltFromModel() { // Setup the training request String modelId = "test-model-id"; KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); - when(knnMethodContext.validate()).thenReturn(null); + when(knnMethodContext.validate(any())).thenReturn(null); when(knnMethodContext.isTrainingRequired()).thenReturn(true); when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY); int dimension = 10; diff --git a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java index b6d76c68e..adecca43a 100644 --- a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java +++ b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java @@ -15,6 +15,7 @@ import org.opensearch.Version; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.SpaceType; @@ -66,10 +67,9 @@ public void testGetModelId() { mock(NativeMemoryCacheManager.class), mock(NativeMemoryEntryContext.TrainingDataEntryContext.class), mock(NativeMemoryEntryContext.AnonymousEntryContext.class), - 10, + KNNMethodConfigContext.builder().vectorDataType(VectorDataType.DEFAULT).dimension(10).versionCreated(Version.CURRENT).build(), "", - "test-node", - VectorDataType.DEFAULT + "test-node" ); assertEquals(modelId, trainingJob.getModelId()); @@ -96,10 +96,13 @@ public void testGetModel() { mock(NativeMemoryCacheManager.class), mock(NativeMemoryEntryContext.TrainingDataEntryContext.class), mock(NativeMemoryEntryContext.AnonymousEntryContext.class), - dimension, + KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .dimension(dimension) + .versionCreated(Version.CURRENT) + .build(), description, - nodeAssignment, - VectorDataType.DEFAULT + nodeAssignment ); Model model = new Model( @@ -130,6 +133,11 @@ public void testRun_success() throws IOException, ExecutionException { int nlists = 5; int dimension = 16; KNNEngine knnEngine = KNNEngine.FAISS; + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .dimension(dimension) + .versionCreated(Version.CURRENT) + .build(); KNNMethodContext knnMethodContext = new KNNMethodContext( knnEngine, SpaceType.INNER_PRODUCT, @@ -185,10 +193,9 @@ public void testRun_success() throws IOException, ExecutionException { nativeMemoryCacheManager, trainingDataEntryContext, modelContext, - dimension, + knnMethodConfigContext, "", - "test-node", - VectorDataType.DEFAULT + "test-node" ); trainingJob.run(); @@ -225,6 +232,11 @@ public void testRun_failure_onGetTrainingDataAllocation() throws ExecutionExcept int nlists = 5; int dimension = 16; KNNEngine knnEngine = KNNEngine.FAISS; + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .dimension(dimension) + .versionCreated(Version.CURRENT) + .build(); KNNMethodContext knnMethodContext = new KNNMethodContext( knnEngine, SpaceType.INNER_PRODUCT, @@ -264,10 +276,9 @@ public void testRun_failure_onGetTrainingDataAllocation() throws ExecutionExcept nativeMemoryCacheManager, trainingDataEntryContext, modelContext, - dimension, + knnMethodConfigContext, "", - "test-node", - VectorDataType.DEFAULT + "test-node" ); trainingJob.run(); @@ -287,6 +298,11 @@ public void testRun_failure_onGetModelAnonymousAllocation() throws ExecutionExce int nlists = 5; int dimension = 16; KNNEngine knnEngine = KNNEngine.FAISS; + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .dimension(dimension) + .versionCreated(Version.CURRENT) + .build(); KNNMethodContext knnMethodContext = new KNNMethodContext( knnEngine, SpaceType.INNER_PRODUCT, @@ -332,10 +348,9 @@ public void testRun_failure_onGetModelAnonymousAllocation() throws ExecutionExce nativeMemoryCacheManager, trainingDataEntryContext, modelContext, - dimension, + knnMethodConfigContext, "", - "test-node", - VectorDataType.DEFAULT + "test-node" ); trainingJob.run(); @@ -355,6 +370,11 @@ public void testRun_failure_closedTrainingDataAllocation() throws ExecutionExcep int nlists = 5; int dimension = 16; KNNEngine knnEngine = KNNEngine.FAISS; + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .dimension(dimension) + .versionCreated(Version.CURRENT) + .build(); KNNMethodContext knnMethodContext = new KNNMethodContext( knnEngine, SpaceType.INNER_PRODUCT, @@ -399,10 +419,9 @@ public void testRun_failure_closedTrainingDataAllocation() throws ExecutionExcep nativeMemoryCacheManager, trainingDataEntryContext, mock(NativeMemoryEntryContext.AnonymousEntryContext.class), - dimension, + knnMethodConfigContext, "", - "test-node", - VectorDataType.DEFAULT + "test-node" ); trainingJob.run(); @@ -420,6 +439,11 @@ public void testRun_failure_notEnoughTrainingData() throws ExecutionException { int nlists = 1024; // setting this to 1024 will cause training to fail when there is only 2 data points int dimension = 16; KNNEngine knnEngine = KNNEngine.FAISS; + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .dimension(dimension) + .versionCreated(Version.CURRENT) + .build(); KNNMethodContext knnMethodContext = new KNNMethodContext( knnEngine, SpaceType.INNER_PRODUCT, @@ -473,10 +497,9 @@ public void testRun_failure_notEnoughTrainingData() throws ExecutionException { nativeMemoryCacheManager, trainingDataEntryContext, modelContext, - dimension, + knnMethodConfigContext, "", - "test-node", - VectorDataType.DEFAULT + "test-node" ); trainingJob.run(); diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 50b149830..a90935869 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -1222,6 +1222,14 @@ public void addKNNDocs(String testIndex, String testField, int dimension, int fi } } + public void addKNNByteDocs(String testIndex, String testField, int dimension, int firstDocID, int numDocs) throws IOException { + for (int i = firstDocID; i < firstDocID + numDocs; i++) { + Byte[] indexVector = new Byte[dimension]; + Arrays.fill(indexVector, (byte) i); + addKnnDoc(testIndex, Integer.toString(i), testField, indexVector); + } + } + public void validateKNNSearch(String testIndex, String testField, int dimension, int numDocs, int k) throws Exception { validateKNNSearch(testIndex, testField, dimension, numDocs, k, null); }