Skip to content

Commit

Permalink
Fix code review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB committed Jan 14, 2025
1 parent f35f222 commit 8764557
Show file tree
Hide file tree
Showing 12 changed files with 117 additions and 82 deletions.
5 changes: 0 additions & 5 deletions src/main/java/org/opensearch/knn/index/SpaceType.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,6 @@ public float scoreTranslation(float rawScore) {
return Math.max((2.0F - rawScore) / 2.0F, 0.0F);
}

@Override
public float scoreToDistanceTranslation(float score) {
return score;
}

@Override
public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
return KNNVectorSimilarityFunction.COSINE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.mapper.PerDimensionProcessor;
import org.opensearch.knn.index.mapper.PerDimensionValidator;
import org.opensearch.knn.index.mapper.VectorTransformer;

import java.util.Objects;
import java.util.Set;
Expand Down Expand Up @@ -90,11 +89,6 @@ protected PerDimensionProcessor doGetPerDimensionProcessor(
throw new IllegalStateException("Unsupported vector data type " + vectorDataType);
}

@Override
protected VectorTransformer getVectorTransformer(KNNMethodContext knnMethodContext) {
return super.getVectorTransformer(knnMethodContext);
}

static KNNLibraryIndexingContext adjustIndexDescription(
MethodAsMapBuilder methodAsMapBuilder,
MethodComponentContext methodComponentContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,14 @@ public class Faiss extends NativeLibrary {
// https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces
private final static Map<SpaceType, Function<Float, Float>> SCORE_TO_DISTANCE_TRANSFORMATIONS = ImmutableMap.<
SpaceType,
Function<Float, Float>>builder().put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : 1 / score - 1).build();
Function<Float, Float>>builder()
.put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : (1 / score) - 1)
.put(SpaceType.COSINESIMIL, score -> 2 - 2 * score)
.build();

private final static Map<SpaceType, Function<Float, Float>> DISTANCE_TRANSLATIONS = ImmutableMap.<
SpaceType,
Function<Float, Float>>builder().put(SpaceType.COSINESIMIL, distance -> (2 - distance) / 2).build();
Function<Float, Float>>builder().put(SpaceType.COSINESIMIL, distance -> 1 - distance).build();

