Skip to content

Commit

Permalink
Multiple innerHit in nested fields
Browse files Browse the repository at this point in the history
Signed-off-by: Heemin Kim <[email protected]>
  • Loading branch information
heemin32 committed Nov 25, 2024
1 parent 2d1a408 commit 142ae3c
Show file tree
Hide file tree
Showing 39 changed files with 1,807 additions and 137 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.18...2.x)
### Features
- Add Support for Multi Values in innerHit for Nested k-NN Fields in Lucene and FAISS (#2283)[https://github.com/opensearch-project/k-NN/pull/2283]
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
### Bug Fixes
Expand Down
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ dependencies {
testImplementation group: 'net.bytebuddy', name: 'byte-buddy', version: '1.15.10'
testImplementation group: 'org.objenesis', name: 'objenesis', version: '3.3'
testImplementation group: 'net.bytebuddy', name: 'byte-buddy-agent', version: '1.15.4'
testFixturesImplementation 'com.jayway.jsonpath:json-path:2.8.0'
testFixturesImplementation "org.opensearch:common-utils:${version}"
implementation 'com.github.oshi:oshi-core:6.4.13'
api "net.java.dev.jna:jna:5.13.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public enum KNNEngine implements KNNLibrary {
private static final Set<KNNEngine> CUSTOM_SEGMENT_FILE_ENGINES = ImmutableSet.of(KNNEngine.NMSLIB, KNNEngine.FAISS);
private static final Set<KNNEngine> ENGINES_SUPPORTING_FILTERS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);
public static final Set<KNNEngine> ENGINES_SUPPORTING_RADIAL_SEARCH = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);
public static final Set<KNNEngine> ENGINES_SUPPORTING_MULTI_VECTORS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);

