Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Encapsulate dimension, vector data type validation/processing inside … #1983

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Move k search k-NN query to re-write phase of vector search query for Native Engines [#1877](https://github.com/opensearch-project/k-NN/pull/1877)
* Restructure mappers to better handle null cases and avoid branching in parsing [#1939](https://github.com/opensearch-project/k-NN/pull/1939)
* Added Quantization Framework and implemented 1Bit and multibit quantizer[#1889](https://github.com/opensearch-project/k-NN/issues/1889)
* Encapsulate dimension, vector data type validation/processing inside Library [#1957](https://github.com/opensearch-project/k-NN/pull/1957)
46 changes: 46 additions & 0 deletions qa/restart-upgrade/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,29 @@ testClusters {
excludeTestsMatching "org.opensearch.knn.bwc.IndexingIT.testEmptyParametersOnUpgrade"
}
}

if (knn_bwc_version.startsWith("1.") ||
knn_bwc_version.startsWith("2.0.") ||
knn_bwc_version.startsWith("2.1.") ||
knn_bwc_version.startsWith("2.2.") ||
knn_bwc_version.startsWith("2.3.") ||
knn_bwc_version.startsWith("2.4") ||
knn_bwc_version.startsWith("2.5.") ||
knn_bwc_version.startsWith("2.6.") ||
knn_bwc_version.startsWith("2.7.") ||
knn_bwc_version.startsWith("2.8.") ||
knn_bwc_version.startsWith("2.9.") ||
knn_bwc_version.startsWith("2.10.") ||
knn_bwc_version.startsWith("2.11.") ||
knn_bwc_version.startsWith("2.12.") ||
knn_bwc_version.startsWith("2.13.") ||
knn_bwc_version.startsWith("2.14.") ||
knn_bwc_version.startsWith("2.15.")) {
filter {
excludeTestsMatching "org.opensearch.knn.bwc.IndexingIT.testKNNIndexBinaryForceMerge"
}
}

nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}".allHttpSocketURI.join(",")}")
nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}".getName()}")
systemProperty 'tests.security.manager', 'false'
Expand Down Expand Up @@ -101,6 +124,29 @@ testClusters {
excludeTestsMatching "org.opensearch.knn.bwc.IndexingIT.testEmptyParametersOnUpgrade"
}
}

if (knn_bwc_version.startsWith("1.") ||
knn_bwc_version.startsWith("2.0.") ||
knn_bwc_version.startsWith("2.1.") ||
knn_bwc_version.startsWith("2.2.") ||
knn_bwc_version.startsWith("2.3.") ||
knn_bwc_version.startsWith("2.4") ||
knn_bwc_version.startsWith("2.5.") ||
knn_bwc_version.startsWith("2.6.") ||
knn_bwc_version.startsWith("2.7.") ||
knn_bwc_version.startsWith("2.8.") ||
knn_bwc_version.startsWith("2.9.") ||
knn_bwc_version.startsWith("2.10.") ||
knn_bwc_version.startsWith("2.11.") ||
knn_bwc_version.startsWith("2.12.") ||
knn_bwc_version.startsWith("2.13.") ||
knn_bwc_version.startsWith("2.14.") ||
knn_bwc_version.startsWith("2.15.")) {
filter {
excludeTestsMatching "org.opensearch.knn.bwc.IndexingIT.testKNNIndexBinaryForceMerge"
}
}

nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}".allHttpSocketURI.join(",")}")
nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}".getName()}")
systemProperty 'tests.security.manager', 'false'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,4 @@ protected static final boolean isRunningAgainstOldCluster() {
protected final Optional<String> getBWCVersion() {
return Optional.ofNullable(System.getProperty(BWC_VERSION, null));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import org.junit.Assert;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;

import java.util.Map;

Expand Down Expand Up @@ -68,6 +70,33 @@ public void testKNNIndexDefaultLegacyFieldMappingForceMerge() throws Exception {
}
}

// Ensure bwc works for binary force merge
public void testKNNIndexBinaryForceMerge() throws Exception {
int dimension = 40;

waitForClusterHealthGreen(NODES_BWC_CLUSTER);
if (isRunningAgainstOldCluster()) {
createKnnIndex(
testIndex,
getKNNDefaultIndexSettings(),
createKnnIndexMapping(
TEST_FIELD,
dimension,
METHOD_HNSW,
KNNEngine.FAISS.getName(),
SpaceType.HAMMING.getValue(),
true,
VectorDataType.BINARY
)
);
addKNNByteDocs(testIndex, TEST_FIELD, dimension / 8, DOC_ID, 100);
// Flush to ensure that index is not re-indexed when node comes back up
flush(testIndex, true);
} else {
forceMergeKnnIndex(testIndex);
}
}

// Custom Legacy Field Mapping
// space_type : "linf", engine : "nmslib", m : 2, ef_construction : 2
public void testKNNIndexCustomLegacyFieldMapping() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,17 +238,11 @@ private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pa
);
}

// Update index description of Faiss for binary data type
if (KNNEngine.FAISS == knnEngine
&& VectorDataType.BINARY.getValue()
.equals(fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()))
&& parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null) {
parameters.put(
KNNConstants.INDEX_DESCRIPTION_PARAMETER,
FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString()
);
IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY);
}
// In OpenSearch 2.16, we added the prefix for binary indices in the index description in the codec logic.
// After 2.16, we added the binary prefix in the faiss library code. However, to ensure backwards compatibility,
// we need to ensure that if the description does not contain the prefix but the type is binary, we add the
// description.
maybeAddBinaryPrefixForFaissBWC(knnEngine, parameters, fieldAttributes);

