Skip to content

Commit

Permalink
Fix model field mapper
Browse files Browse the repository at this point in the history
Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 committed Aug 14, 2024
1 parent 8ca67a0 commit dcf67d9
Showing 1 changed file with 40 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.knn.index.engine.KNNLibraryIndexingContext;
import org.opensearch.knn.index.engine.KNNMethodConfigContext;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;
Expand Down Expand Up @@ -130,8 +131,15 @@ private void initVectorValidator() {
return;
}
ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId);

KNNMethodContext knnMethodContext = getKNNMethodContextFromModelMetadata(modelMetadata);
KNNMethodConfigContext knnMethodConfigContext = getKNNMethodConfigContextFromModelMetadata(modelMetadata);
// Need to handle BWC case
if (knnMethodContext == null || knnMethodConfigContext == null) {
vectorValidator = new SpaceVectorValidator(modelMetadata.getSpaceType());
return;
}

KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine()
.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext);
vectorValidator = knnLibraryIndexingContext.getVectorValidator();
Expand All @@ -142,8 +150,22 @@ private void initPerDimensionValidator() {
return;
}
ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId);

KNNMethodContext knnMethodContext = getKNNMethodContextFromModelMetadata(modelMetadata);
KNNMethodConfigContext knnMethodConfigContext = getKNNMethodConfigContextFromModelMetadata(modelMetadata);
// Need to handle BWC case
if (knnMethodContext == null || knnMethodConfigContext == null) {
if (modelMetadata.getVectorDataType() == VectorDataType.BINARY) {
perDimensionValidator = PerDimensionValidator.DEFAULT_BIT_VALIDATOR;
} else if (modelMetadata.getVectorDataType() == VectorDataType.BYTE) {
perDimensionValidator = PerDimensionValidator.DEFAULT_BYTE_VALIDATOR;
} else {
perDimensionValidator = PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR;
}

return;
}

KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine()
.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext);
perDimensionValidator = knnLibraryIndexingContext.getPerDimensionValidator();
Expand All @@ -154,8 +176,15 @@ private void initPerDimensionProcessor() {
return;
}
ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId);

KNNMethodContext knnMethodContext = getKNNMethodContextFromModelMetadata(modelMetadata);
KNNMethodConfigContext knnMethodConfigContext = getKNNMethodConfigContextFromModelMetadata(modelMetadata);
// Need to handle BWC case
if (knnMethodContext == null || knnMethodConfigContext == null) {
perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR;
return;
}

KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine()
.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext);
perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor();
Expand Down Expand Up @@ -184,13 +213,23 @@ protected void parseCreateField(ParseContext context) throws IOException {
}

private static KNNMethodContext getKNNMethodContextFromModelMetadata(ModelMetadata modelMetadata) {
return new KNNMethodContext(modelMetadata.getKnnEngine(), modelMetadata.getSpaceType(), modelMetadata.getMethodComponentContext());
MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext();
if (methodComponentContext == MethodComponentContext.EMPTY) {
return null;
}
return new KNNMethodContext(modelMetadata.getKnnEngine(), modelMetadata.getSpaceType(), methodComponentContext);
}

private static KNNMethodConfigContext getKNNMethodConfigContextFromModelMetadata(ModelMetadata modelMetadata) {
MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext();
if (methodComponentContext == MethodComponentContext.EMPTY) {
return null;
}
// TODO: Need to fix this version check by serializing the model
return KNNMethodConfigContext.builder()
.vectorDataType(modelMetadata.getVectorDataType())
.dimension(modelMetadata.getDimension())
.versionCreated(Version.V_2_14_0)
.build();
}

Expand Down

0 comments on commit dcf67d9

Please sign in to comment.