Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cosine similarity support for faiss engine #2376

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Add Support for Multi Values in innerHit for Nested k-NN Fields in Lucene and FAISS (#2283)[https://github.com/opensearch-project/k-NN/pull/2283]
- Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292]
- Add expand_nested_docs Parameter support to NMSLIB engine (#2331)[https://github.com/opensearch-project/k-NN/pull/2331]
- Add cosine similarity support for faiss engine (#2376)[https://github.com/opensearch-project/k-NN/pull/2376]
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/org/opensearch/knn/index/SpaceType.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ public float scoreTranslation(float rawScore) {
return Math.max((2.0F - rawScore) / 2.0F, 0.0F);
}

@Override
public float scoreToDistanceTranslation(float score) {
return score;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confused - Why is this correct?

}

@Override
public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
return KNNVectorSimilarityFunction.COSINE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.opensearch.knn.index.mapper.PerDimensionProcessor;
import org.opensearch.knn.index.mapper.PerDimensionValidator;
import org.opensearch.knn.index.mapper.SpaceVectorValidator;
import org.opensearch.knn.index.mapper.VectorTransformer;
import org.opensearch.knn.index.mapper.VectorTransformerFactory;
import org.opensearch.knn.index.mapper.VectorValidator;

import java.util.ArrayList;
Expand Down Expand Up @@ -106,6 +108,10 @@ protected PerDimensionProcessor doGetPerDimensionProcessor(
return PerDimensionProcessor.NOOP_PROCESSOR;
}

protected VectorTransformer getVectorTransformer(KNNMethodContext knnMethodContext) {
return VectorTransformerFactory.getVectorTransformer(knnMethodContext);
}

@Override
public KNNLibraryIndexingContext getKNNLibraryIndexingContext(
KNNMethodContext knnMethodContext,
Expand All @@ -116,19 +122,37 @@ public KNNLibraryIndexingContext getKNNLibraryIndexingContext(
knnMethodConfigContext
);
Map<String, Object> parameterMap = knnLibraryIndexingContext.getLibraryParameters();
parameterMap.put(KNNConstants.SPACE_TYPE, knnMethodContext.getSpaceType().getValue());
parameterMap.put(KNNConstants.SPACE_TYPE, convertUserToMethodSpaceType(knnMethodContext.getSpaceType()).getValue());
parameterMap.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, knnMethodConfigContext.getVectorDataType().getValue());
return KNNLibraryIndexingContextImpl.builder()
.quantizationConfig(knnLibraryIndexingContext.getQuantizationConfig())
.parameters(parameterMap)
.vectorValidator(doGetVectorValidator(knnMethodContext, knnMethodConfigContext))
.perDimensionValidator(doGetPerDimensionValidator(knnMethodContext, knnMethodConfigContext))
.perDimensionProcessor(doGetPerDimensionProcessor(knnMethodContext, knnMethodConfigContext))
.vectorTransformer(getVectorTransformer(knnMethodContext))
.build();
}

@Override
public KNNLibrarySearchContext getKNNLibrarySearchContext() {
return knnLibrarySearchContext;
}

/**
* Converts user defined space type to method space type that is supported by library.
* The subclass can override this method and returns the appropriate space type that
* is supported by the library. This is required because, some libraries may not
* support all the space types supported by OpenSearch, however. this can be achieved by using compatible space type by the library.
* For example, faiss does not support cosine similarity. However, we can use inner product space type for cosine similarity after normalization.
* In this case, we can return the inner product space type for cosine similarity.
*
* @param spaceType The space type to check for compatibility
* @return The compatible space type for the given input, returns the same
* space type if it's already compatible
* @see SpaceType
*/
protected SpaceType convertUserToMethodSpaceType(SpaceType spaceType) {
return spaceType;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.mapper.PerDimensionProcessor;
import org.opensearch.knn.index.mapper.PerDimensionValidator;
import org.opensearch.knn.index.mapper.VectorTransformer;
import org.opensearch.knn.index.mapper.VectorValidator;

import java.util.Map;
Expand Down Expand Up @@ -47,4 +48,12 @@ public interface KNNLibraryIndexingContext {
* @return Get the per dimension processor
*/
PerDimensionProcessor getPerDimensionProcessor();

/**
* Get the vector transformer that will be used to transform the vector before indexing.
* This will be applied at vector level once entire vector is parsed and validated.
*
* @return VectorTransformer
*/
VectorTransformer getVectorTransformer();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

javadoc?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.mapper.PerDimensionProcessor;
import org.opensearch.knn.index.mapper.PerDimensionValidator;
import org.opensearch.knn.index.mapper.VectorTransformer;
import org.opensearch.knn.index.mapper.VectorValidator;

import java.util.Collections;
Expand All @@ -23,6 +24,7 @@ public class KNNLibraryIndexingContextImpl implements KNNLibraryIndexingContext
private VectorValidator vectorValidator;
private PerDimensionValidator perDimensionValidator;
private PerDimensionProcessor perDimensionProcessor;
private VectorTransformer vectorTransformer;
@Builder.Default
private Map<String, Object> parameters = Collections.emptyMap();
@Builder.Default
Expand All @@ -43,6 +45,11 @@ public VectorValidator getVectorValidator() {
return vectorValidator;
}

@Override
public VectorTransformer getVectorTransformer() {
return vectorTransformer;
}

@Override
public PerDimensionValidator getPerDimensionValidator() {
return perDimensionValidator;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.mapper.PerDimensionProcessor;
import org.opensearch.knn.index.mapper.PerDimensionValidator;
import org.opensearch.knn.index.mapper.VectorTransformer;

import java.util.Objects;
import java.util.Set;
Expand Down Expand Up @@ -89,6 +90,11 @@ protected PerDimensionProcessor doGetPerDimensionProcessor(
throw new IllegalStateException("Unsupported vector data type " + vectorDataType);
}

@Override
protected VectorTransformer getVectorTransformer(KNNMethodContext knnMethodContext) {
return super.getVectorTransformer(knnMethodContext);
}

static KNNLibraryIndexingContext adjustIndexDescription(
MethodAsMapBuilder methodAsMapBuilder,
MethodComponentContext methodComponentContext,
Expand Down Expand Up @@ -132,4 +138,15 @@ static MethodComponentContext getEncoderMethodComponent(MethodComponentContext m
}
return (MethodComponentContext) object;
}

@Override
protected SpaceType convertUserToMethodSpaceType(SpaceType spaceType) {
// While FAISS doesn't directly support cosine similarity, we can leverage the mathematical
// relationship between cosine similarity and inner product for normalized vectors to add support.
// When ||a|| = ||b|| = 1, cos(θ) = a · b
if (spaceType == SpaceType.COSINESIMIL) {
return SpaceType.INNER_PRODUCT;
}
return super.convertUserToMethodSpaceType(spaceType);
}
}
22 changes: 17 additions & 5 deletions src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
*/
public class Faiss extends NativeLibrary {
public static final String FAISS_BINARY_INDEX_DESCRIPTION_PREFIX = "B";
Map<SpaceType, Function<Float, Float>> distanceTransform;
Map<SpaceType, Function<Float, Float>> scoreTransform;

// TODO: Current version is not really current version. Instead, it encodes information in the file name
Expand All @@ -36,7 +37,10 @@ public class Faiss extends NativeLibrary {
// Map that overrides OpenSearch score translation by space type of scores returned by faiss
private final static Map<SpaceType, Function<Float, Float>> SCORE_TRANSLATIONS = ImmutableMap.of(
SpaceType.INNER_PRODUCT,
rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1 * rawScore)
rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1 * rawScore),
// COSINESIMIL expects the raw score in 1 - cosine(x,y)
SpaceType.COSINESIMIL,
rawScore -> SpaceType.COSINESIMIL.scoreTranslation(1 - rawScore)
);

// Map that overrides radial search score threshold to faiss required distance, check more details in knn documentation:
Expand All @@ -45,6 +49,10 @@ public class Faiss extends NativeLibrary {
SpaceType,
Function<Float, Float>>builder().put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : 1 / score - 1).build();

private final static Map<SpaceType, Function<Float, Float>> DISTANCE_TRANSLATIONS = ImmutableMap.<
SpaceType,
Function<Float, Float>>builder().put(SpaceType.COSINESIMIL, distance -> (2 - distance) / 2).build();

// Package private so that the method resolving logic can access the methods
final static Map<String, KNNMethod> METHODS = ImmutableMap.of(METHOD_HNSW, new FaissHNSWMethod(), METHOD_IVF, new FaissIVFMethod());

Expand All @@ -53,7 +61,8 @@ public class Faiss extends NativeLibrary {
SCORE_TRANSLATIONS,
CURRENT_VERSION,
KNNConstants.FAISS_EXTENSION,
SCORE_TO_DISTANCE_TRANSFORMATIONS
SCORE_TO_DISTANCE_TRANSFORMATIONS,
DISTANCE_TRANSLATIONS
);

private final MethodResolver methodResolver;
Expand All @@ -71,22 +80,25 @@ private Faiss(
Map<SpaceType, Function<Float, Float>> scoreTranslation,
String currentVersion,
String extension,
Map<SpaceType, Function<Float, Float>> scoreTransform
Map<SpaceType, Function<Float, Float>> scoreTransform,
Map<SpaceType, Function<Float, Float>> distanceTransform
) {
super(methods, scoreTranslation, currentVersion, extension);
this.scoreTransform = scoreTransform;
this.distanceTransform = distanceTransform;
this.methodResolver = new FaissMethodResolver();
}

@Override
public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) {
// Faiss engine uses distance as is and does not need transformation
if (this.distanceTransform.containsKey(spaceType)) {
return this.distanceTransform.get(spaceType).apply(distance);
}
return distance;
}

@Override
public Float scoreToRadialThreshold(Float score, SpaceType spaceType) {
// Faiss engine uses distance as is and need transformation
if (this.scoreTransform.containsKey(spaceType)) {
return this.scoreTransform.get(spaceType).apply(score);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ public class FaissHNSWMethod extends AbstractFaissMethod {
SpaceType.UNDEFINED,
SpaceType.HAMMING,
SpaceType.L2,
SpaceType.INNER_PRODUCT
SpaceType.INNER_PRODUCT,
SpaceType.COSINESIMIL
);

private final static MethodComponentContext DEFAULT_ENCODER_CONTEXT = new MethodComponentContext(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ public class FaissIVFMethod extends AbstractFaissMethod {
SpaceType.UNDEFINED,
SpaceType.L2,
SpaceType.INNER_PRODUCT,
SpaceType.HAMMING
SpaceType.HAMMING,
SpaceType.COSINESIMIL
);

private final static MethodComponentContext DEFAULT_ENCODER_CONTEXT = new MethodComponentContext(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,9 @@ protected PerDimensionValidator getPerDimensionValidator() {
protected PerDimensionProcessor getPerDimensionProcessor() {
return PerDimensionProcessor.NOOP_PROCESSOR;
}

@Override
protected VectorTransformer getVectorTransformer() {
return VectorTransformer.NOOP_VECTOR_TRANSFORMER;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,8 @@ protected void validatePreparse() {
*/
protected abstract PerDimensionProcessor getPerDimensionProcessor();

protected abstract VectorTransformer getVectorTransformer();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add java doc for this one? Also, can you update javadoc for getVectorValidator() to say that it is validated before any transform calls?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack


protected void parseCreateField(ParseContext context, int dimension, VectorDataType vectorDataType) throws IOException {
validatePreparse();

Expand All @@ -691,7 +693,8 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT
}
final byte[] array = bytesArrayOptional.get();
getVectorValidator().validateVector(array);
context.doc().addAll(getFieldsForByteVector(array));
final byte[] transformedArray = getVectorTransformer().transform(array);
context.doc().addAll(getFieldsForByteVector(transformedArray));
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension);

Expand All @@ -700,7 +703,8 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT
}
final float[] array = floatsArrayOptional.get();
getVectorValidator().validateVector(array);
context.doc().addAll(getFieldsForFloatVector(array));
final float[] transformedArray = getVectorTransformer().transform(array);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be called before the per-dimension processor too? What should contract around these 2 be? Im wondering if we even need the per-dimension or if we can wrap that in this new full vector transform.

context.doc().addAll(getFieldsForFloatVector(transformedArray));
} else {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper {
private final PerDimensionProcessor perDimensionProcessor;
private final PerDimensionValidator perDimensionValidator;
private final VectorValidator vectorValidator;
private final VectorTransformer vectorTransformer;

static LuceneFieldMapper createFieldMapper(
String fullname,
Expand Down Expand Up @@ -122,6 +123,7 @@ private LuceneFieldMapper(
this.perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor();
this.perDimensionValidator = knnLibraryIndexingContext.getPerDimensionValidator();
this.vectorValidator = knnLibraryIndexingContext.getVectorValidator();
this.vectorTransformer = knnLibraryIndexingContext.getVectorTransformer();
}

@Override
Expand Down Expand Up @@ -169,6 +171,11 @@ protected PerDimensionProcessor getPerDimensionProcessor() {
return perDimensionProcessor;
}

@Override
protected VectorTransformer getVectorTransformer() {
return vectorTransformer;
}

@Override
void updateEngineStats() {
KNNEngine.LUCENE.setInitialized(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public class MethodFieldMapper extends KNNVectorFieldMapper {
private final PerDimensionProcessor perDimensionProcessor;
private final PerDimensionValidator perDimensionValidator;
private final VectorValidator vectorValidator;
private final VectorTransformer vectorTransformer;

public static MethodFieldMapper createFieldMapper(
String fullname,
Expand Down Expand Up @@ -180,6 +181,7 @@ private MethodFieldMapper(
this.perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor();
this.perDimensionValidator = knnLibraryIndexingContext.getPerDimensionValidator();
this.vectorValidator = knnLibraryIndexingContext.getVectorValidator();
this.vectorTransformer = knnLibraryIndexingContext.getVectorTransformer();
}

@Override
Expand All @@ -196,4 +198,9 @@ protected PerDimensionValidator getPerDimensionValidator() {
protected PerDimensionProcessor getPerDimensionProcessor() {
return perDimensionProcessor;
}

@Override
protected VectorTransformer getVectorTransformer() {
return vectorTransformer;
}
}
Loading
Loading