Skip to content

Commit

Permalink
Change ModelFieldMapper to initialize per method
Browse files Browse the repository at this point in the history
Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 committed Aug 9, 2024
1 parent 79e0a4b commit e22fcf8
Showing 1 changed file with 79 additions and 79 deletions.
158 changes: 79 additions & 79 deletions src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import org.opensearch.Version;
import org.opensearch.common.Explicit;
import org.opensearch.index.mapper.ParseContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.indices.ModelDao;
Expand All @@ -19,8 +18,6 @@
import java.io.IOException;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
Expand All @@ -35,14 +32,9 @@ public class ModelFieldMapper extends KNNVectorFieldMapper {
// If the dimension has not yet been set because we do not have access to model metadata, it will be -1
public static final int UNSET_MODEL_DIMENSION_IDENTIFIER = -1;

private final AtomicReference<SpaceType> spaceType;
private final AtomicReference<MethodComponentContext> methodComponentContext;
private final AtomicInteger dimension;
private final AtomicReference<VectorDataType> vectorDataType;

private final AtomicReference<PerDimensionProcessor> perDimensionProcessor;
private final AtomicReference<PerDimensionValidator> perDimensionValidator;
private final AtomicReference<VectorValidator> vectorValidator;
private PerDimensionProcessor perDimensionProcessor;
private PerDimensionValidator perDimensionValidator;
private VectorValidator vectorValidator;

private final String modelId;

Expand All @@ -69,11 +61,7 @@ public Optional<String> getModelId() {

@Override
public int getDimension() {
ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (!ModelUtil.isModelCreated(modelMetadata)) {
throw new IllegalStateException(String.format("Model ID '%s' is not created.", modelId));
}
return modelMetadata.getDimension();
return getModelMetadata(modelDao, modelId).getDimension();
}
});
return new ModelFieldMapper(
Expand Down Expand Up @@ -105,102 +93,114 @@ private ModelFieldMapper(
modelId = annConfig.getModelId().orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty"));
this.modelDao = modelDao;

this.spaceType = new AtomicReference<>(null);
this.methodComponentContext = new AtomicReference<>(null);
this.dimension = new AtomicInteger(UNSET_MODEL_DIMENSION_IDENTIFIER);
this.vectorDataType = new AtomicReference<>(null);
this.perDimensionProcessor = new AtomicReference<>(null);
this.perDimensionValidator = new AtomicReference<>(null);
this.vectorValidator = new AtomicReference<>(null);
// For the model field mapper, we cannot validate the model during index creation due to
// an issue with reading cluster state during mapper creation. So, we need to validate the
// model when ingestion starts. We do this as lazily as we can
this.perDimensionProcessor = null;
this.perDimensionValidator = null;
this.vectorValidator = null;

this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE);
this.fieldType.putAttribute(MODEL_ID, modelId);
this.fieldType.freeze();
}