// Used to determine how many threads to use when indexing
parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY));
Expand All @@ -260,6 +254,31 @@ private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pa
});
}

private void maybeAddBinaryPrefixForFaissBWC(KNNEngine knnEngine, Map<String, Object> parameters, Map<String, String> fieldAttributes) {
if (KNNEngine.FAISS != knnEngine) {
return;
}

if (!VectorDataType.BINARY.getValue()
.equals(fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()))) {
return;
}

if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) == null) {
return;
}

if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString().startsWith(FAISS_BINARY_INDEX_DESCRIPTION_PREFIX)) {
return;
}

parameters.put(
KNNConstants.INDEX_DESCRIPTION_PARAMETER,
FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString()
);
IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY);
}

/**
* Merges in the fields from the readers in mergeState
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import lombok.AllArgsConstructor;
import lombok.Getter;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.training.VectorSpaceInfo;
import org.opensearch.knn.index.VectorDataType;

import java.util.Locale;
import java.util.Map;

/**
Expand All @@ -25,44 +26,94 @@ public abstract class AbstractKNNLibrary implements KNNLibrary {

@Override
public KNNLibrarySearchContext getKNNLibrarySearchContext(String methodName) {
validateMethodExists(methodName);
throwIllegalArgOnNonNull(validateMethodExists(methodName));
KNNMethod method = methods.get(methodName);
return method.getKNNLibrarySearchContext();
}

@Override
public KNNLibraryIndexingContext getKNNLibraryIndexingContext(KNNMethodContext knnMethodContext) {
public KNNLibraryIndexingContext getKNNLibraryIndexingContext(
KNNMethodContext knnMethodContext,
KNNMethodConfigContext knnMethodConfigContext
) {
String method = knnMethodContext.getMethodComponentContext().getName();
validateMethodExists(method);
throwIllegalArgOnNonNull(validateMethodExists(method));
KNNMethod knnMethod = methods.get(method);
return knnMethod.getKNNLibraryIndexingContext(knnMethodContext);
return knnMethod.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext);
}

@Override
public ValidationException validateMethod(KNNMethodContext knnMethodContext) {
public ValidationException validateMethod(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
String methodName = knnMethodContext.getMethodComponentContext().getName();
validateMethodExists(methodName);
return methods.get(methodName).validate(knnMethodContext);
ValidationException validationException = null;
String invalidErrorMessage = validateMethodExists(methodName);
if (invalidErrorMessage != null) {
validationException = new ValidationException();
validationException.addValidationError(invalidErrorMessage);
return validationException;
}
invalidErrorMessage = validateDimension(knnMethodContext, knnMethodConfigContext);
if (invalidErrorMessage != null) {
validationException = new ValidationException();
validationException.addValidationError(invalidErrorMessage);
}

validateSpaceType(knnMethodContext, knnMethodConfigContext);
ValidationException methodValidation = methods.get(methodName).validate(knnMethodContext, knnMethodConfigContext);
if (methodValidation != null) {
validationException = validationException == null ? new ValidationException() : validationException;
validationException.addValidationErrors(methodValidation.validationErrors());
}

return validationException;
}

@Override
public ValidationException validateMethodWithData(KNNMethodContext knnMethodContext, VectorSpaceInfo vectorSpaceInfo) {
String methodName = knnMethodContext.getMethodComponentContext().getName();
validateMethodExists(methodName);
return methods.get(methodName).validateWithData(knnMethodContext, vectorSpaceInfo);
private void validateSpaceType(final KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
if (knnMethodContext == null) {
return;
}
knnMethodContext.getSpaceType().validateVectorDataType(knnMethodConfigContext.getVectorDataType());
}

private String validateDimension(final KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
if (knnMethodContext == null) {
return null;
}
int dimension = knnMethodConfigContext.getDimension();
if (dimension > KNNEngine.getMaxDimensionByEngine(knnMethodContext.getKnnEngine())) {
return String.format(
Locale.ROOT,
"Dimension value cannot be greater than %s for vector with engine: %s",
KNNEngine.getMaxDimensionByEngine(knnMethodContext.getKnnEngine()),
knnMethodContext.getKnnEngine().getName()
);
}

if (VectorDataType.BINARY == knnMethodConfigContext.getVectorDataType() && dimension % 8 != 0) {
return "Dimension should be multiply of 8 for binary vector data type";
}

return null;
}

@Override
public boolean isTrainingRequired(KNNMethodContext knnMethodContext) {
String methodName = knnMethodContext.getMethodComponentContext().getName();
validateMethodExists(methodName);
throwIllegalArgOnNonNull(validateMethodExists(methodName));
return methods.get(methodName).isTrainingRequired(knnMethodContext);
}

private void validateMethodExists(String methodName) {
private String validateMethodExists(String methodName) {
KNNMethod method = methods.get(methodName);
if (method == null) {
throw new IllegalArgumentException(String.format("Invalid method name: %s", methodName));
return String.format("Invalid method name: %s", methodName);
}
return null;
}

private void throwIllegalArgOnNonNull(String errorMessage) {
if (errorMessage != null) {
throw new IllegalArgumentException(errorMessage);
}
}
}
Loading
Loading