Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for Multi Values in innerHit for Nested k-NN Fields in Lucene and FAISS #2283

Merged
merged 1 commit into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.18...2.x)
### Features
- Add Support for Multi Values in innerHit for Nested k-NN Fields in Lucene and FAISS (#2283)[https://github.com/opensearch-project/k-NN/pull/2283]
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
### Bug Fixes
Expand Down
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ dependencies {
testImplementation group: 'net.bytebuddy', name: 'byte-buddy', version: '1.15.10'
testImplementation group: 'org.objenesis', name: 'objenesis', version: '3.3'
testImplementation group: 'net.bytebuddy', name: 'byte-buddy-agent', version: '1.15.4'
testFixturesImplementation 'com.jayway.jsonpath:json-path:2.8.0'
testFixturesImplementation "org.opensearch:common-utils:${version}"
implementation 'com.github.oshi:oshi-core:6.4.13'
api "net.java.dev.jna:jna:5.13.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ public class KNNConstants {
public static final String QFRAMEWORK_CONFIG = "qframe_config";

public static final String VECTOR_DATA_TYPE_FIELD = "data_type";
public static final String EXPAND_NESTED = "expand_nested_docs";
public static final String MODEL_VECTOR_DATA_TYPE_KEY = VECTOR_DATA_TYPE_FIELD;
public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT;
public static final String MINIMAL_MODE_AND_COMPRESSION_FEATURE = "mode_and_compression_feature";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public enum KNNEngine implements KNNLibrary {
private static final Set<KNNEngine> CUSTOM_SEGMENT_FILE_ENGINES = ImmutableSet.of(KNNEngine.NMSLIB, KNNEngine.FAISS);
private static final Set<KNNEngine> ENGINES_SUPPORTING_FILTERS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);
public static final Set<KNNEngine> ENGINES_SUPPORTING_RADIAL_SEARCH = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);
public static final Set<KNNEngine> ENGINES_SUPPORTING_MULTI_VECTORS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);

