Skip to content

Commit

Permalink
Merge branch '2.x' into backport/backport-1950-to-2.x
Browse files Browse the repository at this point in the history
Signed-off-by: Tejas Shah <[email protected]>
  • Loading branch information
shatejas committed Aug 22, 2024
2 parents ac277d9 + 8dffc2b commit 8b843f4
Show file tree
Hide file tree
Showing 36 changed files with 1,698 additions and 163 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Add script_fields context to KNNAllowlist [#1917] (https://github.com/opensearch-project/k-NN/pull/1917)
* Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844)
* Disallow a vector field to have an invalid character for a physical file name. [#1936](https://github.com/opensearch-project/k-NN/pull/1936)
* Add script_fields context to KNNAllowlist [#1917] (https://github.com/opensearch-project/k-NN/pull/1917)
### Infrastructure
### Documentation
### Maintenance
Expand All @@ -41,3 +42,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Restructure mappers to better handle null cases and avoid branching in parsing [#1939](https://github.com/opensearch-project/k-NN/pull/1939)
* Added Quantization Framework and implemented 1Bit and multibit quantizer[#1889](https://github.com/opensearch-project/k-NN/issues/1889)
* Encapsulate dimension, vector data type validation/processing inside Library [#1957](https://github.com/opensearch-project/k-NN/pull/1957)
* Add quantization state cache [#1960](https://github.com/opensearch-project/k-NN/pull/1960)
18 changes: 18 additions & 0 deletions src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@

import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.indices.ModelUtil.getModelMetadata;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.engine.qframe.QuantizationConfigParser;

import static org.opensearch.knn.common.KNNConstants.QFRAMEWORK_CONFIG;

/**
* A utility class to extract information from FieldInfo.
Expand Down Expand Up @@ -52,4 +56,18 @@ public static VectorDataType extractVectorDataType(final FieldInfo fieldInfo) {
}
return StringUtils.isNotEmpty(vectorDataTypeString) ? VectorDataType.get(vectorDataTypeString) : VectorDataType.DEFAULT;
}

/**
* Extract quantization config from fieldInfo
*
* @param fieldInfo {@link FieldInfo}
* @return {@link QuantizationConfig}
*/
public static QuantizationConfig extractQuantizationConfig(final FieldInfo fieldInfo) {
String quantizationConfigString = fieldInfo.getAttribute(QFRAMEWORK_CONFIG);
if (StringUtils.isEmpty(quantizationConfigString)) {
return QuantizationConfig.EMPTY;
}
return QuantizationConfigParser.fromCsv(quantizationConfigString);
}
}
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ public class KNNConstants {
public static final String MAX_VECTOR_COUNT_PARAMETER = "max_training_vector_count";
public static final String SEARCH_SIZE_PARAMETER = "search_size";

public static final String QFRAMEWORK_CONFIG = "qframe_config";

public static final String VECTOR_DATA_TYPE_FIELD = "data_type";
public static final String MODEL_VECTOR_DATA_TYPE_KEY = VECTOR_DATA_TYPE_FIELD;
public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT;
Expand Down
65 changes: 64 additions & 1 deletion src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.memory.NativeMemoryCacheManagerDto;
import org.opensearch.knn.index.util.IndexHyperParametersUtil;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCache;
import org.opensearch.monitor.jvm.JvmInfo;
import org.opensearch.monitor.os.OsProbe;

Expand Down Expand Up @@ -88,6 +89,8 @@ public class KNNSettings {
* for native engines.
*/
public static final String KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED = "knn.use.format.enabled";
public static final String QUANTIZATION_STATE_CACHE_SIZE_LIMIT = "knn.quantization.cache.size.limit";
public static final String QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = "knn.quantization.cache.expiry.minutes";

/**
* Default setting values
Expand All @@ -106,6 +109,11 @@ public class KNNSettings {
public static final String KNN_DEFAULT_VECTOR_STREAMING_MEMORY_LIMIT_PCT = "1%";

public static final Integer ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE = -1;
public static final Integer KNN_DEFAULT_QUANTIZATION_STATE_CACHE_SIZE_LIMIT_PERCENTAGE = 5; // By default, set aside 5% of the JVM for
// the limit
public static final Integer KNN_MAX_QUANTIZATION_STATE_CACHE_SIZE_LIMIT_PERCENTAGE = 10; // Quantization state cache limit cannot exceed
// 10% of the JVM heap
public static final Integer KNN_DEFAULT_QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = 60;

/**
* Settings Definition
Expand Down Expand Up @@ -272,6 +280,44 @@ public class KNNSettings {
NodeScope
);

/*
* Quantization state cache settings
*/
public static final Setting<ByteSizeValue> QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING = new Setting<ByteSizeValue>(
QUANTIZATION_STATE_CACHE_SIZE_LIMIT,
percentageAsString(KNN_DEFAULT_QUANTIZATION_STATE_CACHE_SIZE_LIMIT_PERCENTAGE),
(s) -> {
ByteSizeValue userDefinedLimit = parseBytesSizeValueOrHeapRatio(s, QUANTIZATION_STATE_CACHE_SIZE_LIMIT);

// parseBytesSizeValueOrHeapRatio will make sure that the value entered falls between 0 and 100% of the
// JVM heap. However, we want the maximum percentage of the heap to be much smaller. So, we add
// some additional validation here before returning
ByteSizeValue jvmHeapSize = JvmInfo.jvmInfo().getMem().getHeapMax();
if ((userDefinedLimit.getKbFrac() / jvmHeapSize.getKbFrac()) > percentageAsFraction(
KNN_MAX_QUANTIZATION_STATE_CACHE_SIZE_LIMIT_PERCENTAGE
)) {
throw new OpenSearchParseException(
"{} ({} KB) cannot exceed {}% of the heap ({} KB).",
QUANTIZATION_STATE_CACHE_SIZE_LIMIT,
userDefinedLimit.getKb(),
KNN_MAX_QUANTIZATION_STATE_CACHE_SIZE_LIMIT_PERCENTAGE,
jvmHeapSize.getKb()
);
}

return userDefinedLimit;
},
NodeScope,
Dynamic
);

public static final Setting<TimeValue> QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING = Setting.positiveTimeSetting(
QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES,
TimeValue.timeValueMinutes(KNN_DEFAULT_QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES),
NodeScope,
Dynamic
);

/**
* Dynamic settings
*/
Expand Down Expand Up @@ -349,6 +395,13 @@ private void setSettingsUpdateConsumers() {

NativeMemoryCacheManager.getInstance().rebuildCache(builder.build());
}, Stream.concat(dynamicCacheSettings.values().stream(), FEATURE_FLAGS.values().stream()).collect(Collectors.toUnmodifiableList()));
clusterService.getClusterSettings().addSettingsUpdateConsumer(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, it -> {
QuantizationStateCache.getInstance().setMaxCacheSizeInKB(it.getKb());
QuantizationStateCache.getInstance().rebuildCache();
});
clusterService.getClusterSettings().addSettingsUpdateConsumer(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING, it -> {
QuantizationStateCache.getInstance().rebuildCache();
});
}

/**
Expand Down Expand Up @@ -400,6 +453,14 @@ private Setting<?> getSetting(String key) {
return KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING;
}

if (QUANTIZATION_STATE_CACHE_SIZE_LIMIT.equals(key)) {
return QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING;
}

if (QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES.equals(key)) {
return QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING;
}

throw new IllegalArgumentException("Cannot find setting by key [" + key + "]");
}

Expand All @@ -419,7 +480,9 @@ public List<Setting<?>> getSettings() {
ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_SETTING,
KNN_FAISS_AVX2_DISABLED_SETTING,
KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING,
KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING
KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING,
QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING,
QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING
);
return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream()))
.collect(Collectors.toList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.opensearch.knn.index.mapper.VectorValidator;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand Down Expand Up @@ -112,12 +111,15 @@ public KNNLibraryIndexingContext getKNNLibraryIndexingContext(
KNNMethodContext knnMethodContext,
KNNMethodConfigContext knnMethodConfigContext
) {
Map<String, Object> parameterMap = new HashMap<>(
methodComponent.getAsMap(knnMethodContext.getMethodComponentContext(), knnMethodConfigContext)
KNNLibraryIndexingContext knnLibraryIndexingContext = methodComponent.getKNNLibraryIndexingContext(
knnMethodContext.getMethodComponentContext(),
knnMethodConfigContext
);
Map<String, Object> parameterMap = knnLibraryIndexingContext.getLibraryParameters();
parameterMap.put(KNNConstants.SPACE_TYPE, knnMethodContext.getSpaceType().getValue());
parameterMap.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, knnMethodConfigContext.getVectorDataType().getValue());
return KNNLibraryIndexingContextImpl.builder()
.quantizationConfig(knnLibraryIndexingContext.getQuantizationConfig())
.parameters(parameterMap)
.vectorValidator(doGetVectorValidator(knnMethodContext, knnMethodConfigContext))
.perDimensionValidator(doGetPerDimensionValidator(knnMethodContext, knnMethodConfigContext))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.index.engine;

import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.mapper.PerDimensionProcessor;
import org.opensearch.knn.index.mapper.PerDimensionValidator;
import org.opensearch.knn.index.mapper.VectorValidator;
Expand All @@ -22,6 +23,13 @@ public interface KNNLibraryIndexingContext {
*/
Map<String, Object> getLibraryParameters();

/**
* Get map of parameters that get passed to the quantization framework
*
* @return Map of parameters
*/
QuantizationConfig getQuantizationConfig();

/**
*
* @return Get the vector validator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
package org.opensearch.knn.index.engine;

import lombok.Builder;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.mapper.PerDimensionProcessor;
import org.opensearch.knn.index.mapper.PerDimensionValidator;
import org.opensearch.knn.index.mapper.VectorValidator;

import java.util.Collections;
import java.util.Map;

/**
Expand All @@ -21,13 +23,21 @@ public class KNNLibraryIndexingContextImpl implements KNNLibraryIndexingContext
private VectorValidator vectorValidator;
private PerDimensionValidator perDimensionValidator;
private PerDimensionProcessor perDimensionProcessor;
private Map<String, Object> parameters;
@Builder.Default
private Map<String, Object> parameters = Collections.emptyMap();
@Builder.Default
private QuantizationConfig quantizationConfig = QuantizationConfig.EMPTY;

@Override
public Map<String, Object> getLibraryParameters() {
return parameters;
}

@Override
public QuantizationConfig getQuantizationConfig() {
return quantizationConfig;
}

@Override
public VectorValidator getVectorValidator() {
return vectorValidator;
Expand Down
38 changes: 26 additions & 12 deletions src/main/java/org/opensearch/knn/index/engine/MethodComponent.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ public class MethodComponent {
private final String name;
@Getter
private final Map<String, Parameter<?>> parameters;
private final TriFunction<MethodComponent, MethodComponentContext, KNNMethodConfigContext, Map<String, Object>> mapGenerator;
private final TriFunction<
MethodComponent,
MethodComponentContext,
KNNMethodConfigContext,
KNNLibraryIndexingContext> knnLibraryIndexingContextGenerator;
private final TriFunction<MethodComponent, MethodComponentContext, Integer, Long> overheadInKBEstimator;
private final boolean requiresTraining;
private final Set<VectorDataType> supportedVectorDataTypes;
Expand All @@ -43,7 +47,7 @@ public class MethodComponent {
private MethodComponent(Builder builder) {
this.name = builder.name;
this.parameters = builder.parameters;
this.mapGenerator = builder.mapGenerator;
this.knnLibraryIndexingContextGenerator = builder.knnLibraryIndexingContextGenerator;
this.overheadInKBEstimator = builder.overheadInKBEstimator;
this.requiresTraining = builder.requiresTraining;
this.supportedVectorDataTypes = builder.supportedDataTypes;
Expand All @@ -55,17 +59,20 @@ private MethodComponent(Builder builder) {
* @param methodComponentContext from which to generate map
* @return Method component as a map
*/
public Map<String, Object> getAsMap(MethodComponentContext methodComponentContext, KNNMethodConfigContext knnMethodConfigContext) {
if (mapGenerator == null) {
public KNNLibraryIndexingContext getKNNLibraryIndexingContext(
MethodComponentContext methodComponentContext,
KNNMethodConfigContext knnMethodConfigContext
) {
if (knnLibraryIndexingContextGenerator == null) {
Map<String, Object> parameterMap = new HashMap<>();
parameterMap.put(KNNConstants.NAME, methodComponentContext.getName());
parameterMap.put(
KNNConstants.PARAMETERS,
getParameterMapWithDefaultsAdded(methodComponentContext, this, knnMethodConfigContext)
);
return parameterMap;
return KNNLibraryIndexingContextImpl.builder().parameters(parameterMap).build();
}
return mapGenerator.apply(this, methodComponentContext, knnMethodConfigContext);
return knnLibraryIndexingContextGenerator.apply(this, methodComponentContext, knnMethodConfigContext);
}

/**
Expand Down Expand Up @@ -209,7 +216,11 @@ public static class Builder {

private final String name;
private final Map<String, Parameter<?>> parameters;
private TriFunction<MethodComponent, MethodComponentContext, KNNMethodConfigContext, Map<String, Object>> mapGenerator;
private TriFunction<
MethodComponent,
MethodComponentContext,
KNNMethodConfigContext,
KNNLibraryIndexingContext> knnLibraryIndexingContextGenerator;
private TriFunction<MethodComponent, MethodComponentContext, Integer, Long> overheadInKBEstimator;
private boolean requiresTraining;
private final Set<VectorDataType> supportedDataTypes;
Expand All @@ -227,7 +238,6 @@ public static Builder builder(String name) {
private Builder(String name) {
this.name = name;
this.parameters = new HashMap<>();
this.mapGenerator = null;
this.overheadInKBEstimator = (mc, mcc, d) -> 0L;
this.supportedDataTypes = new HashSet<>();
}
Expand All @@ -247,13 +257,17 @@ public Builder addParameter(String parameterName, Parameter<?> parameter) {
/**
* Set the function used to parse a MethodComponentContext as a map
*
* @param mapGenerator function to parse a MethodComponentContext as a map
* @param knnLibraryIndexingContextGenerator function to parse a MethodComponentContext as a knnLibraryIndexingContext
* @return this builder
*/
public Builder setMapGenerator(
TriFunction<MethodComponent, MethodComponentContext, KNNMethodConfigContext, Map<String, Object>> mapGenerator
public Builder setKnnLibraryIndexingContextGenerator(
TriFunction<
MethodComponent,
MethodComponentContext,
KNNMethodConfigContext,
KNNLibraryIndexingContext> knnLibraryIndexingContextGenerator
) {
this.mapGenerator = mapGenerator;
this.knnLibraryIndexingContextGenerator = knnLibraryIndexingContextGenerator;
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.AbstractKNNMethod;
import org.opensearch.knn.index.engine.KNNLibraryIndexingContext;
import org.opensearch.knn.index.engine.KNNLibrarySearchContext;
import org.opensearch.knn.index.engine.KNNMethodConfigContext;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.MethodComponent;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.mapper.PerDimensionProcessor;
import org.opensearch.knn.index.mapper.PerDimensionValidator;

import java.util.Objects;
import java.util.Set;

import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.index.engine.faiss.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX;
import static org.opensearch.knn.index.engine.faiss.FaissFP16Util.isFaissSQClipToFP16RangeEnabled;
import static org.opensearch.knn.index.engine.faiss.FaissFP16Util.isFaissSQfp16;

Expand Down Expand Up @@ -81,4 +86,37 @@ protected PerDimensionProcessor doGetPerDimensionProcessor(

throw new IllegalStateException("Unsupported vector data type " + vectorDataType);
}

static KNNLibraryIndexingContext adjustPrefix(
MethodAsMapBuilder methodAsMapBuilder,
MethodComponentContext methodComponentContext,
KNNMethodConfigContext knnMethodConfigContext
) {
String prefix = "";
MethodComponentContext encoderContext = getEncoderMethodComponent(methodComponentContext);
// We need to update the prefix used to create the faiss index if we are using the quantization
// framework
if (encoderContext != null && Objects.equals(encoderContext.getName(), QFrameBitEncoder.NAME)) {
// TODO: Uncomment to use Quantization framework
// leaving commented now just so it wont fail creating faiss indices.
// prefix = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX;
}

if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BINARY) {
prefix = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX;
}
methodAsMapBuilder.indexDescription = prefix + methodAsMapBuilder.indexDescription;
return methodAsMapBuilder.build();
}

static MethodComponentContext getEncoderMethodComponent(MethodComponentContext methodComponentContext) {
if (!methodComponentContext.getParameters().containsKey(METHOD_ENCODER_PARAMETER)) {
return null;
}
Object object = methodComponentContext.getParameters().get(METHOD_ENCODER_PARAMETER);
if (!(object instanceof MethodComponentContext)) {
return null;
}
return (MethodComponentContext) object;
}
}
Loading

0 comments on commit 8b843f4

Please sign in to comment.