// Package private so that the method resolving logic can access the methods
final static Map<String, KNNMethod> METHODS = ImmutableMap.of(METHOD_HNSW, new FaissHNSWMethod(), METHOD_IVF, new FaissIVFMethod());
Expand Down Expand Up @@ -99,6 +102,7 @@ public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) {

@Override
public Float scoreToRadialThreshold(Float score, SpaceType spaceType) {
// Faiss engine uses distance as is and need transformation
if (this.scoreTransform.containsKey(spaceType)) {
return this.scoreTransform.get(spaceType).apply(score);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ protected void validatePreparse() {
protected abstract VectorValidator getVectorValidator();

/**
* Getter for per dimension validator during vector parsing
* Getter for per dimension validator during vector parsing, and before any transformation
*
* @return PerDimensionValidator
*/
Expand All @@ -688,6 +688,11 @@ protected void validatePreparse() {
*/
protected abstract PerDimensionProcessor getPerDimensionProcessor();

/**
* Getter for vector transformer after vector parsing and validation
*
* @return VectorTransformer
*/
protected abstract VectorTransformer getVectorTransformer();

protected void parseCreateField(ParseContext context, int dimension, VectorDataType vectorDataType) throws IOException {
Expand All @@ -700,8 +705,8 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT
}
final byte[] array = bytesArrayOptional.get();
getVectorValidator().validateVector(array);
final byte[] transformedArray = getVectorTransformer().transform(array);
context.doc().addAll(getFieldsForByteVector(transformedArray));
getVectorTransformer().transform(array);
context.doc().addAll(getFieldsForByteVector(array));
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension);

Expand All @@ -710,8 +715,8 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT
}
final float[] array = floatsArrayOptional.get();
getVectorValidator().validateVector(array);
final float[] transformedArray = getVectorTransformer().transform(array);
context.doc().addAll(getFieldsForFloatVector(transformedArray));
getVectorTransformer().transform(array);
context.doc().addAll(getFieldsForFloatVector(array));
} else {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
import org.opensearch.index.query.QueryShardException;
import org.opensearch.knn.index.KNNVectorIndexFieldData;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.search.aggregations.support.CoreValuesSourceType;
import org.opensearch.search.lookup.SearchLookup;

import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.deserializeStoredVector;
Expand Down Expand Up @@ -99,4 +103,37 @@ public RescoreContext resolveRescoreContext(RescoreContext userProvidedContext)
Mode mode = knnMappingConfig.getMode();
return compressionLevel.getDefaultRescoreContext(mode, dimension);
}

/**
* Transforms a query vector based on the field's configuration. The transformation is performed
* in-place on the input vector according to either the KNN method context or the model ID.
*
* @param vector The float array to be transformed in-place. Must not be null.
* @throws IllegalStateException if neither KNN method context nor Model ID is configured
*
* The transformation process follows this order:
* 1. If vector is not FLOAT type, no transformation is performed
* 2. Attempts to use KNN method context if present
* 3. Falls back to model ID if KNN method context is not available
* 4. Throws exception if neither configuration is present
*/
public void transformQueryVector(float[] vector) {
if (VectorDataType.FLOAT != vectorDataType) {
return;
}
final Optional<KNNMethodContext> knnMethodContext = knnMappingConfig.getKnnMethodContext();
if (knnMethodContext.isPresent()) {
VectorTransformerFactory.getVectorTransformer(knnMethodContext.get()).transform(vector);
return;
}
final Optional<String> modelId = knnMappingConfig.getModelId();
if (modelId.isPresent()) {
ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance();
final ModelMetadata metadata = modelDao.getMetadata(modelId.get());
VectorTransformerFactory.getVectorTransformer(metadata).transform(vector);
return;
}
throw new IllegalStateException("Either KNN method context or Model Id should be configured");

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,8 @@ private void initVectorTransformer() {
KNNMethodConfigContext knnMethodConfigContext = getKNNMethodConfigContextFromModelMetadata(modelMetadata);
// Need to handle BWC case
if (knnMethodContext == null || knnMethodConfigContext == null) {
log.debug(
"Method Context not available - falling back to Model Metadata for Engine and Space type to determine VectorTransformer instance"
);
vectorTransformer = VectorTransformerFactory.getVectorTransformer(modelMetadata.getKnnEngine(), modelMetadata.getSpaceType());
log.debug("Method Context not available - falling back to Model Metadata to determine VectorTransformer instance");
vectorTransformer = VectorTransformerFactory.getVectorTransformer(modelMetadata);
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,25 @@
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.mapper;

import org.apache.lucene.util.VectorUtil;

/**
* Normalizes vectors using L2 (Euclidean) normalization. This transformation ensures
* that the vector's magnitude becomes 1 while preserving its directional properties.
* Normalizes vectors using L2 (Euclidean) normalization, ensuring the vector's
* magnitude becomes 1 while preserving its directional properties.
*/
public class NormalizeVectorTransformer implements VectorTransformer {

/**
* Transforms the input vector into unit vector by applying L2 normalization.
*
* @param vector The input vector to be normalized. Must not be null.
* @return A new float array containing the L2-normalized version of the input vector.
* Each component is divided by the Euclidean norm of the vector.
* @throws IllegalArgumentException if the input vector is null, empty, or a zero vector
*/
@Override
public float[] transform(float[] vector) {
public void transform(float[] vector) {
validateVector(vector);
VectorUtil.l2normalize(vector);
}

private void validateVector(float[] vector) {
if (vector == null || vector.length == 0) {
throw new IllegalArgumentException("Vector cannot be null or empty");
}
return VectorUtil.l2normalize(vector);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.mapper;

import java.util.Arrays;

/**
* Defines operations for transforming vectors in the k-NN search context.
* Implementations can modify vectors while preserving their dimensional properties
Expand All @@ -15,50 +12,31 @@
public interface VectorTransformer {

/**
* Transforms a float vector into a new vector of the same type.
*
* Example:
* <pre>{@code
* float[] input = {1.0f, 2.0f, 3.0f};
* float[] transformed = transformer.transform(input);
* }</pre>
* Transforms a float vector in place.
*
* @param vector The input vector to transform (must not be null)
* @return The transformed vector
* @throws IllegalArgumentException if the input vector is null
*/
default float[] transform(final float[] vector) {
default void transform(final float[] vector) {
if (vector == null) {
throw new IllegalArgumentException("Input vector cannot be null");
}
return Arrays.copyOf(vector, vector.length);
}

/**
* Transforms a byte vector into a new vector of the same type.
*
* Example:
* <pre>{@code
* byte[] input = {1, 2, 3};
* byte[] transformed = transformer.transform(input);
* }</pre>
* Transforms a byte vector in place.
*
* @param vector The input vector to transform (must not be null)
* @return The transformed vector
* @throws IllegalArgumentException if the input vector is null
*/
default byte[] transform(final byte[] vector) {
default void transform(final byte[] vector) {
if (vector == null) {
throw new IllegalArgumentException("Input vector cannot be null");
}
// return copy of vector to avoid side effects
return Arrays.copyOf(vector, vector.length);

}

/**
* A no-operation transformer that returns vector values unchanged.
* This constant can be used when no transformation is needed.
*/
VectorTransformer NOOP_VECTOR_TRANSFORMER = new VectorTransformer() {
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.indices.ModelMetadata;

/**
* Factory class responsible for creating appropriate vector transformers based on the KNN method context.
Expand All @@ -35,6 +36,28 @@ public static VectorTransformer getVectorTransformer(final KNNMethodContext cont
return getVectorTransformer(context.getKnnEngine(), context.getSpaceType());
}

/**
* Creates a VectorTransformer based on the provided model metadata.
*
* @param metadata The model metadata containing KNN engine and space type configuration.
* This parameter must not be null.
* @return A VectorTransformer instance configured according to the model metadata
* @throws IllegalArgumentException if metadata is null
*
* The factory determines the appropriate transformer implementation based on:
* - The KNN engine (e.g., FAISS, NMSLIB)
* - The space type (e.g., L2, COSINE)
*
* The returned transformer can be used to modify vectors in-place according to
* the specified engine and space type requirements.
*/
public static VectorTransformer getVectorTransformer(final ModelMetadata metadata) {
if (metadata == null) {
throw new IllegalArgumentException("ModelMetadata cannot be null");
}
return getVectorTransformer(metadata.getKnnEngine(), metadata.getSpaceType());
}

/**
* Returns a vector transformer based on the provided KNN engine and space type.
* For FAISS engine with cosine similarity space type, returns a NormalizeVectorTransformer
Expand All @@ -45,7 +68,7 @@ public static VectorTransformer getVectorTransformer(final KNNMethodContext cont
* @param spaceType The space type
* @return VectorTransformer An appropriate vector transformer instance
*/
public static VectorTransformer getVectorTransformer(final KNNEngine knnEngine, final SpaceType spaceType) {
private static VectorTransformer getVectorTransformer(final KNNEngine knnEngine, final SpaceType spaceType) {
return shouldNormalizeVector(knnEngine, spaceType) ? new NormalizeVectorTransformer() : VectorTransformer.NOOP_VECTOR_TRANSFORMER;
}

Expand Down
14 changes: 4 additions & 10 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.apache.commons.lang.StringUtils;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.VectorUtil;
import org.opensearch.common.ValidationException;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.Strings;
Expand Down Expand Up @@ -429,6 +428,7 @@ protected Query doToQuery(QueryShardContext context) {
SpaceType spaceType = queryConfigFromMapping.get().getSpaceType();
VectorDataType vectorDataType = queryConfigFromMapping.get().getVectorDataType();
RescoreContext processedRescoreContext = knnVectorFieldType.resolveRescoreContext(rescoreContext);
knnVectorFieldType.transformQueryVector(vector);

VectorQueryType vectorQueryType = getVectorQueryType(k, maxDistance, minScore);
updateQueryStats(vectorQueryType);
Expand Down Expand Up @@ -542,7 +542,7 @@ protected Query doToQuery(QueryShardContext context) {
.knnEngine(knnEngine)
.indexName(indexName)
.fieldName(this.fieldName)
.vector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, spaceType))
.vector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine))
.byteVector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, byteVector))
.vectorDataType(vectorDataType)
.k(this.k)
Expand All @@ -559,7 +559,7 @@ protected Query doToQuery(QueryShardContext context) {
.knnEngine(knnEngine)
.indexName(indexName)
.fieldName(this.fieldName)
.vector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, spaceType))
.vector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine))
.byteVector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, byteVector))
.vectorDataType(vectorDataType)
.radius(radius)
Expand Down Expand Up @@ -612,13 +612,7 @@ private void updateQueryStats(VectorQueryType vectorQueryType) {
}
}

private float[] getVectorForCreatingQueryRequest(VectorDataType vectorDataType, KNNEngine knnEngine, SpaceType spaceType) {

// Cosine similarity is supported as Inner product by FAISS by normalizing input vector, hence, we have to normalize
// query vector before applying search
if (knnEngine == KNNEngine.FAISS && spaceType == SpaceType.COSINESIMIL && VectorDataType.FLOAT == vectorDataType) {
return VectorUtil.l2normalize(this.vector);
}
private float[] getVectorForCreatingQueryRequest(VectorDataType vectorDataType, KNNEngine knnEngine) {
if ((VectorDataType.FLOAT == vectorDataType) || (VectorDataType.BYTE == vectorDataType && KNNEngine.FAISS == knnEngine)) {
return this.vector;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ public void testNormalizeTransformer_withEmptyVector_thenThrowsException() {

public void testNormalizeTransformer_withValidVector_thenSuccess() {
float[] input = { -3.0f, 4.0f };
float[] normalized = transformer.transform(input);
transformer.transform(input);

assertEquals(-0.6f, normalized[0], DELTA);
assertEquals(0.8f, normalized[1], DELTA);
assertEquals(-0.6f, input[0], DELTA);
assertEquals(0.8f, input[1], DELTA);

// Verify the magnitude is 1
assertEquals(1.0f, calculateMagnitude(normalized), DELTA);
assertEquals(1.0f, calculateMagnitude(input), DELTA);
}

private float calculateMagnitude(float[] vector) {
Expand Down
Loading

0 comments on commit 8764557

Please sign in to comment.