private static Map<KNNEngine, Integer> MAX_DIMENSIONS_BY_ENGINE = Map.of(
KNNEngine.NMSLIB,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public static class CreateQueryRequest {
private QueryBuilder filter;
private QueryShardContext context;
private RescoreContext rescoreContext;
private boolean expandNested;

public Optional<QueryBuilder> getFilter() {
return Optional.ofNullable(filter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.util.BitSet;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.SpaceType;
Expand Down Expand Up @@ -68,8 +67,8 @@ public Map<Integer, Float> searchLeaf(final LeafReaderContext leafReaderContext,
if (exactSearcherContext.getKnnQuery().getRadius() != null) {
return doRadialSearch(leafReaderContext, exactSearcherContext, iterator);
}
if (exactSearcherContext.getMatchedDocs() != null
&& exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) {
if (exactSearcherContext.getMatchedDocsIterator() != null
&& exactSearcherContext.numberOfMatchedDocs <= exactSearcherContext.getK()) {
return scoreAllDocs(iterator);
}
return searchTopCandidates(iterator, exactSearcherContext.getK(), Predicates.alwaysTrue());
Expand Down Expand Up @@ -155,7 +154,7 @@ private Map<Integer, Float> filterDocsByMinScore(ExactSearcherContext context, K

private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext) throws IOException {
final KNNQuery knnQuery = exactSearcherContext.getKnnQuery();
final BitSet matchedDocs = exactSearcherContext.getMatchedDocs();
final DocIdSetIterator matchedDocs = exactSearcherContext.getMatchedDocsIterator();
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
final FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField());
if (fieldInfo == null) {
Expand Down Expand Up @@ -245,7 +244,8 @@ public static class ExactSearcherContext {
*/
boolean useQuantizedVectorsForSearch;
int k;
BitSet matchedDocs;
DocIdSetIterator matchedDocsIterator;
long numberOfMatchedDocs;
KNNQuery knnQuery;
/**
* whether the matchedDocs contains parent ids or child ids. This is relevant in the case of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public class KNNQuery extends Query {

@Setter
private Query filterQuery;
@Getter
private BitSetProducer parentsFilter;
private Float radius;
private Context context;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;

import static org.opensearch.knn.common.KNNConstants.EXPAND_NESTED;
import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH;
Expand All @@ -74,6 +75,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
public static final ParseField K_FIELD = new ParseField("k");
public static final ParseField FILTER_FIELD = new ParseField("filter");
public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped");
public static final ParseField EXPAND_NESTED_FIELD = new ParseField(EXPAND_NESTED);
public static final ParseField MAX_DISTANCE_FIELD = new ParseField(MAX_DISTANCE);
public static final ParseField MIN_SCORE_FIELD = new ParseField(MIN_SCORE);
public static final ParseField EF_SEARCH_FIELD = new ParseField(METHOD_PARAMETER_EF_SEARCH);
Expand Down Expand Up @@ -106,6 +108,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
private boolean ignoreUnmapped;
@Getter
private RescoreContext rescoreContext;
@Getter
private boolean expandNested;

/**
* Constructs a new query with the given field name and vector
Expand Down Expand Up @@ -147,6 +151,7 @@ public static class Builder {
private String queryName;
private float boost = DEFAULT_BOOST;
private RescoreContext rescoreContext;
private boolean expandNested;

public Builder() {}

Expand Down Expand Up @@ -205,6 +210,11 @@ public Builder rescoreContext(RescoreContext rescoreContext) {
return this;
}

public Builder expandNested(boolean expandNested) {
this.expandNested = expandNested;
return this;
}

public KNNQueryBuilder build() {
validate();
int k = this.k == null ? 0 : this.k;
Expand All @@ -217,7 +227,8 @@ public KNNQueryBuilder build() {
methodParameters,
filter,
ignoreUnmapped,
rescoreContext
rescoreContext,
expandNested
).boost(boost).queryName(queryName);
}

Expand Down Expand Up @@ -319,6 +330,7 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil
this.maxDistance = null;
this.minScore = null;
this.rescoreContext = null;
this.expandNested = false;
}

public static void initialize(ModelDao modelDao) {
Expand All @@ -341,6 +353,7 @@ public KNNQueryBuilder(StreamInput in) throws IOException {
minScore = builder.minScore;
methodParameters = builder.methodParameters;
rescoreContext = builder.rescoreContext;
expandNested = builder.expandNested;
}

@Override
Expand Down Expand Up @@ -536,6 +549,7 @@ protected Query doToQuery(QueryShardContext context) {
.filter(this.filter)
.context(context)
.rescoreContext(processedRescoreContext)
.expandNested(expandNested)
.build();
return KNNQueryFactory.create(createQueryRequest);
}
Expand Down Expand Up @@ -621,7 +635,8 @@ protected boolean doEquals(KNNQueryBuilder other) {
&& Objects.equals(methodParameters, other.methodParameters)
&& Objects.equals(filter, other.filter)
&& Objects.equals(ignoreUnmapped, other.ignoreUnmapped)
&& Objects.equals(rescoreContext, other.rescoreContext);
&& Objects.equals(rescoreContext, other.rescoreContext)
&& Objects.equals(expandNested, other.expandNested);
}

@Override
Expand All @@ -635,7 +650,8 @@ protected int doHashCode() {
ignoreUnmapped,
maxDistance,
minScore,
rescoreContext
rescoreContext,
expandNested
);
}

Expand Down
70 changes: 55 additions & 15 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,28 @@
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery;
import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.common.QueryUtils;
import org.opensearch.knn.index.query.lucenelib.NestedKnnVectorQueryFactory;
import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery;
import org.opensearch.knn.index.query.rescore.RescoreContext;

import java.util.Locale;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.EXPAND_NESTED;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES;
import static org.opensearch.knn.index.engine.KNNEngine.ENGINES_SUPPORTING_MULTI_VECTORS;

/**
* Creates the Lucene k-NN queries
*/
@Log4j2
public class KNNQueryFactory extends BaseQueryFactory {

/**
* Creates a Lucene query for a particular engine.
* @param createQueryRequest request object that has all required fields to construct the query
Expand All @@ -48,13 +49,25 @@ public static Query create(CreateQueryRequest createQueryRequest) {
final Query filterQuery = getFilterQuery(createQueryRequest);
final Map<String, ?> methodParameters = createQueryRequest.getMethodParameters();
final RescoreContext rescoreContext = createQueryRequest.getRescoreContext().orElse(null);

final KNNEngine knnEngine = createQueryRequest.getKnnEngine();
final boolean expandNested = createQueryRequest.isExpandNested();
BitSetProducer parentFilter = null;
if (createQueryRequest.getContext().isPresent()) {
QueryShardContext context = createQueryRequest.getContext().get();
parentFilter = context.getParentFilter();
}

if (parentFilter == null && expandNested) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"Invalid value provided for the [%s] field. [%s] is only supported with a nested field.",
EXPAND_NESTED,
EXPAND_NESTED
)
);
}

if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) {
final Query validatedFilterQuery = validateFilterQuerySupport(filterQuery, createQueryRequest.getKnnEngine());

Expand Down Expand Up @@ -95,7 +108,16 @@ public static Query create(CreateQueryRequest createQueryRequest) {
.rescoreContext(rescoreContext)
.build();
}
return createQueryRequest.getRescoreContext().isPresent() ? new NativeEngineKnnVectorQuery(knnQuery) : knnQuery;

if (createQueryRequest.getRescoreContext().isPresent()) {
return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, expandNested);
}

if (ENGINES_SUPPORTING_MULTI_VECTORS.contains(knnEngine) && expandNested) {
return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, expandNested);
}

return knnQuery;
}

Integer requestEfSearch = null;
Expand All @@ -106,9 +128,9 @@ public static Query create(CreateQueryRequest createQueryRequest) {
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
switch (vectorDataType) {
case BYTE:
return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter);
return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter, expandNested);
case FLOAT:
return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter);
return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter, expandNested);
default:
throw new IllegalArgumentException(
String.format(
Expand All @@ -131,38 +153,56 @@ private static Query validateFilterQuerySupport(final Query filterQuery, final K
}

/**
* If parentFilter is not null, it is a nested query. Therefore, we return {@link DiversifyingChildrenByteKnnVectorQuery}
* which will dedupe search result per parent so that we can get k parent results at the end.
* If parentFilter is not null, it is a nested query. Therefore, we delegate creation of query to {@link NestedKnnVectorQueryFactory}
* which will create query to dedupe search result per parent so that we can get k parent results at the end.
*/
private static Query getKnnByteVectorQuery(
final String fieldName,
final byte[] byteVector,
heemin32 marked this conversation as resolved.
Show resolved Hide resolved
final int k,
final Query filterQuery,
final BitSetProducer parentFilter
final BitSetProducer parentFilter,
final boolean expandNested
) {
if (parentFilter == null) {
assert expandNested == false : "expandNested is allowed to be true only for nested fields.";
return new KnnByteVectorQuery(fieldName, byteVector, k, filterQuery);
} else {
return new DiversifyingChildrenByteKnnVectorQuery(fieldName, byteVector, filterQuery, k, parentFilter);
return NestedKnnVectorQueryFactory.createNestedKnnVectorQuery(
fieldName,
byteVector,
k,
filterQuery,
parentFilter,
expandNested
);
}
}

/**
* If parentFilter is not null, it is a nested query. Therefore, we return {@link DiversifyingChildrenFloatKnnVectorQuery}
* which will dedupe search result per parent so that we can get k parent results at the end.
* If parentFilter is not null, it is a nested query. Therefore, we delegate creation of query to {@link NestedKnnVectorQueryFactory}
* which will create query to dedupe search result per parent so that we can get k parent results at the end.
*/
private static Query getKnnFloatVectorQuery(
final String fieldName,
final float[] floatVector,
final int k,
final Query filterQuery,
final BitSetProducer parentFilter
final BitSetProducer parentFilter,
final boolean expandNested
) {
if (parentFilter == null) {
assert expandNested == false : "expandNested is allowed to be true only for nested fields.";
return new KnnFloatVectorQuery(fieldName, floatVector, k, filterQuery);
} else {
return new DiversifyingChildrenFloatKnnVectorQuery(fieldName, floatVector, filterQuery, k, parentFilter);
return NestedKnnVectorQueryFactory.createNestedKnnVectorQuery(
fieldName,
floatVector,
k,
filterQuery,
parentFilter,
expandNested
);
}
}
}
Loading
Loading