@Override
protected void validatePreparse() {
super.validatePreparse();
// For the model field mapper, we cannot validate the model during index creation due to
// an issue with reading cluster state during mapper creation. So, we need to validate the
// model when ingestion starts. We do this as lazily as we can
ModelMetadata modelMetadata = this.modelDao.getMetadata(modelId);
protected VectorValidator getVectorValidator() {
initVectorValidator();
return vectorValidator;
}

if (!ModelUtil.isModelCreated(modelMetadata)) {
throw new IllegalStateException(
String.format(
"Model \"%s\" from %s's mapping is not created. Because the \"%s\" parameter is not updatable, this index will need to be recreated with a valid model.",
modelId,
simpleName(),
MODEL_ID
)
);
}
@Override
protected PerDimensionValidator getPerDimensionValidator() {
initPerDimensionValidator();
return perDimensionValidator;
}

maybeInitLazyVariables(modelMetadata);
@Override
protected PerDimensionProcessor getPerDimensionProcessor() {
initPerDimensionProcessor();
return perDimensionProcessor;
}

private void maybeInitLazyVariables(ModelMetadata modelMetadata) {
vectorDataType.compareAndExchange(null, modelMetadata.getVectorDataType());
if (spaceType.get() == null) {
spaceType.compareAndExchange(null, modelMetadata.getSpaceType());
spaceType.get().validateVectorDataType(vectorDataType.get());
private void initVectorValidator() {
if (vectorValidator != null) {
return;
}
methodComponentContext.compareAndExchange(null, modelMetadata.getMethodComponentContext());
dimension.compareAndExchange(UNSET_MODEL_DIMENSION_IDENTIFIER, modelMetadata.getDimension());
maybeInitValidatorsAndProcessors();
ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId);
vectorValidator = new SpaceVectorValidator(modelMetadata.getSpaceType());
}

private void maybeInitValidatorsAndProcessors() {
this.vectorValidator.compareAndExchange(null, new SpaceVectorValidator(spaceType.get()));

if (VectorDataType.BINARY == vectorDataType.get()) {
this.perDimensionValidator.compareAndExchange(null, PerDimensionValidator.DEFAULT_BIT_VALIDATOR);
this.perDimensionProcessor.compareAndExchange(null, PerDimensionProcessor.NOOP_PROCESSOR);
private void initPerDimensionValidator() {
if (perDimensionValidator != null) {
return;
}
ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId);
MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext();
VectorDataType dataType = modelMetadata.getVectorDataType();

if (VectorDataType.BYTE == vectorDataType.get()) {
this.perDimensionValidator.compareAndExchange(null, PerDimensionValidator.DEFAULT_BYTE_VALIDATOR);
this.perDimensionProcessor.compareAndExchange(null, PerDimensionProcessor.NOOP_PROCESSOR);
if (VectorDataType.BINARY == dataType) {
perDimensionValidator = PerDimensionValidator.DEFAULT_BIT_VALIDATOR;
return;
}

if (!isFaissSQfp16(methodComponentContext.get())) {
// Normal float and byte processor
this.perDimensionValidator.compareAndExchange(null, PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR);
this.perDimensionProcessor.compareAndExchange(null, PerDimensionProcessor.NOOP_PROCESSOR);
if (VectorDataType.BYTE == dataType) {
perDimensionValidator = PerDimensionValidator.DEFAULT_BYTE_VALIDATOR;
return;
}

this.perDimensionValidator.compareAndExchange(null, PerDimensionValidator.DEFAULT_FP16_VALIDATOR);
if (!isFaissSQClipToFP16RangeEnabled(
(MethodComponentContext) methodComponentContext.get().getParameters().get(METHOD_ENCODER_PARAMETER)
)) {
this.perDimensionProcessor.compareAndExchange(null, PerDimensionProcessor.NOOP_PROCESSOR);
if (!isFaissSQfp16(methodComponentContext)) {
perDimensionValidator = PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR;
return;
}
this.perDimensionProcessor.compareAndExchange(null, PerDimensionProcessor.CLIP_TO_FP16_PROCESSOR);
}

@Override
protected VectorValidator getVectorValidator() {
return vectorValidator.get();
perDimensionValidator = PerDimensionValidator.DEFAULT_FP16_VALIDATOR;
}

@Override
protected PerDimensionValidator getPerDimensionValidator() {
return perDimensionValidator.get();
}
private void initPerDimensionProcessor() {
if (perDimensionProcessor != null) {
return;
}
ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId);
MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext();
VectorDataType dataType = modelMetadata.getVectorDataType();

@Override
protected PerDimensionProcessor getPerDimensionProcessor() {
return perDimensionProcessor.get();
if (VectorDataType.BINARY == dataType) {
perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR;
return;
}

if (VectorDataType.BYTE == dataType) {
perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR;
return;
}

if (!isFaissSQfp16(methodComponentContext)) {
perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR;
return;
}

if (!isFaissSQClipToFP16RangeEnabled(
(MethodComponentContext) methodComponentContext.getParameters().get(METHOD_ENCODER_PARAMETER)
)) {
perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR;
return;
}
perDimensionProcessor = PerDimensionProcessor.CLIP_TO_FP16_PROCESSOR;
}

@Override
protected void parseCreateField(ParseContext context) throws IOException {
validatePreparse();
parseCreateField(context, dimension.get(), vectorDataType.get());
ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId);
parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getVectorDataType());
}

private static ModelMetadata getModelMetadata(ModelDao modelDao, String modelId) {
ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (!ModelUtil.isModelCreated(modelMetadata)) {
throw new IllegalStateException(String.format("Model ID '%s' is not created.", modelId));
}
return modelMetadata;
}
}

0 comments on commit e22fcf8

Please sign in to comment.