From 6671e643e76fbab7fa3b95e930d4cded49361362 Mon Sep 17 00:00:00 2001 From: Vijayan Balasubramanian Date: Wed, 8 Jan 2025 23:01:55 -0800 Subject: [PATCH] Add cosine similarity support for faiss engine FAISS engine doesn't support cosine similarity natively. However we can use inner product to achieve the same, because, when vectors are normalized then inner product will be same as cosine similarity. Hence, before ingestion and perform search, normalize the input vector and add it to faiss index with type as inner product. Since we will be storing normalized vector in segments, to get actual vectors, source can be used. By saving as normalized vector, we don't have to normalize whenever segments are merged. This will keep force merge time and search at competitive, provided we will face additional latency during indexing (one time where we normalize). We also support radial search for cosine similarity. Signed-off-by: Vijayan Balasubramanian --- CHANGELOG.md | 1 + .../org/opensearch/knn/index/SpaceType.java | 5 + .../knn/index/engine/AbstractKNNMethod.java | 24 ++- .../engine/KNNLibraryIndexingContext.java | 3 + .../engine/KNNLibraryIndexingContextImpl.java | 7 + .../engine/faiss/AbstractFaissMethod.java | 17 ++ .../knn/index/engine/faiss/Faiss.java | 22 ++- .../index/engine/faiss/FaissHNSWMethod.java | 3 +- .../index/engine/faiss/FaissIVFMethod.java | 3 +- .../index/mapper/FlatVectorFieldMapper.java | 5 + .../index/mapper/KNNVectorFieldMapper.java | 8 +- .../knn/index/mapper/LuceneFieldMapper.java | 7 + .../knn/index/mapper/MethodFieldMapper.java | 7 + .../knn/index/mapper/ModelFieldMapper.java | 26 +++ .../mapper/NormalizeVectorTransformer.java | 31 ++++ .../knn/index/mapper/VectorTransformer.java | 47 ++++++ .../mapper/VectorTransformerFactory.java | 55 +++++++ .../knn/index/query/KNNQueryBuilder.java | 15 +- .../org/opensearch/knn/index/FaissIT.java | 149 ++++++++++++++---- .../NormalizeVectorTransformerTests.java | 42 +++++ .../mapper/VectorTransformerFactoryTests.java | 72 +++++++++ .../org/opensearch/knn/KNNRestTestCase.java | 20 +++ 22 files changed, 522 insertions(+), 47 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java create mode 100644 src/main/java/org/opensearch/knn/index/mapper/VectorTransformer.java create mode 100644 src/main/java/org/opensearch/knn/index/mapper/VectorTransformerFactory.java create mode 100644 src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java create mode 100644 src/test/java/org/opensearch/knn/index/mapper/VectorTransformerFactoryTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index a09f40bbc..2ebcb2862 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add Support for Multi Values in innerHit for Nested k-NN Fields in Lucene and FAISS (#2283)[https://github.com/opensearch-project/k-NN/pull/2283] - Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292] - Add expand_nested_docs Parameter support to NMSLIB engine (#2331)[https://github.com/opensearch-project/k-NN/pull/2331] +- Add cosine similarity support for faiss engine (#2376)[https://github.com/opensearch-project/k-NN/pull/2376] ### Enhancements - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] - Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290] diff --git a/src/main/java/org/opensearch/knn/index/SpaceType.java b/src/main/java/org/opensearch/knn/index/SpaceType.java index 5d90071e8..147e260b9 100644 --- a/src/main/java/org/opensearch/knn/index/SpaceType.java +++ b/src/main/java/org/opensearch/knn/index/SpaceType.java @@ -77,6 +77,11 @@ public float scoreTranslation(float rawScore) { return Math.max((2.0F - rawScore) / 2.0F, 0.0F); } + @Override + public float scoreToDistanceTranslation(float score) { + return score; + } + @Override public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() { return KNNVectorSimilarityFunction.COSINE; diff --git a/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java b/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java index f53655136..98153b267 100644 --- a/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java @@ -13,6 +13,8 @@ import org.opensearch.knn.index.mapper.PerDimensionProcessor; import org.opensearch.knn.index.mapper.PerDimensionValidator; import org.opensearch.knn.index.mapper.SpaceVectorValidator; +import org.opensearch.knn.index.mapper.VectorTransformer; +import org.opensearch.knn.index.mapper.VectorTransformerFactory; import org.opensearch.knn.index.mapper.VectorValidator; import java.util.ArrayList; @@ -106,6 +108,10 @@ protected PerDimensionProcessor doGetPerDimensionProcessor( return PerDimensionProcessor.NOOP_PROCESSOR; } + protected VectorTransformer getVectorTransformer(KNNMethodContext knnMethodContext) { + return VectorTransformerFactory.getVectorTransformer(knnMethodContext); + } + @Override public KNNLibraryIndexingContext getKNNLibraryIndexingContext( KNNMethodContext knnMethodContext, @@ -116,7 +122,7 @@ public KNNLibraryIndexingContext getKNNLibraryIndexingContext( knnMethodConfigContext ); Map parameterMap = knnLibraryIndexingContext.getLibraryParameters(); - parameterMap.put(KNNConstants.SPACE_TYPE, knnMethodContext.getSpaceType().getValue()); + parameterMap.put(KNNConstants.SPACE_TYPE, getCompatibleSpaceType(knnMethodContext.getSpaceType()).getValue()); parameterMap.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, knnMethodConfigContext.getVectorDataType().getValue()); return KNNLibraryIndexingContextImpl.builder() .quantizationConfig(knnLibraryIndexingContext.getQuantizationConfig()) @@ -124,6 +130,7 @@ public KNNLibraryIndexingContext getKNNLibraryIndexingContext( .vectorValidator(doGetVectorValidator(knnMethodContext, knnMethodConfigContext)) .perDimensionValidator(doGetPerDimensionValidator(knnMethodContext, knnMethodConfigContext)) .perDimensionProcessor(doGetPerDimensionProcessor(knnMethodContext, knnMethodConfigContext)) + .vectorTransformer(getVectorTransformer(knnMethodContext)) .build(); } @@ -131,4 +138,19 @@ public KNNLibraryIndexingContext getKNNLibraryIndexingContext( public KNNLibrarySearchContext getKNNLibrarySearchContext() { return knnLibrarySearchContext; } + + /** + * Gets the compatible space type for the given space type parameter. + * This method validates and returns the appropriate space type that + * is compatible with the system's requirements. + * + * @param spaceType The space type to check for compatibility + * @return The compatible space type for the given input, returns the same + * space type if it's already compatible + * @see SpaceType + */ + + protected SpaceType getCompatibleSpaceType(SpaceType spaceType) { + return spaceType; + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java index 9208661af..283576bf6 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java @@ -8,6 +8,7 @@ import org.opensearch.knn.index.engine.qframe.QuantizationConfig; import org.opensearch.knn.index.mapper.PerDimensionProcessor; import org.opensearch.knn.index.mapper.PerDimensionValidator; +import org.opensearch.knn.index.mapper.VectorTransformer; import org.opensearch.knn.index.mapper.VectorValidator; import java.util.Map; @@ -47,4 +48,6 @@ public interface KNNLibraryIndexingContext { * @return Get the per dimension processor */ PerDimensionProcessor getPerDimensionProcessor(); + + VectorTransformer getVectorTransformer(); } diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java index f5329fc31..9822033b7 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java @@ -9,6 +9,7 @@ import org.opensearch.knn.index.engine.qframe.QuantizationConfig; import org.opensearch.knn.index.mapper.PerDimensionProcessor; import org.opensearch.knn.index.mapper.PerDimensionValidator; +import org.opensearch.knn.index.mapper.VectorTransformer; import org.opensearch.knn.index.mapper.VectorValidator; import java.util.Collections; @@ -23,6 +24,7 @@ public class KNNLibraryIndexingContextImpl implements KNNLibraryIndexingContext private VectorValidator vectorValidator; private PerDimensionValidator perDimensionValidator; private PerDimensionProcessor perDimensionProcessor; + private VectorTransformer vectorTransformer; @Builder.Default private Map parameters = Collections.emptyMap(); @Builder.Default @@ -43,6 +45,11 @@ public VectorValidator getVectorValidator() { return vectorValidator; } + @Override + public VectorTransformer getVectorTransformer() { + return vectorTransformer; + } + @Override public PerDimensionValidator getPerDimensionValidator() { return perDimensionValidator; diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java index 7ae403445..5810a5f54 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java @@ -17,6 +17,7 @@ import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.mapper.PerDimensionProcessor; import org.opensearch.knn.index.mapper.PerDimensionValidator; +import org.opensearch.knn.index.mapper.VectorTransformer; import java.util.Objects; import java.util.Set; @@ -89,6 +90,11 @@ protected PerDimensionProcessor doGetPerDimensionProcessor( throw new IllegalStateException("Unsupported vector data type " + vectorDataType); } + @Override + protected VectorTransformer getVectorTransformer(KNNMethodContext knnMethodContext) { + return super.getVectorTransformer(knnMethodContext); + } + static KNNLibraryIndexingContext adjustIndexDescription( MethodAsMapBuilder methodAsMapBuilder, MethodComponentContext methodComponentContext, @@ -132,4 +138,15 @@ static MethodComponentContext getEncoderMethodComponent(MethodComponentContext m } return (MethodComponentContext) object; } + + @Override + protected SpaceType getCompatibleSpaceType(SpaceType spaceType) { + // While FAISS doesn't directly support cosine similarity, we can leverage the mathematical + // relationship between cosine similarity and inner product for normalized vectors to add support. + // When ||a|| = ||b|| = 1, cos(θ) = a · b + if (spaceType == SpaceType.COSINESIMIL) { + return SpaceType.INNER_PRODUCT; + } + return spaceType; + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java b/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java index a602619a1..d4222dc8d 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java @@ -26,6 +26,7 @@ */ public class Faiss extends NativeLibrary { public static final String FAISS_BINARY_INDEX_DESCRIPTION_PREFIX = "B"; + Map> distanceTransform; Map> scoreTransform; // TODO: Current version is not really current version. Instead, it encodes information in the file name @@ -36,7 +37,10 @@ public class Faiss extends NativeLibrary { // Map that overrides OpenSearch score translation by space type of scores returned by faiss private final static Map> SCORE_TRANSLATIONS = ImmutableMap.of( SpaceType.INNER_PRODUCT, - rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1 * rawScore) + rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1 * rawScore), + // COSINESIMIL expects the raw score in 1 - cosine(x,y) + SpaceType.COSINESIMIL, + rawScore -> SpaceType.COSINESIMIL.scoreTranslation(1 - rawScore) ); // Map that overrides radial search score threshold to faiss required distance, check more details in knn documentation: @@ -45,6 +49,10 @@ public class Faiss extends NativeLibrary { SpaceType, Function>builder().put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : 1 / score - 1).build(); + private final static Map> DISTANCE_TRANSLATIONS = ImmutableMap.< + SpaceType, + Function>builder().put(SpaceType.COSINESIMIL, distance -> (2 - distance) / 2).build(); + // Package private so that the method resolving logic can access the methods final static Map METHODS = ImmutableMap.of(METHOD_HNSW, new FaissHNSWMethod(), METHOD_IVF, new FaissIVFMethod()); @@ -53,7 +61,8 @@ public class Faiss extends NativeLibrary { SCORE_TRANSLATIONS, CURRENT_VERSION, KNNConstants.FAISS_EXTENSION, - SCORE_TO_DISTANCE_TRANSFORMATIONS + SCORE_TO_DISTANCE_TRANSFORMATIONS, + DISTANCE_TRANSLATIONS ); private final MethodResolver methodResolver; @@ -71,22 +80,25 @@ private Faiss( Map> scoreTranslation, String currentVersion, String extension, - Map> scoreTransform + Map> scoreTransform, + Map> distanceTransform ) { super(methods, scoreTranslation, currentVersion, extension); this.scoreTransform = scoreTransform; + this.distanceTransform = distanceTransform; this.methodResolver = new FaissMethodResolver(); } @Override public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { - // Faiss engine uses distance as is and does not need transformation + if (this.distanceTransform.containsKey(spaceType)) { + return this.distanceTransform.get(spaceType).apply(distance); + } return distance; } @Override public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { - // Faiss engine uses distance as is and need transformation if (this.scoreTransform.containsKey(spaceType)) { return this.scoreTransform.get(spaceType).apply(score); } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java index c153a9328..3386f871c 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java @@ -46,7 +46,8 @@ public class FaissHNSWMethod extends AbstractFaissMethod { SpaceType.UNDEFINED, SpaceType.HAMMING, SpaceType.L2, - SpaceType.INNER_PRODUCT + SpaceType.INNER_PRODUCT, + SpaceType.COSINESIMIL ); private final static MethodComponentContext DEFAULT_ENCODER_CONTEXT = new MethodComponentContext( diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java index 340c1f4d8..582029392 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java @@ -49,7 +49,8 @@ public class FaissIVFMethod extends AbstractFaissMethod { SpaceType.UNDEFINED, SpaceType.L2, SpaceType.INNER_PRODUCT, - SpaceType.HAMMING + SpaceType.HAMMING, + SpaceType.COSINESIMIL ); private final static MethodComponentContext DEFAULT_ENCODER_CONTEXT = new MethodComponentContext( diff --git a/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java index 8da41aa59..f671b07f8 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java @@ -109,4 +109,9 @@ protected PerDimensionValidator getPerDimensionValidator() { protected PerDimensionProcessor getPerDimensionProcessor() { return PerDimensionProcessor.NOOP_PROCESSOR; } + + @Override + protected VectorTransformer getVectorTransformer() { + return VectorTransformer.NOOP_VECTOR_TRANSFORMER; + } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 67f3efa5b..2b94d4e9f 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -681,6 +681,8 @@ protected void validatePreparse() { */ protected abstract PerDimensionProcessor getPerDimensionProcessor(); + protected abstract VectorTransformer getVectorTransformer(); + protected void parseCreateField(ParseContext context, int dimension, VectorDataType vectorDataType) throws IOException { validatePreparse(); @@ -691,7 +693,8 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT } final byte[] array = bytesArrayOptional.get(); getVectorValidator().validateVector(array); - context.doc().addAll(getFieldsForByteVector(array)); + final byte[] transformedArray = getVectorTransformer().transform(array); + context.doc().addAll(getFieldsForByteVector(transformedArray)); } else if (VectorDataType.FLOAT == vectorDataType) { Optional floatsArrayOptional = getFloatsFromContext(context, dimension); @@ -700,7 +703,8 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT } final float[] array = floatsArrayOptional.get(); getVectorValidator().validateVector(array); - context.doc().addAll(getFieldsForFloatVector(array)); + final float[] transformedArray = getVectorTransformer().transform(array); + context.doc().addAll(getFieldsForFloatVector(transformedArray)); } else { throw new IllegalArgumentException( String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD) diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index 4ceb9b4b2..83f3ce4c5 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -42,6 +42,7 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { private final PerDimensionProcessor perDimensionProcessor; private final PerDimensionValidator perDimensionValidator; private final VectorValidator vectorValidator; + private final VectorTransformer vectorTransformer; static LuceneFieldMapper createFieldMapper( String fullname, @@ -122,6 +123,7 @@ private LuceneFieldMapper( this.perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor(); this.perDimensionValidator = knnLibraryIndexingContext.getPerDimensionValidator(); this.vectorValidator = knnLibraryIndexingContext.getVectorValidator(); + this.vectorTransformer = knnLibraryIndexingContext.getVectorTransformer(); } @Override @@ -169,6 +171,11 @@ protected PerDimensionProcessor getPerDimensionProcessor() { return perDimensionProcessor; } + @Override + protected VectorTransformer getVectorTransformer() { + return vectorTransformer; + } + @Override void updateEngineStats() { KNNEngine.LUCENE.setInitialized(true); diff --git a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java index 755439ce6..814bc4f63 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java @@ -39,6 +39,7 @@ public class MethodFieldMapper extends KNNVectorFieldMapper { private final PerDimensionProcessor perDimensionProcessor; private final PerDimensionValidator perDimensionValidator; private final VectorValidator vectorValidator; + private final VectorTransformer vectorTransformer; public static MethodFieldMapper createFieldMapper( String fullname, @@ -180,6 +181,7 @@ private MethodFieldMapper( this.perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor(); this.perDimensionValidator = knnLibraryIndexingContext.getPerDimensionValidator(); this.vectorValidator = knnLibraryIndexingContext.getVectorValidator(); + this.vectorTransformer = knnLibraryIndexingContext.getVectorTransformer(); } @Override @@ -196,4 +198,9 @@ protected PerDimensionValidator getPerDimensionValidator() { protected PerDimensionProcessor getPerDimensionProcessor() { return perDimensionProcessor; } + + @Override + protected VectorTransformer getVectorTransformer() { + return vectorTransformer; + } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index cbc7520cf..2458e3355 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -41,6 +41,7 @@ public class ModelFieldMapper extends KNNVectorFieldMapper { private PerDimensionProcessor perDimensionProcessor; private PerDimensionValidator perDimensionValidator; private VectorValidator vectorValidator; + private VectorTransformer vectorTransformer; private final String modelId; @@ -192,6 +193,31 @@ protected PerDimensionProcessor getPerDimensionProcessor() { return perDimensionProcessor; } + @Override + protected VectorTransformer getVectorTransformer() { + initVectorTransformer(); + return vectorTransformer; + } + + private void initVectorTransformer() { + if (vectorTransformer != null) { + return; + } + ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId); + + KNNMethodContext knnMethodContext = getKNNMethodContextFromModelMetadata(modelMetadata); + KNNMethodConfigContext knnMethodConfigContext = getKNNMethodConfigContextFromModelMetadata(modelMetadata); + // Need to handle BWC case + if (knnMethodContext == null || knnMethodConfigContext == null) { + vectorTransformer = VectorTransformerFactory.getVectorTransformer(modelMetadata.getKnnEngine(), modelMetadata.getSpaceType()); + return; + } + + KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine() + .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext); + vectorTransformer = knnLibraryIndexingContext.getVectorTransformer(); + } + private void initVectorValidator() { if (vectorValidator != null) { return; diff --git a/src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java b/src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java new file mode 100644 index 000000000..6a9642435 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformer.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import org.apache.lucene.util.VectorUtil; + +/** + * Normalizes vectors using L2 (Euclidean) normalization. This transformation ensures + * that the vector's magnitude becomes 1 while preserving its directional properties. + */ +public class NormalizeVectorTransformer implements VectorTransformer { + + /** + * Transforms the input vector into unit vector by applying L2 normalization. + * + * @param vector The input vector to be normalized. Must not be null. + * @return A new float array containing the L2-normalized version of the input vector. + * Each component is divided by the Euclidean norm of the vector. + * @throws IllegalArgumentException if the input vector is null, empty, or a zero vector + */ + @Override + public float[] transform(float[] vector) { + if (vector == null || vector.length == 0) { + throw new IllegalArgumentException("Vector cannot be null or empty"); + } + return VectorUtil.l2normalize(vector); + } +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/VectorTransformer.java b/src/main/java/org/opensearch/knn/index/mapper/VectorTransformer.java new file mode 100644 index 000000000..d29e4a460 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/VectorTransformer.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +/** + * Defines operations for transforming vectors in the k-NN search context. + * Implementations can modify vectors while preserving their dimensional properties + * for specific use cases such as normalization, scaling, or other transformations. + * + *

This interface provides default implementations that pass through the original + * vector without modification. Implementing classes should override these methods + * to provide specific transformation logic. + */ +public interface VectorTransformer { + + /** + * Transforms a float vector into a new vector of the same type. + * The default implementation returns the input vector unchanged. + * + * @param vector The input vector to transform + * @return The transformed vector + */ + default float[] transform(float[] vector) { + return vector; + } + + /** + * Transforms a byte vector into a new vector of the same type. + * The default implementation returns the input vector unchanged. + * + * @param vector The input vector to transform + * @return The transformed vector + */ + default byte[] transform(byte[] vector) { + return vector; + } + + /** + * A no-operation transformer that returns vectors unchanged. + * This constant can be used when no transformation is needed. + */ + VectorTransformer NOOP_VECTOR_TRANSFORMER = new VectorTransformer() { + }; +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/VectorTransformerFactory.java b/src/main/java/org/opensearch/knn/index/mapper/VectorTransformerFactory.java new file mode 100644 index 000000000..8726c48b0 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/VectorTransformerFactory.java @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodContext; + +/** + * Factory class responsible for creating appropriate vector transformers based on the KNN method context. + * This factory determines whether vectors need transformation based on the engine type and space type. + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public final class VectorTransformerFactory { + + /** + * Returns a vector transformer based on the provided KNN method context. + * For FAISS engine with cosine similarity space type, returns a NormalizeVectorTransformer + * since FAISS doesn't natively support cosine space type. For all other cases, + * returns a no-operation transformer. + * + * @param context The KNN method context containing engine and space type information + * @return VectorTransformer An appropriate vector transformer instance + * @throws IllegalArgumentException if the context parameter is null + */ + public static VectorTransformer getVectorTransformer(final KNNMethodContext context) { + if (context == null) { + throw new IllegalArgumentException("KNNMethod context cannot be null"); + } + return getVectorTransformer(context.getKnnEngine(), context.getSpaceType()); + } + + /** + * Returns a vector transformer based on the provided KNN engine and space type. + * For FAISS engine with cosine similarity space type, returns a NormalizeVectorTransformer + * since FAISS doesn't natively support cosine space type. For all other cases, + * returns a no-operation transformer. + * + * @param knnEngine The KNN engine type + * @param spaceType The space type + * @return VectorTransformer An appropriate vector transformer instance + */ + public static VectorTransformer getVectorTransformer(final KNNEngine knnEngine, final SpaceType spaceType) { + return shouldNormalizeVector(knnEngine, spaceType) ? new NormalizeVectorTransformer() : VectorTransformer.NOOP_VECTOR_TRANSFORMER; + } + + private static boolean shouldNormalizeVector(final KNNEngine knnEngine, final SpaceType spaceType) { + return knnEngine == KNNEngine.FAISS && spaceType == SpaceType.COSINESIMIL; + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index ee18394f6..c2998df6c 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -12,6 +12,7 @@ import org.apache.commons.lang.StringUtils; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.util.VectorUtil; import org.opensearch.common.ValidationException; import org.opensearch.core.ParseField; import org.opensearch.core.common.Strings; @@ -541,7 +542,7 @@ protected Query doToQuery(QueryShardContext context) { .knnEngine(knnEngine) .indexName(indexName) .fieldName(this.fieldName) - .vector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine)) + .vector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, spaceType)) .byteVector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, byteVector)) .vectorDataType(vectorDataType) .k(this.k) @@ -558,8 +559,8 @@ protected Query doToQuery(QueryShardContext context) { .knnEngine(knnEngine) .indexName(indexName) .fieldName(this.fieldName) - .vector(VectorDataType.FLOAT == vectorDataType ? this.vector : null) - .byteVector(VectorDataType.BYTE == vectorDataType ? byteVector : null) + .vector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, spaceType)) + .byteVector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, byteVector)) .vectorDataType(vectorDataType) .radius(radius) .methodParameters(this.methodParameters) @@ -611,7 +612,13 @@ private void updateQueryStats(VectorQueryType vectorQueryType) { } } - private float[] getVectorForCreatingQueryRequest(VectorDataType vectorDataType, KNNEngine knnEngine) { + private float[] getVectorForCreatingQueryRequest(VectorDataType vectorDataType, KNNEngine knnEngine, SpaceType spaceType) { + + // Cosine similarity is supported as Inner product by FAISS by normalizing input vector, hence, we have to normalize + // query vector before applying search + if (knnEngine == KNNEngine.FAISS && spaceType == SpaceType.COSINESIMIL && VectorDataType.FLOAT == vectorDataType) { + return VectorUtil.l2normalize(this.vector); + } if ((VectorDataType.FLOAT == vectorDataType) || (VectorDataType.BYTE == vectorDataType && KNNEngine.FAISS == knnEngine)) { return this.vector; } diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index c2e75ecb2..e7db4868b 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -17,6 +17,7 @@ import lombok.SneakyThrows; import org.apache.hc.core5.http.ParseException; import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.lucene.util.VectorUtil; import org.junit.BeforeClass; import org.opensearch.client.Response; import org.opensearch.common.settings.Settings; @@ -43,6 +44,7 @@ import java.util.Map; import java.util.Random; import java.util.TreeMap; +import java.util.function.BiFunction; import java.util.stream.Collectors; import static org.opensearch.knn.common.KNNConstants.DIMENSION; @@ -93,6 +95,7 @@ public class FaissIT extends KNNRestTestCase { private static final String FILED_TYPE_INTEGER = "integer"; private static final String NON_EXISTENT_INTEGER_FIELD_NAME = "nonexistent_int_field"; public static final int NEVER_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD = -1; + public static final int ALWAYS_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD = 0; static TestUtils.TestData testData; @@ -373,16 +376,7 @@ public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHN deleteModel(modelId); // Search every 5 seconds 14 times to confirm graph gets evicted - int intervals = 14; - for (int i = 0; i < intervals; i++) { - if (getTotalGraphsInCache() == 0) { - return; - } - - Thread.sleep(5 * 1000); - } - - fail("Graphs are not getting evicted"); + validateGraphEviction(); } @SneakyThrows @@ -555,17 +549,7 @@ public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() { deleteKNNIndex(indexName); deleteModel(modelId); - // Search every 5 seconds 14 times to confirm graph gets evicted - int intervals = 14; - for (int i = 0; i < intervals; i++) { - if (getTotalGraphsInCache() == 0) { - return; - } - - Thread.sleep(5 * 1000); - } - - fail("Graphs are not getting evicted"); + validateGraphEviction(); } @SneakyThrows @@ -1239,17 +1223,7 @@ public void testEndToEnd_whenMethodIsHNSWPQAndHyperParametersNotSet_thenSucceed( deleteKNNIndex(indexName); deleteModel(modelId); - // Search every 5 seconds 14 times to confirm graph gets evicted - int intervals = 14; - for (int i = 0; i < intervals; i++) { - if (getTotalGraphsInCache() == 0) { - return; - } - - Thread.sleep(5 * 1000); - } - - fail("Graphs are not getting evicted"); + validateGraphEviction(); } /** @@ -2050,6 +2024,20 @@ public void testQueryWithFilter_whenNonExistingFieldUsedInFilter_thenSuccessful( assertEquals(1, resultsQuery2.size()); } + public void testCosineSimilarity_withHNSW_withExactSearch_thenSucceed() throws Exception { + testCosineSimilarityForApproximateSearch(NEVER_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD); + } + + public void testCosineSimilarity_withHNSW_withApproximate_thenSucceed() throws Exception { + testCosineSimilarityForApproximateSearch(ALWAYS_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD); + validateGraphEviction(); + } + + public void testCosineSimilarity_withHNSW_withRadialSearch_thenSucceed() throws Exception { + testCosineSimilarityForRadialSearch(ALWAYS_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD); + validateGraphEviction(); + } + protected void setupKNNIndexForFilterQuery() throws Exception { setupKNNIndexForFilterQuery(getKNNDefaultIndexSettings()); } @@ -2161,7 +2149,7 @@ private List> validateRadiusSearchResults( if (filterQuery != null) { queryBuilder.field("filter", filterQuery); } - if (methodParameters != null) { + if (methodParameters != null && methodParameters.size() > 0) { queryBuilder.startObject(METHOD_PARAMETER); for (Map.Entry entry : methodParameters.entrySet()) { queryBuilder.field(entry.getKey(), entry.getValue()); @@ -2182,6 +2170,8 @@ private List> validateRadiusSearchResults( assertTrue(KNNScoringUtil.l2Squared(queryVector, vector) <= distance); } else if (spaceType == SpaceType.INNER_PRODUCT) { assertTrue(KNNScoringUtil.innerProduct(queryVector, vector) >= distance); + } else if (spaceType == SpaceType.COSINESIMIL) { + assertTrue(KNNScoringUtil.cosinesimil(queryVector, vector) >= distance); } else { throw new IllegalArgumentException("Invalid space type"); } @@ -2190,4 +2180,97 @@ private List> validateRadiusSearchResults( } return queryResults; } + + private void testCosineSimilarityForApproximateSearch(int approximateThreshold) throws Exception { + String indexName = randomLowerCaseString(); + String fieldName = randomLowerCaseString(); + SpaceType spaceType = SpaceType.COSINESIMIL; + indexTestData(approximateThreshold, indexName, spaceType, fieldName); + + // search index + validateNearestNeighborsSearch(indexName, fieldName, spaceType, 10, VectorUtil::cosine); + + // Delete index + deleteKNNIndex(indexName); + } + + private void testCosineSimilarityForRadialSearch(int approximateThreshold) throws Exception { + String indexName = randomLowerCaseString(); + String fieldName = randomLowerCaseString(); + SpaceType spaceType = SpaceType.COSINESIMIL; + indexTestData(approximateThreshold, indexName, spaceType, fieldName); + + // search index + validateRadiusSearchResults(indexName, fieldName, testData.queries, 0.92f, null, spaceType, null, null); + + // Delete index + deleteKNNIndex(indexName); + } + + private void indexTestData(int approximateThreshold, String indexName, SpaceType spaceType, String fieldName) throws Exception { + Integer dimension = testData.indexData.vectors[0].length; + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .endObject() + .endObject() + .endObject() + .endObject(); + + String mapping = builder.toString(); + createKnnIndex(indexName, buildKNNIndexSettings(approximateThreshold), mapping); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + fieldName, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + + refreshAllIndices(); + // Assert we have the right number of documents in the index + assertEquals(testData.indexData.docs.length, getDocCount(indexName)); + } + + @SneakyThrows + private void validateNearestNeighborsSearch( + final String indexName, + final String fieldName, + final SpaceType spaceType, + final int k, + final BiFunction scoringFunction + ) { + for (int i = 0; i < testData.queries.length; i++) { + final Response response = searchKNNIndex( + indexName, + KNNQueryBuilder.builder().fieldName(fieldName).vector(testData.queries[i]).k(k).build(), + k + ); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List knnResults = parseSearchResponse(responseBody, fieldName); + assertEquals(k, knnResults.size()); + + final List actualScores = parseSearchResponseScore(responseBody, fieldName); + for (int j = 0; j < k; j++) { + final float[] primitiveArray = knnResults.get(j).getVector(); + assertEquals( + KNNEngine.FAISS.score(scoringFunction.apply(testData.queries[i], primitiveArray), spaceType), + actualScores.get(j), + 0.0001 + ); + } + } + } + } diff --git a/src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java b/src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java new file mode 100644 index 000000000..923e59a26 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/mapper/NormalizeVectorTransformerTests.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import org.opensearch.knn.KNNTestCase; + +public class NormalizeVectorTransformerTests extends KNNTestCase { + private final NormalizeVectorTransformer transformer = new NormalizeVectorTransformer(); + private static final float DELTA = 0.001f; // Delta for floating point comparisons + + public void testNormalizeTransformer_withNullVector_thenThrowsException() { + assertThrows(IllegalArgumentException.class, () -> transformer.transform((float[]) null)); + assertThrows(IllegalArgumentException.class, () -> transformer.transform((byte[]) null)); + } + + public void testNormalizeTransformer_withEmptyVector_thenThrowsException() { + assertThrows(IllegalArgumentException.class, () -> transformer.transform(new float[0])); + } + + public void testNormalizeTransformer_withValidVector_thenSuccess() { + float[] input = { -3.0f, 4.0f }; + float[] normalized = transformer.transform(input); + + assertEquals(-0.6f, normalized[0], DELTA); + assertEquals(0.8f, normalized[1], DELTA); + + // Verify the magnitude is 1 + assertEquals(1.0f, calculateMagnitude(normalized), DELTA); + } + + private float calculateMagnitude(float[] vector) { + float magnitude = 0.0f; + for (float value : vector) { + magnitude += value * value; + } + return (float) Math.sqrt(magnitude); + } + +} diff --git a/src/test/java/org/opensearch/knn/index/mapper/VectorTransformerFactoryTests.java b/src/test/java/org/opensearch/knn/index/mapper/VectorTransformerFactoryTests.java new file mode 100644 index 000000000..6148f83d6 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/mapper/VectorTransformerFactoryTests.java @@ -0,0 +1,72 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodContext; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class VectorTransformerFactoryTests extends KNNTestCase { + public void testAllSpaceTypes_withFaiss() { + for (SpaceType spaceType : SpaceType.values()) { + VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(KNNEngine.FAISS, spaceType); + validateTransformer(spaceType, KNNEngine.FAISS, transformer); + } + } + + public void testAllEngines_withCosine() { + // Test all engines with COSINESIMIL space type + for (KNNEngine engine : KNNEngine.values()) { + VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(engine, SpaceType.COSINESIMIL); + validateTransformer(SpaceType.COSINESIMIL, engine, transformer); + } + } + + public void testGetVectorTransformer_withNullContext() { + // Test case for null context + assertThrows(IllegalArgumentException.class, () -> VectorTransformerFactory.getVectorTransformer(null)); + } + + public void testAllSpaceTypes_usingContext_withFaiss() { + for (SpaceType spaceType : SpaceType.values()) { + KNNMethodContext context = mock(KNNMethodContext.class); + when(context.getKnnEngine()).thenReturn(KNNEngine.FAISS); + when(context.getSpaceType()).thenReturn(spaceType); + VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(context); + validateTransformer(spaceType, KNNEngine.FAISS, transformer); + } + } + + public void testAllEngines_usingContext_withCosine() { + // Test all engines with COSINESIMIL space type + for (KNNEngine engine : KNNEngine.values()) { + KNNMethodContext context = mock(KNNMethodContext.class); + when(context.getKnnEngine()).thenReturn(engine); + when(context.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); + VectorTransformer transformer = VectorTransformerFactory.getVectorTransformer(context); + validateTransformer(SpaceType.COSINESIMIL, engine, transformer); + } + } + + private static void validateTransformer(SpaceType spaceType, KNNEngine engine, VectorTransformer transformer) { + if (spaceType == SpaceType.COSINESIMIL && engine == KNNEngine.FAISS) { + assertTrue( + "Should return NormalizeVectorTransformer for FAISS with " + spaceType, + transformer instanceof NormalizeVectorTransformer + ); + } else { + assertSame( + "Should return NOOP transformer for " + engine + " with COSINESIMIL", + VectorTransformer.NOOP_VECTOR_TRANSFORMER, + transformer + ); + } + } +} diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 896674a18..3fd6aedc1 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -129,6 +129,8 @@ public class KNNRestTestCase extends ODFERestTestCase { protected static final int DELAY_MILLI_SEC = 1000; protected static final int NUM_OF_ATTEMPTS = 30; private static final String SYSTEM_INDEX_PREFIX = ".opendistro"; + public static final int MIN_CODE_UNITS = 4; + public static final int MAX_CODE_UNITS = 10; @AfterClass public static void dumpCoverage() throws IOException, MalformedObjectNameException { @@ -1939,4 +1941,22 @@ protected boolean isApproximateThresholdSupported(final Optional bwcVers final Version version = Version.fromString(versionString); return version.onOrAfter(Version.V_2_18_0); } + + /** + * Generates a random lowercase string with length between MIN_CODE_UNITS and MAX_CODE_UNITS. + * This method is used for test fixtures to generate random string values that can be used + * as identifiers, names, or other string-based test data. + * Example usage: + *

+     * String randomId = randomLowerCaseString();
+     * String indexName = randomLowerCaseString();
+     * String fieldName = randomLowerCaseString();
+     * 
+ * + * @return A random lowercase string of variable length between MIN_CODE_UNITS and MAX_CODE_UNITS + * @see #randomAlphaOfLengthBetween(int, int) + */ + protected static String randomLowerCaseString() { + return randomAlphaOfLengthBetween(MIN_CODE_UNITS, MAX_CODE_UNITS).toLowerCase(Locale.ROOT); + } }