Skip to content

Commit

Permalink
Encapsulate dimension, vector data type validation/processing inside …
Browse files Browse the repository at this point in the history
…Library (opensearch-project#1957)

Introduces a new configuration object KNNMethodConfigContext.
The KNNMethodContext contains the user provided
information for want their index to be built like. However, it is
missing a few pieces that are defined outside of it. These
pieces are needed to validate the config and actually build the config.

This change Integrates validation of KNNMethodContexts with
KNNMethodConfigContext and better encapsulates KNNLibrary
config info in the engine package

Signed-off-by: John Mazanec <[email protected]>
Signed-off-by: Akash Shankaran <[email protected]>
  • Loading branch information
jmazanec15 authored and akashsha1 committed Sep 16, 2024
1 parent 08fe9f3 commit 6312f68
Show file tree
Hide file tree
Showing 71 changed files with 1,890 additions and 1,729 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Move k search k-NN query to re-write phase of vector search query for Native Engines [#1877](https://github.com/opensearch-project/k-NN/pull/1877)
* Restructure mappers to better handle null cases and avoid branching in parsing [#1939](https://github.com/opensearch-project/k-NN/pull/1939)
* Added Quantization Framework and implemented 1Bit and multibit quantizer[#1889](https://github.com/opensearch-project/k-NN/issues/1889)
* Encapsulate dimension, vector data type validation/processing inside Library [#1957](https://github.com/opensearch-project/k-NN/pull/1957)
46 changes: 46 additions & 0 deletions qa/restart-upgrade/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,29 @@ testClusters {
excludeTestsMatching "org.opensearch.knn.bwc.IndexingIT.testEmptyParametersOnUpgrade"
}
}

if (knn_bwc_version.startsWith("1.") ||
knn_bwc_version.startsWith("2.0.") ||
knn_bwc_version.startsWith("2.1.") ||
knn_bwc_version.startsWith("2.2.") ||
knn_bwc_version.startsWith("2.3.") ||
knn_bwc_version.startsWith("2.4") ||
knn_bwc_version.startsWith("2.5.") ||
knn_bwc_version.startsWith("2.6.") ||
knn_bwc_version.startsWith("2.7.") ||
knn_bwc_version.startsWith("2.8.") ||
knn_bwc_version.startsWith("2.9.") ||
knn_bwc_version.startsWith("2.10.") ||
knn_bwc_version.startsWith("2.11.") ||
knn_bwc_version.startsWith("2.12.") ||
knn_bwc_version.startsWith("2.13.") ||
knn_bwc_version.startsWith("2.14.") ||
knn_bwc_version.startsWith("2.15.")) {
filter {
excludeTestsMatching "org.opensearch.knn.bwc.IndexingIT.testKNNIndexBinaryForceMerge"
}
}

nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}".allHttpSocketURI.join(",")}")
nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}".getName()}")
systemProperty 'tests.security.manager', 'false'
Expand Down Expand Up @@ -109,6 +132,29 @@ testClusters {
excludeTestsMatching "org.opensearch.knn.bwc.IndexingIT.testEmptyParametersOnUpgrade"
}
}

if (knn_bwc_version.startsWith("1.") ||
knn_bwc_version.startsWith("2.0.") ||
knn_bwc_version.startsWith("2.1.") ||
knn_bwc_version.startsWith("2.2.") ||
knn_bwc_version.startsWith("2.3.") ||
knn_bwc_version.startsWith("2.4") ||
knn_bwc_version.startsWith("2.5.") ||
knn_bwc_version.startsWith("2.6.") ||
knn_bwc_version.startsWith("2.7.") ||
knn_bwc_version.startsWith("2.8.") ||
knn_bwc_version.startsWith("2.9.") ||
knn_bwc_version.startsWith("2.10.") ||
knn_bwc_version.startsWith("2.11.") ||
knn_bwc_version.startsWith("2.12.") ||
knn_bwc_version.startsWith("2.13.") ||
knn_bwc_version.startsWith("2.14.") ||
knn_bwc_version.startsWith("2.15.")) {
filter {
excludeTestsMatching "org.opensearch.knn.bwc.IndexingIT.testKNNIndexBinaryForceMerge"
}
}

nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}".allHttpSocketURI.join(",")}")
nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}".getName()}")
systemProperty 'tests.security.manager', 'false'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,4 @@ protected static final boolean isRunningAgainstOldCluster() {
protected final Optional<String> getBWCVersion() {
return Optional.ofNullable(System.getProperty(BWC_VERSION, null));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import org.junit.Assert;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;

import java.util.Map;

Expand Down Expand Up @@ -68,6 +70,33 @@ public void testKNNIndexDefaultLegacyFieldMappingForceMerge() throws Exception {
}
}

// Ensure bwc works for binary force merge
public void testKNNIndexBinaryForceMerge() throws Exception {
int dimension = 40;

waitForClusterHealthGreen(NODES_BWC_CLUSTER);
if (isRunningAgainstOldCluster()) {
createKnnIndex(
testIndex,
getKNNDefaultIndexSettings(),
createKnnIndexMapping(
TEST_FIELD,
dimension,
METHOD_HNSW,
KNNEngine.FAISS.getName(),
SpaceType.HAMMING.getValue(),
true,
VectorDataType.BINARY
)
);
addKNNByteDocs(testIndex, TEST_FIELD, dimension / 8, DOC_ID, 100);
// Flush to ensure that index is not re-indexed when node comes back up
flush(testIndex, true);
} else {
forceMergeKnnIndex(testIndex);
}
}

// Custom Legacy Field Mapping
// space_type : "linf", engine : "nmslib", m : 2, ef_construction : 2
public void testKNNIndexCustomLegacyFieldMapping() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,17 +243,11 @@ private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pa
);
}

// Update index description of Faiss for binary data type
if (KNNEngine.FAISS == knnEngine
&& VectorDataType.BINARY.getValue()
.equals(fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()))
&& parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null) {
parameters.put(
KNNConstants.INDEX_DESCRIPTION_PARAMETER,
FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString()
);
IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY);
}
// In OpenSearch 2.16, we added the prefix for binary indices in the index description in the codec logic.
// After 2.16, we added the binary prefix in the faiss library code. However, to ensure backwards compatibility,
// we need to ensure that if the description does not contain the prefix but the type is binary, we add the
// description.
maybeAddBinaryPrefixForFaissBWC(knnEngine, parameters, fieldAttributes);

// Used to determine how many threads to use when indexing
parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY));
Expand All @@ -265,6 +259,31 @@ private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pa
});
}

private void maybeAddBinaryPrefixForFaissBWC(KNNEngine knnEngine, Map<String, Object> parameters, Map<String, String> fieldAttributes) {
if (KNNEngine.FAISS != knnEngine) {
return;
}

if (!VectorDataType.BINARY.getValue()
.equals(fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()))) {
return;
}

if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) == null) {
return;
}

if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString().startsWith(FAISS_BINARY_INDEX_DESCRIPTION_PREFIX)) {
return;
}

parameters.put(
KNNConstants.INDEX_DESCRIPTION_PARAMETER,
FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString()
);
IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY);
}

