diff --git a/CHANGELOG.md b/CHANGELOG.md index c16c78cf9..5f1929a5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] - Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290] - Optimizes lucene query execution to prevent unnecessary rewrites (#2305)[https://github.com/opensearch-project/k-NN/pull/2305] +- Added more detailed error messages for KNN model training (#2378)[https://github.com/opensearch-project/k-NN/pull/2378] - Add check to directly use ANN Search when filters match all docs. (#2320)[https://github.com/opensearch-project/k-NN/pull/2320] - Use one formula to calculate cosine similarity (#2357)[https://github.com/opensearch-project/k-NN/pull/2357] - Make the build work for M series MacOS without manual code changes and local JAVA_HOME config (#2397)[https://github.com/opensearch-project/k-NN/pull/2397] diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/FaissSQIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/FaissSQIT.java index c9a74418f..e8bf47625 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/FaissSQIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/FaissSQIT.java @@ -225,7 +225,7 @@ public void testIVFSQFP16_onUpgradeWhenIndexedAndQueried_thenSucceed() throws Ex // Add training data createBasicKnnIndex(TRAIN_INDEX, TRAIN_TEST_FIELD, DIMENSION); - int trainingDataCount = 200; + int trainingDataCount = 1100; bulkIngestRandomVectors(TRAIN_INDEX, TRAIN_TEST_FIELD, trainingDataCount, DIMENSION); XContentBuilder builder = XContentFactory.jsonBuilder() @@ -278,7 +278,7 @@ public void testIVFSQFP16_onUpgradeWhenClipToFp16isTrueAndIndexedWithOutOfFP16Ra // Add training data createBasicKnnIndex(TRAIN_INDEX, TRAIN_TEST_FIELD, dimension); - int trainingDataCount = 200; + int trainingDataCount = 1100; bulkIngestRandomVectors(TRAIN_INDEX, TRAIN_TEST_FIELD, trainingDataCount, dimension); XContentBuilder builder = XContentFactory.jsonBuilder() diff --git a/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java b/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java index bfd908a09..822f0e2ca 100644 --- a/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java @@ -22,6 +22,12 @@ import java.util.Locale; import java.util.Map; import java.util.Set; +import java.util.function.Function; + +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; /** * Abstract class for KNN methods. This class provides the common functionality for all KNN methods. @@ -108,6 +114,55 @@ protected PerDimensionProcessor doGetPerDimensionProcessor( return PerDimensionProcessor.NOOP_PROCESSOR; } + protected Function doGetTrainingConfigValidationSetup() { + return (trainingConfigValidationInput) -> { + + KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext(); + KNNMethodConfigContext knnMethodConfigContext = trainingConfigValidationInput.getKnnMethodConfigContext(); + Long trainingVectors = trainingConfigValidationInput.getTrainingVectorsCount(); + + TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder(); + + // validate ENCODER_PARAMETER_PQ_M is divisible by vector dimension + if (knnMethodContext != null && knnMethodConfigContext != null) { + if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(ENCODER_PARAMETER_PQ_M) + && knnMethodConfigContext.getDimension() % (Integer) knnMethodContext.getMethodComponentContext() + .getParameters() + .get(ENCODER_PARAMETER_PQ_M) != 0) { + builder.valid(false); + return builder.build(); + } else { + builder.valid(true); + } + } + + // validate number of training points should be greater than minimum clustering criteria defined in faiss + if (knnMethodContext != null && trainingVectors != null) { + long minTrainingVectorCount = 1000; + + MethodComponentContext encoderContext = (MethodComponentContext) knnMethodContext.getMethodComponentContext() + .getParameters() + .get(METHOD_ENCODER_PARAMETER); + + if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_PARAMETER_NLIST) + && encoderContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE)) { + + int nlist = ((Integer) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_PARAMETER_NLIST)); + int code_size = ((Integer) encoderContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE)); + minTrainingVectorCount = (long) Math.max(nlist, Math.pow(2, code_size)); + } + + if (trainingVectors < minTrainingVectorCount) { + builder.valid(false).minTrainingVectorCount(minTrainingVectorCount); + return builder.build(); + } else { + builder.valid(true); + } + } + return builder.build(); + }; + } + protected VectorTransformer getVectorTransformer(SpaceType spaceType) { return VectorTransformerFactory.NOOP_VECTOR_TRANSFORMER; } @@ -131,6 +186,7 @@ public KNNLibraryIndexingContext getKNNLibraryIndexingContext( .perDimensionValidator(doGetPerDimensionValidator(knnMethodContext, knnMethodConfigContext)) .perDimensionProcessor(doGetPerDimensionProcessor(knnMethodContext, knnMethodConfigContext)) .vectorTransformer(getVectorTransformer(knnMethodContext.getSpaceType())) + .trainingConfigValidationSetup(doGetTrainingConfigValidationSetup()) .build(); } diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java index 1ff677cd6..9bef9e2e4 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java @@ -12,6 +12,7 @@ import org.opensearch.knn.index.mapper.VectorValidator; import java.util.Map; +import java.util.function.Function; /** * Context a library gives to build one of its indices @@ -49,6 +50,12 @@ public interface KNNLibraryIndexingContext { */ PerDimensionProcessor getPerDimensionProcessor(); + /** + * + * @return Get function that validates training model parameters + */ + Function getTrainingConfigValidationSetup(); + /** * Get the vector transformer that will be used to transform the vector before indexing. * This will be applied at vector level once entire vector is parsed and validated. diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java index 9822033b7..46b5cb215 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java @@ -14,6 +14,7 @@ import java.util.Collections; import java.util.Map; +import java.util.function.Function; /** * Simple implementation of {@link KNNLibraryIndexingContext} @@ -29,6 +30,7 @@ public class KNNLibraryIndexingContextImpl implements KNNLibraryIndexingContext private Map parameters = Collections.emptyMap(); @Builder.Default private QuantizationConfig quantizationConfig = QuantizationConfig.EMPTY; + private Function trainingConfigValidationSetup; @Override public Map getLibraryParameters() { @@ -59,4 +61,9 @@ public PerDimensionValidator getPerDimensionValidator() { public PerDimensionProcessor getPerDimensionProcessor() { return perDimensionProcessor; } + + @Override + public Function getTrainingConfigValidationSetup() { + return trainingConfigValidationSetup; + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/TrainingConfigValidationInput.java b/src/main/java/org/opensearch/knn/index/engine/TrainingConfigValidationInput.java new file mode 100644 index 000000000..5070173f6 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/engine/TrainingConfigValidationInput.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +/** + * This object provides the input of the validation checks for training model inputs. + * The values in this object need to be dynamically set and calling code needs to handle + * the possibility that the values have not been set. + */ +@Setter +@Getter +@Builder +@AllArgsConstructor +public class TrainingConfigValidationInput { + private Long trainingVectorsCount; + private KNNMethodContext knnMethodContext; + private KNNMethodConfigContext knnMethodConfigContext; +} diff --git a/src/main/java/org/opensearch/knn/index/engine/TrainingConfigValidationOutput.java b/src/main/java/org/opensearch/knn/index/engine/TrainingConfigValidationOutput.java new file mode 100644 index 000000000..0cbe6cad5 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/engine/TrainingConfigValidationOutput.java @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +/** + * This object provides the output of the validation checks for training model inputs. + * The values in this object need to be dynamically set and calling code needs to handle + * the possibility that the values have not been set. + */ +@Setter +@Getter +@Builder +@AllArgsConstructor +public class TrainingConfigValidationOutput { + private boolean valid; + private long minTrainingVectorCount; +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java index 78f3769c5..5de0405c4 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java @@ -23,12 +23,18 @@ import org.opensearch.common.ValidationException; import org.opensearch.common.inject.Inject; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNLibraryIndexingContext; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.TrainingConfigValidationOutput; +import org.opensearch.knn.index.engine.TrainingConfigValidationInput; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportRequestOptions; import org.opensearch.transport.TransportService; import java.util.Map; +import java.util.function.Function; import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES; import static org.opensearch.search.internal.SearchContext.DEFAULT_TERMINATE_AFTER; @@ -134,6 +140,29 @@ protected void getTrainingIndexSizeInKB(TrainingModelRequest trainingModelReques trainingVectors = trainingModelRequest.getMaximumVectorCount(); } + KNNMethodContext knnMethodContext = trainingModelRequest.getKnnMethodContext(); + KNNMethodConfigContext knnMethodConfigContext = trainingModelRequest.getKnnMethodConfigContext(); + + KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine() + .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext); + + Function validateTrainingConfig = knnLibraryIndexingContext + .getTrainingConfigValidationSetup(); + + TrainingConfigValidationInput.TrainingConfigValidationInputBuilder inputBuilder = TrainingConfigValidationInput.builder(); + + TrainingConfigValidationOutput validation = validateTrainingConfig.apply( + inputBuilder.trainingVectorsCount(trainingVectors).knnMethodContext(knnMethodContext).build() + ); + if (!validation.isValid()) { + ValidationException exception = new ValidationException(); + exception.addValidationError( + String.format("Number of training points should be greater than %d", validation.getMinTrainingVectorCount()) + ); + listener.onFailure(exception); + return; + } + listener.onResponse( estimateVectorSetSizeInKB(trainingVectors, trainingModelRequest.getDimension(), trainingModelRequest.getVectorDataType()) ); diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java index 9906ab490..bd2c88347 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java @@ -30,10 +30,14 @@ import org.opensearch.knn.index.engine.EngineResolver; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.KNNLibraryIndexingContext; +import org.opensearch.knn.index.engine.TrainingConfigValidationInput; +import org.opensearch.knn.index.engine.TrainingConfigValidationOutput; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.indices.ModelDao; import java.io.IOException; +import java.util.function.Function; /** * Request to train and serialize a model @@ -283,6 +287,21 @@ public ActionRequestValidationException validate() { exception.addValidationError("Description exceeds limit of " + KNNConstants.MAX_MODEL_DESCRIPTION_LENGTH + " characters"); } + KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine() + .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext); + Function validateTrainingConfig = knnLibraryIndexingContext + .getTrainingConfigValidationSetup(); + TrainingConfigValidationInput.TrainingConfigValidationInputBuilder inputBuilder = TrainingConfigValidationInput.builder(); + TrainingConfigValidationOutput validation = validateTrainingConfig.apply( + inputBuilder.knnMethodConfigContext(knnMethodConfigContext).knnMethodContext(knnMethodContext).build() + ); + + // Check if ENCODER_PARAMETER_PQ_M is divisible by vector dimension + if (!validation.isValid()) { + exception = exception == null ? new ActionRequestValidationException() : exception; + exception.addValidationError("Training request ENCODER_PARAMETER_PQ_M is not divisible by vector dimensions"); + } + // Validate training index exists IndexMetadata indexMetadata = clusterService.state().metadata().index(trainingIndex); if (indexMetadata == null) { diff --git a/src/main/java/org/opensearch/knn/training/TrainingJob.java b/src/main/java/org/opensearch/knn/training/TrainingJob.java index b479192e8..275aa2f47 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingJob.java +++ b/src/main/java/org/opensearch/knn/training/TrainingJob.java @@ -203,9 +203,7 @@ public void run() { } catch (Exception e) { logger.error("Failed to run training job for model \"" + modelId + "\": ", e); modelMetadata.setState(ModelState.FAILED); - modelMetadata.setError( - "Failed to execute training. May be caused by an invalid method definition or " + "not enough memory to perform training." - ); + modelMetadata.setError("Failed to execute training. " + e.getMessage()); KNNCounter.TRAINING_ERRORS.increment(); diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index e1f34c798..20249237d 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -304,7 +304,7 @@ public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHN // training data needs to be at least equal to the number of centroids for PQ // which is 2^8 = 256. 8 because that's the only valid code_size for HNSWPQ - int trainingDataCount = 256; + int trainingDataCount = 1100; SpaceType spaceType = SpaceType.L2; @@ -468,7 +468,7 @@ public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() { // training data needs to be at least equal to the number of centroids for PQ // which is 2^8 = 256. 8 because thats the only valid code_size for HNSWPQ - int trainingDataCount = 256; + int trainingDataCount = 1100; SpaceType spaceType = SpaceType.L2; @@ -736,7 +736,7 @@ public void testIVFSQFP16_whenIndexedAndQueried_thenSucceed() { // Add training data createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); - int trainingDataCount = 200; + int trainingDataCount = 1100; bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension); XContentBuilder builder = XContentFactory.jsonBuilder() @@ -960,7 +960,7 @@ public void testIVFSQFP16_whenIndexedWithOutOfFP16Range_thenThrowException() { // Add training data createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); - int trainingDataCount = 200; + int trainingDataCount = 1100; bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension); XContentBuilder builder = XContentFactory.jsonBuilder() @@ -1064,7 +1064,7 @@ public void testIVFSQFP16_whenClipToFp16isTrueAndIndexedWithOutOfFP16Range_thenS // Add training data createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); - int trainingDataCount = 200; + int trainingDataCount = 1100; bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension); XContentBuilder builder = XContentFactory.jsonBuilder() @@ -1144,7 +1144,7 @@ public void testEndToEnd_whenMethodIsHNSWPQAndHyperParametersNotSet_thenSucceed( // training data needs to be at least equal to the number of centroids for PQ // which is 2^8 = 256. 8 because thats the only valid code_size for HNSWPQ - int trainingDataCount = 256; + int trainingDataCount = 1100; SpaceType spaceType = SpaceType.L2; @@ -1412,7 +1412,7 @@ public void testKNNQuery_withModelDifferentCombination_thenSuccess() throws Exce // Add training data createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); - int trainingDataCount = 200; + int trainingDataCount = 1100; bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension); // Call train API - IVF with nlists = 1 is brute force, but will require training @@ -1767,7 +1767,7 @@ public void testIVF_whenBinaryFormat_whenIVF_thenSuccess() { createKnnIndex(trainingIndexName, trainIndexMapping); - int trainingDataCount = 40; + int trainingDataCount = 1100; bulkIngestRandomBinaryVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension); XContentBuilder trainModelXContentBuilder = XContentFactory.jsonBuilder() diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java index 2e68339d4..07b2be40d 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java @@ -603,7 +603,7 @@ public void testIVFByteVector_whenIndexedAndQueried_thenSucceed() { .toString(); createKnnIndex(INDEX_NAME, trainIndexMapping); - int trainingDataCount = 100; + int trainingDataCount = 1100; bulkIngestRandomByteVectors(INDEX_NAME, FIELD_NAME, trainingDataCount, dimension); XContentBuilder trainModelXContentBuilder = XContentFactory.jsonBuilder() diff --git a/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java index 36db43166..86341e8bc 100644 --- a/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java @@ -620,7 +620,7 @@ public void testKNNScriptScoreOnModelBasedIndex() throws Exception { int dimensions = randomIntBetween(2, 10); String trainMapping = createKnnIndexMapping(TRAIN_FIELD_PARAMETER, dimensions); createKnnIndex(TRAIN_INDEX_PARAMETER, trainMapping); - bulkIngestRandomVectors(TRAIN_INDEX_PARAMETER, TRAIN_FIELD_PARAMETER, dimensions * 3, dimensions); + bulkIngestRandomVectors(TRAIN_INDEX_PARAMETER, TRAIN_FIELD_PARAMETER, 1100, dimensions); XContentBuilder methodBuilder = XContentFactory.jsonBuilder() .startObject() diff --git a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java index 8b2cf5d2b..217ec4547 100644 --- a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java +++ b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java @@ -52,7 +52,7 @@ public class ModeAndCompressionIT extends KNNRestTestCase { private static final String TRAINING_INDEX_NAME = "training_index"; private static final String TRAINING_FIELD_NAME = "training_field"; - private static final int TRAINING_VECS = 20; + private static final int TRAINING_VECS = 1100; private static final int DIMENSION = 16; private static final int NUM_DOCS = 20; diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java index 9e09fe913..c5353ba1f 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java @@ -71,7 +71,7 @@ public class RestKNNStatsHandlerIT extends KNNRestTestCase { private static final String FIELD_LUCENE_NAME = "lucene_test_field"; private static final int DIMENSION = 4; private static int DOC_ID = 0; - private static final int NUM_DOCS = 10; + private static final int NUM_DOCS = 1100; private static final int DELAY_MILLI_SEC = 1000; private static final int NUM_OF_ATTEMPTS = 30; diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java index 7e84491c4..7abdc538f 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java @@ -97,28 +97,11 @@ public void testTrainModel_fail_notEnoughData() throws Exception { .endObject(); Map method = xContentBuilderToMap(builder); - Response trainResponse = trainModel(null, trainingIndexName, trainingFieldName, dimension, method, "dummy description"); - - assertEquals(RestStatus.OK, RestStatus.fromCode(trainResponse.getStatusLine().getStatusCode())); - - // Grab the model id from the response - String trainResponseBody = EntityUtils.toString(trainResponse.getEntity()); - assertNotNull(trainResponseBody); - - Map trainResponseMap = createParser(XContentType.JSON.xContent(), trainResponseBody).map(); - String modelId = (String) trainResponseMap.get(MODEL_ID); - assertNotNull(modelId); - - // Confirm that the model fails to create - Response getResponse = getModel(modelId, null); - String responseBody = EntityUtils.toString(getResponse.getEntity()); - assertNotNull(responseBody); - - Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); - - assertEquals(modelId, responseMap.get(MODEL_ID)); - - assertTrainingFails(modelId, 30, 1000); + ResponseException exception = expectThrows( + ResponseException.class, + () -> trainModel(null, trainingIndexName, trainingFieldName, dimension, method, "dummy description") + ); + assertTrue(exception.getMessage().contains("Number of training points should be greater than")); } public void testTrainModel_fail_tooMuchData() throws Exception { @@ -132,7 +115,7 @@ public void testTrainModel_fail_tooMuchData() throws Exception { // Create a training index and randomly ingest data into it createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); - int trainingDataCount = 20; // 20 * 16 * 4 ~= 10 kb + int trainingDataCount = 128; bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension); // Call the train API with this definition: @@ -491,7 +474,7 @@ public void testTrainModel_success_methodOverrideWithCompressionMode() throws Ex // Create a training index and randomly ingest data into it String mapping = createKnnIndexNestedMapping(dimension, nestedFieldPath); createKnnIndex(trainingIndexName, mapping); - int trainingDataCount = 200; + int trainingDataCount = 1100; bulkIngestRandomVectorsWithNestedField(trainingIndexName, nestedFieldPath, trainingDataCount, dimension); // Call the train API with this definition: diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java index 30c5d33a1..aee45e2cc 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java @@ -36,6 +36,7 @@ import java.util.List; import java.util.Map; +import static org.hamcrest.Matchers.containsString; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; @@ -344,6 +345,55 @@ public void testTrainingIndexSize() { transportAction.getTrainingIndexSizeInKB(trainingModelRequest, listener); } + public void testTrainingIndexSizeFailure() { + + String trainingIndexName = "training-index"; + int dimension = 133; + int vectorCount = 100; + + // Setup the request + TrainingModelRequest trainingModelRequest = new TrainingModelRequest( + null, + getDefaultKNNMethodContextForModel(), + dimension, + trainingIndexName, + "training-field", + null, + "description", + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED + ); + + // Mock client to return the right number of docs + TotalHits totalHits = new TotalHits(vectorCount, TotalHits.Relation.EQUAL_TO); + SearchHits searchHits = new SearchHits(new SearchHit[2], totalHits, 1.0f); + SearchResponse searchResponse = mock(SearchResponse.class); + when(searchResponse.getHits()).thenReturn(searchHits); + Client client = mock(Client.class); + doAnswer(invocationOnMock -> { + ((ActionListener) invocationOnMock.getArguments()[1]).onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + // Setup the action + ClusterService clusterService = mock(ClusterService.class); + TransportService transportService = mock(TransportService.class); + TrainingJobRouterTransportAction transportAction = new TrainingJobRouterTransportAction( + transportService, + new ActionFilters(Collections.emptySet()), + clusterService, + client + ); + + ActionListener listener = ActionListener.wrap( + size -> size.intValue(), + e -> assertThat(e.getMessage(), containsString("Number of training points should be greater than")) + ); + + transportAction.getTrainingIndexSizeInKB(trainingModelRequest, listener); + } + public void testTrainIndexSize_whenDataTypeIsBinary() { String trainingIndexName = "training-index"; int dimension = 8; diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java index 6fd399434..fdffc91d0 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java @@ -621,11 +621,61 @@ public void testValidation_invalid_descriptionToLong() { ActionRequestValidationException exception = trainingModelRequest.validate(); assertNotNull(exception); List validationErrors = exception.validationErrors(); - logger.error("Validation errorsa " + validationErrors); + logger.error("Validation errors " + validationErrors); assertEquals(1, validationErrors.size()); assertTrue(validationErrors.get(0).contains("Description exceeds limit")); } + public void testValidation_invalid_mNotDivisibleByDimension() { + + // Setup the training request + String modelId = "test-model-id"; + int dimension = 10; + String trainingIndex = "test-training-index"; + String trainingField = "test-training-field"; + String trainingFieldModeId = "training-field-model-id"; + + Map parameters = Map.of("m", 3); + + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, parameters); + final KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.DEFAULT, methodComponentContext); + + TrainingModelRequest trainingModelRequest = new TrainingModelRequest( + modelId, + knnMethodContext, + dimension, + trainingIndex, + trainingField, + null, + null, + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED + ); + + // Mock the model dao to return metadata for modelId to recognize it is a duplicate + ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class); + when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); + + ModelDao modelDao = mock(ModelDao.class); + when(modelDao.getMetadata(modelId)).thenReturn(null); + when(modelDao.getMetadata(trainingFieldModeId)).thenReturn(trainingFieldModelMetadata); + + // Cluster service that wont produce validation exception + ClusterService clusterService = getClusterServiceForValidReturns(trainingIndex, trainingField, dimension); + + // Initialize static components with the mocks + TrainingModelRequest.initialize(modelDao, clusterService); + + // Test that validation produces m not divisible by vector dimension error message + ActionRequestValidationException exception = trainingModelRequest.validate(); + assertNotNull(exception); + List validationErrors = exception.validationErrors(); + logger.error("Validation errors " + validationErrors); + assertEquals(2, validationErrors.size()); + assertTrue(validationErrors.get(1).contains("Training request ENCODER_PARAMETER_PQ_M")); + } + public void testValidation_valid_trainingIndexBuiltFromMethod() { // This cluster service will result in no validation exceptions diff --git a/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java b/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java index ae162401b..9cff11271 100644 --- a/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java +++ b/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java @@ -67,7 +67,7 @@ public class RecallTestsIT extends KNNRestTestCase { private final static String TRAIN_FIELD_NAME = "train_field"; private final static String TEST_MODEL_ID = "test_model_id"; private final static int TEST_DIMENSION = 32; - private final static int DOC_COUNT = 500; + private final static int DOC_COUNT = 1100; private final static int QUERY_COUNT = 100; private final static int TEST_K = 100; private final static double PERFECT_RECALL = 1.0; diff --git a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java index 4706bd000..8db9d67bc 100644 --- a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java +++ b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java @@ -41,6 +41,7 @@ import java.util.UUID; import java.util.concurrent.ExecutionException; +import static org.hamcrest.Matchers.containsString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -217,7 +218,6 @@ public void testRun_success() throws IOException, ExecutionException { Model model = trainingJob.getModel(); assertNotNull(model); - assertEquals(ModelState.CREATED, model.getModelMetadata().getState()); // Simple test that creates the index from template and doesnt fail @@ -308,6 +308,10 @@ public void testRun_failure_onGetTrainingDataAllocation() throws ExecutionExcept Model model = trainingJob.getModel(); assertEquals(ModelState.FAILED, trainingJob.getModel().getModelMetadata().getState()); + assertThat( + "Failed to load training data into memory. " + "Check if there is enough memory to perform the request.", + containsString(trainingJob.getModel().getModelMetadata().getError()) + ); assertNotNull(model); assertFalse(model.getModelMetadata().getError().isEmpty()); } @@ -382,6 +386,10 @@ public void testRun_failure_onGetModelAnonymousAllocation() throws ExecutionExce Model model = trainingJob.getModel(); assertEquals(ModelState.FAILED, trainingJob.getModel().getModelMetadata().getState()); + assertThat( + "Failed to allocate space in native memory for the model. " + "Check if there is enough memory to perform the request.", + containsString(trainingJob.getModel().getModelMetadata().getError()) + ); assertNotNull(model); assertFalse(model.getModelMetadata().getError().isEmpty()); } @@ -435,7 +443,7 @@ public void testRun_failure_closedTrainingDataAllocation() throws ExecutionExcep when(nativeMemoryAllocation.isClosed()).thenReturn(true); when(nativeMemoryAllocation.getMemoryAddress()).thenReturn((long) 0); - // Throw error on getting data + // Throw error on allocation is closed when(nativeMemoryCacheManager.get(trainingDataEntryContext, false)).thenReturn(nativeMemoryAllocation); TrainingJob trainingJob = new TrainingJob( @@ -443,7 +451,83 @@ public void testRun_failure_closedTrainingDataAllocation() throws ExecutionExcep knnMethodContext, nativeMemoryCacheManager, trainingDataEntryContext, - mock(NativeMemoryEntryContext.AnonymousEntryContext.class), + modelContext, + knnMethodConfigContext, + "", + "test-node", + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED + ); + + trainingJob.run(); + + Model model = trainingJob.getModel(); + assertThat( + "Failed to execute training. Unable to load training data into memory: allocation is already closed", + containsString(trainingJob.getModel().getModelMetadata().getError()) + ); + assertNotNull(model); + assertEquals(ModelState.FAILED, trainingJob.getModel().getModelMetadata().getState()); + } + + public void testRun_failure_closedModelAnonymousAllocation() throws ExecutionException { + // In this test, the model anonymous allocation should be closed. Then, run should fail and update the error of + // the model + String modelId = "test-model-id"; + + // Define the method setup for method that requires training + int nlists = 5; + int dimension = 16; + KNNEngine knnEngine = KNNEngine.FAISS; + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .dimension(dimension) + .versionCreated(Version.CURRENT) + .build(); + KNNMethodContext knnMethodContext = new KNNMethodContext( + knnEngine, + SpaceType.INNER_PRODUCT, + new MethodComponentContext(METHOD_IVF, ImmutableMap.of(METHOD_PARAMETER_NLIST, nlists)) + ); + + String tdataKey = "t-data-key"; + NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = mock( + NativeMemoryEntryContext.TrainingDataEntryContext.class + ); + when(trainingDataEntryContext.getKey()).thenReturn(tdataKey); + + // Setup model manager + NativeMemoryCacheManager nativeMemoryCacheManager = mock(NativeMemoryCacheManager.class); + + // Setup mock allocation for model that's closed + NativeMemoryAllocation modelAllocation = mock(NativeMemoryAllocation.class); + doAnswer(invocationOnMock -> null).when(modelAllocation).readLock(); + doAnswer(invocationOnMock -> null).when(modelAllocation).readUnlock(); + when(modelAllocation.isClosed()).thenReturn(true); + + String modelKey = "model-test-key"; + NativeMemoryEntryContext.AnonymousEntryContext modelContext = mock(NativeMemoryEntryContext.AnonymousEntryContext.class); + when(modelContext.getKey()).thenReturn(modelKey); + + // Throw error on allocation is closed + when(nativeMemoryCacheManager.get(modelContext, false)).thenReturn(modelAllocation); + doAnswer(invocationOnMock -> null).when(nativeMemoryCacheManager).invalidate(modelKey); + + // Setup mock allocation thats not closed + NativeMemoryAllocation nativeMemoryAllocation = mock(NativeMemoryAllocation.class); + doAnswer(invocationOnMock -> null).when(nativeMemoryAllocation).readLock(); + doAnswer(invocationOnMock -> null).when(nativeMemoryAllocation).readUnlock(); + when(nativeMemoryAllocation.isClosed()).thenReturn(false); + when(nativeMemoryAllocation.getMemoryAddress()).thenReturn((long) 0); + + when(nativeMemoryCacheManager.get(trainingDataEntryContext, false)).thenReturn(nativeMemoryAllocation); + + TrainingJob trainingJob = new TrainingJob( + modelId, + knnMethodContext, + nativeMemoryCacheManager, + trainingDataEntryContext, + modelContext, knnMethodConfigContext, "", "test-node", @@ -454,6 +538,10 @@ public void testRun_failure_closedTrainingDataAllocation() throws ExecutionExcep trainingJob.run(); Model model = trainingJob.getModel(); + assertThat( + "Failed to execute training. Unable to reserve memory for model: allocation is already closed", + containsString(trainingJob.getModel().getModelMetadata().getError()) + ); assertNotNull(model); assertEquals(ModelState.FAILED, trainingJob.getModel().getModelMetadata().getState()); }