Skip to content

Commit

Permalink
Make the config object separate from method
Browse files Browse the repository at this point in the history
Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 committed Aug 14, 2024
1 parent ae47fe7 commit 5011f16
Show file tree
Hide file tree
Showing 38 changed files with 622 additions and 739 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Move k search k-NN query to re-write phase of vector search query for Native Engines [#1877](https://github.com/opensearch-project/k-NN/pull/1877)
* Restructure mappers to better handle null cases and avoid branching in parsing [#1939](https://github.com/opensearch-project/k-NN/pull/1939)
* Added Quantization Framework and implemented 1Bit and multibit quantizer[#1889](https://github.com/opensearch-project/k-NN/issues/1889)
* Encapsulate dimension, vector data type validation/processing inside Library [#1957](https://github.com/opensearch-project/k-NN/pull/1957)
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,18 @@ public KNNLibrarySearchContext getKNNLibrarySearchContext(String methodName) {
}

@Override
public KNNLibraryIndexingContext getKNNLibraryIndexingContext(KNNMethodContext knnMethodContext) {
public KNNLibraryIndexingContext getKNNLibraryIndexingContext(
KNNMethodContext knnMethodContext,
KNNMethodConfigContext knnMethodConfigContext
) {
String method = knnMethodContext.getMethodComponentContext().getName();
throwIllegalArgOnNonNull(validateMethodExists(method));
KNNMethod knnMethod = methods.get(method);
return knnMethod.getKNNLibraryIndexingContext(knnMethodContext);
return knnMethod.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext);
}

@Override
public ValidationException validateMethod(KNNMethodContext knnMethodContext) {
public ValidationException validateMethod(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
String methodName = knnMethodContext.getMethodComponentContext().getName();
ValidationException validationException = null;
String invalidErrorMessage = validateMethodExists(methodName);
Expand All @@ -49,14 +52,14 @@ public ValidationException validateMethod(KNNMethodContext knnMethodContext) {
validationException.addValidationError(invalidErrorMessage);
return validationException;
}
invalidErrorMessage = validateDimension(knnMethodContext);
invalidErrorMessage = validateDimension(knnMethodContext, knnMethodConfigContext);
if (invalidErrorMessage != null) {
validationException = new ValidationException();
validationException.addValidationError(invalidErrorMessage);
}

validateSpaceType(knnMethodContext);
ValidationException methodValidation = methods.get(methodName).validate(knnMethodContext);
validateSpaceType(knnMethodContext, knnMethodConfigContext);
ValidationException methodValidation = methods.get(methodName).validate(knnMethodContext, knnMethodConfigContext);
if (methodValidation != null) {
validationException = validationException == null ? new ValidationException() : validationException;
validationException.addValidationErrors(methodValidation.validationErrors());
Expand All @@ -65,18 +68,18 @@ public ValidationException validateMethod(KNNMethodContext knnMethodContext) {
return validationException;
}

private void validateSpaceType(final KNNMethodContext knnMethodContext) {
if (knnMethodContext == null || knnMethodContext.getKnnMethodConfigContext().getVectorDataType().isEmpty()) {
private void validateSpaceType(final KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
if (knnMethodContext == null || knnMethodConfigContext.getVectorDataType().isEmpty()) {
return;
}
knnMethodContext.getSpaceType().validateVectorDataType(knnMethodContext.getKnnMethodConfigContext().getVectorDataType().get());
knnMethodContext.getSpaceType().validateVectorDataType(knnMethodConfigContext.getVectorDataType().get());
}

private String validateDimension(final KNNMethodContext knnMethodContext) {
if (knnMethodContext == null || knnMethodContext.getKnnMethodConfigContext().getDimension().isEmpty()) {
private String validateDimension(final KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
if (knnMethodContext == null || knnMethodConfigContext.getDimension().isEmpty()) {
return null;
}
int dimension = knnMethodContext.getKnnMethodConfigContext().getDimension().get();
int dimension = knnMethodConfigContext.getDimension().get();
if (dimension > KNNEngine.getMaxDimensionByEngine(knnMethodContext.getKnnEngine())) {
return String.format(
Locale.ROOT,
Expand All @@ -86,11 +89,11 @@ private String validateDimension(final KNNMethodContext knnMethodContext) {
);
}

if (knnMethodContext.getKnnMethodConfigContext().getVectorDataType().isEmpty()) {
if (knnMethodConfigContext.getVectorDataType().isEmpty()) {
return null;
}

if (VectorDataType.BINARY == knnMethodContext.getKnnMethodConfigContext().getVectorDataType().get() && dimension % 8 != 0) {
if (VectorDataType.BINARY == knnMethodConfigContext.getVectorDataType().get() && dimension % 8 != 0) {
return "Dimension should be multiply of 8 for binary vector data type";
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public boolean isSpaceTypeSupported(SpaceType space) {
}

@Override
public ValidationException validate(KNNMethodContext knnMethodContext) {
public ValidationException validate(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
List<String> errorMessages = new ArrayList<>();
if (!isSpaceTypeSupported(knnMethodContext.getSpaceType())) {
errorMessages.add(
Expand All @@ -55,7 +55,7 @@ public ValidationException validate(KNNMethodContext knnMethodContext) {

ValidationException methodValidation = methodComponent.validate(
knnMethodContext.getMethodComponentContext(),
knnMethodContext.getKnnMethodConfigContext()
knnMethodConfigContext
);
if (methodValidation != null) {
errorMessages.addAll(methodValidation.validationErrors());
Expand All @@ -76,13 +76,18 @@ public boolean isTrainingRequired(KNNMethodContext knnMethodContext) {
}

@Override
public int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension) {
return methodComponent.estimateOverheadInKB(knnMethodContext.getMethodComponentContext(), dimension);
public int estimateOverheadInKB(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
return methodComponent.estimateOverheadInKB(
knnMethodContext.getMethodComponentContext(),
knnMethodConfigContext.getDimension().orElseThrow(() -> new IllegalStateException("Dimension needs to be set"))
);
}

protected PerDimensionValidator doGetPerDimensionValidator(KNNMethodContext knnMethodContext) {
VectorDataType vectorDataType = knnMethodContext.getKnnMethodConfigContext()
.getVectorDataType()
protected PerDimensionValidator doGetPerDimensionValidator(
KNNMethodContext knnMethodContext,
KNNMethodConfigContext knnMethodConfigContext
) {
VectorDataType vectorDataType = knnMethodConfigContext.getVectorDataType()
.orElseThrow(
() -> new IllegalStateException("Vector data type needs to be set on KNNMethodConfigContext in order to get the processor")
);
Expand All @@ -97,32 +102,37 @@ protected PerDimensionValidator doGetPerDimensionValidator(KNNMethodContext knnM
return PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR;
}

protected VectorValidator doGetVectorValidator(KNNMethodContext knnMethodContext) {
protected VectorValidator doGetVectorValidator(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
return new SpaceVectorValidator(knnMethodContext.getSpaceType());
}

protected PerDimensionProcessor doGetPerDimensionProcessor(KNNMethodContext knnMethodContext) {
protected PerDimensionProcessor doGetPerDimensionProcessor(
KNNMethodContext knnMethodContext,
KNNMethodConfigContext knnMethodConfigContext
) {
return PerDimensionProcessor.NOOP_PROCESSOR;
}

@Override
public KNNLibraryIndexingContext getKNNLibraryIndexingContext(KNNMethodContext knnMethodContext) {
public KNNLibraryIndexingContext getKNNLibraryIndexingContext(
KNNMethodContext knnMethodContext,
KNNMethodConfigContext knnMethodConfigContext
) {
Map<String, Object> parameterMap = new HashMap<>(
methodComponent.getAsMap(knnMethodContext.getMethodComponentContext(), knnMethodContext.getKnnMethodConfigContext())
methodComponent.getAsMap(knnMethodContext.getMethodComponentContext(), knnMethodConfigContext)
);
parameterMap.put(KNNConstants.SPACE_TYPE, knnMethodContext.getSpaceType().getValue());
parameterMap.put(
KNNConstants.VECTOR_DATA_TYPE_FIELD,
knnMethodContext.getKnnMethodConfigContext()
.getVectorDataType()
knnMethodConfigContext.getVectorDataType()
.orElseThrow(() -> new IllegalStateException("Vector data type needs to be set"))
.getValue()
);
return KNNLibraryIndexingContextImpl.builder()
.parameters(parameterMap)
.vectorValidator(doGetVectorValidator(knnMethodContext))
.perDimensionValidator(doGetPerDimensionValidator(knnMethodContext))
.perDimensionProcessor(doGetPerDimensionProcessor(knnMethodContext))
.vectorValidator(doGetVectorValidator(knnMethodContext, knnMethodConfigContext))
.perDimensionValidator(doGetPerDimensionValidator(knnMethodContext, knnMethodConfigContext))
.perDimensionProcessor(doGetPerDimensionProcessor(knnMethodContext, knnMethodConfigContext))
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public JVMLibrary(Map<String, KNNMethod> methods, String version) {
}

@Override
public int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension) {
public int estimateOverheadInKB(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
throw new UnsupportedOperationException("Estimating overhead is not supported for JVM based libraries.");
}

Expand Down
15 changes: 9 additions & 6 deletions src/main/java/org/opensearch/knn/index/engine/KNNEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ public Float scoreToRadialThreshold(Float score, SpaceType spaceType) {
}

@Override
public ValidationException validateMethod(KNNMethodContext knnMethodContext) {
return knnLibrary.validateMethod(knnMethodContext);
public ValidationException validateMethod(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
return knnLibrary.validateMethod(knnMethodContext, knnMethodConfigContext);
}

@Override
Expand All @@ -170,8 +170,11 @@ public boolean isTrainingRequired(KNNMethodContext knnMethodContext) {
}

@Override
public KNNLibraryIndexingContext getKNNLibraryIndexingContext(KNNMethodContext knnMethodContext) {
return knnLibrary.getKNNLibraryIndexingContext(knnMethodContext);
public KNNLibraryIndexingContext getKNNLibraryIndexingContext(
KNNMethodContext knnMethodContext,
KNNMethodConfigContext knnMethodConfigContext
) {
return knnLibrary.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext);
}

@Override
Expand All @@ -180,8 +183,8 @@ public KNNLibrarySearchContext getKNNLibrarySearchContext(String methodName) {
}

@Override
public int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension) {
return knnLibrary.estimateOverheadInKB(knnMethodContext, dimension);
public int estimateOverheadInKB(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
return knnLibrary.estimateOverheadInKB(knnMethodContext, knnMethodConfigContext);
}

@Override
Expand Down
13 changes: 9 additions & 4 deletions src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ public interface KNNLibrary {
* deemed invalid.
*
* @param knnMethodContext to be validated
* @param knnMethodConfigContext configuration context for the method
* @return ValidationException produced by validation errors; null if no validations errors.
*/
ValidationException validateMethod(KNNMethodContext knnMethodContext);
ValidationException validateMethod(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext);

/**
* Returns whether training is required or not from knnMethodContext for the given library.
Expand All @@ -91,18 +92,22 @@ public interface KNNLibrary {
* Estimate overhead of KNNMethodContext in Kilobytes.
*
* @param knnMethodContext to estimate size for
* @param dimension to estimate size for
* @param knnMethodConfigContext configuration context for the method
* @return size overhead estimate in KB
*/
int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension);
int estimateOverheadInKB(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext);

/**
* Get the context from the library needed to build the index.
*
* @param knnMethodContext to get build context for
* @param knnMethodConfigContext configuration context for the method
* @return parameter map
*/
KNNLibraryIndexingContext getKNNLibraryIndexingContext(KNNMethodContext knnMethodContext);
KNNLibraryIndexingContext getKNNLibraryIndexingContext(
KNNMethodContext knnMethodContext,
KNNMethodConfigContext knnMethodConfigContext
);

/**
* Gets metadata related to methods supported by the library
Expand Down
13 changes: 9 additions & 4 deletions src/main/java/org/opensearch/knn/index/engine/KNNMethod.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ public interface KNNMethod {
* Validate that the configured KNNMethodContext is valid for this method
*
* @param knnMethodContext to be validated
* @param knnMethodConfigContext to be validated
* @return ValidationException produced by validation errors; null if no validations errors.
*/
ValidationException validate(KNNMethodContext knnMethodContext);
ValidationException validate(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext);

/**
* returns whether training is required or not
Expand All @@ -43,18 +44,22 @@ public interface KNNMethod {
* Returns the estimated overhead of the method in KB
*
* @param knnMethodContext context to estimate overhead
* @param dimension dimension to make estimate with
* @param knnMethodConfigContext config context to estimate overhead
* @return estimate overhead in KB
*/
int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension);
int estimateOverheadInKB(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext);

/**
* Parse knnMethodContext into context that the library can use to build the index
*
* @param knnMethodContext to generate the context for
* @param knnMethodConfigContext to generate the context for
* @return KNNLibraryIndexingContext
*/
KNNLibraryIndexingContext getKNNLibraryIndexingContext(KNNMethodContext knnMethodContext);
KNNLibraryIndexingContext getKNNLibraryIndexingContext(
KNNMethodContext knnMethodContext,
KNNMethodConfigContext knnMethodConfigContext
);

/**
* Get the search context for a particular method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ public class KNNMethodContext implements ToXContentFragment, Writeable {
private SpaceType spaceType;
@NonNull
private final MethodComponentContext methodComponentContext;
@NonNull
private final KNNMethodConfigContext knnMethodConfigContext;

/**
* Constructor from stream.
Expand All @@ -58,17 +56,16 @@ public KNNMethodContext(StreamInput in) throws IOException {
this.knnEngine = KNNEngine.getEngine(in.readString());
this.spaceType = SpaceType.getSpace(in.readString());
this.methodComponentContext = new MethodComponentContext(in);
this.knnMethodConfigContext = KNNMethodConfigContext.builder().build();
}

/**
* This method uses the knnEngine to validate that the method is compatible with the engine. Ensure that if
* the {@link KNNMethodConfigContext} is changed at all, this method is called again.
* This method uses the knnEngine to validate that the method is compatible with the engine.
*
* @param knnMethodConfigContext context to validate against
* @return ValidationException produced by validation errors; null if no validations errors.
*/
public ValidationException validate() {
return knnEngine.validateMethod(this);
public ValidationException validate(KNNMethodConfigContext knnMethodConfigContext) {
return knnEngine.validateMethod(this, knnMethodConfigContext);
}

/**
Expand All @@ -83,11 +80,11 @@ public boolean isTrainingRequired() {
/**
* This method estimates the overhead the knn method adds irrespective of the number of vectors
*
* @param dimension dimension to make estimate with
* @param knnMethodConfigContext context to estimate overhead
* @return size in Kilobytes
*/
public int estimateOverheadInKB(int dimension) {
return knnEngine.estimateOverheadInKB(this, dimension);
public int estimateOverheadInKB(KNNMethodConfigContext knnMethodConfigContext) {
return knnEngine.estimateOverheadInKB(this, knnMethodConfigContext);
}

/**
Expand Down Expand Up @@ -176,7 +173,7 @@ public static KNNMethodContext parse(Object in) {

MethodComponentContext method = new MethodComponentContext(name, parameters);

return new KNNMethodContext(engine, spaceType, method, KNNMethodConfigContext.builder().build());
return new KNNMethodContext(engine, spaceType, method);
}

@Override
Expand All @@ -197,18 +194,13 @@ public boolean equals(Object obj) {
equalsBuilder.append(knnEngine, other.knnEngine);
equalsBuilder.append(spaceType, other.spaceType);
equalsBuilder.append(methodComponentContext, other.methodComponentContext);
// equalsBuilder.append(knnMethodConfigContext, other.knnMethodConfigContext);

return equalsBuilder.isEquals();
}

@Override
public int hashCode() {
return new HashCodeBuilder().append(knnEngine)
.append(spaceType)
.append(methodComponentContext)
// .append(knnMethodConfigContext)
.toHashCode();
return new HashCodeBuilder().append(knnEngine).append(spaceType).append(methodComponentContext).toHashCode();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ public float score(float rawScore, SpaceType spaceType) {
}

@Override
public int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension) {
public int estimateOverheadInKB(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
String methodName = knnMethodContext.getMethodComponentContext().getName();
return methods.get(methodName).estimateOverheadInKB(knnMethodContext, dimension);
return methods.get(methodName).estimateOverheadInKB(knnMethodContext, knnMethodConfigContext);
}

@Override
Expand Down
Loading

0 comments on commit 5011f16

Please sign in to comment.