/**
* Merges in the fields from the readers in mergeState
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import lombok.AllArgsConstructor;
import lombok.Getter;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.training.VectorSpaceInfo;
import org.opensearch.knn.index.VectorDataType;

import java.util.Locale;
import java.util.Map;

/**
Expand All @@ -25,44 +26,94 @@ public abstract class AbstractKNNLibrary implements KNNLibrary {

@Override
public KNNLibrarySearchContext getKNNLibrarySearchContext(String methodName) {
validateMethodExists(methodName);
throwIllegalArgOnNonNull(validateMethodExists(methodName));
KNNMethod method = methods.get(methodName);
return method.getKNNLibrarySearchContext();
}

@Override
public KNNLibraryIndexingContext getKNNLibraryIndexingContext(KNNMethodContext knnMethodContext) {
public KNNLibraryIndexingContext getKNNLibraryIndexingContext(
KNNMethodContext knnMethodContext,
KNNMethodConfigContext knnMethodConfigContext
) {
String method = knnMethodContext.getMethodComponentContext().getName();
validateMethodExists(method);
throwIllegalArgOnNonNull(validateMethodExists(method));
KNNMethod knnMethod = methods.get(method);
return knnMethod.getKNNLibraryIndexingContext(knnMethodContext);
return knnMethod.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext);
}

@Override
public ValidationException validateMethod(KNNMethodContext knnMethodContext) {
public ValidationException validateMethod(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
String methodName = knnMethodContext.getMethodComponentContext().getName();
validateMethodExists(methodName);
return methods.get(methodName).validate(knnMethodContext);
ValidationException validationException = null;
String invalidErrorMessage = validateMethodExists(methodName);
if (invalidErrorMessage != null) {
validationException = new ValidationException();
validationException.addValidationError(invalidErrorMessage);
return validationException;
}
invalidErrorMessage = validateDimension(knnMethodContext, knnMethodConfigContext);
if (invalidErrorMessage != null) {
validationException = new ValidationException();
validationException.addValidationError(invalidErrorMessage);
}

validateSpaceType(knnMethodContext, knnMethodConfigContext);
ValidationException methodValidation = methods.get(methodName).validate(knnMethodContext, knnMethodConfigContext);
if (methodValidation != null) {
validationException = validationException == null ? new ValidationException() : validationException;
validationException.addValidationErrors(methodValidation.validationErrors());
}

return validationException;
}

@Override
public ValidationException validateMethodWithData(KNNMethodContext knnMethodContext, VectorSpaceInfo vectorSpaceInfo) {
String methodName = knnMethodContext.getMethodComponentContext().getName();
validateMethodExists(methodName);
return methods.get(methodName).validateWithData(knnMethodContext, vectorSpaceInfo);
private void validateSpaceType(final KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
if (knnMethodContext == null) {
return;
}
knnMethodContext.getSpaceType().validateVectorDataType(knnMethodConfigContext.getVectorDataType());
}

private String validateDimension(final KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
if (knnMethodContext == null) {
return null;
}
int dimension = knnMethodConfigContext.getDimension();
if (dimension > KNNEngine.getMaxDimensionByEngine(knnMethodContext.getKnnEngine())) {
return String.format(
Locale.ROOT,
"Dimension value cannot be greater than %s for vector with engine: %s",
KNNEngine.getMaxDimensionByEngine(knnMethodContext.getKnnEngine()),
knnMethodContext.getKnnEngine().getName()
);
}

if (VectorDataType.BINARY == knnMethodConfigContext.getVectorDataType() && dimension % 8 != 0) {
return "Dimension should be multiply of 8 for binary vector data type";
}

return null;
}

@Override
public boolean isTrainingRequired(KNNMethodContext knnMethodContext) {
String methodName = knnMethodContext.getMethodComponentContext().getName();
validateMethodExists(methodName);
throwIllegalArgOnNonNull(validateMethodExists(methodName));
return methods.get(methodName).isTrainingRequired(knnMethodContext);
}

private void validateMethodExists(String methodName) {
private String validateMethodExists(String methodName) {
KNNMethod method = methods.get(methodName);
if (method == null) {
throw new IllegalArgumentException(String.format("Invalid method name: %s", methodName));
return String.format("Invalid method name: %s", methodName);
}
return null;
}

private void throwIllegalArgOnNonNull(String errorMessage) {
if (errorMessage != null) {
throw new IllegalArgumentException(errorMessage);
}
}
}
Loading

0 comments on commit 6312f68

Please sign in to comment.