Skip to content

Commit

Permalink
Remove optional and fix
Browse files Browse the repository at this point in the history
Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 committed Aug 14, 2024
1 parent 6d0ed6a commit 9ad2096
Show file tree
Hide file tree
Showing 17 changed files with 98 additions and 115 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import com.google.common.collect.ImmutableMap;
import org.apache.hc.core5.http.ParseException;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.junit.Ignore;
import org.opensearch.client.Response;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.xcontent.XContentBuilder;
Expand Down Expand Up @@ -46,7 +45,6 @@
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.common.KNNConstants.NAME;

@Ignore
public class FaissSQIT extends AbstractRestartUpgradeTestCase {
private static final String TEST_FIELD = "test-field";
private static final String TRAIN_TEST_FIELD = "train-test-field";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,17 @@ public ValidationException validateMethod(KNNMethodContext knnMethodContext, KNN
}

private void validateSpaceType(final KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
if (knnMethodContext == null || knnMethodConfigContext.getVectorDataType().isEmpty()) {
if (knnMethodContext == null) {
return;
}
knnMethodContext.getSpaceType().validateVectorDataType(knnMethodConfigContext.getVectorDataType().get());
knnMethodContext.getSpaceType().validateVectorDataType(knnMethodConfigContext.getVectorDataType());
}

private String validateDimension(final KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
if (knnMethodContext == null || knnMethodConfigContext.getDimension().isEmpty()) {
if (knnMethodContext == null) {
return null;
}
int dimension = knnMethodConfigContext.getDimension().get();
int dimension = knnMethodConfigContext.getDimension();
if (dimension > KNNEngine.getMaxDimensionByEngine(knnMethodContext.getKnnEngine())) {
return String.format(
Locale.ROOT,
Expand All @@ -89,11 +89,7 @@ private String validateDimension(final KNNMethodContext knnMethodContext, KNNMet
);
}

if (knnMethodConfigContext.getVectorDataType().isEmpty()) {
return null;
}

if (VectorDataType.BINARY == knnMethodConfigContext.getVectorDataType().get() && dimension % 8 != 0) {
if (VectorDataType.BINARY == knnMethodConfigContext.getVectorDataType() && dimension % 8 != 0) {
return "Dimension should be multiply of 8 for binary vector data type";
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,14 @@ public boolean isTrainingRequired(KNNMethodContext knnMethodContext) {

@Override
public int estimateOverheadInKB(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
return methodComponent.estimateOverheadInKB(
knnMethodContext.getMethodComponentContext(),
knnMethodConfigContext.getDimension().orElseThrow(() -> new IllegalStateException("Dimension needs to be set"))
);
return methodComponent.estimateOverheadInKB(knnMethodContext.getMethodComponentContext(), knnMethodConfigContext.getDimension());
}

protected PerDimensionValidator doGetPerDimensionValidator(
KNNMethodContext knnMethodContext,
KNNMethodConfigContext knnMethodConfigContext
) {
VectorDataType vectorDataType = knnMethodConfigContext.getVectorDataType()
.orElseThrow(
() -> new IllegalStateException("Vector data type needs to be set on KNNMethodConfigContext in order to get the processor")
);
VectorDataType vectorDataType = knnMethodConfigContext.getVectorDataType();

if (VectorDataType.BINARY == vectorDataType) {
return PerDimensionValidator.DEFAULT_BIT_VALIDATOR;
Expand Down Expand Up @@ -122,12 +116,7 @@ public KNNLibraryIndexingContext getKNNLibraryIndexingContext(
methodComponent.getAsMap(knnMethodContext.getMethodComponentContext(), knnMethodConfigContext)
);
parameterMap.put(KNNConstants.SPACE_TYPE, knnMethodContext.getSpaceType().getValue());
parameterMap.put(
KNNConstants.VECTOR_DATA_TYPE_FIELD,
knnMethodConfigContext.getVectorDataType()
.orElseThrow(() -> new IllegalStateException("Vector data type needs to be set"))
.getValue()
);
parameterMap.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, knnMethodConfigContext.getVectorDataType().getValue());
return KNNLibraryIndexingContextImpl.builder()
.parameters(parameterMap)
.vectorValidator(doGetVectorValidator(knnMethodContext, knnMethodConfigContext))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,53 +7,27 @@

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.NoArgsConstructor;
import lombok.Getter;
import lombok.Setter;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.opensearch.Version;
import org.opensearch.knn.index.VectorDataType;

import java.util.Optional;

/**
* This object provides additional context that the user does not provide when {@link KNNMethodContext} is
* created via parsing. The values in this object need to be dynamically set and calling code needs to handle
* the possibility that the values have not been set.
*/
@Setter
@Getter
@Builder
@NoArgsConstructor
@AllArgsConstructor
public final class KNNMethodConfigContext {
private VectorDataType vectorDataType;
private Integer dimension;
private Version versionCreated;

/**
*
* @return vector data type or null if not set
*/
public Optional<VectorDataType> getVectorDataType() {
return Optional.ofNullable(vectorDataType);
}

/**
*
* @return dimension or null if not set
*/
public Optional<Integer> getDimension() {
return Optional.ofNullable(dimension);
}

/**
*
* @return version created or null if not set
*/
public Optional<Version> getVersionCreated() {
return Optional.ofNullable(versionCreated);
}

@Override
public boolean equals(Object obj) {
if (this == obj) return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,14 @@ public ValidationException validate(MethodComponentContext methodComponentContex
Map<String, Object> providedParameters = methodComponentContext.getParameters();

ValidationException validationException = null;
if (knnMethodConfigContext.getVectorDataType().isPresent()
&& !supportedVectorDataTypes.contains(knnMethodConfigContext.getVectorDataType().get())) {
if (!supportedVectorDataTypes.contains(knnMethodConfigContext.getVectorDataType())) {
validationException = new ValidationException();
validationException.addValidationError(
String.format(
Locale.ROOT,
"Method \"%s\" is not supported for vector data type \"%s\".",
name,
knnMethodConfigContext.getVectorDataType().get()
knnMethodConfigContext.getVectorDataType()
)
);
}
Expand Down Expand Up @@ -314,8 +313,7 @@ public static Map<String, Object> getParameterMapWithDefaultsAdded(
) {
Map<String, Object> parametersWithDefaultsMap = new HashMap<>();
Map<String, Object> userProvidedParametersMap = methodComponentContext.getParameters();
Version indexCreationVersion = knnMethodConfigContext.getVersionCreated()
.orElseThrow(() -> new IllegalStateException("Version must be set"));
Version indexCreationVersion = knnMethodConfigContext.getVersionCreated();
for (Parameter<?> parameter : methodComponent.getParameters().values()) {
if (methodComponentContext.getParameters().containsKey(parameter.getName())) {
parametersWithDefaultsMap.put(parameter.getName(), userProvidedParametersMap.get(parameter.getName()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@ protected PerDimensionValidator doGetPerDimensionValidator(
KNNMethodContext knnMethodContext,
KNNMethodConfigContext knnMethodConfigContext
) {
VectorDataType vectorDataType = knnMethodConfigContext.getVectorDataType()
.orElseThrow(
() -> new IllegalStateException("Vector data type needs to be set on KNNMethodConfigContext in order to get the processor")
);
VectorDataType vectorDataType = knnMethodConfigContext.getVectorDataType();
if (VectorDataType.BINARY == vectorDataType) {
return PerDimensionValidator.DEFAULT_BIT_VALIDATOR;
}
Expand All @@ -65,10 +62,7 @@ protected PerDimensionProcessor doGetPerDimensionProcessor(
KNNMethodContext knnMethodContext,
KNNMethodConfigContext knnMethodConfigContext
) {
VectorDataType vectorDataType = knnMethodConfigContext.getVectorDataType()
.orElseThrow(
() -> new IllegalStateException("Vector data type needs to be set on KNNMethodConfigContext in order to get the processor")
);
VectorDataType vectorDataType = knnMethodConfigContext.getVectorDataType();

if (VectorDataType.BINARY == vectorDataType) {
return PerDimensionProcessor.NOOP_PROCESSOR;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ static boolean isFaissSQClipToFP16RangeEnabled(MethodComponentContext methodComp
if (encoderContext == null) {
return false;
}
return (boolean) methodComponentContext.getParameters().getOrDefault(FAISS_SQ_CLIP, false);
return (boolean) encoderContext.getParameters().getOrDefault(FAISS_SQ_CLIP, false);
}

static MethodComponentContext extractEncoderMethodComponentContext(MethodComponentContext methodComponentContext) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ private static MethodComponent initMethodComponent() {
.addParameter(METHOD_ENCODER_PARAMETER, initEncoderParameter())
.setMapGenerator(((methodComponent, methodComponentContext, knnMethodConfigContext) -> {
String prefix = "";
if (knnMethodConfigContext.getVectorDataType().isPresent()
&& knnMethodConfigContext.getVectorDataType().get() == VectorDataType.BINARY) {
if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BINARY) {
prefix = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public class FaissHNSWPQEncoder implements Encoder {
new Parameter.IntegerParameter(ENCODER_PARAMETER_PQ_M, ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT, (v, context) -> {
boolean isValueGreaterThan0 = v > 0;
boolean isValueLessThanCodeCountLimit = v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT;
boolean isDimensionDivisibleByValue = context.getDimension().orElse(0) % v == 0;
boolean isDimensionDivisibleByValue = context.getDimension() % v == 0;
return isValueGreaterThan0 && isValueLessThanCodeCountLimit && isDimensionDivisibleByValue;
})
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ private static MethodComponent initMethodComponent() {
.setRequiresTraining(true)
.setMapGenerator(((methodComponent, methodComponentContext, knnMethodConfigContext) -> {
String prefix = "";
if (knnMethodConfigContext.getVectorDataType().isPresent()
&& knnMethodConfigContext.getVectorDataType().get() == VectorDataType.BINARY) {
if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BINARY) {
prefix = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public class FaissIVFPQEncoder implements Encoder {
new Parameter.IntegerParameter(ENCODER_PARAMETER_PQ_M, ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT, (v, context) -> {
boolean isValueGreaterThan0 = v > 0;
boolean isValueLessThanCodeCountLimit = v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT;
boolean isDimensionDivisibleByValue = context.getDimension().orElse(0) % v == 0;
boolean isDimensionDivisibleByValue = context.getDimension() % v == 0;
return isValueGreaterThan0 && isValueLessThanCodeCountLimit && isDimensionDivisibleByValue;
})
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ public static FlatVectorFieldMapper createFieldMapper(
final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType(
fullname,
metaValue,
knnMethodConfigContext.getVectorDataType().orElseThrow(() -> new IllegalStateException("Datatype not found")),
() -> knnMethodConfigContext.getDimension().orElseThrow(() -> new IllegalStateException("Dimension not found"))
knnMethodConfigContext.getVectorDataType(),
knnMethodConfigContext::getDimension
);
return new FlatVectorFieldMapper(
simpleName,
Expand All @@ -47,7 +47,7 @@ public static FlatVectorFieldMapper createFieldMapper(
ignoreMalformed,
stored,
hasDocValues,
knnMethodConfigContext.getVersionCreated().orElseThrow(() -> new IllegalStateException("Version not found"))
knnMethodConfigContext.getVersionCreated()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ static LuceneFieldMapper createFieldMapper(
final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType(
fullname,
metaValue,
knnMethodConfigContext.getVectorDataType().orElseThrow(() -> new IllegalArgumentException("Vector data type cannot be empty")),
knnMethodConfigContext.getVectorDataType(),
new KNNMappingConfig() {
@Override
public Optional<KNNMethodContext> getKnnMethodContext() {
Expand All @@ -61,7 +61,7 @@ public Optional<KNNMethodContext> getKnnMethodContext() {

@Override
public int getDimension() {
return knnMethodConfigContext.getDimension().orElseThrow(() -> new IllegalStateException("Dimension cannot be empty"));
return knnMethodConfigContext.getDimension();
}
}
);
Expand All @@ -82,7 +82,7 @@ private LuceneFieldMapper(
input.getIgnoreMalformed(),
input.isStored(),
input.isHasDocValues(),
knnMethodConfigContext.getVersionCreated().orElseThrow(() -> new IllegalArgumentException("Method context cannot be empty")),
knnMethodConfigContext.getVersionCreated(),
mappedFieldType.knnMappingConfig.getKnnMethodContext().orElse(null)
);
KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public static MethodFieldMapper createFieldMapper(
final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType(
fullname,
metaValue,
knnMethodConfigContext.getVectorDataType().orElseThrow(() -> new IllegalStateException("Vector data type cannot be empty")),
knnMethodConfigContext.getVectorDataType(),
new KNNMappingConfig() {
@Override
public Optional<KNNMethodContext> getKnnMethodContext() {
Expand All @@ -61,7 +61,7 @@ public Optional<KNNMethodContext> getKnnMethodContext() {

@Override
public int getDimension() {
return knnMethodConfigContext.getDimension().orElseThrow(() -> new IllegalStateException("Dimension cannot be empty"));
return knnMethodConfigContext.getDimension();
}
}
);
Expand Down Expand Up @@ -98,7 +98,7 @@ private MethodFieldMapper(
ignoreMalformed,
stored,
hasDocValues,
knnMethodConfigContext.getVersionCreated().orElseThrow(() -> new IllegalArgumentException("Method context cannot be empty")),
knnMethodConfigContext.getVersionCreated(),
originalKNNMethodContext
);
this.useLuceneBasedVectorField = KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(indexCreatedVersion);
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/org/opensearch/knn/training/TrainingJob.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ public TrainingJob(
new ModelMetadata(
knnMethodContext.getKnnEngine(),
knnMethodContext.getSpaceType(),
knnMethodConfigContext.getDimension().orElseThrow(() -> new IllegalArgumentException("Dimension value missing")),
knnMethodConfigContext.getDimension(),
ModelState.TRAINING,
ZonedDateTime.now(ZoneOffset.UTC).toString(),
description,
"",
nodeAssignment,
knnMethodContext.getMethodComponentContext(),
knnMethodConfigContext.getVectorDataType().orElseThrow(() -> new IllegalArgumentException("VectorDatatype value missing"))
knnMethodConfigContext.getVectorDataType()
),
null,
this.modelId
Expand Down
Loading

0 comments on commit 9ad2096

Please sign in to comment.