private static Map<KNNEngine, Integer> MAX_DIMENSIONS_BY_ENGINE = Map.of(
KNNEngine.NMSLIB,
Expand Down
15 changes: 10 additions & 5 deletions src/main/java/org/opensearch/knn/index/query/ExactSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.util.BitSet;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.SpaceType;
Expand All @@ -38,6 +37,7 @@
import org.opensearch.knn.indices.ModelDao;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
Expand All @@ -60,11 +60,13 @@ public class ExactSearcher {
public Map<Integer, Float> searchLeaf(final LeafReaderContext leafReaderContext, final ExactSearcherContext exactSearcherContext)
throws IOException {
KNNIterator iterator = getKNNIterator(leafReaderContext, exactSearcherContext);
if (iterator == null) {
return Collections.emptyMap();
}
if (exactSearcherContext.getKnnQuery().getRadius() != null) {
return doRadialSearch(leafReaderContext, exactSearcherContext, iterator);
}
if (exactSearcherContext.getMatchedDocs() != null
&& exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) {
if (exactSearcherContext.getMatchedDocs() != null && exactSearcherContext.getMatchedDocs().cost() <= exactSearcherContext.getK()) {
return scoreAllDocs(iterator);
}
return searchTopCandidates(iterator, exactSearcherContext.getK(), Predicates.alwaysTrue());
Expand Down Expand Up @@ -147,9 +149,12 @@ private Map<Integer, Float> filterDocsByMinScore(ExactSearcherContext context, K

private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext) throws IOException {
final KNNQuery knnQuery = exactSearcherContext.getKnnQuery();
final BitSet matchedDocs = exactSearcherContext.getMatchedDocs();
final DocIdSetIterator matchedDocs = exactSearcherContext.getMatchedDocs();
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
if (fieldInfo == null) {
return null;
}
final SpaceType spaceType = FieldInfoExtractor.getSpaceType(modelDao, fieldInfo);

boolean isNestedRequired = exactSearcherContext.isParentHits() && knnQuery.getParentsFilter() != null;
Expand Down Expand Up @@ -233,7 +238,7 @@ public static class ExactSearcherContext {
*/
boolean useQuantizedVectorsForSearch;
int k;
BitSet matchedDocs;
DocIdSetIterator matchedDocs;
KNNQuery knnQuery;
/**
* whether the matchedDocs contains parent ids or child ids. This is relevant in the case of
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/index/query/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public class KNNQuery extends Query {

@Setter
private Query filterQuery;
@Getter
private BitSetProducer parentsFilter;
private Float radius;
private Context context;
Expand Down
45 changes: 38 additions & 7 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.common.QueryUtils;
import org.opensearch.knn.index.query.lucenelib.NestedKnnVectorQueryFactory;
import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery;
import org.opensearch.knn.index.query.rescore.RescoreContext;

Expand All @@ -24,12 +26,14 @@
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES;
import static org.opensearch.knn.index.engine.KNNEngine.ENGINES_SUPPORTING_MULTI_VECTORS;

/**
* Creates the Lucene k-NN queries
*/
@Log4j2
public class KNNQueryFactory extends BaseQueryFactory {
private static final QueryUtils QUERY_UTILS = new QueryUtils();

/**
* Creates a Lucene query for a particular engine.
Expand All @@ -48,11 +52,14 @@ public static Query create(CreateQueryRequest createQueryRequest) {
final Query filterQuery = getFilterQuery(createQueryRequest);
final Map<String, ?> methodParameters = createQueryRequest.getMethodParameters();
final RescoreContext rescoreContext = createQueryRequest.getRescoreContext().orElse(null);
final KNNEngine knnEngine = createQueryRequest.getKnnEngine();

BitSetProducer parentFilter = null;
boolean isInnerHitQuery = false;
if (createQueryRequest.getContext().isPresent()) {
QueryShardContext context = createQueryRequest.getContext().get();
parentFilter = context.getParentFilter();
isInnerHitQuery = context.isInnerHitQuery();
}

if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) {
Expand Down Expand Up @@ -95,7 +102,14 @@ public static Query create(CreateQueryRequest createQueryRequest) {
.rescoreContext(rescoreContext)
.build();
}
return createQueryRequest.getRescoreContext().isPresent() ? new NativeEngineKnnVectorQuery(knnQuery) : knnQuery;

if (createQueryRequest.getRescoreContext().isPresent()) {
return new NativeEngineKnnVectorQuery(knnQuery, QUERY_UTILS, isInnerHitQuery);
} else if (ENGINES_SUPPORTING_MULTI_VECTORS.contains(knnEngine) && isInnerHitQuery) {
return new NativeEngineKnnVectorQuery(knnQuery, QUERY_UTILS, isInnerHitQuery);
} else {
return knnQuery;
}
}

Integer requestEfSearch = null;
Expand All @@ -106,9 +120,9 @@ public static Query create(CreateQueryRequest createQueryRequest) {
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
switch (vectorDataType) {
case BYTE:
return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter);
return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter, isInnerHitQuery);
case FLOAT:
return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter);
return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter, isInnerHitQuery);
default:
throw new IllegalArgumentException(
String.format(
Expand Down Expand Up @@ -139,12 +153,21 @@ private static Query getKnnByteVectorQuery(
final byte[] byteVector,
final int k,
final Query filterQuery,
final BitSetProducer parentFilter
final BitSetProducer parentFilter,
final boolean isInnerHitQuery
) {
if (parentFilter == null) {
assert isInnerHitQuery == false;
return new KnnByteVectorQuery(fieldName, byteVector, k, filterQuery);
} else {
return new DiversifyingChildrenByteKnnVectorQuery(fieldName, byteVector, filterQuery, k, parentFilter);
return NestedKnnVectorQueryFactory.createNestedKnnVectorQuery(
fieldName,
byteVector,
k,
filterQuery,
parentFilter,
isInnerHitQuery
);
}
}

Expand All @@ -157,12 +180,20 @@ private static Query getKnnFloatVectorQuery(
final float[] floatVector,
final int k,
final Query filterQuery,
final BitSetProducer parentFilter
final BitSetProducer parentFilter,
final boolean isInnerHitQuery
) {
if (parentFilter == null) {
return new KnnFloatVectorQuery(fieldName, floatVector, k, filterQuery);
} else {
return new DiversifyingChildrenFloatKnnVectorQuery(fieldName, floatVector, filterQuery, k, parentFilter);
return NestedKnnVectorQueryFactory.createNestedKnnVectorQuery(
fieldName,
floatVector,
k,
filterQuery,
parentFilter,
isInnerHitQuery
);
}
}
}
15 changes: 8 additions & 7 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index.query;

import com.google.common.annotations.VisibleForTesting;
import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
Expand Down Expand Up @@ -66,6 +67,7 @@ public class KNNWeight extends Weight {
private final float boost;

private final NativeMemoryCacheManager nativeMemoryCacheManager;
@Getter
private final Weight filterWeight;
private final ExactSearcher exactSearcher;

Expand Down Expand Up @@ -140,14 +142,14 @@ public Map<Integer, Float> searchLeaf(LeafReaderContext context, int k) throws I
* This improves the recall.
*/
if (isFilteredExactSearchPreferred(cardinality)) {
return doExactSearch(context, filterBitSet, k);
return doExactSearch(context, new BitSetIterator(filterBitSet, filterBitSet.cardinality()), k);
}
Map<Integer, Float> docIdsToScoreMap = doANNSearch(context, filterBitSet, cardinality, k);
// See whether we have to perform exact search based on approx search results
// This is required if there are no native engine files or if approximate search returned
// results less than K, though we have more than k filtered docs
if (isExactSearchRequire(context, cardinality, docIdsToScoreMap.size())) {
final BitSet docs = filterWeight != null ? filterBitSet : null;
final BitSetIterator docs = filterWeight != null ? new BitSetIterator(filterBitSet, filterBitSet.cardinality()) : null;
return doExactSearch(context, docs, k);
}
return docIdsToScoreMap;
Expand Down Expand Up @@ -205,17 +207,16 @@ private int[] bitSetToIntArray(final BitSet bitSet) {
return intArray;
}

private Map<Integer, Float> doExactSearch(final LeafReaderContext context, final BitSet acceptedDocs, int k) throws IOException {
private Map<Integer, Float> doExactSearch(final LeafReaderContext context, final DocIdSetIterator acceptedDocs, int k)
throws IOException {
final ExactSearcherContextBuilder exactSearcherContextBuilder = ExactSearcher.ExactSearcherContext.builder()
.isParentHits(true)
.k(k)
// setting to true, so that if quantization details are present we want to do search on the quantized
// vectors as this flow is used in first pass of search.
.useQuantizedVectorsForSearch(true)
.knnQuery(knnQuery);
if (acceptedDocs != null) {
exactSearcherContextBuilder.matchedDocs(acceptedDocs);
}
.knnQuery(knnQuery)
.matchedDocs(acceptedDocs);
return exactSearch(context, exactSearcherContextBuilder.build());
}

Expand Down
14 changes: 6 additions & 8 deletions src/main/java/org/opensearch/knn/index/query/ResultUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.DocIdSetBuilder;

import java.io.IOException;
Expand Down Expand Up @@ -58,19 +57,18 @@ public static void reduceToTopK(List<Map<Integer, Float>> perLeafResults, int k)
}

/**
* Convert map to bit set, if resultMap is empty or null then returns an Optional. Returning an optional here to
* ensure that the caller is aware that BitSet may not be present
* Convert map of docs to doc id set iterator
*
* @param resultMap Map of results
* @return BitSet of results; null is returned if the result map is empty
* @return Doc id set iterator
* @throws IOException If an error occurs during the search.
*/
public static BitSet resultMapToMatchBitSet(Map<Integer, Float> resultMap) throws IOException {
if (resultMap == null || resultMap.isEmpty()) {
return null;
public static DocIdSetIterator resultMapToDocIds(Map<Integer, Float> resultMap) throws IOException {
if (resultMap.isEmpty()) {
return DocIdSetIterator.empty();
}
final int maxDoc = Collections.max(resultMap.keySet()) + 1;
return BitSet.of(resultMapToDocIds(resultMap, maxDoc), maxDoc);
return resultMapToDocIds(resultMap, maxDoc);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.nativelib;
package org.opensearch.knn.index.query.common;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
Expand All @@ -24,15 +24,15 @@
/**
* This is the same as {@link org.apache.lucene.search.AbstractKnnVectorQuery.DocAndScoreQuery}
*/
final class DocAndScoreQuery extends Query {
final public class DocAndScoreQuery extends Query {

private final int k;
private final int[] docs;
private final float[] scores;
private final int[] segmentStarts;
private final Object contextIdentity;

DocAndScoreQuery(int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
public DocAndScoreQuery(int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
this.k = k;
this.docs = docs;
this.scores = scores;
Expand Down
Loading

0 comments on commit 142ae3c

Please sign in to comment.