Skip to content

Commit

Permalink
Change ModelFieldMapper to initialize per method
Browse files Browse the repository at this point in the history
Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 committed Aug 9, 2024
1 parent 79e0a4b commit 5e61e3f
Showing 1 changed file with 63 additions and 77 deletions.
140 changes: 63 additions & 77 deletions src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import org.opensearch.Version;
import org.opensearch.common.Explicit;
import org.opensearch.index.mapper.ParseContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.indices.ModelDao;
Expand All @@ -19,7 +18,6 @@
import java.io.IOException;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
Expand All @@ -35,11 +33,7 @@ public class ModelFieldMapper extends KNNVectorFieldMapper {
// If the dimension has not yet been set because we do not have access to model metadata, it will be -1
public static final int UNSET_MODEL_DIMENSION_IDENTIFIER = -1;

private final AtomicReference<SpaceType> spaceType;
private final AtomicReference<MethodComponentContext> methodComponentContext;
private final AtomicInteger dimension;
private final AtomicReference<VectorDataType> vectorDataType;

// private final AtomicReference<ModelMetadata> modelMetadata;
private final AtomicReference<PerDimensionProcessor> perDimensionProcessor;
private final AtomicReference<PerDimensionValidator> perDimensionValidator;
private final AtomicReference<VectorValidator> vectorValidator;
Expand Down Expand Up @@ -69,11 +63,7 @@ public Optional<String> getModelId() {

@Override
public int getDimension() {
ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (!ModelUtil.isModelCreated(modelMetadata)) {
throw new IllegalStateException(String.format("Model ID '%s' is not created.", modelId));
}
return modelMetadata.getDimension();
return getModelMetadata(modelDao, modelId).getDimension();
}
});
return new ModelFieldMapper(
Expand Down Expand Up @@ -105,10 +95,9 @@ private ModelFieldMapper(
modelId = annConfig.getModelId().orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty"));
this.modelDao = modelDao;

this.spaceType = new AtomicReference<>(null);
this.methodComponentContext = new AtomicReference<>(null);
this.dimension = new AtomicInteger(UNSET_MODEL_DIMENSION_IDENTIFIER);
this.vectorDataType = new AtomicReference<>(null);
// For the model field mapper, we cannot validate the model during index creation due to
// an issue with reading cluster state during mapper creation. So, we need to validate the
// model when ingestion starts. We do this as lazily as we can
this.perDimensionProcessor = new AtomicReference<>(null);
this.perDimensionValidator = new AtomicReference<>(null);
this.vectorValidator = new AtomicReference<>(null);
Expand All @@ -119,88 +108,85 @@ private ModelFieldMapper(
}

@Override
protected void validatePreparse() {
super.validatePreparse();
// For the model field mapper, we cannot validate the model during index creation due to
// an issue with reading cluster state during mapper creation. So, we need to validate the
// model when ingestion starts. We do this as lazily as we can
ModelMetadata modelMetadata = this.modelDao.getMetadata(modelId);
protected VectorValidator getVectorValidator() {
vectorValidator.compareAndSet(null, initVectorValidator());
return vectorValidator.get();
}

if (!ModelUtil.isModelCreated(modelMetadata)) {
throw new IllegalStateException(
String.format(
"Model \"%s\" from %s's mapping is not created. Because the \"%s\" parameter is not updatable, this index will need to be recreated with a valid model.",
modelId,
simpleName(),
MODEL_ID
)
);
}
@Override
protected PerDimensionValidator getPerDimensionValidator() {
perDimensionValidator.compareAndSet(null, initPerDimensionValidator());
return perDimensionValidator.get();
}

maybeInitLazyVariables(modelMetadata);
@Override
protected PerDimensionProcessor getPerDimensionProcessor() {
perDimensionProcessor.compareAndSet(null, initPerDimensionProcessor());
return perDimensionProcessor.get();
}

private void maybeInitLazyVariables(ModelMetadata modelMetadata) {
vectorDataType.compareAndExchange(null, modelMetadata.getVectorDataType());
if (spaceType.get() == null) {
spaceType.compareAndExchange(null, modelMetadata.getSpaceType());
spaceType.get().validateVectorDataType(vectorDataType.get());
}
methodComponentContext.compareAndExchange(null, modelMetadata.getMethodComponentContext());
dimension.compareAndExchange(UNSET_MODEL_DIMENSION_IDENTIFIER, modelMetadata.getDimension());
maybeInitValidatorsAndProcessors();
private VectorValidator initVectorValidator() {
ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId);
return new SpaceVectorValidator(modelMetadata.getSpaceType());
}

private void maybeInitValidatorsAndProcessors() {
this.vectorValidator.compareAndExchange(null, new SpaceVectorValidator(spaceType.get()));
private PerDimensionValidator initPerDimensionValidator() {
ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId);
MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext();
VectorDataType dataType = modelMetadata.getVectorDataType();

if (VectorDataType.BINARY == vectorDataType.get()) {
this.perDimensionValidator.compareAndExchange(null, PerDimensionValidator.DEFAULT_BIT_VALIDATOR);
this.perDimensionProcessor.compareAndExchange(null, PerDimensionProcessor.NOOP_PROCESSOR);
return;
if (VectorDataType.BINARY == dataType) {
return PerDimensionValidator.DEFAULT_BIT_VALIDATOR;
}

if (VectorDataType.BYTE == vectorDataType.get()) {
this.perDimensionValidator.compareAndExchange(null, PerDimensionValidator.DEFAULT_BYTE_VALIDATOR);
this.perDimensionProcessor.compareAndExchange(null, PerDimensionProcessor.NOOP_PROCESSOR);
return;
if (VectorDataType.BYTE == dataType) {
return PerDimensionValidator.DEFAULT_BYTE_VALIDATOR;
}

if (!isFaissSQfp16(methodComponentContext.get())) {
// Normal float and byte processor
this.perDimensionValidator.compareAndExchange(null, PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR);
this.perDimensionProcessor.compareAndExchange(null, PerDimensionProcessor.NOOP_PROCESSOR);
return;
if (!isFaissSQfp16(methodComponentContext)) {
return PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR;
}

this.perDimensionValidator.compareAndExchange(null, PerDimensionValidator.DEFAULT_FP16_VALIDATOR);
if (!isFaissSQClipToFP16RangeEnabled(
(MethodComponentContext) methodComponentContext.get().getParameters().get(METHOD_ENCODER_PARAMETER)
)) {
this.perDimensionProcessor.compareAndExchange(null, PerDimensionProcessor.NOOP_PROCESSOR);
return;
}
this.perDimensionProcessor.compareAndExchange(null, PerDimensionProcessor.CLIP_TO_FP16_PROCESSOR);
return PerDimensionValidator.DEFAULT_FP16_VALIDATOR;
}

@Override
protected VectorValidator getVectorValidator() {
return vectorValidator.get();
}
private PerDimensionProcessor initPerDimensionProcessor() {
ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId);
MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext();
VectorDataType dataType = modelMetadata.getVectorDataType();

@Override
protected PerDimensionValidator getPerDimensionValidator() {
return perDimensionValidator.get();
}
if (VectorDataType.BINARY == dataType) {
return PerDimensionProcessor.NOOP_PROCESSOR;
}

@Override
protected PerDimensionProcessor getPerDimensionProcessor() {
return perDimensionProcessor.get();
if (VectorDataType.BYTE == dataType) {
return PerDimensionProcessor.NOOP_PROCESSOR;
}

if (!isFaissSQfp16(methodComponentContext)) {
return PerDimensionProcessor.NOOP_PROCESSOR;
}

if (!isFaissSQClipToFP16RangeEnabled(
(MethodComponentContext) methodComponentContext.getParameters().get(METHOD_ENCODER_PARAMETER)
)) {
return PerDimensionProcessor.NOOP_PROCESSOR;
}
return PerDimensionProcessor.CLIP_TO_FP16_PROCESSOR;
}

@Override
protected void parseCreateField(ParseContext context) throws IOException {
validatePreparse();
parseCreateField(context, dimension.get(), vectorDataType.get());
ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId);
parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getVectorDataType());
}

private static ModelMetadata getModelMetadata(ModelDao modelDao, String modelId) {
ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (!ModelUtil.isModelCreated(modelMetadata)) {
throw new IllegalStateException(String.format("Model ID '%s' is not created.", modelId));
}
return modelMetadata;
}
}

0 comments on commit 5e61e3f

Please sign in to comment.